Przeglądaj źródła

modify: 代码优化

yq 4 dni temu
rodzic
commit
0cc78a917e
1 zmienionych plików z 16 dodań i 18 usunięć
  1. 16 18
      feature/woe/strategy_woe.py

+ 16 - 18
feature/woe/strategy_woe.py

@@ -33,6 +33,8 @@ class StrategyWoe(FeatureStrategyBase):
         super().__init__(*args, **kwargs)
         # woe编码需要的分箱信息,复用scorecardpy的格式
         self.sc_woebin = None
+        self._bin_info_filtered: Dict[str, BinInfo]
+        self._homo_bin_info_numeric_set: Dict[str, HomologousBinInfo]
 
     def _f_get_img_corr(self, train_woe) -> Union[str, None]:
         if len(train_woe.columns.to_list()) <= 1:
@@ -422,17 +424,17 @@ class StrategyWoe(FeatureStrategyBase):
         bin_info_filtered = self._f_vif_filter(data, bin_info_filtered)
         bin_info_filtered = BinInfo.ivTopN(bin_info_filtered, max_feature_num)
         self.sc_woebin = self._f_get_sc_woebin(data.train_data, bin_info_filtered)
-        context.set(ContextEnum.BIN_INFO_FILTERED, bin_info_filtered)
         context.set(ContextEnum.WOEBIN, self.sc_woebin)
+        return bin_info_filtered
 
     def feature_search(self, data: DataSplitEntity, *args, **kwargs):
         # 粗筛
         bin_info_fast = self._f_fast_filter(data)
         x_columns = list(bin_info_fast.keys())
 
-        bin_info_filtered: Dict[str, BinInfo] = {}
+        self._bin_info_filtered: Dict[str, BinInfo] = {}
         # 数值型变量多种分箱方式的中间结果
-        homo_bin_info_numeric_set: Dict[str, HomologousBinInfo] = {}
+        self._homo_bin_info_numeric_set: Dict[str, HomologousBinInfo] = {}
         filter_numeric_overview = ""
         filter_numeric_detail = []
         for x_column in tqdm(x_columns):
@@ -440,22 +442,21 @@ class StrategyWoe(FeatureStrategyBase):
                 # 数值型变量筛选
                 homo_bin_info_numeric: HomologousBinInfo = self._handle_numeric(data, x_column)
                 if homo_bin_info_numeric.is_auto_bins:
-                    homo_bin_info_numeric_set[x_column] = homo_bin_info_numeric
+                    self._homo_bin_info_numeric_set[x_column] = homo_bin_info_numeric
                 # iv psi 变量单调性 变量趋势一致性 筛选
                 bin_info: Optional[BinInfo] = homo_bin_info_numeric.filter()
                 if bin_info is not None:
-                    bin_info_filtered[x_column] = bin_info
+                    self._bin_info_filtered[x_column] = bin_info
                 else:
                     # 不满足要求被剔除
                     filter_numeric_overview = f"{filter_numeric_overview}{x_column} {homo_bin_info_numeric.drop_reason()}\n"
                     filter_numeric_detail.append(x_column)
             else:
                 # 字符型暂时用scorecardpy来处理
-                bin_info_filtered[x_column] = bin_info_fast[x_column]
+                self._bin_info_filtered[x_column] = bin_info_fast[x_column]
 
-        self.post_filter(data, bin_info_filtered)
+        self._bin_info_filtered = self.post_filter(data, self._bin_info_filtered)
 
-        context.set(ContextEnum.HOMO_BIN_INFO_NUMERIC_SET, homo_bin_info_numeric_set)
         context.set_filter_info(ContextEnum.FILTER_NUMERIC, filter_numeric_overview, filter_numeric_detail)
 
     def variable_analyse(self, data: DataSplitEntity, column: str, format_bin=None, *args, **kwargs):
@@ -497,7 +498,6 @@ class StrategyWoe(FeatureStrategyBase):
         train_data = data.train_data
         test_data = data.test_data
         # 跨模块调用中间结果,所以从上下文里取
-        bin_info_filtered: Dict[str, BinInfo] = context.get(ContextEnum.BIN_INFO_FILTERED)
 
         metric_value_dict = {}
         # 样本分布
@@ -505,15 +505,15 @@ class StrategyWoe(FeatureStrategyBase):
                                                           table_cell_width=3)
 
         # 变量相关性
-        sc_woebin_train = self._f_get_sc_woebin(train_data, bin_info_filtered)
+        sc_woebin_train = self._f_get_sc_woebin(train_data, self._bin_info_filtered)
         train_woe = sc.woebin_ply(train_data[x_columns], sc_woebin_train, print_info=False)
         img_path_corr = self._f_get_img_corr(train_woe)
         metric_value_dict["变量相关性"] = MetricFucResultEntity(image_path=img_path_corr)
 
         # 变量iv、psi、vif
         df_iv_psi_vif = pd.DataFrame()
-        train_iv = [bin_info_filtered[column].train_iv for column in x_columns]
-        psi = [bin_info_filtered[column].psi for column in x_columns]
+        train_iv = [self._bin_info_filtered[column].train_iv for column in x_columns]
+        psi = [self._bin_info_filtered[column].psi for column in x_columns]
         anns = [columns_anns.get(column, "-") for column in x_columns]
         df_iv_psi_vif["变量"] = x_columns
         df_iv_psi_vif["iv"] = train_iv
@@ -535,7 +535,7 @@ class StrategyWoe(FeatureStrategyBase):
         metric_value_dict["变量趋势-训练集"] = MetricFucResultEntity(image_path=imgs_path_trend_train, image_size=4)
 
         # 变量趋势-测试集
-        sc_woebin_test = self._f_get_sc_woebin(test_data, bin_info_filtered)
+        sc_woebin_test = self._f_get_sc_woebin(test_data, self._bin_info_filtered)
         imgs_path_trend_test = self._f_get_img_trend(sc_woebin_test, x_columns, "test")
         metric_value_dict["变量趋势-测试集"] = MetricFucResultEntity(image_path=imgs_path_trend_test, image_size=4)
 
@@ -554,7 +554,7 @@ class StrategyWoe(FeatureStrategyBase):
                 detail = [detail]
             if isinstance(detail, list):
                 for column in detail:
-                    homo_bin_info_numeric = homo_bin_info_numeric_set.get(column)
+                    homo_bin_info_numeric = self._homo_bin_info_numeric_set.get(column)
                     if homo_bin_info_numeric is None:
                         continue
                     self._f_best_bins_print(display, data, column, homo_bin_info_numeric)
@@ -572,8 +572,6 @@ class StrategyWoe(FeatureStrategyBase):
             if detail is not None and self.ml_config.bin_detail_print:
                 detail_print(detail)
 
-        bin_info_filtered: Dict[str, BinInfo] = context.get(ContextEnum.BIN_INFO_FILTERED)
-        homo_bin_info_numeric_set: Dict[str, HomologousBinInfo] = context.get(ContextEnum.HOMO_BIN_INFO_NUMERIC_SET)
         filter_fast = context.get(ContextEnum.FILTER_FAST)
         filter_numeric = context.get(ContextEnum.FILTER_NUMERIC)
         filter_corr = context.get(ContextEnum.FILTER_CORR)
@@ -597,11 +595,11 @@ class StrategyWoe(FeatureStrategyBase):
                                  title2="测试集")
 
         # 打印breaks_list
-        breaks_list = {column: bin_info.points for column, bin_info in bin_info_filtered.items()}
+        breaks_list = {column: bin_info.points for column, bin_info in self._bin_info_filtered.items()}
         print("变量切分点:")
         print(json.dumps(breaks_list, ensure_ascii=False, indent=2, cls=NumpyEncoder))
         print("选中变量不同分箱数下变量的推荐切分点:")
-        detail_print(list(bin_info_filtered.keys()))
+        detail_print(list(self._bin_info_filtered.keys()))
 
         # 打印fast_filter筛选情况
         filter_print(filter_fast, "快速筛选过程", "剔除train_iv小于阈值")