瀏覽代碼

Merge branch 'dev-v1.0.0' of http://101.126.81.2:18002/model/easy-ml into dev-v1.0.0

# Conflicts:
#	monitor/report_generate.py
qiuya 4 月之前
父節點
當前提交
01356e12f1

+ 1 - 0
.gitignore

@@ -61,3 +61,4 @@ target/
 /logs
 /cache
 */image
+*/~$*

+ 2 - 2
commom/__init__.py

@@ -7,7 +7,7 @@
 from .logger import get_logger
 from .placeholder_func import f_fill_placeholder
 from .user_exceptions import GeneralException
-from .utils import f_get_clazz_in_module, f_clazz_to_json, f_get_date, f_get_datetime, f_save_train_df
+from .utils import f_get_clazz_in_module, f_clazz_to_json, f_get_date, f_get_datetime, f_save_train_df, f_format_float
 
 __all__ = ['f_get_clazz_in_module', 'f_clazz_to_json', 'GeneralException', 'get_logger', 'f_fill_placeholder',
-           'f_get_date', 'f_get_datetime', 'f_save_train_df']
+           'f_get_date', 'f_get_datetime', 'f_save_train_df', 'f_format_float']

+ 4 - 0
commom/utils.py

@@ -16,6 +16,10 @@ import pytz
 from config import BaseConfig
 
 
+def f_format_float(num: float, n=3):
+    return f"{num: .{n}f}"
+
+
 def f_get_date(offset: int = 0, connect: str = "-") -> str:
     current_date = datetime.datetime.now(pytz.timezone("Asia/Shanghai")).date() + datetime.timedelta(days=offset)
     return current_date.strftime(f"%Y{connect}%m{connect}%d")

+ 16 - 0
config/data_process_config_template.json

@@ -0,0 +1,16 @@
+{
+  "sample_rate": 0.01,
+  "bin_search_interval": 0.05,
+  "feature_search_strategy": "iv",
+  "x_candidate_num": 10,
+  "special_values": null,
+  "y_column": "creditability",
+  "x_columns_candidate": [
+    "duration_in_month",
+    "credit_amount",
+    "age_in_years",
+    "purpose",
+    "credit_history",
+    "savings_account_and_bonds"
+  ]
+}

+ 3 - 0
config/train_config_template.json

@@ -0,0 +1,3 @@
+{
+  "model_type": "lr"
+}

+ 3 - 2
entitys/__init__.py

@@ -6,14 +6,15 @@
 """
 from .train_config_entity import TrainConfigEntity
 from .data_process_config_entity import DataProcessConfigEntity
-from .data_feaure_entity import DataFeatureEntity, DataSplitEntity, DataPreparedEntity
+from .data_feaure_entity import DataFeatureEntity, DataSplitEntity, DataPreparedEntity, CandidateFeatureEntity
 from .db_config_entity import DbConfigEntity
 from .metric_config_entity import MetricConfigEntity
 from .metric_entity import MetricTrainEntity, MetricFucEntity
 from .monitor_metric_config_entity import MonitorMetricConfigEntity
 
 __all__ = ['DataFeatureEntity', 'DbConfigEntity', 'MetricTrainEntity', 'MonitorMetricConfigEntity', 'MetricConfigEntity',
-           'MetricFucEntity', 'DataSplitEntity', 'DataProcessConfigEntity', 'TrainConfigEntity', 'DataPreparedEntity']
+           'MetricFucEntity', 'DataSplitEntity', 'DataProcessConfigEntity', 'TrainConfigEntity', 'DataPreparedEntity',
+           'CandidateFeatureEntity']
 
 if __name__ == "__main__":
     pass

+ 59 - 8
entitys/data_feaure_entity.py

@@ -4,10 +4,40 @@
 @time: 2024/11/1
 @desc: 
 """
+
 import pandas as pd
 
+from commom import f_format_float
+
+
+class CandidateFeatureEntity():
+    """
+    经过特征筛选后的特征信息
+    """
+
+    def __init__(self, x_column: str, breaks_list: list = None, iv_max: float = None):
+        self._x_column = x_column
+        self._breaks_list = breaks_list
+        self._iv_max = iv_max
+
+    @property
+    def x_column(self):
+        return self._x_column
+
+    @property
+    def breaks_list(self):
+        return self._breaks_list
+
+    @property
+    def iv_max(self):
+        return self._iv_max
+
 
 class DataFeatureEntity():
+    """
+    数据特征准备完毕
+    """
+
     def __init__(self, data: pd.DataFrame, x_columns: list, y_column: str):
         self._data = data
         self._x_columns = x_columns
@@ -33,6 +63,10 @@ class DataFeatureEntity():
 
 
 class DataPreparedEntity():
+    """
+    训练集测试集特征准备完毕
+    """
+
     def __init__(self, train_data: DataFeatureEntity, val_data: DataFeatureEntity, test_data: DataFeatureEntity):
         self._train_data = train_data
         self._val_data = val_data
@@ -50,13 +84,15 @@ class DataPreparedEntity():
     def test_data(self):
         return self._test_data
 
-
 class DataSplitEntity():
-    def __init__(self, train_data: pd.DataFrame, val_data: pd.DataFrame, test_data: pd.DataFrame, y_column: str):
+    """
+    初始数据训练集测试集划分
+    """
+
+    def __init__(self, train_data: pd.DataFrame, val_data: pd.DataFrame, test_data: pd.DataFrame):
         self._train_data = train_data
         self._val_data = val_data
         self._test_data = test_data
-        self._y_column = y_column
 
     @property
     def train_data(self):
@@ -70,10 +106,25 @@ class DataSplitEntity():
     def test_data(self):
         return self._test_data
 
-    @property
-    def y_column(self):
-        return self._y_column
-
-
+    def get_distribution(self, y_column) -> pd.DataFrame:
+        df = pd.DataFrame()
+        train_data_len = len(self._train_data)
+        test_data_len = len(self._test_data)
+        total = train_data_len + test_data_len
+        train_bad_len = len(self._train_data[self._train_data[y_column] == 1])
+        test_bad_len = len(self._test_data[self._test_data[y_column] == 1])
+        bad_total = train_bad_len + test_bad_len
+
+        df["样本"] = ["训练集", "测试集", "合计"]
+        df["样本数"] = [train_data_len, test_data_len, total]
+        df["样本占比"] = [f"{f_format_float(train_data_len / total * 100, 2)}%",
+                      f"{f_format_float(test_data_len / total * 100, 2)}%", "100%"]
+        df["坏样本数"] = [train_bad_len, test_bad_len, bad_total]
+        df["坏样本比例"] = [f"{f_format_float(train_bad_len / train_data_len * 100, 2)}%",
+                       f"{f_format_float(test_bad_len / test_data_len * 100, 2)}%",
+                       f"{f_format_float(bad_total / total * 100, 2)}%"]
+
+        return df
+    
 if __name__ == "__main__":
     pass

+ 77 - 3
entitys/data_process_config_entity.py

@@ -6,26 +6,77 @@
 """
 import json
 import os
+from typing import List, Union
 
 from commom import GeneralException
 from enums import ResultCodesEnum
 
-from sklearn.model_selection import train_test_split
-
 
 class DataProcessConfigEntity():
-    def __init__(self, y_column: str, fill_method: str, split_method: str):
+    def __init__(self, y_column: str, x_columns_candidate: List[str] = None, fill_method: str = None,
+                 split_method: str = None, feature_search_strategy: str = 'iv', bin_search_interval: float = 0.05,
+                 iv_threshold: float = 0.03, iv_threshold_wide: float = 0.05, corr_threshold: float = 0.4,
+                 sample_rate: float = 0.1, x_candidate_num: int = 10, special_values: Union[dict, list] = None):
+
         # 定义y变量
         self._y_column = y_column
+
+        # 候选x变量
+        self._x_columns_candidate = x_columns_candidate
+
         # 缺失值填充方法
         self._fill_method = fill_method
+
         # 数据划分方法
         self._split_method = split_method
 
+        # 最优特征搜索方法
+        self._feature_search_strategy = feature_search_strategy
+
+        # 使用iv筛变量时的阈值
+        self._iv_threshold = iv_threshold
+
+        # 使用iv粗筛变量时的阈值
+        self._iv_threshold_wide = iv_threshold_wide
+
+        # 贪婪搜索分箱时数据粒度大小,应该在0.01-0.1之间
+        self._bin_search_interval = bin_search_interval
+
+        # 最终保留多少x变量
+        self._x_candidate_num = x_candidate_num
+
+        self._special_values = special_values
+
+        # 变量相关性阈值
+        self._corr_threshold = corr_threshold
+
+        # 贪婪搜索采样比例,只针对4箱5箱时有效
+        self._sample_rate = sample_rate
+
+    @property
+    def sample_rate(self):
+        return self._sample_rate
+
+    @property
+    def corr_threshold(self):
+        return self._corr_threshold
+
+    @property
+    def iv_threshold_wide(self):
+        return self._iv_threshold_wide
+
+    @property
+    def candidate_num(self):
+        return self._x_candidate_num
+
     @property
     def y_column(self):
         return self._y_column
 
+    @property
+    def x_columns_candidate(self):
+        return self._x_columns_candidate
+
     @property
     def fill_method(self):
         return self._fill_method
@@ -34,6 +85,29 @@ class DataProcessConfigEntity():
     def split_method(self):
         return self._split_method
 
+    @property
+    def feature_search_strategy(self):
+        return self._feature_search_strategy
+
+    @property
+    def iv_threshold(self):
+        return self._iv_threshold
+
+    @property
+    def bin_search_interval(self):
+        return self._bin_search_interval
+
+    @property
+    def special_values(self):
+        return self._special_values
+
+    def get_special_values(self, column: str = None):
+        if column is None or isinstance(self._special_values, list):
+            return self._special_values
+        if isinstance(self._special_values, dict) and column is not None:
+            return self._special_values.get(column, [])
+        return []
+
     @staticmethod
     def from_config(config_path: str):
         """

+ 50 - 8
entitys/metric_entity.py

@@ -4,25 +4,46 @@
 @time: 2024/11/1
 @desc:  常用指标实体集合
 """
+from typing import Union
+
 import pandas as pd
 
+from commom import f_format_float
+
 
 class MetricTrainEntity():
     """
     模型训练结果指标类
     """
 
-    def __init__(self, auc: float, ks: float):
-        self._auc = auc
-        self._ks = ks
+    def __init__(self, train_auc: float, train_ks: float, test_auc: float, test_ks: float,
+                 train_perf_image_path: str = None, test_perf_image_path: str = None):
+        self._train_auc = f_format_float(train_auc)
+        self._train_ks = f_format_float(train_ks)
+        self._train_perf_image_path = train_perf_image_path
+
+        self._test_auc = f_format_float(test_auc)
+        self._test_ks = f_format_float(test_ks)
+        self._test_perf_image_path = test_perf_image_path
+
+    def __str__(self):
+        return f"train_auc:{self._train_auc} train_ks:{self._train_ks}\ntest_auc:{self._test_auc} test_ks:{self._test_ks}"
+
+    @property
+    def train_auc(self):
+        return self._train_auc
+
+    @property
+    def train_ks(self):
+        return self._train_ks
 
     @property
-    def auc(self):
-        return self._auc
+    def test_auc(self):
+        return self._test_auc
 
     @property
-    def ks(self):
-        return self._ks
+    def test_ks(self):
+        return self._test_ks
 
 
 class MetricFucEntity():
@@ -30,10 +51,28 @@ class MetricFucEntity():
     指标计算函数结果类
     """
 
-    def __init__(self, table: pd.DataFrame = None, value: str = None, image_path: str = None):
+    def __init__(self, table: pd.DataFrame = None, value: str = None, image_path: Union[str, list] = None,
+                 table_font_size=12, table_autofit=False, table_cell_width=None, image_size: int = 6):
         self._table = table
+        self._table_font_size = table_font_size
+        self._table_cell_width= table_cell_width
+        self._table_autofit = table_autofit
+
         self._value = value
         self._image_path = image_path
+        self._image_size = image_size
+
+    @property
+    def table_cell_width(self):
+        return self._table_cell_width
+
+    @property
+    def table_autofit(self):
+        return self._table_autofit
+
+    @property
+    def table_font_size(self):
+        return self._table_font_size
 
     @property
     def table(self) -> pd.DataFrame:
@@ -47,6 +86,9 @@ class MetricFucEntity():
     def image_path(self):
         return self._image_path
 
+    @property
+    def image_size(self):
+        return self._image_size
 
 if __name__ == "__main__":
     pass

+ 15 - 2
entitys/train_config_entity.py

@@ -8,13 +8,26 @@ import json
 import os
 
 from commom import GeneralException
-from enums import ResultCodesEnum
+from enums import ResultCodesEnum, ModelEnum
 
 
 class TrainConfigEntity():
-    def __init__(self, lr: float):
+    def __init__(self, model_type=str, lr: float = None):
+        # 模型类型
+        self._model_type = model_type
         # 学习率
         self._lr = lr
+        # 报告模板
+        if model_type == ModelEnum.LR.value:
+            self._template_path = "./template/模型开发报告模板_lr.docx"
+
+    @property
+    def template_path(self):
+        return self._template_path
+
+    @property
+    def model_type(self):
+        return self._model_type
 
     @property
     def lr(self):

+ 3 - 1
enums/__init__.py

@@ -5,7 +5,9 @@
 @desc: 枚举值
 """
 from .bins_strategy_enum import BinsStrategyEnum
+from .filter_strategy_enum import FilterStrategyEnum
+from .model_enum import ModelEnum
 from .placeholder_prefix_enum import PlaceholderPrefixEnum
 from .result_codes_enum import ResultCodesEnum
 
-__all__ = ['ResultCodesEnum', 'PlaceholderPrefixEnum', 'BinsStrategyEnum']
+__all__ = ['ResultCodesEnum', 'PlaceholderPrefixEnum', 'BinsStrategyEnum', 'FilterStrategyEnum', 'ModelEnum']

+ 11 - 0
enums/filter_strategy_enum.py

@@ -0,0 +1,11 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/11/14
+@desc: 特征筛选策略枚举值
+"""
+from enum import Enum
+
+
+class FilterStrategyEnum(Enum):
+    IV = "iv"

+ 11 - 0
enums/model_enum.py

@@ -0,0 +1,11 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/11/14
+@desc: 特征筛选策略枚举值
+"""
+from enum import Enum
+
+
+class ModelEnum(Enum):
+    LR = "lr"

+ 3 - 2
feature/__init__.py

@@ -5,5 +5,6 @@
 @desc: 特征挖掘
 """
 
-if __name__ == "__main__":
-    pass
+from .filter_strategy_factory import FilterStrategyFactory
+
+__all__ = ['FilterStrategyFactory']

+ 1 - 0
feature/feature_filter.py

@@ -12,6 +12,7 @@ class FeatureFilter():
         pass
 
     def feature_filter(self, data: DataSplitEntity) -> DataPreparedEntity:
+        # 计算最佳分箱
         pass
 
 

+ 8 - 35
feature/feature_utils.py

@@ -4,12 +4,13 @@
 @time: 2023/12/28
 @desc:  特征工具类
 """
+
 import pandas as pd
+import toad as td
 from sklearn.preprocessing import KBinsDiscretizer
+
 from entitys import DataSplitEntity
 from enums import BinsStrategyEnum
-import scorecardpy as sc
-import toad as td
 
 
 def f_get_bins(data: DataSplitEntity, feat: str, strategy: str = 'quantile', nbins: int = 10) -> pd.DataFrame:
@@ -85,7 +86,10 @@ def f_judge_monto(bd_list: list, pos_neg_cnt: int = 1) -> int:
             continue
         else:
             # 记录一次符号变化
+            start_tr = tmp_tr
             pos_neg_flag += 1
+            if pos_neg_flag > pos_neg_cnt:
+                return False
     # 记录满足趋势要求的变量
     if pos_neg_flag <= pos_neg_cnt:
         return True
@@ -112,40 +116,9 @@ def f_get_psi(train_data: DataSplitEntity, oot_data: DataSplitEntity) -> pd.Data
     return td.metrics.PSI(train_data, oot_data)
 
 
-def f_get_corr(data: DataSplitEntity, meth: str = 'spearman') -> pd.DataFrame:
-    return data.train_data().corr(method=meth)
+def f_get_corr(data: pd.DataFrame, meth: str = 'spearman') -> pd.DataFrame:
+    return data.corr(method=meth)
 
 
 def f_get_ivf(data: DataSplitEntity) -> pd.DataFrame:
     pass
-
-
-def f_get_best_bins(data: DataSplitEntity, x_column: str, special_values: list = []):
-    interval = 0.05
-    # 贪婪搜索训练集及测试集iv值最高的且单调的分箱
-    train_data = data.train_data
-    train_data_filter = train_data[~train_data[x_column].isin(special_values)]
-    train_data_filter = train_data_filter.sort_values(by=x_column, ascending=True)
-    # 特殊值单独一箱
-    # train_data_special_list = []
-    # for special in special_values:
-    #     df_cache = train_data[train_data[x_column] == special]
-    #     if len(df_cache) != 0:
-    #         train_data_special_list.append(df_cache)
-    x_train_data = train_data_filter[x_column]
-    # 计算 2 - 5 箱的情况
-    bin_num_list = list(range(2, 6))
-    for bin_num in bin_num_list:
-        # 构造数据切分点
-        point_list = []
-        init_point_percentile_list = [interval * i for i in range(1, bin_num)]
-        init_point_percentile_list.append(1 - point_list[-1])
-        for point_percentile in init_point_percentile_list:
-            point = x_train_data.iloc[int(len(x_train_data) * point_percentile)]
-            if point not in point_list:
-                point_list.append(point)
-        # 获取分箱结果
-        bins = sc.woebin(train_data, y=data.y_column, breaks_list=point_list)
-        # 单调性判断
-
-    pass

+ 32 - 0
feature/filter_strategy_base.py

@@ -0,0 +1,32 @@
+# -*- coding:utf-8 -*-
+"""
+@author: yq
+@time: 2024/1/2
+@desc: 特征筛选基类
+"""
+import abc
+from typing import Dict, List
+
+from entitys import DataProcessConfigEntity, DataPreparedEntity, CandidateFeatureEntity, MetricFucEntity
+
+
+class FilterStrategyBase(metaclass=abc.ABCMeta):
+
+    def __init__(self, data_process_config: DataProcessConfigEntity, *args, **kwargs):
+        self._data_process_config = data_process_config
+
+    @property
+    def data_process_config(self):
+        return self._data_process_config
+
+    @abc.abstractmethod
+    def filter(self, *args, **kwargs) -> Dict[str, CandidateFeatureEntity]:
+        pass
+
+    @abc.abstractmethod
+    def feature_generate(self, *args, **kwargs) -> DataPreparedEntity:
+        pass
+
+    @abc.abstractmethod
+    def feature_report(self, *args, **kwargs) -> Dict[str, MetricFucEntity]:
+        pass

+ 23 - 0
feature/filter_strategy_factory.py

@@ -0,0 +1,23 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/11/25
+@desc: 特征筛选策略工厂
+"""
+from entitys import DataProcessConfigEntity
+from enums import FilterStrategyEnum
+from .filter_strategy_base import FilterStrategyBase
+from .strategy_iv import StrategyIv
+
+
+class FilterStrategyFactory():
+
+    def __init__(self, data_process_config: DataProcessConfigEntity, *args, **kwargs):
+        self._data_process_config = data_process_config
+        self.strategy_map = {
+            FilterStrategyEnum.IV.value: StrategyIv(data_process_config, *args, **kwargs)
+        }
+
+    def get_strategy(self, ) -> FilterStrategyBase:
+        strategy = self.strategy_map.get(self._data_process_config.feature_search_strategy)
+        return strategy

+ 371 - 0
feature/strategy_iv.py

@@ -0,0 +1,371 @@
+# -*- coding:utf-8 -*-
+"""
+@author: yq
+@time: 2024/1/2
+@desc: iv值及单调性筛选类
+"""
+from itertools import combinations_with_replacement
+from typing import List, Dict
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import scorecardpy as sc
+import seaborn as sns
+from pandas.core.dtypes.common import is_numeric_dtype
+
+from entitys import DataSplitEntity, CandidateFeatureEntity, DataPreparedEntity, DataFeatureEntity, MetricFucEntity
+from init import f_get_save_path
+from .feature_utils import f_judge_monto, f_get_corr
+from .filter_strategy_base import FilterStrategyBase
+
+plt.rcParams['figure.figsize'] = (8, 8)
+
+
+class StrategyIv(FilterStrategyBase):
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def _f_save_var_trend(self, bins, x_columns_candidate, prefix):
+        image_path_list = []
+        for k in x_columns_candidate:
+            bin_df = bins[k]
+            # bin_df["bin"] = bin_df["bin"].apply(lambda x: re.sub(r"(\d+\.\d+)",
+            #                                                      lambda m: "{:.2f}".format(float(m.group(0))), x))
+            sc.woebin_plot(bin_df)
+            path = f_get_save_path(f"{prefix}_{k}.png")
+            plt.savefig(path)
+            image_path_list.append(path)
+        return image_path_list
+
+    def _f_get_bins_by_breaks(self, data: pd.DataFrame, candidate_dict: Dict[str, CandidateFeatureEntity],
+                              y_column=None):
+        y_column = self.data_process_config.y_column if y_column is None else y_column
+        special_values = self.data_process_config.special_values
+        x_columns_candidate = list(candidate_dict.keys())
+        breaks_list = {}
+        for column, candidate in candidate_dict.items():
+            breaks_list[column] = candidate.breaks_list
+        bins = sc.woebin(data[x_columns_candidate + [y_column]], y=y_column, breaks_list=breaks_list,
+                         special_values=special_values)
+        return bins
+
+    def _f_corr_filter(self, data: DataSplitEntity, candidate_dict: Dict[str, CandidateFeatureEntity]) -> List[str]:
+        # 相关性剔除变量
+        corr_threshold = self.data_process_config.corr_threshold
+        train_data = data.train_data
+        x_columns_candidate = list(candidate_dict.keys())
+
+        bins = self._f_get_bins_by_breaks(train_data, candidate_dict)
+        train_woe = sc.woebin_ply(train_data[x_columns_candidate], bins)
+        corr_df = f_get_corr(train_woe)
+        corr_dict = corr_df.to_dict()
+        for column, corr in corr_dict.items():
+            column = column.replace("_woe", "")
+            if column not in x_columns_candidate:
+                continue
+            for challenger_column, challenger_corr in corr.items():
+                challenger_column = challenger_column.replace("_woe", "")
+                if challenger_corr < corr_threshold or column == challenger_column \
+                        or challenger_column not in x_columns_candidate:
+                    continue
+                iv_max = candidate_dict[column].iv_max
+                challenger_iv_max = candidate_dict[challenger_column].iv_max
+                if iv_max > challenger_iv_max:
+                    x_columns_candidate.remove(challenger_column)
+                else:
+                    x_columns_candidate.remove(column)
+                    break
+        return x_columns_candidate
+
+    def _f_wide_filter(self, data: DataSplitEntity) -> Dict:
+        # 粗筛变量
+        train_data = data.train_data
+        test_data = data.test_data
+        special_values = self.data_process_config.special_values
+        y_column = self.data_process_config.y_column
+        iv_threshold_wide = self.data_process_config.iv_threshold_wide
+        x_columns_candidate = self.data_process_config.x_columns_candidate
+        if x_columns_candidate is None or len(x_columns_candidate) == 0:
+            x_columns_candidate = train_data.columns.tolist()
+            x_columns_candidate.remove(y_column)
+
+        bins_train = sc.woebin(train_data[x_columns_candidate + [y_column]], y=y_column, special_values=special_values,
+                               bin_num_limit=5)
+
+        breaks_list = {}
+        for column, bin in bins_train.items():
+            breaks_list[column] = list(bin['breaks'])
+        bins_test = None
+        if test_data is not None and len(test_data) != 0:
+            bins_test = sc.woebin(test_data[x_columns_candidate + [y_column]], y=y_column, breaks_list=breaks_list,
+                                  special_values=special_values
+                                  )
+        bins_iv_dict = {}
+        for column, bin_train in bins_train.items():
+            train_iv = bin_train['total_iv'][0]
+            test_iv = 0
+            if bins_test is not None:
+                bin_test = bins_test[column]
+                test_iv = bin_test['total_iv'][0]
+            iv_max = train_iv + test_iv
+            if train_iv < iv_threshold_wide:
+                continue
+            bins_iv_dict[column] = {"iv_max": iv_max, "breaks_list": breaks_list[column]}
+        return bins_iv_dict
+
+    def _f_get_best_bins_numeric(self, data: DataSplitEntity, x_column: str):
+        # 贪婪搜索【训练集】及【测试集】加起来【iv】值最高的且【单调】的分箱
+        interval = self.data_process_config.bin_search_interval
+        iv_threshold = self.data_process_config.iv_threshold
+        special_values = self.data_process_config.get_special_values(x_column)
+        y_column = self.data_process_config.y_column
+        sample_rate = self.data_process_config.sample_rate
+
+        def _n0(x):
+            return sum(x == 0)
+
+        def _n1(x):
+            return sum(x == 1)
+
+        def _f_distribute_balls(balls, boxes):
+            # 计算在 balls - 1 个空位中放入 boxes - 1 个隔板的方法数
+            total_ways = combinations_with_replacement(range(balls + boxes - 1), boxes - 1)
+            distribute_list = []
+            # 遍历所有可能的隔板位置
+            for combo in total_ways:
+                # 根据隔板位置分配球
+                distribution = [0] * boxes
+                start = 0
+                for i, divider in enumerate(combo):
+                    distribution[i] = divider - start + 1
+                    start = divider + 1
+                distribution[-1] = balls - start  # 最后一个箱子的球数
+                # 确保每个箱子至少有一个球
+                if all(x > 0 for x in distribution):
+                    distribute_list.append(distribution)
+            return distribute_list
+
+        def _get_sv_bins(df, x_column, y_column, special_values):
+            # special_values_bins
+            sv_bin_list = []
+            for special in special_values:
+                dtm = df[df[x_column] == special]
+                if len(dtm) != 0:
+                    dtm['bin'] = [str(special)] * len(dtm)
+                    binning = dtm.groupby(['bin'], group_keys=False)[y_column].agg(
+                        [_n0, _n1]).reset_index().rename(columns={'_n0': 'good', '_n1': 'bad'})
+                    binning['is_special_values'] = [True] * len(binning)
+                    sv_bin_list.append(binning)
+            return sv_bin_list
+
+        def _get_bins(df, x_column, y_column, breaks_list):
+            dtm = pd.DataFrame({'y': df[y_column], 'value': df[x_column]})
+            bstbrks = [-np.inf] + breaks_list + [np.inf]
+            labels = ['[{},{})'.format(bstbrks[i], bstbrks[i + 1]) for i in range(len(bstbrks) - 1)]
+            dtm.loc[:, 'bin'] = pd.cut(dtm['value'], bstbrks, right=False, labels=labels)
+            dtm['bin'] = dtm['bin'].astype(str)
+            bins = dtm.groupby(['bin'], group_keys=False)['y'].agg([_n0, _n1]) \
+                .reset_index().rename(columns={'_n0': 'good', '_n1': 'bad'})
+            bins['is_special_values'] = [False] * len(bins)
+            return bins
+
+        def _calculation_iv(bins):
+            bins['count'] = bins['good'] + bins['bad']
+            bins['badprob'] = bins['bad'] / bins['count']
+            # 单调性判断
+            bad_prob = bins[bins['is_special_values'] == False]['badprob'].values.tolist()
+            if not f_judge_monto(bad_prob):
+                return -1
+            # 计算iv
+            infovalue = pd.DataFrame({'good': bins['good'], 'bad': bins['bad']}) \
+                .replace(0, 0.9) \
+                .assign(
+                DistrBad=lambda x: x.bad / sum(x.bad),
+                DistrGood=lambda x: x.good / sum(x.good)
+            ) \
+                .assign(iv=lambda x: (x.DistrBad - x.DistrGood) * np.log(x.DistrBad / x.DistrGood)) \
+                .iv
+            bins['bin_iv'] = infovalue
+            bins['total_iv'] = bins['bin_iv'].sum()
+            iv = bins['total_iv'].values[0]
+            return iv
+
+        def _f_sampling(distribute_list: list, sample_rate: float):
+            # 采样,完全贪婪搜索耗时太长
+            sampled_list = distribute_list[::int(1 / sample_rate)]
+            return sampled_list
+
+        train_data = data.train_data
+        train_data_filter = train_data[~train_data[x_column].isin(special_values)]
+        train_data_filter = train_data_filter.sort_values(by=x_column, ascending=True)
+        train_data_x = train_data_filter[x_column]
+
+        test_data = data.test_data
+        test_data_filter = None
+        if test_data is not None and len(test_data) != 0:
+            test_data_filter = test_data[~test_data[x_column].isin(special_values)]
+            test_data_filter = test_data_filter.sort_values(by=x_column, ascending=True)
+
+        # 构造数据切分点
+        # 计算 2 - 5 箱的情况
+        distribute_list = []
+        points_list = []
+        for bin_num in list(range(2, 6)):
+            distribute_list_cache = _f_distribute_balls(int(1 / interval), bin_num)
+            # 4箱及以上得采样,不然耗时太久
+            sample_num = 1000 * sample_rate
+            if sample_rate <= 0.15:
+                sample_num *= 2
+            if bin_num == 4 and len(distribute_list_cache) >= sample_num:
+                distribute_list_cache = _f_sampling(distribute_list_cache, sample_num / len(distribute_list_cache))
+            sample_num = 4000 * sample_rate
+            if bin_num == 5 and len(distribute_list_cache) >= sample_num:
+                distribute_list_cache = _f_sampling(distribute_list_cache, sample_num / len(distribute_list_cache))
+            distribute_list.extend(distribute_list_cache)
+        for distribute in distribute_list:
+            point_list_cache = []
+            point_percentile_list = [sum(distribute[0:idx + 1]) * interval for idx, _ in enumerate(distribute[0:-1])]
+            for point_percentile in point_percentile_list:
+                point = train_data_x.iloc[int(len(train_data_x) * point_percentile)]
+                if point not in point_list_cache:
+                    point_list_cache.append(point)
+            if point_list_cache not in points_list:
+                points_list.append(point_list_cache)
+        # IV与单调性过滤
+        iv_max = 0
+        breaks_list = []
+        train_sv_bin_list = _get_sv_bins(train_data, x_column, y_column, special_values)
+        test_sv_bin_list = None
+        if test_data_filter is not None:
+            test_sv_bin_list = _get_sv_bins(test_data, x_column, y_column, special_values)
+        from tqdm import tqdm
+        for point_list in tqdm(points_list):
+            train_bins = _get_bins(train_data_filter, x_column, y_column, point_list)
+            # 与special_values合并计算iv
+            for sv_bin in train_sv_bin_list:
+                train_bins = pd.concat((train_bins, sv_bin))
+            train_iv = _calculation_iv(train_bins)
+            # 只限制训练集的单调性与iv值大小
+            if train_iv < iv_threshold:
+                continue
+
+            test_iv = 0
+            if test_data_filter is not None:
+                test_bins = _get_bins(test_data_filter, x_column, y_column, point_list)
+                for sv_bin in test_sv_bin_list:
+                    test_bins = pd.concat((test_bins, sv_bin))
+                test_iv = _calculation_iv(test_bins)
+            iv = train_iv + test_iv
+            if iv > iv_max:
+                iv_max = iv
+                breaks_list = point_list
+
+        return iv_max, breaks_list
+
+    def filter(self, data: DataSplitEntity, *args, **kwargs) -> Dict[str, CandidateFeatureEntity]:
+        # 粗筛
+        bins_iv_dict = self._f_wide_filter(data)
+        x_columns_candidate = list(bins_iv_dict.keys())
+        candidate_num = self.data_process_config.candidate_num
+        candidate_dict: Dict[str, CandidateFeatureEntity] = {}
+        for x_column in x_columns_candidate:
+            if is_numeric_dtype(data.train_data[x_column]):
+                iv_max, breaks_list = self._f_get_best_bins_numeric(data, x_column)
+                candidate_dict[x_column] = CandidateFeatureEntity(x_column, breaks_list, iv_max)
+            else:
+                # 字符型暂时用scorecardpy来处理
+                candidate_dict[x_column] = CandidateFeatureEntity(x_column, bins_iv_dict[x_column]["breaks_list"],
+                                                                  bins_iv_dict[x_column]["iv_max"])
+
+        # 相关性进一步剔除变量
+        x_columns_candidate = self._f_corr_filter(data, candidate_dict)
+        candidate_list: List[CandidateFeatureEntity] = []
+        for x_column, v in candidate_dict.items():
+            if x_column in x_columns_candidate:
+                candidate_list.append(v)
+
+        candidate_list.sort(key=lambda x: x.iv_max, reverse=True)
+        candidate_list = candidate_list[0:candidate_num]
+        candidate_dict = {}
+        for candidate in candidate_list:
+            candidate_dict[candidate.x_column] = candidate
+        return candidate_dict
+
+    def feature_generate(self, data: DataSplitEntity, candidate_dict: Dict[str, CandidateFeatureEntity], *args,
+                         **kwargs) -> DataPreparedEntity:
+        train_data = data.train_data
+        val_data = data.val_data
+        test_data = data.test_data
+        y_column = self.data_process_config.y_column
+        x_columns_candidate = list(candidate_dict.keys())
+        bins = self._f_get_bins_by_breaks(train_data, candidate_dict)
+
+        train_woe = sc.woebin_ply(train_data[x_columns_candidate], bins)
+        train_data_feature = DataFeatureEntity(pd.concat((train_woe, train_data[y_column]), axis=1),
+                                               train_woe.columns.tolist(), y_column)
+
+        val_data_feature = None
+        if val_data is not None and len(val_data) != 0:
+            val_woe = sc.woebin_ply(val_data[x_columns_candidate], bins)
+            val_data_feature = DataFeatureEntity(pd.concat((val_woe, val_data[y_column]), axis=1),
+                                                 train_woe.columns.tolist(), y_column)
+
+        test_data_feature = None
+        if test_data is not None and len(test_data) != 0:
+            test_woe = sc.woebin_ply(test_data[x_columns_candidate], bins)
+            test_data_feature = DataFeatureEntity(pd.concat((test_woe, test_data[y_column]), axis=1),
+                                                  train_woe.columns.tolist(), y_column)
+        return DataPreparedEntity(train_data_feature, val_data_feature, test_data_feature)
+
+    def feature_report(self, data: DataSplitEntity, candidate_dict: Dict[str, CandidateFeatureEntity], *args,
+                       **kwargs) -> Dict[str, MetricFucEntity]:
+        y_column = self.data_process_config.y_column
+        x_columns_candidate = list(candidate_dict.keys())
+        train_data = data.train_data
+        test_data = data.test_data
+
+        metric_value_dict = {}
+        # 样本分布
+        metric_value_dict["样本分布"] = MetricFucEntity(table=data.get_distribution(y_column), table_font_size=12,
+                                                    table_cell_width=3)
+        # 变量iv及psi
+        train_bins = self._f_get_bins_by_breaks(train_data, candidate_dict)
+        train_iv = {key_: [round(value_['total_iv'].max(), 4)] for key_, value_ in train_bins.items()}
+        train_iv = pd.DataFrame.from_dict(train_iv, orient='index', columns=['IV']).reset_index()
+        train_iv = train_iv.sort_values('IV', ascending=False).reset_index(drop=True)
+        train_iv.columns = ['变量', 'IV']
+
+        if test_data is not None and len(test_data) != 0:
+            # 计算psi仅需把y改成识别各自训练集测试集即可
+            psi_df = pd.concat((train_data, test_data))
+            psi_df["#target#"] = [1] * len(train_data) + [0] * len(test_data)
+            psi = self._f_get_bins_by_breaks(psi_df, candidate_dict, y_column="#target#")
+            psi = {key_: [round(value_['total_iv'].max(), 4)] for key_, value_ in psi.items()}
+            psi = pd.DataFrame.from_dict(psi, orient='index', columns=['psi']).reset_index()
+            psi.columns = ['变量', 'psi']
+            train_iv = pd.merge(train_iv, psi, on="变量", how="left")
+
+            # 变量趋势-测试集
+            test_bins = self._f_get_bins_by_breaks(test_data, candidate_dict)
+            image_path_list = self._f_save_var_trend(test_bins, x_columns_candidate, "test")
+            metric_value_dict["变量趋势-测试集"] = MetricFucEntity(image_path=image_path_list, image_size=4)
+
+        metric_value_dict["变量iv"] = MetricFucEntity(table=train_iv, table_font_size=12, table_cell_width=3)
+        # 变量趋势-训练集
+        image_path_list = self._f_save_var_trend(train_bins, x_columns_candidate, "train")
+        metric_value_dict["变量趋势-训练集"] = MetricFucEntity(image_path=image_path_list, image_size=4)
+        # 变量有效性
+        train_woe = sc.woebin_ply(train_data[x_columns_candidate], train_bins)
+        train_corr = f_get_corr(train_woe)
+        plt.figure(figsize=(12, 12))
+        sns.heatmap(train_corr, vmax=1, square=True, cmap='RdBu', annot=True)
+        plt.title('Variables Correlation', fontsize=15)
+        plt.yticks(rotation=0)
+        plt.xticks(rotation=90)
+        path = f_get_save_path(f"var_corr.png")
+        plt.savefig(path)
+        metric_value_dict["变量有效性"] = MetricFucEntity(image_path=path)
+
+        return metric_value_dict

+ 17 - 1
init/__init__.py

@@ -2,8 +2,24 @@
 """
 @author: yq
 @time: 2024/10/31
-@desc: 模型及指标计算类初始化
+@desc: 一些资源初始化
 """
 
+import os
+
+from commom import f_get_datetime
+from config import BaseConfig
+
+__all__ = ['f_get_save_path']
+
+save_path = os.path.join(BaseConfig.train_path, f"{f_get_datetime()}")
+os.makedirs(save_path, exist_ok=True)
+
+
+def f_get_save_path(file_name: str) -> str:
+    path = os.path.join(save_path, file_name)
+    return path
+
+
 if __name__ == "__main__":
     pass

+ 5 - 0
metric_test2.py

@@ -44,6 +44,11 @@ if __name__ == "__main__":
     f_register_metric_func(AMetric)
     f_register_metric_func(BMetric)
     data_loader = DataLoaderExcel()
+
+    a = data_loader.get_data("cache/报表自动化需求-2411.xlsx")
+    a.writr("cache/a.xlsx")
+
+
     monitor_metric = MonitorMetric("./cache/model_monitor_config1.json")
     monitor_metric.calculate_metric(data_loader=data_loader)
     monitor_metric.generate_report()

+ 14 - 2
model/__init__.py

@@ -4,11 +4,23 @@
 @time: 2023/12/28
 @desc: 模型相关
 """
-
+from commom import GeneralException
+from enums import ModelEnum, ResultCodesEnum
 from .model_base import ModelBase
 from .model_lr import ModelLr
 
-__all__ = ['ModelBase', 'ModelLr']
+__all__ = ['ModelBase', 'f_get_model']
+
+model_map = {
+    ModelEnum.LR.value: ModelLr
+}
+
+
+def f_get_model(model_type: str):
+    if model_type not in model_map.keys():
+        raise GeneralException(ResultCodesEnum.ILLEGAL_PARAMS, message=f"模型【{model_type}】没有实现")
+    return model_map[model_type]
+
 
 if __name__ == "__main__":
     pass

+ 2 - 2
model/model_base.py

@@ -8,7 +8,7 @@ import abc
 
 import pandas as pd
 
-from entitys import DataFeatureEntity, MetricTrainEntity, TrainConfigEntity
+from entitys import MetricTrainEntity, TrainConfigEntity, DataPreparedEntity
 
 
 class ModelBase(metaclass=abc.ABCMeta):
@@ -17,7 +17,7 @@ class ModelBase(metaclass=abc.ABCMeta):
         self._train_config = train_config
 
     @abc.abstractmethod
-    def train(self, data: DataFeatureEntity, *args, **kwargs) -> MetricTrainEntity:
+    def train(self, data: DataPreparedEntity, *args, **kwargs) -> MetricTrainEntity:
         pass
 
     @abc.abstractmethod

+ 17 - 9
model/model_lr.py

@@ -4,26 +4,34 @@
 @time: 2024/11/1
 @desc: 
 """
+
 import pandas as pd
 from sklearn.linear_model import LogisticRegression
+from toad.metrics import KS, AUC
 
-from entitys import DataFeatureEntity, MetricTrainEntity, TrainConfigEntity
+from entitys import MetricTrainEntity, TrainConfigEntity, DataPreparedEntity
 from .model_base import ModelBase
 
-from toad.metrics import KS, AUC
-
 
 class ModelLr(ModelBase):
     def __init__(self, train_config: TrainConfigEntity):
         super().__init__(train_config)
         self.lr = LogisticRegression(penalty='l1', C=0.9, solver='saga', n_jobs=-1)
 
-    def train(self, data: DataFeatureEntity, *args, **kwargs) -> MetricTrainEntity:
-        self.lr.fit(data.get_Xdata(), data.get_Ydata())
-        pred_y = self.predict(data.get_Xdata())
-        ks = KS(pred_y, data.get_Ydata())
-        auc = AUC(pred_y, data.get_Ydata())
-        return MetricTrainEntity(auc, ks)
+    def train(self, data: DataPreparedEntity, *args, **kwargs) -> MetricTrainEntity:
+        train_data = data.train_data
+        test_data = data.test_data
+        self.lr.fit(train_data.get_Xdata(), train_data.get_Ydata())
+
+        train_prob = self.lr.predict_proba(train_data.get_Xdata())[:, 1]
+        train_auc = AUC(train_prob, train_data.get_Ydata())
+        train_ks = KS(train_prob, train_data.get_Ydata())
+
+        test_prob = self.lr.predict_proba(test_data.get_Xdata())[:, 1]
+        test_auc = AUC(test_prob, test_data.get_Ydata())
+        test_ks = KS(test_prob, test_data.get_Ydata())
+
+        return MetricTrainEntity(train_auc, train_ks, test_auc, test_ks)
 
     def predict_prob(self, x: pd.DataFrame, *args, **kwargs):
         return self.lr.predict_proba(x)[:, 1]

+ 42 - 68
monitor/report_generate.py

@@ -7,15 +7,12 @@
 import os
 from typing import Dict
 
-import pandas as pd
-
 from docx import Document
 from docx.enum.table import WD_ALIGN_VERTICAL
 from docx.enum.text import WD_ALIGN_PARAGRAPH
 from docx.oxml import OxmlElement
 from docx.oxml.ns import qn
-from docx.shared import Inches, Cm
-from docx.shared import Pt
+from docx.shared import Inches, Cm, Pt
 
 from commom import GeneralException, f_get_datetime
 from config import BaseConfig
@@ -26,42 +23,37 @@ from enums import ResultCodesEnum, PlaceholderPrefixEnum
 class Report():
 
     @staticmethod
-    def _set_cell_width(cell):
-        text = cell.text
-        if len(text) >= 10:
-            cell.width = Cm(2)
-        elif len(text) >= 15:
-            cell.width = Cm(2.5)
-        elif len(text) >= 25:
-            cell.width = Cm(3)
-        else:
-            cell.width = Cm(1.5)
+    def _set_cell_width(table, table_cell_width):
+        for column in table.columns:
+            if table_cell_width is not None:
+                column.width = Cm(table_cell_width)
+            # elif len(text) >= 10:
+            #     cell.width = Cm(2)
+            # elif len(text) >= 15:
+            #     cell.width = Cm(2.5)
+            # elif len(text) >= 25:
+            #     cell.width = Cm(3)
+            # else:
+            #     cell.width = Cm(1.5)
 
     @staticmethod
-    def _set_cell_format(cell, pt=11):
-        cell.paragraphs[0].alignment = WD_ALIGN_PARAGRAPH.CENTER
-        cell.vertical_alignment = WD_ALIGN_VERTICAL.CENTER
-
-        # 设置字体
+    def _set_cell_format(cell, font_size=None):
         for paragraph in cell.paragraphs:
+            paragraph.alignment = WD_ALIGN_PARAGRAPH.CENTER
             for run in paragraph.runs:
-                # 判断文本是否包含中文
-                if any('\u4e00' <= char <= '\u9fff' for char in run.text):
-                    run.font.name = '宋体'  # 设置中文字体为宋体
-                else:
-                    run.font.name = 'Times New Roman'  # 设置英文字体为Times New Roman
-            run.font.size = Pt(pt)
+                if font_size is not None:
+                    run.font.size = Pt(font_size)
+        cell.vertical_alignment = WD_ALIGN_VERTICAL.CENTER
 
     @staticmethod
-    def _merge_cell_column(pre_cell, curr_cell):
+    def _merge_cell_column(pre_cell, curr_cell, table_font_size, table_cell_width):
         if curr_cell.text == pre_cell.text:
             column_name = curr_cell.text
             pre_cell.merge(curr_cell)
             pre_cell.text = column_name
             for run in pre_cell.paragraphs[0].runs:
                 run.bold = True
-            Report._set_cell_format(pre_cell)
-            Report._set_cell_width(pre_cell)
+            Report._set_cell_format(pre_cell, table_font_size)
 
     @staticmethod
     def _set_table_singleBoard(table):
@@ -121,10 +113,6 @@ class Report():
                     run.text = ''
                 paragraph.runs[-1].text = text
 
-    @staticmethod
-    def _get_text_length(text):
-        return sum(3 if '\u4e00' <= char <= '\u9fff' else 1 for char in text)
-
     @staticmethod
     def _fill_table_placeholder(doc: Document, metric_value_dict: Dict[str, MetricFucEntity]):
         # 替换表格
@@ -132,6 +120,9 @@ class Report():
             for metric_code, metric_fuc_entity in metric_value_dict.items():
                 placeholder = Report._get_placeholder(PlaceholderPrefixEnum.TABLE, metric_code)
                 metric_table = metric_fuc_entity.table
+                table_font_size = metric_fuc_entity.table_font_size
+                table_autofit = metric_fuc_entity.table_autofit
+                table_cell_width = metric_fuc_entity.table_cell_width
                 if metric_table is None:
                     continue
                 if not placeholder in paragraph.text:
@@ -142,56 +133,32 @@ class Report():
                 table = doc.add_table(rows=metric_table.shape[0] + 1, cols=metric_table.shape[1])
                 table.alignment = WD_ALIGN_PARAGRAPH.CENTER
                 paragraph._element.addnext(table._element)
-
-                # 根据列名计算单元格宽度,对不符合最小宽度的情况,重新调整
-                # TODO:根据列名和内容综合调整单元格宽度
-                a4_width = 21 - 2 * 3.18
-                total_columns = metric_table.shape[1]
-                col_lengthes = [Report._get_text_length(c) for c in metric_table.columns]
-                cell_width_unit = a4_width / sum(col_lengthes)
-                cell_widths = [c * cell_width_unit for c in col_lengthes]
-                min_cell_width = 1
-                adjusted_cell_widths = [max(c, min_cell_width) for c in cell_widths]
-                adjusted_width = sum(adjusted_cell_widths)
-                if adjusted_width > a4_width:
-                    excess_width = adjusted_width - a4_width
-                    excess_width_per_column = excess_width / total_columns
-                    adjusted_cell_widths = [max(min_cell_width, c - excess_width_per_column) for c in
-                                            adjusted_cell_widths]
-
                 # 列名
                 for column_idx, column_name in enumerate(metric_table.columns):
                     cell = table.cell(0, column_idx)
                     cell.text = str(column_name)
                     for run in cell.paragraphs[0].runs:
                         run.bold = True
-                    Report._set_cell_format(cell, 11)
-                    Report._set_cell_width(cell)
-                    table.columns[column_idx].width = Cm(adjusted_cell_widths[column_idx])
-                    # Report._set_cell_width(cell, cell_widths[column_idx])
+                    Report._set_cell_format(cell, table_font_size)
                     # 合并相同的列名
                     if column_idx != 0 and BaseConfig.merge_table_column:
                         pre_cell = table.cell(0, column_idx - 1)
-                        Report._merge_cell_column(pre_cell, cell)
+                        Report._merge_cell_column(pre_cell, cell, table_font_size, table_cell_width)
                 # 值
                 for row_idx, row in metric_table.iterrows():
                     for column_idx, value in enumerate(row):
                         cell = table.cell(row_idx + 1, column_idx)
-                        if "率" in metric_table.columns[column_idx] or (
-                                "率" in str(row[0]) and pd.notna(value) and (column_idx != 0)):
-                            value = f"{float(value) * 100:.2f}%" if pd.notna(value) else '/'
-                        else:
-                            value = str(value) if pd.notna(value) else '/'
-                        cell.text = value
-                        Report._set_cell_format(cell, 10.5)
-                        # Report._set_cell_width(cell)
+                        cell.text = str(value)
+                        Report._set_cell_format(cell, table_font_size)
                         # 合并第一行数据也为列的情况
                         if row_idx == 0:
-                            Report._merge_cell_column(table.cell(0, column_idx), cell)
+                            Report._merge_cell_column(table.cell(0, column_idx), cell, table_font_size,
+                                                      table_cell_width)
 
+                Report._set_cell_width(table, table_cell_width)
                 Report._set_table_singleBoard(table)
                 # 禁止自动调整表格
-                if len(metric_table.columns) <= 12:
+                if len(metric_table.columns) <= 12 or not table_autofit:
                     table.autofit = False
 
     @staticmethod
@@ -201,21 +168,26 @@ class Report():
             for metric_code, metric_fuc_entity in metric_value_dict.items():
                 placeholder = Report._get_placeholder(PlaceholderPrefixEnum.IMAGE, metric_code)
                 image_path = metric_fuc_entity.image_path
+                image_size = metric_fuc_entity.image_size
                 if image_path is None:
                     continue
                 if not placeholder in paragraph.text:
                     continue
-                if not os.path.exists(image_path):
-                    raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"文件【{image_path}】不存在")
+                if isinstance(image_path, str):
+                    image_path = [image_path]
+                for path in image_path:
+                    if not os.path.exists(path):
+                        raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"文件【{image_path}】不存在")
                 # 清除占位符
                 for run in paragraph.runs:
                     if placeholder not in run.text:
                         continue
                     run.text = run.text.replace(placeholder, "")
-                    run.add_picture(image_path, width=Inches(6))
+                    for path in image_path:
+                        run.add_picture(path, width=Inches(image_size))
 
     @staticmethod
-    def generate_report(metric_value_dict: Dict[str, MetricFucEntity], template_path: str):
+    def generate_report(metric_value_dict: Dict[str, MetricFucEntity], template_path: str, save_path=None):
         if os.path.exists(template_path):
             doc = Document(template_path)
         else:
@@ -225,6 +197,8 @@ class Report():
         Report._fill_table_placeholder(doc, metric_value_dict)
         Report._fill_image_placeholder(doc, metric_value_dict)
         new_path = template_path.replace(".docx", f"{f_get_datetime()}.docx")
+        if save_path is not None:
+            new_path = save_path
         doc.save(f"./{new_path}")
 
 

+ 2 - 0
requirements.txt

@@ -1,3 +1,5 @@
 pymysql==1.0.2
 python-docx==0.8.11
 xlrd==1.2.0
+scorecardpy==0.1.9.7
+toad==0.0.64

+ 24 - 0
strategy_test1.py

@@ -0,0 +1,24 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/11/1
+@desc: 
+"""
+import time
+
+from entitys import DataSplitEntity, DataProcessConfigEntity
+from feature import FilterStrategyFactory
+from feature.strategy_iv import StrategyIv
+
+if __name__ == "__main__":
+    time_now = time.time()
+    import scorecardpy as sc
+    dat = sc.germancredit()
+    dat["creditability"] = dat["creditability"].apply(lambda x: 1 if x == "bad" else 0)
+    data = DataSplitEntity(dat[:700], None, dat[700:])
+    filter_strategy_factory= FilterStrategyFactory(DataProcessConfigEntity.from_config('./config/data_process_config_template.json'))
+    strategy = filter_strategy_factory.get_strategy()
+    candidate_feature = strategy.filter(data)
+    candidate_feature = strategy.feature_generate(data, candidate_feature)
+
+    print(time.time() - time_now)

二進制
template/模型开发报告模板_lr.docx


+ 34 - 0
train_test.py

@@ -0,0 +1,34 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/11/27
+@desc: 
+"""
+import time
+
+from entitys import DataSplitEntity, DataProcessConfigEntity, TrainConfigEntity
+from feature import FilterStrategyFactory
+from trainer import TrainPipeline
+
+if __name__ == "__main__":
+    time_now = time.time()
+    import scorecardpy as sc
+
+    dat = sc.germancredit()
+    dat["creditability"] = dat["creditability"].apply(lambda x: 1 if x == "bad" else 0)
+    data = DataSplitEntity(dat[:700], None, dat[700:])
+
+    # 特征处理
+    filter_strategy_factory = FilterStrategyFactory(
+        DataProcessConfigEntity.from_config('./config/data_process_config_template.json'))
+    strategy = filter_strategy_factory.get_strategy()
+    candidate_feature = strategy.filter(data)
+    data_prepared = strategy.feature_generate(data, candidate_feature)
+    # 训练
+    train_pipeline = TrainPipeline(TrainConfigEntity.from_config('./config/train_config_template.json'))
+    train_pipeline.train(data_prepared)
+    # 报告生成
+    metric_value_dict = strategy.feature_report(data, candidate_feature)
+    train_pipeline.generate_report(metric_value_dict)
+
+    print(time.time() - time_now)

+ 15 - 9
trainer/train.py

@@ -4,20 +4,26 @@
 @time: 2024/11/1
 @desc: 模型训练管道
 """
-from entitys import DataFeatureEntity
-from model import ModelBase
+from typing import Dict
+
+from entitys import DataPreparedEntity, TrainConfigEntity, MetricFucEntity
+from init import f_get_save_path
+from model import f_get_model
+from monitor.report_generate import Report
 
 
 class TrainPipeline():
-    def __init__(self, model: ModelBase):
-        self.model = model
+    def __init__(self, train_config: TrainConfigEntity):
+        self._train_config = train_config
+        model_clazz = f_get_model(self._train_config.model_type)
+        self.model = model_clazz(self._train_config)
 
-    def train(self, train_data: DataFeatureEntity, test_data: DataFeatureEntity):
-        metric_train = self.model.train(train_data)
-        self.model.predict_prob(test_data.get_Xdata())
+    def train(self, data: DataPreparedEntity):
+        metric_train = self.model.train(data)
+        print(metric_train)
 
-    def generate_report(self):
-        pass
+    def generate_report(self, metric_value_dict: Dict[str, MetricFucEntity]):
+        Report.generate_report(metric_value_dict, self._train_config.template_path, save_path=f_get_save_path("模型报告.docx"))
 
 
 if __name__ == "__main__":