pipeline.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/1
  5. @desc: 模型训练管道
  6. """
  7. from typing import List
  8. import pandas as pd
  9. from entitys import DataSplitEntity, MlConfigEntity, DataFeatureEntity
  10. from enums import ConstantEnum, ModelEnum
  11. from feature import FeatureStrategyFactory, FeatureStrategyBase
  12. from init import init
  13. from model import ModelBase, ModelFactory, f_add_rules, f_get_model_score_bin, f_calcu_model_psi
  14. from monitor import ReportWord
  15. init()
  16. class Pipeline():
  17. def __init__(self, ml_config: MlConfigEntity = None, data: DataSplitEntity = None, *args, **kwargs):
  18. if ml_config is not None:
  19. self._ml_config = ml_config
  20. else:
  21. self._ml_config = MlConfigEntity(*args, **kwargs)
  22. feature_strategy_clazz = FeatureStrategyFactory.get_strategy(self._ml_config.feature_strategy)
  23. self._feature_strategy: FeatureStrategyBase = feature_strategy_clazz(self._ml_config)
  24. model_clazz = ModelFactory.get_model(self._ml_config.model_type)
  25. self._model: ModelBase = model_clazz(self._ml_config)
  26. self._data = data
  27. def train(self, ):
  28. # 特征筛选
  29. self._feature_strategy.feature_search(self._data)
  30. metric_feature = self._feature_strategy.feature_report(self._data)
  31. # 生成训练数据
  32. train_data = self._feature_strategy.feature_generate(self._data.train_data)
  33. train_data = DataFeatureEntity(data_x=train_data, data_y=self._data.train_data[self._ml_config.y_column])
  34. test_data = self._feature_strategy.feature_generate(self._data.test_data)
  35. test_data = DataFeatureEntity(data_x=test_data, data_y=self._data.test_data[self._ml_config.y_column])
  36. self._model.train(train_data, test_data, data=self._data)
  37. metric_model = self._model.train_report(self._data)
  38. self.metric_value_dict = {**metric_feature, **metric_model}
  39. def prob(self, data: pd.DataFrame):
  40. if self._ml_config.model_type == ModelEnum.XGB.value:
  41. feature = data
  42. else:
  43. feature = self._feature_strategy.feature_generate(data)
  44. prob = self._model.prob(feature)
  45. return prob
  46. def score(self, data: pd.DataFrame):
  47. return self._model.score(data)
  48. def score_rule(self, data: pd.DataFrame):
  49. return self._model.score_rule(data)
  50. def psi(self, x1: pd.DataFrame, x2: pd.DataFrame, points: List[float] = None) -> pd.DataFrame:
  51. if self._ml_config.model_type == ModelEnum.XGB.value:
  52. y1 = self.prob(x1)
  53. y2 = self.prob(x2)
  54. sort_ascending = False
  55. else:
  56. sort_ascending = True
  57. if len(self._ml_config.rules) != 0:
  58. y1 = self.score_rule(x1)
  59. y2 = self.score_rule(x2)
  60. else:
  61. y1 = self.score(x1)
  62. y2 = self.score(x2)
  63. x1_score_bin, score_bins = f_get_model_score_bin(x1, y1, points)
  64. x2_score_bin, _ = f_get_model_score_bin(x2, y2, score_bins)
  65. model_psi = f_calcu_model_psi(x1_score_bin, x2_score_bin, sort_ascending)
  66. print(f"模型psi: {model_psi['psi'].sum()}")
  67. return model_psi
  68. def report(self, ):
  69. save_path = self._ml_config.f_get_save_path("模型报告.docx")
  70. ReportWord.generate_report(self.metric_value_dict, self._model.get_report_template_path(), save_path=save_path)
  71. print(f"模型报告文件储存路径:{save_path}")
  72. def save(self):
  73. self._ml_config.config_save()
  74. self._feature_strategy.feature_save()
  75. self._model.model_save()
  76. @staticmethod
  77. def load(path: str):
  78. ml_config = MlConfigEntity.from_config(path)
  79. pipeline = Pipeline(ml_config=ml_config)
  80. pipeline._feature_strategy.feature_load(path)
  81. pipeline._model.model_load(path)
  82. return pipeline
  83. def variable_analyse(self, column: str, format_bin=None):
  84. self._feature_strategy.variable_analyse(self._data, column, format_bin)
  85. def rules_test(self, ):
  86. rules = self._ml_config.rules
  87. df = self._data.train_data.copy()
  88. df[ConstantEnum.SCORE.value] = [0] * len(df)
  89. f_add_rules(df, rules)
  90. if __name__ == "__main__":
  91. pass