ol_test.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/27
  5. @desc:
  6. """
  7. import time
  8. from entitys import DataSplitEntity
  9. from online_learning import OnlineLearningTrainer
  10. if __name__ == "__main__":
  11. time_now = time.time()
  12. import scorecardpy as sc
  13. # 加载数据
  14. dat = sc.germancredit()
  15. dat_columns = dat.columns.tolist()
  16. dat_columns = [c.replace(".","_") for c in dat_columns]
  17. dat.columns = dat_columns
  18. dat["creditability"] = dat["creditability"].apply(lambda x: 1 if x == "bad" else 0)
  19. data = DataSplitEntity(train_data=dat[:709], test_data=dat[709:])
  20. # 特征处理
  21. cfg = {
  22. # 模型系数,分箱信息等,请参考ol_resources_demo目录下文件
  23. # 模型系数文件 coef.dict(如果有常数项(截距)请用const作为key)
  24. # 分箱信息文件 feature.csv(数值型的分箱信息请按升序排列)
  25. "path_resources": "/root/notebook/ol_resources_demo",
  26. # 项目名称,影响数据存储位置
  27. "project_name": "OnlineLearningDemo",
  28. "y_column": "creditability",
  29. # 学习率
  30. "lr": 0.01,
  31. # 单次更新批大小
  32. "batch_size": 64,
  33. # 训练轮数
  34. "epochs": 20,
  35. "jupyter_print": True,
  36. # 压力测试
  37. "stress_test": True,
  38. # 压力测试抽样次数
  39. "stress_sample_times": 10,
  40. }
  41. # 训练并生成报告
  42. trainer = OnlineLearningTrainer(data=data, **cfg)
  43. trainer.train()
  44. trainer.report()
  45. print(time.time() - time_now)