Przeglądaj źródła

add: web 训练页面

yq 4 miesięcy temu
rodzic
commit
2300140eb2

+ 49 - 15
app.py

@@ -7,7 +7,10 @@
 
 import gradio as gr
 
-from webui import f_project_is_exist, f_data_upload, engine
+from init import init
+from webui import f_project_is_exist, f_data_upload, engine, f_train
+
+init()
 
 input_elems = set()
 elem_dict = {}
@@ -27,30 +30,61 @@ with gr.Blocks() as demo:
             with gr.Row():
                 data_insight = gr.Dataframe(visible=False, label="数据探查", max_height=600, wrap=True)
 
+            input_elems.update(
+                {project_name, file_data, data_upload})
+            elem_dict.update(dict(
+                project_name=project_name,
+                file_data=file_data,
+                data_upload=data_upload
+            ))
+
         with gr.TabItem("训练"):
             with gr.Row():
                 with gr.Column():
-                    model_type = gr.Dropdown(["lr"], value="lr", label="模型")
-                    search_strategy = gr.Dropdown(["iv"], value="iv", label="特征搜索策略")
-                    y_column = gr.Textbox(label="Y标签")
-                    x_columns = gr.Textbox(label="X特征")
-                    gr.Slider(0.05, 1, value=0.1, label="分箱组合采样率", step=0.01),
+                    with gr.Row():
+                        model_type = gr.Dropdown(["lr"], value="lr", label="模型")
+                        search_strategy = gr.Dropdown(["iv"], value="iv", label="特征搜索策略")
+                    with gr.Row():
+                        y_column = gr.Dropdown(label="Y标签列", interactive=True, info="其值应该是0或者1")
+                        x_columns_candidate = gr.Dropdown(label="X特征列", multiselect=True, interactive=True,
+                                                          info="不应包含Y特征列,不选择则使用全部特征")
+                    with gr.Row():
+                        x_candidate_num = gr.Number(value=10, label="建模最多保留特征数",  info="保留最重要的N个特征",
+                                                    interactive=True)
+                        sample_rate = gr.Slider(0.05, 1, value=0.1, label="分箱组合采样率", info="对2-5箱所有分箱组合进行采样",
+                                                step=0.01, interactive=True)
+                        special_values = gr.Textbox(label="特殊值", placeholder="可以是dict list str格式",
+                                                    info="分箱时特殊值会单独一个分箱")
+                    with gr.Row():
+                        test_split_strategy = gr.Dropdown(["随机"], value="随机", label="测试集划分方式")
+                        test_split_rate = gr.Slider(0, 0.5, value=0.3, label="测试集划分比例", step=0.05, interactive=True)
+
                     train_button = gr.Button("开始训练", variant="primary")
                 with gr.Column():
                     gr.Textbox(value="输出")
 
-        input_elems.update({project_name, file_data, data_upload, model_type, search_strategy})
-        elem_dict.update(dict(
-            project_name=project_name,
-            file_data=file_data,
-            data_upload=data_upload,
-            model_type=model_type,
-            search_strategy=search_strategy
-        ))
+            input_elems.update(
+                {model_type, search_strategy, y_column, x_columns_candidate, x_candidate_num, sample_rate,
+                 special_values, test_split_strategy, test_split_rate
+                 })
+            elem_dict.update(dict(
+                model_type=model_type,
+                feature_search_strategy=search_strategy,
+                y_column=y_column,
+                x_columns_candidate=x_columns_candidate,
+                x_candidate_num=x_candidate_num,
+                sample_rate=sample_rate,
+                special_values=special_values,
+                test_split_strategy=test_split_strategy,
+                test_split_rate=test_split_rate,
+            ))
+
         engine.add_elems(elem_dict)
 
         project_name.change(fn=f_project_is_exist, inputs=input_elems)
-        file_data.upload(fn=f_data_upload, inputs=input_elems, outputs=[data_upload, data_insight])
+        file_data.upload(fn=f_data_upload, inputs=input_elems, outputs=[data_upload, data_insight, y_column,
+                                                                        x_columns_candidate])
+        train_button.click(fn=f_train, inputs=input_elems)
 
     demo.launch(share=True)
 

+ 47 - 6
entitys/data_process_config_entity.py

@@ -8,15 +8,20 @@ import json
 import os
 from typing import List, Union
 
-from commom import GeneralException
+from commom import GeneralException, f_get_datetime
+from config import BaseConfig
 from enums import ResultCodesEnum
 
 
 class DataProcessConfigEntity():
-    def __init__(self, y_column: str, x_columns_candidate: List[str] = None, fill_method: str = None,
+    def __init__(self, y_column: str, x_columns_candidate: List[str] = None, fill_method: str = None, fill_value=None,
                  split_method: str = None, feature_search_strategy: str = 'iv', bin_search_interval: float = 0.05,
                  iv_threshold: float = 0.03, iv_threshold_wide: float = 0.05, corr_threshold: float = 0.4,
-                 sample_rate: float = 0.1, x_candidate_num: int = 10, special_values: Union[dict, list] = None):
+                 sample_rate: float = 0.1, x_candidate_num: int = 10, special_values: Union[dict, list, str] = None,
+                 project_name: str = None, *args, **kwargs):
+
+        # 项目名称,和缓存路径有关
+        self._project_name = project_name
 
         # 定义y变量
         self._y_column = y_column
@@ -27,6 +32,9 @@ class DataProcessConfigEntity():
         # 缺失值填充方法
         self._fill_method = fill_method
 
+        # 缺失值填充值
+        self._fill_value = fill_value
+
         # 数据划分方法
         self._split_method = split_method
 
@@ -53,6 +61,21 @@ class DataProcessConfigEntity():
         # 贪婪搜索采样比例,只针对4箱5箱时有效
         self._sample_rate = sample_rate
 
+        if self._project_name is None or len(self._project_name) == 0:
+            self._base_dir = os.path.join(BaseConfig.train_path, f"{f_get_datetime()}")
+        else:
+            self._base_dir = os.path.join(BaseConfig.train_path, self._project_name)
+
+        os.makedirs(self._base_dir, exist_ok=True)
+
+    @property
+    def base_dir(self):
+        return self._base_dir
+
+    @property
+    def project_name(self):
+        return self._project_name
+
     @property
     def sample_rate(self):
         return self._sample_rate
@@ -77,6 +100,10 @@ class DataProcessConfigEntity():
     def x_columns_candidate(self):
         return self._x_columns_candidate
 
+    @property
+    def fill_value(self):
+        return self._fill_value
+
     @property
     def fill_method(self):
         return self._fill_method
@@ -99,14 +126,28 @@ class DataProcessConfigEntity():
 
     @property
     def special_values(self):
-        return self._special_values
+        if self._special_values is None or len(self._special_values) == 0:
+            return None
+        if isinstance(self._special_values, str):
+            return [self._special_values]
+        if isinstance(self._special_values, (dict, list)):
+            return self._special_values
+        return None
 
     def get_special_values(self, column: str = None):
+        if self._special_values is None or len(self._special_values) == 0:
+            return None
+        if isinstance(self._special_values, str):
+            return [self._special_values]
         if column is None or isinstance(self._special_values, list):
             return self._special_values
         if isinstance(self._special_values, dict) and column is not None:
-            return self._special_values.get(column, [])
-        return []
+            return self._special_values.get(column, None)
+        return None
+
+    def f_get_save_path(self, file_name: str) -> str:
+        path = os.path.join(self._base_dir, file_name)
+        return path
 
     @staticmethod
     def from_config(config_path: str):

+ 6 - 1
entitys/train_config_entity.py

@@ -12,14 +12,19 @@ from enums import ResultCodesEnum
 
 
 class TrainConfigEntity():
-    def __init__(self, lr: float = None):
+    def __init__(self, lr: float = None, *args, **kwargs):
         # 学习率
         self._lr = lr
+        # 该函数需要去继承
+        self.f_get_save_path = None
 
     @property
     def lr(self):
         return self._lr
 
+    def set_save_path_func(self, f):
+        self.f_get_save_path = f
+
     @staticmethod
     def from_config(config_path: str):
         """

+ 8 - 8
feature/filter_strategy_factory.py

@@ -11,16 +11,16 @@ from enums import FilterStrategyEnum, ResultCodesEnum
 from .filter_strategy_base import FilterStrategyBase
 from .strategy_iv import StrategyIv
 
+strategy_map = {
+    FilterStrategyEnum.IV.value: StrategyIv
+}
 
-class FilterStrategyFactory():
 
-    def __init__(self, ):
-        self.strategy_map = {
-            FilterStrategyEnum.IV.value: StrategyIv
-        }
+class FilterStrategyFactory():
 
-    def get_strategy(self, strategy: str) -> Type[FilterStrategyBase]:
-        if strategy not in self.strategy_map.keys():
+    @staticmethod
+    def get_strategy(strategy: str) -> Type[FilterStrategyBase]:
+        if strategy not in strategy_map.keys():
             raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"特征搜索策略【{strategy}】不存在")
-        strategy = self.strategy_map.get(strategy)
+        strategy = strategy_map.get(strategy)
         return strategy

+ 2 - 5
feature/strategy_iv.py

@@ -15,12 +15,9 @@ import seaborn as sns
 from pandas.core.dtypes.common import is_numeric_dtype
 
 from entitys import DataSplitEntity, CandidateFeatureEntity, DataPreparedEntity, DataFeatureEntity, MetricFucEntity
-from init import f_get_save_path
 from .feature_utils import f_judge_monto, f_get_corr, f_get_ivf
 from .filter_strategy_base import FilterStrategyBase
 
-plt.rcParams['figure.figsize'] = (8, 8)
-
 
 class StrategyIv(FilterStrategyBase):
 
@@ -41,7 +38,7 @@ class StrategyIv(FilterStrategyBase):
         plt.title('Variables Correlation', fontsize=15)
         plt.yticks(rotation=0)
         plt.xticks(rotation=90)
-        path = f_get_save_path(f"var_corr.png")
+        path = self.data_process_config.f_get_save_path(f"var_corr.png")
         plt.savefig(path)
         return path
 
@@ -52,7 +49,7 @@ class StrategyIv(FilterStrategyBase):
             # bin_df["bin"] = bin_df["bin"].apply(lambda x: re.sub(r"(\d+\.\d+)",
             #                                                      lambda m: "{:.2f}".format(float(m.group(0))), x))
             sc.woebin_plot(bin_df)
-            path = f_get_save_path(f"{prefix}_{k}.png")
+            path = self.data_process_config.f_get_save_path(f"{prefix}_{k}.png")
             plt.savefig(path)
             image_path_list.append(path)
         return image_path_list

+ 6 - 16
init/__init__.py

@@ -5,29 +5,19 @@
 @desc: 一些资源初始化
 """
 
-import os
-
 import matplotlib
+
 matplotlib.use('Agg')
 
 import matplotlib.pyplot as plt
 
-from commom import f_get_datetime
-from config import BaseConfig
-
-__all__ = ['f_get_save_path']
-
-# 设置支持中文的字体
-plt.rcParams['font.sans-serif'] = ['SimHei']  # 使用黑体
-plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题
-
-save_path = os.path.join(BaseConfig.train_path, f"{f_get_datetime()}")
-os.makedirs(save_path, exist_ok=True)
+__all__ = ['init']
 
 
-def f_get_save_path(file_name: str) -> str:
-    path = os.path.join(save_path, file_name)
-    return path
+def init():
+    plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置支持中文的字体
+    plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题
+    plt.rcParams['figure.figsize'] = (8, 8)
 
 
 if __name__ == "__main__":

+ 8 - 8
model/model_factory.py

@@ -11,15 +11,15 @@ from enums import ModelEnum, ResultCodesEnum
 from model import ModelBase
 from .model_lr import ModelLr
 
+model_map = {
+    ModelEnum.LR.value: ModelLr
+}
 
-class ModelFactory():
 
-    def __init__(self, ):
-        self.model_map = {
-            ModelEnum.LR.value: ModelLr
-        }
+class ModelFactory():
 
-    def get_model(self, model_type: str) -> Type[ModelBase]:
-        if model_type not in self.model_map.keys():
+    @staticmethod
+    def get_model(model_type: str) -> Type[ModelBase]:
+        if model_type not in model_map.keys():
             raise GeneralException(ResultCodesEnum.ILLEGAL_PARAMS, message=f"模型【{model_type}】没有实现")
-        return self.model_map.get(model_type)
+        return model_map.get(model_type)

+ 6 - 7
model/model_lr.py

@@ -13,7 +13,6 @@ from sklearn.linear_model import LogisticRegression
 from commom import f_df_to_image
 from entitys import DataPreparedEntity, MetricFucEntity, DataSplitEntity
 from feature import f_calcu_model_ks, f_get_model_score_bin, f_calcu_model_psi
-from init import f_get_save_path
 from .model_base import ModelBase
 
 
@@ -50,7 +49,7 @@ class ModelLr(ModelBase):
         card_df = pd.DataFrame(columns=card['basepoints'].keys())
         for k, v in card.items():
             card_df = pd.concat((card_df, v))
-        card_df_path = f_get_save_path(f"card_df.png")
+        card_df_path = self._train_config.f_get_save_path(f"card_df.png")
         f_df_to_image(card_df, card_df_path)
         metric_value_dict["评分卡"] = MetricFucEntity(image_path=card_df_path)
 
@@ -65,7 +64,7 @@ class ModelLr(ModelBase):
         train_prob = self.lr.predict_proba(train_data.get_Xdata())[:, 1]
         image_path_list = []
         train_perf = sc.perf_eva(train_y, train_prob, title="train", show_plot=True)
-        path = f_get_save_path(f"train_perf.png")
+        path = self._train_config.f_get_save_path(f"train_perf.png")
         train_perf["pic"].savefig(path)
         image_path_list.append(path)
 
@@ -78,7 +77,7 @@ class ModelLr(ModelBase):
             test_prob = self.lr.predict_proba(test_data.get_Xdata())[:, 1]
             test_y = test_data.get_Ydata()
             test_perf = sc.perf_eva(test_y, test_prob, title="test", show_plot=True)
-            path = f_get_save_path(f"test_perf.png")
+            path = self._train_config.f_get_save_path(f"test_perf.png")
             test_perf["pic"].savefig(path)
             image_path_list.append(path)
             test_auc = test_perf["KS"]
@@ -94,19 +93,19 @@ class ModelLr(ModelBase):
         # 评分卡分箱
         train_data_original, score_bins = f_get_model_score_bin(train_data_original, card)
         train_data_gain = f_calcu_model_ks(train_data_original, y_column, sort_ascending=True)
-        train_data_gain_path = f_get_save_path(f"train_data_gain.png")
+        train_data_gain_path = self._train_config.f_get_save_path(f"train_data_gain.png")
         f_df_to_image(train_data_gain, train_data_gain_path)
         metric_value_dict["训练集分数分箱"] = MetricFucEntity(image_path=train_data_gain_path)
         if test_data is not None:
             test_data_original, bins = f_get_model_score_bin(test_data_original, card, score_bins)
             test_data_gain = f_calcu_model_ks(test_data_original, y_column, sort_ascending=True)
-            test_data_gain_path = f_get_save_path(f"test_data_gain.png")
+            test_data_gain_path = self._train_config.f_get_save_path(f"test_data_gain.png")
             f_df_to_image(test_data_gain, test_data_gain_path)
             metric_value_dict["测试集分数分箱"] = MetricFucEntity(image_path=test_data_gain_path)
 
         # 模型分psi
         model_psi = f_calcu_model_psi(train_data_original, test_data_original)
-        model_psi_path = f_get_save_path(f"model_psi.png")
+        model_psi_path = self._train_config.f_get_save_path(f"model_psi.png")
         f_df_to_image(model_psi, model_psi_path)
         metric_value_dict["模型稳定性"] = MetricFucEntity(value=model_psi["psi"].sum().round(4), image_path=model_psi_path)
 

+ 2 - 4
train_test.py

@@ -22,16 +22,14 @@ if __name__ == "__main__":
 
     # 特征处理
     ## 获取特征筛选策略
-    filter_strategy_factory = FilterStrategyFactory()
-    filter_strategy_clazz = filter_strategy_factory.get_strategy("iv")
+    filter_strategy_clazz = FilterStrategyFactory.get_strategy("iv")
     ## 可传入参数
     # filter_strategy = filter_strategy_clazz(y_column="creditability")
     ## 也可从配置文件加载
     filter_strategy = filter_strategy_clazz(DataProcessConfigEntity.from_config('./config/data_process_config_template.json'))
 
     # 选择模型
-    model_factory = ModelFactory()
-    model_clazz = model_factory.get_model("lr")
+    model_clazz = ModelFactory.get_model("lr")
     model = model_clazz()
 
     # 训练并生成报告

+ 5 - 2
trainer/train.py

@@ -7,16 +7,19 @@
 
 from entitys import DataSplitEntity
 from feature.filter_strategy_base import FilterStrategyBase
-from init import f_get_save_path
+from init import init
 from model import ModelBase
 from monitor.report_generate import Report
 
+init()
+
 
 class TrainPipeline():
     def __init__(self, filter_strategy: FilterStrategyBase, model: ModelBase, data: DataSplitEntity):
         self._filter_strategy = filter_strategy
         self._model = model
         self._data = data
+        self._model._train_config.set_save_path_func(self._filter_strategy.data_process_config.f_get_save_path)
 
     def train(self, ):
         # 处理数据,获取候选特征
@@ -31,7 +34,7 @@ class TrainPipeline():
 
     def generate_report(self, ):
         Report.generate_report(self.metric_value_dict, self._model.get_template_path(),
-                               save_path=f_get_save_path("模型报告.docx"))
+                               save_path=self._filter_strategy.data_process_config.f_get_save_path("模型报告.docx"))
 
 
 if __name__ == "__main__":

+ 2 - 2
webui/__init__.py

@@ -6,6 +6,6 @@
 
 """
 from .manager import engine
-from .utils import f_project_is_exist, f_data_upload
+from .utils import f_project_is_exist, f_data_upload, f_train
 
-__all__ = ['engine', 'f_project_is_exist', 'f_data_upload']
+__all__ = ['engine', 'f_project_is_exist', 'f_data_upload', 'f_train']

+ 6 - 0
webui/manager.py

@@ -22,5 +22,11 @@ class Manager:
     def get(self, data, key):
         return data[self._get_elem_by_id(key)]
 
+    def get_all(self, data) -> Dict:
+        all = {}
+        for k, v in self._id_to_elem.items():
+            all[k] = data[v]
+        return all
+
 
 engine = Manager()

+ 62 - 4
webui/utils.py

@@ -10,16 +10,34 @@ from typing import List
 
 import gradio as gr
 import pandas as pd
+from sklearn.model_selection import train_test_split
 
 from config import BaseConfig
 from data import DataLoaderExcel, DataExplore
+from entitys import DataSplitEntity
+from feature import FilterStrategyFactory
+from model import ModelFactory
+from trainer import TrainPipeline
 from .manager import engine
 
-DATA_DIR = "data"
+DATA_SUB_DIR = "data"
 UPLOAD_DATA_PREFIX = "prefix_upload_data_"
 data_loader = DataLoaderExcel()
 
 
+def _clean_base_dir(data):
+    base_dir = _get_base_dir(data)
+    file_name_list: List[str] = os.listdir(base_dir)
+    for file_name in file_name_list:
+        if file_name in [DATA_SUB_DIR]:
+            continue
+        file_path = os.path.join(base_dir, file_name)
+        if os.path.isdir(file_path):
+            shutil.rmtree(file_path)
+        else:
+            os.remove(file_path)
+
+
 def _check_save_dir(data):
     project_name = engine.get(data, "project_name")
     if project_name is None or len(project_name) == 0:
@@ -43,7 +61,7 @@ def _get_base_dir(data):
 
 def _get_upload_data(data) -> pd.DataFrame:
     base_dir = _get_base_dir(data)
-    save_path = os.path.join(base_dir, DATA_DIR)
+    save_path = os.path.join(base_dir, DATA_SUB_DIR)
     file_path = _get_prefix_file(save_path, UPLOAD_DATA_PREFIX)
     df = data_loader.get_data(file_path)
     return df
@@ -74,9 +92,49 @@ def f_data_upload(data):
     if not _check_save_dir(data):
         return
     file_data = engine.get(data, "file_data")
-    data_path = f_get_save_path(data, file_data.name, DATA_DIR, UPLOAD_DATA_PREFIX)
+    data_path = f_get_save_path(data, file_data.name, DATA_SUB_DIR, UPLOAD_DATA_PREFIX)
     shutil.copy(file_data.name, data_path)
     df = _get_upload_data(data)
     distribution = DataExplore.distribution(df)
+    columns = df.columns.to_list()
+    return gr.update(value=df, visible=True), gr.update(value=distribution, visible=True), gr.update(
+        choices=columns), gr.update(choices=columns)
+
+
+def f_verify_param(data):
+    y_column = engine.get(data, "y_column")
+    if y_column is None:
+        gr.Warning(message=f'Y标签列不能为空', duration=5)
+        return False
+    return True
+
+
+def f_train(data):
+    feature_search_strategy = engine.get(data, "feature_search_strategy")
+    model_type = engine.get(data, "model_type")
+    test_split_rate = engine.get(data, "test_split_rate")
+    data_upload = engine.get(data, "data_upload")
+    all_param = engine.get_all(data)
+    # 清空储存目录
+    _clean_base_dir(data)
+    # 校验参数
+    if not f_verify_param(data):
+        return
+
+        # 数据集划分
+    train_data, test_data = train_test_split(data_upload, test_size=test_split_rate, shuffle=True, random_state=2025)
+    data_split = DataSplitEntity(train_data=train_data, val_data=None, test_data=test_data)
+
+    # 特征处理
+    ## 获取特征筛选策略
+    filter_strategy_clazz = FilterStrategyFactory.get_strategy(feature_search_strategy)
+    filter_strategy = filter_strategy_clazz(**all_param)
+
+    # 选择模型
+    model_clazz = ModelFactory.get_model(model_type)
+    model = model_clazz(**all_param)
 
-    return gr.update(value=df, visible=True), gr.update(value=distribution, visible=True),
+    # 训练并生成报告
+    train_pipeline = TrainPipeline(filter_strategy, model, data_split)
+    train_pipeline.train()
+    train_pipeline.generate_report()