1234567891011121314151617181920212223242526272829303132333435 |
- # -*- coding: utf-8 -*-
- """
- @author: yq
- @time: 2024/11/27
- @desc:
- """
- import time
- from entitys import DataSplitEntity, DataProcessConfigEntity, TrainConfigEntity
- from feature import FilterStrategyFactory
- from trainer import TrainPipeline
- if __name__ == "__main__":
- time_now = time.time()
- import scorecardpy as sc
- dat = sc.germancredit()
- dat["creditability"] = dat["creditability"].apply(lambda x: 1 if x == "bad" else 0)
- data = DataSplitEntity(dat[:709], None, dat[709:])
- # 特征处理
- filter_strategy_factory = FilterStrategyFactory(
- DataProcessConfigEntity.from_config('./config/data_process_config_template.json'))
- strategy = filter_strategy_factory.get_strategy()
- candidate_feature = strategy.filter(data)
- data_prepared = strategy.feature_generate(data, candidate_feature)
- # 训练
- train_pipeline = TrainPipeline(TrainConfigEntity.from_config('./config/train_config_template.json'))
- metric_value_dict_train = train_pipeline.train(data_prepared)
- # 报告生成
- metric_value_dict_feature = strategy.feature_report(data, candidate_feature)
- metric_value_dict_train.update(metric_value_dict_feature)
- train_pipeline.generate_report(metric_value_dict_train)
- print(time.time() - time_now)
|