123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163 |
- # -*- coding: utf-8 -*-
- """
- @author: yq
- @time: 2024/12/5
- @desc:
- """
- import os
- import shutil
- from typing import List
- import gradio as gr
- import pandas as pd
- from sklearn.model_selection import train_test_split
- from config import BaseConfig
- from data import DataLoaderExcel, DataExplore
- from entitys import DataSplitEntity
- from feature import FilterStrategyFactory
- from model import ModelFactory
- from trainer import TrainPipeline
- from .manager import engine
- DATA_SUB_DIR = "data"
- UPLOAD_DATA_PREFIX = "prefix_upload_data_"
- data_loader = DataLoaderExcel()
- def _clean_base_dir(data):
- base_dir = _get_base_dir(data)
- file_name_list: List[str] = os.listdir(base_dir)
- for file_name in file_name_list:
- if file_name in [DATA_SUB_DIR]:
- continue
- file_path = os.path.join(base_dir, file_name)
- if os.path.isdir(file_path):
- shutil.rmtree(file_path)
- else:
- os.remove(file_path)
- def _check_save_dir(data):
- project_name = engine.get(data, "project_name")
- if project_name is None or len(project_name) == 0:
- raise gr.Error(message='项目名称不能为空', duration=5)
- return True
- def _get_prefix_file(save_path, prefix):
- file_name_list: List[str] = os.listdir(save_path)
- for file_name in file_name_list:
- if prefix in file_name:
- return os.path.join(save_path, file_name)
- def _get_base_dir(data):
- project_name = engine.get(data, "project_name")
- base_dir = os.path.join(BaseConfig.train_path, project_name)
- return base_dir
- def _get_upload_data(data) -> pd.DataFrame:
- base_dir = _get_base_dir(data)
- save_path = os.path.join(base_dir, DATA_SUB_DIR)
- file_path = _get_prefix_file(save_path, UPLOAD_DATA_PREFIX)
- df = data_loader.get_data(file_path)
- return df
- def _get_auc_ks_images(data):
- base_dir = _get_base_dir(data)
- return [os.path.join(base_dir, "train_perf.png"), os.path.join(base_dir, "test_perf.png")]
- def f_project_is_exist(data):
- project_name = engine.get(data, "project_name")
- if project_name is None or len(project_name) == 0:
- gr.Warning(message='项目名称不能为空', duration=5)
- elif os.path.exists(_get_base_dir(data)):
- gr.Warning(message='项目名称已被使用', duration=5)
- def _get_save_path(data, file_name: str, sub_dir="", name_prefix=""):
- base_dir = _get_base_dir(data)
- save_path = os.path.join(base_dir, sub_dir)
- os.makedirs(save_path, exist_ok=True)
- # 有前缀标示的先删除
- if name_prefix:
- file = _get_prefix_file(save_path, name_prefix)
- if file:
- os.remove(file)
- save_path = os.path.join(save_path, name_prefix + os.path.basename(file_name))
- return save_path
- def f_data_upload(data):
- if not _check_save_dir(data):
- return
- file_data = engine.get(data, "file_data")
- data_path = _get_save_path(data, file_data.name, DATA_SUB_DIR, UPLOAD_DATA_PREFIX)
- shutil.copy(file_data.name, data_path)
- df = _get_upload_data(data)
- distribution = DataExplore.distribution(df)
- columns = df.columns.to_list()
- return gr.update(value=df, visible=True), gr.update(value=distribution, visible=True), gr.update(
- choices=columns), gr.update(choices=columns)
- def f_download_report(data):
- file_path = _get_save_path(data, "模型报告.docx")
- if os.path.exists(file_path):
- return gr.update(value=file_path)
- else:
- raise FileNotFoundError(f"{file_path} not found.")
- def f_verify_param(data):
- y_column = engine.get(data, "y_column")
- if y_column is None:
- raise gr.Error(message=f'Y标签列不能为空', duration=5)
- return True
- def f_train(data, progress=gr.Progress(track_tqdm=True)):
- # import time
- # print(1111111)
- # time.sleep(5)
- # return gr.update(elem_id="train_button", value="111")
- progress(0, desc="Starting")
- feature_search_strategy = engine.get(data, "feature_search_strategy")
- model_type = engine.get(data, "model_type")
- test_split_rate = engine.get(data, "test_split_rate")
- data_upload = engine.get(data, "data_upload")
- all_param = engine.get_all(data)
- # 清空储存目录
- _clean_base_dir(data)
- # 校验参数
- if not f_verify_param(data):
- return
- # 数据集划分
- train_data, test_data = train_test_split(data_upload, test_size=test_split_rate, shuffle=True, random_state=2025)
- data_split = DataSplitEntity(train_data=train_data, val_data=None, test_data=test_data)
- progress(0.01)
- # 特征处理
- ## 获取特征筛选策略
- filter_strategy_clazz = FilterStrategyFactory.get_strategy(feature_search_strategy)
- filter_strategy = filter_strategy_clazz(**all_param)
- # 选择模型
- model_clazz = ModelFactory.get_model(model_type)
- model = model_clazz(**all_param)
- # 训练并生成报告
- train_pipeline = TrainPipeline(filter_strategy, model, data_split)
- metric_value_dict = train_pipeline.train()
- progress(0.95)
- train_pipeline.generate_report()
- auc_df = metric_value_dict["模型结果"].table
- return gr.update(value="训练完成"), gr.update(value=auc_df, visible=True), \
- gr.update(value=_get_auc_ks_images(data), visible=True), gr.update(visible=True)
|