pipeline.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/1
  5. @desc: 模型训练管道
  6. """
  7. import pandas as pd
  8. from entitys import DataSplitEntity, MlConfigEntity, DataFeatureEntity
  9. from feature import FeatureStrategyFactory, FeatureStrategyBase
  10. from init import init
  11. from model import ModelBase, ModelFactory, f_add_rules
  12. from monitor import ReportWord
  13. init()
  14. class Pipeline():
  15. def __init__(self, ml_config: MlConfigEntity = None, data: DataSplitEntity = None, *args, **kwargs):
  16. if ml_config is not None:
  17. self._ml_config = ml_config
  18. else:
  19. self._ml_config = MlConfigEntity(*args, **kwargs)
  20. feature_strategy_clazz = FeatureStrategyFactory.get_strategy(self._ml_config.feature_strategy)
  21. self._feature_strategy: FeatureStrategyBase = feature_strategy_clazz(self._ml_config)
  22. model_clazz = ModelFactory.get_model(self._ml_config.model_type)
  23. self._model: ModelBase = model_clazz(self._ml_config)
  24. self._data = data
  25. def train(self, ):
  26. # 特征筛选
  27. self._feature_strategy.feature_search(self._data)
  28. metric_feature = self._feature_strategy.feature_report(self._data)
  29. # 生成训练数据
  30. train_data = self._feature_strategy.feature_generate(self._data.train_data)
  31. train_data = DataFeatureEntity(data_x=train_data, data_y=self._data.train_data[self._ml_config.y_column])
  32. test_data = self._feature_strategy.feature_generate(self._data.test_data)
  33. test_data = DataFeatureEntity(data_x=test_data, data_y=self._data.test_data[self._ml_config.y_column])
  34. self._model.train(train_data, test_data)
  35. metric_model = self._model.train_report(self._data)
  36. self.metric_value_dict = {**metric_feature, **metric_model}
  37. def prob(self, data: pd.DataFrame):
  38. feature = self._feature_strategy.feature_generate(data)
  39. prob = self._model.prob(feature)
  40. return prob
  41. def score(self, data: pd.DataFrame):
  42. return self._model.score(data)
  43. def score_rule(self, data: pd.DataFrame):
  44. return self._model.score_rule(data)
  45. def report(self, ):
  46. save_path = self._ml_config.f_get_save_path("模型报告.docx")
  47. ReportWord.generate_report(self.metric_value_dict, self._model.get_report_template_path(), save_path=save_path)
  48. print(f"模型报告文件储存路径:{save_path}")
  49. def save(self):
  50. self._ml_config.config_save()
  51. self._feature_strategy.feature_save()
  52. self._model.model_save()
  53. @staticmethod
  54. def load(path: str):
  55. ml_config = MlConfigEntity.from_config(path)
  56. pipeline = Pipeline(ml_config=ml_config)
  57. pipeline._feature_strategy.feature_load(path)
  58. pipeline._model.model_load(path)
  59. return pipeline
  60. def variable_analyse(self, column: str, format_bin=None):
  61. self._feature_strategy.variable_analyse(self._data, column, format_bin)
  62. def rules_test(self, ):
  63. rules = self._ml_config.rules
  64. df = self._data.train_data.copy()
  65. df["SCORE"] = [0] * len(df)
  66. f_add_rules(df, rules)
  67. if __name__ == "__main__":
  68. pass