Jelajahi Sumber

add: FilterStrategyFactory

yq 4 bulan lalu
induk
melakukan
b6897e8745

+ 4 - 1
config/data_process_config_template.json

@@ -2,7 +2,10 @@
   "y_column": "creditability",
   "x_columns_candidate": [
     "duration_in_month",
-    "credit_amount"
+    "credit_amount",
+    "installment_rate_in_percentage_of_disposable_income",
+    "present_residence_since",
+    "age_in_years"
   ],
   "bin_search_interval": 0.05
 }

+ 16 - 2
entitys/data_process_config_entity.py

@@ -13,9 +13,9 @@ from enums import ResultCodesEnum
 
 
 class DataProcessConfigEntity():
-    def __init__(self, y_column: str, x_columns_candidate: List[str], fill_method: str = None, split_method: str = None,
+    def __init__(self, y_column: str, x_columns_candidate: List[str] = None, fill_method: str = None, split_method: str = None,
                  feature_search_strategy: str = 'iv', bin_search_interval: float = 0.05, iv_threshold: float = 0.03,
-                 x_candidate_num: int = 10, special_values: Union[dict, list] = None):
+                 iv_threshold_wide: float = 0.05, corr_threshold: float = 0.4, x_candidate_num: int = 10, special_values: Union[dict, list] = None):
         # 定义y变量
         self._y_column = y_column
 
@@ -34,6 +34,9 @@ class DataProcessConfigEntity():
         # 使用iv筛变量时的阈值
         self._iv_threshold = iv_threshold
 
+        # 使用iv粗筛变量时的阈值
+        self._iv_threshold_wide = iv_threshold_wide
+
         # 贪婪搜索分箱时数据粒度大小,应该在0.01-0.1之间
         self._bin_search_interval = bin_search_interval
 
@@ -42,6 +45,17 @@ class DataProcessConfigEntity():
 
         self._special_values = special_values
 
+        # 变量相关性阈值
+        self._corr_threshold = corr_threshold
+
+    @property
+    def corr_threshold(self):
+        return self._corr_threshold
+
+    @property
+    def iv_threshold_wide(self):
+        return self._iv_threshold_wide
+
     @property
     def candidate_num(self):
         return self._x_candidate_num

+ 2 - 3
feature/__init__.py

@@ -5,7 +5,6 @@
 @desc: 特征挖掘
 """
 
+from .filter_strategy_factory import FilterStrategyFactory
 
-
-if __name__ == "__main__":
-    pass
+__all__ = ['FilterStrategyFactory']

+ 2 - 2
feature/feature_utils.py

@@ -116,8 +116,8 @@ def f_get_psi(train_data: DataSplitEntity, oot_data: DataSplitEntity) -> pd.Data
     return td.metrics.PSI(train_data, oot_data)
 
 
-def f_get_corr(data: DataSplitEntity, meth: str = 'spearman') -> pd.DataFrame:
-    return data.train_data().corr(method=meth)
+def f_get_corr(data: pd.DataFrame, meth: str = 'spearman') -> pd.DataFrame:
+    return data.corr(method=meth)
 
 
 def f_get_ivf(data: DataSplitEntity) -> pd.DataFrame:

+ 16 - 2
feature/filter_strategy_factory.py

@@ -4,6 +4,20 @@
 @time: 2024/11/25
 @desc: 特征筛选策略工厂
 """
+from entitys import DataProcessConfigEntity
+from enums import FilterStrategyEnum
+from .filter_strategy_base import FilterStrategyBase
+from .strategy_iv import StrategyIv
 
-if __name__ == "__main__":
-    pass
+
+class FilterStrategyFactory():
+
+    def __init__(self, data_process_config: DataProcessConfigEntity, *args, **kwargs):
+        self._data_process_config = data_process_config
+        self.strategy_map = {
+            FilterStrategyEnum.IV.value: StrategyIv(data_process_config, *args, **kwargs)
+        }
+
+    def get_strategy(self, ) -> FilterStrategyBase:
+        strategy = self.strategy_map.get(self._data_process_config.feature_search_strategy)
+        return strategy

+ 62 - 7
feature/strategy_iv.py

@@ -5,13 +5,14 @@
 @desc: iv值及单调性筛选类
 """
 from itertools import combinations_with_replacement
-from typing import List
+from typing import List, Dict
 
 import numpy as np
 import pandas as pd
+import scorecardpy as sc
 
-from entitys import DataSplitEntity, CandidateFeatureEntity, DataProcessConfigEntity
-from .feature_utils import f_judge_monto
+from entitys import DataSplitEntity, CandidateFeatureEntity
+from .feature_utils import f_judge_monto, f_get_corr
 from .filter_strategy_base import FilterStrategyBase
 
 
@@ -20,6 +21,50 @@ class StrategyIv(FilterStrategyBase):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
 
+    def _f_corr_filter(self, data: DataSplitEntity, candidate_dict: Dict[str, CandidateFeatureEntity]) -> List[str]:
+        # 相关性剔除变量
+        corr_threshold = self.data_process_config.corr_threshold
+        train_data = data.train_data
+        x_columns_candidate = list(candidate_dict.keys())
+        corr_df = f_get_corr(train_data[x_columns_candidate])
+        corr_dict = corr_df.to_dict()
+        for column, corr in corr_dict.items():
+            if column not in x_columns_candidate:
+                continue
+            for challenger_column, challenger_corr in corr.items():
+                if challenger_corr < corr_threshold or column == challenger_column \
+                        or challenger_column not in x_columns_candidate:
+                    continue
+                iv_max = candidate_dict[column].iv_max
+                challenger_iv_max = candidate_dict[challenger_column].iv_max
+                if iv_max > challenger_iv_max:
+                    x_columns_candidate.remove(challenger_column)
+                else:
+                    x_columns_candidate.remove(column)
+                    break
+        return x_columns_candidate
+
+    def _f_wide_filter(self, data: DataSplitEntity) -> List[str]:
+        # 粗筛变量
+        train_data = data.train_data
+        y_column = self.data_process_config.y_column
+        iv_threshold_wide = self.data_process_config.iv_threshold_wide
+        x_columns_candidate = self.data_process_config.x_columns_candidate
+        if x_columns_candidate is None or len(x_columns_candidate) == 0:
+            x_columns_candidate = train_data.columns.tolist().remove(y_column)
+
+        bins = sc.woebin(train_data[x_columns_candidate + [y_column]], y=y_column)
+        bins_iv_list = []
+        columns = []
+        for column, bin in bins.items():
+            total_iv = bin['total_iv'][0]
+            if total_iv < iv_threshold_wide:
+                continue
+            bins_iv_list.append({column: total_iv})
+            columns.append(column)
+        bins_iv_list = bins_iv_list.sort(key=lambda x: list(x.values())[0], reverse=True)
+        return columns
+
     def _f_get_best_bins(self, data: DataSplitEntity, x_column: str):
         # 贪婪搜索【训练集】及【测试集】加起来【iv】值最高的且【单调】的分箱
         interval = self.data_process_config.bin_search_interval
@@ -154,13 +199,23 @@ class StrategyIv(FilterStrategyBase):
 
         return iv_max, breaks_list
 
-    def filter(self, data: DataSplitEntity, *args, **kwargs):
-        x_columns_candidate = self.data_process_config.x_columns_candidate
+    def filter(self, data: DataSplitEntity, *args, **kwargs) -> List[CandidateFeatureEntity]:
+        # 粗筛
+        x_columns_candidate = self._f_wide_filter(data)
         candidate_num = self.data_process_config.candidate_num
-        candidate_list: List[CandidateFeatureEntity] = []
+
+        candidate_dict: Dict[str, CandidateFeatureEntity] = {}
         for x_column in x_columns_candidate:
             iv_max, breaks_list = self._f_get_best_bins(data, x_column)
-            candidate_list.append(CandidateFeatureEntity(x_column, breaks_list, iv_max))
+            candidate_dict[x_column] = CandidateFeatureEntity(x_column, breaks_list, iv_max)
+
+        # 相关性进一步剔除变量
+        x_columns_candidate = self._f_corr_filter(data, candidate_dict)
+        candidate_list: List[CandidateFeatureEntity] = []
+        for x_column, v in candidate_dict.items():
+            if x_column in x_columns_candidate:
+                candidate_list.append(v)
+
         candidate_list.sort(key=lambda x: x.iv_max, reverse=True)
 
         return candidate_list[0:candidate_num]

+ 5 - 0
metric_test2.py

@@ -44,6 +44,11 @@ if __name__ == "__main__":
     f_register_metric_func(AMetric)
     f_register_metric_func(BMetric)
     data_loader = DataLoaderExcel()
+
+    a = data_loader.get_data("cache/报表自动化需求-2411.xlsx")
+    a.writr("cache/a.xlsx")
+
+
     monitor_metric = MonitorMetric("./cache/model_monitor_config1.json")
     monitor_metric.calculate_metric(data_loader=data_loader)
     monitor_metric.generate_report()

+ 8 - 2
strategy_test1.py

@@ -4,13 +4,19 @@
 @time: 2024/11/1
 @desc: 
 """
+import time
+
 from entitys import DataSplitEntity, DataProcessConfigEntity
+from feature import FilterStrategyFactory
 from feature.strategy_iv import StrategyIv
 
 if __name__ == "__main__":
+    time_now = time.time()
     import scorecardpy as sc
     dat = sc.germancredit()
     dat["creditability"] = dat["creditability"].apply(lambda x: 1 if x == "bad" else 0)
     data = DataSplitEntity(dat[:700], None, dat[700:])
-    strategy = StrategyIv(DataProcessConfigEntity.from_config('./config/data_process_config_template.json'))
-    strategy.filter(data)
+    filter_strategy_factory= FilterStrategyFactory(DataProcessConfigEntity.from_config('./config/data_process_config_template.json'))
+    strategy = filter_strategy_factory.get_strategy()
+    strategy.filter(data)
+    print(time.time() - time_now)