# -*- coding: utf-8 -*-
"""
@author: yq
@time: 2024/11/27
@desc:
"""
import time

from entitys import DataSplitEntity
from online_learning import OnlineLearningTrainer


if __name__ == "__main__":
    time_now = time.time()
    import scorecardpy as sc

    # 加载数据
    dat = sc.germancredit()
    dat_columns = dat.columns.tolist()
    dat_columns = [c.replace(".","_") for c in dat_columns]
    dat.columns = dat_columns

    dat["creditability"] = dat["creditability"].apply(lambda x: 1 if x == "bad" else 0)

    data = DataSplitEntity(train_data=dat[:709], test_data=dat[709:])

    # 特征处理
    cfg = {
        # 模型系数,分箱信息等,请参考ol_resources_demo目录下文件
        # 模型系数文件 coef.dict(如果有常数项(截距)请用const作为key)
        # 分箱信息文件 feature.csv(数值型的分箱信息请按升序排列)
        "path_resources": "/root/notebook/ol_resources_demo",
        # 项目名称,影响数据存储位置
        "project_name": "OnlineLearningDemo",
        "y_column": "creditability",
        # 学习率
        "lr": 0.01,
        # 单次更新批大小
        "batch_size": 64,
        # 训练轮数
        "epochs": 20,
        "jupyter_print": True,
        # 压力测试
        "stress_test": True,
        # 压力测试抽样次数
        "stress_sample_times": 10,
    }

    # 训练并生成报告
    trainer = OnlineLearningTrainer(data=data, **cfg)
    trainer.train()
    trainer.report()

    print(time.time() - time_now)