train_test_xgb.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/11/27
  5. @desc:
  6. """
  7. import time
  8. from entitys import DataSplitEntity, MlConfigEntity
  9. from pipeline import Pipeline
  10. if __name__ == "__main__":
  11. time_now = time.time()
  12. import scorecardpy as sc
  13. # 加载数据
  14. dat = sc.germancredit()
  15. dat_columns = dat.columns.tolist()
  16. dat_columns = [c.replace(".","_") for c in dat_columns]
  17. dat.columns = dat_columns
  18. dat["creditability"] = dat["creditability"].apply(lambda x: 1 if x == "bad" else 0)
  19. # dat["credit_amount_corr1"] = dat["credit_amount"] * 2
  20. # dat["credit_amount_corr2"] = dat["credit_amount"] * 3
  21. data = DataSplitEntity(train_data=dat[:709], test_data=dat[709:])
  22. # 训练并生成报告
  23. # train_pipeline = Pipeline(MlConfigEntity.from_config('config/demo/ml_config_template.json'), data)
  24. # 特征处理
  25. cfg = {
  26. # 项目名称,影响数据存储位置
  27. "project_name": "demo",
  28. # jupyter下输出内容
  29. "jupyter_print": True,
  30. # 是否开启粗分箱
  31. "format_bin": True,
  32. "max_feature_num": 20,
  33. # 压力测试
  34. "stress_test": True,
  35. # 压力测试抽样次数
  36. "stress_sample_times": 10,
  37. # y
  38. "y_column": "creditability",
  39. # 参与建模的候选变量
  40. # "x_columns": [
  41. # "duration_in_month",
  42. # "credit_amount",
  43. # "age_in_years",
  44. # "purpose",
  45. # "credit_history",
  46. # "random",
  47. # "credit_amount_corr1",
  48. # "credit_amount_corr2",
  49. # ],
  50. # 变量释义
  51. "columns_anns": {
  52. "age_in_years": "年龄",
  53. "credit_history": "借贷历史"
  54. },
  55. # 被排除的变量
  56. "columns_exclude": [],
  57. # 强制使用的变量
  58. # "columns_include": ["credit_amount"],
  59. "model_type": "xgb",
  60. "feature_strategy": "norm",
  61. }
  62. train_pipeline = Pipeline(data=data, **cfg)
  63. train_pipeline.train()
  64. train_pipeline.report()
  65. print(time.time() - time_now)