# -*- coding: utf-8 -*-
"""
@author: yq
@time: 2025/4/3
@desc: 值标准化,类似于分箱
"""
import os
from typing import Dict, List

import joblib
import pandas as pd
import xgboost as xgb
from pandas.core.dtypes.common import is_numeric_dtype

from commom import GeneralException, f_display_title
from data import DataExplore
from entitys import DataSplitEntity, MetricFucResultEntity
from enums import ResultCodesEnum, ContextEnum, FileEnum
from feature.feature_strategy_base import FeatureStrategyBase
from init import context
from .utils import f_format_value, OneHot, f_format_bin


class StrategyNorm(FeatureStrategyBase):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.x_columns = None
        self.one_hot_encoder_dict: Dict[str, OneHot] = {}
        self.points_dict: Dict[str, List[float]] = {}

    def _f_fast_filter(self, data: DataSplitEntity) -> List[str]:
        y_column = self.ml_config.y_column
        x_columns = self.ml_config.x_columns
        columns_exclude = self.ml_config.columns_exclude
        format_bin = self.ml_config.format_bin
        params_xgb = self.ml_config.params_xgb
        max_feature_num = self.ml_config.max_feature_num
        columns_anns = self.ml_config.columns_anns

        train_data = data.train_data.copy()
        test_data = data.test_data.copy()

        # 特征列配置
        if len(x_columns) == 0:
            x_columns = train_data.columns.tolist()
        if y_column in x_columns:
            x_columns.remove(y_column)
        for column in columns_exclude:
            if column in x_columns:
                x_columns.remove(column)

        # 简单校验数据类型一致性
        check_msg = DataExplore.check_type(data.data[x_columns])
        if check_msg != "":
            print(f"数据类型分析:\n{check_msg}\n同一变量请保持数据类型一致")
            raise GeneralException(ResultCodesEnum.ILLEGAL_PARAMS, message=f"数据类型错误.")

        # 数据处理
        model_columns = []
        num_columns = []
        str_columns = []
        for x_column in x_columns:
            if is_numeric_dtype(train_data[x_column]):
                num_columns.append(x_column)
                # 粗分箱
                if format_bin:
                    data_x_describe = train_data[x_column].describe(percentiles=[0.1, 0.9])
                    points = f_format_bin(data_x_describe)
                    if points is not None:
                        self.points_dict[x_column] = points
                        train_data[x_column] = train_data[x_column].apply(lambda x: f_format_value(points, x))
                        test_data[x_column] = test_data[x_column].apply(lambda x: f_format_value(points, x))
            else:
                str_columns.append(x_column)
                one_hot_encoder = OneHot()
                one_hot_encoder.fit(data.data, x_column)
                one_hot_encoder.encoder(train_data)
                one_hot_encoder.encoder(test_data)
                model_columns.extend(one_hot_encoder.columns_onehot)
                self.one_hot_encoder_dict[x_column] = one_hot_encoder

        model_columns.extend(num_columns)

        # 重要性剔除弱变量
        model = xgb.XGBClassifier(objective=params_xgb.get("objective"),
                                  n_estimators=params_xgb.get("num_boost_round"),
                                  max_depth=params_xgb.get("max_depth"),
                                  learning_rate=params_xgb.get("learning_rate"),
                                  random_state=params_xgb.get("random_state"),
                                  reg_alpha=params_xgb.get("alpha"),
                                  subsample=params_xgb.get("subsample"),
                                  colsample_bytree=params_xgb.get("colsample_bytree"),
                                  importance_type='weight'
                                  )

        model.fit(X=train_data[model_columns], y=train_data[y_column],
                  eval_set=[(train_data[model_columns], train_data[y_column]),
                            (test_data[model_columns], test_data[y_column])],
                  eval_metric=params_xgb.get("eval_metric"),
                  early_stopping_rounds=params_xgb.get("early_stopping_rounds"),
                  verbose=False,
                  )

        # 重要合并,字符型变量重要性为各one-hot子变量求和
        importance = model.feature_importances_
        feature = []
        importance_weight = []
        for x_column in num_columns:
            for i, j in zip(model_columns, importance):
                if i == x_column:
                    feature.append(x_column)
                    importance_weight.append(j)
                    break
        for x_column in str_columns:
            feature_cache = 0
            for i, j in zip(model_columns, importance):
                if i.startswith(f"{x_column}("):
                    feature_cache += j
            feature.append(x_column)
            importance_weight.append(feature_cache)

        anns = [columns_anns.get(column, "-") for column in feature]
        df_importance = pd.DataFrame({'feature': feature, f'importance_weight': importance_weight, "释义": anns})
        df_importance.sort_values(by=["importance_weight"], ascending=[False], inplace=True)
        df_importance.reset_index(drop=True, inplace=True)
        df_importance_rank = df_importance[df_importance["importance_weight"] > 0]
        df_importance_rank.reset_index(drop=True, inplace=True)

        x_columns_filter = list(df_importance_rank["feature"])[0:max_feature_num]

        context.set_filter_info(ContextEnum.FILTER_FAST,
                                f"筛选前变量数量:{len(x_columns)}\n{x_columns}\n"
                                f"快速筛选剔除变量数量:{len(x_columns) - len(x_columns_filter)}", detail=df_importance)

        context.set(ContextEnum.XGB_COLUMNS_NUM, num_columns)
        context.set(ContextEnum.XGB_POINTS, self.points_dict)

        return x_columns_filter

    def feature_search(self, data: DataSplitEntity, *args, **kwargs):
        x_columns = self._f_fast_filter(data)
        # 排个序,防止因为顺序原因导致的可能的bug
        x_columns.sort()
        self.x_columns = x_columns
        context.set(ContextEnum.XGB_COLUMNS_SELECTED, x_columns)

    def variable_analyse(self, *args, **kwargs):
        pass

    def feature_generate(self, data: pd.DataFrame, *args, **kwargs) -> pd.DataFrame:
        df = data.copy()
        model_columns = []
        for x_column in self.x_columns:
            if x_column in self.points_dict.keys():
                points = self.points_dict[x_column]
                df[x_column] = df[x_column].apply(lambda x: f_format_value(points, x))
                model_columns.append(x_column)
            elif x_column in self.one_hot_encoder_dict.keys():
                one_hot_encoder = self.one_hot_encoder_dict[x_column]
                one_hot_encoder.encoder(df)
                model_columns.extend(one_hot_encoder.columns_onehot)
            else:
                model_columns.append(x_column)

        return df[model_columns]

    def feature_save(self, *args, **kwargs):
        if self.x_columns is None:
            GeneralException(ResultCodesEnum.NOT_FOUND, message=f"feature不存在")

        path = self.ml_config.f_get_save_path(FileEnum.FEATURE_PKL.value)
        feature_info = {
            "x_columns": self.x_columns,
            "one_hot_encoder_dict": self.one_hot_encoder_dict,
            "points_dict": self.points_dict,
        }
        joblib.dump(feature_info, path)
        print(f"feature save to【{path}】success. ")

    def feature_load(self, path: str, *args, **kwargs):
        if os.path.isdir(path):
            path = os.path.join(path, FileEnum.FEATURE_PKL.value)
        if not os.path.isfile(path) or FileEnum.FEATURE_PKL.value not in path:
            raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"特征信息【{FileEnum.FEATURE_PKL.value}】不存在")

        feature_info = joblib.load(path)
        self.x_columns = feature_info["x_columns"]
        self.one_hot_encoder_dict = feature_info["one_hot_encoder_dict"]
        self.points_dict = feature_info["points_dict"]
        print(f"feature load from【{path}】success.")

    def feature_report(self, data: DataSplitEntity, *args, **kwargs) -> Dict[str, MetricFucResultEntity]:

        y_column = self.ml_config.y_column

        metric_value_dict = {}
        # 样本分布
        metric_value_dict["样本分布"] = MetricFucResultEntity(table=data.get_distribution(y_column), table_font_size=10,
                                                          table_cell_width=3)

        self.jupyter_print(metric_value_dict)
        return metric_value_dict

    def jupyter_print(self, metric_value_dict, *args, **kwargs):
        from IPython import display

        max_feature_num = self.ml_config.max_feature_num
        filter_fast = context.get(ContextEnum.FILTER_FAST)

        f_display_title(display, "样本分布")
        display.display(metric_value_dict["样本分布"].table)

        df_importance = filter_fast["detail"]
        df_importance = df_importance[df_importance["feature"].isin(self.x_columns)]
        f_display_title(display, "入模变量")
        display.display(df_importance)

        f_display_title(display, "快速筛选过程")
        print(f"剔除变量重要性排名{max_feature_num}以后的变量")
        print(filter_fast.get("overview"))
        display.display(filter_fast["detail"])