strategy_woe.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637
  1. # -*- coding:utf-8 -*-
  2. """
  3. @author: yq
  4. @time: 2024/1/2
  5. @desc: iv值及单调性筛选类
  6. """
  7. import json
  8. import os.path
  9. from itertools import combinations_with_replacement
  10. from typing import Dict, Optional, Union
  11. import matplotlib.pyplot as plt
  12. import numpy as np
  13. import pandas as pd
  14. import scorecardpy as sc
  15. import seaborn as sns
  16. from pandas.core.dtypes.common import is_numeric_dtype
  17. from tqdm import tqdm
  18. from commom import f_display_images_by_side, NumpyEncoder, GeneralException, f_df_to_image, f_display_title, \
  19. f_image_crop_white_borders
  20. from entitys import DataSplitEntity, MetricFucResultEntity
  21. from enums import ContextEnum, ResultCodesEnum
  22. from feature.feature_strategy_base import FeatureStrategyBase
  23. from init import context
  24. from .entity import BinInfo, HomologousBinInfo
  25. from .utils import f_monto_shift, f_get_corr, f_get_vif, f_format_bin, f_trend_shift, f_get_psi
  26. class StrategyWoe(FeatureStrategyBase):
  27. def __init__(self, *args, **kwargs):
  28. super().__init__(*args, **kwargs)
  29. # woe编码需要的分箱信息,复用scorecardpy的格式
  30. self.sc_woebin = None
  31. def _f_get_img_corr(self, train_woe) -> Union[str, None]:
  32. if len(train_woe.columns.to_list()) <= 1:
  33. return None
  34. train_corr = f_get_corr(train_woe)
  35. plt.figure(figsize=(12, 12))
  36. sns.heatmap(train_corr, vmax=1, square=True, cmap='RdBu', annot=True)
  37. plt.title('Variables Correlation', fontsize=15)
  38. plt.yticks(rotation=0)
  39. plt.xticks(rotation=90)
  40. img_path = self.ml_config.f_get_save_path(f"corr.png")
  41. plt.savefig(img_path)
  42. f_image_crop_white_borders(img_path, img_path)
  43. return img_path
  44. def _f_get_img_trend(self, sc_woebin, x_columns, prefix):
  45. imgs_path = []
  46. for k in x_columns:
  47. df_bin = sc_woebin[k]
  48. # df_bin["bin"] = df_bin["bin"].apply(lambda x: re.sub(r"(\d+\.\d+)",
  49. # lambda m: "{:.2f}".format(float(m.group(0))), x))
  50. sc.woebin_plot(df_bin)
  51. path = self.ml_config.f_get_save_path(f"{prefix}_{k}.png")
  52. plt.savefig(path)
  53. imgs_path.append(path)
  54. return imgs_path
  55. def _f_get_sc_woebin(self, data: pd.DataFrame, bin_info_dict: Dict[str, BinInfo]) -> Dict[str, pd.DataFrame]:
  56. y_column = self.ml_config.y_column
  57. special_values = self.ml_config.special_values
  58. x_columns = list(bin_info_dict.keys())
  59. breaks_list = {column: bin_info.points for column, bin_info in bin_info_dict.items()}
  60. sc_woebin = sc.woebin(data[x_columns + [y_column]], y=y_column, breaks_list=breaks_list,
  61. special_values=special_values, print_info=False)
  62. return sc_woebin
  63. def _handle_numeric(self, data: DataSplitEntity, x_column: str) -> HomologousBinInfo:
  64. # 贪婪搜索【训练集】及【测试集】加起来【iv】值最高的且【单调】的分箱
  65. def _n0(x):
  66. return sum(x == 0)
  67. def _n1(x):
  68. return sum(x == 1)
  69. def _get_bins_sv(df, x_column):
  70. y_column = self.ml_config.y_column
  71. special_values = self.ml_config.get_special_values(x_column)
  72. # special_values_bins
  73. bins_sv = pd.DataFrame()
  74. for special in special_values:
  75. dtm = df[df[x_column] == special]
  76. if len(dtm) != 0:
  77. dtm['bin'] = [str(special)] * len(dtm)
  78. bin = dtm.groupby(['bin'], group_keys=False)[y_column].agg([_n0, _n1]) \
  79. .reset_index().rename(columns={'_n0': 'good', '_n1': 'bad'})
  80. bin['is_special_values'] = [True] * len(bin)
  81. bins_sv = pd.concat((bins_sv, bin))
  82. return bins_sv
  83. def _get_bins_nsv(df, x_column, breaks_list):
  84. # no_special_values_bins
  85. def _left_value(bin: str):
  86. if "," not in bin:
  87. return float(bin)
  88. left = bin.split(",")[0]
  89. return float(left[1:])
  90. y_column = self.ml_config.y_column
  91. dtm = pd.DataFrame({'y': df[y_column], 'value': df[x_column]})
  92. bstbrks = [-np.inf] + breaks_list + [np.inf]
  93. labels = ['[{},{})'.format(bstbrks[i], bstbrks[i + 1]) for i in range(len(bstbrks) - 1)]
  94. dtm.loc[:, 'bin'] = pd.cut(dtm['value'], bstbrks, right=False, labels=labels)
  95. dtm['bin'] = dtm['bin'].astype(str)
  96. bins = dtm.groupby(['bin'], group_keys=False)['y'].agg([_n0, _n1]) \
  97. .reset_index().rename(columns={'_n0': 'good', '_n1': 'bad'})
  98. bins['is_special_values'] = [False] * len(bins)
  99. bins["ordered"] = bins['bin'].apply(_left_value)
  100. # 排序防止计算变量分箱后的单调性错位
  101. bins = bins.sort_values(by=["ordered"], ascending=[True])
  102. return bins
  103. def _get_badprobs(bins):
  104. bins['count'] = bins['good'] + bins['bad']
  105. bins['badprob'] = bins['bad'] / bins['count']
  106. return bins['badprob'].values.tolist()
  107. def _get_iv(bins):
  108. infovalue = pd.DataFrame({'good': bins['good'], 'bad': bins['bad']}) \
  109. .replace(0, 0.9) \
  110. .assign(DistrBad=lambda x: x.bad / sum(x.bad), DistrGood=lambda x: x.good / sum(x.good)) \
  111. .assign(iv=lambda x: (x.DistrBad - x.DistrGood) * np.log(x.DistrBad / x.DistrGood)) \
  112. .iv
  113. bins['bin_iv'] = infovalue
  114. bins['total_iv'] = bins['bin_iv'].sum()
  115. iv = bins['total_iv'].values[0]
  116. return iv.round(3)
  117. def _get_points(data_ascending, column):
  118. def _sampling(raw_list: list, num: int):
  119. # 按步长采样
  120. return raw_list[::int(len(raw_list) / num)]
  121. def _distribute(interval, bin_num):
  122. parts = int(1 / interval)
  123. # 穷举分布,隔板法
  124. total_ways = combinations_with_replacement(range(parts + bin_num - 1), bin_num - 1)
  125. distributions = []
  126. # 遍历所有可能的隔板位置
  127. for combo in total_ways:
  128. # 根据隔板位置分配球
  129. distribution = [0] * bin_num
  130. start = 0
  131. for i, divider in enumerate(combo):
  132. distribution[i] = divider - start + 1
  133. start = divider + 1
  134. distribution[-1] = parts - start # 最后一个箱子的球数
  135. # 确保每个箱子至少有一个球
  136. if all(x > 0 for x in distribution):
  137. distributions.append(distribution)
  138. return distributions
  139. interval = self.ml_config.bin_search_interval
  140. bin_sample_rate = self.ml_config.bin_sample_rate
  141. format_bin = self.ml_config.format_bin
  142. data_x = data_ascending[column]
  143. data_x_describe = data_x.describe(percentiles=[0.1, 0.9])
  144. data_x_max = data_x.max()
  145. # 计算 2 - 5 箱的情况
  146. distributions_list = []
  147. for bin_num in list(range(2, 6)):
  148. distributions = _distribute(interval, bin_num)
  149. # 4箱及以上得采样,不然耗时太久
  150. sample_num = 1000 * bin_sample_rate
  151. if bin_sample_rate <= 0.15:
  152. sample_num *= 2
  153. if bin_num == 5:
  154. sample_num = 4000 * bin_sample_rate
  155. if bin_num in (4, 5) and len(distributions) >= sample_num:
  156. distributions = _sampling(distributions, sample_num)
  157. distributions_list.extend(distributions)
  158. points_list = []
  159. for distributions in distributions_list:
  160. points = []
  161. point_percentile = [sum(distributions[0:idx + 1]) * interval for idx, _ in
  162. enumerate(distributions[0:-1])]
  163. for percentile in point_percentile:
  164. point = data_x.iloc[int(len(data_x) * percentile)]
  165. point = float(point)
  166. if format_bin:
  167. point = f_format_bin(data_x_describe, point)
  168. point = round(point, 2)
  169. if point == 0:
  170. continue
  171. # 排除粗分箱后越界的情况
  172. if point not in points and point < data_x_max:
  173. points.append(point)
  174. if points not in points_list and len(points) != 0:
  175. points_list.append(points)
  176. return points_list
  177. special_values = self.ml_config.get_special_values(x_column)
  178. breaks_list = self.ml_config.get_breaks_list(x_column)
  179. iv_threshold = self.ml_config.iv_threshold
  180. psi_threshold = self.ml_config.psi_threshold
  181. monto_shift_threshold = self.ml_config.monto_shift_threshold
  182. trend_shift_threshold = self.ml_config.trend_shift_threshold
  183. train_data = data.train_data
  184. test_data = data.test_data
  185. train_data_ascending_nsv = train_data[~train_data[x_column].isin(special_values)] \
  186. .sort_values(by=x_column, ascending=True)
  187. test_data_ascending_nsv = test_data[~test_data[x_column].isin(special_values)] \
  188. .sort_values(by=x_column, ascending=True)
  189. train_bins_sv = _get_bins_sv(train_data, x_column)
  190. test_bins_sv = _get_bins_sv(test_data, x_column)
  191. # 获取每种分箱的信息
  192. # 构造数据切分点
  193. is_auto_bins = 1
  194. if len(breaks_list) != 0:
  195. points_list_nsv = [breaks_list]
  196. is_auto_bins = 0
  197. else:
  198. points_list_nsv = _get_points(train_data_ascending_nsv, x_column)
  199. homo_bin_info = HomologousBinInfo(x_column, is_auto_bins, self.ml_config.is_include(x_column))
  200. # 计算iv psi monto_shift等
  201. for points in points_list_nsv:
  202. bin_info = BinInfo()
  203. bin_info.x_column = x_column
  204. bin_info.bin_num = len(points) + 1
  205. bin_info.points = points
  206. bin_info.is_auto_bins = is_auto_bins
  207. # 变量iv,与special_values合并计算iv
  208. train_bins_nsv = _get_bins_nsv(train_data_ascending_nsv, x_column, points)
  209. train_bins = pd.concat((train_bins_nsv, train_bins_sv))
  210. train_iv = _get_iv(train_bins)
  211. test_bins_nsv = _get_bins_nsv(test_data_ascending_nsv, x_column, points)
  212. test_bins = pd.concat((test_bins_nsv, test_bins_sv))
  213. test_iv = _get_iv(test_bins)
  214. bin_info.train_iv = train_iv
  215. bin_info.test_iv = test_iv
  216. bin_info.iv = train_iv + test_iv
  217. bin_info.is_qualified_iv_train = 1 if train_iv > iv_threshold else 0
  218. # 变量单调性变化次数
  219. train_badprobs_nsv = _get_badprobs(train_bins_nsv)
  220. monto_shift_train_nsv = f_monto_shift(train_badprobs_nsv)
  221. bin_info.monto_shift_nsv = monto_shift_train_nsv
  222. bin_info.is_qualified_monto_train_nsv = 0 if monto_shift_train_nsv > monto_shift_threshold else 1
  223. # 变量趋势一致性
  224. test_badprobs_nsv = _get_badprobs(test_bins_nsv)
  225. trend_shift_nsv = f_trend_shift(train_badprobs_nsv, test_badprobs_nsv)
  226. bin_info.trend_shift_nsv = trend_shift_nsv
  227. bin_info.is_qualified_trend_nsv = 0 if trend_shift_nsv > trend_shift_threshold else 1
  228. # 变量psi
  229. psi = f_get_psi(train_bins, test_bins)
  230. bin_info.psi = psi
  231. bin_info.is_qualified_psi = 1 if psi < psi_threshold else 0
  232. homo_bin_info.add(bin_info)
  233. return homo_bin_info
  234. def _f_fast_filter(self, data: DataSplitEntity) -> Dict[str, BinInfo]:
  235. # 通过iv值粗筛变量
  236. train_data = data.train_data
  237. test_data = data.test_data
  238. y_column = self.ml_config.y_column
  239. x_columns = self.ml_config.x_columns
  240. columns_exclude = self.ml_config.columns_exclude
  241. special_values = self.ml_config.special_values
  242. breaks_list = self.ml_config.breaks_list.copy()
  243. iv_threshold = self.ml_config.iv_threshold
  244. psi_threshold = self.ml_config.psi_threshold
  245. if len(x_columns) == 0:
  246. x_columns = train_data.columns.tolist()
  247. if y_column in x_columns:
  248. x_columns.remove(y_column)
  249. for column in columns_exclude:
  250. if column in x_columns:
  251. x_columns.remove(column)
  252. bins_train = sc.woebin(train_data[x_columns + [y_column]], y=y_column, bin_num_limit=5,
  253. special_values=special_values, breaks_list=breaks_list, print_info=False)
  254. for column, bin in bins_train.items():
  255. breaks_list[column] = list(bin[bin["is_special_values"] == False]['breaks'])
  256. bins_test = sc.woebin(test_data[x_columns + [y_column]], y=y_column,
  257. special_values=special_values, breaks_list=breaks_list, print_info=False)
  258. bin_info_fast: Dict[str, BinInfo] = {}
  259. filter_fast_overview = ""
  260. for column, bin_train in bins_train.items():
  261. train_iv = bin_train['total_iv'][0].round(3)
  262. if train_iv <= iv_threshold and not self.ml_config.is_include(column):
  263. filter_fast_overview = f"{filter_fast_overview}{column} 因为train_iv【{train_iv}】小于阈值被剔除\n"
  264. continue
  265. bin_test = bins_test[column]
  266. test_iv = bin_test['total_iv'][0].round(3)
  267. iv = round(train_iv + test_iv, 3)
  268. psi = f_get_psi(bin_train, bin_test)
  269. # if psi >= psi_threshold and not self.ml_config.is_include(column):
  270. # filter_fast_overview = f"{filter_fast_overview}{column} 因为psi【{psi}】大于阈值被剔除\n"
  271. # continue
  272. bin_info_fast[column] = BinInfo.ofConvertByDict(
  273. {"x_column": column, "train_iv": train_iv, "iv": iv, "psi": psi, "points": breaks_list[column]}
  274. )
  275. context.set_filter_info(ContextEnum.FILTER_FAST,
  276. f"筛选前变量数量:{len(x_columns)}\n{x_columns}\n"
  277. f"快速筛选剔除变量数量:{len(x_columns) - len(bin_info_fast)}\n{filter_fast_overview}")
  278. return bin_info_fast
  279. def _f_corr_filter(self, data: DataSplitEntity, bin_info_dict: Dict[str, BinInfo]) -> Dict[str, BinInfo]:
  280. # 相关性剔除变量
  281. corr_threshold = self.ml_config.corr_threshold
  282. train_data = data.train_data
  283. x_columns = list(bin_info_dict.keys())
  284. sc_woebin = self._f_get_sc_woebin(train_data, bin_info_dict)
  285. train_woe = sc.woebin_ply(train_data[x_columns], sc_woebin, print_info=False)
  286. corr_df = f_get_corr(train_woe)
  287. corr_dict = corr_df.to_dict()
  288. filter_corr_overview = ""
  289. filter_corr_detail = {}
  290. # 依次判断每个变量对于其它变量的相关性
  291. for column, corr in corr_dict.items():
  292. column = column.replace("_woe", "")
  293. column_remove = []
  294. overview = f"{column}: "
  295. if column not in x_columns:
  296. continue
  297. for challenger_column, challenger_corr in corr.items():
  298. challenger_corr = round(challenger_corr, 3)
  299. challenger_column = challenger_column.replace("_woe", "")
  300. if challenger_corr < corr_threshold or column == challenger_column \
  301. or challenger_column not in x_columns:
  302. continue
  303. # 相关性大于阈值的情况,选择iv值大的
  304. iv = bin_info_dict[column].iv
  305. challenger_iv = bin_info_dict[challenger_column].iv
  306. if iv > challenger_iv:
  307. if not self.ml_config.is_include(challenger_column):
  308. column_remove.append(challenger_column)
  309. overview = f"{overview}【{challenger_column}_iv{challenger_iv}_corr{challenger_corr}】 "
  310. else:
  311. # 自己被剔除的情况下不再记录
  312. column_remove = []
  313. overview = ""
  314. break
  315. # 剔除与自己相关的变量
  316. for c in column_remove:
  317. if c in x_columns:
  318. x_columns.remove(c)
  319. if len(column_remove) != 0:
  320. filter_corr_overview = f"{filter_corr_overview}{overview}\n"
  321. filter_corr_detail[column] = column_remove
  322. for column in list(bin_info_dict.keys()):
  323. if column not in x_columns:
  324. bin_info_dict.pop(column)
  325. context.set_filter_info(ContextEnum.FILTER_CORR, filter_corr_overview, filter_corr_detail)
  326. return bin_info_dict
  327. def _f_vif_filter(self, data: DataSplitEntity, bin_info_dict: Dict[str, BinInfo]) -> Dict[str, BinInfo]:
  328. vif_threshold = self.ml_config.vif_threshold
  329. train_data = data.train_data
  330. x_columns = list(bin_info_dict.keys())
  331. sc_woebin = self._f_get_sc_woebin(train_data, bin_info_dict)
  332. train_woe = sc.woebin_ply(train_data[x_columns], sc_woebin, print_info=False)
  333. df_vif = f_get_vif(train_woe)
  334. if df_vif is None:
  335. return bin_info_dict
  336. filter_vif_overview = ""
  337. filter_vif_detail = []
  338. for _, row in df_vif.iterrows():
  339. column = row["变量"]
  340. vif = row["vif"]
  341. if vif < vif_threshold or self.ml_config.is_include(column):
  342. continue
  343. filter_vif_overview = f"{filter_vif_overview}{column} 因为vif【{vif}】大于阈值被剔除\n"
  344. filter_vif_detail.append(column)
  345. bin_info_dict.pop(column)
  346. context.set_filter_info(ContextEnum.FILTER_VIF, filter_vif_overview, filter_vif_detail)
  347. return bin_info_dict
  348. def post_filter(self, data: DataSplitEntity, bin_info_dict: Dict[str, BinInfo]):
  349. # 变量之间进行比较的过滤器
  350. max_feature_num = self.ml_config.max_feature_num
  351. bin_info_filtered = self._f_corr_filter(data, bin_info_dict)
  352. bin_info_filtered = self._f_vif_filter(data, bin_info_filtered)
  353. bin_info_filtered = BinInfo.ivTopN(bin_info_filtered, max_feature_num)
  354. self.sc_woebin = self._f_get_sc_woebin(data.train_data, bin_info_filtered)
  355. context.set(ContextEnum.BIN_INFO_FILTERED, bin_info_filtered)
  356. context.set(ContextEnum.WOEBIN, self.sc_woebin)
  357. def feature_search(self, data: DataSplitEntity, *args, **kwargs):
  358. # 粗筛
  359. bin_info_fast = self._f_fast_filter(data)
  360. x_columns = list(bin_info_fast.keys())
  361. bin_info_filtered: Dict[str, BinInfo] = {}
  362. # 数值型变量多种分箱方式的中间结果
  363. homo_bin_info_numeric_set: Dict[str, HomologousBinInfo] = {}
  364. filter_numeric_overview = ""
  365. filter_numeric_detail = []
  366. for x_column in tqdm(x_columns):
  367. if is_numeric_dtype(data.train_data[x_column]):
  368. # 数值型变量筛选
  369. homo_bin_info_numeric: HomologousBinInfo = self._handle_numeric(data, x_column)
  370. if homo_bin_info_numeric.is_auto_bins:
  371. homo_bin_info_numeric_set[x_column] = homo_bin_info_numeric
  372. # iv psi 变量单调性 变量趋势一致性 筛选
  373. bin_info: Optional[BinInfo] = homo_bin_info_numeric.filter()
  374. if bin_info is not None:
  375. bin_info_filtered[x_column] = bin_info
  376. else:
  377. # 不满足要求被剔除
  378. filter_numeric_overview = f"{filter_numeric_overview}{x_column} {homo_bin_info_numeric.drop_reason()}\n"
  379. filter_numeric_detail.append(x_column)
  380. else:
  381. # 字符型暂时用scorecardpy来处理
  382. bin_info_filtered[x_column] = bin_info_fast[x_column]
  383. self.post_filter(data, bin_info_filtered)
  384. context.set(ContextEnum.HOMO_BIN_INFO_NUMERIC_SET, homo_bin_info_numeric_set)
  385. context.set_filter_info(ContextEnum.FILTER_NUMERIC, filter_numeric_overview, filter_numeric_detail)
  386. def variable_analyse(self, data: DataSplitEntity, column: str, format_bin=None, *args, **kwargs):
  387. from IPython import display
  388. if is_numeric_dtype(data.train_data[column]):
  389. train_data = data.train_data
  390. test_data = data.test_data
  391. format_bin_mlcfg = self.ml_config.format_bin
  392. if format_bin is not None:
  393. self.ml_config._format_bin = format_bin
  394. homo_bin_info_numeric: HomologousBinInfo = self._handle_numeric(data, column)
  395. bins_info = homo_bin_info_numeric.get_best_bins()
  396. print(f"-----【{column}】不同分箱数下变量的推荐切分点-----")
  397. imgs_path_trend_train = []
  398. imgs_path_trend_test = []
  399. for bin_info in bins_info:
  400. print(json.dumps(bin_info.points, ensure_ascii=False, cls=NumpyEncoder))
  401. breaks_list = [str(i) for i in bin_info.points]
  402. sc_woebin_train = self._f_get_sc_woebin(train_data, {column: bin_info})
  403. image_path = self._f_get_img_trend(sc_woebin_train, [column],
  404. f"train_{column}_{'_'.join(breaks_list)}")
  405. imgs_path_trend_train.append(image_path[0])
  406. sc_woebin_test = self._f_get_sc_woebin(test_data, {column: bin_info})
  407. image_path = self._f_get_img_trend(sc_woebin_test, [column],
  408. f"test_{column}_{'_'.join(breaks_list)}")
  409. imgs_path_trend_test.append(image_path[0])
  410. f_display_images_by_side(display, imgs_path_trend_train, title=f"训练集",
  411. image_path_list2=imgs_path_trend_test, title2="测试集")
  412. self.ml_config._format_bin = format_bin_mlcfg
  413. else:
  414. print("只能针对数值型变量进行分析。")
  415. def feature_save(self, *args, **kwargs):
  416. if self.sc_woebin is None:
  417. GeneralException(ResultCodesEnum.NOT_FOUND, message=f"feature不存在")
  418. df_woebin = pd.concat(self.sc_woebin.values())
  419. path = self.ml_config.f_get_save_path(f"feature.csv")
  420. df_woebin.to_csv(path)
  421. print(f"feature save to【{path}】success. ")
  422. def feature_load(self, path: str, *args, **kwargs):
  423. if os.path.isdir(path):
  424. path = os.path.join(path, "feature.csv")
  425. if not os.path.isfile(path) or "feature.csv" not in path:
  426. raise GeneralException(ResultCodesEnum.NOT_FOUND, message=f"特征信息【feature.csv】不存在")
  427. df_woebin = pd.read_csv(path)
  428. variables = df_woebin["variable"].unique().tolist()
  429. self.sc_woebin = {}
  430. for variable in variables:
  431. self.sc_woebin[variable] = df_woebin[df_woebin["variable"] == variable]
  432. print(f"feature load from【{path}】success.")
  433. def feature_generate(self, data: pd.DataFrame, *args, **kwargs) -> pd.DataFrame:
  434. x_columns = list(self.sc_woebin.keys())
  435. # 排个序,防止因为顺序原因导致的可能的bug
  436. x_columns.sort()
  437. data_woe = sc.woebin_ply(data[x_columns], self.sc_woebin, print_info=False)
  438. return data_woe
  439. def feature_report(self, data: DataSplitEntity, *args, **kwargs) -> Dict[str, MetricFucResultEntity]:
  440. y_column = self.ml_config.y_column
  441. columns_anns = self.ml_config.columns_anns
  442. x_columns = list(self.sc_woebin.keys())
  443. train_data = data.train_data
  444. test_data = data.test_data
  445. # 跨模块调用中间结果,所以从上下文里取
  446. bin_info_filtered: Dict[str, BinInfo] = context.get(ContextEnum.BIN_INFO_FILTERED)
  447. metric_value_dict = {}
  448. # 样本分布
  449. metric_value_dict["样本分布"] = MetricFucResultEntity(table=data.get_distribution(y_column), table_font_size=10,
  450. table_cell_width=3)
  451. # 变量相关性
  452. sc_woebin_train = self._f_get_sc_woebin(train_data, bin_info_filtered)
  453. train_woe = sc.woebin_ply(train_data[x_columns], sc_woebin_train, print_info=False)
  454. img_path_corr = self._f_get_img_corr(train_woe)
  455. metric_value_dict["变量相关性"] = MetricFucResultEntity(image_path=img_path_corr)
  456. # 变量iv、psi、vif
  457. df_iv_psi_vif = pd.DataFrame()
  458. train_iv = [bin_info_filtered[column].train_iv for column in x_columns]
  459. psi = [bin_info_filtered[column].psi for column in x_columns]
  460. anns = [columns_anns.get(column, "-") for column in x_columns]
  461. df_iv_psi_vif["变量"] = x_columns
  462. df_iv_psi_vif["iv"] = train_iv
  463. df_iv_psi_vif["psi"] = psi
  464. df_vif = f_get_vif(train_woe)
  465. if df_vif is not None:
  466. df_iv_psi_vif = pd.merge(df_iv_psi_vif, df_vif, on="变量", how="left")
  467. df_iv_psi_vif["释义"] = anns
  468. df_iv_psi_vif.sort_values(by=["iv"], ascending=[False], inplace=True)
  469. img_path_iv = self.ml_config.f_get_save_path(f"iv.png")
  470. f_df_to_image(df_iv_psi_vif, img_path_iv)
  471. metric_value_dict["变量iv"] = MetricFucResultEntity(table=df_iv_psi_vif, image_path=img_path_iv)
  472. # 变量趋势-训练集
  473. imgs_path_trend_train = self._f_get_img_trend(sc_woebin_train, x_columns, "train")
  474. metric_value_dict["变量趋势-训练集"] = MetricFucResultEntity(image_path=imgs_path_trend_train, image_size=4)
  475. # 变量趋势-测试集
  476. sc_woebin_test = self._f_get_sc_woebin(test_data, bin_info_filtered)
  477. imgs_path_trend_test = self._f_get_img_trend(sc_woebin_test, x_columns, "test")
  478. metric_value_dict["变量趋势-测试集"] = MetricFucResultEntity(image_path=imgs_path_trend_test, image_size=4)
  479. # context.set(ContextEnum.METRIC_FEATURE.value, metric_value_dict)
  480. if self.ml_config.jupyter_print:
  481. self.jupyter_print(data, metric_value_dict)
  482. return metric_value_dict
  483. def jupyter_print(self, data: DataSplitEntity, metric_value_dict=Dict[str, MetricFucResultEntity]):
  484. from IPython import display
  485. def detail_print(detail):
  486. if isinstance(detail, str):
  487. detail = [detail]
  488. if isinstance(detail, list):
  489. for column in detail:
  490. homo_bin_info_numeric = homo_bin_info_numeric_set.get(column)
  491. if homo_bin_info_numeric is None:
  492. continue
  493. bins_info = homo_bin_info_numeric.get_best_bins()
  494. print(f"-----【{column}】不同分箱数下变量的推荐切分点-----")
  495. imgs_path_trend_train = []
  496. imgs_path_trend_test = []
  497. for bin_info in bins_info:
  498. print(json.dumps(bin_info.points, ensure_ascii=False, cls=NumpyEncoder))
  499. breaks_list = [str(i) for i in bin_info.points]
  500. sc_woebin_train = self._f_get_sc_woebin(train_data, {column: bin_info})
  501. image_path = self._f_get_img_trend(sc_woebin_train, [column],
  502. f"train_{column}_{'_'.join(breaks_list)}")
  503. imgs_path_trend_train.append(image_path[0])
  504. sc_woebin_test = self._f_get_sc_woebin(test_data, {column: bin_info})
  505. image_path = self._f_get_img_trend(sc_woebin_test, [column],
  506. f"test_{column}_{'_'.join(breaks_list)}")
  507. imgs_path_trend_test.append(image_path[0])
  508. f_display_images_by_side(display, imgs_path_trend_train, title=f"训练集",
  509. image_path_list2=imgs_path_trend_test, title2="测试集")
  510. if isinstance(detail, dict):
  511. for column, challenger_columns in detail.items():
  512. print(f"-----相关性筛选保留的【{column}】-----")
  513. detail_print(column)
  514. detail_print(challenger_columns)
  515. def filter_print(filter, title, notes=""):
  516. f_display_title(display, title)
  517. print(notes)
  518. print(filter.get("overview"))
  519. detail = filter.get("detail")
  520. if detail is not None and self.ml_config.bin_detail_print:
  521. detail_print(detail)
  522. train_data = data.train_data
  523. test_data = data.test_data
  524. bin_info_filtered: Dict[str, BinInfo] = context.get(ContextEnum.BIN_INFO_FILTERED)
  525. homo_bin_info_numeric_set: Dict[str, HomologousBinInfo] = context.get(ContextEnum.HOMO_BIN_INFO_NUMERIC_SET)
  526. filter_fast = context.get(ContextEnum.FILTER_FAST)
  527. filter_numeric = context.get(ContextEnum.FILTER_NUMERIC)
  528. filter_corr = context.get(ContextEnum.FILTER_CORR)
  529. filter_vif = context.get(ContextEnum.FILTER_VIF)
  530. filter_ivtop = context.get(ContextEnum.FILTER_IVTOP)
  531. f_display_title(display, "样本分布")
  532. display.display(metric_value_dict["样本分布"].table)
  533. # 打印变量iv
  534. f_display_title(display, "变量iv")
  535. display.display(metric_value_dict["变量iv"].table)
  536. # 打印变量相关性
  537. f_display_images_by_side(display, metric_value_dict["变量相关性"].image_path, width=800)
  538. # 打印变量趋势
  539. f_display_title(display, "变量趋势")
  540. imgs_path_trend_train = metric_value_dict["变量趋势-训练集"].image_path
  541. imgs_path_trend_test = metric_value_dict.get("变量趋势-测试集").image_path
  542. f_display_images_by_side(display, imgs_path_trend_train, title="训练集", image_path_list2=imgs_path_trend_test,
  543. title2="测试集")
  544. # 打印breaks_list
  545. breaks_list = {column: bin_info.points for column, bin_info in bin_info_filtered.items()}
  546. print("变量切分点:")
  547. print(json.dumps(breaks_list, ensure_ascii=False, indent=2, cls=NumpyEncoder))
  548. print("选中变量不同分箱数下变量的推荐切分点:")
  549. detail_print(list(bin_info_filtered.keys()))
  550. # 打印fast_filter筛选情况
  551. filter_print(filter_fast, "快速筛选过程", "剔除train_iv小于阈值")
  552. filter_print(filter_numeric, "数值变量筛选过程")
  553. filter_print(filter_corr, "相关性筛选过程")
  554. filter_print(filter_vif, "vif筛选过程")
  555. filter_print(filter_ivtop, "ivtop筛选过程", "iv = train_iv + test_iv")