utils.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/12/5
  5. @desc:
  6. """
  7. import os
  8. import shutil
  9. from typing import List
  10. import gradio as gr
  11. import pandas as pd
  12. from config import BaseConfig
  13. from data import DataLoaderExcel, DataExplore
  14. from .manager import engine
  15. DATA_DIR = "data"
  16. UPLOAD_DATA_PREFIX = "prefix_upload_data_"
  17. data_loader = DataLoaderExcel()
  18. def _check_save_dir(data):
  19. project_name = engine.get(data, "project_name")
  20. if project_name is None or len(project_name) == 0:
  21. gr.Warning(message='项目名称不能为空', duration=5)
  22. return False
  23. return True
  24. def _get_prefix_file(save_path, prefix):
  25. file_name_list: List[str] = os.listdir(save_path)
  26. for file_name in file_name_list:
  27. if prefix in file_name:
  28. return os.path.join(save_path, file_name)
  29. def _get_base_dir(data):
  30. project_name = engine.get(data, "project_name")
  31. base_dir = os.path.join(BaseConfig.train_path, project_name)
  32. return base_dir
  33. def _get_upload_data(data) -> pd.DataFrame:
  34. base_dir = _get_base_dir(data)
  35. save_path = os.path.join(base_dir, DATA_DIR)
  36. file_path = _get_prefix_file(save_path, UPLOAD_DATA_PREFIX)
  37. df = data_loader.get_data(file_path)
  38. return df
  39. def f_project_is_exist(data):
  40. project_name = engine.get(data, "project_name")
  41. if project_name is None or len(project_name) == 0:
  42. gr.Warning(message='项目名称不能为空', duration=5)
  43. elif os.path.exists(_get_base_dir(data)):
  44. gr.Warning(message='项目名称已被使用', duration=5)
  45. def f_get_save_path(data, file_name: str, sub_dir="", name_prefix=""):
  46. base_dir = _get_base_dir(data)
  47. save_path = os.path.join(base_dir, sub_dir)
  48. os.makedirs(save_path, exist_ok=True)
  49. # 有前缀标示的先删除
  50. if name_prefix:
  51. file = _get_prefix_file(save_path, name_prefix)
  52. if file:
  53. os.remove(file)
  54. save_path = os.path.join(save_path, name_prefix + os.path.basename(file_name))
  55. return save_path
  56. def f_data_upload(data):
  57. if not _check_save_dir(data):
  58. return
  59. file_data = engine.get(data, "file_data")
  60. data_path = f_get_save_path(data, file_data.name, DATA_DIR, UPLOAD_DATA_PREFIX)
  61. shutil.copy(file_data.name, data_path)
  62. df = _get_upload_data(data)
  63. distribution = DataExplore.distribution(df)
  64. return gr.update(value=df, visible=True), gr.update(value=distribution, visible=True),