커서가 망쳐놓은 듯
This commit is contained in:
@@ -1,13 +1,12 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
KIS Bot용 ML 승률 예측 모델
|
||||
- kis_bot/quant_bot.db의 trade_history 데이터로 학습
|
||||
- MariaDB kis_quant_db의 trade_history 데이터로 학습
|
||||
- 매수 신호의 승률 예측 (0.0 ~ 1.0)
|
||||
- 주간 단위 자동 재학습
|
||||
"""
|
||||
import os
|
||||
import pickle
|
||||
import sqlite3
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
@@ -20,7 +19,6 @@ logger = logging.getLogger("KIS_MLPredictor")
|
||||
|
||||
try:
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score
|
||||
ML_AVAILABLE = True
|
||||
except ImportError:
|
||||
@@ -37,11 +35,10 @@ class MLPredictor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: str = None,
|
||||
db_path: str = None, # 하위 호환용 (무시됨) — MariaDB 사용
|
||||
model_path: str = None,
|
||||
):
|
||||
# 기본값: kis_bot/quant_bot.db, kis_bot/ml_model.pkl
|
||||
self.db_path = db_path or str(SCRIPT_DIR / "quant_bot.db")
|
||||
# db_path: 하위 호환을 위해 파라미터 유지하나 내부적으로 MariaDB 사용
|
||||
self.model_path = model_path or str(SCRIPT_DIR / "ml_model.pkl")
|
||||
self.model = None
|
||||
self.feature_names = [
|
||||
@@ -63,40 +60,53 @@ class MLPredictor:
|
||||
self.load_model()
|
||||
|
||||
def extract_features_from_db(self, days: int = 90) -> pd.DataFrame:
|
||||
"""DB에서 학습용 피처 추출
|
||||
"""MariaDB trade_history 에서 학습용 피처 추출.
|
||||
|
||||
현재는 trade_history의 profit_rate 기반으로 승/패 라벨만 생성하고,
|
||||
피처는 프로토타입 단계로 랜덤 값을 사용한다.
|
||||
(실전에서는 active_trades에 진입 시점 피처를 저장해서 사용해야 함)
|
||||
- profit_rate 로 승/패 라벨 생성
|
||||
- 진입 시점 피처(rsi, volume_ratio 등)는 매수 시 DB에 저장된 값 사용
|
||||
- 시간순 정렬 후 반환 (시계열 분리 시 미래 참조 방지)
|
||||
"""
|
||||
try:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
from database import TradeDB
|
||||
db = TradeDB()
|
||||
cutoff_date = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d")
|
||||
|
||||
query = f"""
|
||||
SELECT profit_rate, buy_date, sell_date, strategy
|
||||
FROM trade_history
|
||||
WHERE sell_date >= '{cutoff_date}'
|
||||
ORDER BY sell_date DESC
|
||||
"""
|
||||
|
||||
df = pd.read_sql_query(query, conn)
|
||||
conn.close()
|
||||
|
||||
feat_cols = ", ".join(f"`{c}`" for c in self.feature_names)
|
||||
rows = db.conn.execute(
|
||||
f"SELECT profit_rate, buy_date, sell_date, strategy, {feat_cols} "
|
||||
f"FROM trade_history WHERE sell_date >= %s ORDER BY sell_date ASC",
|
||||
(cutoff_date,)
|
||||
).fetchall()
|
||||
|
||||
if not rows:
|
||||
logger.warning("⚠️ 학습 데이터 없음 (trade_history 조회 결과 0건)")
|
||||
return None
|
||||
|
||||
# DictCursor 반환 → DataFrame 직접 생성
|
||||
df = pd.DataFrame(rows)
|
||||
|
||||
if len(df) < self.min_train_samples:
|
||||
logger.warning(
|
||||
f"⚠️ 학습 데이터 부족: {len(df)}건 (최소 {self.min_train_samples}건 필요)"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# 진입 피처가 전부 NULL인 과거 데이터 제외
|
||||
feature_ok = df[self.feature_names].notna().any(axis=1)
|
||||
df = df.loc[feature_ok].copy()
|
||||
if len(df) < self.min_train_samples:
|
||||
logger.warning(
|
||||
f"⚠️ 진입 피처가 있는 데이터 부족: {len(df)}건 (최소 {self.min_train_samples}건 필요). "
|
||||
"매수 시 entry_features 저장 후 누적되면 학습 가능합니다."
|
||||
)
|
||||
return None
|
||||
|
||||
df["is_win"] = (df["profit_rate"] > 0).astype(int)
|
||||
logger.info(
|
||||
f"📊 학습 데이터 로드: {len(df)}건 "
|
||||
f"(익절: {df['is_win'].sum()}건, 손절: {(1 - df['is_win']).sum()}건)"
|
||||
f"📊 학습 데이터 로드: {len(df)}건 (시간순) "
|
||||
f"(익절: {df['is_win'].sum()}건, 손절: {(~df['is_win']).sum()}건)"
|
||||
)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 피처 추출 실패: {e}")
|
||||
return None
|
||||
@@ -117,29 +127,19 @@ class MLPredictor:
|
||||
logger.warning("⚠️ 학습 데이터 부족 - ML 모델 사용 불가")
|
||||
return False
|
||||
|
||||
logger.warning("⚠️ [프로토타입] 랜덤 피처로 학습 중")
|
||||
logger.warning(" → 실제 운영 시: active_trades 테이블에 진입 피처 저장 후 사용")
|
||||
|
||||
# TODO: 실제 피처 데이터로 교체 필요
|
||||
# 현재는 데모용 랜덤 피처 사용
|
||||
np.random.seed(42)
|
||||
X = pd.DataFrame(
|
||||
{
|
||||
"rsi": np.random.uniform(20, 80, len(df)),
|
||||
"volume_ratio": np.random.uniform(0.5, 5.0, len(df)),
|
||||
"tail_length_pct": np.random.uniform(0, 5, len(df)),
|
||||
"ma5_gap_pct": np.random.uniform(-5, 5, len(df)),
|
||||
"ma20_gap_pct": np.random.uniform(-10, 10, len(df)),
|
||||
"foreign_net_buy": np.random.uniform(-1000, 1000, len(df)),
|
||||
"institution_net_buy": np.random.uniform(-500, 500, len(df)),
|
||||
"market_hour": np.random.randint(9, 15, len(df)),
|
||||
}
|
||||
)
|
||||
# 실제 DB 진입 피처 사용 (누락값은 0으로 보정)
|
||||
X = df[self.feature_names].fillna(0).astype(float)
|
||||
y = df["is_win"].values
|
||||
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, test_size=0.2, random_state=42, stratify=y
|
||||
)
|
||||
# 시계열 분리: 무작위 셔플 금지. 과거 80% = 학습, 최근 20% = 테스트 (미래 참조 방지)
|
||||
n = len(X)
|
||||
train_size = int(n * 0.8)
|
||||
if train_size < 10 or (n - train_size) < 5:
|
||||
logger.warning("⚠️ 데이터 적어 시계열 분리 불가 (학습/테스트 각 10건·5건 이상 권장)")
|
||||
train_size = max(10, n - 5)
|
||||
X_train, X_test = X.iloc[:train_size], X.iloc[train_size:]
|
||||
y_train, y_test = y[:train_size], y[train_size:]
|
||||
logger.info(f"📅 시계열 분리: 학습 {len(y_train)}건 (과거) / 테스트 {len(y_test)}건 (최근)")
|
||||
|
||||
logger.info("🤖 RandomForest 학습 시작...")
|
||||
self.model = RandomForestClassifier(
|
||||
@@ -173,7 +173,8 @@ class MLPredictor:
|
||||
return True
|
||||
|
||||
def predict_win_probability(self, features: dict) -> float:
|
||||
"""매수 신호의 승률 예측 (0.0 ~ 1.0)"""
|
||||
"""매수 신호의 승률 예측 (0.0 ~ 1.0).
|
||||
매매 로직에서는 ML_MIN_PROBABILITY(권장 0.65 이상) 미만일 때 진입 보류 권장."""
|
||||
if not ML_AVAILABLE or self.model is None:
|
||||
return 0.5
|
||||
|
||||
|
||||
Reference in New Issue
Block a user