Bläddra i källkod

add: web页面

yq 4 månader sedan
förälder
incheckning
e36b2c6b18
9 ändrade filer med 178 tillägg och 7 borttagningar
  1. 1 0
      .gitignore
  2. 53 0
      app.py
  3. 2 2
      config/base_config.py
  4. 5 1
      data/loader/data_loader_excel.py
  5. 1 1
      requirements-py310.txt
  6. 3 3
      train_test.py
  7. 11 0
      webui/__init__.py
  8. 26 0
      webui/manager.py
  9. 76 0
      webui/utils.py

+ 1 - 0
.gitignore

@@ -64,3 +64,4 @@ target/
 */~$*
 *.ipynb
 /flagged
+.gradio

+ 53 - 0
app.py

@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/12/4
+@desc: 
+"""
+
+import gradio as gr
+
+from webui import f_project_is_exist, f_data_upload, engine
+
+input_elems = set()
+elem_dict = {}
+
+with gr.Blocks("Easy-ML") as demo:
+    gr.HTML('<h1 ><center><font size="5">Easy-ML</font></center></h1>')
+    gr.HTML('<h2 ><center><font size="2">快速建模工具</font></center></h2>')
+    with gr.Tabs():
+        with gr.TabItem("数据"):
+            with gr.Row():
+                project_name = gr.Textbox(label="项目名称", placeholder="请输入不重复的项目名称",
+                                          info="项目名称将会被作为缓存目录名称,如果重复会导致结果被覆盖")
+            with gr.Row():
+                file_data = gr.File(label="建模数据")
+
+        with gr.TabItem("训练"):
+            with gr.Row():
+                with gr.Column():
+                    model_type = gr.Dropdown(["lr"], value="lr", label="模型")
+                    search_strategy = gr.Dropdown(["iv"], value="iv", label="特征搜索策略")
+                    gr.Textbox(label="Y标签")
+                    gr.Textbox(label="X特征")
+                    gr.Slider(0.05, 1, value=0.1, label="分箱组合采样率", step=0.01),
+                    train_button = gr.Button("开始训练", variant="primary")
+                with gr.Column():
+                    gr.Textbox(value="输出")
+
+        input_elems.update({project_name, file_data, model_type, search_strategy})
+        elem_dict.update(dict(
+            project_name=project_name,
+            file_data=file_data,
+            model_type=model_type,
+            search_strategy=search_strategy
+        ))
+        engine.add_elems(elem_dict)
+
+        project_name.change(fn=f_project_is_exist, inputs=input_elems)
+        file_data.upload(fn=f_data_upload, inputs=input_elems, outputs=[])
+
+    demo.launch(share=True)
+
+if __name__ == "__main__":
+    pass

+ 2 - 2
config/base_config.py

@@ -9,11 +9,11 @@ import os
 
 class BaseConfig:
     # 图片缓存位置
-    image_path = "./cache/image"
+    image_path = os.path.join(".", "cache", "image")
     os.makedirs(image_path, exist_ok=True)
 
     # 模型训练中间结果
-    train_path = "./cache/train"
+    train_path = os.path.join(".", "cache", "train")
     os.makedirs(train_path, exist_ok=True)
 
     # 表格合并相同列名的列

+ 5 - 1
data/loader/data_loader_excel.py

@@ -23,7 +23,11 @@ class DataLoaderExcel(DataLoaderBase):
         pass
 
     def get_data(self, file_path: str, sheet_name: str = 0) -> pd.DataFrame:
-        df: pd.DataFrame = pd.read_excel(file_path, sheet_name=sheet_name, index_col=False, dtype=str)
+        df: pd.DataFrame = pd.DataFrame()
+        if ".xlsx" in file_path:
+            df = pd.read_excel(file_path, sheet_name=sheet_name, index_col=False, dtype=str)
+        elif ".csv" in file_path:
+            df = pd.read_csv(file_path)
         columns = df.columns.to_list()
         columns_new = []
         for idx, column in enumerate(columns):

+ 1 - 1
requirements-py310.txt

@@ -4,6 +4,6 @@ xlrd==1.2.0
 scorecardpy==0.1.9.7
 toad==0.1.4
 dataframe_image==0.1.14
-gradio==3.0.12
+gradio==5.8.0
 matplotlib==3.9.3
 scikit-learn==1.1.3

+ 3 - 3
train_test.py

@@ -6,7 +6,7 @@
 """
 import time
 
-from entitys import DataSplitEntity
+from entitys import DataSplitEntity, DataProcessConfigEntity
 from feature import FilterStrategyFactory
 from model import ModelFactory
 from trainer import TrainPipeline
@@ -25,9 +25,9 @@ if __name__ == "__main__":
     filter_strategy_factory = FilterStrategyFactory()
     filter_strategy_clazz = filter_strategy_factory.get_strategy("iv")
     ## 可传入参数
-    filter_strategy = filter_strategy_clazz(y_column="creditability")
+    # filter_strategy = filter_strategy_clazz(y_column="creditability")
     ## 也可从配置文件加载
-    # filter_strategy = filter_strategy_clazz(DataProcessConfigEntity.from_config('./config/data_process_config_template.json'))
+    filter_strategy = filter_strategy_clazz(DataProcessConfigEntity.from_config('./config/data_process_config_template.json'))
 
     # 选择模型
     model_factory = ModelFactory()

+ 11 - 0
webui/__init__.py

@@ -0,0 +1,11 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/12/5
+@desc:
+
+"""
+from .manager import engine
+from .utils import f_project_is_exist, f_data_upload
+
+__all__ = ['engine', 'f_project_is_exist', 'f_data_upload']

+ 26 - 0
webui/manager.py

@@ -0,0 +1,26 @@
+from typing import Dict
+
+from gradio.components import Component
+
+
+class Manager:
+    def __init__(self) -> None:
+        self._id_to_elem: Dict[str, "Component"] = {}
+        self._elem_to_id: Dict["Component", str] = {}
+
+    def add_elems(self, elem_dict: Dict[str, "Component"]) -> None:
+        for elem_id, elem in elem_dict.items():
+            self._id_to_elem[elem_id] = elem
+            self._elem_to_id[elem] = elem_id
+
+    def _get_elem_by_id(self, elem_id: str) -> "Component":
+        return self._id_to_elem[elem_id]
+
+    def _get_id_by_elem(self, elem: "Component") -> str:
+        return self._elem_to_id[elem]
+
+    def get(self, data, key):
+        return data[self._get_elem_by_id(key)]
+
+
+engine = Manager()

+ 76 - 0
webui/utils.py

@@ -0,0 +1,76 @@
+# -*- coding: utf-8 -*-
+"""
+@author: yq
+@time: 2024/12/5
+@desc: 
+"""
+import os
+import shutil
+from typing import List
+
+import gradio as gr
+import pandas as pd
+
+from config import BaseConfig
+from data import DataLoaderExcel
+from .manager import engine
+
+DATA_DIR = "data"
+UPLOAD_DATA_PREFIX = "prefix_upload_data_"
+data_loader = DataLoaderExcel()
+
+
+def _check_save_dir(data):
+    project_name = engine.get(data, "project_name")
+    if project_name is None or len(project_name) == 0:
+        gr.Warning(message='项目名称不能为空', duration=5)
+        return False
+    return True
+
+
+def _get_prefix_file(save_path, prefix):
+    file_name_list: List[str] = os.listdir(save_path)
+    for file_name in file_name_list:
+        if prefix in file_name:
+            return os.path.join(save_path, file_name)
+
+
+def _get_base_dir(data):
+    project_name = engine.get(data, "project_name")
+    base_dir = os.path.join(BaseConfig.train_path, project_name)
+    return base_dir
+
+
+def _get_upload_data(data) -> pd.DataFrame:
+    base_dir = _get_base_dir(data)
+    save_path = os.path.join(base_dir, DATA_DIR)
+    file_path = _get_prefix_file(save_path, UPLOAD_DATA_PREFIX)
+    df = data_loader.get_data(file_path)
+
+
+def f_project_is_exist(data):
+    project_name = engine.get(data, "project_name")
+    if project_name is None or len(project_name) == 0:
+        gr.Warning(message='项目名称不能为空', duration=5)
+    elif os.path.exists(_get_base_dir(data)):
+        gr.Warning(message='项目名称已被使用', duration=5)
+
+
+def f_get_save_path(data, file_name: str, sub_dir="", name_prefix=""):
+    base_dir = _get_base_dir(data)
+    save_path = os.path.join(base_dir, sub_dir)
+    os.makedirs(save_path, exist_ok=True)
+    # 有前缀标示的先删除
+    if name_prefix:
+        file = _get_prefix_file(save_path, name_prefix)
+        os.remove(file)
+    save_path = os.path.join(save_path, name_prefix + os.path.basename(file_name))
+    return save_path
+
+
+def f_data_upload(data):
+    if not _check_save_dir(data):
+        return
+    file_data = engine.get(data, "file_data")
+    data_path = f_get_save_path(data, file_data.name, DATA_DIR, UPLOAD_DATA_PREFIX)
+    shutil.copy(file_data.name, data_path)