ol_config_entity.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/1
  5. @desc: OnlineLearning数配置类
  6. """
  7. import json
  8. import os
  9. from typing import List
  10. from commom import GeneralException, f_get_datetime
  11. from config import BaseConfig
  12. from enums import ResultCodesEnum, FileEnum
  13. from init import warning_ignore
  14. class OnlineLearningConfigEntity():
  15. def __init__(self,
  16. path_resources: str,
  17. y_column: str,
  18. project_name: str = None,
  19. lr: float = 0.01,
  20. random_state: int = 2025,
  21. batch_size: int = 64,
  22. epochs: int = 50,
  23. columns_anns: dict = {},
  24. jupyter_print=False,
  25. save_pmml=True,
  26. stress_test=False,
  27. stress_sample_times=100,
  28. stress_bad_rate_list: List[float] = [],
  29. params_xgb={},
  30. *args, **kwargs):
  31. self._path_resources = path_resources
  32. # 定义y变量
  33. self._y_column = y_column
  34. # 项目名称,和缓存路径有关
  35. self._project_name = project_name
  36. # 学习率
  37. self._lr = lr
  38. # 学习率
  39. self._random_state = random_state
  40. # 模型单次更新使用数据量
  41. self._batch_size = batch_size
  42. # 最大训练轮数
  43. self._epochs = epochs
  44. # 变量注释
  45. self._columns_anns = columns_anns
  46. # jupyter下输出内容
  47. self._jupyter_print = jupyter_print
  48. self._save_pmml = save_pmml
  49. self._stress_test = stress_test
  50. self._stress_sample_times = stress_sample_times
  51. self._stress_bad_rate_list = stress_bad_rate_list
  52. self._params_xgb = params_xgb
  53. if self._project_name is None or len(self._project_name) == 0:
  54. self._base_dir = os.path.join(BaseConfig.train_path, f"{f_get_datetime()}")
  55. else:
  56. self._base_dir = os.path.join(BaseConfig.train_path, self._project_name)
  57. os.makedirs(self._base_dir, exist_ok=True)
  58. print(f"项目路径:【{self._base_dir}】")
  59. if self._jupyter_print:
  60. warning_ignore()
  61. @property
  62. def path_resources(self):
  63. return self._path_resources
  64. @property
  65. def y_column(self):
  66. return self._y_column
  67. @property
  68. def lr(self):
  69. return self._lr
  70. @property
  71. def random_state(self):
  72. return self._random_state
  73. @property
  74. def batch_size(self):
  75. return self._batch_size
  76. @property
  77. def epochs(self):
  78. return self._epochs
  79. @property
  80. def columns_anns(self):
  81. return self._columns_anns
  82. @property
  83. def jupyter_print(self):
  84. return self._jupyter_print
  85. @property
  86. def save_pmml(self):
  87. return self._save_pmml
  88. @property
  89. def stress_test(self):
  90. return self._stress_test
  91. @property
  92. def stress_sample_times(self):
  93. return self._stress_sample_times
  94. @property
  95. def stress_bad_rate_list(self):
  96. return self._stress_bad_rate_list
  97. @property
  98. def params_xgb(self):
  99. params = {
  100. 'objective': 'binary:logistic',
  101. 'eval_metric': 'auc',
  102. 'learning_rate': 0.1,
  103. 'max_depth': 3,
  104. 'subsample': None,
  105. 'colsample_bytree': None,
  106. 'alpha': 0,
  107. 'lambda': 1,
  108. 'num_boost_round': 100,
  109. 'early_stopping_rounds': 20,
  110. 'verbose_eval': 10,
  111. 'random_state': 2025,
  112. 'save_pmml': True,
  113. 'trees_print': False,
  114. # tree_add tree_refresh
  115. 'oltype': "refresh",
  116. 'add_columns': []
  117. }
  118. params.update(self._params_xgb)
  119. return params
  120. @staticmethod
  121. def from_config(config_path: str):
  122. """
  123. 从配置文件生成实体类
  124. """
  125. if os.path.isdir(config_path):
  126. config_path = os.path.join(config_path, FileEnum.OL_CFG.value)
  127. if os.path.exists(config_path):
  128. with open(config_path, mode="r", encoding="utf-8") as f:
  129. j = json.loads(f.read())
  130. else:
  131. raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"指配置文件【{config_path}】不存在")
  132. print(f"olcfg load from【{config_path}】success. ")
  133. return OnlineLearningConfigEntity(**j)
  134. def config_save(self):
  135. path = self.f_get_save_path(FileEnum.OL_CFG.value)
  136. with open(path, mode="w", encoding="utf-8") as f:
  137. j = {k.lstrip("_"): v for k, v in self.__dict__.items()}
  138. j = json.dumps(j, ensure_ascii=False)
  139. f.write(j)
  140. print(f"olcfg save to【{path}】success. ")
  141. def f_get_save_path(self, file_name: str) -> str:
  142. path = os.path.join(self._base_dir, file_name)
  143. return path
  144. if __name__ == "__main__":
  145. pass