data_process_config_entity.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/1
  5. @desc: 数据处理配置类
  6. """
  7. import json
  8. import os
  9. from typing import List, Union
  10. from commom import GeneralException, f_get_datetime
  11. from config import BaseConfig
  12. from enums import ResultCodesEnum
  13. class DataProcessConfigEntity():
  14. def __init__(self, y_column: str, x_columns_candidate: List[str] = None, fill_method: str = None, fill_value=None,
  15. split_method: str = None, feature_search_strategy: str = 'iv', bin_search_interval: float = 0.05,
  16. iv_threshold: float = 0.03, iv_threshold_wide: float = 0.05, corr_threshold: float = 0.4,
  17. sample_rate: float = 0.1, x_candidate_num: int = 10, special_values: Union[dict, list, str] = None,
  18. project_name: str = None, format_bin: str = False, breaks_list: dict = None, pos_neg_cnt=1,
  19. monto_contrast_change_cnt=0, jupyter=False, strees=False, strees_sample_times=100,
  20. strees_bad_rate_list: List[float] = [], *args, **kwargs):
  21. # 是否开启下输出内容
  22. self._strees = strees
  23. # jupyter下输出内容
  24. self._strees_sample_times = strees_sample_times
  25. # jupyter下输出内容
  26. self._strees_bad_rate_list = strees_bad_rate_list
  27. # jupyter下输出内容
  28. self._jupyter = jupyter
  29. # 单调性允许变化次数
  30. self._pos_neg_cnt = pos_neg_cnt
  31. # 变量趋势一致性允许变化次数
  32. self._monto_contrast_change_cnt = monto_contrast_change_cnt
  33. # 是否启用粗分箱
  34. self._format_bin = format_bin
  35. # 项目名称,和缓存路径有关
  36. self._project_name = project_name
  37. # 定义y变量
  38. self._y_column = y_column
  39. # 候选x变量
  40. self._x_columns_candidate = x_columns_candidate
  41. # 缺失值填充方法
  42. self._fill_method = fill_method
  43. # 缺失值填充值
  44. self._fill_value = fill_value
  45. # 数据划分方法
  46. self._split_method = split_method
  47. # 最优特征搜索方法
  48. self._feature_search_strategy = feature_search_strategy
  49. # 使用iv筛变量时的阈值
  50. self._iv_threshold = iv_threshold
  51. # 使用iv粗筛变量时的阈值
  52. self._iv_threshold_wide = iv_threshold_wide
  53. # 贪婪搜索分箱时数据粒度大小,应该在0.01-0.1之间
  54. self._bin_search_interval = bin_search_interval
  55. # 最终保留多少x变量
  56. self._x_candidate_num = x_candidate_num
  57. self._special_values = special_values
  58. self._breaks_list = breaks_list
  59. # 变量相关性阈值
  60. self._corr_threshold = corr_threshold
  61. # 贪婪搜索采样比例,只针对4箱5箱时有效
  62. self._sample_rate = sample_rate
  63. if self._project_name is None or len(self._project_name) == 0:
  64. self._base_dir = os.path.join(BaseConfig.train_path, f"{f_get_datetime()}")
  65. else:
  66. self._base_dir = os.path.join(BaseConfig.train_path, self._project_name)
  67. os.makedirs(self._base_dir, exist_ok=True)
  68. @property
  69. def strees(self):
  70. return self._strees
  71. @property
  72. def strees_sample_times(self):
  73. return self._strees_sample_times
  74. @property
  75. def strees_bad_rate_list(self):
  76. return self._strees_bad_rate_list
  77. @property
  78. def jupyter(self):
  79. return self._jupyter
  80. @property
  81. def base_dir(self):
  82. return self._base_dir
  83. @property
  84. def pos_neg_cnt(self):
  85. return self._pos_neg_cnt
  86. @property
  87. def monto_contrast_change_cnt(self):
  88. return self._monto_contrast_change_cnt
  89. @property
  90. def format_bin(self):
  91. return self._format_bin
  92. @property
  93. def project_name(self):
  94. return self._project_name
  95. @property
  96. def sample_rate(self):
  97. return self._sample_rate
  98. @property
  99. def corr_threshold(self):
  100. return self._corr_threshold
  101. @property
  102. def iv_threshold_wide(self):
  103. return self._iv_threshold_wide
  104. @property
  105. def candidate_num(self):
  106. return self._x_candidate_num
  107. @property
  108. def y_column(self):
  109. return self._y_column
  110. @property
  111. def x_columns_candidate(self):
  112. return self._x_columns_candidate
  113. @property
  114. def fill_value(self):
  115. return self._fill_value
  116. @property
  117. def fill_method(self):
  118. return self._fill_method
  119. @property
  120. def split_method(self):
  121. return self._split_method
  122. @property
  123. def feature_search_strategy(self):
  124. return self._feature_search_strategy
  125. @property
  126. def iv_threshold(self):
  127. return self._iv_threshold
  128. @property
  129. def bin_search_interval(self):
  130. return self._bin_search_interval
  131. @property
  132. def special_values(self):
  133. if self._special_values is None or len(self._special_values) == 0:
  134. return None
  135. if isinstance(self._special_values, str):
  136. return [self._special_values]
  137. if isinstance(self._special_values, (dict, list)):
  138. return self._special_values
  139. return None
  140. def get_special_values(self, column: str = None):
  141. if self._special_values is None or len(self._special_values) == 0:
  142. return []
  143. if isinstance(self._special_values, str):
  144. return [self._special_values]
  145. if isinstance(self._special_values, list):
  146. return self._special_values
  147. if isinstance(self._special_values, dict) and column is not None:
  148. return self._special_values.get(column, [])
  149. return []
  150. @property
  151. def breaks_list(self):
  152. if self._breaks_list is None:
  153. return {}
  154. if isinstance(self._breaks_list, dict):
  155. return self._breaks_list
  156. return {}
  157. def get_breaks_list(self, column: str = None):
  158. if self._breaks_list is None or len(self._breaks_list) == 0:
  159. return []
  160. if isinstance(self._breaks_list, dict) and column is not None:
  161. return self._breaks_list.get(column, [])
  162. return []
  163. def f_get_save_path(self, file_name: str) -> str:
  164. path = os.path.join(self._base_dir, file_name)
  165. return path
  166. @staticmethod
  167. def from_config(config_path: str):
  168. """
  169. 从配置文件生成实体类
  170. """
  171. if os.path.exists(config_path):
  172. with open(config_path, mode="r", encoding="utf-8") as f:
  173. j = json.loads(f.read())
  174. else:
  175. raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"指配置文件【{config_path}】不存在")
  176. return DataProcessConfigEntity(**j)
  177. if __name__ == "__main__":
  178. pass