# -*- coding: utf-8 -*-
"""
@author: yq
@time: 2024/11/1
@desc: 模型训练管道
"""
import pandas as pd

from entitys import DataSplitEntity, MlConfigEntity, DataFeatureEntity
from feature import FeatureStrategyFactory, FeatureStrategyBase
from init import init
from model import ModelBase, ModelFactory, f_add_rules
from monitor import ReportWord

init()


class Pipeline():
    def __init__(self, ml_config: MlConfigEntity = None, data: DataSplitEntity = None, *args, **kwargs):
        if ml_config is not None:
            self._ml_config = ml_config
        else:
            self._ml_config = MlConfigEntity(*args, **kwargs)
        feature_strategy_clazz = FeatureStrategyFactory.get_strategy(self._ml_config.feature_strategy)
        self._feature_strategy: FeatureStrategyBase = feature_strategy_clazz(self._ml_config)
        model_clazz = ModelFactory.get_model(self._ml_config.model_type)
        self._model: ModelBase = model_clazz(self._ml_config)
        self._data = data

    def train(self, ):
        # 特征筛选
        self._feature_strategy.feature_search(self._data)
        metric_feature = self._feature_strategy.feature_report(self._data)

        # 生成训练数据
        train_data = self._feature_strategy.feature_generate(self._data.train_data)
        train_data = DataFeatureEntity(data_x=train_data, data_y=self._data.train_data[self._ml_config.y_column])
        test_data = self._feature_strategy.feature_generate(self._data.test_data)
        test_data = DataFeatureEntity(data_x=test_data, data_y=self._data.test_data[self._ml_config.y_column])
        self._model.train(train_data, test_data)
        metric_model = self._model.train_report(self._data)

        self.metric_value_dict = {**metric_feature, **metric_model}

    def prob(self, data: pd.DataFrame):
        feature = self._feature_strategy.feature_generate(data)
        prob = self._model.prob(feature)
        return prob

    def score(self, data: pd.DataFrame):
        return self._model.score(data)

    def score_rule(self, data: pd.DataFrame):
        return self._model.score_rule(data)

    def report(self, ):
        save_path = self._ml_config.f_get_save_path("模型报告.docx")
        ReportWord.generate_report(self.metric_value_dict, self._model.get_report_template_path(), save_path=save_path)
        print(f"模型报告文件储存路径:{save_path}")

    def save(self):
        self._ml_config.config_save()
        self._feature_strategy.feature_save()
        self._model.model_save()

    @staticmethod
    def load(path: str):
        ml_config = MlConfigEntity.from_config(path)
        pipeline = Pipeline(ml_config=ml_config)
        pipeline._feature_strategy.feature_load(path)
        pipeline._model.model_load(path)
        return pipeline

    def variable_analyse(self, column: str, format_bin=None):
        self._feature_strategy.variable_analyse(self._data, column, format_bin)

    def rules_test(self, ):
        rules = self._ml_config.rules
        df = self._data.train_data.copy()
        df["SCORE"] = [0] * len(df)
        f_add_rules(df, rules)


if __name__ == "__main__":
    pass