utils.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/12/5
  5. @desc:
  6. """
  7. import os
  8. import shutil
  9. from typing import List
  10. import gradio as gr
  11. import pandas as pd
  12. from sklearn.model_selection import train_test_split
  13. from config import BaseConfig
  14. from data import DataLoaderExcel, DataExplore
  15. from entitys import DataSplitEntity
  16. from feature import FilterStrategyFactory
  17. from model import ModelFactory
  18. from trainer import TrainPipeline
  19. from .manager import engine
  20. DATA_SUB_DIR = "data"
  21. UPLOAD_DATA_PREFIX = "prefix_upload_data_"
  22. data_loader = DataLoaderExcel()
  23. def _clean_base_dir(data):
  24. base_dir = _get_base_dir(data)
  25. file_name_list: List[str] = os.listdir(base_dir)
  26. for file_name in file_name_list:
  27. if file_name in [DATA_SUB_DIR]:
  28. continue
  29. file_path = os.path.join(base_dir, file_name)
  30. if os.path.isdir(file_path):
  31. shutil.rmtree(file_path)
  32. else:
  33. os.remove(file_path)
  34. def _check_save_dir(data):
  35. project_name = engine.get(data, "project_name")
  36. if project_name is None or len(project_name) == 0:
  37. gr.Warning(message='项目名称不能为空', duration=5)
  38. return False
  39. return True
  40. def _get_prefix_file(save_path, prefix):
  41. file_name_list: List[str] = os.listdir(save_path)
  42. for file_name in file_name_list:
  43. if prefix in file_name:
  44. return os.path.join(save_path, file_name)
  45. def _get_base_dir(data):
  46. project_name = engine.get(data, "project_name")
  47. base_dir = os.path.join(BaseConfig.train_path, project_name)
  48. return base_dir
  49. def _get_upload_data(data) -> pd.DataFrame:
  50. base_dir = _get_base_dir(data)
  51. save_path = os.path.join(base_dir, DATA_SUB_DIR)
  52. file_path = _get_prefix_file(save_path, UPLOAD_DATA_PREFIX)
  53. df = data_loader.get_data(file_path)
  54. return df
  55. def f_project_is_exist(data):
  56. project_name = engine.get(data, "project_name")
  57. if project_name is None or len(project_name) == 0:
  58. gr.Warning(message='项目名称不能为空', duration=5)
  59. elif os.path.exists(_get_base_dir(data)):
  60. gr.Warning(message='项目名称已被使用', duration=5)
  61. def f_get_save_path(data, file_name: str, sub_dir="", name_prefix=""):
  62. base_dir = _get_base_dir(data)
  63. save_path = os.path.join(base_dir, sub_dir)
  64. os.makedirs(save_path, exist_ok=True)
  65. # 有前缀标示的先删除
  66. if name_prefix:
  67. file = _get_prefix_file(save_path, name_prefix)
  68. if file:
  69. os.remove(file)
  70. save_path = os.path.join(save_path, name_prefix + os.path.basename(file_name))
  71. return save_path
  72. def f_data_upload(data):
  73. if not _check_save_dir(data):
  74. return
  75. file_data = engine.get(data, "file_data")
  76. data_path = f_get_save_path(data, file_data.name, DATA_SUB_DIR, UPLOAD_DATA_PREFIX)
  77. shutil.copy(file_data.name, data_path)
  78. df = _get_upload_data(data)
  79. distribution = DataExplore.distribution(df)
  80. columns = df.columns.to_list()
  81. return gr.update(value=df, visible=True), gr.update(value=distribution, visible=True), gr.update(
  82. choices=columns), gr.update(choices=columns)
  83. def f_verify_param(data):
  84. y_column = engine.get(data, "y_column")
  85. if y_column is None:
  86. gr.Warning(message=f'Y标签列不能为空', duration=5)
  87. return False
  88. return True
  89. def f_train(data):
  90. feature_search_strategy = engine.get(data, "feature_search_strategy")
  91. model_type = engine.get(data, "model_type")
  92. test_split_rate = engine.get(data, "test_split_rate")
  93. data_upload = engine.get(data, "data_upload")
  94. all_param = engine.get_all(data)
  95. # 清空储存目录
  96. _clean_base_dir(data)
  97. # 校验参数
  98. if not f_verify_param(data):
  99. return
  100. # 数据集划分
  101. train_data, test_data = train_test_split(data_upload, test_size=test_split_rate, shuffle=True, random_state=2025)
  102. data_split = DataSplitEntity(train_data=train_data, val_data=None, test_data=test_data)
  103. # 特征处理
  104. ## 获取特征筛选策略
  105. filter_strategy_clazz = FilterStrategyFactory.get_strategy(feature_search_strategy)
  106. filter_strategy = filter_strategy_clazz(**all_param)
  107. # 选择模型
  108. model_clazz = ModelFactory.get_model(model_type)
  109. model = model_clazz(**all_param)
  110. # 训练并生成报告
  111. train_pipeline = TrainPipeline(filter_strategy, model, data_split)
  112. train_pipeline.train()
  113. train_pipeline.generate_report()