Explorar el Código

modify: gr参数传递

yq hace 4 meses
padre
commit
ddb7523d13
Se han modificado 4 ficheros con 53 adiciones y 36 borrados
  1. 39 29
      app.py
  2. 1 0
      requirements-py310.txt
  3. 2 2
      webui/manager.py
  4. 11 5
      webui/utils.py

+ 39 - 29
app.py

@@ -32,11 +32,12 @@ with gr.Blocks() as demo:
                 data_insight = gr.Dataframe(visible=False, label="数据探查", max_height=600, wrap=True)
 
             input_elems.update(
-                {project_name, file_data, data_upload})
+                {project_name, file_data, data_upload, data_insight})
             elem_dict.update(dict(
                 project_name=project_name,
                 file_data=file_data,
-                data_upload=data_upload
+                data_upload=data_upload,
+                data_insight=data_insight
             ))
 
         with gr.TabItem("训练"):
@@ -59,8 +60,24 @@ with gr.Blocks() as demo:
                     with gr.Row():
                         test_split_strategy = gr.Dropdown(["随机"], value="随机", label="测试集划分方式")
                         test_split_rate = gr.Slider(0, 0.5, value=0.3, label="测试集划分比例", step=0.05, interactive=True)
-
                     train_button = gr.Button("开始训练", variant="primary", elem_id="train_button")
+
+                    input_elems.update(
+                        {model_type, search_strategy, y_column, x_columns_candidate, x_candidate_num, sample_rate,
+                         special_values, test_split_strategy, test_split_rate, train_button
+                         })
+                    elem_dict.update(dict(
+                        model_type=model_type,
+                        feature_search_strategy=search_strategy,
+                        y_column=y_column,
+                        x_columns_candidate=x_columns_candidate,
+                        x_candidate_num=x_candidate_num,
+                        sample_rate=sample_rate,
+                        special_values=special_values,
+                        test_split_strategy=test_split_strategy,
+                        test_split_rate=test_split_rate,
+                        train_button=train_button))
+
                 with gr.Column():
                     with gr.Row():
                         train_progress = gr.Textbox(label="训练进度", scale=4)
@@ -74,33 +91,26 @@ with gr.Blocks() as demo:
                         gallery_auc = gr.Gallery(label="auc ks", columns=[1], rows=[2], object_fit="contain",
                                                  height="auto", visible=False, interactive=False)
 
-                input_elems.update(
-                    {model_type, search_strategy, y_column, x_columns_candidate, x_candidate_num, sample_rate,
-                     special_values, test_split_strategy, test_split_rate
-                     })
-                elem_dict.update(dict(
-                    model_type=model_type,
-                    feature_search_strategy=search_strategy,
-                    y_column=y_column,
-                    x_columns_candidate=x_columns_candidate,
-                    x_candidate_num=x_candidate_num,
-                    sample_rate=sample_rate,
-                    special_values=special_values,
-                    test_split_strategy=test_split_strategy,
-                    test_split_rate=test_split_rate,
-                ))
+                    input_elems.update(
+                        {train_progress, download_report, file_report, auc_df, gallery_auc})
+                    elem_dict.update(dict(
+                        train_progress=train_progress,
+                        download_report=download_report,
+                        file_report=file_report,
+                        auc_df=auc_df,
+                        gallery_auc=gallery_auc))
 
-            engine.add_elems(elem_dict)
+                engine.add_elems(elem_dict)
 
-            project_name.change(fn=f_project_is_exist, inputs=input_elems)
-            file_data.upload(fn=f_data_upload, inputs=input_elems, outputs=[data_upload, data_insight, y_column,
-                                                                            x_columns_candidate])
-            train_button.click(fn=f_train, inputs=input_elems,
-                               outputs=[train_progress, auc_df, gallery_auc, download_report])
-            download_report.click(fn=f_download_report, inputs=input_elems, outputs=download_report)
+                project_name.change(fn=f_project_is_exist, inputs=input_elems)
+                file_data.upload(fn=f_data_upload, inputs=input_elems, outputs=[data_upload, data_insight, y_column,
+                                                                                x_columns_candidate])
+                train_button.click(fn=f_train, inputs=input_elems,
+                                   outputs=[train_progress, auc_df, gallery_auc, download_report])
+                download_report.click(fn=f_download_report, inputs=input_elems, outputs=download_report)
 
-        demo.queue(default_concurrency_limit=5)
-        demo.launch(share=False, show_error=True)
+            demo.queue(default_concurrency_limit=5)
+            demo.launch(share=False, show_error=True, server_name="0.0.0.0", server_port=18066)
 
-    if __name__ == "__main__":
-        pass
+        if __name__ == "__main__":
+            pass

+ 1 - 0
requirements-py310.txt

@@ -6,4 +6,5 @@ toad==0.1.4
 dataframe_image==0.1.14
 gradio==5.8.0
 matplotlib==3.9.3
+numpy==1.26.4
 scikit-learn==1.1.3

+ 2 - 2
webui/manager.py

@@ -13,14 +13,14 @@ class Manager:
             self._id_to_elem[elem_id] = elem
             self._elem_to_id[elem] = elem_id
 
-    def _get_elem_by_id(self, elem_id: str) -> "Component":
+    def get_elem_by_id(self, elem_id: str) -> "Component":
         return self._id_to_elem[elem_id]
 
     def _get_id_by_elem(self, elem: "Component") -> str:
         return self._elem_to_id[elem]
 
     def get(self, data, key):
-        return data[self._get_elem_by_id(key)]
+        return data[self.get_elem_by_id(key)]
 
     def get_all(self, data) -> Dict:
         all = {}

+ 11 - 5
webui/utils.py

@@ -101,14 +101,18 @@ def f_data_upload(data):
     df = _get_upload_data(data)
     distribution = DataExplore.distribution(df)
     columns = df.columns.to_list()
-    return gr.update(value=df, visible=True), gr.update(value=distribution, visible=True), gr.update(
-        choices=columns), gr.update(choices=columns)
+    return {
+        engine.get_elem_by_id("data_upload"): gr.update(value=df, visible=True),
+        engine.get_elem_by_id("data_insight"): gr.update(value=distribution, visible=True),
+        engine.get_elem_by_id("y_column"): gr.update(choices=columns),
+        engine.get_elem_by_id("x_columns_candidate"): gr.update(choices=columns)
+    }
 
 
 def f_download_report(data):
     file_path = _get_save_path(data, "模型报告.docx")
     if os.path.exists(file_path):
-        return gr.update(value=file_path)
+        return {engine.get_elem_by_id("download_report"): gr.update(value=file_path)}
     else:
         raise FileNotFoundError(f"{file_path} not found.")
 
@@ -159,5 +163,7 @@ def f_train(data, progress=gr.Progress(track_tqdm=True)):
 
     auc_df = metric_value_dict["模型结果"].table
 
-    return gr.update(value="训练完成"), gr.update(value=auc_df, visible=True), \
-           gr.update(value=_get_auc_ks_images(data), visible=True), gr.update(visible=True)
+    return {engine.get_elem_by_id("train_progress"): gr.update(value="训练完成"),
+            engine.get_elem_by_id("auc_df"): gr.update(value=auc_df, visible=True),
+            engine.get_elem_by_id("gallery_auc"): gr.update(value=_get_auc_ks_images(data), visible=True),
+            engine.get_elem_by_id("download_report"): gr.update(visible=True)}