yq 4 сар өмнө
parent
commit
cc211f8b61

+ 7 - 2
feature/filter_strategy_base.py

@@ -5,8 +5,9 @@
 @desc: 特征筛选基类
 """
 import abc
+from typing import Dict
 
-from entitys import DataProcessConfigEntity
+from entitys import DataProcessConfigEntity, DataPreparedEntity, CandidateFeatureEntity
 
 
 class FilterStrategyBase(metaclass=abc.ABCMeta):
@@ -19,5 +20,9 @@ class FilterStrategyBase(metaclass=abc.ABCMeta):
         return self._data_process_config
 
     @abc.abstractmethod
-    def filter(self, *args, **kwargs):
+    def filter(self, *args, **kwargs) -> Dict[str, CandidateFeatureEntity]:
+        pass
+
+    @abc.abstractmethod
+    def feature_generate(self, *args, **kwargs) -> DataPreparedEntity:
         pass

+ 45 - 10
feature/strategy_iv.py

@@ -12,7 +12,7 @@ import pandas as pd
 import scorecardpy as sc
 from pandas.core.dtypes.common import is_numeric_dtype
 
-from entitys import DataSplitEntity, CandidateFeatureEntity
+from entitys import DataSplitEntity, CandidateFeatureEntity, DataPreparedEntity, DataFeatureEntity
 from .feature_utils import f_judge_monto, f_get_corr
 from .filter_strategy_base import FilterStrategyBase
 
@@ -22,23 +22,29 @@ class StrategyIv(FilterStrategyBase):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
 
-    def _f_corr_filter(self, data: DataSplitEntity, candidate_dict: Dict[str, CandidateFeatureEntity]) -> List[str]:
-        # 相关性剔除变量
-        corr_threshold = self.data_process_config.corr_threshold
+    def _f_get_bins_by_breaks(self, data: pd.DataFrame, candidate_dict: Dict[str, CandidateFeatureEntity]):
         y_column = self.data_process_config.y_column
         special_values = self.data_process_config.special_values
-        train_data = data.train_data
         x_columns_candidate = list(candidate_dict.keys())
         breaks_list = {}
         for column, candidate in candidate_dict.items():
             breaks_list[column] = candidate.breaks_list
-        bins = sc.woebin(train_data[x_columns_candidate + [y_column]], y=y_column, breaks_list=breaks_list,
-                              special_values=special_values)
+        bins = sc.woebin(data[x_columns_candidate + [y_column]], y=y_column, breaks_list=breaks_list,
+                         special_values=special_values)
+        return bins
+
+    def _f_corr_filter(self, data: DataSplitEntity, candidate_dict: Dict[str, CandidateFeatureEntity]) -> List[str]:
+        # 相关性剔除变量
+        corr_threshold = self.data_process_config.corr_threshold
+        train_data = data.train_data
+        x_columns_candidate = list(candidate_dict.keys())
+
+        bins = self._f_get_bins_by_breaks(train_data, candidate_dict)
         train_woe = sc.woebin_ply(train_data[x_columns_candidate], bins)
         corr_df = f_get_corr(train_woe)
         corr_dict = corr_df.to_dict()
         for column, corr in corr_dict.items():
-            column = column.replace("_woe","")
+            column = column.replace("_woe", "")
             if column not in x_columns_candidate:
                 continue
             for challenger_column, challenger_corr in corr.items():
@@ -240,7 +246,7 @@ class StrategyIv(FilterStrategyBase):
 
         return iv_max, breaks_list
 
-    def filter(self, data: DataSplitEntity, *args, **kwargs) -> List[CandidateFeatureEntity]:
+    def filter(self, data: DataSplitEntity, *args, **kwargs) -> Dict[str, CandidateFeatureEntity]:
         # 粗筛
         bins_iv_dict = self._f_wide_filter(data)
         x_columns_candidate = list(bins_iv_dict.keys())
@@ -263,5 +269,34 @@ class StrategyIv(FilterStrategyBase):
                 candidate_list.append(v)
 
         candidate_list.sort(key=lambda x: x.iv_max, reverse=True)
+        candidate_list = candidate_list[0:candidate_num]
+        candidate_dict = {}
+        for candidate in candidate_list:
+            candidate_dict[candidate.x_column] = candidate
+        return candidate_dict
+
+    def feature_generate(self, data: DataSplitEntity, candidate_dict: Dict[str, CandidateFeatureEntity], *args,
+                         **kwargs) -> DataPreparedEntity:
+        train_data = data.train_data
+        val_data = data.val_data
+        test_data = data.test_data
+        y_column = self.data_process_config.y_column
+        x_columns_candidate = list(candidate_dict.keys())
+        bins = self._f_get_bins_by_breaks(train_data, candidate_dict)
+
+        train_woe = sc.woebin_ply(train_data[x_columns_candidate], bins)
+        train_data_feature = DataFeatureEntity(pd.concat((train_woe, train_data[y_column]), axis=1),
+                                               train_woe.columns.tolist(), y_column)
 
-        return candidate_list[0:candidate_num]
+        val_data_feature = None
+        if val_data is not None and len(val_data) != 0:
+            val_woe = sc.woebin_ply(val_data[x_columns_candidate], bins)
+            val_data_feature = DataFeatureEntity(pd.concat((val_woe, val_data[y_column]), axis=1),
+                                                 train_woe.columns.tolist(), y_column)
+
+        test_data_feature = None
+        if test_data is not None and len(test_data) != 0:
+            test_woe = sc.woebin_ply(test_data[x_columns_candidate], bins)
+            test_data_feature = DataFeatureEntity(pd.concat((test_woe, test_data[y_column]), axis=1),
+                                                  train_woe.columns.tolist(), y_column)
+        return DataPreparedEntity(train_data_feature, val_data_feature, test_data_feature)

+ 3 - 1
strategy_test1.py

@@ -18,5 +18,7 @@ if __name__ == "__main__":
     data = DataSplitEntity(dat[:700], None, dat[700:])
     filter_strategy_factory= FilterStrategyFactory(DataProcessConfigEntity.from_config('./config/data_process_config_template.json'))
     strategy = filter_strategy_factory.get_strategy()
-    a = strategy.filter(data)
+    candidate_feature = strategy.filter(data)
+    candidate_feature = strategy.feature_generate(data, candidate_feature)
+
     print(time.time() - time_now)