model_lr.py 820 B

123456789101112131415161718192021222324252627282930313233
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/1
  5. @desc:
  6. """
  7. import pandas as pd
  8. from sklearn.linear_model import LogisticRegression
  9. from entitys import DataFeatureEntity, MetricTrainEntity
  10. from .model_base import ModelBase
  11. class ModelLr(ModelBase):
  12. def __init__(self, ):
  13. self.lr = LogisticRegression(penalty='l1', C=0.9, solver='saga', n_jobs=-1)
  14. def train(self, data: DataFeatureEntity, *args, **kwargs) -> MetricTrainEntity:
  15. self.lr.fit(data.get_Xdata(), data.get_Ydata())
  16. return MetricTrainEntity(0.7, 0.4)
  17. def predict_prob(self, x: pd.DataFrame, *args, **kwargs):
  18. return self.lr.predict_proba(x)[:, 1]
  19. def predict(self, x: pd.DataFrame, *args, **kwargs):
  20. pass
  21. def export_model_file(self):
  22. pass
  23. if __name__ == "__main__":
  24. pass