ml_pid_cbm.tools.plotting_tools
Module with plotting tools
1""" 2Module with plotting tools 3""" 4import gc 5import itertools 6from typing import List, Tuple 7 8import fasttreeshap as shap 9import matplotlib as mplt 10import matplotlib.colors 11import matplotlib.pyplot as plt 12import numpy as np 13import pandas as pd 14from hipe4ml.model_handler import ModelHandler 15from hipe4ml.plot_utils import (plot_corr, plot_distr, plot_output_train_test, 16 plot_roc) 17from hipe4ml.tree_handler import TreeHandler 18from matplotlib import rcParams 19from optuna.study import Study 20from optuna.visualization import plot_contour, plot_optimization_history 21from sklearn.utils import resample 22 23from . import json_tools 24 25PARAMS = { 26 "axes.titlesize": "22", 27 "axes.labelsize": "22", 28 "xtick.labelsize": "22", 29 "ytick.labelsize": "22", 30 "figure.figsize": "10, 7", 31 "figure.dpi": "300", 32 "legend.fontsize": "20", 33} 34rcParams.update(PARAMS) 35 36 37def tof_plot( 38 df: pd.DataFrame, 39 json_file_name: str, 40 particles_title: str, 41 file_name: str = "tof_plot", 42 x_axis_range: List[int] = [-13, 13], 43 y_axis_range: List[str] = [-1, 2], 44 save_fig: bool = True, 45) -> None: 46 """ 47 Method for creating tof plots. 48 49 Args: 50 df (pd.DataFrame): Dataframe with particles to plot 51 json_file_name (str): Name of the config.json file 52 particles_title (str): Name of the particle type. 53 file_name (str, optional): Filename to be created. Defaults to "tof_plot". 54 Will add the particles_title after the tof_plot_ when saved. 55 x_axis_range (List[int], optional): X-axis range. Defaults to [-13, 13]. 56 y_axis_range (List[str], optional): Y-axi range. Defaults to [-1, 2]. 57 save_fig (bool, optional): Where the figure should be saved. Defaults to True. 58 59 Returns: 60 None. 61 """ 62 # load variable names 63 charge_var_name = json_tools.load_var_name(json_file_name, "charge") 64 momentum_var_name = json_tools.load_var_name(json_file_name, "momentum") 65 mass2_var_name = json_tools.load_var_name(json_file_name, "mass2") 66 # prepare plot variables 67 ranges = [x_axis_range, y_axis_range] 68 qp = df[charge_var_name] * df[momentum_var_name] 69 mass2 = df[mass2_var_name] 70 x_axis_name = r"sign($q$) $\cdot p$ (GeV/c)" 71 y_axis_name = r"$m^2$ $(GeV/c^2)^2$" 72 # plot graph 73 fig, _ = plt.subplots(figsize=(15, 10), dpi=300) 74 plt.hist2d(qp, mass2, bins=200, norm=matplotlib.colors.LogNorm(), range=ranges) 75 plt.xlabel(x_axis_name, fontsize=20, loc="right") 76 plt.ylabel(y_axis_name, fontsize=20, loc="top") 77 title = f"TOF 2D plot for {particles_title}" 78 plt.title(title, fontsize=20) 79 fig.tight_layout() 80 plt.colorbar() 81 title = title.replace(" ", "_") 82 # savefig 83 if save_fig: 84 file_name = particles_title.replace(" ", "_") 85 plt.savefig(f"tof_plot_{file_name}.png") 86 plt.savefig(f"tof_plot_{file_name}.pdf") 87 plt.close() 88 else: 89 plt.show() 90 return fig 91 92 93def var_distributions_plot( 94 vars_to_draw: list, 95 data_list: List[TreeHandler], 96 leg_labels: List[str] = ["protons", "kaons", "pions"], 97 save_fig: bool = True, 98 filename: str = "vars_disitributions", 99): 100 """ 101 Plots distributions of given variables using plot_distr from hipe4ml. 102 103 Args: 104 vars_to_draw (list): List of variables to draw. 105 data_list (List[TreeHandler]): List of TreeHandlers with data. 106 leg_labels (List[str], optional): Names of the particles which are given in the list of TreeHandlers. 107 Defaults to ["protons", "kaons", "pions"]. 108 save_fig (bool, optional): Whether should save the plot. Defaults to True. 109 filename (str, optional): Name of the plot to be saved. Defaults to "vars_disitributions". 110 """ 111 plot_distr( 112 data_list, 113 vars_to_draw, 114 bins=100, 115 labels=leg_labels, 116 log=True, 117 figsize=(40, 40), 118 alpha=0.3, 119 grid=False, 120 ) 121 if save_fig: 122 plt.savefig(f"{filename}.png") 123 plt.savefig(f"{filename}.pdf") 124 plt.close() 125 else: 126 plt.show() 127 128 129def correlations_plot( 130 vars_to_draw: list, 131 data_list: List[TreeHandler], 132 leg_labels: List[str] = ["protons", "kaons", "pions"], 133 save_fig: bool = True, 134): 135 """ 136 Creates correlation plots 137 138 Args: 139 vars_to_draw (list): Variables to check correlations. 140 data_list (List[TreeHandler]): List of TreeHandlers with data. 141 leg_labels (List[str], optional): Names of the particles which are given in the list of TreeHandlers. 142 Defaults to ["protons", "kaons", "pions"]. 143 save_fig (bool, optional): Whether should save the plot. Defaults to True. 144 """ 145 plt.subplots_adjust( 146 left=0.06, bottom=0.06, right=0.99, top=0.96, hspace=0.55, wspace=0.55 147 ) 148 cor_plots = plot_corr(data_list, vars_to_draw, leg_labels) 149 if isinstance(cor_plots, list): 150 for i, plot in enumerate(cor_plots): 151 if save_fig: 152 plot.savefig(f"correlations_plot_{i}.png") 153 plot.savefig(f"correlations_plot_{i}.pdf") 154 plt.close(plot) 155 else: 156 plot.show() 157 else: 158 if save_fig: 159 cor_plots.savefig(f"correlations_plot.png") 160 cor_plots.savefig(f"correlations_plot.pdf") 161 plt.close(cor_plots) 162 else: 163 cor_plots.show() 164 165 166def opt_history_plot(study: Study, save_fig: bool = True): 167 """ 168 Saves optimization history. 169 170 Args: 171 study (Study): optuna.Study to be saved 172 save_fig (bool, optional): Whether should save the plot. Defaults to True. 173 """ 174 # for saving python-kaleido package is needed 175 fig = plot_optimization_history(study) 176 if save_fig: 177 fig.write_image("optimization_history.png") 178 fig.write_image("optimization_history.pdf") 179 else: 180 fig.show() 181 plt.close() 182 183 184def opt_contour_plot(study: Study, save_fig: bool = True): 185 """ 186 Saves optimization contour plot 187 188 Args: 189 study (Study): optuna.Study to be saved 190 save_fig (bool, optional): Whether should save the plot. Defaults to True. 191 """ 192 fig = plot_contour(study) 193 if save_fig: 194 fig.write_image("optimization_contour.png") 195 fig.write_image("optimization_contour.pdf") 196 plt.close() 197 else: 198 plt.show() 199 200 201def output_train_test_plot( 202 model_hdl: ModelHandler, 203 train_test_data, 204 leg_labels: List[str] = ["protons", "kaons", "pions"], 205 logscale: bool = False, 206 save_fig: bool = True, 207): 208 """ 209 Output traing plot as in hipe4ml.plot_output_train_test 210 211 Args: 212 model_hdl (ModelHandler): Model handler to be tested 213 train_test_data (_type_): List created by PrepareModel.prepare_train_test_data 214 leg_labels (List[str], optional): Names of the classified particles. Defaults to ["protons", "kaons", "pions"]. 215 logscale (bool, optional): Whether should use logscale. Defaults to False. 216 save_fig (bool, optional): Whether should save the plots. Defaults to True. 217 """ 218 ml_out_fig = plot_output_train_test( 219 model_hdl, 220 train_test_data, 221 100, 222 False, 223 leg_labels, 224 logscale=logscale, 225 density=False, # if true histograms are normalized 226 ) 227 if len(leg_labels) > 1: 228 for idx, fig in enumerate(ml_out_fig): 229 if save_fig: 230 fig.savefig(f"output_train_test_plot_{idx}.png") 231 fig.savefig(f"output_train_test_plot_{idx}.pdf") 232 else: 233 fig.show() 234 else: 235 if save_fig: 236 ml_out_fig.savefig(f"output_train_test_plot.png") 237 ml_out_fig.savefig(f"output_train_test_plot.pdf") 238 else: 239 ml_out_fig.show() 240 plt.close() 241 242 243def roc_plot( 244 test_df: pd.DataFrame, 245 test_labels_array: np.ndarray, 246 leg_labels: List[str] = ["protons", "kaons", "pions"], 247 save_fig: bool = True, 248): 249 """ 250 Roc plot of the model 251 252 Args: 253 test_df (pd.DataFrame): Dataframe containg test_dataset with particles. 254 test_labels_array (np.ndarray): Ndarray containig labels of the test_df. 255 leg_labels (List[str], optional): Names of the classified particles. Defaults to ["protons", "kaons", "pions"]. 256 save_fig (bool, optional): Whether should save the plot. Defaults to True. 257 """ 258 plot_roc(test_df, test_labels_array, None, leg_labels, multi_class_opt="ovo") 259 if save_fig: 260 plt.savefig("roc_plot.png") 261 plt.savefig("roc_plot.pdf") 262 plt.close() 263 else: 264 plt.show() 265 266 267def plot_confusion_matrix( 268 cnf_matrix: np.ndarray, 269 classes: List[str] = ["proton", "kaon", "pion", "bckgr"], 270 normalize: bool = False, 271 title: str = "Confusion matrix", 272 cmap=mplt.colormaps["Blues"], 273 save_fig: bool = True, 274): 275 """ 276 Plot created earlier confusion matrix. 277 278 Args: 279 cnf_matrix (np.ndarray): Confusion matrix 280 classes (List[str], optional): List of the names of the classes. 281 Defaults to ["proton", "kaon", "pion", "bckgr"]. 282 normalize (bool, optional): Whether should normalize the plot. Defaults to False. 283 title (str, optional): Title of the plot. Defaults to "Confusion matrix". 284 cmap (_type_, optional): Cmap used for colors. Defaults to mplt.colormaps["Blues"]. 285 save_fig (bool, optional): Whether should save the plot. Defaults to True. 286 """ 287 filename = "confusion_matrix" 288 if normalize: 289 cnf_matrix = cnf_matrix.astype("float") / cnf_matrix.sum(axis=1)[:, np.newaxis] 290 print("Normalized confusion matrix") 291 title = title + " (normalized)" 292 filename = filename + " (norm)" 293 else: 294 print("Confusion matrix, without normalization") 295 296 print(cnf_matrix) 297 np.set_printoptions(precision=2) 298 fig, axs = plt.subplots(figsize=(10, 8), dpi=300) 299 axs.yaxis.set_label_coords(-0.04, 0.5) 300 axs.xaxis.set_label_coords(0.5, -0.005) 301 plt.imshow(cnf_matrix, interpolation="nearest", cmap=cmap) 302 plt.title(title) 303 plt.colorbar() 304 tick_marks = np.arange(len(classes)) 305 plt.xticks(tick_marks, classes, rotation=45) 306 plt.yticks(tick_marks, classes) 307 308 fmt = ".2f" if normalize else "d" 309 thresh = cnf_matrix.max() / 2.0 310 for i, j in itertools.product( 311 range(cnf_matrix.shape[0]), range(cnf_matrix.shape[1]) 312 ): 313 plt.text( 314 j, 315 i, 316 format(cnf_matrix[i, j], fmt), 317 horizontalalignment="center", 318 color="white" if cnf_matrix[i, j] > thresh else "black", 319 ) 320 321 plt.tight_layout() 322 plt.ylabel("True label", fontsize=15) 323 plt.xlabel("Predicted label", fontsize=15) 324 if save_fig: 325 plt.savefig(f"{filename}.png") 326 plt.savefig(f"{filename}.pdf") 327 plt.close() 328 else: 329 plt.show() 330 331 332def plot_mass2( 333 xgb_mass: pd.Series, 334 sim_mass: pd.Series, 335 particles_title: str, 336 range1: Tuple[float, float], 337 y_axis_log: bool = False, 338 save_fig: bool = True, 339): 340 """ 341 Plots mass^2 342 343 Args: 344 xgb_mass (pd.Series): pd.Series containg xgb_selected mass^2 345 sim_mass (pd.Series): pd.Series containg MC-true mass^2 346 particles_title (str): Name of the plot. 347 range1 (tuple[float, float]): Range of the mass2 to be plotted on x-axis. 348 y_axis_log (bool, optional): If should use logscale in y-scale. Defaults to False. 349 save_fig (bool, optional): Whether should save the plot. Defaults to True. 350 """ 351 # fig, axs = plt.subplots(2, 1,figsize=(15,10), sharex=True, gridspec_kw={'width_ratios': [10], 352 # 'height_ratios': [8,4]}) 353 fig, axs = plt.subplots(figsize=(15, 10), dpi=300) 354 355 ns, bins, patches = axs.hist( 356 xgb_mass, bins=300, facecolor="red", alpha=0.3, range=range1 357 ) 358 ns1, bins1, patches1 = axs.hist( 359 sim_mass, bins=300, facecolor="blue", alpha=0.3, range=range1 360 ) 361 # plt.xlabel("Mass in GeV", fontsize = 15) 362 axs.set_ylabel("counts", fontsize=15) 363 # axs[0].grid() 364 axs.legend( 365 ("XGBoost selected " + particles_title, "all simulated " + particles_title), 366 loc="upper right", 367 ) 368 if y_axis_log: 369 axs.set_yscale("log") 370 # plt.rcParams["legend.loc"] = 'upper right' 371 title = f"{particles_title} $mass^2$ histogram" 372 yName = r"Counts" 373 xName = r"$m^2$ $(GeV/c^2)^2$" 374 plt.xlabel(xName, fontsize=20, loc="right") 375 plt.ylabel(yName, fontsize=20, loc="top") 376 axs.set_title(title, fontsize=20) 377 axs.grid() 378 axs.tick_params(axis="both", which="major", labelsize=18) 379 if save_fig: 380 plt.savefig(f"mass2_{particles_title}.png") 381 plt.savefig(f"mass2_{particles_title}.pdf") 382 plt.close() 383 else: 384 plt.show() 385 386 387def plot_all_particles_mass2( 388 xgb_selected: pd.Series, 389 mass2_variable_name: str, 390 pid_variable_name: str, 391 particles_title: str, 392 range1: Tuple[float, float], 393 y_axis_log: bool = False, 394 save_fig: bool = True, 395): 396 """ 397 Plots mc-true particle type in xgb_selected particles 398 399 Args: 400 xgb_selected (pd.Series): pd.Series with xgb-selected particles. 401 mass2_variable_name (str): Name of the mass2 variable name. 402 pid_variable_name (str): Name of the pid variable name. 403 particles_title (str): Name of the plot. 404 range1 (tuple[float, float]): Range of the x-axis. 405 y_axis_log (bool, optional): If should use logscale in y-scale. Defaults to False. 406 save_fig (bool, optional): Whether should save the plot. Defaults to True. 407 """ 408 # fig, axs = plt.subplots(2, 1,figsize=(15,10), sharex=True, gridspec_kw={'width_ratios': [10], 409 # 'height_ratios': [8,4]}) 410 fig, axs = plt.subplots(figsize=(15, 10), dpi=300) 411 412 selected_protons = xgb_selected[xgb_selected[pid_variable_name] == 0][ 413 mass2_variable_name 414 ] 415 selected_kaons = xgb_selected[xgb_selected[pid_variable_name] == 1][ 416 mass2_variable_name 417 ] 418 selected_pions = xgb_selected[xgb_selected[pid_variable_name] == 2][ 419 mass2_variable_name 420 ] 421 422 ns, bins, patches = axs.hist( 423 selected_protons, bins=300, facecolor="blue", alpha=0.4, range=range1 424 ) 425 ns, bins, patches = axs.hist( 426 selected_kaons, bins=300, facecolor="orange", alpha=0.4, range=range1 427 ) 428 ns, bins, patches = axs.hist( 429 selected_pions, bins=300, facecolor="green", alpha=0.4, range=range1 430 ) 431 432 # plt.xlabel("Mass in GeV", fontsize = 15) 433 axs.set_ylabel("counts", fontsize=15) 434 # axs[0].grid() 435 axs.legend( 436 ( 437 f"XGBoost selected true protons", 438 "XGBoost selected true kaons", 439 "XGBoost selected true pions", 440 ), 441 loc="upper right", 442 ) 443 if y_axis_log: 444 axs.set_yscale("log") 445 title = f"ALL XGBoost selected (true and false positive) {particles_title} $mass^2$ histogram" 446 yName = r"Counts" 447 xName = r"$m^2$ $(GeV/c^2)^2$" 448 plt.xlabel(xName, loc="right") 449 plt.ylabel(yName, loc="top") 450 axs.set_title(title) 451 axs.grid() 452 axs.tick_params(axis="both", which="major", labelsize=18) 453 if save_fig: 454 plt.savefig(f"mass2_all_selected_{particles_title}.png") 455 plt.savefig(f"mass2_all_selected_{particles_title}.pdf") 456 plt.close() 457 else: 458 plt.show() 459 460 461def plot_eff_pT_rap( 462 df: pd.DataFrame, 463 pid: int, 464 pid_var_name: str = "Complex_pid", 465 rapidity_var_name: str = "Complex_rapidity", 466 pT_var_name: str = "Complex_pT", 467 ranges: Tuple[Tuple[float, float], Tuple[float, float]] = [[0, 5], [0, 3]], 468 nbins: int = 50, 469 save_fig: bool = True, 470 particle_names: List[str] = ["protons", "kaons", "pions", "bckgr"], 471): 472 df_true = df[(df[pid_var_name] == pid)] # simulated 473 df_reco = df[(df["xgb_preds"] == pid)] # reconstructed by xgboost 474 475 x = np.array(df_true[rapidity_var_name]) 476 y = np.array(df_true[pT_var_name]) 477 478 xe = np.array(df_reco[rapidity_var_name]) 479 ye = np.array(df_reco[pT_var_name]) 480 481 fig = plt.figure(figsize=(8, 10), dpi=300) 482 plt.title(f"$p_T$-rapidity efficiency for all selected {particle_names[pid]}") 483 true, yedges, xedges = np.histogram2d(x, y, bins=nbins, range=ranges) 484 reco, _, _ = np.histogram2d(xe, ye, bins=(yedges, xedges), range=ranges) 485 486 eff = np.divide(true, reco, out=np.zeros_like(true), where=reco != 0) # Efficiency 487 eff[eff == 0] = np.nan # show zeros as white 488 img = plt.imshow( 489 eff, 490 interpolation="nearest", 491 origin="lower", 492 vmin=0, 493 vmax=1, 494 extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]], 495 ) 496 497 cbar = fig.colorbar(img, fraction=0.025, pad=0.08) # above plot H 498 cbar.set_label("efficiency (selected/simulated)", rotation=270, labelpad=20) 499 500 plt.xlabel("rapidity") 501 plt.ylabel("$p_T$ (GeV/c)") 502 plt.tight_layout() 503 if save_fig: 504 plt.savefig(f"eff_pT_rap_{particle_names[pid]}.png") 505 plt.savefig(f"eff_pT_rap_{particle_names[pid]}.pdf") 506 plt.close() 507 else: 508 plt.show() 509 510 511def plot_pt_rapidity( 512 df: pd.DataFrame, 513 pid: int, 514 pid_var_name: str = "Complex_pid", 515 rapidity_var_name: str = "Complex_rapidity", 516 pT_var_name: str = "Complex_pT", 517 ranges: Tuple[Tuple[float, float], Tuple[float, float]] = [[0, 5], [0, 3]], 518 nbins=50, 519 save_fig: bool = True, 520 particle_names: List[str] = ["protons", "kaons", "pions", "bckgr"], 521): 522 """ 523 Plots pt-rapidity 2D histogram. 524 525 Args: 526 df (pd.DataFrame): Dataframe with input data. 527 pid (int): Pid of the variable to be plotted. 528 pid_var_name (str, optional): Name of the pid variable. Defaults to "Complex_pid". 529 rapidity_var_name (str, optional): Name of the rapidity variable. Defaults to "Complex_rapidity". 530 pT_var_name (str, optional): Name of the pT variable. Defaults to "Complex_pT". 531 ranges (Tuple[Tuple[float, float], Tuple[float, float]], optional): 532 Ranges of the plot. Defaults to [[0, 5], [0, 3]]. 533 nbins (int, optional): Number of bins in each axis. Defaults to 50. 534 save_fig (bool, optional): Whether should save the figute. Defaults to True. 535 particle_names (List[str], optional): Names of the particles corresponding to pid. 536 Defaults to ["protons", "kaons", "pions", "bckgr"]. 537 """ 538 df_true = df[(df[pid_var_name] == pid)] # simulated 539 540 x = np.array(df_true[rapidity_var_name]) 541 y = np.array(df_true[pT_var_name]) 542 543 fig = plt.figure(figsize=(8, 10), dpi=300) 544 plt.title(f"$p_T$-rapidity graph for all simulated {particle_names[pid]}") 545 546 true, yedges, xedges = np.histogram2d(x, y, bins=nbins, range=ranges) 547 true[true == 0] = np.nan # show zeros as white 548 549 img = plt.imshow( 550 true, 551 interpolation="nearest", 552 origin="lower", 553 extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]], 554 ) 555 556 cbar = fig.colorbar(img, fraction=0.025, pad=0.08) # above plot H 557 cbar.set_label("counts", rotation=270, labelpad=20) 558 559 plt.xlabel("rapidity") 560 plt.ylabel("$p_T$ (GeV/c)") 561 plt.tight_layout() 562 if save_fig: 563 plt.savefig(f"plot_pt_rapidity_{particle_names[pid]}.png") 564 plt.savefig(f"plot_pt_rapidity_{particle_names[pid]}.pdf") 565 plt.close() 566 else: 567 plt.show() 568 569 570def _shap_summary( 571 shap_values, 572 x_train_resampled: pd.DataFrame, 573 features_names: List[str], 574 particle_name: str, 575 save_fig: bool = True, 576): 577 """ 578 Internal method for plotting summary shap plots. 579 580 Args: 581 shap_values (_type_): Shap values. 582 x_train_resampled (pd.DataFrame): Dataframe with X training variables. 583 features_names (List[str]): List of the training variables. 584 particle_name (str): Name of the particle. 585 save_fig (bool, optional): Whether should save the plot. Defaults to True. 586 """ 587 fig, ax = plt.subplots(figsize=(8, 6), dpi=300) 588 shap.summary_plot( 589 shap_values, 590 x_train_resampled, 591 feature_names=features_names, 592 show=False, 593 ) 594 w, h = plt.gcf().get_size_inches() 595 plt.gcf().set_size_inches(h + 2, h) 596 plt.gcf().set_size_inches(w, w * 3 / 4) 597 plt.gcf().axes[-1].set_aspect("auto") 598 plt.gcf().axes[-1].set_box_aspect(50) 599 plt.xlabel(f"SHAP values for {particle_name}", fontsize=18) 600 ax.spines["top"].set_visible(True) 601 ax.spines["right"].set_visible(True) 602 ax.spines["bottom"].set_visible(True) 603 ax.spines["left"].set_visible(True) 604 ax.tick_params( 605 axis="both", 606 which="major", 607 length=10, 608 direction="in", 609 labelsize=15, 610 zorder=4, 611 ) 612 ax.minorticks_on() 613 ax.tick_params( 614 axis="both", which="minor", length=5, direction="in", labelsize=15, zorder=5 615 ) 616 fig.tight_layout() 617 if save_fig: 618 plt.savefig(f"shap_summary_{particle_name}.png") 619 plt.savefig(f"shap_summary_{particle_name}.pdf") 620 plt.close() 621 else: 622 plt.show() 623 624 625def _shap_interaction( 626 shap_values, 627 x_train_resampled, 628 features_names, 629 particle_name: str, 630 save_fig: bool = True, 631): 632 """ 633 Internal method for plotting shap interaction plots. 634 635 Args: 636 shap_values (_type_): Shap values. 637 x_train_resampled (pd.DataFrame): Dataframe with X training variables. 638 features_names (List[str]): List of the training variables. 639 particle_name (str): Name of the particle. 640 save_fig (bool, optional): Whether should save the plot. Defaults to True. 641 """ 642 for feature in features_names: 643 fig, ax = plt.subplots(figsize=(8, 6), dpi=300) 644 shap.dependence_plot( 645 feature, 646 shap_values, 647 x_train_resampled, 648 display_features=x_train_resampled, 649 show=False, 650 ) 651 w, h = plt.gcf().get_size_inches() 652 plt.gcf().set_size_inches(h + 2, h) 653 plt.gcf().set_size_inches(w, w * 3 / 4) 654 plt.gcf().axes[-1].set_aspect("auto") 655 plt.gcf().axes[-1].set_box_aspect(50) 656 plt.xlabel(f"{feature} for {particle_name}", fontsize=18) 657 ax.spines["top"].set_visible(True) 658 ax.spines["right"].set_visible(True) 659 ax.spines["bottom"].set_visible(True) 660 ax.spines["left"].set_visible(True) 661 ax.tick_params( 662 axis="both", 663 which="major", 664 length=10, 665 direction="in", 666 labelsize=15, 667 zorder=4, 668 ) 669 ax.minorticks_on() 670 ax.tick_params( 671 axis="both", which="minor", length=5, direction="in", labelsize=15, zorder=5 672 ) 673 fig.tight_layout() 674 if save_fig: 675 plt.savefig(f"shap_{feature}_{particle_name}.png") 676 plt.savefig(f"shap_{feature}_{particle_name}.pdf") 677 plt.close() 678 else: 679 plt.show() 680 681 682def plot_shap_summary( 683 x_train: pd.DataFrame, 684 y_train: pd.DataFrame, 685 model_hdl: ModelHandler, 686 features_names: List[str], 687 n_workers: int = 1, 688 save_fig: bool = True, 689 approximate: bool = False, 690 n_samples: int = 50000, 691 particle_names: List[str] = ["protons", "kaons", "pions"], 692): 693 """ 694 Method for plotting shap plots 695 696 Args: 697 x_train (pd.DataFrame): pd.Dataframe with X training dataset. 698 y_train (pd.DataFrame): X training dataset labels. 699 model_hdl (ModelHandler): Model Handler to be explained. 700 features_names (List[str]): List of the training variables. 701 n_workers (int, optional): Number of thread for multithreading. 702 Note: it uses fastreeshap library, not shap. Defaults to 1. 703 save_fig (bool, optional): Whether should save the plots.. Defaults to True. 704 approximate (bool, optional): Whether should the approximate values. Defaults to False. 705 n_samples (int, optional): Maximal number of samples in each class. Defaults to 50000. 706 particle_names (List[str], optional): List of the classified particle names. 707 Defaults to ["protons", "kaons", "pions"]. 708 """ 709 print("Creating shap plots...") 710 explainer = shap.TreeExplainer( 711 model_hdl.get_original_model(), n_jobs=n_workers, approximate=approximate 712 ) 713 # Apply n_sanples in each class 714 y_train_df = pd.DataFrame(y_train, columns=["true_class"]) 715 merged_df = pd.concat([x_train, y_train_df], axis=1) 716 grouped_df = merged_df.groupby("true_class") 717 resampled_df = pd.concat( 718 [ 719 resample(group, n_samples=min(n_samples, len(group)), replace=False) 720 for _, group in grouped_df 721 ] 722 ) 723 724 # Split the resampled pd.DataFrame back into input data and label data 725 x_train_resampled = resampled_df.iloc[:, :-1] 726 y_train_resampled = resampled_df.iloc[:, -1].to_numpy() 727 del merged_df, grouped_df, resampled_df 728 gc.collect() 729 730 shap_values = explainer.shap_values( 731 x_train_resampled, y_train_resampled, check_additivity=False 732 ) 733 num_classes = len(shap_values) # get the number of classes 734 for i in range(num_classes): 735 _shap_summary( 736 shap_values[i], 737 x_train_resampled, 738 features_names, 739 particle_names[i], 740 save_fig=save_fig, 741 ) 742 _shap_interaction( 743 shap_values[i], 744 x_train_resampled, 745 features_names, 746 particle_names[i], 747 save_fig=save_fig, 748 ) 749 750 751def plot_efficiency_purity( 752 probas: np.ndarray, 753 efficiencies: List[List[float]], 754 purities: List[List[float]], 755 save_fig: bool = True, 756 particle_names: List[str] = ["protons", "kaons", "pions"], 757): 758 """ 759 Plots efficiency and purity in function of probability cuts. 760 761 Args: 762 probas (np.ndarray): Probability cuts 763 efficiencies (List[List[float]]): List of list of efficiencies for each clas. 764 purities (List[List[float]]): List of list of purities for each clas. 765 save_fig (bool, optional): Whether should save the fig. Defaults to True. 766 particle_names (List[str], optional): List of the particle names. Defaults to ["protons", "kaons", "pions"]. 767 """ 768 for i, (eff, pur) in enumerate(zip(efficiencies, purities)): 769 if save_fig: 770 dpi = 300 771 else: 772 dpi = 100 773 fig, ax = plt.subplots(figsize=(10, 7), dpi=dpi) 774 ax.plot(probas, eff, label="efficiency") 775 ax.plot(probas, pur, label="purity") 776 ax.set_xlabel("BDT cut") 777 ax.set_ylabel("\% ") 778 ax.legend(loc="upper right") 779 ax.set_title( 780 f"Efficiency and purity in function of BDT cut for {particle_names[i]}" 781 ) 782 ax.grid(which="major", linestyle="-") 783 ax.minorticks_on() 784 ax.grid(which="minor", linestyle="--") 785 if save_fig: 786 fig.savefig(f"efficiency_purity__{particle_names[i]}.png") 787 fig.savefig(f"efficiency_purity_id_{particle_names[i]}.pdf") 788 plt.close() 789 else: 790 plt.show() 791 792 793# deprecated 794def plot_before_after_variables( 795 df: pd.DataFrame, 796 pid: float, 797 pid_variable_name: str, 798 training_variables: List[str], 799 save_fig: bool = True, 800 log_yscale: bool = True, 801): 802 """ 803 Plots each variable before and after selection. 804 Legacy: var_distributions_plot should rather be used. 805 806 Args: 807 df (pd.DataFrame): _description_ 808 pid (float): _description_ 809 pid_variable_name (str): _description_ 810 training_variables (List[str]): _description_ 811 save_fig (bool, optional): _description_. Defaults to True. 812 log_yscale (bool, optional): _description_. Defaults to True. 813 814 Returns: 815 _type_: _description_ 816 """ 817 df_true = df[(df[pid_variable_name] == pid)] # simulated 818 df_reco = df[(df["xgb_preds"] == pid)] # reconstructed by xgboost 819 820 def variable_plot( 821 df_true, 822 df_reco, 823 variable_name: str, 824 log_yscale: bool = True, 825 leg1="Simulated", 826 leg2="XGB-selected", 827 bins=100, 828 ): 829 fig, ax = plt.subplots(figsize=(15, 15), dpi=300) 830 ax.hist( 831 df_true[variable_name], 832 bins=bins, 833 facecolor="blue", 834 alpha=0.6, 835 histtype="step", 836 fill=False, 837 linewidth=2, 838 ) 839 ax.hist( 840 df_reco[variable_name], 841 bins=bins, 842 facecolor="red", 843 alpha=0.7, 844 histtype="step", 845 fill=False, 846 linewidth=2, 847 ) 848 ax.grid() 849 ax.set_xlabel(variable_name, fontsize=15, loc="right") 850 ax.set_ylim(bottom=1) 851 ax.set_ylabel("counts", fontsize=15) 852 if log_yscale: 853 ax.set_yscale("log") 854 ax.legend((leg1, leg2), fontsize=15, loc="upper right") 855 ax.set_title( 856 f"{variable_name} before and after XGB selection for pid={pid}", fontsize=15 857 ) 858 return fig 859 860 for training_variable in training_variables: 861 plot = variable_plot(df_true, df_reco, training_variable, log_yscale) 862 if save_fig: 863 plot.savefig(f"{training_variable}_before_after_pid_{pid}.png") 864 plot.savefig(f"{training_variable}_before_after_{pid}.pdf") 865 plt.close() 866 867 else: 868 plot.show()
38def tof_plot( 39 df: pd.DataFrame, 40 json_file_name: str, 41 particles_title: str, 42 file_name: str = "tof_plot", 43 x_axis_range: List[int] = [-13, 13], 44 y_axis_range: List[str] = [-1, 2], 45 save_fig: bool = True, 46) -> None: 47 """ 48 Method for creating tof plots. 49 50 Args: 51 df (pd.DataFrame): Dataframe with particles to plot 52 json_file_name (str): Name of the config.json file 53 particles_title (str): Name of the particle type. 54 file_name (str, optional): Filename to be created. Defaults to "tof_plot". 55 Will add the particles_title after the tof_plot_ when saved. 56 x_axis_range (List[int], optional): X-axis range. Defaults to [-13, 13]. 57 y_axis_range (List[str], optional): Y-axi range. Defaults to [-1, 2]. 58 save_fig (bool, optional): Where the figure should be saved. Defaults to True. 59 60 Returns: 61 None. 62 """ 63 # load variable names 64 charge_var_name = json_tools.load_var_name(json_file_name, "charge") 65 momentum_var_name = json_tools.load_var_name(json_file_name, "momentum") 66 mass2_var_name = json_tools.load_var_name(json_file_name, "mass2") 67 # prepare plot variables 68 ranges = [x_axis_range, y_axis_range] 69 qp = df[charge_var_name] * df[momentum_var_name] 70 mass2 = df[mass2_var_name] 71 x_axis_name = r"sign($q$) $\cdot p$ (GeV/c)" 72 y_axis_name = r"$m^2$ $(GeV/c^2)^2$" 73 # plot graph 74 fig, _ = plt.subplots(figsize=(15, 10), dpi=300) 75 plt.hist2d(qp, mass2, bins=200, norm=matplotlib.colors.LogNorm(), range=ranges) 76 plt.xlabel(x_axis_name, fontsize=20, loc="right") 77 plt.ylabel(y_axis_name, fontsize=20, loc="top") 78 title = f"TOF 2D plot for {particles_title}" 79 plt.title(title, fontsize=20) 80 fig.tight_layout() 81 plt.colorbar() 82 title = title.replace(" ", "_") 83 # savefig 84 if save_fig: 85 file_name = particles_title.replace(" ", "_") 86 plt.savefig(f"tof_plot_{file_name}.png") 87 plt.savefig(f"tof_plot_{file_name}.pdf") 88 plt.close() 89 else: 90 plt.show() 91 return fig
Method for creating tof plots.
Args: df (pd.DataFrame): Dataframe with particles to plot json_file_name (str): Name of the config.json file particles_title (str): Name of the particle type. file_name (str, optional): Filename to be created. Defaults to "tof_plot". Will add the particles_title after the tof_plot_ when saved. x_axis_range (List[int], optional): X-axis range. Defaults to [-13, 13]. y_axis_range (List[str], optional): Y-axi range. Defaults to [-1, 2]. save_fig (bool, optional): Where the figure should be saved. Defaults to True.
Returns: None.
94def var_distributions_plot( 95 vars_to_draw: list, 96 data_list: List[TreeHandler], 97 leg_labels: List[str] = ["protons", "kaons", "pions"], 98 save_fig: bool = True, 99 filename: str = "vars_disitributions", 100): 101 """ 102 Plots distributions of given variables using plot_distr from hipe4ml. 103 104 Args: 105 vars_to_draw (list): List of variables to draw. 106 data_list (List[TreeHandler]): List of TreeHandlers with data. 107 leg_labels (List[str], optional): Names of the particles which are given in the list of TreeHandlers. 108 Defaults to ["protons", "kaons", "pions"]. 109 save_fig (bool, optional): Whether should save the plot. Defaults to True. 110 filename (str, optional): Name of the plot to be saved. Defaults to "vars_disitributions". 111 """ 112 plot_distr( 113 data_list, 114 vars_to_draw, 115 bins=100, 116 labels=leg_labels, 117 log=True, 118 figsize=(40, 40), 119 alpha=0.3, 120 grid=False, 121 ) 122 if save_fig: 123 plt.savefig(f"{filename}.png") 124 plt.savefig(f"{filename}.pdf") 125 plt.close() 126 else: 127 plt.show()
Plots distributions of given variables using plot_distr from hipe4ml.
Args: vars_to_draw (list): List of variables to draw. data_list (List[TreeHandler]): List of TreeHandlers with data. leg_labels (List[str], optional): Names of the particles which are given in the list of TreeHandlers. Defaults to ["protons", "kaons", "pions"]. save_fig (bool, optional): Whether should save the plot. Defaults to True. filename (str, optional): Name of the plot to be saved. Defaults to "vars_disitributions".
130def correlations_plot( 131 vars_to_draw: list, 132 data_list: List[TreeHandler], 133 leg_labels: List[str] = ["protons", "kaons", "pions"], 134 save_fig: bool = True, 135): 136 """ 137 Creates correlation plots 138 139 Args: 140 vars_to_draw (list): Variables to check correlations. 141 data_list (List[TreeHandler]): List of TreeHandlers with data. 142 leg_labels (List[str], optional): Names of the particles which are given in the list of TreeHandlers. 143 Defaults to ["protons", "kaons", "pions"]. 144 save_fig (bool, optional): Whether should save the plot. Defaults to True. 145 """ 146 plt.subplots_adjust( 147 left=0.06, bottom=0.06, right=0.99, top=0.96, hspace=0.55, wspace=0.55 148 ) 149 cor_plots = plot_corr(data_list, vars_to_draw, leg_labels) 150 if isinstance(cor_plots, list): 151 for i, plot in enumerate(cor_plots): 152 if save_fig: 153 plot.savefig(f"correlations_plot_{i}.png") 154 plot.savefig(f"correlations_plot_{i}.pdf") 155 plt.close(plot) 156 else: 157 plot.show() 158 else: 159 if save_fig: 160 cor_plots.savefig(f"correlations_plot.png") 161 cor_plots.savefig(f"correlations_plot.pdf") 162 plt.close(cor_plots) 163 else: 164 cor_plots.show()
Creates correlation plots
Args: vars_to_draw (list): Variables to check correlations. data_list (List[TreeHandler]): List of TreeHandlers with data. leg_labels (List[str], optional): Names of the particles which are given in the list of TreeHandlers. Defaults to ["protons", "kaons", "pions"]. save_fig (bool, optional): Whether should save the plot. Defaults to True.
167def opt_history_plot(study: Study, save_fig: bool = True): 168 """ 169 Saves optimization history. 170 171 Args: 172 study (Study): optuna.Study to be saved 173 save_fig (bool, optional): Whether should save the plot. Defaults to True. 174 """ 175 # for saving python-kaleido package is needed 176 fig = plot_optimization_history(study) 177 if save_fig: 178 fig.write_image("optimization_history.png") 179 fig.write_image("optimization_history.pdf") 180 else: 181 fig.show() 182 plt.close()
Saves optimization history.
Args: study (Study): optuna.Study to be saved save_fig (bool, optional): Whether should save the plot. Defaults to True.
185def opt_contour_plot(study: Study, save_fig: bool = True): 186 """ 187 Saves optimization contour plot 188 189 Args: 190 study (Study): optuna.Study to be saved 191 save_fig (bool, optional): Whether should save the plot. Defaults to True. 192 """ 193 fig = plot_contour(study) 194 if save_fig: 195 fig.write_image("optimization_contour.png") 196 fig.write_image("optimization_contour.pdf") 197 plt.close() 198 else: 199 plt.show()
Saves optimization contour plot
Args: study (Study): optuna.Study to be saved save_fig (bool, optional): Whether should save the plot. Defaults to True.
202def output_train_test_plot( 203 model_hdl: ModelHandler, 204 train_test_data, 205 leg_labels: List[str] = ["protons", "kaons", "pions"], 206 logscale: bool = False, 207 save_fig: bool = True, 208): 209 """ 210 Output traing plot as in hipe4ml.plot_output_train_test 211 212 Args: 213 model_hdl (ModelHandler): Model handler to be tested 214 train_test_data (_type_): List created by PrepareModel.prepare_train_test_data 215 leg_labels (List[str], optional): Names of the classified particles. Defaults to ["protons", "kaons", "pions"]. 216 logscale (bool, optional): Whether should use logscale. Defaults to False. 217 save_fig (bool, optional): Whether should save the plots. Defaults to True. 218 """ 219 ml_out_fig = plot_output_train_test( 220 model_hdl, 221 train_test_data, 222 100, 223 False, 224 leg_labels, 225 logscale=logscale, 226 density=False, # if true histograms are normalized 227 ) 228 if len(leg_labels) > 1: 229 for idx, fig in enumerate(ml_out_fig): 230 if save_fig: 231 fig.savefig(f"output_train_test_plot_{idx}.png") 232 fig.savefig(f"output_train_test_plot_{idx}.pdf") 233 else: 234 fig.show() 235 else: 236 if save_fig: 237 ml_out_fig.savefig(f"output_train_test_plot.png") 238 ml_out_fig.savefig(f"output_train_test_plot.pdf") 239 else: 240 ml_out_fig.show() 241 plt.close()
Output traing plot as in hipe4ml.plot_output_train_test
Args: model_hdl (ModelHandler): Model handler to be tested train_test_data (_type_): List created by PrepareModel.prepare_train_test_data leg_labels (List[str], optional): Names of the classified particles. Defaults to ["protons", "kaons", "pions"]. logscale (bool, optional): Whether should use logscale. Defaults to False. save_fig (bool, optional): Whether should save the plots. Defaults to True.
244def roc_plot( 245 test_df: pd.DataFrame, 246 test_labels_array: np.ndarray, 247 leg_labels: List[str] = ["protons", "kaons", "pions"], 248 save_fig: bool = True, 249): 250 """ 251 Roc plot of the model 252 253 Args: 254 test_df (pd.DataFrame): Dataframe containg test_dataset with particles. 255 test_labels_array (np.ndarray): Ndarray containig labels of the test_df. 256 leg_labels (List[str], optional): Names of the classified particles. Defaults to ["protons", "kaons", "pions"]. 257 save_fig (bool, optional): Whether should save the plot. Defaults to True. 258 """ 259 plot_roc(test_df, test_labels_array, None, leg_labels, multi_class_opt="ovo") 260 if save_fig: 261 plt.savefig("roc_plot.png") 262 plt.savefig("roc_plot.pdf") 263 plt.close() 264 else: 265 plt.show()
Roc plot of the model
Args: test_df (pd.DataFrame): Dataframe containg test_dataset with particles. test_labels_array (np.ndarray): Ndarray containig labels of the test_df. leg_labels (List[str], optional): Names of the classified particles. Defaults to ["protons", "kaons", "pions"]. save_fig (bool, optional): Whether should save the plot. Defaults to True.
268def plot_confusion_matrix( 269 cnf_matrix: np.ndarray, 270 classes: List[str] = ["proton", "kaon", "pion", "bckgr"], 271 normalize: bool = False, 272 title: str = "Confusion matrix", 273 cmap=mplt.colormaps["Blues"], 274 save_fig: bool = True, 275): 276 """ 277 Plot created earlier confusion matrix. 278 279 Args: 280 cnf_matrix (np.ndarray): Confusion matrix 281 classes (List[str], optional): List of the names of the classes. 282 Defaults to ["proton", "kaon", "pion", "bckgr"]. 283 normalize (bool, optional): Whether should normalize the plot. Defaults to False. 284 title (str, optional): Title of the plot. Defaults to "Confusion matrix". 285 cmap (_type_, optional): Cmap used for colors. Defaults to mplt.colormaps["Blues"]. 286 save_fig (bool, optional): Whether should save the plot. Defaults to True. 287 """ 288 filename = "confusion_matrix" 289 if normalize: 290 cnf_matrix = cnf_matrix.astype("float") / cnf_matrix.sum(axis=1)[:, np.newaxis] 291 print("Normalized confusion matrix") 292 title = title + " (normalized)" 293 filename = filename + " (norm)" 294 else: 295 print("Confusion matrix, without normalization") 296 297 print(cnf_matrix) 298 np.set_printoptions(precision=2) 299 fig, axs = plt.subplots(figsize=(10, 8), dpi=300) 300 axs.yaxis.set_label_coords(-0.04, 0.5) 301 axs.xaxis.set_label_coords(0.5, -0.005) 302 plt.imshow(cnf_matrix, interpolation="nearest", cmap=cmap) 303 plt.title(title) 304 plt.colorbar() 305 tick_marks = np.arange(len(classes)) 306 plt.xticks(tick_marks, classes, rotation=45) 307 plt.yticks(tick_marks, classes) 308 309 fmt = ".2f" if normalize else "d" 310 thresh = cnf_matrix.max() / 2.0 311 for i, j in itertools.product( 312 range(cnf_matrix.shape[0]), range(cnf_matrix.shape[1]) 313 ): 314 plt.text( 315 j, 316 i, 317 format(cnf_matrix[i, j], fmt), 318 horizontalalignment="center", 319 color="white" if cnf_matrix[i, j] > thresh else "black", 320 ) 321 322 plt.tight_layout() 323 plt.ylabel("True label", fontsize=15) 324 plt.xlabel("Predicted label", fontsize=15) 325 if save_fig: 326 plt.savefig(f"{filename}.png") 327 plt.savefig(f"{filename}.pdf") 328 plt.close() 329 else: 330 plt.show()
Plot created earlier confusion matrix.
Args: cnf_matrix (np.ndarray): Confusion matrix classes (List[str], optional): List of the names of the classes. Defaults to ["proton", "kaon", "pion", "bckgr"]. normalize (bool, optional): Whether should normalize the plot. Defaults to False. title (str, optional): Title of the plot. Defaults to "Confusion matrix". cmap (_type_, optional): Cmap used for colors. Defaults to mplt.colormaps["Blues"]. save_fig (bool, optional): Whether should save the plot. Defaults to True.
333def plot_mass2( 334 xgb_mass: pd.Series, 335 sim_mass: pd.Series, 336 particles_title: str, 337 range1: Tuple[float, float], 338 y_axis_log: bool = False, 339 save_fig: bool = True, 340): 341 """ 342 Plots mass^2 343 344 Args: 345 xgb_mass (pd.Series): pd.Series containg xgb_selected mass^2 346 sim_mass (pd.Series): pd.Series containg MC-true mass^2 347 particles_title (str): Name of the plot. 348 range1 (tuple[float, float]): Range of the mass2 to be plotted on x-axis. 349 y_axis_log (bool, optional): If should use logscale in y-scale. Defaults to False. 350 save_fig (bool, optional): Whether should save the plot. Defaults to True. 351 """ 352 # fig, axs = plt.subplots(2, 1,figsize=(15,10), sharex=True, gridspec_kw={'width_ratios': [10], 353 # 'height_ratios': [8,4]}) 354 fig, axs = plt.subplots(figsize=(15, 10), dpi=300) 355 356 ns, bins, patches = axs.hist( 357 xgb_mass, bins=300, facecolor="red", alpha=0.3, range=range1 358 ) 359 ns1, bins1, patches1 = axs.hist( 360 sim_mass, bins=300, facecolor="blue", alpha=0.3, range=range1 361 ) 362 # plt.xlabel("Mass in GeV", fontsize = 15) 363 axs.set_ylabel("counts", fontsize=15) 364 # axs[0].grid() 365 axs.legend( 366 ("XGBoost selected " + particles_title, "all simulated " + particles_title), 367 loc="upper right", 368 ) 369 if y_axis_log: 370 axs.set_yscale("log") 371 # plt.rcParams["legend.loc"] = 'upper right' 372 title = f"{particles_title} $mass^2$ histogram" 373 yName = r"Counts" 374 xName = r"$m^2$ $(GeV/c^2)^2$" 375 plt.xlabel(xName, fontsize=20, loc="right") 376 plt.ylabel(yName, fontsize=20, loc="top") 377 axs.set_title(title, fontsize=20) 378 axs.grid() 379 axs.tick_params(axis="both", which="major", labelsize=18) 380 if save_fig: 381 plt.savefig(f"mass2_{particles_title}.png") 382 plt.savefig(f"mass2_{particles_title}.pdf") 383 plt.close() 384 else: 385 plt.show()
Plots mass^2
Args: xgb_mass (pd.Series): pd.Series containg xgb_selected mass^2 sim_mass (pd.Series): pd.Series containg MC-true mass^2 particles_title (str): Name of the plot. range1 (tuple[float, float]): Range of the mass2 to be plotted on x-axis. y_axis_log (bool, optional): If should use logscale in y-scale. Defaults to False. save_fig (bool, optional): Whether should save the plot. Defaults to True.
388def plot_all_particles_mass2( 389 xgb_selected: pd.Series, 390 mass2_variable_name: str, 391 pid_variable_name: str, 392 particles_title: str, 393 range1: Tuple[float, float], 394 y_axis_log: bool = False, 395 save_fig: bool = True, 396): 397 """ 398 Plots mc-true particle type in xgb_selected particles 399 400 Args: 401 xgb_selected (pd.Series): pd.Series with xgb-selected particles. 402 mass2_variable_name (str): Name of the mass2 variable name. 403 pid_variable_name (str): Name of the pid variable name. 404 particles_title (str): Name of the plot. 405 range1 (tuple[float, float]): Range of the x-axis. 406 y_axis_log (bool, optional): If should use logscale in y-scale. Defaults to False. 407 save_fig (bool, optional): Whether should save the plot. Defaults to True. 408 """ 409 # fig, axs = plt.subplots(2, 1,figsize=(15,10), sharex=True, gridspec_kw={'width_ratios': [10], 410 # 'height_ratios': [8,4]}) 411 fig, axs = plt.subplots(figsize=(15, 10), dpi=300) 412 413 selected_protons = xgb_selected[xgb_selected[pid_variable_name] == 0][ 414 mass2_variable_name 415 ] 416 selected_kaons = xgb_selected[xgb_selected[pid_variable_name] == 1][ 417 mass2_variable_name 418 ] 419 selected_pions = xgb_selected[xgb_selected[pid_variable_name] == 2][ 420 mass2_variable_name 421 ] 422 423 ns, bins, patches = axs.hist( 424 selected_protons, bins=300, facecolor="blue", alpha=0.4, range=range1 425 ) 426 ns, bins, patches = axs.hist( 427 selected_kaons, bins=300, facecolor="orange", alpha=0.4, range=range1 428 ) 429 ns, bins, patches = axs.hist( 430 selected_pions, bins=300, facecolor="green", alpha=0.4, range=range1 431 ) 432 433 # plt.xlabel("Mass in GeV", fontsize = 15) 434 axs.set_ylabel("counts", fontsize=15) 435 # axs[0].grid() 436 axs.legend( 437 ( 438 f"XGBoost selected true protons", 439 "XGBoost selected true kaons", 440 "XGBoost selected true pions", 441 ), 442 loc="upper right", 443 ) 444 if y_axis_log: 445 axs.set_yscale("log") 446 title = f"ALL XGBoost selected (true and false positive) {particles_title} $mass^2$ histogram" 447 yName = r"Counts" 448 xName = r"$m^2$ $(GeV/c^2)^2$" 449 plt.xlabel(xName, loc="right") 450 plt.ylabel(yName, loc="top") 451 axs.set_title(title) 452 axs.grid() 453 axs.tick_params(axis="both", which="major", labelsize=18) 454 if save_fig: 455 plt.savefig(f"mass2_all_selected_{particles_title}.png") 456 plt.savefig(f"mass2_all_selected_{particles_title}.pdf") 457 plt.close() 458 else: 459 plt.show()
Plots mc-true particle type in xgb_selected particles
Args: xgb_selected (pd.Series): pd.Series with xgb-selected particles. mass2_variable_name (str): Name of the mass2 variable name. pid_variable_name (str): Name of the pid variable name. particles_title (str): Name of the plot. range1 (tuple[float, float]): Range of the x-axis. y_axis_log (bool, optional): If should use logscale in y-scale. Defaults to False. save_fig (bool, optional): Whether should save the plot. Defaults to True.
462def plot_eff_pT_rap( 463 df: pd.DataFrame, 464 pid: int, 465 pid_var_name: str = "Complex_pid", 466 rapidity_var_name: str = "Complex_rapidity", 467 pT_var_name: str = "Complex_pT", 468 ranges: Tuple[Tuple[float, float], Tuple[float, float]] = [[0, 5], [0, 3]], 469 nbins: int = 50, 470 save_fig: bool = True, 471 particle_names: List[str] = ["protons", "kaons", "pions", "bckgr"], 472): 473 df_true = df[(df[pid_var_name] == pid)] # simulated 474 df_reco = df[(df["xgb_preds"] == pid)] # reconstructed by xgboost 475 476 x = np.array(df_true[rapidity_var_name]) 477 y = np.array(df_true[pT_var_name]) 478 479 xe = np.array(df_reco[rapidity_var_name]) 480 ye = np.array(df_reco[pT_var_name]) 481 482 fig = plt.figure(figsize=(8, 10), dpi=300) 483 plt.title(f"$p_T$-rapidity efficiency for all selected {particle_names[pid]}") 484 true, yedges, xedges = np.histogram2d(x, y, bins=nbins, range=ranges) 485 reco, _, _ = np.histogram2d(xe, ye, bins=(yedges, xedges), range=ranges) 486 487 eff = np.divide(true, reco, out=np.zeros_like(true), where=reco != 0) # Efficiency 488 eff[eff == 0] = np.nan # show zeros as white 489 img = plt.imshow( 490 eff, 491 interpolation="nearest", 492 origin="lower", 493 vmin=0, 494 vmax=1, 495 extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]], 496 ) 497 498 cbar = fig.colorbar(img, fraction=0.025, pad=0.08) # above plot H 499 cbar.set_label("efficiency (selected/simulated)", rotation=270, labelpad=20) 500 501 plt.xlabel("rapidity") 502 plt.ylabel("$p_T$ (GeV/c)") 503 plt.tight_layout() 504 if save_fig: 505 plt.savefig(f"eff_pT_rap_{particle_names[pid]}.png") 506 plt.savefig(f"eff_pT_rap_{particle_names[pid]}.pdf") 507 plt.close() 508 else: 509 plt.show()
512def plot_pt_rapidity( 513 df: pd.DataFrame, 514 pid: int, 515 pid_var_name: str = "Complex_pid", 516 rapidity_var_name: str = "Complex_rapidity", 517 pT_var_name: str = "Complex_pT", 518 ranges: Tuple[Tuple[float, float], Tuple[float, float]] = [[0, 5], [0, 3]], 519 nbins=50, 520 save_fig: bool = True, 521 particle_names: List[str] = ["protons", "kaons", "pions", "bckgr"], 522): 523 """ 524 Plots pt-rapidity 2D histogram. 525 526 Args: 527 df (pd.DataFrame): Dataframe with input data. 528 pid (int): Pid of the variable to be plotted. 529 pid_var_name (str, optional): Name of the pid variable. Defaults to "Complex_pid". 530 rapidity_var_name (str, optional): Name of the rapidity variable. Defaults to "Complex_rapidity". 531 pT_var_name (str, optional): Name of the pT variable. Defaults to "Complex_pT". 532 ranges (Tuple[Tuple[float, float], Tuple[float, float]], optional): 533 Ranges of the plot. Defaults to [[0, 5], [0, 3]]. 534 nbins (int, optional): Number of bins in each axis. Defaults to 50. 535 save_fig (bool, optional): Whether should save the figute. Defaults to True. 536 particle_names (List[str], optional): Names of the particles corresponding to pid. 537 Defaults to ["protons", "kaons", "pions", "bckgr"]. 538 """ 539 df_true = df[(df[pid_var_name] == pid)] # simulated 540 541 x = np.array(df_true[rapidity_var_name]) 542 y = np.array(df_true[pT_var_name]) 543 544 fig = plt.figure(figsize=(8, 10), dpi=300) 545 plt.title(f"$p_T$-rapidity graph for all simulated {particle_names[pid]}") 546 547 true, yedges, xedges = np.histogram2d(x, y, bins=nbins, range=ranges) 548 true[true == 0] = np.nan # show zeros as white 549 550 img = plt.imshow( 551 true, 552 interpolation="nearest", 553 origin="lower", 554 extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]], 555 ) 556 557 cbar = fig.colorbar(img, fraction=0.025, pad=0.08) # above plot H 558 cbar.set_label("counts", rotation=270, labelpad=20) 559 560 plt.xlabel("rapidity") 561 plt.ylabel("$p_T$ (GeV/c)") 562 plt.tight_layout() 563 if save_fig: 564 plt.savefig(f"plot_pt_rapidity_{particle_names[pid]}.png") 565 plt.savefig(f"plot_pt_rapidity_{particle_names[pid]}.pdf") 566 plt.close() 567 else: 568 plt.show()
Plots pt-rapidity 2D histogram.
Args: df (pd.DataFrame): Dataframe with input data. pid (int): Pid of the variable to be plotted. pid_var_name (str, optional): Name of the pid variable. Defaults to "Complex_pid". rapidity_var_name (str, optional): Name of the rapidity variable. Defaults to "Complex_rapidity". pT_var_name (str, optional): Name of the pT variable. Defaults to "Complex_pT". ranges (Tuple[Tuple[float, float], Tuple[float, float]], optional): Ranges of the plot. Defaults to [[0, 5], [0, 3]]. nbins (int, optional): Number of bins in each axis. Defaults to 50. save_fig (bool, optional): Whether should save the figute. Defaults to True. particle_names (List[str], optional): Names of the particles corresponding to pid. Defaults to ["protons", "kaons", "pions", "bckgr"].
683def plot_shap_summary( 684 x_train: pd.DataFrame, 685 y_train: pd.DataFrame, 686 model_hdl: ModelHandler, 687 features_names: List[str], 688 n_workers: int = 1, 689 save_fig: bool = True, 690 approximate: bool = False, 691 n_samples: int = 50000, 692 particle_names: List[str] = ["protons", "kaons", "pions"], 693): 694 """ 695 Method for plotting shap plots 696 697 Args: 698 x_train (pd.DataFrame): pd.Dataframe with X training dataset. 699 y_train (pd.DataFrame): X training dataset labels. 700 model_hdl (ModelHandler): Model Handler to be explained. 701 features_names (List[str]): List of the training variables. 702 n_workers (int, optional): Number of thread for multithreading. 703 Note: it uses fastreeshap library, not shap. Defaults to 1. 704 save_fig (bool, optional): Whether should save the plots.. Defaults to True. 705 approximate (bool, optional): Whether should the approximate values. Defaults to False. 706 n_samples (int, optional): Maximal number of samples in each class. Defaults to 50000. 707 particle_names (List[str], optional): List of the classified particle names. 708 Defaults to ["protons", "kaons", "pions"]. 709 """ 710 print("Creating shap plots...") 711 explainer = shap.TreeExplainer( 712 model_hdl.get_original_model(), n_jobs=n_workers, approximate=approximate 713 ) 714 # Apply n_sanples in each class 715 y_train_df = pd.DataFrame(y_train, columns=["true_class"]) 716 merged_df = pd.concat([x_train, y_train_df], axis=1) 717 grouped_df = merged_df.groupby("true_class") 718 resampled_df = pd.concat( 719 [ 720 resample(group, n_samples=min(n_samples, len(group)), replace=False) 721 for _, group in grouped_df 722 ] 723 ) 724 725 # Split the resampled pd.DataFrame back into input data and label data 726 x_train_resampled = resampled_df.iloc[:, :-1] 727 y_train_resampled = resampled_df.iloc[:, -1].to_numpy() 728 del merged_df, grouped_df, resampled_df 729 gc.collect() 730 731 shap_values = explainer.shap_values( 732 x_train_resampled, y_train_resampled, check_additivity=False 733 ) 734 num_classes = len(shap_values) # get the number of classes 735 for i in range(num_classes): 736 _shap_summary( 737 shap_values[i], 738 x_train_resampled, 739 features_names, 740 particle_names[i], 741 save_fig=save_fig, 742 ) 743 _shap_interaction( 744 shap_values[i], 745 x_train_resampled, 746 features_names, 747 particle_names[i], 748 save_fig=save_fig, 749 )
Method for plotting shap plots
Args: x_train (pd.DataFrame): pd.Dataframe with X training dataset. y_train (pd.DataFrame): X training dataset labels. model_hdl (ModelHandler): Model Handler to be explained. features_names (List[str]): List of the training variables. n_workers (int, optional): Number of thread for multithreading. Note: it uses fastreeshap library, not shap. Defaults to 1. save_fig (bool, optional): Whether should save the plots.. Defaults to True. approximate (bool, optional): Whether should the approximate values. Defaults to False. n_samples (int, optional): Maximal number of samples in each class. Defaults to 50000. particle_names (List[str], optional): List of the classified particle names. Defaults to ["protons", "kaons", "pions"].
752def plot_efficiency_purity( 753 probas: np.ndarray, 754 efficiencies: List[List[float]], 755 purities: List[List[float]], 756 save_fig: bool = True, 757 particle_names: List[str] = ["protons", "kaons", "pions"], 758): 759 """ 760 Plots efficiency and purity in function of probability cuts. 761 762 Args: 763 probas (np.ndarray): Probability cuts 764 efficiencies (List[List[float]]): List of list of efficiencies for each clas. 765 purities (List[List[float]]): List of list of purities for each clas. 766 save_fig (bool, optional): Whether should save the fig. Defaults to True. 767 particle_names (List[str], optional): List of the particle names. Defaults to ["protons", "kaons", "pions"]. 768 """ 769 for i, (eff, pur) in enumerate(zip(efficiencies, purities)): 770 if save_fig: 771 dpi = 300 772 else: 773 dpi = 100 774 fig, ax = plt.subplots(figsize=(10, 7), dpi=dpi) 775 ax.plot(probas, eff, label="efficiency") 776 ax.plot(probas, pur, label="purity") 777 ax.set_xlabel("BDT cut") 778 ax.set_ylabel("\% ") 779 ax.legend(loc="upper right") 780 ax.set_title( 781 f"Efficiency and purity in function of BDT cut for {particle_names[i]}" 782 ) 783 ax.grid(which="major", linestyle="-") 784 ax.minorticks_on() 785 ax.grid(which="minor", linestyle="--") 786 if save_fig: 787 fig.savefig(f"efficiency_purity__{particle_names[i]}.png") 788 fig.savefig(f"efficiency_purity_id_{particle_names[i]}.pdf") 789 plt.close() 790 else: 791 plt.show()
Plots efficiency and purity in function of probability cuts.
Args: probas (np.ndarray): Probability cuts efficiencies (List[List[float]]): List of list of efficiencies for each clas. purities (List[List[float]]): List of list of purities for each clas. save_fig (bool, optional): Whether should save the fig. Defaults to True. particle_names (List[str], optional): List of the particle names. Defaults to ["protons", "kaons", "pions"].
795def plot_before_after_variables( 796 df: pd.DataFrame, 797 pid: float, 798 pid_variable_name: str, 799 training_variables: List[str], 800 save_fig: bool = True, 801 log_yscale: bool = True, 802): 803 """ 804 Plots each variable before and after selection. 805 Legacy: var_distributions_plot should rather be used. 806 807 Args: 808 df (pd.DataFrame): _description_ 809 pid (float): _description_ 810 pid_variable_name (str): _description_ 811 training_variables (List[str]): _description_ 812 save_fig (bool, optional): _description_. Defaults to True. 813 log_yscale (bool, optional): _description_. Defaults to True. 814 815 Returns: 816 _type_: _description_ 817 """ 818 df_true = df[(df[pid_variable_name] == pid)] # simulated 819 df_reco = df[(df["xgb_preds"] == pid)] # reconstructed by xgboost 820 821 def variable_plot( 822 df_true, 823 df_reco, 824 variable_name: str, 825 log_yscale: bool = True, 826 leg1="Simulated", 827 leg2="XGB-selected", 828 bins=100, 829 ): 830 fig, ax = plt.subplots(figsize=(15, 15), dpi=300) 831 ax.hist( 832 df_true[variable_name], 833 bins=bins, 834 facecolor="blue", 835 alpha=0.6, 836 histtype="step", 837 fill=False, 838 linewidth=2, 839 ) 840 ax.hist( 841 df_reco[variable_name], 842 bins=bins, 843 facecolor="red", 844 alpha=0.7, 845 histtype="step", 846 fill=False, 847 linewidth=2, 848 ) 849 ax.grid() 850 ax.set_xlabel(variable_name, fontsize=15, loc="right") 851 ax.set_ylim(bottom=1) 852 ax.set_ylabel("counts", fontsize=15) 853 if log_yscale: 854 ax.set_yscale("log") 855 ax.legend((leg1, leg2), fontsize=15, loc="upper right") 856 ax.set_title( 857 f"{variable_name} before and after XGB selection for pid={pid}", fontsize=15 858 ) 859 return fig 860 861 for training_variable in training_variables: 862 plot = variable_plot(df_true, df_reco, training_variable, log_yscale) 863 if save_fig: 864 plot.savefig(f"{training_variable}_before_after_pid_{pid}.png") 865 plot.savefig(f"{training_variable}_before_after_{pid}.pdf") 866 plt.close() 867 868 else: 869 plot.show()
Plots each variable before and after selection. Legacy: var_distributions_plot should rather be used.
Args: df (pd.DataFrame): _description_ pid (float): _description_ pid_variable_name (str): _description_ training_variables (List[str]): _description_ save_fig (bool, optional): _description_. Defaults to True. log_yscale (bool, optional): _description_. Defaults to True.
Returns: _type_: _description_