train_test.py 804 B

123456789101112131415161718192021222324252627282930
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/27
  5. @desc:
  6. """
  7. import time
  8. from entitys import DataSplitEntity, MlConfigEntity
  9. from pipeline import Pipeline
  10. if __name__ == "__main__":
  11. time_now = time.time()
  12. import scorecardpy as sc
  13. # 加载数据
  14. dat = sc.germancredit()
  15. dat_columns = dat.columns.tolist()
  16. dat_columns = [c.replace(".","_") for c in dat_columns]
  17. dat.columns = dat_columns
  18. dat["creditability"] = dat["creditability"].apply(lambda x: 1 if x == "bad" else 0)
  19. data = DataSplitEntity(train_data=dat[:709], test_data=dat[709:])
  20. # 训练并生成报告
  21. train_pipeline = Pipeline(MlConfigEntity.from_config('./config/ml_config_template.json'), data)
  22. train_pipeline.train()
  23. train_pipeline.report()
  24. print(time.time() - time_now)