|
@@ -71,6 +71,7 @@ class StrategyIv(FilterStrategyBase):
|
|
|
iv_threshold = self.data_process_config.iv_threshold
|
|
|
special_values = self.data_process_config.get_special_values(x_column)
|
|
|
y_column = self.data_process_config.y_column
|
|
|
+ sample_rate = self.data_process_config.sample_rate
|
|
|
|
|
|
def _n0(x):
|
|
|
return sum(x == 0)
|
|
@@ -141,6 +142,11 @@ class StrategyIv(FilterStrategyBase):
|
|
|
iv = bins['total_iv'].values[0]
|
|
|
return iv
|
|
|
|
|
|
+ def _f_sampling(distribute_list: list, sample_rate: float):
|
|
|
+ # 采样,完全贪婪搜索耗时太长
|
|
|
+ sampled_list = distribute_list[::int(1 / sample_rate)]
|
|
|
+ return sampled_list
|
|
|
+
|
|
|
train_data = data.train_data
|
|
|
train_data_filter = train_data[~train_data[x_column].isin(special_values)]
|
|
|
train_data_filter = train_data_filter.sort_values(by=x_column, ascending=True)
|
|
@@ -157,7 +163,17 @@ class StrategyIv(FilterStrategyBase):
|
|
|
distribute_list = []
|
|
|
points_list = []
|
|
|
for bin_num in list(range(2, 6)):
|
|
|
- distribute_list.extend(_f_distribute_balls(int(1 / interval), bin_num))
|
|
|
+ distribute_list_cache = _f_distribute_balls(int(1 / interval), bin_num)
|
|
|
+ # 4箱及以上得采样,不然耗时太久
|
|
|
+ sample_num = 1000 * sample_rate
|
|
|
+ if sample_rate <= 0.15:
|
|
|
+ sample_num *= 2
|
|
|
+ if bin_num == 4 and len(distribute_list_cache) >= sample_num:
|
|
|
+ distribute_list_cache = _f_sampling(distribute_list_cache, sample_num / len(distribute_list_cache))
|
|
|
+ sample_num = 4000 * sample_rate
|
|
|
+ if bin_num == 5 and len(distribute_list_cache) >= sample_num:
|
|
|
+ distribute_list_cache = _f_sampling(distribute_list_cache, sample_num / len(distribute_list_cache))
|
|
|
+ distribute_list.extend(distribute_list_cache)
|
|
|
for distribute in distribute_list:
|
|
|
point_list_cache = []
|
|
|
point_percentile_list = [sum(distribute[0:idx + 1]) * interval for idx, _ in enumerate(distribute[0:-1])]
|