# -*- coding: utf-8 -*-
"""
@author: yq
@time: 2024/12/5
@desc: 
"""
import os
import shutil
from typing import List

import gradio as gr
import pandas as pd
from sklearn.model_selection import train_test_split

from config import BaseConfig
from data import DataLoaderExcel, DataExplore
from entitys import DataSplitEntity
from feature import FilterStrategyFactory
from model import ModelFactory
from trainer import TrainPipeline
from .manager import engine

DATA_SUB_DIR = "data"
UPLOAD_DATA_PREFIX = "prefix_upload_data_"
data_loader = DataLoaderExcel()


def _clean_base_dir(data):
    base_dir = _get_base_dir(data)
    file_name_list: List[str] = os.listdir(base_dir)
    for file_name in file_name_list:
        if file_name in [DATA_SUB_DIR]:
            continue
        file_path = os.path.join(base_dir, file_name)
        if os.path.isdir(file_path):
            shutil.rmtree(file_path)
        else:
            os.remove(file_path)


def _check_save_dir(data):
    project_name = engine.get(data, "project_name")
    if project_name is None or len(project_name) == 0:
        raise gr.Error(message='项目名称不能为空', duration=5)
    return True


def _get_prefix_file(save_path, prefix):
    file_name_list: List[str] = os.listdir(save_path)
    for file_name in file_name_list:
        if prefix in file_name:
            return os.path.join(save_path, file_name)


def _get_base_dir(data):
    project_name = engine.get(data, "project_name")
    base_dir = os.path.join(BaseConfig.train_path, project_name)
    return base_dir


def _get_upload_data(data) -> pd.DataFrame:
    base_dir = _get_base_dir(data)
    save_path = os.path.join(base_dir, DATA_SUB_DIR)
    file_path = _get_prefix_file(save_path, UPLOAD_DATA_PREFIX)
    df = data_loader.get_data(file_path)
    return df


def _get_auc_ks_images(data):
    base_dir = _get_base_dir(data)
    return [os.path.join(base_dir, "train_perf.png"), os.path.join(base_dir, "test_perf.png")]


def f_project_is_exist(data):
    project_name = engine.get(data, "project_name")
    if project_name is None or len(project_name) == 0:
        gr.Warning(message='项目名称不能为空', duration=5)
    elif os.path.exists(_get_base_dir(data)):
        gr.Warning(message='项目名称已被使用', duration=5)


def _get_save_path(data, file_name: str, sub_dir="", name_prefix=""):
    base_dir = _get_base_dir(data)
    save_path = os.path.join(base_dir, sub_dir)
    os.makedirs(save_path, exist_ok=True)
    # 有前缀标示的先删除
    if name_prefix:
        file = _get_prefix_file(save_path, name_prefix)
        if file:
            os.remove(file)
    save_path = os.path.join(save_path, name_prefix + os.path.basename(file_name))
    return save_path


def f_data_upload(data):
    if not _check_save_dir(data):
        return
    file_data = engine.get(data, "file_data")
    data_path = _get_save_path(data, file_data.name, DATA_SUB_DIR, UPLOAD_DATA_PREFIX)
    shutil.copy(file_data.name, data_path)
    df = _get_upload_data(data)
    distribution = DataExplore.distribution(df)
    columns = df.columns.to_list()
    return {
        engine.get_elem_by_id("data_upload"): gr.update(value=df, visible=True),
        engine.get_elem_by_id("data_insight"): gr.update(value=distribution, visible=True),
        engine.get_elem_by_id("y_column"): gr.update(choices=columns),
        engine.get_elem_by_id("x_columns_candidate"): gr.update(choices=columns)
    }


def f_download_report(data):
    file_path = _get_save_path(data, "模型报告.docx")
    if os.path.exists(file_path):
        return {engine.get_elem_by_id("download_report"): gr.update(value=file_path)}
    else:
        raise FileNotFoundError(f"{file_path} not found.")


def f_verify_param(data):
    y_column = engine.get(data, "y_column")
    if y_column is None:
        raise gr.Error(message=f'Y标签列不能为空', duration=5)
    return True


def f_train(data, progress=gr.Progress(track_tqdm=True)):
    def _reset_component_state():
        return {engine.get_elem_by_id("download_report"): gr.update(visible=False),
                engine.get_elem_by_id("auc_df"): gr.update(visible=False),
                engine.get_elem_by_id("gallery_auc"): gr.update(visible=False)}

    progress(0, desc="Starting")
    feature_search_strategy = engine.get(data, "feature_search_strategy")
    model_type = engine.get(data, "model_type")
    test_split_rate = engine.get(data, "test_split_rate")
    data_upload = engine.get(data, "data_upload")
    all_param = engine.get_all(data)
    # 清空储存目录
    _clean_base_dir(data)
    # 校验参数
    if not f_verify_param(data):
        yield _reset_component_state()

    yield _reset_component_state()

    # 数据集划分
    train_data, test_data = train_test_split(data_upload, test_size=test_split_rate, shuffle=True, random_state=2025)
    data_split = DataSplitEntity(train_data=train_data, val_data=None, test_data=test_data)
    progress(0.01)

    # 特征处理
    ## 获取特征筛选策略
    filter_strategy_clazz = FilterStrategyFactory.get_strategy(feature_search_strategy)
    filter_strategy = filter_strategy_clazz(**all_param)

    # 选择模型
    model_clazz = ModelFactory.get_model(model_type)
    model = model_clazz(**all_param)

    # 训练并生成报告
    train_pipeline = TrainPipeline(filter_strategy, model, data_split)
    metric_value_dict = train_pipeline.train()
    progress(0.95)
    train_pipeline.generate_report()

    auc_df = metric_value_dict["模型结果"].table

    report_file_path = _get_save_path(data, "模型报告.docx")

    yield {engine.get_elem_by_id("train_progress"): gr.update(value="训练完成"),
           engine.get_elem_by_id("auc_df"): gr.update(value=auc_df, visible=True),
           engine.get_elem_by_id("gallery_auc"): gr.update(value=_get_auc_ks_images(data), visible=True),
           engine.get_elem_by_id("download_report"): gr.update(value=report_file_path, visible=True)}