data_process_config_entity.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/1
  5. @desc: 数据处理配置类
  6. """
  7. import json
  8. import os
  9. from typing import List, Union
  10. from commom import GeneralException, f_get_datetime
  11. from config import BaseConfig
  12. from enums import ResultCodesEnum
  13. class DataProcessConfigEntity():
  14. def __init__(self, y_column: str, x_columns_candidate: List[str] = None, fill_method: str = None,
  15. split_method: str = None, feature_search_strategy: str = 'iv', bin_search_interval: float = 0.05,
  16. iv_threshold: float = 0.03, iv_threshold_wide: float = 0.05, corr_threshold: float = 0.4,
  17. sample_rate: float = 0.1, x_candidate_num: int = 10, special_values: Union[dict, list] = None):
  18. self.save_path = os.path.join(BaseConfig.train_path, f"{f_get_datetime()}")
  19. os.makedirs(self.save_path, exist_ok=True)
  20. # 定义y变量
  21. self._y_column = y_column
  22. # 候选x变量
  23. self._x_columns_candidate = x_columns_candidate
  24. # 缺失值填充方法
  25. self._fill_method = fill_method
  26. # 数据划分方法
  27. self._split_method = split_method
  28. # 最优特征搜索方法
  29. self._feature_search_strategy = feature_search_strategy
  30. # 使用iv筛变量时的阈值
  31. self._iv_threshold = iv_threshold
  32. # 使用iv粗筛变量时的阈值
  33. self._iv_threshold_wide = iv_threshold_wide
  34. # 贪婪搜索分箱时数据粒度大小,应该在0.01-0.1之间
  35. self._bin_search_interval = bin_search_interval
  36. # 最终保留多少x变量
  37. self._x_candidate_num = x_candidate_num
  38. self._special_values = special_values
  39. # 变量相关性阈值
  40. self._corr_threshold = corr_threshold
  41. # 贪婪搜索采样比例,只针对4箱5箱时有效
  42. self._sample_rate = sample_rate
  43. @property
  44. def sample_rate(self):
  45. return self._sample_rate
  46. @property
  47. def corr_threshold(self):
  48. return self._corr_threshold
  49. @property
  50. def iv_threshold_wide(self):
  51. return self._iv_threshold_wide
  52. @property
  53. def candidate_num(self):
  54. return self._x_candidate_num
  55. @property
  56. def y_column(self):
  57. return self._y_column
  58. @property
  59. def x_columns_candidate(self):
  60. return self._x_columns_candidate
  61. @property
  62. def fill_method(self):
  63. return self._fill_method
  64. @property
  65. def split_method(self):
  66. return self._split_method
  67. @property
  68. def feature_search_strategy(self):
  69. return self._feature_search_strategy
  70. @property
  71. def iv_threshold(self):
  72. return self._iv_threshold
  73. @property
  74. def bin_search_interval(self):
  75. return self._bin_search_interval
  76. @property
  77. def special_values(self):
  78. return self._special_values
  79. def get_special_values(self, column: str = None):
  80. if column is None or isinstance(self._special_values, list):
  81. return self._special_values
  82. if isinstance(self._special_values, dict) and column is not None:
  83. return self._special_values.get(column, [])
  84. return []
  85. @staticmethod
  86. def from_config(config_path: str):
  87. """
  88. 从配置文件生成实体类
  89. """
  90. if os.path.exists(config_path):
  91. with open(config_path, mode="r", encoding="utf-8") as f:
  92. j = json.loads(f.read())
  93. else:
  94. raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"指配置文件【{config_path}】不存在")
  95. return DataProcessConfigEntity(**j)
  96. def _get_save_path(self, file_name: str) -> str:
  97. path = os.path.join(self.save_path, file_name)
  98. return path
  99. if __name__ == "__main__":
  100. pass