|
@@ -125,10 +125,11 @@ def f_verify_param(data):
|
|
|
|
|
|
|
|
|
def f_train(data, progress=gr.Progress(track_tqdm=True)):
|
|
|
- # import time
|
|
|
- # print(1111111)
|
|
|
- # time.sleep(5)
|
|
|
- # return gr.update(elem_id="train_button", value="111")
|
|
|
+ def _reset_component_state():
|
|
|
+ return {engine.get_elem_by_id("download_report"): gr.update(visible=False),
|
|
|
+ engine.get_elem_by_id("auc_df"): gr.update(visible=False),
|
|
|
+ engine.get_elem_by_id("gallery_auc"): gr.update(visible=False)}
|
|
|
+
|
|
|
progress(0, desc="Starting")
|
|
|
feature_search_strategy = engine.get(data, "feature_search_strategy")
|
|
|
model_type = engine.get(data, "model_type")
|
|
@@ -139,9 +140,11 @@ def f_train(data, progress=gr.Progress(track_tqdm=True)):
|
|
|
_clean_base_dir(data)
|
|
|
# 校验参数
|
|
|
if not f_verify_param(data):
|
|
|
- return
|
|
|
+ yield _reset_component_state()
|
|
|
|
|
|
- # 数据集划分
|
|
|
+ yield _reset_component_state()
|
|
|
+
|
|
|
+ # 数据集划分
|
|
|
train_data, test_data = train_test_split(data_upload, test_size=test_split_rate, shuffle=True, random_state=2025)
|
|
|
data_split = DataSplitEntity(train_data=train_data, val_data=None, test_data=test_data)
|
|
|
progress(0.01)
|
|
@@ -163,7 +166,9 @@ def f_train(data, progress=gr.Progress(track_tqdm=True)):
|
|
|
|
|
|
auc_df = metric_value_dict["模型结果"].table
|
|
|
|
|
|
- return {engine.get_elem_by_id("train_progress"): gr.update(value="训练完成"),
|
|
|
- engine.get_elem_by_id("auc_df"): gr.update(value=auc_df, visible=True),
|
|
|
- engine.get_elem_by_id("gallery_auc"): gr.update(value=_get_auc_ks_images(data), visible=True),
|
|
|
- engine.get_elem_by_id("download_report"): gr.update(visible=True)}
|
|
|
+ report_file_path = _get_save_path(data, "模型报告.docx")
|
|
|
+
|
|
|
+ yield {engine.get_elem_by_id("train_progress"): gr.update(value="训练完成"),
|
|
|
+ engine.get_elem_by_id("auc_df"): gr.update(value=auc_df, visible=True),
|
|
|
+ engine.get_elem_by_id("gallery_auc"): gr.update(value=_get_auc_ks_images(data), visible=True),
|
|
|
+ engine.get_elem_by_id("download_report"): gr.update(value=report_file_path, visible=True)}
|