yq 5 months ago
parent
commit
10e60da04a
6 changed files with 87 additions and 41 deletions
  1. 10 8
      app.py
  2. 2 2
      commom/__init__.py
  3. 9 2
      commom/utils.py
  4. 4 3
      config/base_config.py
  5. 13 13
      strategy_parse.py
  6. 49 13
      webui/utils.py

+ 10 - 8
app.py

@@ -6,6 +6,8 @@
 """
 import matplotlib
 
+from config import BaseConfig
+
 matplotlib.use('Agg')
 
 import gradio as gr
@@ -28,7 +30,8 @@ with gr.Blocks() as demo:
 
                     with gr.Row():
                         file_data = gr.File(label="策略文档", file_types=[".xlsx"], scale=3)
-                        sheet_name = gr.Dropdown(choices=["流程"], value="流程", label="策略查看", interactive=True,
+                        sheet_name = gr.Dropdown(choices=[BaseConfig.flow_sheet_name], value=BaseConfig.flow_sheet_name,
+                                                 label="策略查看", interactive=True,
                                                  info="流程及节点信息查看", scale=1)
 
                     with gr.Row():
@@ -36,7 +39,7 @@ with gr.Blocks() as demo:
                     code_generate = gr.Button("生成代码", variant="primary")
 
                     input_elems.update(
-                        {project_name, sheet_name, file_data, data_upload})
+                        {project_name, sheet_name, file_data, data_upload, code_generate})
                     elem_dict.update(dict(
                         project_name=project_name,
                         sheet_name=sheet_name,
@@ -49,21 +52,20 @@ with gr.Blocks() as demo:
                         generate_progress = gr.Textbox(label="生成进度", scale=4)
                         download_code = gr.DownloadButton(label="代码下载", variant="primary",
                                                           visible=False, scale=1)
-                        file_report = gr.File(visible=False)
                     with gr.Row():
-                        code_view = gr.Code()
+                        code_view = gr.Code(visible=False)
                     input_elems.update(
-                        {generate_progress, download_code, file_report})
+                        {generate_progress, download_code, code_view})
                     elem_dict.update(dict(
                         generate_progress=generate_progress,
-                        download_report=download_code,
-                        file_report=file_report
+                        download_code=download_code,
+                        code_view=code_view
                     ))
 
                     engine.add_elems(elem_dict)
 
                     project_name.change(fn=f_project_is_exist, inputs=input_elems)
-                    sheet_name.change(fn=f_get_sheet_data, inputs=input_elems, outputs=[data_upload])
+                    sheet_name.change(fn=f_get_sheet_data, inputs=input_elems, outputs=[data_upload, code_view])
                     file_data.upload(fn=f_data_upload, inputs=input_elems, outputs=[data_upload, sheet_name])
                     code_generate.click(fn=f_code_generate, inputs=input_elems,
                                         outputs=[generate_progress, code_view, download_code])

+ 2 - 2
commom/__init__.py

@@ -6,7 +6,7 @@
 """
 from .llm_call import call_llm, f_file_upload
 from .user_exceptions import GeneralException
-from .utils import f_get_date, f_get_datetime, f_get_save_path, create_zip
+from .utils import f_get_date, f_get_datetime, f_get_save_path, f_create_zip, f_read_file
 
 __all__ = ['GeneralException', 'f_get_date', 'f_get_datetime', 'f_get_save_path', 'call_llm', 'f_file_upload',
-           'create_zip']
+           'f_create_zip','f_read_file']

+ 9 - 2
commom/utils.py

@@ -30,7 +30,14 @@ def f_get_save_path(file_name: str, sub_path=""):
     os.makedirs(os.path.join(base_dir, sub_path), exist_ok=True)
     return os.path.join(base_dir, sub_path, file_name)
 
-def create_zip(zip_name, files):
+
+def f_create_zip(zip_name, files):
     with zipfile.ZipFile(zip_name, 'w') as zipf:
         for file in files:
-            zipf.write(file)
+            zipf.write(file)
+
+
+def f_read_file(file_path) -> str:
+    with open(file_path, mode="r", encoding="utf8") as f:
+        s = f.read()
+    return s

+ 4 - 3
config/base_config.py

@@ -14,6 +14,7 @@ class BaseConfig:
     bot_id = "7397344489205153807"
     file_upload_url = "https://api.coze.cn/v1/files/upload"
     base_dir = os.path.join(".", "cache")
-
-
-
+    flow_sheet_name = "流程"
+    code_zip_name = "code.zip"
+    node_map_name = "node_func_dict.json"
+    

+ 13 - 13
strategy_parse.py

@@ -15,7 +15,7 @@ from PIL import Image
 from openpyxl import load_workbook
 from tqdm import tqdm
 
-from commom import call_llm, f_file_upload, GeneralException, f_get_datetime, create_zip
+from commom import call_llm, f_file_upload, GeneralException, f_get_datetime, f_create_zip
 from config import BaseConfig
 from enums import ResultCodesEnum
 from prompt import f_get_prompt_parse_node, f_get_prompt_parse_flow, f_get_prompt_parse_flow_image
@@ -86,7 +86,7 @@ class StrategyParse:
 
     def _f_parse_flow(self, node_list: list, df: pd.DataFrame):
 
-        node_func_dict = {"流程": "flow.py"}
+        node_func_dict = {BaseConfig.flow_sheet_name: "flow.py"}
         func = ""
         node_func_map = ""
         func_import = ""
@@ -96,7 +96,7 @@ class StrategyParse:
             node_func_map = f"{node_func_map}{node_name}: {func_name}\n"
             func_import = f"{func_import}from {func_name} import {func_name}\n"
 
-        save_path = self._f_get_save_path("node_func_dict.json")
+        save_path = self._f_get_save_path(BaseConfig.node_map_name)
         with open(save_path, mode="w", encoding="utf8") as f:
             f.write(json.dumps(node_func_dict, ensure_ascii=False))
 
@@ -168,37 +168,37 @@ class StrategyParse:
         wb = load_workbook(file_path)
         excel = pd.ExcelFile(file_path)
         sheet_names = excel.sheet_names
-        if "流程图" not in sheet_names:
-            GeneralException(ResultCodesEnum.NOT_FOUND, message=f"sheet【流程图】不存在")
+        if BaseConfig.flow_sheet_name not in sheet_names:
+            raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"sheet【{BaseConfig.flow_sheet_name}】不存在")
         node_list = []
         for node_name in tqdm(sheet_names):
-            if node_name == "流程图":
+            if node_name == BaseConfig.flow_sheet_name:
                 continue
             df = excel.parse(sheet_name=node_name)
             func_name, code = self._f_parse_node(df, node_name)
             node_list.append((node_name, func_name, code))
-        self._f_parse_flow_image(wb["流程图"], node_list)
+        self._f_parse_flow_image(wb[BaseConfig.flow_sheet_name], node_list)
         wb.close()
         excel.close()
 
     def f_parse_strategy(self, excel: pd.ExcelFile, progress=None):
         sheet_names = excel.sheet_names
-        if "流程" not in sheet_names:
-            GeneralException(ResultCodesEnum.NOT_FOUND, message=f"sheet【流程】不存在")
+        if BaseConfig.flow_sheet_name not in sheet_names:
+            raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"sheet【{BaseConfig.flow_sheet_name}】不存在")
         node_list = []
         for node_name in tqdm(sheet_names):
-            if node_name == "流程":
+            if node_name == BaseConfig.flow_sheet_name:
                 continue
             df = excel.parse(sheet_name=node_name)
             func_name, code = self._f_parse_node(df, node_name)
             node_list.append((node_name, func_name, code))
         if progress is not None:
             progress(0.9)
-        self._f_parse_flow(node_list, excel.parse(sheet_name="流程"))
+        self._f_parse_flow(node_list, excel.parse(sheet_name=BaseConfig.flow_sheet_name))
 
-        save_path = self._f_get_save_path("code.zip")
+        save_path = self._f_get_save_path(BaseConfig.code_zip_name)
         py_files = self._f_get_py_files()
-        create_zip(save_path, py_files)
+        f_create_zip(save_path, py_files)
 
 
 if __name__ == "__main__":

+ 49 - 13
webui/utils.py

@@ -4,15 +4,17 @@
 @time: 2024/12/5
 @desc: 
 """
+import json
 import os
 import shutil
+import time
 from typing import List
 
 import gradio as gr
 import pandas as pd
 
+from commom import f_read_file
 from config import BaseConfig
-from strategy_parse import StrategyParse
 from .manager import engine
 
 DATA_SUB_DIR = "data"
@@ -60,6 +62,16 @@ def _get_upload_data(data) -> pd.ExcelFile:
     return excel
 
 
+def _get_node_func_dict(data) -> dict:
+    node_func_dict_path = _get_save_path(data, BaseConfig.node_map_name)
+    if not os.path.exists(node_func_dict_path):
+        return None
+    with open(node_func_dict_path, mode="r", encoding="utf8") as f:
+        node_func_dict = f.read()
+    node_func_dict = json.loads(node_func_dict)
+    return node_func_dict
+
+
 def f_project_is_exist(data):
     project_name = engine.get(data, "project_name")
     if project_name is None or len(project_name) == 0:
@@ -88,7 +100,7 @@ def f_data_upload(data):
     data_path = _get_save_path(data, file_data.name, DATA_SUB_DIR, UPLOAD_DATA_PREFIX)
     shutil.copy(file_data.name, data_path)
     excel = _get_upload_data(data)
-    df = excel.parse(sheet_name="流程")
+    df = excel.parse(sheet_name=BaseConfig.flow_sheet_name)
     columns = excel.sheet_names
     excel.close()
     return {
@@ -102,13 +114,22 @@ def f_get_sheet_data(data):
     excel = _get_upload_data(data)
     df = excel.parse(sheet_name=sheet_name)
     excel.close()
+
+    node_func_dict = _get_node_func_dict(data)
+    if node_func_dict is not None:
+        code = f_read_file(_get_save_path(data, node_func_dict[sheet_name]))
+        return {
+            engine.get_elem_by_id("data_upload"): gr.update(value=df),
+            engine.get_elem_by_id("code_view"): gr.update(value=code),
+        }
+
     return {
-        engine.get_elem_by_id("data_upload"): gr.update(value=df, visible=True)
+        engine.get_elem_by_id("data_upload"): gr.update(value=df)
     }
 
 
 def f_download_code(data):
-    file_path = _get_save_path(data, "code.zip")
+    file_path = _get_save_path(data, BaseConfig.code_zip_name)
     if os.path.exists(file_path):
         return {engine.get_elem_by_id("download_code"): gr.update(value=file_path)}
     else:
@@ -119,8 +140,8 @@ def f_verify_param(data):
     excel = _get_upload_data(data)
     columns = excel.sheet_names
     excel.close()
-    if "流程" not in columns:
-        raise gr.Error(message=f'【流程】sheet不能为空', duration=5)
+    if BaseConfig.flow_sheet_name not in columns:
+        raise gr.Error(message=f'【{BaseConfig.flow_sheet_name}】sheet不能为空', duration=5)
     return True
 
 
@@ -135,7 +156,7 @@ def f_code_generate(data, progress=gr.Progress(track_tqdm=True)):
     all_param = engine.get_all(data)
 
     # 清空储存目录
-    _clean_base_dir(data)
+    # _clean_base_dir(data)
     # 校验参数
     if not f_verify_param(data):
         yield _reset_component_state()
@@ -143,16 +164,31 @@ def f_code_generate(data, progress=gr.Progress(track_tqdm=True)):
     yield _reset_component_state()
 
     progress(0.01)
-    excel = _get_upload_data(data)
 
-    strategy_parse = StrategyParse(**all_param)
-    strategy_parse.f_parse_strategy(excel, progress)
+    time.sleep(2)
+    excel = None
+    try:
+        excel = _get_upload_data(data)
 
-    excel.close()
+        # strategy_parse = StrategyParse(**all_param)
+        # strategy_parse.f_parse_strategy(excel, progress)
+
+        excel.close()
+    except Exception as msg:
+        if excel is not None:
+            excel.close()
+        yield _reset_component_state()
+        raise gr.Error(message=f"系统错误【{msg}】", duration=5)
+
+    code_zip_file_path = _get_save_path(data, BaseConfig.code_zip_name)
+
+    node_func_dict = _get_node_func_dict(data)
 
-    code_file_path = _get_save_path(data, "code.zip")
+    flow_code = f_read_file(_get_save_path(data, node_func_dict[BaseConfig.flow_sheet_name]))
 
     progress(1)
 
     yield {engine.get_elem_by_id("generate_progress"): gr.update(value="生成完成"),
-           engine.get_elem_by_id("download_code"): gr.update(value=code_file_path, visible=True)}
+           engine.get_elem_by_id("download_code"): gr.update(value=code_zip_file_path, visible=True),
+           engine.get_elem_by_id("code_view"): gr.update(value=flow_code, visible=True),
+           }