model_lr.py 736 B

1234567891011121314151617181920212223242526272829303132
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/1
  5. @desc:
  6. """
  7. import pandas as pd
  8. from sklearn.linear_model import LogisticRegression
  9. from entitys import DataFeatureEntity
  10. from model_base import ModelBase
  11. class ModelLr(ModelBase):
  12. def __init__(self, ):
  13. self.lr = LogisticRegression(penalty='l1', C=0.9, solver='saga', n_jobs=-1)
  14. def train(self, data: DataFeatureEntity, *args, **kwargs):
  15. self.lr.fit(data.get_Xdata(), data.get_Ydata())
  16. def predict_prob(self, x: pd.DataFrame, *args, **kwargs):
  17. return self.lr.predict_proba(x)[:, 1]
  18. def predict(self, x: pd.DataFrame, *args, **kwargs):
  19. pass
  20. def export_model_file(self):
  21. pass
  22. if __name__ == "__main__":
  23. pass