# -*- coding:utf-8 -*-
"""
@author: yq
@time: 2023/12/28
@desc:  特征工具类
"""
from typing import Union

import numpy as np
import pandas as pd
from statsmodels.stats.outliers_influence import variance_inflation_factor as vif

FORMAT_DICT = {
    # 比例类 -1 - 1
    "bin_rate1": np.arange(-1, 1 + 0.1, 0.1),

    # 次数类1 0 -10
    "bin_cnt1": np.arange(0, 11, 1),
    # 次数类2 0 - 20
    "bin_cnt2": [0, 1, 2, 3, 4, 5, 8, 10, 15, 20],
    # 次数类3 0 - 50
    "bin_cnt3": [0, 2, 4, 6, 8, 10, 15, 20, 25, 30, 35, 40, 45, 50],
    # 次数类4 0 - 100
    "bin_cnt4": [0, 3, 6, 10, 15, 20, 30, 40, 50, 80, 100],

    # 金额类1 0 - 1w
    "bin_amt1": np.arange(0, 1.1e4, 1e3),
    # 金额类2 0 - 5w
    "bin_amt2": np.arange(0, 5.5e4, 5e3),
    # 金额类3 0 - 10w
    "bin_amt3": np.arange(0, 11e4, 1e4),
    # 金额类4 0 - 20w
    "bin_amt4": [0, 1e4, 2e4, 3e4, 4e4, 5e4, 8e4, 10e4, 15e4, 20e4],
    # 金额类5 0 - 100w
    "bin_amt5": [0, 5e4, 10e4, 15e4, 20e4, 25e4, 30e4, 40e4, 50e4, 100e4],

    # 年龄类
    "bin_age": [20, 25, 30, 35, 40, 45, 50, 55, 60, 65],
}


# 粗分箱
def f_format_bin(data_describe: pd.Series, raw_v):
    percent10 = data_describe["10%"]
    percent90 = data_describe["90%"]
    format_v = raw_v

    # 筛选最合适的标准化分箱节点
    bin = None
    for k, v_list in FORMAT_DICT.items():
        bin_min = min(v_list)
        bin_max = max(v_list)
        if percent10 >= bin_min and percent90 <= bin_max:
            if bin is None:
                bin = (k, bin_max)
            elif bin[1] > bin_max:
                bin = (k, bin_max)

    if bin is None:
        return format_v

    # 选择分箱内适合的切分点
    v_list = FORMAT_DICT[bin[0]]
    for idx in range(1, len(v_list)):
        v_left = v_list[idx - 1]
        v_right = v_list[idx]
        # 就近原则
        if v_left <= raw_v <= v_right:
            format_v = v_right if (raw_v - v_left) - (v_right - raw_v) > 0 else v_left
    if format_v not in v_list:
        if format_v > v_list[-1]:
            format_v = v_list[-1]
        if format_v < v_list[0]:
            format_v = v_list[0]

    return format_v


# 单调性变化次数
def f_monto_shift(badprobs: list) -> int:
    if len(badprobs) <= 2:
        return 0
    before = badprobs[1] - badprobs[0]
    change_cnt = 0
    for i in range(2, len(badprobs)):
        next = badprobs[i] - badprobs[i - 1]
        # 后一位bad_rate减前一位bad_rate,保证bad_rate的单调性
        if (next >= 0 and before >= 0) or (next <= 0 and before <= 0):
            # 满足趋势保持,查看下一位
            continue
        else:
            # 记录一次符号变化
            before = next
            change_cnt += 1
    return change_cnt


# 变量趋势一致变化次数
def f_trend_shift(train_badprobs: list, test_badprobs: list) -> int:
    if len(train_badprobs) != len(test_badprobs) or len(train_badprobs) < 2 or len(test_badprobs) < 2:
        return 0
    train_monto = np.array(train_badprobs[1:]) - np.array(train_badprobs[0:-1])
    train_monto = np.where(train_monto >= 0, 1, -1)
    test_monto = np.array(test_badprobs[1:]) - np.array(test_badprobs[0:-1])
    test_monto = np.where(test_monto >= 0, 1, -1)
    contrast = train_monto - test_monto
    return len(contrast[contrast != 0])


def f_get_psi(train_bins, test_bins):
    train_bins['count'] = train_bins['good'] + train_bins['bad']
    train_bins['proportion'] = train_bins['count'] / train_bins['count'].sum()
    test_bins['count'] = test_bins['good'] + test_bins['bad']
    test_bins['proportion'] = test_bins['count'] / test_bins['count'].sum()

    psi = (train_bins['proportion'] - test_bins['proportion']) * np.log(
        train_bins['proportion'] / test_bins['proportion'])
    psi = psi.reset_index()
    psi = psi.rename(columns={"proportion": "psi"})

    return psi["psi"].sum().round(3)


def f_get_corr(data: pd.DataFrame, meth: str = 'spearman') -> pd.DataFrame:
    return data.corr(method=meth)


def f_get_vif(data: pd.DataFrame) -> Union[pd.DataFrame, None]:
    if len(data.columns.to_list()) <= 1:
        return None
    vif_v = [round(vif(data.values, data.columns.get_loc(i)), 3) for i in data.columns]
    df_vif = pd.DataFrame()
    df_vif["变量"] = [column.replace("_woe", "") for column in data.columns]
    df_vif['vif'] = vif_v
    return df_vif