# -*- coding: utf-8 -*- """ @author: yq @time: 2024/11/1 @desc: 数据处理配置类 """ import json import os from typing import List, Union from commom import GeneralException, f_get_datetime from config import BaseConfig from enums import ResultCodesEnum, FileEnum from init import warning_ignore class MlConfigEntity(): def __init__(self, y_column: str, project_name: str = None, x_columns: List[str] = [], columns_exclude: List[str] = [], columns_include: List[str] = [], columns_anns: dict = {}, bin_search_interval: float = 0.05, bin_sample_rate: float = 0.1, iv_threshold: float = 0.01, corr_threshold: float = 0.4, psi_threshold: float = 0.2, vif_threshold: float = 10, monto_shift_threshold=1, trend_shift_threshold=0, max_feature_num: int = 10, special_values: Union[dict, list, str] = None, breaks_list: dict = None, format_bin: str = False, jupyter_print=False, bin_detail_print=True, stress_test=False, stress_sample_times=100, stress_bad_rate_list: List[float] = [], model_type="lr", feature_strategy="woe", rules=[], fill_method: str = None, fill_value=None, *args, **kwargs): self._model_type = model_type self._feature_strategy = feature_strategy self._psi_threshold = psi_threshold self._vif_threshold = vif_threshold # 排除的x列 self._columns_exclude = columns_exclude # 强制保留的x列 self._columns_include = columns_include # 变量注释 self._columns_anns = columns_anns # 是否开启下输出内容 self._stress_test = stress_test # jupyter下输出内容 self._stress_sample_times = stress_sample_times # jupyter下输出内容 self._stress_bad_rate_list = stress_bad_rate_list # jupyter下输出内容 self._jupyter_print = jupyter_print # jupyter下输出内容 self._bin_detail_print = bin_detail_print # 单调性允许变化次数 self._monto_shift_threshold = monto_shift_threshold # 变量趋势一致性允许变化次数 self._trend_shift_threshold = trend_shift_threshold # 是否启用粗分箱 self._format_bin = format_bin # 项目名称,和缓存路径有关 self._project_name = project_name # 定义y变量 self._y_column = y_column # 候选x变量 self._x_columns = x_columns # 缺失值填充方法 self._fill_method = fill_method # 缺失值填充值 self._fill_value = fill_value # 使用iv筛变量时的阈值 self._iv_threshold = iv_threshold # 贪婪搜索分箱时数据粒度大小,应该在0.01-0.1之间 self._bin_search_interval = bin_search_interval # 最终保留多少x变量 self._max_feature_num = max_feature_num self._special_values = special_values self._breaks_list = breaks_list # 变量相关性阈值 self._corr_threshold = corr_threshold # 贪婪搜索采样比例,只针对4箱5箱时有效 self._bin_sample_rate = bin_sample_rate # 加减分规则 self._rules = rules if self._project_name is None or len(self._project_name) == 0: self._base_dir = os.path.join(BaseConfig.train_path, f"{f_get_datetime()}") else: self._base_dir = os.path.join(BaseConfig.train_path, self._project_name) self._include = columns_include + list(self.breaks_list.keys()) os.makedirs(self._base_dir, exist_ok=True) print(f"项目路径:【{self._base_dir}】") if self._jupyter_print: warning_ignore() @property def model_type(self): return self._model_type @property def feature_strategy(self): return self._feature_strategy @property def psi_threshold(self): return self._psi_threshold @property def vif_threshold(self): return self._vif_threshold @property def stress_test(self): return self._stress_test @property def stress_sample_times(self): return self._stress_sample_times @property def stress_bad_rate_list(self): return self._stress_bad_rate_list @property def jupyter_print(self): return self._jupyter_print @property def bin_detail_print(self): return self._bin_detail_print @property def base_dir(self): return self._base_dir @property def monto_shift_threshold(self): return self._monto_shift_threshold @property def trend_shift_threshold(self): return self._trend_shift_threshold @property def format_bin(self): return self._format_bin @property def project_name(self): return self._project_name @property def bin_sample_rate(self): return self._bin_sample_rate @property def rules(self): return self._rules @property def corr_threshold(self): return self._corr_threshold @property def max_feature_num(self): return self._max_feature_num @property def y_column(self): return self._y_column @property def x_columns(self): return self._x_columns @property def columns_exclude(self): return self._columns_exclude @property def columns_include(self): return self._columns_include @property def columns_anns(self): return self._columns_anns @property def fill_value(self): return self._fill_value @property def fill_method(self): return self._fill_method @property def iv_threshold(self): return self._iv_threshold @property def bin_search_interval(self): return self._bin_search_interval @property def special_values(self): if self._special_values is None or len(self._special_values) == 0: return None if isinstance(self._special_values, str): return [self._special_values] if isinstance(self._special_values, (dict, list)): return self._special_values return None def get_special_values(self, column: str = None): if self._special_values is None or len(self._special_values) == 0: return [] if isinstance(self._special_values, str): return [self._special_values] if isinstance(self._special_values, list): return self._special_values if isinstance(self._special_values, dict) and column is not None: return self._special_values.get(column, []) return [] @property def breaks_list(self): if self._breaks_list is None: return {} if isinstance(self._breaks_list, dict): return self._breaks_list return {} def get_breaks_list(self, column: str = None): if self._breaks_list is None or len(self._breaks_list) == 0: return [] if isinstance(self._breaks_list, dict) and column is not None: return self._breaks_list.get(column, []) return [] def is_include(self, column: str) -> bool: return column in self._include def f_get_save_path(self, file_name: str) -> str: path = os.path.join(self._base_dir, file_name) return path @staticmethod def from_config(config_path: str): """ 从配置文件生成实体类 """ if os.path.isdir(config_path): config_path = os.path.join(config_path, FileEnum.MLCFG.value) if os.path.exists(config_path): with open(config_path, mode="r", encoding="utf-8") as f: j = json.loads(f.read()) else: raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"指配置文件【{config_path}】不存在") print(f"mlcfg load from【{config_path}】success. ") return MlConfigEntity(**j) def config_save(self): path = self.f_get_save_path(FileEnum.MLCFG.value) with open(path, mode="w", encoding="utf-8") as f: j = {k.lstrip("_"): v for k, v in self.__dict__.items()} j = json.dumps(j, ensure_ascii=False) f.write(j) print(f"mlcfg save to【{path}】success. ") if __name__ == "__main__": pass