# -*- coding: utf-8 -*-
"""
@author: yq
@time: 2024/11/1
@desc: 数据处理配置类
"""
import json
import os
from typing import List, Union

from commom import GeneralException, f_get_datetime
from config import BaseConfig
from enums import ResultCodesEnum
from init import warning_ignore


class MlConfigEntity():
    def __init__(self,
                 y_column: str,
                 project_name: str = None,
                 x_columns: List[str] = [],
                 columns_exclude: List[str] = [],
                 columns_include: List[str] = [],
                 columns_anns: dict = {},
                 bin_search_interval: float = 0.05,
                 bin_sample_rate: float = 0.1,
                 iv_threshold: float = 0.01,
                 corr_threshold: float = 0.4,
                 psi_threshold: float = 0.2,
                 vif_threshold: float = 10,
                 monto_shift_threshold=1,
                 trend_shift_threshold=0,
                 max_feature_num: int = 10,
                 special_values: Union[dict, list, str] = None,
                 breaks_list: dict = None,
                 format_bin: str = False,
                 jupyter_print=False,
                 bin_detail_print=True,
                 stress_test=False,
                 stress_sample_times=100,
                 stress_bad_rate_list: List[float] = [],
                 model_type="lr",
                 feature_strategy="woe",
                 rules=[],
                 fill_method: str = None,
                 fill_value=None,
                 *args, **kwargs):

        self._model_type = model_type

        self._feature_strategy = feature_strategy

        self._psi_threshold = psi_threshold

        self._vif_threshold = vif_threshold

        # 排除的x列
        self._columns_exclude = columns_exclude

        # 强制保留的x列
        self._columns_include = columns_include

        # 变量注释
        self._columns_anns = columns_anns

        # 是否开启下输出内容
        self._stress_test = stress_test

        # jupyter下输出内容
        self._stress_sample_times = stress_sample_times

        # jupyter下输出内容
        self._stress_bad_rate_list = stress_bad_rate_list

        # jupyter下输出内容
        self._jupyter_print = jupyter_print

        # jupyter下输出内容
        self._bin_detail_print = bin_detail_print

        # 单调性允许变化次数
        self._monto_shift_threshold = monto_shift_threshold

        # 变量趋势一致性允许变化次数
        self._trend_shift_threshold = trend_shift_threshold

        # 是否启用粗分箱
        self._format_bin = format_bin

        # 项目名称,和缓存路径有关
        self._project_name = project_name

        # 定义y变量
        self._y_column = y_column

        # 候选x变量
        self._x_columns = x_columns

        # 缺失值填充方法
        self._fill_method = fill_method

        # 缺失值填充值
        self._fill_value = fill_value

        # 使用iv筛变量时的阈值
        self._iv_threshold = iv_threshold

        # 贪婪搜索分箱时数据粒度大小,应该在0.01-0.1之间
        self._bin_search_interval = bin_search_interval

        # 最终保留多少x变量
        self._max_feature_num = max_feature_num

        self._special_values = special_values

        self._breaks_list = breaks_list

        # 变量相关性阈值
        self._corr_threshold = corr_threshold

        # 贪婪搜索采样比例,只针对4箱5箱时有效
        self._bin_sample_rate = bin_sample_rate

        # 加减分规则
        self._rules = rules

        if self._project_name is None or len(self._project_name) == 0:
            self._base_dir = os.path.join(BaseConfig.train_path, f"{f_get_datetime()}")
        else:
            self._base_dir = os.path.join(BaseConfig.train_path, self._project_name)

        self._include = columns_include + list(self.breaks_list.keys())

        os.makedirs(self._base_dir, exist_ok=True)
        print(f"项目路径:【{self._base_dir}】")

        if self._jupyter_print:
            warning_ignore()

    @property
    def model_type(self):
        return self._model_type

    @property
    def feature_strategy(self):
        return self._feature_strategy

    @property
    def psi_threshold(self):
        return self._psi_threshold

    @property
    def vif_threshold(self):
        return self._vif_threshold

    @property
    def stress_test(self):
        return self._stress_test

    @property
    def stress_sample_times(self):

        return self._stress_sample_times

    @property
    def stress_bad_rate_list(self):
        return self._stress_bad_rate_list

    @property
    def jupyter_print(self):
        return self._jupyter_print

    @property
    def bin_detail_print(self):
        return self._bin_detail_print

    @property
    def base_dir(self):
        return self._base_dir

    @property
    def monto_shift_threshold(self):
        return self._monto_shift_threshold

    @property
    def trend_shift_threshold(self):
        return self._trend_shift_threshold

    @property
    def format_bin(self):
        return self._format_bin

    @property
    def project_name(self):
        return self._project_name

    @property
    def bin_sample_rate(self):
        return self._bin_sample_rate

    @property
    def rules(self):
        return self._rules

    @property
    def corr_threshold(self):
        return self._corr_threshold

    @property
    def max_feature_num(self):
        return self._max_feature_num

    @property
    def y_column(self):
        return self._y_column

    @property
    def x_columns(self):
        return self._x_columns

    @property
    def columns_exclude(self):
        return self._columns_exclude

    @property
    def columns_include(self):
        return self._columns_include

    @property
    def columns_anns(self):
        return self._columns_anns

    @property
    def fill_value(self):
        return self._fill_value

    @property
    def fill_method(self):
        return self._fill_method

    @property
    def iv_threshold(self):
        return self._iv_threshold

    @property
    def bin_search_interval(self):
        return self._bin_search_interval

    @property
    def special_values(self):
        if self._special_values is None or len(self._special_values) == 0:
            return None
        if isinstance(self._special_values, str):
            return [self._special_values]
        if isinstance(self._special_values, (dict, list)):
            return self._special_values
        return None

    def get_special_values(self, column: str = None):
        if self._special_values is None or len(self._special_values) == 0:
            return []
        if isinstance(self._special_values, str):
            return [self._special_values]
        if isinstance(self._special_values, list):
            return self._special_values
        if isinstance(self._special_values, dict) and column is not None:
            return self._special_values.get(column, [])
        return []

    @property
    def breaks_list(self):
        if self._breaks_list is None:
            return {}
        if isinstance(self._breaks_list, dict):
            return self._breaks_list
        return {}

    def get_breaks_list(self, column: str = None):
        if self._breaks_list is None or len(self._breaks_list) == 0:
            return []
        if isinstance(self._breaks_list, dict) and column is not None:
            return self._breaks_list.get(column, [])
        return []

    def is_include(self, column: str) -> bool:
        return column in self._include

    def f_get_save_path(self, file_name: str) -> str:
        path = os.path.join(self._base_dir, file_name)
        return path

    @staticmethod
    def from_config(config_path: str):
        """
        从配置文件生成实体类
        """
        if os.path.isdir(config_path):
            config_path = os.path.join(config_path, "mlcfg.json")

        if os.path.exists(config_path):
            with open(config_path, mode="r", encoding="utf-8") as f:
                j = json.loads(f.read())
        else:
            raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"指配置文件【{config_path}】不存在")
        print(f"mlcfg load from【{config_path}】success. ")
        return MlConfigEntity(**j)

    def config_save(self):
        path = self.f_get_save_path("mlcfg.json")
        with open(path, mode="w", encoding="utf-8") as f:
            j = {k.lstrip("_"): v for k, v in self.__dict__.items()}
            j = json.dumps(j, ensure_ascii=False)
            f.write(j)
        print(f"mlcfg save to【{path}】success. ")


if __name__ == "__main__":
    pass