train_test.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/27
  5. @desc:
  6. """
  7. import time
  8. from entitys import DataSplitEntity, DataProcessConfigEntity, TrainConfigEntity
  9. from feature import FilterStrategyFactory
  10. from trainer import TrainPipeline
  11. if __name__ == "__main__":
  12. time_now = time.time()
  13. import scorecardpy as sc
  14. dat = sc.germancredit()
  15. dat["creditability"] = dat["creditability"].apply(lambda x: 1 if x == "bad" else 0)
  16. data = DataSplitEntity(dat[:700], None, dat[700:])
  17. # 特征处理
  18. filter_strategy_factory = FilterStrategyFactory(
  19. DataProcessConfigEntity.from_config('./config/data_process_config_template.json'))
  20. strategy = filter_strategy_factory.get_strategy()
  21. candidate_feature = strategy.filter(data)
  22. data_prepared = strategy.feature_generate(data, candidate_feature)
  23. # 训练
  24. train_pipeline = TrainPipeline(TrainConfigEntity.from_config('./config/train_config_template.json'))
  25. metric_value_dict_train = train_pipeline.train(data_prepared)
  26. # 报告生成
  27. metric_value_dict_feature = strategy.feature_report(data, candidate_feature)
  28. metric_value_dict_train.update(metric_value_dict_feature)
  29. train_pipeline.generate_report(metric_value_dict_train)
  30. print(time.time() - time_now)