123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318 |
- # -*- coding: utf-8 -*-
- """
- @author: yq
- @time: 2024/11/1
- @desc: 数据处理配置类
- """
- import json
- import os
- from typing import List, Union
- from commom import GeneralException, f_get_datetime
- from config import BaseConfig
- from enums import ResultCodesEnum
- from init import warning_ignore
- class MlConfigEntity():
- def __init__(self,
- y_column: str,
- project_name: str = None,
- x_columns: List[str] = [],
- columns_exclude: List[str] = [],
- columns_include: List[str] = [],
- columns_anns: dict = {},
- bin_search_interval: float = 0.05,
- bin_sample_rate: float = 0.1,
- iv_threshold: float = 0.01,
- corr_threshold: float = 0.4,
- psi_threshold: float = 0.2,
- vif_threshold: float = 10,
- monto_shift_threshold=1,
- trend_shift_threshold=0,
- max_feature_num: int = 10,
- special_values: Union[dict, list, str] = None,
- breaks_list: dict = None,
- format_bin: str = False,
- jupyter_print=False,
- bin_detail_print=True,
- stress_test=False,
- stress_sample_times=100,
- stress_bad_rate_list: List[float] = [],
- model_type="lr",
- feature_strategy="woe",
- rules=[],
- fill_method: str = None,
- fill_value=None,
- *args, **kwargs):
- self._model_type = model_type
- self._feature_strategy = feature_strategy
- self._psi_threshold = psi_threshold
- self._vif_threshold = vif_threshold
- # 排除的x列
- self._columns_exclude = columns_exclude
- # 强制保留的x列
- self._columns_include = columns_include
- # 变量注释
- self._columns_anns = columns_anns
- # 是否开启下输出内容
- self._stress_test = stress_test
- # jupyter下输出内容
- self._stress_sample_times = stress_sample_times
- # jupyter下输出内容
- self._stress_bad_rate_list = stress_bad_rate_list
- # jupyter下输出内容
- self._jupyter_print = jupyter_print
- # jupyter下输出内容
- self._bin_detail_print = bin_detail_print
- # 单调性允许变化次数
- self._monto_shift_threshold = monto_shift_threshold
- # 变量趋势一致性允许变化次数
- self._trend_shift_threshold = trend_shift_threshold
- # 是否启用粗分箱
- self._format_bin = format_bin
- # 项目名称,和缓存路径有关
- self._project_name = project_name
- # 定义y变量
- self._y_column = y_column
- # 候选x变量
- self._x_columns = x_columns
- # 缺失值填充方法
- self._fill_method = fill_method
- # 缺失值填充值
- self._fill_value = fill_value
- # 使用iv筛变量时的阈值
- self._iv_threshold = iv_threshold
- # 贪婪搜索分箱时数据粒度大小,应该在0.01-0.1之间
- self._bin_search_interval = bin_search_interval
- # 最终保留多少x变量
- self._max_feature_num = max_feature_num
- self._special_values = special_values
- self._breaks_list = breaks_list
- # 变量相关性阈值
- self._corr_threshold = corr_threshold
- # 贪婪搜索采样比例,只针对4箱5箱时有效
- self._bin_sample_rate = bin_sample_rate
- # 加减分规则
- self._rules = rules
- if self._project_name is None or len(self._project_name) == 0:
- self._base_dir = os.path.join(BaseConfig.train_path, f"{f_get_datetime()}")
- else:
- self._base_dir = os.path.join(BaseConfig.train_path, self._project_name)
- self._include = columns_include + list(self.breaks_list.keys())
- os.makedirs(self._base_dir, exist_ok=True)
- print(f"项目路径:【{self._base_dir}】")
- if self._jupyter_print:
- warning_ignore()
- @property
- def model_type(self):
- return self._model_type
- @property
- def feature_strategy(self):
- return self._feature_strategy
- @property
- def psi_threshold(self):
- return self._psi_threshold
- @property
- def vif_threshold(self):
- return self._vif_threshold
- @property
- def stress_test(self):
- return self._stress_test
- @property
- def stress_sample_times(self):
- return self._stress_sample_times
- @property
- def stress_bad_rate_list(self):
- return self._stress_bad_rate_list
- @property
- def jupyter_print(self):
- return self._jupyter_print
- @property
- def bin_detail_print(self):
- return self._bin_detail_print
- @property
- def base_dir(self):
- return self._base_dir
- @property
- def monto_shift_threshold(self):
- return self._monto_shift_threshold
- @property
- def trend_shift_threshold(self):
- return self._trend_shift_threshold
- @property
- def format_bin(self):
- return self._format_bin
- @property
- def project_name(self):
- return self._project_name
- @property
- def bin_sample_rate(self):
- return self._bin_sample_rate
- @property
- def rules(self):
- return self._rules
- @property
- def corr_threshold(self):
- return self._corr_threshold
- @property
- def max_feature_num(self):
- return self._max_feature_num
- @property
- def y_column(self):
- return self._y_column
- @property
- def x_columns(self):
- return self._x_columns
- @property
- def columns_exclude(self):
- return self._columns_exclude
- @property
- def columns_include(self):
- return self._columns_include
- @property
- def columns_anns(self):
- return self._columns_anns
- @property
- def fill_value(self):
- return self._fill_value
- @property
- def fill_method(self):
- return self._fill_method
- @property
- def iv_threshold(self):
- return self._iv_threshold
- @property
- def bin_search_interval(self):
- return self._bin_search_interval
- @property
- def special_values(self):
- if self._special_values is None or len(self._special_values) == 0:
- return None
- if isinstance(self._special_values, str):
- return [self._special_values]
- if isinstance(self._special_values, (dict, list)):
- return self._special_values
- return None
- def get_special_values(self, column: str = None):
- if self._special_values is None or len(self._special_values) == 0:
- return []
- if isinstance(self._special_values, str):
- return [self._special_values]
- if isinstance(self._special_values, list):
- return self._special_values
- if isinstance(self._special_values, dict) and column is not None:
- return self._special_values.get(column, [])
- return []
- @property
- def breaks_list(self):
- if self._breaks_list is None:
- return {}
- if isinstance(self._breaks_list, dict):
- return self._breaks_list
- return {}
- def get_breaks_list(self, column: str = None):
- if self._breaks_list is None or len(self._breaks_list) == 0:
- return []
- if isinstance(self._breaks_list, dict) and column is not None:
- return self._breaks_list.get(column, [])
- return []
- def is_include(self, column: str) -> bool:
- return column in self._include
- def f_get_save_path(self, file_name: str) -> str:
- path = os.path.join(self._base_dir, file_name)
- return path
- @staticmethod
- def from_config(config_path: str):
- """
- 从配置文件生成实体类
- """
- if os.path.isdir(config_path):
- config_path = os.path.join(config_path, "mlcfg.json")
- if os.path.exists(config_path):
- with open(config_path, mode="r", encoding="utf-8") as f:
- j = json.loads(f.read())
- else:
- raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"指配置文件【{config_path}】不存在")
- print(f"mlcfg load to【{config_path}】success. ")
- return MlConfigEntity(**j)
- def config_save(self):
- path = self.f_get_save_path("mlcfg.json")
- with open(path, mode="w", encoding="utf-8") as f:
- j = {k.lstrip("_"): v for k, v in self.__dict__.items()}
- j = json.dumps(j, ensure_ascii=False)
- f.write(j)
- print(f"mlcfg save to【{path}】success. ")
- if __name__ == "__main__":
- pass
|