app.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/12/4
  5. @desc:
  6. """
  7. import gradio as gr
  8. from init import init
  9. from webui import f_project_is_exist, f_data_upload, engine, f_train, f_download_report
  10. init()
  11. input_elems = set()
  12. elem_dict = {}
  13. with gr.Blocks() as demo:
  14. gr.HTML('<h1 ><center><font size="5">Easy-ML</font></center></h1>')
  15. gr.HTML('<h2 ><center><font size="2">快速建模工具</font></center></h2>')
  16. gr.State([])
  17. with gr.Tabs():
  18. with gr.TabItem("数据"):
  19. with gr.Row():
  20. project_name = gr.Textbox(label="项目名称", placeholder="请输入不重复的项目名称",
  21. info="项目名称将会被作为缓存目录名称,如果重复会导致结果被覆盖")
  22. with gr.Row():
  23. file_data = gr.File(label="建模数据", file_types=[".csv", ".xlsx"])
  24. with gr.Row():
  25. data_upload = gr.Dataframe(visible=False, label="当前上传数据", max_height=300)
  26. with gr.Row():
  27. data_insight = gr.Dataframe(visible=False, label="数据探查", max_height=600, wrap=True)
  28. input_elems.update(
  29. {project_name, file_data, data_upload, data_insight})
  30. elem_dict.update(dict(
  31. project_name=project_name,
  32. file_data=file_data,
  33. data_upload=data_upload,
  34. data_insight=data_insight
  35. ))
  36. with gr.TabItem("训练"):
  37. with gr.Row():
  38. with gr.Column():
  39. with gr.Row():
  40. model_type = gr.Dropdown(["lr"], value="lr", label="模型")
  41. search_strategy = gr.Dropdown(["iv"], value="iv", label="特征搜索策略")
  42. with gr.Row():
  43. y_column = gr.Dropdown(label="Y标签列", interactive=True, info="其值应该是0或者1")
  44. x_columns = gr.Dropdown(label="X特征列", multiselect=True, interactive=True,
  45. info="不应包含Y特征列,不选择则使用全部特征")
  46. with gr.Row():
  47. max_feature_num = gr.Number(value=10, label="建模最多保留特征数", info="保留最重要的N个特征",
  48. interactive=True)
  49. bin_sample_rate = gr.Slider(0.05, 1, value=0.1, label="分箱组合采样率", info="对2-5箱所有分箱组合进行采样",
  50. step=0.01, interactive=True)
  51. special_values = gr.Textbox(label="特殊值", placeholder="可以是dict list str格式",
  52. info="分箱时特殊值会单独一个分箱")
  53. with gr.Row():
  54. test_split_strategy = gr.Dropdown(["随机"], value="随机", label="测试集划分方式")
  55. test_split_rate = gr.Slider(0, 0.5, value=0.3, label="测试集划分比例", step=0.05, interactive=True)
  56. train_button = gr.Button("开始训练", variant="primary", elem_id="train_button")
  57. input_elems.update(
  58. {model_type, search_strategy, y_column, x_columns, max_feature_num, bin_sample_rate,
  59. special_values, test_split_strategy, test_split_rate, train_button
  60. })
  61. elem_dict.update(dict(
  62. model_type=model_type,
  63. feature_search_strategy=search_strategy,
  64. y_column=y_column,
  65. x_columns=x_columns,
  66. max_feature_num=max_feature_num,
  67. bin_sample_rate=bin_sample_rate,
  68. special_values=special_values,
  69. test_split_strategy=test_split_strategy,
  70. test_split_rate=test_split_rate,
  71. train_button=train_button))
  72. with gr.Column():
  73. with gr.Row():
  74. train_progress = gr.Textbox(label="训练进度", scale=4)
  75. download_report = gr.DownloadButton(label="报告下载", variant="primary", elem_id="download_report",
  76. visible=False, scale=1)
  77. file_report = gr.File(visible=False)
  78. with gr.Row():
  79. auc_df = gr.Dataframe(visible=False, label="auc ks", max_height=300, interactive=False)
  80. with gr.Row():
  81. gallery_auc = gr.Gallery(label="auc ks", columns=[1], rows=[2], object_fit="contain",
  82. height="auto", visible=False, interactive=False)
  83. input_elems.update(
  84. {train_progress, download_report, file_report, auc_df, gallery_auc})
  85. elem_dict.update(dict(
  86. train_progress=train_progress,
  87. download_report=download_report,
  88. file_report=file_report,
  89. auc_df=auc_df,
  90. gallery_auc=gallery_auc))
  91. engine.add_elems(elem_dict)
  92. project_name.change(fn=f_project_is_exist, inputs=input_elems)
  93. file_data.upload(fn=f_data_upload, inputs=input_elems, outputs=[data_upload, data_insight, y_column,
  94. x_columns])
  95. train_button.click(fn=f_train, inputs=input_elems,
  96. outputs=[train_progress, auc_df, gallery_auc, download_report])
  97. download_report.click(fn=f_download_report, inputs=input_elems, outputs=download_report)
  98. demo.queue(default_concurrency_limit=5)
  99. demo.launch(share=False, show_error=True, server_name="0.0.0.0", server_port=18066)
  100. if __name__ == "__main__":
  101. pass