# -*- coding: utf-8 -*- """ @author: yq @time: 2024/11/1 @desc: 模型训练管道 """ import pandas as pd from entitys import DataSplitEntity, MlConfigEntity, DataFeatureEntity from feature import FeatureStrategyFactory from feature.feature_strategy_base import FeatureStrategyBase from init import init from model import ModelBase from model import ModelFactory from monitor.report_generate import Report init() class Pipeline(): def __init__(self, ml_config: MlConfigEntity = None, data: DataSplitEntity = None, *args, **kwargs): if ml_config is not None: self._ml_config = ml_config else: self._ml_config = MlConfigEntity(*args, **kwargs) feature_strategy_clazz = FeatureStrategyFactory.get_strategy(self._ml_config.feature_strategy) self._feature_strategy: FeatureStrategyBase = feature_strategy_clazz(self._ml_config) model_clazz = ModelFactory.get_model(self._ml_config.model_type) self._model: ModelBase = model_clazz(self._ml_config) self._data = data def train(self, ): # 特征筛选 self._feature_strategy.feature_search(self._data) metric_feature = self._feature_strategy.feature_report(self._data) # 生成训练数据 train_data = self._feature_strategy.feature_generate(self._data.train_data) train_data = DataFeatureEntity(data_x=train_data, data_y=self._data.train_data[self._ml_config.y_column]) test_data = self._feature_strategy.feature_generate(self._data.test_data) test_data = DataFeatureEntity(data_x=test_data, data_y=self._data.test_data[self._ml_config.y_column]) self._model.train(train_data, test_data) metric_model = self._model.train_report(self._data) self.metric_value_dict = {**metric_feature, **metric_model} def prob(self, data: pd.DataFrame): feature = self._feature_strategy.feature_generate(data) prob = self._model.prob(feature) return prob def score(self, data: pd.DataFrame): return self._model.score(data) def score_rule(self, data: pd.DataFrame): return self._model.score_rule(data) def report(self, ): save_path = self._ml_config.f_get_save_path("模型报告.docx") Report.generate_report(self.metric_value_dict, self._model.get_report_template_path(), save_path=save_path) print(f"模型报告文件储存路径:{save_path}") def save(self): self._ml_config.config_save() self._feature_strategy.feature_save() self._model.model_save() @staticmethod def load(path: str): ml_config = MlConfigEntity.from_config(path) pipeline = Pipeline(ml_config=ml_config) pipeline._feature_strategy.feature_load(path) pipeline._model.model_load(path) return pipeline def variable_analyse(self, column: str, format_bin=None): self._feature_strategy.variable_analyse(self._data, column, format_bin) if __name__ == "__main__": pass