pipeline.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/1
  5. @desc: 模型训练管道
  6. """
  7. import pandas as pd
  8. from entitys import DataSplitEntity, MlConfigEntity, DataFeatureEntity
  9. from feature import FeatureStrategyFactory
  10. from feature.feature_strategy_base import FeatureStrategyBase
  11. from init import init
  12. from model import ModelBase
  13. from model import ModelFactory
  14. from monitor.report_generate import Report
  15. init()
  16. class Pipeline():
  17. def __init__(self, ml_config: MlConfigEntity = None, data: DataSplitEntity = None, *args, **kwargs):
  18. if ml_config is not None:
  19. self._ml_config = ml_config
  20. else:
  21. self._ml_config = MlConfigEntity(*args, **kwargs)
  22. feature_strategy_clazz = FeatureStrategyFactory.get_strategy(self._ml_config.feature_strategy)
  23. self._feature_strategy: FeatureStrategyBase = feature_strategy_clazz(self._ml_config)
  24. model_clazz = ModelFactory.get_model(self._ml_config.model_type)
  25. self._model: ModelBase = model_clazz(self._ml_config)
  26. self._data = data
  27. def train(self, ):
  28. # 特征筛选
  29. self._feature_strategy.feature_search(self._data)
  30. metric_feature = self._feature_strategy.feature_report(self._data)
  31. # 生成训练数据
  32. train_data = self._feature_strategy.feature_generate(self._data.train_data)
  33. train_data = DataFeatureEntity(data_x=train_data, data_y=self._data.train_data[self._ml_config.y_column])
  34. test_data = self._feature_strategy.feature_generate(self._data.test_data)
  35. test_data = DataFeatureEntity(data_x=test_data, data_y=self._data.test_data[self._ml_config.y_column])
  36. self._model.train(train_data, test_data)
  37. metric_model = self._model.train_report(self._data)
  38. self.metric_value_dict = {**metric_feature, **metric_model}
  39. def prob(self, data: pd.DataFrame):
  40. feature = self._feature_strategy.feature_generate(data)
  41. prob = self._model.prob(feature)
  42. return prob
  43. def score(self, data: pd.DataFrame):
  44. return self._model.score(data)
  45. def score_rule(self, data: pd.DataFrame):
  46. return self._model.score_rule(data)
  47. def report(self, ):
  48. save_path = self._ml_config.f_get_save_path("模型报告.docx")
  49. Report.generate_report(self.metric_value_dict, self._model.get_report_template_path(), save_path=save_path)
  50. print(f"模型报告文件储存路径:{save_path}")
  51. def save(self):
  52. self._ml_config.config_save()
  53. self._feature_strategy.feature_save()
  54. self._model.model_save()
  55. @staticmethod
  56. def load(path: str):
  57. ml_config = MlConfigEntity.from_config(path)
  58. pipeline = Pipeline(ml_config=ml_config)
  59. pipeline._feature_strategy.feature_load(path)
  60. pipeline._model.model_load(path)
  61. return pipeline
  62. def variable_analyse(self, column: str, format_bin=None):
  63. self._feature_strategy.variable_analyse(self._data, column, format_bin)
  64. if __name__ == "__main__":
  65. pass