|
@@ -10,6 +10,8 @@ from sklearn.linear_model import LogisticRegression
|
|
|
from entitys import DataFeatureEntity, MetricTrainEntity, TrainConfigEntity
|
|
|
from .model_base import ModelBase
|
|
|
|
|
|
+from toad.metrics import KS, AUC
|
|
|
+
|
|
|
|
|
|
class ModelLr(ModelBase):
|
|
|
def __init__(self, train_config: TrainConfigEntity):
|
|
@@ -18,13 +20,16 @@ class ModelLr(ModelBase):
|
|
|
|
|
|
def train(self, data: DataFeatureEntity, *args, **kwargs) -> MetricTrainEntity:
|
|
|
self.lr.fit(data.get_Xdata(), data.get_Ydata())
|
|
|
- return MetricTrainEntity(0.7, 0.4)
|
|
|
+ pred_y = self.predict(data.get_Xdata())
|
|
|
+ ks = KS(pred_y, data.get_Ydata())
|
|
|
+ auc = AUC(pred_y, data.get_Ydata())
|
|
|
+ return MetricTrainEntity(auc, ks)
|
|
|
|
|
|
def predict_prob(self, x: pd.DataFrame, *args, **kwargs):
|
|
|
return self.lr.predict_proba(x)[:, 1]
|
|
|
|
|
|
def predict(self, x: pd.DataFrame, *args, **kwargs):
|
|
|
- pass
|
|
|
+ return self.lr.predict(x)
|
|
|
|
|
|
def export_model_file(self):
|
|
|
pass
|