# -*- coding: utf-8 -*-
"""
@author: yq
@time: 2025/2/27
@desc: 
"""
import json
import math
import os
import re
from os.path import dirname, realpath
from typing import Dict, List

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scorecardpy as sc
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

from commom import GeneralException, f_image_crop_white_borders, f_df_to_image, f_display_title, \
    f_display_images_by_side
from entitys import DataSplitEntity, OnlineLearningConfigEntity, MetricFucResultEntity
from enums import ResultCodesEnum, ConstantEnum, ContextEnum, FileEnum
from feature import f_woebin_load
from init import init, context
from model import f_get_model_score_bin, f_calcu_model_ks, f_stress_test, f_calcu_model_psi
from monitor import ReportWord
from .utils import LR

init()


class OnlineLearningTrainerLr:
    def __init__(self, data: DataSplitEntity = None, ol_config: OnlineLearningConfigEntity = None, *args, **kwargs):
        if ol_config is not None:
            self._ol_config = ol_config
        else:
            self._ol_config = OnlineLearningConfigEntity(*args, **kwargs)
        self._data = data
        self._columns = None
        self._model_original: LR
        self._model_optimized: LR
        self._df_param_optimized = None
        self.sc_woebin = None
        self.card_cfg = None
        self.card = None
        # 报告模板
        self._template_path = os.path.join(dirname(dirname(realpath(__file__))),
                                           "./template/OnlineLearning报告模板_lr.docx")
        self._init(self._ol_config.path_resources)

    def _init(self, path: str):
        if not os.path.isdir(path):
            raise GeneralException(ResultCodesEnum.ILLEGAL_PARAMS, message=f"【{path}】不是文件夹")

        path_coef = os.path.join(path, FileEnum.COEF.value)
        if not os.path.isfile(path_coef):
            raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"模型系数文件【{path_coef}】不存在")
        with open(path_coef, mode="r", encoding="utf-8") as f:
            coef = json.loads(f.read())
            print(f"coef load from【{path_coef}】success.")

        path_card_cfg = os.path.join(path, FileEnum.CARD_CFG.value)
        if os.path.isfile(path_card_cfg):
            with open(path_card_cfg, mode="r", encoding="utf-8") as f:
                self.card_cfg = json.loads(f.read())
                print(f"{FileEnum.CARD_CFG.value} load from【{path_card_cfg}】success.")

        self._columns = list(coef.keys())
        # 排个序,防止因为顺序原因导致的可能的bug
        self._columns.sort()
        weight = [coef[k] for k in self._columns]
        self._model_original = LR(nn.Parameter(torch.tensor(np.array(weight))))
        self._model_optimized = LR(nn.Parameter(torch.tensor(np.array(weight))))

        self._columns = [re.sub('_woe$', '', i) for i in self._columns]
        # 剔除常数项,因为woe编码里没有常数项
        self._columns_intercept_remove = self._columns.copy()
        if ConstantEnum.INTERCEPT.value in self._columns_intercept_remove:
            self._columns_intercept_remove.remove(ConstantEnum.INTERCEPT.value)
        # woe编码后带_woe后缀
        self._columns_woe = [f"{i}_woe" for i in self._columns]

        self.sc_woebin = f_woebin_load(path)
        for k in self._columns_intercept_remove:
            if k not in self.sc_woebin.keys():
                GeneralException(ResultCodesEnum.ILLEGAL_PARAMS, message=f"模型变量【{k}】在woe特征里不存在")

    def _feature_generate(self, data: pd.DataFrame) -> pd.DataFrame:
        data_woe = sc.woebin_ply(data[self._columns_intercept_remove], self.sc_woebin, print_info=False)
        data_woe[f"{ConstantEnum.INTERCEPT.value}_woe"] = [1] * len(data_woe)
        return data_woe[self._columns_woe].to_numpy()

    def _f_get_best_model(self, df_param: pd.DataFrame, epoch: int = None) -> LR:
        if epoch is None:
            df_param_sort = df_param.sort_values(by=["ks_test", "auc_test"], ascending=[False, False])
            print(f"选择最佳参数:\n{df_param_sort.iloc[0].to_dict()}")
            weight = list(df_param_sort.iloc[0])
        else:
            print(f"选择epoch:【{epoch}】的参数:\n{df_param[df_param['epoch'] == epoch].iloc[0].to_dict()}")
            weight = list(df_param[df_param["epoch"] == epoch].iloc[0])
        weight = nn.Parameter(torch.tensor(np.array(weight[0:-5])))
        return LR(weight)

    def _f_get_scorecard(self, ):
        class M:
            def __init__(self, ):
                pass

        m = M()
        m.coef_ = [self._model_optimized.linear.weight.tolist()]
        m.intercept_ = [0]
        self.card = sc.scorecard(self.sc_woebin, m, self._columns_woe, **self.card_cfg)

    def _f_get_metric_auc_ks(self, model_type: str):
        def _get_auc_ks(data, title):
            y = data[self._ol_config.y_column]
            y_prob = self.prob(data, model)
            perf = sc.perf_eva(y, y_prob, title=f"{title}", show_plot=True)
            path = self._ol_config.f_get_save_path(f"perf_{title}.png")
            perf["pic"].savefig(path)
            auc = perf["AUC"]
            ks = perf["KS"]
            f_image_crop_white_borders(path, path)
            return auc, ks, path

        train_data = self._data.train_data
        test_data = self._data.test_data
        data = self._data.data
        model = self._model_optimized
        if model_type != "新模型":
            model = self._model_original

        img_path_auc_ks = []
        auc, ks, path = _get_auc_ks(data, f"{model_type}-建模数据")
        img_path_auc_ks.append(path)
        train_auc, train_ks, path = _get_auc_ks(train_data, f"{model_type}-训练集")
        img_path_auc_ks.append(path)
        test_auc, test_ks, path = _get_auc_ks(test_data, f"{model_type}-测试集")
        img_path_auc_ks.append(path)

        df_auc_ks = pd.DataFrame()
        df_auc_ks["样本集"] = ["建模数据", "训练集", "测试集"]
        df_auc_ks["AUC"] = [auc, train_auc, test_auc]
        df_auc_ks["KS"] = [ks, train_ks, test_ks]

        return MetricFucResultEntity(table=df_auc_ks, image_path=img_path_auc_ks, image_size=5, table_font_size=10)

    def _f_get_metric_trend(self, ):
        y_column = self._ol_config.y_column
        data = self._data.data

        # 建模样本变量趋势
        breaks_list = {}
        special_values = {}
        for column, bin in self.sc_woebin.items():
            breaks_list[column] = list(bin[bin["is_special_values"] == False]['breaks'])
            sv = list(bin[bin["is_special_values"] == True]['breaks'])
            if len(sv) > 0:
                special_values[column] = sv
        woebin = sc.woebin(data[self._columns_intercept_remove + [y_column]], y=y_column, breaks_list=breaks_list,
                           special_values=special_values, print_info=False)

        imgs_path = []
        for k, df_bin in woebin.items():
            sc.woebin_plot(df_bin)
            path = self._ol_config.f_get_save_path(f"trend_{k}.png")
            plt.savefig(path)
            imgs_path.append(path)
        return MetricFucResultEntity(image_path=imgs_path, image_size=4)

    def _f_get_metric_coef(self, ):
        columns_anns = self._ol_config.columns_anns
        df = pd.DataFrame()
        df["变量"] = self._columns
        df["原变量WOE拟合系数"] = [round(i, 4) for i in self._model_original.linear.weight.tolist()]
        df["新变量WOE拟合系数"] = [round(i, 4) for i in self._model_optimized.linear.weight.tolist()]
        anns = [columns_anns.get(column, "-") for column in self._columns]
        df["释义"] = anns
        img_path_coef = self._ol_config.f_get_save_path(f"coef.png")
        f_df_to_image(df, img_path_coef)
        return MetricFucResultEntity(table=df, image_path=img_path_coef)

    def _f_get_metric_gain(self, model_type: str):
        y_column = self._ol_config.y_column
        data = self._data.data

        model = self._model_optimized
        if model_type != "新模型":
            model = self._model_original

        score = self.prob(data, model)
        score_bin, _ = f_get_model_score_bin(data, score)
        gain = f_calcu_model_ks(score_bin, y_column, sort_ascending=False)
        img_path_gain = self._ol_config.f_get_save_path(f"{model_type}-gain.png")
        f_df_to_image(gain, img_path_gain)

        return MetricFucResultEntity(table=gain, image_path=img_path_gain)

    def _f_get_stress_test(self, ):
        stress_sample_times = self._ol_config.stress_sample_times
        stress_bad_rate_list = self._ol_config.stress_bad_rate_list
        y_column = self._ol_config.y_column
        data = self._data.data
        score = self.prob(data, self._model_optimized)
        score_bin, _ = f_get_model_score_bin(data, score)
        df_stress = f_stress_test(score_bin, sample_times=stress_sample_times, bad_rate_list=stress_bad_rate_list,
                                  target_column=y_column, score_column=ConstantEnum.SCORE.value, sort_ascending=False)

        img_path_stress = self._ol_config.f_get_save_path(f"stress.png")
        f_df_to_image(df_stress, img_path_stress)
        return MetricFucResultEntity(table=df_stress, image_path=img_path_stress)

    def prob(self, x: pd.DataFrame, model=None):
        if model is None:
            model = self._model_optimized
        model.eval()
        with torch.no_grad():
            x = torch.tensor(self._feature_generate(x), dtype=torch.float64)
            y_prob = model(x)
            y_prob = y_prob.detach().numpy()
            return y_prob

    def score(self, x: pd.DataFrame) -> np.array:
        return np.array(sc.scorecard_ply(x, self.card, print_step=0)["score"])

    def psi(self, x1: pd.DataFrame, x2: pd.DataFrame, points: List[float] = None) -> pd.DataFrame:
        y1 = self.prob(x1)
        y2 = self.prob(x2)
        x1_score_bin, score_bins = f_get_model_score_bin(x1, y1, points)
        x2_score_bin, _ = f_get_model_score_bin(x2, y2, score_bins)
        model_psi = f_calcu_model_psi(x1_score_bin, x2_score_bin, sort_ascending=False)
        print(f"模型psi: {model_psi['psi'].sum()}")
        return model_psi

    def train(self, ):
        def _get_param_optimized(model: LR, epoch):
            model.eval()
            with torch.no_grad():
                y_prob = model(test_x)
                loss = criterion(y_prob, torch.tensor(test_y.to_numpy(), dtype=torch.float64))
                loss_test = loss.detach().item()
                y_prob = y_prob.detach().numpy()
                perf = sc.perf_eva(test_y, y_prob, show_plot=False)
                auc = perf["AUC"]
                ks = perf["KS"]
                row = model.linear.weight.tolist() + [auc, ks, epoch + 1, loss_train, loss_test]
                return dict(zip(df_param_columns, row))

        epochs = self._ol_config.epochs
        batch_size = self._ol_config.batch_size
        train_data = self._data.train_data
        test_data = self._data.test_data
        train_x = self._feature_generate(train_data)
        train_y = train_data[self._ol_config.y_column].to_numpy()
        test_x = torch.tensor(self._feature_generate(test_data), dtype=torch.float64)
        test_y = test_data[self._ol_config.y_column]

        criterion = nn.BCELoss()
        optimizer = optim.Adam(self._model_optimized.parameters(), lr=self._ol_config.lr)

        df_param_columns = self._columns + ["auc_test", "ks_test", "epoch", "loss_train", "loss_test"]
        self._df_param_optimized = pd.DataFrame(columns=df_param_columns)
        
        # 优化前
        loss_train = 0
        self._df_param_optimized.loc[len(self._df_param_optimized)] = _get_param_optimized(self._model_original, -1)
        for epoch in tqdm(range(epochs)):
            data_len = len(train_x)
            for i in range(math.ceil(data_len / batch_size)):
                train_x_batch = torch.tensor(train_x[i * batch_size:(i + 1) * batch_size], dtype=torch.float64)
                train_y_batch = torch.tensor(train_y[i * batch_size:(i + 1) * batch_size], dtype=torch.float64)
                self._model_optimized.train()
                optimizer.zero_grad()
                y_prob = self._model_optimized(train_x_batch)
                loss = criterion(y_prob, train_y_batch)
                loss.backward()
                optimizer.step()
                loss_train = loss.detach().item()
            # 测试集评估
            self._df_param_optimized.loc[len(self._df_param_optimized)] = _get_param_optimized(self._model_optimized, epoch)

    def save(self):

        self._ol_config.config_save()

        if self.sc_woebin is None:
            GeneralException(ResultCodesEnum.NOT_FOUND, message=f"feature不存在")
        df_woebin = pd.concat(self.sc_woebin.values())
        path = self._ol_config.f_get_save_path(FileEnum.FEATURE.value)
        df_woebin.to_csv(path)
        print(f"feature save to【{path}】success. ")

        if self._model_optimized is None:
            GeneralException(ResultCodesEnum.NOT_FOUND, message=f"模型不存在")
        path = self._ol_config.f_get_save_path(FileEnum.COEF.value)
        with open(path, mode="w", encoding="utf-8") as f:
            coef = dict(zip(self._columns, self._model_optimized.linear.weight.tolist()))
            j = json.dumps(coef, ensure_ascii=False)
            f.write(j)
        print(f"model save to【{path}】success. ")

        if self.card is not None:
            df_card = pd.concat(self.card.values())
            path = self._ol_config.f_get_save_path(FileEnum.CARD.value)
            df_card.to_csv(path)
            print(f"model save to【{path}】success. ")

    @staticmethod
    def load(path: str):
        ol_config = OnlineLearningConfigEntity.from_config(path)
        ol_config._path_resources = path
        return OnlineLearningTrainerLr(ol_config=ol_config)

    def report(self, epoch: int = None):
        self._model_optimized = self._f_get_best_model(self._df_param_optimized, epoch)

        if self._ol_config.jupyter_print:
            from IPython import display
            f_display_title(display, "模型系数优化过程")
            display.display(self._df_param_optimized)

        metric_value_dict = {}

        # 评分卡
        if not self.card_cfg is None:
            self._f_get_scorecard()
            df_card = pd.concat(self.card.values())
            img_path_card = self._ol_config.f_get_save_path(f"card.png")
            f_df_to_image(df_card, img_path_card)
            metric_value_dict["评分卡"] = MetricFucResultEntity(table=df_card, image_path=img_path_card)

        # 样本分布
        metric_value_dict["样本分布"] = MetricFucResultEntity(table=self._data.get_distribution(self._ol_config.y_column),
                                                          table_font_size=10, table_cell_width=3)

        # 模型结果对比
        metric_value_dict[f"模型结果-新模型"] = self._f_get_metric_auc_ks("新模型")
        metric_value_dict[f"模型结果-原模型"] = self._f_get_metric_auc_ks("原模型")

        # 变量趋势
        metric_value_dict["变量趋势-建模数据"] = self._f_get_metric_trend()

        # 模型系数对比
        metric_value_dict["模型系数"] = self._f_get_metric_coef()

        # 分数分箱
        metric_value_dict["分数分箱-建模数据-新模型"] = self._f_get_metric_gain("新模型")
        metric_value_dict["分数分箱-建模数据-原模型"] = self._f_get_metric_gain("原模型")

        # 压力测试
        if self._ol_config.stress_test:
            metric_value_dict["压力测试"] = self._f_get_stress_test()

        if self._ol_config.jupyter_print:
            self.jupyter_print(metric_value_dict)

        save_path = self._ol_config.f_get_save_path("OnlineLearning报告.docx")
        ReportWord.generate_report(metric_value_dict, self._template_path, save_path=save_path)
        print(f"模型报告文件储存路径:{save_path}")

    def jupyter_print(self, metric_value_dict=Dict[str, MetricFucResultEntity]):
        from IPython import display

        f_display_title(display, "样本分布")
        display.display(metric_value_dict["样本分布"].table)

        f_display_title(display, "模型结果")
        print(f"原模型")
        display.display(metric_value_dict["模型结果-原模型"].table)
        f_display_images_by_side(display, metric_value_dict["模型结果-原模型"].image_path)
        print(f"新模型")
        display.display(metric_value_dict["模型结果-新模型"].table)
        f_display_images_by_side(display, metric_value_dict["模型结果-新模型"].image_path)

        f_display_title(display, "模型系数")
        display.display(metric_value_dict["模型系数"].table)

        f_display_title(display, "分数分箱")
        print(f"建模数据上分数分箱")
        print(f"原模型")
        display.display(metric_value_dict["分数分箱-建模数据-原模型"].table)
        print(f"新模型")
        display.display(metric_value_dict["分数分箱-建模数据-新模型"].table)

        f_display_title(display, "变量趋势")
        print(f"建模数据上变量趋势")
        f_display_images_by_side(display, metric_value_dict["变量趋势-建模数据"].image_path)

        if "压力测试" in metric_value_dict.keys():
            f_display_title(display, "压力测试")
            display.display(metric_value_dict["压力测试"].table)

        # 评分卡
        if "评分卡" in metric_value_dict.keys():
            f_display_title(display, "评分卡")
            display.display(metric_value_dict["评分卡"].table)


if __name__ == "__main__":
    pass