# -*- coding: utf-8 -*- """ @author: yq @time: 2024/12/4 @desc: """ import gradio as gr from init import init from webui import f_project_is_exist, f_data_upload, engine, f_train, f_download_report init() input_elems = set() elem_dict = {} with gr.Blocks() as demo: gr.HTML('

Easy-ML

') gr.HTML('

快速建模工具

') gr.State([]) with gr.Tabs(): with gr.TabItem("数据"): with gr.Row(): project_name = gr.Textbox(label="项目名称", placeholder="请输入不重复的项目名称", info="项目名称将会被作为缓存目录名称,如果重复会导致结果被覆盖") with gr.Row(): file_data = gr.File(label="建模数据", file_types=[".csv", ".xlsx"]) with gr.Row(): data_upload = gr.Dataframe(visible=False, label="当前上传数据", max_height=300) with gr.Row(): data_insight = gr.Dataframe(visible=False, label="数据探查", max_height=600, wrap=True) input_elems.update( {project_name, file_data, data_upload, data_insight}) elem_dict.update(dict( project_name=project_name, file_data=file_data, data_upload=data_upload, data_insight=data_insight )) with gr.TabItem("训练"): with gr.Row(): with gr.Column(): with gr.Row(): model_type = gr.Dropdown(["lr"], value="lr", label="模型") search_strategy = gr.Dropdown(["iv"], value="iv", label="特征搜索策略") with gr.Row(): y_column = gr.Dropdown(label="Y标签列", interactive=True, info="其值应该是0或者1") x_columns = gr.Dropdown(label="X特征列", multiselect=True, interactive=True, info="不应包含Y特征列,不选择则使用全部特征") with gr.Row(): max_feature_num = gr.Number(value=10, label="建模最多保留特征数", info="保留最重要的N个特征", interactive=True) bin_sample_rate = gr.Slider(0.05, 1, value=0.1, label="分箱组合采样率", info="对2-5箱所有分箱组合进行采样", step=0.01, interactive=True) special_values = gr.Textbox(label="特殊值", placeholder="可以是dict list str格式", info="分箱时特殊值会单独一个分箱") with gr.Row(): test_split_strategy = gr.Dropdown(["随机"], value="随机", label="测试集划分方式") test_split_rate = gr.Slider(0, 0.5, value=0.3, label="测试集划分比例", step=0.05, interactive=True) train_button = gr.Button("开始训练", variant="primary", elem_id="train_button") input_elems.update( {model_type, search_strategy, y_column, x_columns, max_feature_num, bin_sample_rate, special_values, test_split_strategy, test_split_rate, train_button }) elem_dict.update(dict( model_type=model_type, feature_search_strategy=search_strategy, y_column=y_column, x_columns=x_columns, max_feature_num=max_feature_num, bin_sample_rate=bin_sample_rate, special_values=special_values, test_split_strategy=test_split_strategy, test_split_rate=test_split_rate, train_button=train_button)) with gr.Column(): with gr.Row(): train_progress = gr.Textbox(label="训练进度", scale=4) download_report = gr.DownloadButton(label="报告下载", variant="primary", elem_id="download_report", visible=False, scale=1) file_report = gr.File(visible=False) with gr.Row(): auc_df = gr.Dataframe(visible=False, label="auc ks", max_height=300, interactive=False) with gr.Row(): gallery_auc = gr.Gallery(label="auc ks", columns=[1], rows=[2], object_fit="contain", height="auto", visible=False, interactive=False) input_elems.update( {train_progress, download_report, file_report, auc_df, gallery_auc}) elem_dict.update(dict( train_progress=train_progress, download_report=download_report, file_report=file_report, auc_df=auc_df, gallery_auc=gallery_auc)) engine.add_elems(elem_dict) project_name.change(fn=f_project_is_exist, inputs=input_elems) file_data.upload(fn=f_data_upload, inputs=input_elems, outputs=[data_upload, data_insight, y_column, x_columns]) train_button.click(fn=f_train, inputs=input_elems, outputs=[train_progress, auc_df, gallery_auc, download_report]) download_report.click(fn=f_download_report, inputs=input_elems, outputs=download_report) demo.queue(default_concurrency_limit=5) demo.launch(share=False, show_error=True, server_name="0.0.0.0", server_port=18066) if __name__ == "__main__": pass