# -*- coding: utf-8 -*-
"""
@author: yq
@time: 2024/11/27
@desc: 
"""
import time

from entitys import DataSplitEntity, DataProcessConfigEntity
from feature import FilterStrategyFactory
from model import ModelFactory
from trainer import TrainPipeline

if __name__ == "__main__":
    time_now = time.time()
    import scorecardpy as sc

    # 加载数据
    dat = sc.germancredit()
    dat["creditability"] = dat["creditability"].apply(lambda x: 1 if x == "bad" else 0)
    data = DataSplitEntity(train_data=dat[:709], val_data=None, test_data=dat[709:])

    # 特征处理
    ## 获取特征筛选策略
    filter_strategy_clazz = FilterStrategyFactory.get_strategy("iv")
    ## 可传入参数
    # filter_strategy = filter_strategy_clazz(y_column="creditability")
    ## 也可从配置文件加载
    filter_strategy = filter_strategy_clazz(DataProcessConfigEntity.from_config('./config/data_process_config_template.json'))

    # 选择模型
    model_clazz = ModelFactory.get_model("lr")
    model = model_clazz()

    # 训练并生成报告
    train_pipeline = TrainPipeline(filter_strategy, model, data)
    train_pipeline.train()
    train_pipeline.generate_report()

    print(time.time() - time_now)