Jelajahi Sumber

bugfix: bug修复

yq 1 bulan lalu
induk
melakukan
a9d5b54ba4
5 mengubah file dengan 92 tambahan dan 20 penghapusan
  1. 3 0
      config/base_config.py
  2. 11 7
      feature/woe/entity.py
  3. 12 10
      feature/woe/strategy_woe.py
  4. 10 3
      init/__init__.py
  5. 56 0
      train_test.py

+ 3 - 0
config/base_config.py

@@ -16,5 +16,8 @@ class BaseConfig:
     train_path = os.path.join(".", "cache", "train")
     os.makedirs(train_path, exist_ok=True)
 
+    # 运行环境,目前影响上下文的储存
+    run_env = "jupyter"
+
     # 表格合并相同列名的列
     merge_table_column = True

+ 11 - 7
feature/woe/entity.py

@@ -76,9 +76,10 @@ class HomologousBinInfo():
      同一变量不同分箱下的特征信息
      """
 
-    def __init__(self, x_column: str, is_auto_bins: int = None):
+    def __init__(self, x_column: str, is_auto_bins: int = None, is_include: bool = False):
         self.x_column = x_column
         self.is_auto_bins = is_auto_bins
+        self.is_include = is_include
         self.bins_info: List[BinInfo] = []
 
     def add(self, bin_info: BinInfo):
@@ -131,12 +132,15 @@ class HomologousBinInfo():
         # 人工指定切分点的直接返回
         if not self.is_auto_bins:
             return BinInfo.ofConvertByDict(df_bins_info.iloc[0].to_dict())
-        df_bins_info_filter = df_bins_info[
-            (df_bins_info["is_qualified_iv_train"] == 1)
-            & (df_bins_info["is_qualified_monto_train_nsv"] == 1)
-            & (df_bins_info["is_qualified_trend_nsv"] == 1)
-            & (df_bins_info["is_qualified_psi"] == 1)
-            ]
+        if self.is_include:
+            df_bins_info_filter = df_bins_info
+        else:
+            df_bins_info_filter = df_bins_info[
+                (df_bins_info["is_qualified_iv_train"] == 1)
+                & (df_bins_info["is_qualified_monto_train_nsv"] == 1)
+                & (df_bins_info["is_qualified_trend_nsv"] == 1)
+                & (df_bins_info["is_qualified_psi"] == 1)
+                ]
         # 选取单调性变化最少,iv最大,psi 最小的分箱
         df_bins_info_filter.sort_values(by=["monto_shift_nsv", "trend_shift_nsv", "iv", "psi"],
                                         ascending=[True, True, False, True], inplace=True)

+ 12 - 10
feature/woe/strategy_woe.py

@@ -221,7 +221,7 @@ class StrategyWoe(FeatureStrategyBase):
             is_auto_bins = 0
         else:
             points_list_nsv = _get_points(train_data_ascending_nsv, x_column)
-        homo_bin_info = HomologousBinInfo(x_column, is_auto_bins)
+        homo_bin_info = HomologousBinInfo(x_column, is_auto_bins, self.ml_config.is_include(x_column))
         # 计算iv psi monto_shift等
         for points in points_list_nsv:
             bin_info = BinInfo()
@@ -300,11 +300,11 @@ class StrategyWoe(FeatureStrategyBase):
                 continue
             bin_test = bins_test[column]
             test_iv = bin_test['total_iv'][0].round(3)
-            iv = train_iv + test_iv
+            iv = round(train_iv + test_iv, 3)
             psi = f_get_psi(bin_train, bin_test)
-            if psi >= psi_threshold and not self.ml_config.is_include(column):
-                filter_fast_overview = f"{filter_fast_overview}{column} 因为psi【{psi}】大于阈值被剔除\n"
-                continue
+            # if psi >= psi_threshold and not self.ml_config.is_include(column):
+            #     filter_fast_overview = f"{filter_fast_overview}{column} 因为psi【{psi}】大于阈值被剔除\n"
+            #     continue
             bin_info_fast[column] = BinInfo.ofConvertByDict(
                 {"x_column": column, "train_iv": train_iv, "iv": iv, "psi": psi, "points": breaks_list[column]}
             )
@@ -324,7 +324,7 @@ class StrategyWoe(FeatureStrategyBase):
         train_woe = sc.woebin_ply(train_data[x_columns], sc_woebin, print_info=False)
         corr_df = f_get_corr(train_woe)
         corr_dict = corr_df.to_dict()
-        filter_corr_overview = "corr_filter\n"
+        filter_corr_overview = "filter_corr\n"
         filter_corr_detail = {}
         # 依次判断每个变量对于其它变量的相关性
         for column, corr in corr_dict.items():
@@ -355,7 +355,7 @@ class StrategyWoe(FeatureStrategyBase):
             for c in column_remove:
                 if c in x_columns:
                     x_columns.remove(c)
-            if overview != "":
+            if len(column_remove) != 0:
                 filter_corr_overview = f"{filter_corr_overview}{overview}\n"
                 filter_corr_detail[column] = column_remove
         for column in list(bin_info_dict.keys()):
@@ -383,7 +383,7 @@ class StrategyWoe(FeatureStrategyBase):
             bin_info = bin_info_dict[column]
             bin_info.vif = vif
             bin_info_dict[column] = bin_info
-            if vif < vif_threshold:
+            if vif < vif_threshold or self.ml_config.is_include(column):
                 continue
             filter_vif_overview = f"{filter_vif_overview}{column} 因为vif【{vif}】大于阈值被剔除\n"
             filter_vif_detail.append(column)
@@ -519,6 +519,9 @@ class StrategyWoe(FeatureStrategyBase):
         from IPython import display
 
         def detail_print(detail):
+            if isinstance(detail, str):
+                detail = [detail]
+
             if isinstance(detail, list):
                 for column in detail:
                     homo_bin_info_numeric = homo_bin_info_numeric_set.get(column)
@@ -545,8 +548,7 @@ class StrategyWoe(FeatureStrategyBase):
                 for column, challenger_columns in detail.items():
                     print(f"-----相关性筛选保留的【{column}】-----")
                     detail_print(column)
-                    for challenger_column in challenger_columns:
-                        detail_print(challenger_column)
+                    detail_print(challenger_columns)
 
         train_data = data.train_data
         test_data = data.test_data

+ 10 - 3
init/__init__.py

@@ -10,15 +10,18 @@ import threading
 import matplotlib
 from contextvars import ContextVar
 
+from config import BaseConfig
+
 matplotlib.use('Agg')
 
 import matplotlib.pyplot as plt
 
 __all__ = ['init', 'warning_ignore', "context"]
 
+
 class Context:
     def __init__(self):
-        # 上下文,只在当前线程内有效,notebook下会失效
+        # 上下文,适合notebook下单个用户
         self._instance_lock = threading.Lock()
         self.context = {}
 
@@ -33,9 +36,10 @@ class Context:
         data = {"overview": overview, "detail": detail}
         self.set(key, data)
 
+
 class ContexThreading:
     def __init__(self):
-        # 上下文,只在当前线程内有效,notebook下会失效
+        # 上下文,web下多用户需要线程隔离,notebook下会失效
         self.context = ContextVar('context')
         self.context.set({})
 
@@ -53,7 +57,10 @@ class ContexThreading:
         self.set(key, data)
 
 
-context = Context()
+if BaseConfig.run_env == "jupyter":
+    context = Context()
+else:
+    context = ContexThreading()
 
 
 def init():

+ 56 - 0
train_test.py

@@ -20,10 +20,66 @@ if __name__ == "__main__":
     dat.columns = dat_columns
 
     dat["creditability"] = dat["creditability"].apply(lambda x: 1 if x == "bad" else 0)
+
+    dat["credit_amount_corr1"] = dat["credit_amount"] * 2
+    dat["credit_amount_corr2"] = dat["credit_amount"] * 3
+
     data = DataSplitEntity(train_data=dat[:709], test_data=dat[709:])
 
     # 训练并生成报告
     train_pipeline = Pipeline(MlConfigEntity.from_config('./config/ml_config_template.json'), data)
+    # 特征处理
+    cfg = {
+        "project_name": "demo",
+        # jupyter下输出内容
+        "jupyter_print": True,
+        # 是否开启粗分箱
+        "format_bin": False,
+        # 变量切分点搜索采样率
+        "bin_sample_rate": 0.01,
+        # 最多保留候选变量数
+        "max_feature_num": 10,
+        # 单调性允许变化次数
+        "monto_shift_threshold": 1,
+        "iv_threshold": 0.01,
+        "corr_threshold": 0.4,
+        "psi_threshold": 0.2,
+        "vif_threshold": 10,
+        # 压力测试
+        "stress_test": True,
+        "stress_sample_times": 10,
+        # 特殊值
+        "special_values": {"age_in_years": [36]},
+        # 手动定义切分点,字符型的变量以'%,%'合并枚举值
+        "breaks_list": {
+            #                 'duration_in_month': [12, 18, 48],
+            'credit_amount': [2000, 3500, 4000, 7000],
+            'purpose': ['retraining%,%car (used)', 'radio/television', 'furniture/equipment%,%business%,%repairs',
+                        'domestic appliances%,%education%,%car (new)%,%others'],
+            #                 'age_in_years': [27, 34, 58]
+        },
+        # y
+        "y_column": "creditability",
+        # 候选变量
+        "x_columns": [
+            "duration_in_month",
+            "credit_amount",
+            "age_in_years",
+            "purpose",
+            "credit_history",
+
+            "credit_amount_corr1",
+            "credit_amount_corr2",
+        ],
+        "columns_anns": {
+            "age_in_years": "年龄",
+            "credit_history": "借贷历史"
+        },
+        "columns_exclude": [],
+        # "columns_include": ["age_in_years"],
+    }
+
+    train_pipeline = Pipeline(data=data, **cfg)
     train_pipeline.train()
     train_pipeline.report()