Files
kis_bot/ml_predictor.py
2026-03-17 12:33:30 +09:00

231 lines
8.7 KiB
Python
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
KIS Bot용 ML 승률 예측 모델
- MariaDB kis_quant_db의 trade_history 데이터로 학습
- 매수 신호의 승률 예측 (0.0 ~ 1.0)
- 주간 단위 자동 재학습
"""
import os
import pickle
import logging
from pathlib import Path
from datetime import datetime, timedelta
import numpy as np
import pandas as pd
# Logger 설정
logger = logging.getLogger("KIS_MLPredictor")
try:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score
ML_AVAILABLE = True
except ImportError:
ML_AVAILABLE = False
logger.warning("⚠️ scikit-learn 미설치! ML 기능 사용 불가")
logger.warning(" 설치: pip install scikit-learn")
SCRIPT_DIR = Path(__file__).resolve().parent
class MLPredictor:
"""매수 신호 승률 예측 모델"""
def __init__(
self,
db_path: str = None, # 하위 호환용 (무시됨) — MariaDB 사용
model_path: str = None,
):
# db_path: 하위 호환을 위해 파라미터 유지하나 내부적으로 MariaDB 사용
self.model_path = model_path or str(SCRIPT_DIR / "ml_model.pkl")
self.model = None
self.feature_names = [
"rsi",
"volume_ratio",
"tail_length_pct",
"ma5_gap_pct",
"ma20_gap_pct",
"foreign_net_buy",
"institution_net_buy",
"market_hour",
]
self.min_train_samples = 30
if not ML_AVAILABLE:
logger.error("❌ scikit-learn이 설치되지 않았습니다!")
return
self.load_model()
def extract_features_from_db(self, days: int = 90) -> pd.DataFrame:
"""MariaDB trade_history 에서 학습용 피처 추출.
- profit_rate 로 승/패 라벨 생성
- 진입 시점 피처(rsi, volume_ratio 등)는 매수 시 DB에 저장된 값 사용
- 시간순 정렬 후 반환 (시계열 분리 시 미래 참조 방지)
"""
try:
from database import TradeDB
db = TradeDB()
cutoff_date = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d")
feat_cols = ", ".join(f"`{c}`" for c in self.feature_names)
rows = db.conn.execute(
f"SELECT profit_rate, buy_date, sell_date, strategy, {feat_cols} "
f"FROM trade_history WHERE sell_date >= %s ORDER BY sell_date ASC",
(cutoff_date,)
).fetchall()
if not rows:
logger.warning("⚠️ 학습 데이터 없음 (trade_history 조회 결과 0건)")
return None
# DictCursor 반환 → DataFrame 직접 생성
df = pd.DataFrame(rows)
if len(df) < self.min_train_samples:
logger.warning(
f"⚠️ 학습 데이터 부족: {len(df)}건 (최소 {self.min_train_samples}건 필요)"
)
return None
# 진입 피처가 전부 NULL인 과거 데이터 제외
feature_ok = df[self.feature_names].notna().any(axis=1)
df = df.loc[feature_ok].copy()
if len(df) < self.min_train_samples:
logger.warning(
f"⚠️ 진입 피처가 있는 데이터 부족: {len(df)}건 (최소 {self.min_train_samples}건 필요). "
"매수 시 entry_features 저장 후 누적되면 학습 가능합니다."
)
return None
df["is_win"] = (df["profit_rate"] > 0).astype(int)
logger.info(
f"📊 학습 데이터 로드: {len(df)}건 (시간순) "
f"(익절: {df['is_win'].sum()}건, 손절: {(~df['is_win']).sum()}건)"
)
return df
except Exception as e:
logger.error(f"❌ 피처 추출 실패: {e}")
return None
def train_model(self, retrain: bool = False) -> bool:
"""모델 학습"""
if not ML_AVAILABLE:
logger.error("❌ scikit-learn 미설치로 학습 불가")
return False
if self.model is not None and not retrain:
logger.info("✅ 기존 모델 사용")
return True
df = self.extract_features_from_db(days=90)
if df is None or len(df) < self.min_train_samples:
logger.warning("⚠️ 학습 데이터 부족 - ML 모델 사용 불가")
return False
# 실제 DB 진입 피처 사용 (누락값은 0으로 보정)
X = df[self.feature_names].fillna(0).astype(float)
y = df["is_win"].values
# 시계열 분리: 무작위 셔플 금지. 과거 80% = 학습, 최근 20% = 테스트 (미래 참조 방지)
n = len(X)
train_size = int(n * 0.8)
if train_size < 10 or (n - train_size) < 5:
logger.warning("⚠️ 데이터 적어 시계열 분리 불가 (학습/테스트 각 10건·5건 이상 권장)")
train_size = max(10, n - 5)
X_train, X_test = X.iloc[:train_size], X.iloc[train_size:]
y_train, y_test = y[:train_size], y[train_size:]
logger.info(f"📅 시계열 분리: 학습 {len(y_train)}건 (과거) / 테스트 {len(y_test)}건 (최근)")
logger.info("🤖 RandomForest 학습 시작...")
self.model = RandomForestClassifier(
n_estimators=100,
max_depth=10,
min_samples_split=5,
random_state=42,
)
self.model.fit(X_train, y_train)
y_pred = self.model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred, zero_division=0)
recall = recall_score(y_test, y_pred, zero_division=0)
logger.info("✅ 학습 완료!")
logger.info(f" 정확도: {accuracy:.2%}")
logger.info(f" 정밀도: {precision:.2%}")
logger.info(f" 재현율: {recall:.2%}")
feature_importance = sorted(
zip(self.feature_names, self.model.feature_importances_),
key=lambda x: x[1],
reverse=True,
)
logger.info(" 중요 피처:")
for fname, importance in feature_importance[:5]:
logger.info(f" {fname}: {importance:.3f}")
self.save_model()
return True
def predict_win_probability(self, features: dict) -> float:
"""매수 신호의 승률 예측 (0.0 ~ 1.0).
매매 로직에서는 ML_MIN_PROBABILITY(권장 0.65 이상) 미만일 때 진입 보류 권장."""
if not ML_AVAILABLE or self.model is None:
return 0.5
try:
X = pd.DataFrame([features])[self.feature_names]
proba = self.model.predict_proba(X)[0]
win_prob = proba[1]
return float(win_prob)
except Exception as e:
logger.error(f"❌ 예측 실패: {e}")
return 0.5
def save_model(self) -> None:
"""모델 파일로 저장"""
try:
with open(self.model_path, "wb") as f:
pickle.dump(self.model, f)
logger.info(f"💾 모델 저장: {self.model_path}")
except Exception as e:
logger.error(f"❌ 모델 저장 실패: {e}")
def load_model(self) -> bool:
"""저장된 모델 로드"""
if not ML_AVAILABLE:
return False
if os.path.exists(self.model_path):
try:
with open(self.model_path, "rb") as f:
self.model = pickle.load(f)
logger.info(f"✅ 모델 로드: {self.model_path}")
return True
except Exception as e:
logger.error(f"❌ 모델 로드 실패: {e}")
else:
logger.info(" 저장된 모델 없음 - 첫 실행 시 학습 필요")
return False
def should_retrain(self) -> bool:
"""재학습이 필요한지 체크 (7일 경과 시)"""
if not os.path.exists(self.model_path):
return True
model_mtime = datetime.fromtimestamp(os.path.getmtime(self.model_path))
days_old = (datetime.now() - model_mtime).days
if days_old >= 7:
logger.info(f"🔄 모델 {days_old}일 경과 → 재학습 필요")
return True
return False