strategy_norm.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2025/4/3
  5. @desc: 值标准化,类似于分箱
  6. """
  7. from typing import Dict, List
  8. import pandas as pd
  9. import xgboost as xgb
  10. from pandas.core.dtypes.common import is_numeric_dtype
  11. from commom import GeneralException, f_display_title
  12. from data import DataExplore
  13. from entitys import DataSplitEntity, MetricFucResultEntity
  14. from enums import ResultCodesEnum, ContextEnum
  15. from feature.feature_strategy_base import FeatureStrategyBase
  16. from init import context
  17. from .utils import f_format_value, OneHot, f_format_bin
  18. class StrategyNorm(FeatureStrategyBase):
  19. def __init__(self, *args, **kwargs):
  20. super().__init__(*args, **kwargs)
  21. self.x_columns = None
  22. self.one_hot_encoder_dict: Dict[str, OneHot] = {}
  23. self.points_dict: Dict[str, List[float]] = {}
  24. def _f_fast_filter(self, data: DataSplitEntity) -> List[str]:
  25. y_column = self.ml_config.y_column
  26. x_columns = self.ml_config.x_columns
  27. columns_exclude = self.ml_config.columns_exclude
  28. format_bin = self.ml_config.format_bin
  29. params_xgb = self.ml_config.params_xgb
  30. max_feature_num = self.ml_config.max_feature_num
  31. train_data = data.train_data.copy()
  32. test_data = data.test_data.copy()
  33. # 特征列配置
  34. if len(x_columns) == 0:
  35. x_columns = train_data.columns.tolist()
  36. if y_column in x_columns:
  37. x_columns.remove(y_column)
  38. for column in columns_exclude:
  39. if column in x_columns:
  40. x_columns.remove(column)
  41. # 简单校验数据类型一致性
  42. check_msg = DataExplore.check_type(data.data[x_columns])
  43. if check_msg != "":
  44. print(f"数据类型分析:\n{check_msg}\n同一变量请保持数据类型一致")
  45. raise GeneralException(ResultCodesEnum.ILLEGAL_PARAMS, message=f"数据类型错误.")
  46. # 数据处理
  47. model_columns = []
  48. num_columns = []
  49. str_columns = []
  50. for x_column in x_columns:
  51. if is_numeric_dtype(train_data[x_column]):
  52. num_columns.append(x_column)
  53. # 粗分箱
  54. if format_bin:
  55. data_x_describe = train_data[x_column].describe(percentiles=[0.1, 0.9])
  56. points = f_format_bin(data_x_describe)
  57. self.points_dict[x_column] = points
  58. train_data[x_column] = train_data[x_column].apply(lambda x: f_format_value(points, x))
  59. test_data[x_column] = test_data[x_column].apply(lambda x: f_format_value(points, x))
  60. else:
  61. str_columns.append(x_column)
  62. one_hot_encoder = OneHot()
  63. one_hot_encoder.fit(data.data, x_column)
  64. one_hot_encoder.encoder(train_data)
  65. one_hot_encoder.encoder(test_data)
  66. model_columns.extend(one_hot_encoder.columns_onehot)
  67. self.one_hot_encoder_dict[x_column] = one_hot_encoder
  68. model_columns.extend(num_columns)
  69. # 重要性剔除弱变量
  70. model = xgb.XGBClassifier(objective=params_xgb.get("objective"),
  71. n_estimators=params_xgb.get("num_boost_round"),
  72. max_depth=params_xgb.get("max_depth"),
  73. learning_rate=params_xgb.get("learning_rate"),
  74. random_state=params_xgb.get("random_state"),
  75. reg_alpha=params_xgb.get("alpha"),
  76. subsample=params_xgb.get("subsample"),
  77. colsample_bytree=params_xgb.get("colsample_bytree"),
  78. importance_type='weight'
  79. )
  80. model.fit(X=train_data[model_columns], y=train_data[y_column],
  81. eval_set=[(train_data[model_columns], train_data[y_column]),
  82. (test_data[model_columns], test_data[y_column])],
  83. eval_metric=params_xgb.get("eval_metric"),
  84. early_stopping_rounds=params_xgb.get("early_stopping_rounds"),
  85. verbose=False,
  86. )
  87. # 重要合并,字符型变量重要性为各one-hot子变量求和
  88. importance = model.feature_importances_
  89. feature = []
  90. importance_weight = []
  91. for x_column in num_columns:
  92. for i, j in zip(model_columns, importance):
  93. if i == x_column:
  94. feature.append(x_column)
  95. importance_weight.append(j)
  96. break
  97. for x_column in str_columns:
  98. feature_cache = 0
  99. for i, j in zip(model_columns, importance):
  100. if i.startswith(f"{x_column}("):
  101. feature_cache += j
  102. feature.append(x_column)
  103. importance_weight.append(feature_cache)
  104. df_importance = pd.DataFrame({'feature': feature, f'importance_weight': importance_weight})
  105. df_importance.sort_values(by=["importance_weight"], ascending=[False], inplace=True)
  106. df_importance.reset_index(drop=True, inplace=True)
  107. df_importance_rank = df_importance[df_importance["importance_weight"] > 0]
  108. df_importance_rank.reset_index(drop=True, inplace=True)
  109. x_columns_filter = list(df_importance_rank["feature"])[0:max_feature_num]
  110. context.set_filter_info(ContextEnum.FILTER_FAST,
  111. f"筛选前变量数量:{len(x_columns)}\n{x_columns}\n"
  112. f"快速筛选剔除变量数量:{len(x_columns) - len(x_columns_filter)}", detail=df_importance)
  113. return x_columns_filter
  114. def feature_search(self, data: DataSplitEntity, *args, **kwargs):
  115. x_columns = self._f_fast_filter(data)
  116. # 排个序,防止因为顺序原因导致的可能的bug
  117. x_columns.sort()
  118. self.x_columns = x_columns
  119. def variable_analyse(self, *args, **kwargs):
  120. pass
  121. def feature_generate(self, data: pd.DataFrame, *args, **kwargs) -> pd.DataFrame:
  122. df = data.copy()
  123. model_columns = []
  124. for x_column in self.x_columns:
  125. if x_column in self.points_dict.keys():
  126. points = self.points_dict[x_column]
  127. df[x_column] = df[x_column].apply(lambda x: f_format_value(points, x))
  128. model_columns.append(x_column)
  129. elif x_column in self.one_hot_encoder_dict.keys():
  130. one_hot_encoder = self.one_hot_encoder_dict[x_column]
  131. one_hot_encoder.encoder(df)
  132. model_columns.extend(one_hot_encoder.columns_onehot)
  133. else:
  134. model_columns.append(x_column)
  135. return df[model_columns]
  136. def feature_save(self, *args, **kwargs):
  137. self.x_columns = None
  138. self.one_hot_encoder_dict: Dict[str, OneHot] = {}
  139. self.points_dict: Dict[str, List[float]] = {}
  140. pass
  141. def feature_load(self, path: str, *args, **kwargs):
  142. pass
  143. def feature_report(self, data: DataSplitEntity, *args, **kwargs) -> Dict[str, MetricFucResultEntity]:
  144. y_column = self.ml_config.y_column
  145. metric_value_dict = {}
  146. # 样本分布
  147. metric_value_dict["样本分布"] = MetricFucResultEntity(table=data.get_distribution(y_column), table_font_size=10,
  148. table_cell_width=3)
  149. self.jupyter_print(metric_value_dict)
  150. return metric_value_dict
  151. def jupyter_print(self, metric_value_dict, *args, **kwargs):
  152. from IPython import display
  153. max_feature_num = self.ml_config.max_feature_num
  154. f_display_title(display, "样本分布")
  155. display.display(metric_value_dict["样本分布"].table)
  156. filter_fast = context.get(ContextEnum.FILTER_FAST)
  157. f_display_title(display, "快速筛选过程")
  158. print(f"剔除变量重要性排名{max_feature_num}以后的变量")
  159. print(filter_fast.get("overview"))
  160. display.display(filter_fast["detail"])