# -*- coding: utf-8 -*-
"""
@author: yq
@time: 2024/11/1
@desc: 数据处理配置类
"""
import json
import os
from typing import List, Union

from commom import GeneralException, f_get_datetime
from config import BaseConfig
from enums import ResultCodesEnum


class DataProcessConfigEntity():
    def __init__(self, y_column: str, x_columns_candidate: List[str] = None, fill_method: str = None, fill_value=None,
                 split_method: str = None, feature_search_strategy: str = 'iv', bin_search_interval: float = 0.05,
                 iv_threshold: float = 0.03, iv_threshold_wide: float = 0.05, corr_threshold: float = 0.4,
                 sample_rate: float = 0.1, x_candidate_num: int = 10, special_values: Union[dict, list, str] = None,
                 project_name: str = None, format_bin: str = False, breaks_list: dict = None, pos_neg_cnt=1,
                 jupyter=False, *args, **kwargs):

        # 单调性允许变化次数
        self._jupyter = jupyter

        # 单调性允许变化次数
        self._pos_neg_cnt = pos_neg_cnt

        # 是否启用粗分箱
        self._format_bin = format_bin

        # 项目名称,和缓存路径有关
        self._project_name = project_name

        # 定义y变量
        self._y_column = y_column

        # 候选x变量
        self._x_columns_candidate = x_columns_candidate

        # 缺失值填充方法
        self._fill_method = fill_method

        # 缺失值填充值
        self._fill_value = fill_value

        # 数据划分方法
        self._split_method = split_method

        # 最优特征搜索方法
        self._feature_search_strategy = feature_search_strategy

        # 使用iv筛变量时的阈值
        self._iv_threshold = iv_threshold

        # 使用iv粗筛变量时的阈值
        self._iv_threshold_wide = iv_threshold_wide

        # 贪婪搜索分箱时数据粒度大小,应该在0.01-0.1之间
        self._bin_search_interval = bin_search_interval

        # 最终保留多少x变量
        self._x_candidate_num = x_candidate_num

        self._special_values = special_values

        self._breaks_list = breaks_list

        # 变量相关性阈值
        self._corr_threshold = corr_threshold

        # 贪婪搜索采样比例,只针对4箱5箱时有效
        self._sample_rate = sample_rate

        if self._project_name is None or len(self._project_name) == 0:
            self._base_dir = os.path.join(BaseConfig.train_path, f"{f_get_datetime()}")
        else:
            self._base_dir = os.path.join(BaseConfig.train_path, self._project_name)

        os.makedirs(self._base_dir, exist_ok=True)

    @property
    def jupyter(self):
        return self._jupyter

    @property
    def base_dir(self):
        return self._base_dir

    @property
    def pos_neg_cnt(self):
        return self._pos_neg_cnt

    @property
    def format_bin(self):
        return self._format_bin

    @property
    def project_name(self):
        return self._project_name

    @property
    def sample_rate(self):
        return self._sample_rate

    @property
    def corr_threshold(self):
        return self._corr_threshold

    @property
    def iv_threshold_wide(self):
        return self._iv_threshold_wide

    @property
    def candidate_num(self):
        return self._x_candidate_num

    @property
    def y_column(self):
        return self._y_column

    @property
    def x_columns_candidate(self):
        return self._x_columns_candidate

    @property
    def fill_value(self):
        return self._fill_value

    @property
    def fill_method(self):
        return self._fill_method

    @property
    def split_method(self):
        return self._split_method

    @property
    def feature_search_strategy(self):
        return self._feature_search_strategy

    @property
    def iv_threshold(self):
        return self._iv_threshold

    @property
    def bin_search_interval(self):
        return self._bin_search_interval

    @property
    def special_values(self):
        if self._special_values is None or len(self._special_values) == 0:
            return None
        if isinstance(self._special_values, str):
            return [self._special_values]
        if isinstance(self._special_values, (dict, list)):
            return self._special_values
        return None

    def get_special_values(self, column: str = None):
        if self._special_values is None or len(self._special_values) == 0:
            return []
        if isinstance(self._special_values, str):
            return [self._special_values]
        if isinstance(self._special_values, list):
            return self._special_values
        if isinstance(self._special_values, dict) and column is not None:
            return self._special_values.get(column, [])
        return []

    @property
    def breaks_list(self):
        if self._breaks_list is None:
            return {}
        if isinstance(self._breaks_list, dict):
            return self._breaks_list
        return {}

    def get_breaks_list(self, column: str = None):
        if self._breaks_list is None or len(self._breaks_list) == 0:
            return []
        if isinstance(self._breaks_list, dict) and column is not None:
            return self._breaks_list.get(column, [])
        return []

    def f_get_save_path(self, file_name: str) -> str:
        path = os.path.join(self._base_dir, file_name)
        return path

    @staticmethod
    def from_config(config_path: str):
        """
        从配置文件生成实体类
        """
        if os.path.exists(config_path):
            with open(config_path, mode="r", encoding="utf-8") as f:
                j = json.loads(f.read())
        else:
            raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"指配置文件【{config_path}】不存在")

        return DataProcessConfigEntity(**j)


if __name__ == "__main__":
    pass