Explorar el Código

add: iv筛选增加字符串类型变量处理

yq hace 4 meses
padre
commit
810ea0460b
Se han modificado 1 ficheros con 50 adiciones y 20 borrados
  1. 50 20
      feature/strategy_iv.py

+ 50 - 20
feature/strategy_iv.py

@@ -10,6 +10,7 @@ from typing import List, Dict
 import numpy as np
 import pandas as pd
 import scorecardpy as sc
+from pandas.core.dtypes.common import is_numeric_dtype
 
 from entitys import DataSplitEntity, CandidateFeatureEntity
 from .feature_utils import f_judge_monto, f_get_corr
@@ -24,14 +25,24 @@ class StrategyIv(FilterStrategyBase):
     def _f_corr_filter(self, data: DataSplitEntity, candidate_dict: Dict[str, CandidateFeatureEntity]) -> List[str]:
         # 相关性剔除变量
         corr_threshold = self.data_process_config.corr_threshold
+        y_column = self.data_process_config.y_column
+        special_values = self.data_process_config.special_values
         train_data = data.train_data
         x_columns_candidate = list(candidate_dict.keys())
-        corr_df = f_get_corr(train_data[x_columns_candidate])
+        breaks_list = {}
+        for column, candidate in candidate_dict.items():
+            breaks_list[column] = candidate.breaks_list
+        bins = sc.woebin(train_data[x_columns_candidate + [y_column]], y=y_column, breaks_list=breaks_list,
+                              special_values=special_values)
+        train_woe = sc.woebin_ply(train_data[x_columns_candidate], bins)
+        corr_df = f_get_corr(train_woe)
         corr_dict = corr_df.to_dict()
         for column, corr in corr_dict.items():
+            column = column.replace("_woe","")
             if column not in x_columns_candidate:
                 continue
             for challenger_column, challenger_corr in corr.items():
+                challenger_column = challenger_column.replace("_woe", "")
                 if challenger_corr < corr_threshold or column == challenger_column \
                         or challenger_column not in x_columns_candidate:
                     continue
@@ -44,28 +55,43 @@ class StrategyIv(FilterStrategyBase):
                     break
         return x_columns_candidate
 
-    def _f_wide_filter(self, data: DataSplitEntity) -> List[str]:
+    def _f_wide_filter(self, data: DataSplitEntity) -> Dict:
         # 粗筛变量
         train_data = data.train_data
+        test_data = data.test_data
+        special_values = self.data_process_config.special_values
         y_column = self.data_process_config.y_column
         iv_threshold_wide = self.data_process_config.iv_threshold_wide
         x_columns_candidate = self.data_process_config.x_columns_candidate
         if x_columns_candidate is None or len(x_columns_candidate) == 0:
-            x_columns_candidate = train_data.columns.tolist().remove(y_column)
-
-        bins = sc.woebin(train_data[x_columns_candidate + [y_column]], y=y_column)
-        bins_iv_list = []
-        columns = []
-        for column, bin in bins.items():
-            total_iv = bin['total_iv'][0]
-            if total_iv < iv_threshold_wide:
+            x_columns_candidate = train_data.columns.tolist()
+            x_columns_candidate.remove(y_column)
+
+        bins_train = sc.woebin(train_data[x_columns_candidate + [y_column]], y=y_column, special_values=special_values,
+                               bin_num_limit=5)
+
+        breaks_list = {}
+        for column, bin in bins_train.items():
+            breaks_list[column] = list(bin['breaks'])
+        bins_test = None
+        if test_data is not None and len(test_data) != 0:
+            bins_test = sc.woebin(test_data[x_columns_candidate + [y_column]], y=y_column, breaks_list=breaks_list,
+                                  special_values=special_values
+                                  )
+        bins_iv_dict = {}
+        for column, bin_train in bins_train.items():
+            train_iv = bin_train['total_iv'][0]
+            test_iv = 0
+            if bins_test is not None:
+                bin_test = bins_test[column]
+                test_iv = bin_test['total_iv'][0]
+            iv_max = train_iv + test_iv
+            if train_iv < iv_threshold_wide:
                 continue
-            bins_iv_list.append({column: total_iv})
-            columns.append(column)
-        bins_iv_list = bins_iv_list.sort(key=lambda x: list(x.values())[0], reverse=True)
-        return columns
+            bins_iv_dict[column] = {"iv_max": iv_max, "breaks_list": breaks_list[column]}
+        return bins_iv_dict
 
-    def _f_get_best_bins(self, data: DataSplitEntity, x_column: str):
+    def _f_get_best_bins_numeric(self, data: DataSplitEntity, x_column: str):
         # 贪婪搜索【训练集】及【测试集】加起来【iv】值最高的且【单调】的分箱
         interval = self.data_process_config.bin_search_interval
         iv_threshold = self.data_process_config.iv_threshold
@@ -207,7 +233,6 @@ class StrategyIv(FilterStrategyBase):
                 for sv_bin in test_sv_bin_list:
                     test_bins = pd.concat((test_bins, sv_bin))
                 test_iv = _calculation_iv(test_bins)
-
             iv = train_iv + test_iv
             if iv > iv_max:
                 iv_max = iv
@@ -217,13 +242,18 @@ class StrategyIv(FilterStrategyBase):
 
     def filter(self, data: DataSplitEntity, *args, **kwargs) -> List[CandidateFeatureEntity]:
         # 粗筛
-        x_columns_candidate = self._f_wide_filter(data)
+        bins_iv_dict = self._f_wide_filter(data)
+        x_columns_candidate = list(bins_iv_dict.keys())
         candidate_num = self.data_process_config.candidate_num
-
         candidate_dict: Dict[str, CandidateFeatureEntity] = {}
         for x_column in x_columns_candidate:
-            iv_max, breaks_list = self._f_get_best_bins(data, x_column)
-            candidate_dict[x_column] = CandidateFeatureEntity(x_column, breaks_list, iv_max)
+            if is_numeric_dtype(data.train_data[x_column]):
+                iv_max, breaks_list = self._f_get_best_bins_numeric(data, x_column)
+                candidate_dict[x_column] = CandidateFeatureEntity(x_column, breaks_list, iv_max)
+            else:
+                # 字符型暂时用scorecardpy来处理
+                candidate_dict[x_column] = CandidateFeatureEntity(x_column, bins_iv_dict[x_column]["breaks_list"],
+                                                                  bins_iv_dict[x_column]["iv_max"])
 
         # 相关性进一步剔除变量
         x_columns_candidate = self._f_corr_filter(data, candidate_dict)