# -*- coding: utf-8 -*-
"""
@author: yq
@time: 2025/2/14
@desc: 
"""
from typing import Union, List

import pandas as pd

from enums import ContextEnum
from init import context


class BinInfo():
    def __init__(self,
                 x_column: str = None,
                 bin_num: int = None,
                 points: list = None,
                 is_auto_bins: int = None,
                 train_iv: float = None,
                 test_iv: float = None,
                 iv: float = None,
                 is_qualified_iv_train: int = None,
                 monto_shift_nsv: int = None,
                 is_qualified_monto_train_nsv: int = None,
                 trend_shift_nsv: int = None,
                 is_qualified_trend_nsv: int = None,
                 psi: float = None,
                 is_qualified_psi: int = None,
                 ):
        self.x_column = x_column
        self.bin_num = bin_num
        self.points = points
        self.is_auto_bins = is_auto_bins
        self.train_iv = train_iv
        self.test_iv = test_iv
        self.iv = iv
        self.is_qualified_iv_train = is_qualified_iv_train
        self.monto_shift_nsv = monto_shift_nsv
        self.is_qualified_monto_train_nsv = is_qualified_monto_train_nsv
        self.trend_shift_nsv = trend_shift_nsv
        self.is_qualified_trend_nsv = is_qualified_trend_nsv
        self.psi = psi
        self.is_qualified_psi = is_qualified_psi

    def to_dict(self):
        return self.__dict__

    @staticmethod
    def ivTopN(data: dict, top_n: int):
        candidate = list(data.values())
        candidate.sort(key=lambda x: x.iv, reverse=True)
        filter_ivtop_overview = ""
        filter_ivtop_detail = []
        if top_n < len(candidate):
            for bin_info in candidate[top_n:]:
                filter_ivtop_overview = f"{filter_ivtop_overview}{bin_info.x_column} 因为ivtop【{bin_info.iv}】被剔除\n"
                filter_ivtop_detail.append(bin_info.x_column)
        candidate = candidate[0:top_n]
        context.set_filter_info(ContextEnum.FILTER_IVTOP, filter_ivtop_overview, filter_ivtop_detail)
        return {bin_info.x_column: bin_info for bin_info in candidate}

    @staticmethod
    def ofConvertByDict(data: dict):
        bin_info = BinInfo()
        for k, v in data.items():
            bin_info.__setattr__(k, v)
        return bin_info


class HomologousBinInfo():
    """
     同一变量不同分箱下的特征信息
     """

    def __init__(self, x_column: str, is_auto_bins: int = None, is_include: bool = False):
        self.x_column = x_column
        self.is_auto_bins = is_auto_bins
        self.is_include = is_include
        self.bins_info: List[BinInfo] = []

    def add(self, bin_info: BinInfo):
        self.bins_info.append(bin_info)

    def convert_to_df(self) -> pd.DataFrame:
        data = []
        for bin_info in self.bins_info:
            data.append(bin_info.to_dict())
        df_bins_info = pd.DataFrame(data=data)
        return df_bins_info

    def drop_reason(self, ) -> str:
        df_bins_info = self.convert_to_df()

        df_bins_info_filter1 = df_bins_info[df_bins_info["is_qualified_iv_train"] == 1]
        if len(df_bins_info_filter1) == 0:
            return f"因为train_iv最大值【{df_bins_info['train_iv'].max()}】小于阈值被剔除"

        df_bins_info_filter2 = df_bins_info[
            (df_bins_info["is_qualified_iv_train"] == 1)
            & (df_bins_info["is_qualified_monto_train_nsv"] == 1)
            ]
        if len(df_bins_info_filter2) == 0:
            return f"因为monto单调变化最小次数【{df_bins_info_filter1['monto_shift_nsv'].min()}】大于阈值被剔除"

        df_bins_info_filter3 = df_bins_info[
            (df_bins_info["is_qualified_iv_train"] == 1)
            & (df_bins_info["is_qualified_monto_train_nsv"] == 1)
            & (df_bins_info["is_qualified_trend_nsv"] == 1)
            ]
        if len(df_bins_info_filter3) == 0:
            return f"因为trend变量趋势一致性变化最小次数【{df_bins_info_filter2['trend_shift_nsv'].min()}】大于阈值被剔除"

        df_bins_info_filter4 = df_bins_info[
            (df_bins_info["is_qualified_iv_train"] == 1)
            & (df_bins_info["is_qualified_monto_train_nsv"] == 1)
            & (df_bins_info["is_qualified_trend_nsv"] == 1)
            & (df_bins_info["is_qualified_psi"] == 1)
            ]
        if len(df_bins_info_filter4) == 0:
            return f"因为psi【{df_bins_info_filter3['psi'].min()}】大于阈值被剔除"

        print(df_bins_info_filter4)
        return f"因为【未知原因】被剔除"

    def filter(self) -> Union[BinInfo, None]:
        # iv psi 变量单调性 变量趋势一致性 筛选
        df_bins_info = self.convert_to_df()
        # 人工指定切分点的直接返回
        if not self.is_auto_bins:
            return BinInfo.ofConvertByDict(df_bins_info.iloc[0].to_dict())
        if self.is_include:
            df_bins_info_filter = df_bins_info
        else:
            df_bins_info_filter = df_bins_info[
                (df_bins_info["is_qualified_iv_train"] == 1)
                & (df_bins_info["is_qualified_monto_train_nsv"] == 1)
                & (df_bins_info["is_qualified_trend_nsv"] == 1)
                & (df_bins_info["is_qualified_psi"] == 1)
                ]
        # 选取单调性变化最少,iv最大,psi 最小的分箱
        df_bins_info_filter.sort_values(by=["monto_shift_nsv", "trend_shift_nsv", "iv", "psi"],
                                        ascending=[True, True, False, True], inplace=True)
        if len(df_bins_info_filter) != 0:
            return BinInfo.ofConvertByDict(df_bins_info_filter.iloc[0].to_dict())
        return None

    def get_best_bins(self) -> List[BinInfo]:
        df_bins_info = self.convert_to_df()
        bin_num_list = df_bins_info["bin_num"].unique().tolist()
        bin_num_list.sort()
        bins_info = []
        for bin_num in bin_num_list:
            df_bins_info_filter = df_bins_info[df_bins_info["bin_num"] == bin_num]
            df_bins_info_filter.sort_values(by=["monto_shift_nsv", "trend_shift_nsv", "iv", "psi"],
                                            ascending=[True, True, False, True], inplace=True)
            bin_info_dict1 = df_bins_info_filter.iloc[0].to_dict()
            bins_info.append(BinInfo.ofConvertByDict(bin_info_dict1))

            # 获取没单调性排序的,考虑到age这种变量允许有转折的
            df_bins_info_filter.sort_values(by=["trend_shift_nsv", "iv", "psi"],
                                            ascending=[True, False, True], inplace=True)
            bin_info_dict2 = df_bins_info_filter.iloc[0].to_dict()
            if bin_info_dict1["monto_shift_nsv"] != bin_info_dict2["monto_shift_nsv"]:
                bins_info.append(BinInfo.ofConvertByDict(bin_info_dict2))

        return bins_info