1234567891011121314151617181920212223242526272829303132333435363738394041424344454647 |
- # -*- coding: utf-8 -*-
- """
- @author: yq
- @time: 2024/11/1
- @desc:
- """
- import pandas as pd
- from sklearn.linear_model import LogisticRegression
- from toad.metrics import KS, AUC
- from entitys import MetricTrainEntity, TrainConfigEntity, DataPreparedEntity
- from .model_base import ModelBase
- class ModelLr(ModelBase):
- def __init__(self, train_config: TrainConfigEntity):
- super().__init__(train_config)
- self.lr = LogisticRegression(penalty='l1', C=0.9, solver='saga', n_jobs=-1)
- def train(self, data: DataPreparedEntity, *args, **kwargs) -> MetricTrainEntity:
- train_data = data.train_data
- test_data = data.test_data
- self.lr.fit(train_data.get_Xdata(), train_data.get_Ydata())
- train_prob = self.lr.predict_proba(train_data.get_Xdata())[:, 1]
- train_auc = AUC(train_prob, train_data.get_Ydata())
- train_ks = KS(train_prob, train_data.get_Ydata())
- test_prob = self.lr.predict_proba(test_data.get_Xdata())[:, 1]
- test_auc = AUC(test_prob, test_data.get_Ydata())
- test_ks = KS(test_prob, test_data.get_Ydata())
- return MetricTrainEntity(train_auc, train_ks, test_auc, test_ks)
- def predict_prob(self, x: pd.DataFrame, *args, **kwargs):
- return self.lr.predict_proba(x)[:, 1]
- def predict(self, x: pd.DataFrame, *args, **kwargs):
- return self.lr.predict(x)
- def export_model_file(self):
- pass
- if __name__ == "__main__":
- pass
|