Procházet zdrojové kódy

add: 增加逻辑回归预测、KS和AUC计算代码

wangzhaoyang před 5 měsíci
rodič
revize
e5c5186151
2 změnil soubory, kde provedl 9 přidání a 2 odebrání
  1. 2 0
      entitys/data_process_config_entity.py
  2. 7 2
      model/model_lr.py

+ 2 - 0
entitys/data_process_config_entity.py

@@ -10,6 +10,8 @@ import os
 from commom import GeneralException
 from enums import ResultCodesEnum
 
+from sklearn.model_selection import train_test_split
+
 
 class DataProcessConfigEntity():
     def __init__(self, y_column: str, fill_method: str, split_method: str):

+ 7 - 2
model/model_lr.py

@@ -10,6 +10,8 @@ from sklearn.linear_model import LogisticRegression
 from entitys import DataFeatureEntity, MetricTrainEntity, TrainConfigEntity
 from .model_base import ModelBase
 
+from toad.metrics import KS, AUC
+
 
 class ModelLr(ModelBase):
     def __init__(self, train_config: TrainConfigEntity):
@@ -18,13 +20,16 @@ class ModelLr(ModelBase):
 
     def train(self, data: DataFeatureEntity, *args, **kwargs) -> MetricTrainEntity:
         self.lr.fit(data.get_Xdata(), data.get_Ydata())
-        return MetricTrainEntity(0.7, 0.4)
+        pred_y = self.predict(data.get_Xdata())
+        ks = KS(pred_y, data.get_Ydata())
+        auc = AUC(pred_y, data.get_Ydata())
+        return MetricTrainEntity(auc, ks)
 
     def predict_prob(self, x: pd.DataFrame, *args, **kwargs):
         return self.lr.predict_proba(x)[:, 1]
 
     def predict(self, x: pd.DataFrame, *args, **kwargs):
-        pass
+        return self.lr.predict(x)
 
     def export_model_file(self):
         pass