train_test.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/27
  5. @desc:
  6. """
  7. import time
  8. from entitys import DataSplitEntity
  9. from feature import FilterStrategyFactory
  10. from model import ModelFactory
  11. from trainer import TrainPipeline
  12. if __name__ == "__main__":
  13. time_now = time.time()
  14. import scorecardpy as sc
  15. # 加载数据
  16. dat = sc.germancredit()
  17. dat["creditability"] = dat["creditability"].apply(lambda x: 1 if x == "bad" else 0)
  18. data = DataSplitEntity(train_data=dat[:709], val_data=None, test_data=dat[709:])
  19. # 特征处理
  20. ## 获取特征筛选策略
  21. filter_strategy_factory = FilterStrategyFactory()
  22. filter_strategy_clazz = filter_strategy_factory.get_strategy("iv")
  23. ## 可传入参数
  24. filter_strategy = filter_strategy_clazz(y_column="creditability")
  25. ## 也可从配置文件加载
  26. # filter_strategy = filter_strategy_clazz(DataProcessConfigEntity.from_config('./config/data_process_config_template.json'))
  27. # 选择模型
  28. model_factory = ModelFactory()
  29. model_clazz = model_factory.get_model("lr")
  30. model = model_clazz()
  31. # 训练并生成报告
  32. train_pipeline = TrainPipeline(filter_strategy, model, data)
  33. train_pipeline.train()
  34. train_pipeline.generate_report()
  35. print(time.time() - time_now)