Răsfoiți Sursa

add: 训练结果

yq 4 luni în urmă
părinte
comite
064d461287
5 a modificat fișierele cu 65 adăugiri și 40 ștergeri
  1. 35 27
      app.py
  2. 3 3
      feature/strategy_iv.py
  3. 2 2
      model/model_lr.py
  4. 4 2
      trainer/train.py
  5. 21 6
      webui/utils.py

+ 35 - 27
app.py

@@ -18,13 +18,14 @@ elem_dict = {}
 with gr.Blocks() as demo:
     gr.HTML('<h1 ><center><font size="5">Easy-ML</font></center></h1>')
     gr.HTML('<h2 ><center><font size="2">快速建模工具</font></center></h2>')
+    gr.State([])
     with gr.Tabs():
         with gr.TabItem("数据"):
             with gr.Row():
                 project_name = gr.Textbox(label="项目名称", placeholder="请输入不重复的项目名称",
                                           info="项目名称将会被作为缓存目录名称,如果重复会导致结果被覆盖")
             with gr.Row():
-                file_data = gr.File(label="建模数据")
+                file_data = gr.File(label="建模数据", file_types=[".csv", ".xlsx"])
             with gr.Row():
                 data_upload = gr.Dataframe(visible=False, label="当前上传数据", max_height=300)
             with gr.Row():
@@ -49,7 +50,7 @@ with gr.Blocks() as demo:
                         x_columns_candidate = gr.Dropdown(label="X特征列", multiselect=True, interactive=True,
                                                           info="不应包含Y特征列,不选择则使用全部特征")
                     with gr.Row():
-                        x_candidate_num = gr.Number(value=10, label="建模最多保留特征数",  info="保留最重要的N个特征",
+                        x_candidate_num = gr.Number(value=10, label="建模最多保留特征数", info="保留最重要的N个特征",
                                                     interactive=True)
                         sample_rate = gr.Slider(0.05, 1, value=0.1, label="分箱组合采样率", info="对2-5箱所有分箱组合进行采样",
                                                 step=0.01, interactive=True)
@@ -59,34 +60,41 @@ with gr.Blocks() as demo:
                         test_split_strategy = gr.Dropdown(["随机"], value="随机", label="测试集划分方式")
                         test_split_rate = gr.Slider(0, 0.5, value=0.3, label="测试集划分比例", step=0.05, interactive=True)
 
-                    train_button = gr.Button("开始训练", variant="primary")
+                    train_button = gr.Button("开始训练", variant="primary", elem_id="train_button")
                 with gr.Column():
-                    gr.Textbox(value="输出")
+                    with gr.Row():
+                        train_progress = gr.Textbox(label="训练进度")
+                    with gr.Row():
+                        auc_df = gr.Dataframe(visible=False, label="auc ks", max_height=300, interactive=False)
+                    with gr.Row():
+                        gallery_auc = gr.Gallery(label="auc ks", columns=[1], rows=[2], object_fit="contain",
+                                                 height="auto", visible=False, interactive=False)
 
-            input_elems.update(
-                {model_type, search_strategy, y_column, x_columns_candidate, x_candidate_num, sample_rate,
-                 special_values, test_split_strategy, test_split_rate
-                 })
-            elem_dict.update(dict(
-                model_type=model_type,
-                feature_search_strategy=search_strategy,
-                y_column=y_column,
-                x_columns_candidate=x_columns_candidate,
-                x_candidate_num=x_candidate_num,
-                sample_rate=sample_rate,
-                special_values=special_values,
-                test_split_strategy=test_split_strategy,
-                test_split_rate=test_split_rate,
-            ))
+                input_elems.update(
+                    {model_type, search_strategy, y_column, x_columns_candidate, x_candidate_num, sample_rate,
+                     special_values, test_split_strategy, test_split_rate
+                     })
+                elem_dict.update(dict(
+                    model_type=model_type,
+                    feature_search_strategy=search_strategy,
+                    y_column=y_column,
+                    x_columns_candidate=x_columns_candidate,
+                    x_candidate_num=x_candidate_num,
+                    sample_rate=sample_rate,
+                    special_values=special_values,
+                    test_split_strategy=test_split_strategy,
+                    test_split_rate=test_split_rate,
+                ))
 
-        engine.add_elems(elem_dict)
+            engine.add_elems(elem_dict)
 
-        project_name.change(fn=f_project_is_exist, inputs=input_elems)
-        file_data.upload(fn=f_data_upload, inputs=input_elems, outputs=[data_upload, data_insight, y_column,
-                                                                        x_columns_candidate])
-        train_button.click(fn=f_train, inputs=input_elems)
+            project_name.change(fn=f_project_is_exist, inputs=input_elems)
+            file_data.upload(fn=f_data_upload, inputs=input_elems, outputs=[data_upload, data_insight, y_column,
+                                                                            x_columns_candidate])
+            train_button.click(fn=f_train, inputs=input_elems, outputs=[train_progress, auc_df, gallery_auc])
 
-    demo.launch(share=True)
+        demo.queue(concurrency_count=3)
+        demo.launch(share=False, show_error=True)
 
-if __name__ == "__main__":
-    pass
+    if __name__ == "__main__":
+        pass

+ 3 - 3
feature/strategy_iv.py

@@ -13,6 +13,7 @@ import pandas as pd
 import scorecardpy as sc
 import seaborn as sns
 from pandas.core.dtypes.common import is_numeric_dtype
+from tqdm import tqdm
 
 from entitys import DataSplitEntity, CandidateFeatureEntity, DataPreparedEntity, DataFeatureEntity, MetricFucEntity
 from .feature_utils import f_judge_monto, f_get_corr, f_get_ivf
@@ -255,8 +256,7 @@ class StrategyIv(FilterStrategyBase):
         test_sv_bin_list = None
         if test_data_filter is not None:
             test_sv_bin_list = _get_sv_bins(test_data, x_column, y_column, special_values)
-        from tqdm import tqdm
-        for point_list in tqdm(points_list):
+        for point_list in points_list:
             train_bins = _get_bins(train_data_filter, x_column, y_column, point_list)
             # 与special_values合并计算iv
             for sv_bin in train_sv_bin_list:
@@ -285,7 +285,7 @@ class StrategyIv(FilterStrategyBase):
         x_columns_candidate = list(bins_iv_dict.keys())
         candidate_num = self.data_process_config.candidate_num
         candidate_dict: Dict[str, CandidateFeatureEntity] = {}
-        for x_column in x_columns_candidate:
+        for x_column in tqdm(x_columns_candidate):
             if is_numeric_dtype(data.train_data[x_column]):
                 iv_max, breaks_list = self._f_get_best_bins_numeric(data, x_column)
                 candidate_dict[x_column] = CandidateFeatureEntity(x_column, breaks_list, iv_max)

+ 2 - 2
model/model_lr.py

@@ -80,8 +80,8 @@ class ModelLr(ModelBase):
             path = self._train_config.f_get_save_path(f"test_perf.png")
             test_perf["pic"].savefig(path)
             image_path_list.append(path)
-            test_auc = test_perf["KS"]
-            test_ks = test_perf["AUC"]
+            test_auc = test_perf["AUC"]
+            test_ks = test_perf["KS"]
 
         df_auc = pd.DataFrame()
         df_auc["样本集"] = ["训练集", "测试集"]

+ 4 - 2
trainer/train.py

@@ -4,8 +4,9 @@
 @time: 2024/11/1
 @desc: 模型训练管道
 """
+from typing import Dict
 
-from entitys import DataSplitEntity
+from entitys import DataSplitEntity, MetricFucEntity
 from feature.filter_strategy_base import FilterStrategyBase
 from init import init
 from model import ModelBase
@@ -21,7 +22,7 @@ class TrainPipeline():
         self._data = data
         self._model._train_config.set_save_path_func(self._filter_strategy.data_process_config.f_get_save_path)
 
-    def train(self, ):
+    def train(self, ) -> Dict[str, MetricFucEntity]:
         # 处理数据,获取候选特征
         candidate_feature = self._filter_strategy.filter(self._data)
         # 生成训练数据
@@ -31,6 +32,7 @@ class TrainPipeline():
 
         metric_value_dict_train = self._model.train(data_prepared, *data_prepared.args, **data_prepared.kwargs)
         self.metric_value_dict = {**metric_value_dict_feature, **metric_value_dict_train}
+        return self.metric_value_dict
 
     def generate_report(self, ):
         Report.generate_report(self.metric_value_dict, self._model.get_template_path(),

+ 21 - 6
webui/utils.py

@@ -41,8 +41,7 @@ def _clean_base_dir(data):
 def _check_save_dir(data):
     project_name = engine.get(data, "project_name")
     if project_name is None or len(project_name) == 0:
-        gr.Warning(message='项目名称不能为空', duration=5)
-        return False
+        raise gr.Error(message='项目名称不能为空', duration=5)
     return True
 
 
@@ -67,6 +66,11 @@ def _get_upload_data(data) -> pd.DataFrame:
     return df
 
 
+def _get_auc_ks_images(data):
+    base_dir = _get_base_dir(data)
+    return [os.path.join(base_dir, "train_perf.png"), os.path.join(base_dir, "test_perf.png")]
+
+
 def f_project_is_exist(data):
     project_name = engine.get(data, "project_name")
     if project_name is None or len(project_name) == 0:
@@ -104,12 +108,16 @@ def f_data_upload(data):
 def f_verify_param(data):
     y_column = engine.get(data, "y_column")
     if y_column is None:
-        gr.Warning(message=f'Y标签列不能为空', duration=5)
-        return False
+        raise gr.Error(message=f'Y标签列不能为空', duration=5)
     return True
 
 
-def f_train(data):
+def f_train(data, progress=gr.Progress(track_tqdm=True)):
+    # import time
+    # print(1111111)
+    # time.sleep(5)
+    # return gr.update(elem_id="train_button", value="111")
+    progress(0, desc="Starting")
     feature_search_strategy = engine.get(data, "feature_search_strategy")
     model_type = engine.get(data, "model_type")
     test_split_rate = engine.get(data, "test_split_rate")
@@ -124,6 +132,7 @@ def f_train(data):
         # 数据集划分
     train_data, test_data = train_test_split(data_upload, test_size=test_split_rate, shuffle=True, random_state=2025)
     data_split = DataSplitEntity(train_data=train_data, val_data=None, test_data=test_data)
+    progress(0.01)
 
     # 特征处理
     ## 获取特征筛选策略
@@ -136,5 +145,11 @@ def f_train(data):
 
     # 训练并生成报告
     train_pipeline = TrainPipeline(filter_strategy, model, data_split)
-    train_pipeline.train()
+    metric_value_dict = train_pipeline.train()
+    progress(0.95)
     train_pipeline.generate_report()
+
+    auc_df = metric_value_dict["模型结果"].table
+
+    return gr.update(value="训练完成"), gr.update(value=auc_df, visible=True), \
+           gr.update(value=_get_auc_ks_images(data), visible=True)