1234567891011121314151617181920212223242526272829303132333435363738 |
- # -*- coding: utf-8 -*-
- """
- @author: yq
- @time: 2024/11/27
- @desc:
- """
- import time
- from entitys import DataSplitEntity, DataProcessConfigEntity
- from feature import FilterStrategyFactory
- from model import ModelFactory
- from trainer import TrainPipeline
- if __name__ == "__main__":
- time_now = time.time()
- import scorecardpy as sc
- # 加载数据
- dat = sc.germancredit()
- dat["creditability"] = dat["creditability"].apply(lambda x: 1 if x == "bad" else 0)
- data = DataSplitEntity(train_data=dat[:709], val_data=None, test_data=dat[709:])
- # 特征处理
- ## 获取特征筛选策略
- filter_strategy_clazz = FilterStrategyFactory.get_strategy("iv")
- ## 也可从配置文件加载
- filter_strategy = filter_strategy_clazz(DataProcessConfigEntity.from_config('./config/data_process_config_template.json'))
- # 选择模型
- model_clazz = ModelFactory.get_model("lr")
- model = model_clazz()
- # 训练并生成报告
- train_pipeline = TrainPipeline(filter_strategy, model, data)
- train_pipeline.train()
- train_pipeline.generate_report()
- print(time.time() - time_now)
|