Selaa lähdekoodia

add: 下载报告

yq 4 kuukautta sitten
vanhempi
sitoutus
9fc3f92634
4 muutettua tiedostoa jossa 27 lisäystä ja 13 poistoa
  1. 10 4
      app.py
  2. 4 4
      start.sh
  3. 2 2
      webui/__init__.py
  4. 11 3
      webui/utils.py

+ 10 - 4
app.py

@@ -8,7 +8,7 @@
 import gradio as gr
 
 from init import init
-from webui import f_project_is_exist, f_data_upload, engine, f_train
+from webui import f_project_is_exist, f_data_upload, engine, f_train, f_download_report
 
 init()
 
@@ -63,7 +63,11 @@ with gr.Blocks() as demo:
                     train_button = gr.Button("开始训练", variant="primary", elem_id="train_button")
                 with gr.Column():
                     with gr.Row():
-                        train_progress = gr.Textbox(label="训练进度")
+                        train_progress = gr.Textbox(label="训练进度", scale=4)
+                        download_report = gr.DownloadButton(label="报告下载", variant="primary", elem_id="download_report",
+                                                            visible=False, scale=1)
+                        file_report = gr.File(visible=False)
+
                     with gr.Row():
                         auc_df = gr.Dataframe(visible=False, label="auc ks", max_height=300, interactive=False)
                     with gr.Row():
@@ -91,9 +95,11 @@ with gr.Blocks() as demo:
             project_name.change(fn=f_project_is_exist, inputs=input_elems)
             file_data.upload(fn=f_data_upload, inputs=input_elems, outputs=[data_upload, data_insight, y_column,
                                                                             x_columns_candidate])
-            train_button.click(fn=f_train, inputs=input_elems, outputs=[train_progress, auc_df, gallery_auc])
+            train_button.click(fn=f_train, inputs=input_elems,
+                               outputs=[train_progress, auc_df, gallery_auc, download_report])
+            download_report.click(fn=f_download_report, inputs=input_elems, outputs=download_report)
 
-        demo.queue(concurrency_count=3)
+        demo.queue(default_concurrency_limit=5)
         demo.launch(share=False, show_error=True)
 
     if __name__ == "__main__":

+ 4 - 4
start.sh

@@ -1,11 +1,11 @@
 #!/bin/bash
 
-source activate chatglm2
+source activate easy_ml
 
 PATH_APP=$(pwd)
 
 function get_pid() {
-  APP_PID=$(ps -ef | grep "python $PATH_APP/main.py" | grep -v grep | awk '{print $2}')
+  APP_PID=$(ps -ef | grep "python $PATH_APP/app.py" | grep -v grep | awk '{print $2}')
 }
 
 function kill_app() {
@@ -18,8 +18,8 @@ function kill_app() {
 }
 
 function start_app() {
-  echo $(date +%F%n%T) "开始启动model-api-classify..."
-  PYTHONIOENCODING=utf-8 nohup python $PATH_APP/main.py > $PATH_APP/nohup.out 2>&1 &
+  echo $(date +%F%n%T) "开始启动 app..."
+  PYTHONIOENCODING=utf-8 nohup python $PATH_APP/app.py > $PATH_APP/nohup.out 2>&1 &
   sleep 3
   echo $(tail -50 $PATH_APP/nohup.out)
   echo "启动完成..."

+ 2 - 2
webui/__init__.py

@@ -6,6 +6,6 @@
 
 """
 from .manager import engine
-from .utils import f_project_is_exist, f_data_upload, f_train
+from .utils import f_project_is_exist, f_data_upload, f_train, f_download_report
 
-__all__ = ['engine', 'f_project_is_exist', 'f_data_upload', 'f_train']
+__all__ = ['engine', 'f_project_is_exist', 'f_data_upload', 'f_train', 'f_download_report']

+ 11 - 3
webui/utils.py

@@ -79,7 +79,7 @@ def f_project_is_exist(data):
         gr.Warning(message='项目名称已被使用', duration=5)
 
 
-def f_get_save_path(data, file_name: str, sub_dir="", name_prefix=""):
+def _get_save_path(data, file_name: str, sub_dir="", name_prefix=""):
     base_dir = _get_base_dir(data)
     save_path = os.path.join(base_dir, sub_dir)
     os.makedirs(save_path, exist_ok=True)
@@ -96,7 +96,7 @@ def f_data_upload(data):
     if not _check_save_dir(data):
         return
     file_data = engine.get(data, "file_data")
-    data_path = f_get_save_path(data, file_data.name, DATA_SUB_DIR, UPLOAD_DATA_PREFIX)
+    data_path = _get_save_path(data, file_data.name, DATA_SUB_DIR, UPLOAD_DATA_PREFIX)
     shutil.copy(file_data.name, data_path)
     df = _get_upload_data(data)
     distribution = DataExplore.distribution(df)
@@ -105,6 +105,14 @@ def f_data_upload(data):
         choices=columns), gr.update(choices=columns)
 
 
+def f_download_report(data):
+    file_path = _get_save_path(data, "模型报告.docx")
+    if os.path.exists(file_path):
+        return gr.update(value=file_path)
+    else:
+        raise FileNotFoundError(f"{file_path} not found.")
+
+
 def f_verify_param(data):
     y_column = engine.get(data, "y_column")
     if y_column is None:
@@ -152,4 +160,4 @@ def f_train(data, progress=gr.Progress(track_tqdm=True)):
     auc_df = metric_value_dict["模型结果"].table
 
     return gr.update(value="训练完成"), gr.update(value=auc_df, visible=True), \
-           gr.update(value=_get_auc_ks_images(data), visible=True)
+           gr.update(value=_get_auc_ks_images(data), visible=True), gr.update(visible=True)