Files
kis_bot/ml_predictor.py
2026-02-22 18:05:14 +09:00

230 lines
8.2 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
KIS Bot용 ML 승률 예측 모델
- kis_bot/quant_bot.db의 trade_history 데이터로 학습
- 매수 신호의 승률 예측 (0.0 ~ 1.0)
- 주간 단위 자동 재학습
"""
import os
import pickle
import sqlite3
import logging
from pathlib import Path
from datetime import datetime, timedelta
import numpy as np
import pandas as pd
# Logger 설정
logger = logging.getLogger("KIS_MLPredictor")
try:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score
ML_AVAILABLE = True
except ImportError:
ML_AVAILABLE = False
logger.warning("⚠️ scikit-learn 미설치! ML 기능 사용 불가")
logger.warning(" 설치: pip install scikit-learn")
SCRIPT_DIR = Path(__file__).resolve().parent
class MLPredictor:
"""매수 신호 승률 예측 모델"""
def __init__(
self,
db_path: str = None,
model_path: str = None,
):
# 기본값: kis_bot/quant_bot.db, kis_bot/ml_model.pkl
self.db_path = db_path or str(SCRIPT_DIR / "quant_bot.db")
self.model_path = model_path or str(SCRIPT_DIR / "ml_model.pkl")
self.model = None
self.feature_names = [
"rsi",
"volume_ratio",
"tail_length_pct",
"ma5_gap_pct",
"ma20_gap_pct",
"foreign_net_buy",
"institution_net_buy",
"market_hour",
]
self.min_train_samples = 30
if not ML_AVAILABLE:
logger.error("❌ scikit-learn이 설치되지 않았습니다!")
return
self.load_model()
def extract_features_from_db(self, days: int = 90) -> pd.DataFrame:
"""DB에서 학습용 피처 추출
현재는 trade_history의 profit_rate 기반으로 승/패 라벨만 생성하고,
피처는 프로토타입 단계로 랜덤 값을 사용한다.
(실전에서는 active_trades에 진입 시점 피처를 저장해서 사용해야 함)
"""
try:
conn = sqlite3.connect(self.db_path)
cutoff_date = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d")
query = f"""
SELECT profit_rate, buy_date, sell_date, strategy
FROM trade_history
WHERE sell_date >= '{cutoff_date}'
ORDER BY sell_date DESC
"""
df = pd.read_sql_query(query, conn)
conn.close()
if len(df) < self.min_train_samples:
logger.warning(
f"⚠️ 학습 데이터 부족: {len(df)}건 (최소 {self.min_train_samples}건 필요)"
)
return None
df["is_win"] = (df["profit_rate"] > 0).astype(int)
logger.info(
f"📊 학습 데이터 로드: {len(df)}"
f"(익절: {df['is_win'].sum()}건, 손절: {(1 - df['is_win']).sum()}건)"
)
return df
except Exception as e:
logger.error(f"❌ 피처 추출 실패: {e}")
return None
def train_model(self, retrain: bool = False) -> bool:
"""모델 학습"""
if not ML_AVAILABLE:
logger.error("❌ scikit-learn 미설치로 학습 불가")
return False
if self.model is not None and not retrain:
logger.info("✅ 기존 모델 사용")
return True
df = self.extract_features_from_db(days=90)
if df is None or len(df) < self.min_train_samples:
logger.warning("⚠️ 학습 데이터 부족 - ML 모델 사용 불가")
return False
logger.warning("⚠️ [프로토타입] 랜덤 피처로 학습 중")
logger.warning(" → 실제 운영 시: active_trades 테이블에 진입 피처 저장 후 사용")
# TODO: 실제 피처 데이터로 교체 필요
# 현재는 데모용 랜덤 피처 사용
np.random.seed(42)
X = pd.DataFrame(
{
"rsi": np.random.uniform(20, 80, len(df)),
"volume_ratio": np.random.uniform(0.5, 5.0, len(df)),
"tail_length_pct": np.random.uniform(0, 5, len(df)),
"ma5_gap_pct": np.random.uniform(-5, 5, len(df)),
"ma20_gap_pct": np.random.uniform(-10, 10, len(df)),
"foreign_net_buy": np.random.uniform(-1000, 1000, len(df)),
"institution_net_buy": np.random.uniform(-500, 500, len(df)),
"market_hour": np.random.randint(9, 15, len(df)),
}
)
y = df["is_win"].values
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
logger.info("🤖 RandomForest 학습 시작...")
self.model = RandomForestClassifier(
n_estimators=100,
max_depth=10,
min_samples_split=5,
random_state=42,
)
self.model.fit(X_train, y_train)
y_pred = self.model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred, zero_division=0)
recall = recall_score(y_test, y_pred, zero_division=0)
logger.info("✅ 학습 완료!")
logger.info(f" 정확도: {accuracy:.2%}")
logger.info(f" 정밀도: {precision:.2%}")
logger.info(f" 재현율: {recall:.2%}")
feature_importance = sorted(
zip(self.feature_names, self.model.feature_importances_),
key=lambda x: x[1],
reverse=True,
)
logger.info(" 중요 피처:")
for fname, importance in feature_importance[:5]:
logger.info(f" {fname}: {importance:.3f}")
self.save_model()
return True
def predict_win_probability(self, features: dict) -> float:
"""매수 신호의 승률 예측 (0.0 ~ 1.0)"""
if not ML_AVAILABLE or self.model is None:
return 0.5
try:
X = pd.DataFrame([features])[self.feature_names]
proba = self.model.predict_proba(X)[0]
win_prob = proba[1]
return float(win_prob)
except Exception as e:
logger.error(f"❌ 예측 실패: {e}")
return 0.5
def save_model(self) -> None:
"""모델 파일로 저장"""
try:
with open(self.model_path, "wb") as f:
pickle.dump(self.model, f)
logger.info(f"💾 모델 저장: {self.model_path}")
except Exception as e:
logger.error(f"❌ 모델 저장 실패: {e}")
def load_model(self) -> bool:
"""저장된 모델 로드"""
if not ML_AVAILABLE:
return False
if os.path.exists(self.model_path):
try:
with open(self.model_path, "rb") as f:
self.model = pickle.load(f)
logger.info(f"✅ 모델 로드: {self.model_path}")
return True
except Exception as e:
logger.error(f"❌ 모델 로드 실패: {e}")
else:
logger.info(" 저장된 모델 없음 - 첫 실행 시 학습 필요")
return False
def should_retrain(self) -> bool:
"""재학습이 필요한지 체크 (7일 경과 시)"""
if not os.path.exists(self.model_path):
return True
model_mtime = datetime.fromtimestamp(os.path.getmtime(self.model_path))
days_old = (datetime.now() - model_mtime).days
if days_old >= 7:
logger.info(f"🔄 모델 {days_old}일 경과 → 재학습 필요")
return True
return False