strategy_parse.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/12/18
  5. @desc: 策略流节点解析
  6. """
  7. import json
  8. import os
  9. import re
  10. import time
  11. from typing import List
  12. import pandas as pd
  13. from PIL import Image
  14. from openpyxl import load_workbook
  15. from tqdm import tqdm
  16. from commom import call_llm, f_file_upload, GeneralException, f_get_datetime, f_create_zip
  17. from config import BaseConfig
  18. from enums import ResultCodesEnum
  19. from prompt import f_get_prompt_parse_node, f_get_prompt_parse_flow, f_get_prompt_parse_flow_image
  20. class StrategyParse:
  21. def __init__(self, project_name: str = None, *args, **kwargs):
  22. # 项目名称,和缓存路径有关
  23. self._project_name = project_name
  24. if self._project_name is None or len(self._project_name) == 0:
  25. self._base_dir = os.path.join(BaseConfig.base_dir, f"{f_get_datetime()}")
  26. else:
  27. self._base_dir = os.path.join(BaseConfig.base_dir, self._project_name)
  28. os.makedirs(self._base_dir, exist_ok=True)
  29. @property
  30. def project_name(self):
  31. return self._project_name
  32. @property
  33. def base_dir(self):
  34. return self._base_dir
  35. def _f_get_save_path(self, file_name: str) -> str:
  36. path = os.path.join(self._base_dir, file_name)
  37. return path
  38. def _f_get_py_files(self, ):
  39. py_files = []
  40. file_name_list: List[str] = os.listdir(self._base_dir)
  41. for file_name in file_name_list:
  42. if ".py" in file_name:
  43. py_files.append(os.path.join(self._base_dir, file_name))
  44. return py_files
  45. # 未使用
  46. def _f_parse_flow_image(self, ws, node_list: list):
  47. image = ws._images[0]
  48. img = Image.open(image.ref).convert("RGB")
  49. save_path = self._f_get_save_path("流程图.png")
  50. img.save(save_path)
  51. time.sleep(1)
  52. file_id = f_file_upload(save_path)
  53. prompt = f_get_prompt_parse_flow_image(node_list)
  54. print(prompt)
  55. prompt = [
  56. {
  57. "type": "text",
  58. "text": prompt
  59. },
  60. {
  61. "type": "image",
  62. "file_id": file_id
  63. }
  64. ]
  65. prompt = json.dumps(prompt, ensure_ascii=False)
  66. llm_answer = call_llm(prompt, "object_string")
  67. print(llm_answer)
  68. code = re.findall(r"```python\n(.*)\n```", llm_answer, flags=re.DOTALL)[0]
  69. save_path = self._f_get_save_path("flow.py")
  70. with open(save_path, mode="w", encoding="utf8") as f:
  71. f.write(code)
  72. save_path = self._f_get_save_path("__init__.py")
  73. with open(save_path, mode="w", encoding="utf8") as f:
  74. f.write("")
  75. # 未使用
  76. def _f_parse_strategy_image(self, file_path):
  77. wb = load_workbook(file_path)
  78. excel = pd.ExcelFile(file_path)
  79. sheet_names = excel.sheet_names
  80. if BaseConfig.flow_sheet_name not in sheet_names:
  81. raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"sheet【{BaseConfig.flow_sheet_name}】不存在")
  82. node_list = []
  83. for node_name in tqdm(sheet_names):
  84. if node_name == BaseConfig.flow_sheet_name:
  85. continue
  86. df = excel.parse(sheet_name=node_name)
  87. func_name, code = self._f_parse_node(df, node_name)
  88. node_list.append((node_name, func_name, code))
  89. self._f_parse_flow_image(wb[BaseConfig.flow_sheet_name], node_list)
  90. wb.close()
  91. excel.close()
  92. def _f_parse_node(self, df: pd.DataFrame, node_name):
  93. rules = ""
  94. for idx, row in df.iterrows():
  95. var_name = row["变量"]
  96. var_name = var_name.replace("\n", " ")
  97. rule_content = row["逻辑"]
  98. rule_content = rule_content.replace("\n", " ")
  99. rule_out = row["输出"]
  100. notes_output = row["输出备注"]
  101. if notes_output is None or notes_output != notes_output:
  102. notes_output = ""
  103. else:
  104. notes_output = notes_output.replace("\n", " ")
  105. notes_output = f" 结果备注: {notes_output}"
  106. notes_input = row["输入备注"]
  107. if notes_input is None or notes_input != notes_input:
  108. notes_input = ""
  109. else:
  110. notes_input = notes_input.replace("\n", " ")
  111. notes_input = f" 变量备注: {notes_input}"
  112. rules = f"{rules}规则{idx + 1}: 变量:{var_name} 逻辑:{rule_content} 输出:{rule_out}{notes_input}{notes_output}\n"
  113. default_output = list(df["默认输出"])[0]
  114. if default_output is None or default_output != default_output:
  115. default_output = ""
  116. else:
  117. default_output = str(default_output).replace("\n", " ")
  118. default_output = f"{default_output}"
  119. # 构造提示词
  120. prompt = f_get_prompt_parse_node(node_name, rules, default_output)
  121. print(prompt)
  122. # 调用大模型
  123. llm_answer = call_llm(prompt)
  124. # 解析代码部分
  125. code = re.findall(r"```python\n(.*)\n```", llm_answer, flags=re.DOTALL)[0]
  126. # 解析函数名
  127. func_name = re.findall(r"def (.*)\(data", code)[0]
  128. # 保存节点代码
  129. save_path = self._f_get_save_path(f"{func_name}.py")
  130. print(code)
  131. with open(save_path, mode="w", encoding="utf8") as f:
  132. f.write(code)
  133. return func_name, code
  134. def _f_parse_flow(self, node_list: list, df: pd.DataFrame):
  135. node_func_dict = {BaseConfig.flow_sheet_name: "flow.py"}
  136. func = ""
  137. node_func_map = ""
  138. func_import = ""
  139. for node_name, func_name, code in node_list:
  140. node_func_dict[node_name] = f"{func_name}.py"
  141. func = f"{func}{code}\n\n"
  142. node_func_map = f"{node_func_map}{node_name}: {func_name}\n"
  143. func_import = f"{func_import}from {func_name} import {func_name}\n"
  144. save_path = self._f_get_save_path(BaseConfig.node_map_name)
  145. with open(save_path, mode="w", encoding="utf8") as f:
  146. f.write(json.dumps(node_func_dict, ensure_ascii=False))
  147. flow = ""
  148. for _, row in df.iterrows():
  149. strategy = row["策略流描述"]
  150. flow = f"{flow}{strategy}\n"
  151. flow = flow.strip()
  152. # 构造提示词
  153. prompt = f_get_prompt_parse_flow(func, node_func_map, func_import, flow)
  154. print(prompt)
  155. # 调用大模型
  156. llm_answer = call_llm(prompt)
  157. print(llm_answer)
  158. # 解析代码部分
  159. code = re.findall(r"```python\n(.*)\n```", llm_answer, flags=re.DOTALL)[0]
  160. # 保存代码
  161. save_path = self._f_get_save_path("flow.py")
  162. with open(save_path, mode="w", encoding="utf8") as f:
  163. f.write(code)
  164. save_path = self._f_get_save_path("__init__.py")
  165. with open(save_path, mode="w", encoding="utf8") as f:
  166. f.write("")
  167. def f_parse_strategy(self, excel: pd.ExcelFile, progress=None):
  168. sheet_names = excel.sheet_names
  169. if BaseConfig.flow_sheet_name not in sheet_names:
  170. raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"sheet【{BaseConfig.flow_sheet_name}】不存在")
  171. # 解析各节点
  172. node_list = []
  173. for node_name in tqdm(sheet_names):
  174. # 忽略“流程”的sheet
  175. if node_name == BaseConfig.flow_sheet_name:
  176. continue
  177. df = excel.parse(sheet_name=node_name)
  178. func_name, code = self._f_parse_node(df, node_name)
  179. node_list.append((node_name, func_name, code))
  180. if progress is not None:
  181. progress(0.9)
  182. # 解析流程
  183. self._f_parse_flow(node_list, excel.parse(sheet_name=BaseConfig.flow_sheet_name))
  184. # 打包文件
  185. save_path = self._f_get_save_path(BaseConfig.code_zip_name)
  186. py_files = self._f_get_py_files()
  187. f_create_zip(save_path, py_files)
  188. if __name__ == "__main__":
  189. excel = pd.ExcelFile("./template/demo.xlsx")
  190. strategy_parse = StrategyParse()
  191. strategy_parse.f_parse_strategy(excel)
  192. excel.close()