model_lr.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/1
  5. @desc:
  6. """
  7. import pandas as pd
  8. from sklearn.linear_model import LogisticRegression
  9. from toad.metrics import KS, AUC
  10. from entitys import MetricTrainEntity, TrainConfigEntity, DataPreparedEntity
  11. from .model_base import ModelBase
  12. class ModelLr(ModelBase):
  13. def __init__(self, train_config: TrainConfigEntity):
  14. super().__init__(train_config)
  15. self.lr = LogisticRegression(penalty='l1', C=0.9, solver='saga', n_jobs=-1)
  16. def train(self, data: DataPreparedEntity, *args, **kwargs) -> MetricTrainEntity:
  17. train_data = data.train_data
  18. test_data = data.test_data
  19. self.lr.fit(train_data.get_Xdata(), train_data.get_Ydata())
  20. train_prob = self.lr.predict_proba(train_data.get_Xdata())[:, 1]
  21. train_auc = AUC(train_prob, train_data.get_Ydata())
  22. train_ks = KS(train_prob, train_data.get_Ydata())
  23. test_prob = self.lr.predict_proba(test_data.get_Xdata())[:, 1]
  24. test_auc = AUC(test_prob, test_data.get_Ydata())
  25. test_ks = KS(test_prob, test_data.get_Ydata())
  26. return MetricTrainEntity(train_auc, train_ks, test_auc, test_ks)
  27. def predict_prob(self, x: pd.DataFrame, *args, **kwargs):
  28. return self.lr.predict_proba(x)[:, 1]
  29. def predict(self, x: pd.DataFrame, *args, **kwargs):
  30. return self.lr.predict(x)
  31. def export_model_file(self):
  32. pass
  33. if __name__ == "__main__":
  34. pass