ml_config_entity.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/1
  5. @desc: 数据处理配置类
  6. """
  7. import json
  8. import os
  9. from typing import List, Union
  10. from commom import GeneralException, f_get_datetime
  11. from config import BaseConfig
  12. from enums import ResultCodesEnum, FileEnum, ModelEnum, FeatureStrategyEnum
  13. from init import warning_ignore
  14. class MlConfigEntity():
  15. def __init__(self,
  16. y_column: str,
  17. project_name: str = None,
  18. x_columns: List[str] = [],
  19. columns_exclude: List[str] = [],
  20. columns_include: List[str] = [],
  21. columns_anns: dict = {},
  22. bin_search_interval: float = 0.05,
  23. bin_sample_rate: float = 0.1,
  24. iv_threshold: float = 0.01,
  25. corr_threshold: float = 0.4,
  26. psi_threshold: float = 0.2,
  27. vif_threshold: float = 10,
  28. monto_shift_threshold=1,
  29. trend_shift_threshold=0,
  30. max_feature_num: int = 10,
  31. special_values: Union[dict, list, str] = None,
  32. breaks_list: dict = None,
  33. format_bin: str = False,
  34. jupyter_print=False,
  35. bin_detail_print=True,
  36. stress_test=False,
  37. stress_sample_times=100,
  38. stress_bad_rate_list: List[float] = [],
  39. model_type="lr",
  40. feature_strategy="woe",
  41. params_xgb={},
  42. rules=[],
  43. *args, **kwargs):
  44. self._model_type = model_type
  45. self._feature_strategy = feature_strategy
  46. self._params_xgb = params_xgb
  47. self._psi_threshold = psi_threshold
  48. self._vif_threshold = vif_threshold
  49. # 排除的x列
  50. self._columns_exclude = columns_exclude
  51. # 强制保留的x列
  52. self._columns_include = columns_include
  53. # 变量注释
  54. self._columns_anns = columns_anns
  55. # 是否开启下输出内容
  56. self._stress_test = stress_test
  57. # jupyter下输出内容
  58. self._stress_sample_times = stress_sample_times
  59. # jupyter下输出内容
  60. self._stress_bad_rate_list = stress_bad_rate_list
  61. # jupyter下输出内容
  62. self._jupyter_print = jupyter_print
  63. # jupyter下输出内容
  64. self._bin_detail_print = bin_detail_print
  65. # 单调性允许变化次数
  66. self._monto_shift_threshold = monto_shift_threshold
  67. # 变量趋势一致性允许变化次数
  68. self._trend_shift_threshold = trend_shift_threshold
  69. # 是否启用粗分箱
  70. self._format_bin = format_bin
  71. # 项目名称,和缓存路径有关
  72. self._project_name = project_name
  73. # 定义y变量
  74. self._y_column = y_column
  75. # 候选x变量
  76. self._x_columns = x_columns
  77. # 使用iv筛变量时的阈值
  78. self._iv_threshold = iv_threshold
  79. # 贪婪搜索分箱时数据粒度大小,应该在0.01-0.1之间
  80. self._bin_search_interval = bin_search_interval
  81. # 最终保留多少x变量
  82. self._max_feature_num = max_feature_num
  83. self._special_values = special_values
  84. self._breaks_list = breaks_list
  85. # 变量相关性阈值
  86. self._corr_threshold = corr_threshold
  87. # 贪婪搜索采样比例,只针对4箱5箱时有效
  88. self._bin_sample_rate = bin_sample_rate
  89. # 加减分规则
  90. self._rules = rules
  91. if self._project_name is None or len(self._project_name) == 0:
  92. self._base_dir = os.path.join(BaseConfig.train_path, f"{f_get_datetime()}")
  93. else:
  94. self._base_dir = os.path.join(BaseConfig.train_path, self._project_name)
  95. self._include = columns_include + list(self.breaks_list.keys())
  96. os.makedirs(self._base_dir, exist_ok=True)
  97. print(f"项目路径:【{self._base_dir}】")
  98. if self._jupyter_print:
  99. warning_ignore()
  100. @property
  101. def model_type(self):
  102. return self._model_type
  103. @property
  104. def feature_strategy(self):
  105. if ModelEnum.LR.value == self._model_type:
  106. return FeatureStrategyEnum.WOE.value
  107. if ModelEnum.XGB.value == self._model_type:
  108. return FeatureStrategyEnum.NORM.value
  109. @property
  110. def params_xgb(self):
  111. params = {
  112. 'objective': 'binary:logistic',
  113. 'eval_metric': 'auc',
  114. 'learning_rate': 0.1,
  115. 'max_depth': 3,
  116. 'subsample': None,
  117. 'colsample_bytree': None,
  118. 'alpha': 0,
  119. 'lambda': 1,
  120. 'num_boost_round': 100,
  121. 'early_stopping_rounds': 20,
  122. 'verbose_eval': 10,
  123. 'random_state': 2025,
  124. 'save_pmml': True,
  125. 'trees_print': False,
  126. }
  127. params.update(self._params_xgb)
  128. return params
  129. @property
  130. def psi_threshold(self):
  131. return self._psi_threshold
  132. @property
  133. def vif_threshold(self):
  134. return self._vif_threshold
  135. @property
  136. def stress_test(self):
  137. return self._stress_test
  138. @property
  139. def stress_sample_times(self):
  140. return self._stress_sample_times
  141. @property
  142. def stress_bad_rate_list(self):
  143. return self._stress_bad_rate_list
  144. @property
  145. def jupyter_print(self):
  146. return self._jupyter_print
  147. @property
  148. def bin_detail_print(self):
  149. return self._bin_detail_print
  150. @property
  151. def base_dir(self):
  152. return self._base_dir
  153. @property
  154. def monto_shift_threshold(self):
  155. return self._monto_shift_threshold
  156. @property
  157. def trend_shift_threshold(self):
  158. return self._trend_shift_threshold
  159. @property
  160. def format_bin(self):
  161. return self._format_bin
  162. @property
  163. def project_name(self):
  164. return self._project_name
  165. @property
  166. def bin_sample_rate(self):
  167. return self._bin_sample_rate
  168. @property
  169. def rules(self):
  170. return self._rules
  171. @property
  172. def corr_threshold(self):
  173. return self._corr_threshold
  174. @property
  175. def max_feature_num(self):
  176. return self._max_feature_num
  177. @property
  178. def y_column(self):
  179. return self._y_column
  180. @property
  181. def x_columns(self):
  182. return self._x_columns
  183. @property
  184. def columns_exclude(self):
  185. return self._columns_exclude
  186. @property
  187. def columns_include(self):
  188. return self._columns_include
  189. @property
  190. def columns_anns(self):
  191. return self._columns_anns
  192. @property
  193. def iv_threshold(self):
  194. return self._iv_threshold
  195. @property
  196. def bin_search_interval(self):
  197. return self._bin_search_interval
  198. @property
  199. def special_values(self):
  200. if self._special_values is None or len(self._special_values) == 0:
  201. return None
  202. if isinstance(self._special_values, str):
  203. return [self._special_values]
  204. if isinstance(self._special_values, (dict, list)):
  205. return self._special_values
  206. return None
  207. def get_special_values(self, column: str = None):
  208. if self._special_values is None or len(self._special_values) == 0:
  209. return []
  210. if isinstance(self._special_values, str):
  211. return [self._special_values]
  212. if isinstance(self._special_values, list):
  213. return self._special_values
  214. if isinstance(self._special_values, dict) and column is not None:
  215. return self._special_values.get(column, [])
  216. return []
  217. @property
  218. def breaks_list(self):
  219. if self._breaks_list is None:
  220. return {}
  221. if isinstance(self._breaks_list, dict):
  222. return self._breaks_list
  223. return {}
  224. def get_breaks_list(self, column: str = None):
  225. if self._breaks_list is None or len(self._breaks_list) == 0:
  226. return []
  227. if isinstance(self._breaks_list, dict) and column is not None:
  228. return self._breaks_list.get(column, [])
  229. return []
  230. def is_include(self, column: str) -> bool:
  231. return column in self._include
  232. def f_get_save_path(self, file_name: str) -> str:
  233. path = os.path.join(self._base_dir, file_name)
  234. return path
  235. @staticmethod
  236. def from_config(config_path: str):
  237. """
  238. 从配置文件生成实体类
  239. """
  240. if os.path.isdir(config_path):
  241. config_path = os.path.join(config_path, FileEnum.ML_CFG.value)
  242. if os.path.exists(config_path):
  243. with open(config_path, mode="r", encoding="utf-8") as f:
  244. j = json.loads(f.read())
  245. else:
  246. raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"指配置文件【{config_path}】不存在")
  247. print(f"mlcfg load from【{config_path}】success. ")
  248. return MlConfigEntity(**j)
  249. def config_save(self):
  250. path = self.f_get_save_path(FileEnum.ML_CFG.value)
  251. with open(path, mode="w", encoding="utf-8") as f:
  252. j = {k.lstrip("_"): v for k, v in self.__dict__.items()}
  253. j = json.dumps(j, ensure_ascii=False)
  254. f.write(j)
  255. print(f"mlcfg save to【{path}】success. ")
  256. if __name__ == "__main__":
  257. pass