# -*- coding: utf-8 -*-
"""
@author: yq
@time: 2024/12/4
@desc: 
"""

import gradio as gr

from init import init
from webui import f_project_is_exist, f_data_upload, engine, f_train, f_download_report

init()

input_elems = set()
elem_dict = {}

with gr.Blocks() as demo:
    gr.HTML('<h1 ><center><font size="5">Easy-ML</font></center></h1>')
    gr.HTML('<h2 ><center><font size="2">快速建模工具</font></center></h2>')
    gr.State([])
    with gr.Tabs():
        with gr.TabItem("数据"):
            with gr.Row():
                project_name = gr.Textbox(label="项目名称", placeholder="请输入不重复的项目名称",
                                          info="项目名称将会被作为缓存目录名称,如果重复会导致结果被覆盖")
            with gr.Row():
                file_data = gr.File(label="建模数据", file_types=[".csv", ".xlsx"])
            with gr.Row():
                data_upload = gr.Dataframe(visible=False, label="当前上传数据", max_height=300)
            with gr.Row():
                data_insight = gr.Dataframe(visible=False, label="数据探查", max_height=600, wrap=True)

            input_elems.update(
                {project_name, file_data, data_upload})
            elem_dict.update(dict(
                project_name=project_name,
                file_data=file_data,
                data_upload=data_upload
            ))

        with gr.TabItem("训练"):
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        model_type = gr.Dropdown(["lr"], value="lr", label="模型")
                        search_strategy = gr.Dropdown(["iv"], value="iv", label="特征搜索策略")
                    with gr.Row():
                        y_column = gr.Dropdown(label="Y标签列", interactive=True, info="其值应该是0或者1")
                        x_columns_candidate = gr.Dropdown(label="X特征列", multiselect=True, interactive=True,
                                                          info="不应包含Y特征列,不选择则使用全部特征")
                    with gr.Row():
                        x_candidate_num = gr.Number(value=10, label="建模最多保留特征数", info="保留最重要的N个特征",
                                                    interactive=True)
                        sample_rate = gr.Slider(0.05, 1, value=0.1, label="分箱组合采样率", info="对2-5箱所有分箱组合进行采样",
                                                step=0.01, interactive=True)
                        special_values = gr.Textbox(label="特殊值", placeholder="可以是dict list str格式",
                                                    info="分箱时特殊值会单独一个分箱")
                    with gr.Row():
                        test_split_strategy = gr.Dropdown(["随机"], value="随机", label="测试集划分方式")
                        test_split_rate = gr.Slider(0, 0.5, value=0.3, label="测试集划分比例", step=0.05, interactive=True)

                    train_button = gr.Button("开始训练", variant="primary", elem_id="train_button")
                with gr.Column():
                    with gr.Row():
                        train_progress = gr.Textbox(label="训练进度", scale=4)
                        download_report = gr.DownloadButton(label="报告下载", variant="primary", elem_id="download_report",
                                                            visible=False, scale=1)
                        file_report = gr.File(visible=False)

                    with gr.Row():
                        auc_df = gr.Dataframe(visible=False, label="auc ks", max_height=300, interactive=False)
                    with gr.Row():
                        gallery_auc = gr.Gallery(label="auc ks", columns=[1], rows=[2], object_fit="contain",
                                                 height="auto", visible=False, interactive=False)

                input_elems.update(
                    {model_type, search_strategy, y_column, x_columns_candidate, x_candidate_num, sample_rate,
                     special_values, test_split_strategy, test_split_rate
                     })
                elem_dict.update(dict(
                    model_type=model_type,
                    feature_search_strategy=search_strategy,
                    y_column=y_column,
                    x_columns_candidate=x_columns_candidate,
                    x_candidate_num=x_candidate_num,
                    sample_rate=sample_rate,
                    special_values=special_values,
                    test_split_strategy=test_split_strategy,
                    test_split_rate=test_split_rate,
                ))

            engine.add_elems(elem_dict)

            project_name.change(fn=f_project_is_exist, inputs=input_elems)
            file_data.upload(fn=f_data_upload, inputs=input_elems, outputs=[data_upload, data_insight, y_column,
                                                                            x_columns_candidate])
            train_button.click(fn=f_train, inputs=input_elems,
                               outputs=[train_progress, auc_df, gallery_auc, download_report])
            download_report.click(fn=f_download_report, inputs=input_elems, outputs=download_report)

        demo.queue(default_concurrency_limit=5)
        demo.launch(share=False, show_error=True)

    if __name__ == "__main__":
        pass