ml_pid_cbm.tools.plotting_tools

Module with plotting tools

  1"""
  2Module with plotting tools
  3"""
  4import gc
  5import itertools
  6from typing import List, Tuple
  7
  8import fasttreeshap as shap
  9import matplotlib as mplt
 10import matplotlib.colors
 11import matplotlib.pyplot as plt
 12import numpy as np
 13import pandas as pd
 14from hipe4ml.model_handler import ModelHandler
 15from hipe4ml.plot_utils import (plot_corr, plot_distr, plot_output_train_test,
 16                                plot_roc)
 17from hipe4ml.tree_handler import TreeHandler
 18from matplotlib import rcParams
 19from optuna.study import Study
 20from optuna.visualization import plot_contour, plot_optimization_history
 21from sklearn.utils import resample
 22
 23from . import json_tools
 24
 25PARAMS = {
 26    "axes.titlesize": "22",
 27    "axes.labelsize": "22",
 28    "xtick.labelsize": "22",
 29    "ytick.labelsize": "22",
 30    "figure.figsize": "10, 7",
 31    "figure.dpi": "300",
 32    "legend.fontsize": "20",
 33}
 34rcParams.update(PARAMS)
 35
 36
 37def tof_plot(
 38    df: pd.DataFrame,
 39    json_file_name: str,
 40    particles_title: str,
 41    file_name: str = "tof_plot",
 42    x_axis_range: List[int] = [-13, 13],
 43    y_axis_range: List[str] = [-1, 2],
 44    save_fig: bool = True,
 45) -> None:
 46    """
 47    Method for creating tof plots.
 48
 49    Args:
 50        df (pd.DataFrame): Dataframe with particles to plot
 51        json_file_name (str): Name of the config.json file
 52        particles_title (str): Name of the particle type.
 53        file_name (str, optional): Filename to be created. Defaults to "tof_plot".
 54            Will add the particles_title after the tof_plot_ when saved.
 55        x_axis_range (List[int], optional): X-axis range. Defaults to [-13, 13].
 56        y_axis_range (List[str], optional): Y-axi range. Defaults to [-1, 2].
 57        save_fig (bool, optional): Where the figure should be saved. Defaults to True.
 58
 59    Returns:
 60        None.
 61    """
 62    # load variable names
 63    charge_var_name = json_tools.load_var_name(json_file_name, "charge")
 64    momentum_var_name = json_tools.load_var_name(json_file_name, "momentum")
 65    mass2_var_name = json_tools.load_var_name(json_file_name, "mass2")
 66    # prepare plot variables
 67    ranges = [x_axis_range, y_axis_range]
 68    qp = df[charge_var_name] * df[momentum_var_name]
 69    mass2 = df[mass2_var_name]
 70    x_axis_name = r"sign($q$) $\cdot p$ (GeV/c)"
 71    y_axis_name = r"$m^2$ $(GeV/c^2)^2$"
 72    # plot graph
 73    fig, _ = plt.subplots(figsize=(15, 10), dpi=300)
 74    plt.hist2d(qp, mass2, bins=200, norm=matplotlib.colors.LogNorm(), range=ranges)
 75    plt.xlabel(x_axis_name, fontsize=20, loc="right")
 76    plt.ylabel(y_axis_name, fontsize=20, loc="top")
 77    title = f"TOF 2D plot for {particles_title}"
 78    plt.title(title, fontsize=20)
 79    fig.tight_layout()
 80    plt.colorbar()
 81    title = title.replace(" ", "_")
 82    # savefig
 83    if save_fig:
 84        file_name = particles_title.replace(" ", "_")
 85        plt.savefig(f"tof_plot_{file_name}.png")
 86        plt.savefig(f"tof_plot_{file_name}.pdf")
 87        plt.close()
 88    else:
 89        plt.show()
 90    return fig
 91
 92
 93def var_distributions_plot(
 94    vars_to_draw: list,
 95    data_list: List[TreeHandler],
 96    leg_labels: List[str] = ["protons", "kaons", "pions"],
 97    save_fig: bool = True,
 98    filename: str = "vars_disitributions",
 99):
100    """
101    Plots distributions of given variables using plot_distr from hipe4ml.
102
103    Args:
104        vars_to_draw (list): List of variables to draw.
105        data_list (List[TreeHandler]): List of TreeHandlers with data.
106        leg_labels (List[str], optional): Names of the particles which are given in the list of TreeHandlers.
107            Defaults to ["protons", "kaons", "pions"].
108        save_fig (bool, optional): Whether should save the plot. Defaults to True.
109        filename (str, optional): Name of the plot to be saved. Defaults to "vars_disitributions".
110    """
111    plot_distr(
112        data_list,
113        vars_to_draw,
114        bins=100,
115        labels=leg_labels,
116        log=True,
117        figsize=(40, 40),
118        alpha=0.3,
119        grid=False,
120    )
121    if save_fig:
122        plt.savefig(f"{filename}.png")
123        plt.savefig(f"{filename}.pdf")
124        plt.close()
125    else:
126        plt.show()
127
128
129def correlations_plot(
130    vars_to_draw: list,
131    data_list: List[TreeHandler],
132    leg_labels: List[str] = ["protons", "kaons", "pions"],
133    save_fig: bool = True,
134):
135    """
136    Creates correlation plots
137
138    Args:
139        vars_to_draw (list): Variables to check correlations.
140        data_list (List[TreeHandler]): List of TreeHandlers with data.
141        leg_labels (List[str], optional): Names of the particles which are given in the list of TreeHandlers.
142            Defaults to ["protons", "kaons", "pions"].
143        save_fig (bool, optional):  Whether should save the plot. Defaults to True.
144    """
145    plt.subplots_adjust(
146        left=0.06, bottom=0.06, right=0.99, top=0.96, hspace=0.55, wspace=0.55
147    )
148    cor_plots = plot_corr(data_list, vars_to_draw, leg_labels)
149    if isinstance(cor_plots, list):
150        for i, plot in enumerate(cor_plots):
151            if save_fig:
152                plot.savefig(f"correlations_plot_{i}.png")
153                plot.savefig(f"correlations_plot_{i}.pdf")
154                plt.close(plot)
155            else:
156                plot.show()
157    else:
158        if save_fig:
159                cor_plots.savefig(f"correlations_plot.png")
160                cor_plots.savefig(f"correlations_plot.pdf")
161                plt.close(cor_plots)
162        else:
163            cor_plots.show()
164
165
166def opt_history_plot(study: Study, save_fig: bool = True):
167    """
168    Saves optimization history.
169
170    Args:
171        study (Study): optuna.Study to be saved
172        save_fig (bool, optional): Whether should save the plot. Defaults to True.
173    """
174    # for saving python-kaleido package is needed
175    fig = plot_optimization_history(study)
176    if save_fig:
177        fig.write_image("optimization_history.png")
178        fig.write_image("optimization_history.pdf")
179    else:
180        fig.show()
181    plt.close()
182
183
184def opt_contour_plot(study: Study, save_fig: bool = True):
185    """
186    Saves optimization contour plot
187
188    Args:
189        study (Study): optuna.Study to be saved
190        save_fig (bool, optional): Whether should save the plot. Defaults to True.
191    """
192    fig = plot_contour(study)
193    if save_fig:
194        fig.write_image("optimization_contour.png")
195        fig.write_image("optimization_contour.pdf")
196        plt.close()
197    else:
198        plt.show()
199
200
201def output_train_test_plot(
202    model_hdl: ModelHandler,
203    train_test_data,
204    leg_labels: List[str] = ["protons", "kaons", "pions"],
205    logscale: bool = False,
206    save_fig: bool = True,
207):
208    """
209    Output traing plot as in hipe4ml.plot_output_train_test
210
211    Args:
212        model_hdl (ModelHandler): Model handler to be tested
213        train_test_data (_type_): List created by PrepareModel.prepare_train_test_data
214        leg_labels (List[str], optional): Names of the classified particles. Defaults to ["protons", "kaons", "pions"].
215        logscale (bool, optional): Whether should use logscale. Defaults to False.
216        save_fig (bool, optional): Whether should save the plots. Defaults to True.
217    """
218    ml_out_fig = plot_output_train_test(
219        model_hdl,
220        train_test_data,
221        100,
222        False,
223        leg_labels,
224        logscale=logscale,
225        density=False,  # if true histograms are normalized
226    )
227    if len(leg_labels) > 1:
228        for idx, fig in enumerate(ml_out_fig):
229            if save_fig:
230                fig.savefig(f"output_train_test_plot_{idx}.png")
231                fig.savefig(f"output_train_test_plot_{idx}.pdf")
232            else:
233                fig.show()
234    else:
235        if save_fig:
236            ml_out_fig.savefig(f"output_train_test_plot.png")
237            ml_out_fig.savefig(f"output_train_test_plot.pdf")
238        else:
239            ml_out_fig.show()
240    plt.close()
241
242
243def roc_plot(
244    test_df: pd.DataFrame,
245    test_labels_array: np.ndarray,
246    leg_labels: List[str] = ["protons", "kaons", "pions"],
247    save_fig: bool = True,
248):
249    """
250    Roc plot of the model
251
252    Args:
253        test_df (pd.DataFrame): Dataframe containg test_dataset with particles.
254        test_labels_array (np.ndarray): Ndarray containig labels of the test_df.
255        leg_labels (List[str], optional): Names of the classified particles. Defaults to ["protons", "kaons", "pions"].
256        save_fig (bool, optional): Whether should save the plot. Defaults to True.
257    """
258    plot_roc(test_df, test_labels_array, None, leg_labels, multi_class_opt="ovo")
259    if save_fig:
260        plt.savefig("roc_plot.png")
261        plt.savefig("roc_plot.pdf")
262        plt.close()
263    else:
264        plt.show()
265
266
267def plot_confusion_matrix(
268    cnf_matrix: np.ndarray,
269    classes: List[str] = ["proton", "kaon", "pion", "bckgr"],
270    normalize: bool = False,
271    title: str = "Confusion matrix",
272    cmap=mplt.colormaps["Blues"],
273    save_fig: bool = True,
274):
275    """
276    Plot created earlier confusion matrix.
277
278    Args:
279        cnf_matrix (np.ndarray): Confusion matrix
280        classes (List[str], optional): List of the names of the classes.
281            Defaults to ["proton", "kaon", "pion", "bckgr"].
282        normalize (bool, optional): Whether should normalize the plot. Defaults to False.
283        title (str, optional): Title of the plot. Defaults to "Confusion matrix".
284        cmap (_type_, optional): Cmap used for colors. Defaults to mplt.colormaps["Blues"].
285        save_fig (bool, optional): Whether should save the plot. Defaults to True.
286    """
287    filename = "confusion_matrix"
288    if normalize:
289        cnf_matrix = cnf_matrix.astype("float") / cnf_matrix.sum(axis=1)[:, np.newaxis]
290        print("Normalized confusion matrix")
291        title = title + " (normalized)"
292        filename = filename + " (norm)"
293    else:
294        print("Confusion matrix, without normalization")
295
296    print(cnf_matrix)
297    np.set_printoptions(precision=2)
298    fig, axs = plt.subplots(figsize=(10, 8), dpi=300)
299    axs.yaxis.set_label_coords(-0.04, 0.5)
300    axs.xaxis.set_label_coords(0.5, -0.005)
301    plt.imshow(cnf_matrix, interpolation="nearest", cmap=cmap)
302    plt.title(title)
303    plt.colorbar()
304    tick_marks = np.arange(len(classes))
305    plt.xticks(tick_marks, classes, rotation=45)
306    plt.yticks(tick_marks, classes)
307
308    fmt = ".2f" if normalize else "d"
309    thresh = cnf_matrix.max() / 2.0
310    for i, j in itertools.product(
311        range(cnf_matrix.shape[0]), range(cnf_matrix.shape[1])
312    ):
313        plt.text(
314            j,
315            i,
316            format(cnf_matrix[i, j], fmt),
317            horizontalalignment="center",
318            color="white" if cnf_matrix[i, j] > thresh else "black",
319        )
320
321    plt.tight_layout()
322    plt.ylabel("True label", fontsize=15)
323    plt.xlabel("Predicted label", fontsize=15)
324    if save_fig:
325        plt.savefig(f"{filename}.png")
326        plt.savefig(f"{filename}.pdf")
327        plt.close()
328    else:
329        plt.show()
330
331
332def plot_mass2(
333    xgb_mass: pd.Series,
334    sim_mass: pd.Series,
335    particles_title: str,
336    range1: Tuple[float, float],
337    y_axis_log: bool = False,
338    save_fig: bool = True,
339):
340    """
341    Plots mass^2
342
343    Args:
344        xgb_mass (pd.Series): pd.Series containg xgb_selected mass^2
345        sim_mass (pd.Series): pd.Series containg MC-true mass^2
346        particles_title (str): Name of the plot.
347        range1 (tuple[float, float]): Range of the mass2 to be plotted on x-axis.
348        y_axis_log (bool, optional): If should use logscale in y-scale. Defaults to False.
349        save_fig (bool, optional): Whether should save the plot. Defaults to True.
350    """
351    # fig, axs = plt.subplots(2, 1,figsize=(15,10), sharex=True,  gridspec_kw={'width_ratios': [10],
352    #                            'height_ratios': [8,4]})
353    fig, axs = plt.subplots(figsize=(15, 10), dpi=300)
354
355    ns, bins, patches = axs.hist(
356        xgb_mass, bins=300, facecolor="red", alpha=0.3, range=range1
357    )
358    ns1, bins1, patches1 = axs.hist(
359        sim_mass, bins=300, facecolor="blue", alpha=0.3, range=range1
360    )
361    # plt.xlabel("Mass in GeV", fontsize = 15)
362    axs.set_ylabel("counts", fontsize=15)
363    # axs[0].grid()
364    axs.legend(
365        ("XGBoost selected " + particles_title, "all simulated " + particles_title),
366        loc="upper right",
367    )
368    if y_axis_log:
369        axs.set_yscale("log")
370    # plt.rcParams["legend.loc"] = 'upper right'
371    title = f"{particles_title} $mass^2$ histogram"
372    yName = r"Counts"
373    xName = r"$m^2$ $(GeV/c^2)^2$"
374    plt.xlabel(xName, fontsize=20, loc="right")
375    plt.ylabel(yName, fontsize=20, loc="top")
376    axs.set_title(title, fontsize=20)
377    axs.grid()
378    axs.tick_params(axis="both", which="major", labelsize=18)
379    if save_fig:
380        plt.savefig(f"mass2_{particles_title}.png")
381        plt.savefig(f"mass2_{particles_title}.pdf")
382        plt.close()
383    else:
384        plt.show()
385
386
387def plot_all_particles_mass2(
388    xgb_selected: pd.Series,
389    mass2_variable_name: str,
390    pid_variable_name: str,
391    particles_title: str,
392    range1: Tuple[float, float],
393    y_axis_log: bool = False,
394    save_fig: bool = True,
395):
396    """
397    Plots mc-true particle type in xgb_selected particles
398
399    Args:
400        xgb_selected (pd.Series): pd.Series with xgb-selected particles.
401        mass2_variable_name (str): Name of the mass2 variable name.
402        pid_variable_name (str): Name of the pid variable name.
403        particles_title (str): Name of the plot.
404        range1 (tuple[float, float]): Range of the x-axis.
405        y_axis_log (bool, optional): If should use logscale in y-scale. Defaults to False.
406        save_fig (bool, optional): Whether should save the plot. Defaults to True.
407    """
408    # fig, axs = plt.subplots(2, 1,figsize=(15,10), sharex=True,  gridspec_kw={'width_ratios': [10],
409    #                            'height_ratios': [8,4]})
410    fig, axs = plt.subplots(figsize=(15, 10), dpi=300)
411
412    selected_protons = xgb_selected[xgb_selected[pid_variable_name] == 0][
413        mass2_variable_name
414    ]
415    selected_kaons = xgb_selected[xgb_selected[pid_variable_name] == 1][
416        mass2_variable_name
417    ]
418    selected_pions = xgb_selected[xgb_selected[pid_variable_name] == 2][
419        mass2_variable_name
420    ]
421
422    ns, bins, patches = axs.hist(
423        selected_protons, bins=300, facecolor="blue", alpha=0.4, range=range1
424    )
425    ns, bins, patches = axs.hist(
426        selected_kaons, bins=300, facecolor="orange", alpha=0.4, range=range1
427    )
428    ns, bins, patches = axs.hist(
429        selected_pions, bins=300, facecolor="green", alpha=0.4, range=range1
430    )
431
432    # plt.xlabel("Mass in GeV", fontsize = 15)
433    axs.set_ylabel("counts", fontsize=15)
434    # axs[0].grid()
435    axs.legend(
436        (
437            f"XGBoost selected true protons",
438            "XGBoost selected true kaons",
439            "XGBoost selected true pions",
440        ),
441        loc="upper right",
442    )
443    if y_axis_log:
444        axs.set_yscale("log")
445    title = f"ALL XGBoost selected (true and false positive) {particles_title} $mass^2$ histogram"
446    yName = r"Counts"
447    xName = r"$m^2$ $(GeV/c^2)^2$"
448    plt.xlabel(xName, loc="right")
449    plt.ylabel(yName, loc="top")
450    axs.set_title(title)
451    axs.grid()
452    axs.tick_params(axis="both", which="major", labelsize=18)
453    if save_fig:
454        plt.savefig(f"mass2_all_selected_{particles_title}.png")
455        plt.savefig(f"mass2_all_selected_{particles_title}.pdf")
456        plt.close()
457    else:
458        plt.show()
459
460
461def plot_eff_pT_rap(
462    df: pd.DataFrame,
463    pid: int,
464    pid_var_name: str = "Complex_pid",
465    rapidity_var_name: str = "Complex_rapidity",
466    pT_var_name: str = "Complex_pT",
467    ranges: Tuple[Tuple[float, float], Tuple[float, float]] = [[0, 5], [0, 3]],
468    nbins: int = 50,
469    save_fig: bool = True,
470    particle_names: List[str] = ["protons", "kaons", "pions", "bckgr"],
471):
472    df_true = df[(df[pid_var_name] == pid)]  # simulated
473    df_reco = df[(df["xgb_preds"] == pid)]  # reconstructed by xgboost
474
475    x = np.array(df_true[rapidity_var_name])
476    y = np.array(df_true[pT_var_name])
477
478    xe = np.array(df_reco[rapidity_var_name])
479    ye = np.array(df_reco[pT_var_name])
480
481    fig = plt.figure(figsize=(8, 10), dpi=300)
482    plt.title(f"$p_T$-rapidity efficiency for all selected {particle_names[pid]}")
483    true, yedges, xedges = np.histogram2d(x, y, bins=nbins, range=ranges)
484    reco, _, _ = np.histogram2d(xe, ye, bins=(yedges, xedges), range=ranges)
485
486    eff = np.divide(true, reco, out=np.zeros_like(true), where=reco != 0)  # Efficiency
487    eff[eff == 0] = np.nan  # show zeros as white
488    img = plt.imshow(
489        eff,
490        interpolation="nearest",
491        origin="lower",
492        vmin=0,
493        vmax=1,
494        extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]],
495    )
496
497    cbar = fig.colorbar(img, fraction=0.025, pad=0.08)  # above plot H
498    cbar.set_label("efficiency (selected/simulated)", rotation=270, labelpad=20)
499
500    plt.xlabel("rapidity")
501    plt.ylabel("$p_T$ (GeV/c)")
502    plt.tight_layout()
503    if save_fig:
504        plt.savefig(f"eff_pT_rap_{particle_names[pid]}.png")
505        plt.savefig(f"eff_pT_rap_{particle_names[pid]}.pdf")
506        plt.close()
507    else:
508        plt.show()
509
510
511def plot_pt_rapidity(
512    df: pd.DataFrame,
513    pid: int,
514    pid_var_name: str = "Complex_pid",
515    rapidity_var_name: str = "Complex_rapidity",
516    pT_var_name: str = "Complex_pT",
517    ranges: Tuple[Tuple[float, float], Tuple[float, float]] = [[0, 5], [0, 3]],
518    nbins=50,
519    save_fig: bool = True,
520    particle_names: List[str] = ["protons", "kaons", "pions", "bckgr"],
521):
522    """
523    Plots pt-rapidity 2D histogram.
524
525    Args:
526        df (pd.DataFrame): Dataframe with input data.
527        pid (int): Pid of the variable to be plotted.
528        pid_var_name (str, optional): Name of the pid variable. Defaults to "Complex_pid".
529        rapidity_var_name (str, optional): Name of the rapidity variable. Defaults to "Complex_rapidity".
530        pT_var_name (str, optional): Name of the pT variable. Defaults to "Complex_pT".
531        ranges (Tuple[Tuple[float, float], Tuple[float, float]], optional):
532            Ranges of the plot. Defaults to [[0, 5], [0, 3]].
533        nbins (int, optional): Number of bins in each axis. Defaults to 50.
534        save_fig (bool, optional): Whether should save the figute. Defaults to True.
535        particle_names (List[str], optional): Names of the particles corresponding to pid.
536            Defaults to ["protons", "kaons", "pions", "bckgr"].
537    """
538    df_true = df[(df[pid_var_name] == pid)]  # simulated
539
540    x = np.array(df_true[rapidity_var_name])
541    y = np.array(df_true[pT_var_name])
542
543    fig = plt.figure(figsize=(8, 10), dpi=300)
544    plt.title(f"$p_T$-rapidity graph for all simulated {particle_names[pid]}")
545
546    true, yedges, xedges = np.histogram2d(x, y, bins=nbins, range=ranges)
547    true[true == 0] = np.nan  # show zeros as white
548
549    img = plt.imshow(
550        true,
551        interpolation="nearest",
552        origin="lower",
553        extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]],
554    )
555
556    cbar = fig.colorbar(img, fraction=0.025, pad=0.08)  # above plot H
557    cbar.set_label("counts", rotation=270, labelpad=20)
558
559    plt.xlabel("rapidity")
560    plt.ylabel("$p_T$ (GeV/c)")
561    plt.tight_layout()
562    if save_fig:
563        plt.savefig(f"plot_pt_rapidity_{particle_names[pid]}.png")
564        plt.savefig(f"plot_pt_rapidity_{particle_names[pid]}.pdf")
565        plt.close()
566    else:
567        plt.show()
568
569
570def _shap_summary(
571    shap_values,
572    x_train_resampled: pd.DataFrame,
573    features_names: List[str],
574    particle_name: str,
575    save_fig: bool = True,
576):
577    """
578    Internal method for plotting summary shap plots.
579
580    Args:
581        shap_values (_type_): Shap values.
582        x_train_resampled (pd.DataFrame): Dataframe with X training variables.
583        features_names (List[str]): List of the training variables.
584        particle_name (str): Name of the particle.
585        save_fig (bool, optional): Whether should save the plot. Defaults to True.
586    """
587    fig, ax = plt.subplots(figsize=(8, 6), dpi=300)
588    shap.summary_plot(
589        shap_values,
590        x_train_resampled,
591        feature_names=features_names,
592        show=False,
593    )
594    w, h = plt.gcf().get_size_inches()
595    plt.gcf().set_size_inches(h + 2, h)
596    plt.gcf().set_size_inches(w, w * 3 / 4)
597    plt.gcf().axes[-1].set_aspect("auto")
598    plt.gcf().axes[-1].set_box_aspect(50)
599    plt.xlabel(f"SHAP values for  {particle_name}", fontsize=18)
600    ax.spines["top"].set_visible(True)
601    ax.spines["right"].set_visible(True)
602    ax.spines["bottom"].set_visible(True)
603    ax.spines["left"].set_visible(True)
604    ax.tick_params(
605        axis="both",
606        which="major",
607        length=10,
608        direction="in",
609        labelsize=15,
610        zorder=4,
611    )
612    ax.minorticks_on()
613    ax.tick_params(
614        axis="both", which="minor", length=5, direction="in", labelsize=15, zorder=5
615    )
616    fig.tight_layout()
617    if save_fig:
618        plt.savefig(f"shap_summary_{particle_name}.png")
619        plt.savefig(f"shap_summary_{particle_name}.pdf")
620        plt.close()
621    else:
622        plt.show()
623
624
625def _shap_interaction(
626    shap_values,
627    x_train_resampled,
628    features_names,
629    particle_name: str,
630    save_fig: bool = True,
631):
632    """
633    Internal method for plotting shap interaction plots.
634
635    Args:
636        shap_values (_type_): Shap values.
637        x_train_resampled (pd.DataFrame): Dataframe with X training variables.
638        features_names (List[str]): List of the training variables.
639        particle_name (str): Name of the particle.
640        save_fig (bool, optional): Whether should save the plot. Defaults to True.
641    """
642    for feature in features_names:
643        fig, ax = plt.subplots(figsize=(8, 6), dpi=300)
644        shap.dependence_plot(
645            feature,
646            shap_values,
647            x_train_resampled,
648            display_features=x_train_resampled,
649            show=False,
650        )
651        w, h = plt.gcf().get_size_inches()
652        plt.gcf().set_size_inches(h + 2, h)
653        plt.gcf().set_size_inches(w, w * 3 / 4)
654        plt.gcf().axes[-1].set_aspect("auto")
655        plt.gcf().axes[-1].set_box_aspect(50)
656        plt.xlabel(f"{feature} for {particle_name}", fontsize=18)
657        ax.spines["top"].set_visible(True)
658        ax.spines["right"].set_visible(True)
659        ax.spines["bottom"].set_visible(True)
660        ax.spines["left"].set_visible(True)
661        ax.tick_params(
662            axis="both",
663            which="major",
664            length=10,
665            direction="in",
666            labelsize=15,
667            zorder=4,
668        )
669        ax.minorticks_on()
670        ax.tick_params(
671            axis="both", which="minor", length=5, direction="in", labelsize=15, zorder=5
672        )
673        fig.tight_layout()
674        if save_fig:
675            plt.savefig(f"shap_{feature}_{particle_name}.png")
676            plt.savefig(f"shap_{feature}_{particle_name}.pdf")
677            plt.close()
678        else:
679            plt.show()
680
681
682def plot_shap_summary(
683    x_train: pd.DataFrame,
684    y_train: pd.DataFrame,
685    model_hdl: ModelHandler,
686    features_names: List[str],
687    n_workers: int = 1,
688    save_fig: bool = True,
689    approximate: bool = False,
690    n_samples: int = 50000,
691    particle_names: List[str] = ["protons", "kaons", "pions"],
692):
693    """
694    Method for plotting shap plots
695
696    Args:
697        x_train (pd.DataFrame): pd.Dataframe with X training dataset.
698        y_train (pd.DataFrame): X training dataset labels.
699        model_hdl (ModelHandler): Model Handler to be explained.
700        features_names (List[str]): List of the training variables.
701        n_workers (int, optional): Number of thread for multithreading.
702            Note: it uses fastreeshap library, not shap. Defaults to 1.
703        save_fig (bool, optional): Whether should save the plots.. Defaults to True.
704        approximate (bool, optional): Whether should the approximate values. Defaults to False.
705        n_samples (int, optional): Maximal number of samples in each class. Defaults to 50000.
706        particle_names (List[str], optional): List of the classified particle names.
707            Defaults to ["protons", "kaons", "pions"].
708    """
709    print("Creating shap plots...")
710    explainer = shap.TreeExplainer(
711        model_hdl.get_original_model(), n_jobs=n_workers, approximate=approximate
712    )
713    # Apply n_sanples in each class
714    y_train_df = pd.DataFrame(y_train, columns=["true_class"])
715    merged_df = pd.concat([x_train, y_train_df], axis=1)
716    grouped_df = merged_df.groupby("true_class")
717    resampled_df = pd.concat(
718        [
719            resample(group, n_samples=min(n_samples, len(group)), replace=False)
720            for _, group in grouped_df
721        ]
722    )
723
724    # Split the resampled pd.DataFrame back into input data and label data
725    x_train_resampled = resampled_df.iloc[:, :-1]
726    y_train_resampled = resampled_df.iloc[:, -1].to_numpy()
727    del merged_df, grouped_df, resampled_df
728    gc.collect()
729
730    shap_values = explainer.shap_values(
731        x_train_resampled, y_train_resampled, check_additivity=False
732    )
733    num_classes = len(shap_values)  # get the number of classes
734    for i in range(num_classes):
735        _shap_summary(
736            shap_values[i],
737            x_train_resampled,
738            features_names,
739            particle_names[i],
740            save_fig=save_fig,
741        )
742        _shap_interaction(
743            shap_values[i],
744            x_train_resampled,
745            features_names,
746            particle_names[i],
747            save_fig=save_fig,
748        )
749
750
751def plot_efficiency_purity(
752    probas: np.ndarray,
753    efficiencies: List[List[float]],
754    purities: List[List[float]],
755    save_fig: bool = True,
756    particle_names: List[str] = ["protons", "kaons", "pions"],
757):
758    """
759    Plots efficiency and purity in function of probability cuts.
760
761    Args:
762        probas (np.ndarray): Probability cuts
763        efficiencies (List[List[float]]): List of list of efficiencies for each clas.
764        purities (List[List[float]]): List of list of purities for each clas.
765        save_fig (bool, optional): Whether should save the fig. Defaults to True.
766        particle_names (List[str], optional): List of the particle names. Defaults to ["protons", "kaons", "pions"].
767    """
768    for i, (eff, pur) in enumerate(zip(efficiencies, purities)):
769        if save_fig:
770            dpi = 300
771        else:
772            dpi = 100
773        fig, ax = plt.subplots(figsize=(10, 7), dpi=dpi)
774        ax.plot(probas, eff, label="efficiency")
775        ax.plot(probas, pur, label="purity")
776        ax.set_xlabel("BDT cut")
777        ax.set_ylabel("\% ")
778        ax.legend(loc="upper right")
779        ax.set_title(
780            f"Efficiency and purity in function of BDT cut for {particle_names[i]}"
781        )
782        ax.grid(which="major", linestyle="-")
783        ax.minorticks_on()
784        ax.grid(which="minor", linestyle="--")
785        if save_fig:
786            fig.savefig(f"efficiency_purity__{particle_names[i]}.png")
787            fig.savefig(f"efficiency_purity_id_{particle_names[i]}.pdf")
788            plt.close()
789        else:
790            plt.show()
791
792
793# deprecated
794def plot_before_after_variables(
795    df: pd.DataFrame,
796    pid: float,
797    pid_variable_name: str,
798    training_variables: List[str],
799    save_fig: bool = True,
800    log_yscale: bool = True,
801):
802    """
803    Plots each variable before and after selection.
804    Legacy: var_distributions_plot should rather be used.
805
806    Args:
807        df (pd.DataFrame): _description_
808        pid (float): _description_
809        pid_variable_name (str): _description_
810        training_variables (List[str]): _description_
811        save_fig (bool, optional): _description_. Defaults to True.
812        log_yscale (bool, optional): _description_. Defaults to True.
813
814    Returns:
815        _type_: _description_
816    """
817    df_true = df[(df[pid_variable_name] == pid)]  # simulated
818    df_reco = df[(df["xgb_preds"] == pid)]  # reconstructed by xgboost
819
820    def variable_plot(
821        df_true,
822        df_reco,
823        variable_name: str,
824        log_yscale: bool = True,
825        leg1="Simulated",
826        leg2="XGB-selected",
827        bins=100,
828    ):
829        fig, ax = plt.subplots(figsize=(15, 15), dpi=300)
830        ax.hist(
831            df_true[variable_name],
832            bins=bins,
833            facecolor="blue",
834            alpha=0.6,
835            histtype="step",
836            fill=False,
837            linewidth=2,
838        )
839        ax.hist(
840            df_reco[variable_name],
841            bins=bins,
842            facecolor="red",
843            alpha=0.7,
844            histtype="step",
845            fill=False,
846            linewidth=2,
847        )
848        ax.grid()
849        ax.set_xlabel(variable_name, fontsize=15, loc="right")
850        ax.set_ylim(bottom=1)
851        ax.set_ylabel("counts", fontsize=15)
852        if log_yscale:
853            ax.set_yscale("log")
854        ax.legend((leg1, leg2), fontsize=15, loc="upper right")
855        ax.set_title(
856            f"{variable_name} before and after XGB selection for pid={pid}", fontsize=15
857        )
858        return fig
859
860    for training_variable in training_variables:
861        plot = variable_plot(df_true, df_reco, training_variable, log_yscale)
862        if save_fig:
863            plot.savefig(f"{training_variable}_before_after_pid_{pid}.png")
864            plot.savefig(f"{training_variable}_before_after_{pid}.pdf")
865            plt.close()
866
867        else:
868            plot.show()
def tof_plot( df: pandas.core.frame.DataFrame, json_file_name: str, particles_title: str, file_name: str = 'tof_plot', x_axis_range: List[int] = [-13, 13], y_axis_range: List[str] = [-1, 2], save_fig: bool = True) -> None:
38def tof_plot(
39    df: pd.DataFrame,
40    json_file_name: str,
41    particles_title: str,
42    file_name: str = "tof_plot",
43    x_axis_range: List[int] = [-13, 13],
44    y_axis_range: List[str] = [-1, 2],
45    save_fig: bool = True,
46) -> None:
47    """
48    Method for creating tof plots.
49
50    Args:
51        df (pd.DataFrame): Dataframe with particles to plot
52        json_file_name (str): Name of the config.json file
53        particles_title (str): Name of the particle type.
54        file_name (str, optional): Filename to be created. Defaults to "tof_plot".
55            Will add the particles_title after the tof_plot_ when saved.
56        x_axis_range (List[int], optional): X-axis range. Defaults to [-13, 13].
57        y_axis_range (List[str], optional): Y-axi range. Defaults to [-1, 2].
58        save_fig (bool, optional): Where the figure should be saved. Defaults to True.
59
60    Returns:
61        None.
62    """
63    # load variable names
64    charge_var_name = json_tools.load_var_name(json_file_name, "charge")
65    momentum_var_name = json_tools.load_var_name(json_file_name, "momentum")
66    mass2_var_name = json_tools.load_var_name(json_file_name, "mass2")
67    # prepare plot variables
68    ranges = [x_axis_range, y_axis_range]
69    qp = df[charge_var_name] * df[momentum_var_name]
70    mass2 = df[mass2_var_name]
71    x_axis_name = r"sign($q$) $\cdot p$ (GeV/c)"
72    y_axis_name = r"$m^2$ $(GeV/c^2)^2$"
73    # plot graph
74    fig, _ = plt.subplots(figsize=(15, 10), dpi=300)
75    plt.hist2d(qp, mass2, bins=200, norm=matplotlib.colors.LogNorm(), range=ranges)
76    plt.xlabel(x_axis_name, fontsize=20, loc="right")
77    plt.ylabel(y_axis_name, fontsize=20, loc="top")
78    title = f"TOF 2D plot for {particles_title}"
79    plt.title(title, fontsize=20)
80    fig.tight_layout()
81    plt.colorbar()
82    title = title.replace(" ", "_")
83    # savefig
84    if save_fig:
85        file_name = particles_title.replace(" ", "_")
86        plt.savefig(f"tof_plot_{file_name}.png")
87        plt.savefig(f"tof_plot_{file_name}.pdf")
88        plt.close()
89    else:
90        plt.show()
91    return fig

Method for creating tof plots.

Args: df (pd.DataFrame): Dataframe with particles to plot json_file_name (str): Name of the config.json file particles_title (str): Name of the particle type. file_name (str, optional): Filename to be created. Defaults to "tof_plot". Will add the particles_title after the tof_plot_ when saved. x_axis_range (List[int], optional): X-axis range. Defaults to [-13, 13]. y_axis_range (List[str], optional): Y-axi range. Defaults to [-1, 2]. save_fig (bool, optional): Where the figure should be saved. Defaults to True.

Returns: None.

def var_distributions_plot( vars_to_draw: list, data_list: List[hipe4ml.tree_handler.TreeHandler], leg_labels: List[str] = ['protons', 'kaons', 'pions'], save_fig: bool = True, filename: str = 'vars_disitributions'):
 94def var_distributions_plot(
 95    vars_to_draw: list,
 96    data_list: List[TreeHandler],
 97    leg_labels: List[str] = ["protons", "kaons", "pions"],
 98    save_fig: bool = True,
 99    filename: str = "vars_disitributions",
100):
101    """
102    Plots distributions of given variables using plot_distr from hipe4ml.
103
104    Args:
105        vars_to_draw (list): List of variables to draw.
106        data_list (List[TreeHandler]): List of TreeHandlers with data.
107        leg_labels (List[str], optional): Names of the particles which are given in the list of TreeHandlers.
108            Defaults to ["protons", "kaons", "pions"].
109        save_fig (bool, optional): Whether should save the plot. Defaults to True.
110        filename (str, optional): Name of the plot to be saved. Defaults to "vars_disitributions".
111    """
112    plot_distr(
113        data_list,
114        vars_to_draw,
115        bins=100,
116        labels=leg_labels,
117        log=True,
118        figsize=(40, 40),
119        alpha=0.3,
120        grid=False,
121    )
122    if save_fig:
123        plt.savefig(f"{filename}.png")
124        plt.savefig(f"{filename}.pdf")
125        plt.close()
126    else:
127        plt.show()

Plots distributions of given variables using plot_distr from hipe4ml.

Args: vars_to_draw (list): List of variables to draw. data_list (List[TreeHandler]): List of TreeHandlers with data. leg_labels (List[str], optional): Names of the particles which are given in the list of TreeHandlers. Defaults to ["protons", "kaons", "pions"]. save_fig (bool, optional): Whether should save the plot. Defaults to True. filename (str, optional): Name of the plot to be saved. Defaults to "vars_disitributions".

def correlations_plot( vars_to_draw: list, data_list: List[hipe4ml.tree_handler.TreeHandler], leg_labels: List[str] = ['protons', 'kaons', 'pions'], save_fig: bool = True):
130def correlations_plot(
131    vars_to_draw: list,
132    data_list: List[TreeHandler],
133    leg_labels: List[str] = ["protons", "kaons", "pions"],
134    save_fig: bool = True,
135):
136    """
137    Creates correlation plots
138
139    Args:
140        vars_to_draw (list): Variables to check correlations.
141        data_list (List[TreeHandler]): List of TreeHandlers with data.
142        leg_labels (List[str], optional): Names of the particles which are given in the list of TreeHandlers.
143            Defaults to ["protons", "kaons", "pions"].
144        save_fig (bool, optional):  Whether should save the plot. Defaults to True.
145    """
146    plt.subplots_adjust(
147        left=0.06, bottom=0.06, right=0.99, top=0.96, hspace=0.55, wspace=0.55
148    )
149    cor_plots = plot_corr(data_list, vars_to_draw, leg_labels)
150    if isinstance(cor_plots, list):
151        for i, plot in enumerate(cor_plots):
152            if save_fig:
153                plot.savefig(f"correlations_plot_{i}.png")
154                plot.savefig(f"correlations_plot_{i}.pdf")
155                plt.close(plot)
156            else:
157                plot.show()
158    else:
159        if save_fig:
160                cor_plots.savefig(f"correlations_plot.png")
161                cor_plots.savefig(f"correlations_plot.pdf")
162                plt.close(cor_plots)
163        else:
164            cor_plots.show()

Creates correlation plots

Args: vars_to_draw (list): Variables to check correlations. data_list (List[TreeHandler]): List of TreeHandlers with data. leg_labels (List[str], optional): Names of the particles which are given in the list of TreeHandlers. Defaults to ["protons", "kaons", "pions"]. save_fig (bool, optional): Whether should save the plot. Defaults to True.

def opt_history_plot(study: optuna.study.study.Study, save_fig: bool = True):
167def opt_history_plot(study: Study, save_fig: bool = True):
168    """
169    Saves optimization history.
170
171    Args:
172        study (Study): optuna.Study to be saved
173        save_fig (bool, optional): Whether should save the plot. Defaults to True.
174    """
175    # for saving python-kaleido package is needed
176    fig = plot_optimization_history(study)
177    if save_fig:
178        fig.write_image("optimization_history.png")
179        fig.write_image("optimization_history.pdf")
180    else:
181        fig.show()
182    plt.close()

Saves optimization history.

Args: study (Study): optuna.Study to be saved save_fig (bool, optional): Whether should save the plot. Defaults to True.

def opt_contour_plot(study: optuna.study.study.Study, save_fig: bool = True):
185def opt_contour_plot(study: Study, save_fig: bool = True):
186    """
187    Saves optimization contour plot
188
189    Args:
190        study (Study): optuna.Study to be saved
191        save_fig (bool, optional): Whether should save the plot. Defaults to True.
192    """
193    fig = plot_contour(study)
194    if save_fig:
195        fig.write_image("optimization_contour.png")
196        fig.write_image("optimization_contour.pdf")
197        plt.close()
198    else:
199        plt.show()

Saves optimization contour plot

Args: study (Study): optuna.Study to be saved save_fig (bool, optional): Whether should save the plot. Defaults to True.

def output_train_test_plot( model_hdl: hipe4ml.model_handler.ModelHandler, train_test_data, leg_labels: List[str] = ['protons', 'kaons', 'pions'], logscale: bool = False, save_fig: bool = True):
202def output_train_test_plot(
203    model_hdl: ModelHandler,
204    train_test_data,
205    leg_labels: List[str] = ["protons", "kaons", "pions"],
206    logscale: bool = False,
207    save_fig: bool = True,
208):
209    """
210    Output traing plot as in hipe4ml.plot_output_train_test
211
212    Args:
213        model_hdl (ModelHandler): Model handler to be tested
214        train_test_data (_type_): List created by PrepareModel.prepare_train_test_data
215        leg_labels (List[str], optional): Names of the classified particles. Defaults to ["protons", "kaons", "pions"].
216        logscale (bool, optional): Whether should use logscale. Defaults to False.
217        save_fig (bool, optional): Whether should save the plots. Defaults to True.
218    """
219    ml_out_fig = plot_output_train_test(
220        model_hdl,
221        train_test_data,
222        100,
223        False,
224        leg_labels,
225        logscale=logscale,
226        density=False,  # if true histograms are normalized
227    )
228    if len(leg_labels) > 1:
229        for idx, fig in enumerate(ml_out_fig):
230            if save_fig:
231                fig.savefig(f"output_train_test_plot_{idx}.png")
232                fig.savefig(f"output_train_test_plot_{idx}.pdf")
233            else:
234                fig.show()
235    else:
236        if save_fig:
237            ml_out_fig.savefig(f"output_train_test_plot.png")
238            ml_out_fig.savefig(f"output_train_test_plot.pdf")
239        else:
240            ml_out_fig.show()
241    plt.close()

Output traing plot as in hipe4ml.plot_output_train_test

Args: model_hdl (ModelHandler): Model handler to be tested train_test_data (_type_): List created by PrepareModel.prepare_train_test_data leg_labels (List[str], optional): Names of the classified particles. Defaults to ["protons", "kaons", "pions"]. logscale (bool, optional): Whether should use logscale. Defaults to False. save_fig (bool, optional): Whether should save the plots. Defaults to True.

def roc_plot( test_df: pandas.core.frame.DataFrame, test_labels_array: numpy.ndarray, leg_labels: List[str] = ['protons', 'kaons', 'pions'], save_fig: bool = True):
244def roc_plot(
245    test_df: pd.DataFrame,
246    test_labels_array: np.ndarray,
247    leg_labels: List[str] = ["protons", "kaons", "pions"],
248    save_fig: bool = True,
249):
250    """
251    Roc plot of the model
252
253    Args:
254        test_df (pd.DataFrame): Dataframe containg test_dataset with particles.
255        test_labels_array (np.ndarray): Ndarray containig labels of the test_df.
256        leg_labels (List[str], optional): Names of the classified particles. Defaults to ["protons", "kaons", "pions"].
257        save_fig (bool, optional): Whether should save the plot. Defaults to True.
258    """
259    plot_roc(test_df, test_labels_array, None, leg_labels, multi_class_opt="ovo")
260    if save_fig:
261        plt.savefig("roc_plot.png")
262        plt.savefig("roc_plot.pdf")
263        plt.close()
264    else:
265        plt.show()

Roc plot of the model

Args: test_df (pd.DataFrame): Dataframe containg test_dataset with particles. test_labels_array (np.ndarray): Ndarray containig labels of the test_df. leg_labels (List[str], optional): Names of the classified particles. Defaults to ["protons", "kaons", "pions"]. save_fig (bool, optional): Whether should save the plot. Defaults to True.

def plot_confusion_matrix( cnf_matrix: numpy.ndarray, classes: List[str] = ['proton', 'kaon', 'pion', 'bckgr'], normalize: bool = False, title: str = 'Confusion matrix', cmap=<matplotlib.colors.LinearSegmentedColormap object>, save_fig: bool = True):
268def plot_confusion_matrix(
269    cnf_matrix: np.ndarray,
270    classes: List[str] = ["proton", "kaon", "pion", "bckgr"],
271    normalize: bool = False,
272    title: str = "Confusion matrix",
273    cmap=mplt.colormaps["Blues"],
274    save_fig: bool = True,
275):
276    """
277    Plot created earlier confusion matrix.
278
279    Args:
280        cnf_matrix (np.ndarray): Confusion matrix
281        classes (List[str], optional): List of the names of the classes.
282            Defaults to ["proton", "kaon", "pion", "bckgr"].
283        normalize (bool, optional): Whether should normalize the plot. Defaults to False.
284        title (str, optional): Title of the plot. Defaults to "Confusion matrix".
285        cmap (_type_, optional): Cmap used for colors. Defaults to mplt.colormaps["Blues"].
286        save_fig (bool, optional): Whether should save the plot. Defaults to True.
287    """
288    filename = "confusion_matrix"
289    if normalize:
290        cnf_matrix = cnf_matrix.astype("float") / cnf_matrix.sum(axis=1)[:, np.newaxis]
291        print("Normalized confusion matrix")
292        title = title + " (normalized)"
293        filename = filename + " (norm)"
294    else:
295        print("Confusion matrix, without normalization")
296
297    print(cnf_matrix)
298    np.set_printoptions(precision=2)
299    fig, axs = plt.subplots(figsize=(10, 8), dpi=300)
300    axs.yaxis.set_label_coords(-0.04, 0.5)
301    axs.xaxis.set_label_coords(0.5, -0.005)
302    plt.imshow(cnf_matrix, interpolation="nearest", cmap=cmap)
303    plt.title(title)
304    plt.colorbar()
305    tick_marks = np.arange(len(classes))
306    plt.xticks(tick_marks, classes, rotation=45)
307    plt.yticks(tick_marks, classes)
308
309    fmt = ".2f" if normalize else "d"
310    thresh = cnf_matrix.max() / 2.0
311    for i, j in itertools.product(
312        range(cnf_matrix.shape[0]), range(cnf_matrix.shape[1])
313    ):
314        plt.text(
315            j,
316            i,
317            format(cnf_matrix[i, j], fmt),
318            horizontalalignment="center",
319            color="white" if cnf_matrix[i, j] > thresh else "black",
320        )
321
322    plt.tight_layout()
323    plt.ylabel("True label", fontsize=15)
324    plt.xlabel("Predicted label", fontsize=15)
325    if save_fig:
326        plt.savefig(f"{filename}.png")
327        plt.savefig(f"{filename}.pdf")
328        plt.close()
329    else:
330        plt.show()

Plot created earlier confusion matrix.

Args: cnf_matrix (np.ndarray): Confusion matrix classes (List[str], optional): List of the names of the classes. Defaults to ["proton", "kaon", "pion", "bckgr"]. normalize (bool, optional): Whether should normalize the plot. Defaults to False. title (str, optional): Title of the plot. Defaults to "Confusion matrix". cmap (_type_, optional): Cmap used for colors. Defaults to mplt.colormaps["Blues"]. save_fig (bool, optional): Whether should save the plot. Defaults to True.

def plot_mass2( xgb_mass: pandas.core.series.Series, sim_mass: pandas.core.series.Series, particles_title: str, range1: Tuple[float, float], y_axis_log: bool = False, save_fig: bool = True):
333def plot_mass2(
334    xgb_mass: pd.Series,
335    sim_mass: pd.Series,
336    particles_title: str,
337    range1: Tuple[float, float],
338    y_axis_log: bool = False,
339    save_fig: bool = True,
340):
341    """
342    Plots mass^2
343
344    Args:
345        xgb_mass (pd.Series): pd.Series containg xgb_selected mass^2
346        sim_mass (pd.Series): pd.Series containg MC-true mass^2
347        particles_title (str): Name of the plot.
348        range1 (tuple[float, float]): Range of the mass2 to be plotted on x-axis.
349        y_axis_log (bool, optional): If should use logscale in y-scale. Defaults to False.
350        save_fig (bool, optional): Whether should save the plot. Defaults to True.
351    """
352    # fig, axs = plt.subplots(2, 1,figsize=(15,10), sharex=True,  gridspec_kw={'width_ratios': [10],
353    #                            'height_ratios': [8,4]})
354    fig, axs = plt.subplots(figsize=(15, 10), dpi=300)
355
356    ns, bins, patches = axs.hist(
357        xgb_mass, bins=300, facecolor="red", alpha=0.3, range=range1
358    )
359    ns1, bins1, patches1 = axs.hist(
360        sim_mass, bins=300, facecolor="blue", alpha=0.3, range=range1
361    )
362    # plt.xlabel("Mass in GeV", fontsize = 15)
363    axs.set_ylabel("counts", fontsize=15)
364    # axs[0].grid()
365    axs.legend(
366        ("XGBoost selected " + particles_title, "all simulated " + particles_title),
367        loc="upper right",
368    )
369    if y_axis_log:
370        axs.set_yscale("log")
371    # plt.rcParams["legend.loc"] = 'upper right'
372    title = f"{particles_title} $mass^2$ histogram"
373    yName = r"Counts"
374    xName = r"$m^2$ $(GeV/c^2)^2$"
375    plt.xlabel(xName, fontsize=20, loc="right")
376    plt.ylabel(yName, fontsize=20, loc="top")
377    axs.set_title(title, fontsize=20)
378    axs.grid()
379    axs.tick_params(axis="both", which="major", labelsize=18)
380    if save_fig:
381        plt.savefig(f"mass2_{particles_title}.png")
382        plt.savefig(f"mass2_{particles_title}.pdf")
383        plt.close()
384    else:
385        plt.show()

Plots mass^2

Args: xgb_mass (pd.Series): pd.Series containg xgb_selected mass^2 sim_mass (pd.Series): pd.Series containg MC-true mass^2 particles_title (str): Name of the plot. range1 (tuple[float, float]): Range of the mass2 to be plotted on x-axis. y_axis_log (bool, optional): If should use logscale in y-scale. Defaults to False. save_fig (bool, optional): Whether should save the plot. Defaults to True.

def plot_all_particles_mass2( xgb_selected: pandas.core.series.Series, mass2_variable_name: str, pid_variable_name: str, particles_title: str, range1: Tuple[float, float], y_axis_log: bool = False, save_fig: bool = True):
388def plot_all_particles_mass2(
389    xgb_selected: pd.Series,
390    mass2_variable_name: str,
391    pid_variable_name: str,
392    particles_title: str,
393    range1: Tuple[float, float],
394    y_axis_log: bool = False,
395    save_fig: bool = True,
396):
397    """
398    Plots mc-true particle type in xgb_selected particles
399
400    Args:
401        xgb_selected (pd.Series): pd.Series with xgb-selected particles.
402        mass2_variable_name (str): Name of the mass2 variable name.
403        pid_variable_name (str): Name of the pid variable name.
404        particles_title (str): Name of the plot.
405        range1 (tuple[float, float]): Range of the x-axis.
406        y_axis_log (bool, optional): If should use logscale in y-scale. Defaults to False.
407        save_fig (bool, optional): Whether should save the plot. Defaults to True.
408    """
409    # fig, axs = plt.subplots(2, 1,figsize=(15,10), sharex=True,  gridspec_kw={'width_ratios': [10],
410    #                            'height_ratios': [8,4]})
411    fig, axs = plt.subplots(figsize=(15, 10), dpi=300)
412
413    selected_protons = xgb_selected[xgb_selected[pid_variable_name] == 0][
414        mass2_variable_name
415    ]
416    selected_kaons = xgb_selected[xgb_selected[pid_variable_name] == 1][
417        mass2_variable_name
418    ]
419    selected_pions = xgb_selected[xgb_selected[pid_variable_name] == 2][
420        mass2_variable_name
421    ]
422
423    ns, bins, patches = axs.hist(
424        selected_protons, bins=300, facecolor="blue", alpha=0.4, range=range1
425    )
426    ns, bins, patches = axs.hist(
427        selected_kaons, bins=300, facecolor="orange", alpha=0.4, range=range1
428    )
429    ns, bins, patches = axs.hist(
430        selected_pions, bins=300, facecolor="green", alpha=0.4, range=range1
431    )
432
433    # plt.xlabel("Mass in GeV", fontsize = 15)
434    axs.set_ylabel("counts", fontsize=15)
435    # axs[0].grid()
436    axs.legend(
437        (
438            f"XGBoost selected true protons",
439            "XGBoost selected true kaons",
440            "XGBoost selected true pions",
441        ),
442        loc="upper right",
443    )
444    if y_axis_log:
445        axs.set_yscale("log")
446    title = f"ALL XGBoost selected (true and false positive) {particles_title} $mass^2$ histogram"
447    yName = r"Counts"
448    xName = r"$m^2$ $(GeV/c^2)^2$"
449    plt.xlabel(xName, loc="right")
450    plt.ylabel(yName, loc="top")
451    axs.set_title(title)
452    axs.grid()
453    axs.tick_params(axis="both", which="major", labelsize=18)
454    if save_fig:
455        plt.savefig(f"mass2_all_selected_{particles_title}.png")
456        plt.savefig(f"mass2_all_selected_{particles_title}.pdf")
457        plt.close()
458    else:
459        plt.show()

Plots mc-true particle type in xgb_selected particles

Args: xgb_selected (pd.Series): pd.Series with xgb-selected particles. mass2_variable_name (str): Name of the mass2 variable name. pid_variable_name (str): Name of the pid variable name. particles_title (str): Name of the plot. range1 (tuple[float, float]): Range of the x-axis. y_axis_log (bool, optional): If should use logscale in y-scale. Defaults to False. save_fig (bool, optional): Whether should save the plot. Defaults to True.

def plot_eff_pT_rap( df: pandas.core.frame.DataFrame, pid: int, pid_var_name: str = 'Complex_pid', rapidity_var_name: str = 'Complex_rapidity', pT_var_name: str = 'Complex_pT', ranges: Tuple[Tuple[float, float], Tuple[float, float]] = [[0, 5], [0, 3]], nbins: int = 50, save_fig: bool = True, particle_names: List[str] = ['protons', 'kaons', 'pions', 'bckgr']):
462def plot_eff_pT_rap(
463    df: pd.DataFrame,
464    pid: int,
465    pid_var_name: str = "Complex_pid",
466    rapidity_var_name: str = "Complex_rapidity",
467    pT_var_name: str = "Complex_pT",
468    ranges: Tuple[Tuple[float, float], Tuple[float, float]] = [[0, 5], [0, 3]],
469    nbins: int = 50,
470    save_fig: bool = True,
471    particle_names: List[str] = ["protons", "kaons", "pions", "bckgr"],
472):
473    df_true = df[(df[pid_var_name] == pid)]  # simulated
474    df_reco = df[(df["xgb_preds"] == pid)]  # reconstructed by xgboost
475
476    x = np.array(df_true[rapidity_var_name])
477    y = np.array(df_true[pT_var_name])
478
479    xe = np.array(df_reco[rapidity_var_name])
480    ye = np.array(df_reco[pT_var_name])
481
482    fig = plt.figure(figsize=(8, 10), dpi=300)
483    plt.title(f"$p_T$-rapidity efficiency for all selected {particle_names[pid]}")
484    true, yedges, xedges = np.histogram2d(x, y, bins=nbins, range=ranges)
485    reco, _, _ = np.histogram2d(xe, ye, bins=(yedges, xedges), range=ranges)
486
487    eff = np.divide(true, reco, out=np.zeros_like(true), where=reco != 0)  # Efficiency
488    eff[eff == 0] = np.nan  # show zeros as white
489    img = plt.imshow(
490        eff,
491        interpolation="nearest",
492        origin="lower",
493        vmin=0,
494        vmax=1,
495        extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]],
496    )
497
498    cbar = fig.colorbar(img, fraction=0.025, pad=0.08)  # above plot H
499    cbar.set_label("efficiency (selected/simulated)", rotation=270, labelpad=20)
500
501    plt.xlabel("rapidity")
502    plt.ylabel("$p_T$ (GeV/c)")
503    plt.tight_layout()
504    if save_fig:
505        plt.savefig(f"eff_pT_rap_{particle_names[pid]}.png")
506        plt.savefig(f"eff_pT_rap_{particle_names[pid]}.pdf")
507        plt.close()
508    else:
509        plt.show()
def plot_pt_rapidity( df: pandas.core.frame.DataFrame, pid: int, pid_var_name: str = 'Complex_pid', rapidity_var_name: str = 'Complex_rapidity', pT_var_name: str = 'Complex_pT', ranges: Tuple[Tuple[float, float], Tuple[float, float]] = [[0, 5], [0, 3]], nbins=50, save_fig: bool = True, particle_names: List[str] = ['protons', 'kaons', 'pions', 'bckgr']):
512def plot_pt_rapidity(
513    df: pd.DataFrame,
514    pid: int,
515    pid_var_name: str = "Complex_pid",
516    rapidity_var_name: str = "Complex_rapidity",
517    pT_var_name: str = "Complex_pT",
518    ranges: Tuple[Tuple[float, float], Tuple[float, float]] = [[0, 5], [0, 3]],
519    nbins=50,
520    save_fig: bool = True,
521    particle_names: List[str] = ["protons", "kaons", "pions", "bckgr"],
522):
523    """
524    Plots pt-rapidity 2D histogram.
525
526    Args:
527        df (pd.DataFrame): Dataframe with input data.
528        pid (int): Pid of the variable to be plotted.
529        pid_var_name (str, optional): Name of the pid variable. Defaults to "Complex_pid".
530        rapidity_var_name (str, optional): Name of the rapidity variable. Defaults to "Complex_rapidity".
531        pT_var_name (str, optional): Name of the pT variable. Defaults to "Complex_pT".
532        ranges (Tuple[Tuple[float, float], Tuple[float, float]], optional):
533            Ranges of the plot. Defaults to [[0, 5], [0, 3]].
534        nbins (int, optional): Number of bins in each axis. Defaults to 50.
535        save_fig (bool, optional): Whether should save the figute. Defaults to True.
536        particle_names (List[str], optional): Names of the particles corresponding to pid.
537            Defaults to ["protons", "kaons", "pions", "bckgr"].
538    """
539    df_true = df[(df[pid_var_name] == pid)]  # simulated
540
541    x = np.array(df_true[rapidity_var_name])
542    y = np.array(df_true[pT_var_name])
543
544    fig = plt.figure(figsize=(8, 10), dpi=300)
545    plt.title(f"$p_T$-rapidity graph for all simulated {particle_names[pid]}")
546
547    true, yedges, xedges = np.histogram2d(x, y, bins=nbins, range=ranges)
548    true[true == 0] = np.nan  # show zeros as white
549
550    img = plt.imshow(
551        true,
552        interpolation="nearest",
553        origin="lower",
554        extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]],
555    )
556
557    cbar = fig.colorbar(img, fraction=0.025, pad=0.08)  # above plot H
558    cbar.set_label("counts", rotation=270, labelpad=20)
559
560    plt.xlabel("rapidity")
561    plt.ylabel("$p_T$ (GeV/c)")
562    plt.tight_layout()
563    if save_fig:
564        plt.savefig(f"plot_pt_rapidity_{particle_names[pid]}.png")
565        plt.savefig(f"plot_pt_rapidity_{particle_names[pid]}.pdf")
566        plt.close()
567    else:
568        plt.show()

Plots pt-rapidity 2D histogram.

Args: df (pd.DataFrame): Dataframe with input data. pid (int): Pid of the variable to be plotted. pid_var_name (str, optional): Name of the pid variable. Defaults to "Complex_pid". rapidity_var_name (str, optional): Name of the rapidity variable. Defaults to "Complex_rapidity". pT_var_name (str, optional): Name of the pT variable. Defaults to "Complex_pT". ranges (Tuple[Tuple[float, float], Tuple[float, float]], optional): Ranges of the plot. Defaults to [[0, 5], [0, 3]]. nbins (int, optional): Number of bins in each axis. Defaults to 50. save_fig (bool, optional): Whether should save the figute. Defaults to True. particle_names (List[str], optional): Names of the particles corresponding to pid. Defaults to ["protons", "kaons", "pions", "bckgr"].

def plot_shap_summary( x_train: pandas.core.frame.DataFrame, y_train: pandas.core.frame.DataFrame, model_hdl: hipe4ml.model_handler.ModelHandler, features_names: List[str], n_workers: int = 1, save_fig: bool = True, approximate: bool = False, n_samples: int = 50000, particle_names: List[str] = ['protons', 'kaons', 'pions']):
683def plot_shap_summary(
684    x_train: pd.DataFrame,
685    y_train: pd.DataFrame,
686    model_hdl: ModelHandler,
687    features_names: List[str],
688    n_workers: int = 1,
689    save_fig: bool = True,
690    approximate: bool = False,
691    n_samples: int = 50000,
692    particle_names: List[str] = ["protons", "kaons", "pions"],
693):
694    """
695    Method for plotting shap plots
696
697    Args:
698        x_train (pd.DataFrame): pd.Dataframe with X training dataset.
699        y_train (pd.DataFrame): X training dataset labels.
700        model_hdl (ModelHandler): Model Handler to be explained.
701        features_names (List[str]): List of the training variables.
702        n_workers (int, optional): Number of thread for multithreading.
703            Note: it uses fastreeshap library, not shap. Defaults to 1.
704        save_fig (bool, optional): Whether should save the plots.. Defaults to True.
705        approximate (bool, optional): Whether should the approximate values. Defaults to False.
706        n_samples (int, optional): Maximal number of samples in each class. Defaults to 50000.
707        particle_names (List[str], optional): List of the classified particle names.
708            Defaults to ["protons", "kaons", "pions"].
709    """
710    print("Creating shap plots...")
711    explainer = shap.TreeExplainer(
712        model_hdl.get_original_model(), n_jobs=n_workers, approximate=approximate
713    )
714    # Apply n_sanples in each class
715    y_train_df = pd.DataFrame(y_train, columns=["true_class"])
716    merged_df = pd.concat([x_train, y_train_df], axis=1)
717    grouped_df = merged_df.groupby("true_class")
718    resampled_df = pd.concat(
719        [
720            resample(group, n_samples=min(n_samples, len(group)), replace=False)
721            for _, group in grouped_df
722        ]
723    )
724
725    # Split the resampled pd.DataFrame back into input data and label data
726    x_train_resampled = resampled_df.iloc[:, :-1]
727    y_train_resampled = resampled_df.iloc[:, -1].to_numpy()
728    del merged_df, grouped_df, resampled_df
729    gc.collect()
730
731    shap_values = explainer.shap_values(
732        x_train_resampled, y_train_resampled, check_additivity=False
733    )
734    num_classes = len(shap_values)  # get the number of classes
735    for i in range(num_classes):
736        _shap_summary(
737            shap_values[i],
738            x_train_resampled,
739            features_names,
740            particle_names[i],
741            save_fig=save_fig,
742        )
743        _shap_interaction(
744            shap_values[i],
745            x_train_resampled,
746            features_names,
747            particle_names[i],
748            save_fig=save_fig,
749        )

Method for plotting shap plots

Args: x_train (pd.DataFrame): pd.Dataframe with X training dataset. y_train (pd.DataFrame): X training dataset labels. model_hdl (ModelHandler): Model Handler to be explained. features_names (List[str]): List of the training variables. n_workers (int, optional): Number of thread for multithreading. Note: it uses fastreeshap library, not shap. Defaults to 1. save_fig (bool, optional): Whether should save the plots.. Defaults to True. approximate (bool, optional): Whether should the approximate values. Defaults to False. n_samples (int, optional): Maximal number of samples in each class. Defaults to 50000. particle_names (List[str], optional): List of the classified particle names. Defaults to ["protons", "kaons", "pions"].

def plot_efficiency_purity( probas: numpy.ndarray, efficiencies: List[List[float]], purities: List[List[float]], save_fig: bool = True, particle_names: List[str] = ['protons', 'kaons', 'pions']):
752def plot_efficiency_purity(
753    probas: np.ndarray,
754    efficiencies: List[List[float]],
755    purities: List[List[float]],
756    save_fig: bool = True,
757    particle_names: List[str] = ["protons", "kaons", "pions"],
758):
759    """
760    Plots efficiency and purity in function of probability cuts.
761
762    Args:
763        probas (np.ndarray): Probability cuts
764        efficiencies (List[List[float]]): List of list of efficiencies for each clas.
765        purities (List[List[float]]): List of list of purities for each clas.
766        save_fig (bool, optional): Whether should save the fig. Defaults to True.
767        particle_names (List[str], optional): List of the particle names. Defaults to ["protons", "kaons", "pions"].
768    """
769    for i, (eff, pur) in enumerate(zip(efficiencies, purities)):
770        if save_fig:
771            dpi = 300
772        else:
773            dpi = 100
774        fig, ax = plt.subplots(figsize=(10, 7), dpi=dpi)
775        ax.plot(probas, eff, label="efficiency")
776        ax.plot(probas, pur, label="purity")
777        ax.set_xlabel("BDT cut")
778        ax.set_ylabel("\% ")
779        ax.legend(loc="upper right")
780        ax.set_title(
781            f"Efficiency and purity in function of BDT cut for {particle_names[i]}"
782        )
783        ax.grid(which="major", linestyle="-")
784        ax.minorticks_on()
785        ax.grid(which="minor", linestyle="--")
786        if save_fig:
787            fig.savefig(f"efficiency_purity__{particle_names[i]}.png")
788            fig.savefig(f"efficiency_purity_id_{particle_names[i]}.pdf")
789            plt.close()
790        else:
791            plt.show()

Plots efficiency and purity in function of probability cuts.

Args: probas (np.ndarray): Probability cuts efficiencies (List[List[float]]): List of list of efficiencies for each clas. purities (List[List[float]]): List of list of purities for each clas. save_fig (bool, optional): Whether should save the fig. Defaults to True. particle_names (List[str], optional): List of the particle names. Defaults to ["protons", "kaons", "pions"].

def plot_before_after_variables( df: pandas.core.frame.DataFrame, pid: float, pid_variable_name: str, training_variables: List[str], save_fig: bool = True, log_yscale: bool = True):
795def plot_before_after_variables(
796    df: pd.DataFrame,
797    pid: float,
798    pid_variable_name: str,
799    training_variables: List[str],
800    save_fig: bool = True,
801    log_yscale: bool = True,
802):
803    """
804    Plots each variable before and after selection.
805    Legacy: var_distributions_plot should rather be used.
806
807    Args:
808        df (pd.DataFrame): _description_
809        pid (float): _description_
810        pid_variable_name (str): _description_
811        training_variables (List[str]): _description_
812        save_fig (bool, optional): _description_. Defaults to True.
813        log_yscale (bool, optional): _description_. Defaults to True.
814
815    Returns:
816        _type_: _description_
817    """
818    df_true = df[(df[pid_variable_name] == pid)]  # simulated
819    df_reco = df[(df["xgb_preds"] == pid)]  # reconstructed by xgboost
820
821    def variable_plot(
822        df_true,
823        df_reco,
824        variable_name: str,
825        log_yscale: bool = True,
826        leg1="Simulated",
827        leg2="XGB-selected",
828        bins=100,
829    ):
830        fig, ax = plt.subplots(figsize=(15, 15), dpi=300)
831        ax.hist(
832            df_true[variable_name],
833            bins=bins,
834            facecolor="blue",
835            alpha=0.6,
836            histtype="step",
837            fill=False,
838            linewidth=2,
839        )
840        ax.hist(
841            df_reco[variable_name],
842            bins=bins,
843            facecolor="red",
844            alpha=0.7,
845            histtype="step",
846            fill=False,
847            linewidth=2,
848        )
849        ax.grid()
850        ax.set_xlabel(variable_name, fontsize=15, loc="right")
851        ax.set_ylim(bottom=1)
852        ax.set_ylabel("counts", fontsize=15)
853        if log_yscale:
854            ax.set_yscale("log")
855        ax.legend((leg1, leg2), fontsize=15, loc="upper right")
856        ax.set_title(
857            f"{variable_name} before and after XGB selection for pid={pid}", fontsize=15
858        )
859        return fig
860
861    for training_variable in training_variables:
862        plot = variable_plot(df_true, df_reco, training_variable, log_yscale)
863        if save_fig:
864            plot.savefig(f"{training_variable}_before_after_pid_{pid}.png")
865            plot.savefig(f"{training_variable}_before_after_{pid}.pdf")
866            plt.close()
867
868        else:
869            plot.show()

Plots each variable before and after selection. Legacy: var_distributions_plot should rather be used.

Args: df (pd.DataFrame): _description_ pid (float): _description_ pid_variable_name (str): _description_ training_variables (List[str]): _description_ save_fig (bool, optional): _description_. Defaults to True. log_yscale (bool, optional): _description_. Defaults to True.

Returns: _type_: _description_