ol_test_xgb.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/27
  5. @desc:
  6. """
  7. import time
  8. from entitys import DataSplitEntity
  9. from online_learning import OnlineLearningTrainerXgb
  10. if __name__ == "__main__":
  11. time_now = time.time()
  12. import scorecardpy as sc
  13. # 加载数据
  14. dat = sc.germancredit()
  15. dat_columns = dat.columns.tolist()
  16. dat_columns = [c.replace(".", "_") for c in dat_columns]
  17. dat.columns = dat_columns
  18. dat["creditability"] = dat["creditability"].apply(lambda x: 1 if x == "bad" else 0)
  19. data = DataSplitEntity(train_data=dat[:609], test_data=dat[609:])
  20. # 特征处理
  21. cfg = {
  22. # 模型,请参考ol_resources_demo目录下文件
  23. # 模型文件 model.pkl
  24. "path_resources": "/root/notebook/ol_resources_demo",
  25. # 项目名称,影响数据存储位置
  26. "project_name": "OnlineLearningDemo",
  27. "y_column": "creditability",
  28. # 学习率
  29. "lr": 0.01,
  30. "jupyter_print": True,
  31. # 压力测试
  32. "stress_test": False,
  33. # 压力测试抽样次数
  34. "stress_sample_times": 10,
  35. "columns_anns": {
  36. "age_in_years": "年龄"
  37. },
  38. "params_xgb": {
  39. 'objective': 'binary:logistic',
  40. 'eval_metric': 'auc',
  41. 'learning_rate': 0.1,
  42. 'max_depth': 3,
  43. 'subsample': None,
  44. 'colsample_bytree': None,
  45. 'alpha': 0,
  46. 'lambda': 1,
  47. 'num_boost_round': 7,
  48. 'early_stopping_rounds': 20,
  49. 'verbose_eval': 10,
  50. 'random_state': 2025,
  51. 'save_pmml': True,
  52. 'trees_print': False,
  53. # tree_refresh tree_add
  54. 'oltype': "tree_add",
  55. 'add_columns': ['age_in_years'],
  56. }
  57. }
  58. # 训练并生成报告
  59. trainer = OnlineLearningTrainerXgb(data=data, **cfg)
  60. trainer.train()
  61. trainer.report()
  62. print(time.time() - time_now)