# -*- coding:utf-8 -*-
"""
@author: yq
@time: 2024/1/2
@desc: iv值及单调性筛选类
"""
from itertools import combinations_with_replacement
from typing import List

import numpy as np
import pandas as pd

from entitys import DataSplitEntity, CandidateFeatureEntity, DataProcessConfigEntity
from .feature_utils import f_judge_monto
from .filter_strategy_base import FilterStrategyBase


class StrategyIv(FilterStrategyBase):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _f_get_best_bins(self, data: DataSplitEntity, x_column: str):
        # 贪婪搜索【训练集】及【测试集】加起来【iv】值最高的且【单调】的分箱
        interval = self.data_process_config.bin_search_interval
        iv_threshold = self.data_process_config.iv_threshold
        special_values = self.data_process_config.get_special_values(x_column)
        y_column = self.data_process_config.y_column

        def _n0(x):
            return sum(x == 0)

        def _n1(x):
            return sum(x == 1)

        def _f_distribute_balls(balls, boxes):
            # 计算在 balls - 1 个空位中放入 boxes - 1 个隔板的方法数
            total_ways = combinations_with_replacement(range(balls + boxes - 1), boxes - 1)
            distribute_list = []
            # 遍历所有可能的隔板位置
            for combo in total_ways:
                # 根据隔板位置分配球
                distribution = [0] * boxes
                start = 0
                for i, divider in enumerate(combo):
                    distribution[i] = divider - start + 1
                    start = divider + 1
                distribution[-1] = balls - start  # 最后一个箱子的球数
                # 确保每个箱子至少有一个球
                if all(x > 0 for x in distribution):
                    distribute_list.append(distribution)
            return distribute_list

        def _get_sv_bins(df, x_column, y_column, special_values):
            # special_values_bins
            sv_bin_list = []
            for special in special_values:
                dtm = df[df[x_column] == special]
                if len(dtm) != 0:
                    dtm['bin'] = [str(special)] * len(dtm)
                    binning = dtm.groupby(['bin'], group_keys=False)[y_column].agg(
                        [_n0, _n1]).reset_index().rename(columns={'_n0': 'good', '_n1': 'bad'})
                    binning['is_special_values'] = [True] * len(binning)
                    sv_bin_list.append(binning)
            return sv_bin_list

        def _get_bins(df, x_column, y_column, breaks_list):
            dtm = pd.DataFrame({'y': df[y_column], 'value': df[x_column]})
            bstbrks = [-np.inf] + breaks_list + [np.inf]
            labels = ['[{},{})'.format(bstbrks[i], bstbrks[i + 1]) for i in range(len(bstbrks) - 1)]
            dtm.loc[:, 'bin'] = pd.cut(dtm['value'], bstbrks, right=False, labels=labels)
            dtm['bin'] = dtm['bin'].astype(str)
            bins = dtm.groupby(['bin'], group_keys=False)['y'].agg([_n0, _n1]) \
                .reset_index().rename(columns={'_n0': 'good', '_n1': 'bad'})
            bins['is_special_values'] = [False] * len(bins)
            return bins

        def _calculation_iv(bins):
            bins['count'] = bins['good'] + bins['bad']
            bins['badprob'] = bins['bad'] / bins['count']
            # 单调性判断
            bad_prob = bins[bins['is_special_values'] == False]['badprob'].values.tolist()
            if not f_judge_monto(bad_prob):
                return -1
            # 计算iv
            infovalue = pd.DataFrame({'good': bins['good'], 'bad': bins['bad']}) \
                .replace(0, 0.9) \
                .assign(
                DistrBad=lambda x: x.bad / sum(x.bad),
                DistrGood=lambda x: x.good / sum(x.good)
            ) \
                .assign(iv=lambda x: (x.DistrBad - x.DistrGood) * np.log(x.DistrBad / x.DistrGood)) \
                .iv
            bins['bin_iv'] = infovalue
            bins['total_iv'] = bins['bin_iv'].sum()
            iv = bins['total_iv'].values[0]
            return iv

        train_data = data.train_data
        train_data_filter = train_data[~train_data[x_column].isin(special_values)]
        train_data_filter = train_data_filter.sort_values(by=x_column, ascending=True)
        train_data_x = train_data_filter[x_column]

        test_data = data.test_data
        test_data_filter = None
        if test_data is not None and len(test_data) != 0:
            test_data_filter = test_data[~test_data[x_column].isin(special_values)]
            test_data_filter = test_data_filter.sort_values(by=x_column, ascending=True)

        # 构造数据切分点
        # 计算 2 - 5 箱的情况
        distribute_list = []
        points_list = []
        for bin_num in list(range(2, 6)):
            distribute_list.extend(_f_distribute_balls(int(1 / interval), bin_num))
        for distribute in distribute_list:
            point_list_cache = []
            point_percentile_list = [sum(distribute[0:idx + 1]) * interval for idx, _ in enumerate(distribute[0:-1])]
            for point_percentile in point_percentile_list:
                point = train_data_x.iloc[int(len(train_data_x) * point_percentile)]
                if point not in point_list_cache:
                    point_list_cache.append(point)
            if point_list_cache not in points_list:
                points_list.append(point_list_cache)
        # IV与单调性过滤
        iv_max = 0
        breaks_list = []
        train_sv_bin_list = _get_sv_bins(train_data, x_column, y_column, special_values)
        test_sv_bin_list = None
        if test_data_filter is not None:
            test_sv_bin_list = _get_sv_bins(test_data, x_column, y_column, special_values)
        from tqdm import tqdm
        for point_list in tqdm(points_list):
            train_bins = _get_bins(train_data_filter, x_column, y_column, point_list)
            # 与special_values合并计算iv
            for sv_bin in train_sv_bin_list:
                train_bins = pd.concat((train_bins, sv_bin))
            train_iv = _calculation_iv(train_bins)
            # 只限制训练集的单调性与iv值大小
            if train_iv < iv_threshold:
                continue

            test_iv = 0
            if test_data_filter is not None:
                test_bins = _get_bins(test_data_filter, x_column, y_column, point_list)
                for sv_bin in test_sv_bin_list:
                    test_bins = pd.concat((test_bins, sv_bin))
                test_iv = _calculation_iv(test_bins)

            iv = train_iv + test_iv
            if iv > iv_max:
                iv_max = iv
                breaks_list = point_list

        return iv_max, breaks_list

    def filter(self, data: DataSplitEntity, *args, **kwargs):
        x_columns_candidate = self.data_process_config.x_columns_candidate
        candidate_num = self.data_process_config.candidate_num
        candidate_list: List[CandidateFeatureEntity] = []
        for x_column in x_columns_candidate:
            iv_max, breaks_list = self._f_get_best_bins(data, x_column)
            candidate_list.append(CandidateFeatureEntity(x_column, breaks_list, iv_max))
        candidate_list.sort(key=lambda x: x.iv_max, reverse=True)

        return candidate_list[0:candidate_num]