ml_pid_cbm.gauss.validate_gauss

  1import argparse
  2import os
  3import sys
  4from collections import defaultdict
  5from shutil import copy2
  6from typing import List, Tuple
  7
  8import numpy as np
  9from sklearn.metrics import confusion_matrix
 10
 11from ml_pid_cbm.tools import json_tools, plotting_tools
 12from ml_pid_cbm.tools.load_data import LoadData
 13from ml_pid_cbm.tools.particles_id import ParticlesId as Pid
 14from ml_pid_cbm.validate_model import ValidateModel
 15
 16
 17class ValidateGauss(ValidateModel):
 18    """
 19    Class for testing the ml model
 20    """
 21
 22    def evaluate_probas(
 23        self,
 24        start: float = 0.35,
 25        stop: float = 1,
 26        n_steps: int = 40,
 27        purity_cut: float = 0.0,
 28        save_fig: bool = True,
 29    ) -> Tuple[float, float, float]:
 30        """Method for evaluating probability (BDT) cut effect on efficency and purity.
 31
 32        Args:
 33            start (float, optional): Lower range of probablity cuts. Defaults to 0.3.
 34            stop (float, optional): Upper range of probablity cuts. Defaults to 0.98.
 35            n_steps (int, optional): Number of probability cuts to try. Defaults to 30.
 36            pid_variable_name (str, optional): Name of the variable containing true Pid. Defaults to "Complex_pid".
 37            purity_cut (float, optional): Minimal purity for automatic cuts selection. Defaults to 0..
 38            save_fig (bool, optional): Saves figures (BDT cut vs efficiency and purity) to file if True. Defaults to True.
 39
 40        Returns:
 41            Tuple[float, float, float]: Probability cut for each variable.
 42        """
 43        print(
 44            f"Checking efficiency and purity for {int(n_steps)} probablity cuts between {start}, and {stop}..."
 45        )
 46        probas = np.linspace(start, stop, n_steps)
 47        efficienciess_protons, efficiencies_kaons, efficiencies_pions = [], [], []
 48        efficiencies = [efficienciess_protons, efficiencies_kaons, efficiencies_pions]
 49        purities_protons, purities_kaons, purities_pions = [], [], []
 50        purities = [purities_protons, purities_kaons, purities_pions]
 51        best_cuts = [0.0, 0.0, 0.0]
 52        max_efficiencies = [0.0, 0.0, 0.0]
 53        max_purities = [0.0, 0.0, 0.0]
 54
 55        for proba in probas:
 56            self.gauss_preds(proba, proba, proba)
 57            # confusion matrix
 58            cnf_matrix = confusion_matrix(
 59                self.particles_df[self.pid_variable_name],
 60                self.particles_df["Complex_gauss_preds"],
 61            )
 62            for pid in range(self.get_n_classes() - 1):
 63                efficiency, purity = self.efficiency_stats(
 64                    cnf_matrix, pid, print_output=False
 65                )
 66                efficiencies[pid].append(efficiency)
 67                purities[pid].append(purity)
 68                if purity_cut > 0.0:
 69                    # Minimal purity for automatic threshold selection.
 70                    # Will choose the highest efficiency for purity above this value.
 71                    if purity >= purity_cut:
 72                        if efficiency > max_efficiencies[pid]:
 73                            best_cuts[pid] = proba
 74                            max_efficiencies[pid] = efficiency
 75                            max_purities[pid] = purity
 76                    # If max purity is below this value, will choose the highest purity available.
 77                    else:
 78                        if purity > max_purities[pid]:
 79                            best_cuts[pid] = proba
 80                            max_efficiencies[pid] = efficiency
 81                            max_purities[pid] = purity
 82
 83        plotting_tools.plot_efficiency_purity(probas, efficiencies, purities, save_fig)
 84        if save_fig:
 85            print("Plots ready!")
 86        if purity_cut > 0:
 87            print(f"Selected probaility cuts: {best_cuts}")
 88            return (best_cuts[0], best_cuts[1], best_cuts[2])
 89        else:
 90            return (-1.0, -1.0, -1.0)
 91
 92    def gauss_preds(self, proba_proton: float, proba_kaon: float, proba_pion: float):
 93        """Gets particle type as selected by xgboost model if above probability threshold.
 94
 95        Args:
 96            proba_proton (float): Probablity threshold to classify particle as proton.
 97            proba_kaon (float): Probablity threshold to classify particle as kaon.
 98            proba_pion (float): Probablity threshold to classify particle as pion.
 99        """
100        df = self.particles_df
101        df["Complex_gauss_preds"] = df["Complex_gauss_pid"]
102
103        # setting to bckgr if smaller than probability threshold
104        proton = (df["Complex_gauss_pid"] == 0) & (df["Complex_prob_p"] > proba_proton)
105        pion = (df["Complex_gauss_pid"] == 1) & (df["Complex_prob_K"] > proba_kaon)
106        kaon = (df["Complex_gauss_pid"] == 2) & (df["Complex_prob_pi"] > proba_pion)
107        df.loc[~(proton | pion | kaon), "Complex_gauss_preds"] = 3
108
109        self.particles_df = df
110
111    def confusion_matrix_and_stats(
112        self, efficiency_filename: str = "efficiency_stats.txt"
113    ):
114        """
115        Generates confusion matrix and efficiency/purity stats.
116        """
117        cnf_matrix = confusion_matrix(
118            self.particles_df[self.pid_variable_name],
119            self.particles_df["Complex_gauss_preds"],
120        )
121        plotting_tools.plot_confusion_matrix(cnf_matrix)
122        plotting_tools.plot_confusion_matrix(cnf_matrix, normalize=True)
123        txt_file = open(efficiency_filename, "w+")
124        for pid in range(self.get_n_classes() - 1):
125            self.efficiency_stats(cnf_matrix, pid, txt_file)
126        txt_file.close()
127
128    def remap_gauss_names(self):
129        """
130        Remaps Pid of particles to output format from XGBoost Model.
131        Protons: 0; Kaons: 1; Pions, Electrons, Muons: 2; Other: 3
132
133        """
134        df = self.particles_df
135        if self.anti_particles:
136            df["Complex_gauss_pid"] = (
137                df["Complex_gauss_pid"]
138                .map(
139                    defaultdict(
140                        lambda: 3.0,
141                        {
142                            Pid.ANTI_PROTON.value: 0.0,
143                            Pid.NEG_KAON.value: 1.0,
144                            Pid.NEG_PION.value: 2.0,
145                            Pid.ELECTRON.value: 2.0,
146                            Pid.NEG_MUON.value: 2.0,
147                        },
148                    ),
149                    na_action="ignore",
150                )
151                .astype(float)
152            )
153        else:
154            df["Complex_gauss_pid"] = (
155                df["Complex_gauss_pid"]
156                .map(
157                    defaultdict(
158                        lambda: 3.0,
159                        {
160                            Pid.PROTON.value: 0.0,
161                            Pid.POS_KAON.value: 1.0,
162                            Pid.POS_PION.value: 2.0,
163                            Pid.POSITRON.value: 2.0,
164                            Pid.POS_MUON.value: 2.0,
165                        },
166                    ),
167                    na_action="ignore",
168                )
169                .astype(float)
170            )
171        self.particles_df = df
172
173    def _tof_plots(self):
174        """
175        Generates tof plots.
176        """
177        for pid, particle_name in enumerate(self.classes_names):
178            # simulated:
179            try:
180                plotting_tools.tof_plot(
181                    self.particles_df[self.particles_df[self.pid_variable_name] == pid],
182                    self.json_file_name,
183                    f"{particle_name} (all simulated)",
184                )
185            except ValueError:
186                print(f"No simulated {particle_name}s")
187            # xgb selected
188            try:
189                plotting_tools.tof_plot(
190                    self.particles_df[self.particles_df["Complex_gauss_preds"] == pid],
191                    self.json_file_name,
192                    f"{particle_name} (Gauss-selected)",
193                )
194            except ValueError:
195                print(f"No Gauss-selected {particle_name}s")
196
197    def _mass2_plots(self):
198        """
199        Generates mass2 plots.
200        """
201        protons_range = (-0.2, 1.8)
202        kaons_range = (-0.2, 0.6)
203        pions_range = (-0.3, 0.3)
204        ranges = [protons_range, kaons_range, pions_range, pions_range]
205        for pid, particle_name in enumerate(self.classes_names):
206            plotting_tools.plot_mass2(
207                self.particles_df[self.particles_df["Complex_gauss_preds"] == pid][
208                    self.mass2_variable_name
209                ],
210                self.particles_df[self.particles_df[self.pid_variable_name] == pid][
211                    self.mass2_variable_name
212                ],
213                particle_name,
214                ranges[pid],
215            )
216            plotting_tools.plot_all_particles_mass2(
217                self.particles_df[self.particles_df["Complex_gauss_preds"] == pid],
218                self.mass2_variable_name,
219                self.pid_variable_name,
220                particle_name,
221                ranges[pid],
222            )
223
224    def _vars_distributions_plots(self):
225        """
226        Generates distributions of variables and pT-rapidity graphs.
227        """
228        vars_to_draw = json_tools.load_vars_to_draw(self.json_file_name)
229        for pid, particle_name in enumerate(self.classes_names):
230            plotting_tools.var_distributions_plot(
231                vars_to_draw,
232                [
233                    self.particles_df[
234                        (self.particles_df[self.pid_variable_name] == pid)
235                    ],
236                    self.particles_df[
237                        (
238                            (self.particles_df[self.pid_variable_name] == pid)
239                            & (self.particles_df["Complex_gauss_preds"] == pid)
240                        )
241                    ],
242                    self.particles_df[
243                        (
244                            (self.particles_df[self.pid_variable_name] != pid)
245                            & (self.particles_df["Complex_gauss_preds"] == pid)
246                        )
247                    ],
248                ],
249                [
250                    f"true MC {particle_name}",
251                    f"true selected {particle_name}",
252                    f"false selected {particle_name}",
253                ],
254                filename=f"vars_dist_{particle_name}",
255            )
256
257
258def parse_args(args: List[str]) -> argparse.Namespace:
259    """
260    Arguments parser for the main method.
261
262    Args:
263        args (List[str]): Arguments from the command line, should be sys.argv[1:].
264
265    Returns:
266        argparse.Namespace: argparse.Namespace containg args
267    """
268    parser = argparse.ArgumentParser(
269        prog="ML_PID_CBM ValidatGauss",
270        description="Program for validating Gaussian PID model",
271    )
272    parser.add_argument(
273        "--config",
274        "-c",
275        nargs=1,
276        required=True,
277        type=str,
278        help="Filename of path of config json file.",
279    )
280    parser.add_argument(
281        "--nworkers",
282        "-n",
283        type=int,
284        default=1,
285        help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.",
286    )
287    parser.add_argument(
288        "--momentum",
289        "-p",
290        nargs=2,
291        required=True,
292        type=float,
293        help="Lower and upper momentum limit, e.g., 1 3",
294    )
295    return parser.parse_args(args)
296
297
298if __name__ == "__main__":
299    # parser for main class
300    args = parse_args(sys.argv[1:])
301    # config  arguments to be loaded from args
302    json_file_name = args.config[0]
303
304    n_workers = args.nworkers
305    lower_p, upper_p, is_anti = args.momentum[0], args.momentum[1], False
306    # loading test data
307    data_file_name = json_tools.load_file_name(json_file_name, "test")
308
309    loader = LoadData(data_file_name, json_file_name, lower_p, upper_p, is_anti)
310    # sigma selection
311    # loading model handler and applying on dataset
312    print(f"\nLoading data from {data_file_name}\n in ranges p = {lower_p}, {upper_p}")
313    json_file_path = os.path.join(os.getcwd(), json_file_name)
314    folder_name = f"gauss_{lower_p}_{upper_p}"
315    if not os.path.exists(f"{folder_name}"):
316        os.mkdir(f"{folder_name}")
317    os.chdir(f"{folder_name}")
318    copy2(json_file_path, os.getcwd())
319    test_particles = loader.load_tree(max_workers=n_workers)
320    # validate model object
321    validate = ValidateGauss(
322        lower_p, upper_p, is_anti, json_file_name, test_particles.get_data_frame()
323    )
324    # remap Pid to match output XGBoost format
325    validate.remap_names()
326    validate.remap_gauss_names()
327    pid_variable_name = json_tools.load_var_name(json_file_name, "pid")
328    proba_proton, proba_kaon, proba_pion = validate.evaluate_probas(purity_cut=90)
329    # graphs
330    validate.gauss_preds(proba_proton, proba_kaon, proba_pion)
331    validate.confusion_matrix_and_stats()
332    print("Generating plots...")
333    validate.generate_plots()
334    # save validated dataset
335    validate.save_df()
class ValidateGauss(ml_pid_cbm.validate_model.ValidateModel):
 18class ValidateGauss(ValidateModel):
 19    """
 20    Class for testing the ml model
 21    """
 22
 23    def evaluate_probas(
 24        self,
 25        start: float = 0.35,
 26        stop: float = 1,
 27        n_steps: int = 40,
 28        purity_cut: float = 0.0,
 29        save_fig: bool = True,
 30    ) -> Tuple[float, float, float]:
 31        """Method for evaluating probability (BDT) cut effect on efficency and purity.
 32
 33        Args:
 34            start (float, optional): Lower range of probablity cuts. Defaults to 0.3.
 35            stop (float, optional): Upper range of probablity cuts. Defaults to 0.98.
 36            n_steps (int, optional): Number of probability cuts to try. Defaults to 30.
 37            pid_variable_name (str, optional): Name of the variable containing true Pid. Defaults to "Complex_pid".
 38            purity_cut (float, optional): Minimal purity for automatic cuts selection. Defaults to 0..
 39            save_fig (bool, optional): Saves figures (BDT cut vs efficiency and purity) to file if True. Defaults to True.
 40
 41        Returns:
 42            Tuple[float, float, float]: Probability cut for each variable.
 43        """
 44        print(
 45            f"Checking efficiency and purity for {int(n_steps)} probablity cuts between {start}, and {stop}..."
 46        )
 47        probas = np.linspace(start, stop, n_steps)
 48        efficienciess_protons, efficiencies_kaons, efficiencies_pions = [], [], []
 49        efficiencies = [efficienciess_protons, efficiencies_kaons, efficiencies_pions]
 50        purities_protons, purities_kaons, purities_pions = [], [], []
 51        purities = [purities_protons, purities_kaons, purities_pions]
 52        best_cuts = [0.0, 0.0, 0.0]
 53        max_efficiencies = [0.0, 0.0, 0.0]
 54        max_purities = [0.0, 0.0, 0.0]
 55
 56        for proba in probas:
 57            self.gauss_preds(proba, proba, proba)
 58            # confusion matrix
 59            cnf_matrix = confusion_matrix(
 60                self.particles_df[self.pid_variable_name],
 61                self.particles_df["Complex_gauss_preds"],
 62            )
 63            for pid in range(self.get_n_classes() - 1):
 64                efficiency, purity = self.efficiency_stats(
 65                    cnf_matrix, pid, print_output=False
 66                )
 67                efficiencies[pid].append(efficiency)
 68                purities[pid].append(purity)
 69                if purity_cut > 0.0:
 70                    # Minimal purity for automatic threshold selection.
 71                    # Will choose the highest efficiency for purity above this value.
 72                    if purity >= purity_cut:
 73                        if efficiency > max_efficiencies[pid]:
 74                            best_cuts[pid] = proba
 75                            max_efficiencies[pid] = efficiency
 76                            max_purities[pid] = purity
 77                    # If max purity is below this value, will choose the highest purity available.
 78                    else:
 79                        if purity > max_purities[pid]:
 80                            best_cuts[pid] = proba
 81                            max_efficiencies[pid] = efficiency
 82                            max_purities[pid] = purity
 83
 84        plotting_tools.plot_efficiency_purity(probas, efficiencies, purities, save_fig)
 85        if save_fig:
 86            print("Plots ready!")
 87        if purity_cut > 0:
 88            print(f"Selected probaility cuts: {best_cuts}")
 89            return (best_cuts[0], best_cuts[1], best_cuts[2])
 90        else:
 91            return (-1.0, -1.0, -1.0)
 92
 93    def gauss_preds(self, proba_proton: float, proba_kaon: float, proba_pion: float):
 94        """Gets particle type as selected by xgboost model if above probability threshold.
 95
 96        Args:
 97            proba_proton (float): Probablity threshold to classify particle as proton.
 98            proba_kaon (float): Probablity threshold to classify particle as kaon.
 99            proba_pion (float): Probablity threshold to classify particle as pion.
100        """
101        df = self.particles_df
102        df["Complex_gauss_preds"] = df["Complex_gauss_pid"]
103
104        # setting to bckgr if smaller than probability threshold
105        proton = (df["Complex_gauss_pid"] == 0) & (df["Complex_prob_p"] > proba_proton)
106        pion = (df["Complex_gauss_pid"] == 1) & (df["Complex_prob_K"] > proba_kaon)
107        kaon = (df["Complex_gauss_pid"] == 2) & (df["Complex_prob_pi"] > proba_pion)
108        df.loc[~(proton | pion | kaon), "Complex_gauss_preds"] = 3
109
110        self.particles_df = df
111
112    def confusion_matrix_and_stats(
113        self, efficiency_filename: str = "efficiency_stats.txt"
114    ):
115        """
116        Generates confusion matrix and efficiency/purity stats.
117        """
118        cnf_matrix = confusion_matrix(
119            self.particles_df[self.pid_variable_name],
120            self.particles_df["Complex_gauss_preds"],
121        )
122        plotting_tools.plot_confusion_matrix(cnf_matrix)
123        plotting_tools.plot_confusion_matrix(cnf_matrix, normalize=True)
124        txt_file = open(efficiency_filename, "w+")
125        for pid in range(self.get_n_classes() - 1):
126            self.efficiency_stats(cnf_matrix, pid, txt_file)
127        txt_file.close()
128
129    def remap_gauss_names(self):
130        """
131        Remaps Pid of particles to output format from XGBoost Model.
132        Protons: 0; Kaons: 1; Pions, Electrons, Muons: 2; Other: 3
133
134        """
135        df = self.particles_df
136        if self.anti_particles:
137            df["Complex_gauss_pid"] = (
138                df["Complex_gauss_pid"]
139                .map(
140                    defaultdict(
141                        lambda: 3.0,
142                        {
143                            Pid.ANTI_PROTON.value: 0.0,
144                            Pid.NEG_KAON.value: 1.0,
145                            Pid.NEG_PION.value: 2.0,
146                            Pid.ELECTRON.value: 2.0,
147                            Pid.NEG_MUON.value: 2.0,
148                        },
149                    ),
150                    na_action="ignore",
151                )
152                .astype(float)
153            )
154        else:
155            df["Complex_gauss_pid"] = (
156                df["Complex_gauss_pid"]
157                .map(
158                    defaultdict(
159                        lambda: 3.0,
160                        {
161                            Pid.PROTON.value: 0.0,
162                            Pid.POS_KAON.value: 1.0,
163                            Pid.POS_PION.value: 2.0,
164                            Pid.POSITRON.value: 2.0,
165                            Pid.POS_MUON.value: 2.0,
166                        },
167                    ),
168                    na_action="ignore",
169                )
170                .astype(float)
171            )
172        self.particles_df = df
173
174    def _tof_plots(self):
175        """
176        Generates tof plots.
177        """
178        for pid, particle_name in enumerate(self.classes_names):
179            # simulated:
180            try:
181                plotting_tools.tof_plot(
182                    self.particles_df[self.particles_df[self.pid_variable_name] == pid],
183                    self.json_file_name,
184                    f"{particle_name} (all simulated)",
185                )
186            except ValueError:
187                print(f"No simulated {particle_name}s")
188            # xgb selected
189            try:
190                plotting_tools.tof_plot(
191                    self.particles_df[self.particles_df["Complex_gauss_preds"] == pid],
192                    self.json_file_name,
193                    f"{particle_name} (Gauss-selected)",
194                )
195            except ValueError:
196                print(f"No Gauss-selected {particle_name}s")
197
198    def _mass2_plots(self):
199        """
200        Generates mass2 plots.
201        """
202        protons_range = (-0.2, 1.8)
203        kaons_range = (-0.2, 0.6)
204        pions_range = (-0.3, 0.3)
205        ranges = [protons_range, kaons_range, pions_range, pions_range]
206        for pid, particle_name in enumerate(self.classes_names):
207            plotting_tools.plot_mass2(
208                self.particles_df[self.particles_df["Complex_gauss_preds"] == pid][
209                    self.mass2_variable_name
210                ],
211                self.particles_df[self.particles_df[self.pid_variable_name] == pid][
212                    self.mass2_variable_name
213                ],
214                particle_name,
215                ranges[pid],
216            )
217            plotting_tools.plot_all_particles_mass2(
218                self.particles_df[self.particles_df["Complex_gauss_preds"] == pid],
219                self.mass2_variable_name,
220                self.pid_variable_name,
221                particle_name,
222                ranges[pid],
223            )
224
225    def _vars_distributions_plots(self):
226        """
227        Generates distributions of variables and pT-rapidity graphs.
228        """
229        vars_to_draw = json_tools.load_vars_to_draw(self.json_file_name)
230        for pid, particle_name in enumerate(self.classes_names):
231            plotting_tools.var_distributions_plot(
232                vars_to_draw,
233                [
234                    self.particles_df[
235                        (self.particles_df[self.pid_variable_name] == pid)
236                    ],
237                    self.particles_df[
238                        (
239                            (self.particles_df[self.pid_variable_name] == pid)
240                            & (self.particles_df["Complex_gauss_preds"] == pid)
241                        )
242                    ],
243                    self.particles_df[
244                        (
245                            (self.particles_df[self.pid_variable_name] != pid)
246                            & (self.particles_df["Complex_gauss_preds"] == pid)
247                        )
248                    ],
249                ],
250                [
251                    f"true MC {particle_name}",
252                    f"true selected {particle_name}",
253                    f"false selected {particle_name}",
254                ],
255                filename=f"vars_dist_{particle_name}",
256            )

Class for testing the ml model

def evaluate_probas( self, start: float = 0.35, stop: float = 1, n_steps: int = 40, purity_cut: float = 0.0, save_fig: bool = True) -> Tuple[float, float, float]:
23    def evaluate_probas(
24        self,
25        start: float = 0.35,
26        stop: float = 1,
27        n_steps: int = 40,
28        purity_cut: float = 0.0,
29        save_fig: bool = True,
30    ) -> Tuple[float, float, float]:
31        """Method for evaluating probability (BDT) cut effect on efficency and purity.
32
33        Args:
34            start (float, optional): Lower range of probablity cuts. Defaults to 0.3.
35            stop (float, optional): Upper range of probablity cuts. Defaults to 0.98.
36            n_steps (int, optional): Number of probability cuts to try. Defaults to 30.
37            pid_variable_name (str, optional): Name of the variable containing true Pid. Defaults to "Complex_pid".
38            purity_cut (float, optional): Minimal purity for automatic cuts selection. Defaults to 0..
39            save_fig (bool, optional): Saves figures (BDT cut vs efficiency and purity) to file if True. Defaults to True.
40
41        Returns:
42            Tuple[float, float, float]: Probability cut for each variable.
43        """
44        print(
45            f"Checking efficiency and purity for {int(n_steps)} probablity cuts between {start}, and {stop}..."
46        )
47        probas = np.linspace(start, stop, n_steps)
48        efficienciess_protons, efficiencies_kaons, efficiencies_pions = [], [], []
49        efficiencies = [efficienciess_protons, efficiencies_kaons, efficiencies_pions]
50        purities_protons, purities_kaons, purities_pions = [], [], []
51        purities = [purities_protons, purities_kaons, purities_pions]
52        best_cuts = [0.0, 0.0, 0.0]
53        max_efficiencies = [0.0, 0.0, 0.0]
54        max_purities = [0.0, 0.0, 0.0]
55
56        for proba in probas:
57            self.gauss_preds(proba, proba, proba)
58            # confusion matrix
59            cnf_matrix = confusion_matrix(
60                self.particles_df[self.pid_variable_name],
61                self.particles_df["Complex_gauss_preds"],
62            )
63            for pid in range(self.get_n_classes() - 1):
64                efficiency, purity = self.efficiency_stats(
65                    cnf_matrix, pid, print_output=False
66                )
67                efficiencies[pid].append(efficiency)
68                purities[pid].append(purity)
69                if purity_cut > 0.0:
70                    # Minimal purity for automatic threshold selection.
71                    # Will choose the highest efficiency for purity above this value.
72                    if purity >= purity_cut:
73                        if efficiency > max_efficiencies[pid]:
74                            best_cuts[pid] = proba
75                            max_efficiencies[pid] = efficiency
76                            max_purities[pid] = purity
77                    # If max purity is below this value, will choose the highest purity available.
78                    else:
79                        if purity > max_purities[pid]:
80                            best_cuts[pid] = proba
81                            max_efficiencies[pid] = efficiency
82                            max_purities[pid] = purity
83
84        plotting_tools.plot_efficiency_purity(probas, efficiencies, purities, save_fig)
85        if save_fig:
86            print("Plots ready!")
87        if purity_cut > 0:
88            print(f"Selected probaility cuts: {best_cuts}")
89            return (best_cuts[0], best_cuts[1], best_cuts[2])
90        else:
91            return (-1.0, -1.0, -1.0)

Method for evaluating probability (BDT) cut effect on efficency and purity.

Args: start (float, optional): Lower range of probablity cuts. Defaults to 0.3. stop (float, optional): Upper range of probablity cuts. Defaults to 0.98. n_steps (int, optional): Number of probability cuts to try. Defaults to 30. pid_variable_name (str, optional): Name of the variable containing true Pid. Defaults to "Complex_pid". purity_cut (float, optional): Minimal purity for automatic cuts selection. Defaults to 0.. save_fig (bool, optional): Saves figures (BDT cut vs efficiency and purity) to file if True. Defaults to True.

Returns: Tuple[float, float, float]: Probability cut for each variable.

def gauss_preds(self, proba_proton: float, proba_kaon: float, proba_pion: float):
 93    def gauss_preds(self, proba_proton: float, proba_kaon: float, proba_pion: float):
 94        """Gets particle type as selected by xgboost model if above probability threshold.
 95
 96        Args:
 97            proba_proton (float): Probablity threshold to classify particle as proton.
 98            proba_kaon (float): Probablity threshold to classify particle as kaon.
 99            proba_pion (float): Probablity threshold to classify particle as pion.
100        """
101        df = self.particles_df
102        df["Complex_gauss_preds"] = df["Complex_gauss_pid"]
103
104        # setting to bckgr if smaller than probability threshold
105        proton = (df["Complex_gauss_pid"] == 0) & (df["Complex_prob_p"] > proba_proton)
106        pion = (df["Complex_gauss_pid"] == 1) & (df["Complex_prob_K"] > proba_kaon)
107        kaon = (df["Complex_gauss_pid"] == 2) & (df["Complex_prob_pi"] > proba_pion)
108        df.loc[~(proton | pion | kaon), "Complex_gauss_preds"] = 3
109
110        self.particles_df = df

Gets particle type as selected by xgboost model if above probability threshold.

Args: proba_proton (float): Probablity threshold to classify particle as proton. proba_kaon (float): Probablity threshold to classify particle as kaon. proba_pion (float): Probablity threshold to classify particle as pion.

def confusion_matrix_and_stats(self, efficiency_filename: str = 'efficiency_stats.txt'):
112    def confusion_matrix_and_stats(
113        self, efficiency_filename: str = "efficiency_stats.txt"
114    ):
115        """
116        Generates confusion matrix and efficiency/purity stats.
117        """
118        cnf_matrix = confusion_matrix(
119            self.particles_df[self.pid_variable_name],
120            self.particles_df["Complex_gauss_preds"],
121        )
122        plotting_tools.plot_confusion_matrix(cnf_matrix)
123        plotting_tools.plot_confusion_matrix(cnf_matrix, normalize=True)
124        txt_file = open(efficiency_filename, "w+")
125        for pid in range(self.get_n_classes() - 1):
126            self.efficiency_stats(cnf_matrix, pid, txt_file)
127        txt_file.close()

Generates confusion matrix and efficiency/purity stats.

def remap_gauss_names(self):
129    def remap_gauss_names(self):
130        """
131        Remaps Pid of particles to output format from XGBoost Model.
132        Protons: 0; Kaons: 1; Pions, Electrons, Muons: 2; Other: 3
133
134        """
135        df = self.particles_df
136        if self.anti_particles:
137            df["Complex_gauss_pid"] = (
138                df["Complex_gauss_pid"]
139                .map(
140                    defaultdict(
141                        lambda: 3.0,
142                        {
143                            Pid.ANTI_PROTON.value: 0.0,
144                            Pid.NEG_KAON.value: 1.0,
145                            Pid.NEG_PION.value: 2.0,
146                            Pid.ELECTRON.value: 2.0,
147                            Pid.NEG_MUON.value: 2.0,
148                        },
149                    ),
150                    na_action="ignore",
151                )
152                .astype(float)
153            )
154        else:
155            df["Complex_gauss_pid"] = (
156                df["Complex_gauss_pid"]
157                .map(
158                    defaultdict(
159                        lambda: 3.0,
160                        {
161                            Pid.PROTON.value: 0.0,
162                            Pid.POS_KAON.value: 1.0,
163                            Pid.POS_PION.value: 2.0,
164                            Pid.POSITRON.value: 2.0,
165                            Pid.POS_MUON.value: 2.0,
166                        },
167                    ),
168                    na_action="ignore",
169                )
170                .astype(float)
171            )
172        self.particles_df = df

Remaps Pid of particles to output format from XGBoost Model. Protons: 0; Kaons: 1; Pions, Electrons, Muons: 2; Other: 3

def parse_args(args: List[str]) -> argparse.Namespace:
259def parse_args(args: List[str]) -> argparse.Namespace:
260    """
261    Arguments parser for the main method.
262
263    Args:
264        args (List[str]): Arguments from the command line, should be sys.argv[1:].
265
266    Returns:
267        argparse.Namespace: argparse.Namespace containg args
268    """
269    parser = argparse.ArgumentParser(
270        prog="ML_PID_CBM ValidatGauss",
271        description="Program for validating Gaussian PID model",
272    )
273    parser.add_argument(
274        "--config",
275        "-c",
276        nargs=1,
277        required=True,
278        type=str,
279        help="Filename of path of config json file.",
280    )
281    parser.add_argument(
282        "--nworkers",
283        "-n",
284        type=int,
285        default=1,
286        help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.",
287    )
288    parser.add_argument(
289        "--momentum",
290        "-p",
291        nargs=2,
292        required=True,
293        type=float,
294        help="Lower and upper momentum limit, e.g., 1 3",
295    )
296    return parser.parse_args(args)

Arguments parser for the main method.

Args: args (List[str]): Arguments from the command line, should be sys.argv[1:].

Returns: argparse.Namespace: argparse.Namespace containg args