Browse Source

add: 贪婪搜索增加采样

yq 4 tháng trước cách đây
mục cha
commit
145dbaf03d
3 tập tin đã thay đổi với 27 bổ sung3 xóa
  1. 9 1
      entitys/data_process_config_entity.py
  2. 17 1
      feature/strategy_iv.py
  3. 1 1
      strategy_test1.py

+ 9 - 1
entitys/data_process_config_entity.py

@@ -15,7 +15,8 @@ from enums import ResultCodesEnum
 class DataProcessConfigEntity():
     def __init__(self, y_column: str, x_columns_candidate: List[str] = None, fill_method: str = None, split_method: str = None,
                  feature_search_strategy: str = 'iv', bin_search_interval: float = 0.05, iv_threshold: float = 0.03,
-                 iv_threshold_wide: float = 0.05, corr_threshold: float = 0.4, x_candidate_num: int = 10, special_values: Union[dict, list] = None):
+                 iv_threshold_wide: float = 0.05, corr_threshold: float = 0.4, sample_rate: float = 0.1,
+                 x_candidate_num: int = 10, special_values: Union[dict, list] = None):
         # 定义y变量
         self._y_column = y_column
 
@@ -48,6 +49,13 @@ class DataProcessConfigEntity():
         # 变量相关性阈值
         self._corr_threshold = corr_threshold
 
+        # 贪婪搜索采样比例,只针对4箱5箱时有效
+        self._sample_rate = sample_rate
+
+    @property
+    def sample_rate(self):
+        return self._sample_rate
+
     @property
     def corr_threshold(self):
         return self._corr_threshold

+ 17 - 1
feature/strategy_iv.py

@@ -71,6 +71,7 @@ class StrategyIv(FilterStrategyBase):
         iv_threshold = self.data_process_config.iv_threshold
         special_values = self.data_process_config.get_special_values(x_column)
         y_column = self.data_process_config.y_column
+        sample_rate = self.data_process_config.sample_rate
 
         def _n0(x):
             return sum(x == 0)
@@ -141,6 +142,11 @@ class StrategyIv(FilterStrategyBase):
             iv = bins['total_iv'].values[0]
             return iv
 
+        def _f_sampling(distribute_list: list, sample_rate: float):
+            # 采样,完全贪婪搜索耗时太长
+            sampled_list = distribute_list[::int(1 / sample_rate)]
+            return sampled_list
+
         train_data = data.train_data
         train_data_filter = train_data[~train_data[x_column].isin(special_values)]
         train_data_filter = train_data_filter.sort_values(by=x_column, ascending=True)
@@ -157,7 +163,17 @@ class StrategyIv(FilterStrategyBase):
         distribute_list = []
         points_list = []
         for bin_num in list(range(2, 6)):
-            distribute_list.extend(_f_distribute_balls(int(1 / interval), bin_num))
+            distribute_list_cache = _f_distribute_balls(int(1 / interval), bin_num)
+            # 4箱及以上得采样,不然耗时太久
+            sample_num = 1000 * sample_rate
+            if sample_rate <= 0.15:
+                sample_num *= 2
+            if bin_num == 4 and len(distribute_list_cache) >= sample_num:
+                distribute_list_cache = _f_sampling(distribute_list_cache, sample_num / len(distribute_list_cache))
+            sample_num = 4000 * sample_rate
+            if bin_num == 5 and len(distribute_list_cache) >= sample_num:
+                distribute_list_cache = _f_sampling(distribute_list_cache, sample_num / len(distribute_list_cache))
+            distribute_list.extend(distribute_list_cache)
         for distribute in distribute_list:
             point_list_cache = []
             point_percentile_list = [sum(distribute[0:idx + 1]) * interval for idx, _ in enumerate(distribute[0:-1])]

+ 1 - 1
strategy_test1.py

@@ -18,5 +18,5 @@ if __name__ == "__main__":
     data = DataSplitEntity(dat[:700], None, dat[700:])
     filter_strategy_factory= FilterStrategyFactory(DataProcessConfigEntity.from_config('./config/data_process_config_template.json'))
     strategy = filter_strategy_factory.get_strategy()
-    strategy.filter(data)
+    a = strategy.filter(data)
     print(time.time() - time_now)