ml_pid_cbm.validate_model

  1import argparse
  2import io
  3import os
  4import re
  5import sys
  6from collections import defaultdict
  7from typing import List, Tuple
  8
  9import numpy as np
 10import pandas as pd
 11from hipe4ml.model_handler import ModelHandler
 12from sklearn.metrics import confusion_matrix
 13
 14from tools import json_tools, plotting_tools
 15from tools.load_data import LoadData
 16from tools.particles_id import ParticlesId as Pid
 17
 18
 19class ValidateModel:
 20    """
 21    Class for testing the ml model
 22    """
 23
 24    def __init__(
 25        self,
 26        lower_p_cut: float,
 27        upper_p_cut: float,
 28        anti_particles: bool,
 29        json_file_name: str,
 30        particles_df: pd.DataFrame,
 31    ):
 32        self.lower_p_cut = lower_p_cut
 33        self.upper_p_cut = upper_p_cut
 34        self.anti_particles = anti_particles
 35        self.json_file_name = json_file_name
 36        self.particles_df = particles_df
 37        self.pid_variable_name = json_tools.load_var_name(self.json_file_name, "pid")
 38        self.mass2_variable_name = json_tools.load_var_name(
 39            self.json_file_name, "mass2"
 40        )
 41        self.classes_names = ["protons", "kaons", "pions", "bckgr"]
 42
 43    def get_n_classes(self):
 44        return len(self.classes_names)
 45
 46    def xgb_preds(self, proba_proton: float, proba_kaon: float, proba_pion: float):
 47        """Gets particle type as selected by xgboost model if above probability threshold.
 48
 49        Args:
 50            proba_proton (float): Probablity threshold to classify particle as proton.
 51            proba_kaon (float): Probablity threshold to classify particle as kaon.
 52            proba_pion (float): Probablity threshold to classify particle as pion.
 53        """
 54        df = self.particles_df
 55        df["xgb_preds"] = (
 56            df[["model_output_0", "model_output_1", "model_output_2"]]
 57            .idxmax(axis=1)
 58            .map(lambda x: x.lstrip("model_output_"))
 59            .astype(int)
 60        )
 61        # setting to bckgr if smaller than probability threshold
 62        proton = (df["xgb_preds"] == 0) & (df["model_output_0"] > proba_proton)
 63        pion = (df["xgb_preds"] == 1) & (df["model_output_1"] > proba_kaon)
 64        kaon = (df["xgb_preds"] == 2) & (df["model_output_2"] > proba_pion)
 65        df.loc[~(proton | pion | kaon), "xgb_preds"] = 3
 66
 67        self.particles_df = df
 68
 69    def remap_names(self):
 70        """
 71        Remaps Pid of particles to output format from XGBoost Model.
 72        Protons: 0; Kaons: 1; Pions, Electrons, Muons: 2; Other: 3
 73
 74        """
 75        df = self.particles_df
 76        if self.anti_particles:
 77            df[self.pid_variable_name] = (
 78                df[self.pid_variable_name]
 79                .map(
 80                    defaultdict(
 81                        lambda: 3.0,
 82                        {
 83                            Pid.ANTI_PROTON.value: 0.0,
 84                            Pid.NEG_KAON.value: 1.0,
 85                            Pid.NEG_PION.value: 2.0,
 86                            Pid.ELECTRON.value: 2.0,
 87                            Pid.NEG_MUON.value: 2.0,
 88                        },
 89                    ),
 90                    na_action="ignore",
 91                )
 92                .astype(float)
 93            )
 94        else:
 95            df[self.pid_variable_name] = (
 96                df[self.pid_variable_name]
 97                .map(
 98                    defaultdict(
 99                        lambda: 3.0,
100                        {
101                            Pid.PROTON.value: 0.0,
102                            Pid.POS_KAON.value: 1.0,
103                            Pid.POS_PION.value: 2.0,
104                            Pid.POSITRON.value: 2.0,
105                            Pid.POS_MUON.value: 2.0,
106                        },
107                    ),
108                    na_action="ignore",
109                )
110                .astype(float)
111            )
112        self.particles_df = df
113
114    def save_df(self):
115        """
116        Saves dataframe with validated data into pickle format.
117        """
118        self.particles_df.to_pickle("validated_data.pickle")
119
120    def sigma_selection(self, pid: float, nsigma: float = 5, info: bool = False):
121        """Sigma selection for dataframe to remove systmatically (not by the ML model) mismatched particles.
122
123        Args:
124            pid (float): Pid of particle for this selection
125            nsigma (float, optional): _description_. Defaults to 5.
126            info (bool, optional): _description_. Defaults to False.
127        """
128        df = self.particles_df
129        # for selected pid
130        mean = df[df[self.pid_variable_name] == pid][self.mass2_variable_name].mean()
131        std = df[df[self.pid_variable_name] == pid][self.mass2_variable_name].std()
132        outside_sigma = (df[self.pid_variable_name] == pid) & (
133            (df[self.mass2_variable_name] < (mean - nsigma * std))
134            | (df[self.mass2_variable_name] > (mean + nsigma * std))
135        )
136        df_sigma_selected = df[~outside_sigma]
137        if info:
138            df_len = len(df)
139            df1_len = len(df_sigma_selected)
140            print(
141                "we get rid of "
142                + str(round((df_len - df1_len) / df_len * 100, 2))
143                + " % of pid = "
144                + str(pid)
145                + " particle entries"
146            )
147        self.particles_df = df_sigma_selected
148
149    def evaluate_probas(
150        self,
151        start: float = 0.3,
152        stop: float = 0.98,
153        n_steps: int = 30,
154        purity_cut: float = 0.0,
155        save_fig: bool = True,
156    ) -> Tuple[float, float, float]:
157        """Method for evaluating probability (BDT) cut effect on efficency and purity.
158
159        Args:
160            start (float, optional): Lower range of probablity cuts. Defaults to 0.3.
161            stop (float, optional): Upper range of probablity cuts. Defaults to 0.98.
162            n_steps (int, optional): Number of probability cuts to try. Defaults to 30.
163            pid_variable_name (str, optional): Name of the variable containing true Pid. Defaults to "Complex_pid".
164            purity_cut (float, optional): Minimal purity for automatic cuts selection. Defaults to 0..
165            save_fig (bool, optional): Saves figures (BDT cut vs efficiency and purity) to file if True. Defaults to True.
166
167        Returns:
168            Tuple[float, float, float]: Probability cut for each variable.
169        """
170        print(
171            f"Checking efficiency and purity for {int(n_steps)} probablity cuts between {start}, and {stop}..."
172        )
173        probas = np.linspace(start, stop, n_steps)
174        efficienciess_protons, efficiencies_kaons, efficiencies_pions = [], [], []
175        efficiencies = [efficienciess_protons, efficiencies_kaons, efficiencies_pions]
176        purities_protons, purities_kaons, purities_pions = [], [], []
177        purities = [purities_protons, purities_kaons, purities_pions]
178        best_cuts = [0.0, 0.0, 0.0]
179        max_efficiencies = [0.0, 0.0, 0.0]
180        max_purities = [0.0, 0.0, 0.0]
181
182        for proba in probas:
183            self.xgb_preds(proba, proba, proba)
184            # confusion matrix
185            cnf_matrix = confusion_matrix(
186                self.particles_df[self.pid_variable_name],
187                self.particles_df["xgb_preds"],
188            )
189            for pid in range(self.get_n_classes() - 1):
190                efficiency, purity = self.efficiency_stats(
191                    cnf_matrix, pid, print_output=False
192                )
193                efficiencies[pid].append(efficiency)
194                purities[pid].append(purity)
195                if purity_cut > 0.0:
196                    # Minimal purity for automatic threshold selection.
197                    # Will choose the highest efficiency for purity above this value.
198                    if purity >= purity_cut:
199                        if efficiency > max_efficiencies[pid]:
200                            best_cuts[pid] = proba
201                            max_efficiencies[pid] = efficiency
202                            max_purities[pid] = purity
203                    # If max purity is below this value, will choose the highest purity available.
204                    else:
205                        if purity > max_purities[pid]:
206                            best_cuts[pid] = proba
207                            max_efficiencies[pid] = efficiency
208                            max_purities[pid] = purity
209
210        plotting_tools.plot_efficiency_purity(probas, efficiencies, purities, save_fig)
211        if save_fig:
212            print("Plots ready!")
213        if purity_cut > 0:
214            print(f"Selected probaility cuts: {best_cuts}")
215            return (best_cuts[0], best_cuts[1], best_cuts[2])
216        else:
217            return (-1.0, -1.0, -1.0)
218    
219    @staticmethod
220    def efficiency_stats(
221    cnf_matrix: np.ndarray,
222    pid: int,
223    txt_tile: io.TextIOWrapper = None,
224    print_output: bool = True,
225    ) -> Tuple[float, float]:
226        """
227        Prints efficiency stats from confusion matrix into efficiency_stats.txt file and stdout.
228        Efficiency is calculated as correctly identified X / all true simulated X
229        Purity is calculated as correctly identified X / all identified X
230
231        Args:
232            cnf_matrix (np.ndarray): Confusion matrix generated by sklearn.metrics.confusion_matrix.
233            pid (int): Pid of particles to print efficiency stats.
234            txt_tile (io.TextIOWrapper): Text file to write the output. Defaults to None.
235            print_output (bool): Whether to print the output to stdout. Defaults to True.
236
237        Returns:
238            Tuple[float, float]: Tuple with efficiency and purity
239        """
240        all_simulated_signal = cnf_matrix[pid].sum()
241        true_signal = cnf_matrix[pid][pid]
242        false_signal = cnf_matrix[:, pid].sum() - true_signal
243        reconstructed_signals = true_signal + false_signal
244
245        efficiency = (true_signal / all_simulated_signal) * 100
246        purity = (true_signal / reconstructed_signals) * 100
247
248        stats = f"""
249        For particle ID = {pid}: 
250        Efficiency: {efficiency:.2f}%
251        Purity: {purity:.2f}%
252        """
253
254        if print_output:
255            print(stats)
256
257        if txt_tile is not None:
258            txt_tile.writelines(stats)
259
260        return (efficiency, purity)
261
262    def confusion_matrix_and_stats(
263        self, efficiency_filename: str = "efficiency_stats.txt"
264    ):
265        """
266        Generates confusion matrix and efficiency/purity stats.
267        """
268        cnf_matrix = confusion_matrix(
269            self.particles_df[self.pid_variable_name], self.particles_df["xgb_preds"]
270        )
271        plotting_tools.plot_confusion_matrix(cnf_matrix)
272        plotting_tools.plot_confusion_matrix(cnf_matrix, normalize=True)
273        txt_file = open(efficiency_filename, "w+")
274        for pid in range(self.get_n_classes() - 1):
275            self.efficiency_stats(cnf_matrix, pid, txt_file)
276        txt_file.close()
277
278    def generate_plots(self):
279        """
280        Generate tof, mass2, vars, and pT-rapidity plots
281        """
282        self._tof_plots()
283        self._mass2_plots()
284        self._vars_distributions_plots()
285
286    def _tof_plots(self):
287        """
288        Generates tof plots.
289        """
290        for pid, particle_name in enumerate(self.classes_names):
291            # simulated:
292            try:
293                plotting_tools.tof_plot(
294                    self.particles_df[self.particles_df[self.pid_variable_name] == pid],
295                    self.json_file_name,
296                    f"{particle_name} (all simulated)",
297                )
298            except ValueError:
299                print(f"No simulated {particle_name}s")
300            # xgb selected
301            try:
302                plotting_tools.tof_plot(
303                    self.particles_df[self.particles_df["xgb_preds"] == pid],
304                    self.json_file_name,
305                    f"{particle_name} (XGB-selected)",
306                )
307            except ValueError:
308                print(f"No XGB-selected {particle_name}s")
309
310    def _mass2_plots(self):
311        """
312        Generates mass2 plots.
313        """
314        protons_range = (-0.2, 1.8)
315        kaons_range = (-0.2, 0.6)
316        pions_range = (-0.3, 0.3)
317        ranges = [protons_range, kaons_range, pions_range, pions_range]
318        for pid, particle_name in enumerate(self.classes_names):
319            plotting_tools.plot_mass2(
320                self.particles_df[self.particles_df["xgb_preds"] == pid][
321                    self.mass2_variable_name
322                ],
323                self.particles_df[self.particles_df[self.pid_variable_name] == pid][
324                    self.mass2_variable_name
325                ],
326                particle_name,
327                ranges[pid],
328            )
329            plotting_tools.plot_all_particles_mass2(
330                self.particles_df[self.particles_df["xgb_preds"] == pid],
331                self.mass2_variable_name,
332                self.pid_variable_name,
333                particle_name,
334                ranges[pid],
335            )
336
337    def _vars_distributions_plots(self):
338        """
339        Generates distributions of variables and pT-rapidity graphs.
340        """
341        vars_to_draw = json_tools.load_vars_to_draw(self.json_file_name)
342        for pid, particle_name in enumerate(self.classes_names):
343            plotting_tools.var_distributions_plot(
344                vars_to_draw,
345                [
346                    self.particles_df[
347                        (self.particles_df[self.pid_variable_name] == pid)
348                    ],
349                    self.particles_df[
350                        (
351                            (self.particles_df[self.pid_variable_name] == pid)
352                            & (self.particles_df["xgb_preds"] == pid)
353                        )
354                    ],
355                    self.particles_df[
356                        (
357                            (self.particles_df[self.pid_variable_name] != pid)
358                            & (self.particles_df["xgb_preds"] == pid)
359                        )
360                    ],
361                ],
362                [
363                    f"true MC {particle_name}",
364                    f"true selected {particle_name}",
365                    f"false selected {particle_name}",
366                ],
367                filename=f"vars_dist_{particle_name}",
368            )
369            plotting_tools.plot_eff_pT_rap(self.particles_df, pid)
370            plotting_tools.plot_pt_rapidity(self.particles_df, pid)
371
372    @staticmethod
373    def parse_model_name(
374        name: str,
375        pattern: str = r"model_([\d.]+)_([\d.]+)_(anti)|model_([\d.]+)_([\d.]+)_([a-zA-Z]+)",
376    ) -> Tuple[float, float, bool]:
377        """Parser model name to get info about lower momentum cut, upper momentum cut, and if model is trained for anti_particles.
378
379        Args:
380            name (str): Name of the model.
381            pattern (_type_, optional): Pattern of model name.
382             Defaults to r"model_([\d.]+)_([\d.]+)_(anti)|model_([\d.]+)_([\d.]+)_([a-zA-Z]+)".
383
384        Raises:
385            ValueError: Raises error if model name incorrect.
386
387        Returns:
388            Tuple[float, float, bool]: Tuple containing lower_p_cut, upper_p_cut, is_anti
389        """
390        match = re.match(pattern, name)
391        if match:
392            if match.group(3):
393                lower_p_cut = float(match.group(1))
394                upper_p_cut = float(match.group(2))
395                is_anti = True
396            else:
397                lower_p_cut = float(match.group(4))
398                upper_p_cut = float(match.group(5))
399                is_anti = False
400        else:
401            raise ValueError("Incorrect model name, regex not found.")
402        return (lower_p_cut, upper_p_cut, is_anti)
403
404
405def parse_args(args: List[str]) -> argparse.Namespace:
406    """
407    Arguments parser for the main method.
408
409    Args:
410        args (List[str]): Arguments from the command line, should be sys.argv[1:].
411
412    Returns:
413        argparse.Namespace: argparse.Namespace containg args
414    """
415    parser = argparse.ArgumentParser(
416        prog="ML_PID_CBM ValidateModel",
417        description="Program for validating PID ML models",
418    )
419    parser.add_argument(
420        "--config",
421        "-c",
422        nargs=1,
423        required=True,
424        type=str,
425        help="Filename of path of config json file.",
426    )
427    parser.add_argument(
428        "--modelname",
429        "-m",
430        nargs=1,
431        required=True,
432        type=str,
433        help="Name of folder containing trained ml model.",
434    )
435    proba_group = parser.add_mutually_exclusive_group(required=True)
436    proba_group.add_argument(
437        "--probabilitycuts",
438        "-p",
439        nargs=3,
440        type=float,
441        help="Probability cut value for respectively protons, kaons, and pions. E.g., 0.9 0.95 0.9",
442    )
443    proba_group.add_argument(
444        "--evaluateproba",
445        "-e",
446        nargs=3,
447        type=float,
448        help="Minimal probability cut, maximal, and number of steps to investigate.",
449    )
450    parser.add_argument(
451        "--nworkers",
452        "-n",
453        type=int,
454        default=1,
455        help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.",
456    )
457    decision_group = parser.add_mutually_exclusive_group()
458    decision_group.add_argument(
459        "--interactive",
460        "-i",
461        action="store_true",
462        help="Interactive mode allows selection of probability cuts after evaluating them.",
463    )
464    decision_group.add_argument(
465        "--automatic",
466        "-a",
467        nargs=1,
468        type=float,
469        help="""Minimal purity for automatic threshold selection (in percent) e.g., 90.
470        Will choose the highest efficiency for purity above this value.
471        If max purity is below this value, will choose the highest purity available.""",
472    )
473    return parser.parse_args(args)
474
475
476if __name__ == "__main__":
477    # parser for main class
478    args = parse_args(sys.argv[1:])
479    # config  arguments to be loaded from args
480    json_file_name = args.config[0]
481    model_name = args.modelname[0]
482    proba_proton, proba_kaon, proba_pion = (
483        (args.probabilitycuts[0], args.probabilitycuts[1], args.probabilitycuts[2])
484        if args.probabilitycuts is not None
485        else (-1.0, -1.0, -1.0)
486    )
487
488    n_workers = args.nworkers
489    purity_cut = args.automatic[0] if args.automatic is not None else 0.0
490    lower_p, upper_p, is_anti = ValidateModel.parse_model_name(model_name)
491    # loading test data
492    data_file_name = json_tools.load_file_name(json_file_name, "test")
493
494    loader = LoadData(data_file_name, json_file_name, lower_p, upper_p, is_anti)
495    # sigma selection
496    # loading model handler and applying on dataset
497    print(
498        f"\nLoading data from {data_file_name}\nApplying model handler from {model_name}"
499    )
500    os.chdir(f"{model_name}")
501    model_hdl = ModelHandler()
502    model_hdl.load_model_handler(model_name)
503    test_particles = loader.load_tree(model_handler=model_hdl, max_workers=n_workers)
504    # validate model object
505    validate = ValidateModel(
506        lower_p, upper_p, is_anti, json_file_name, test_particles.get_data_frame()
507    )
508    # remap Pid to match output XGBoost format
509    validate.remap_names()
510    pid_variable_name = json_tools.load_var_name(json_file_name, "pid")
511    # set probability cuts
512    if args.evaluateproba is not None:
513        proba_proton, proba_kaon, proba_pion = validate.evaluate_probas(
514            args.evaluateproba[0],
515            args.evaluateproba[1],
516            int(args.evaluateproba[2]),
517            purity_cut,
518            not args.interactive,
519        )
520        if args.interactive:
521            while proba_proton < 0 or proba_proton > 1:
522                proba_proton = float(
523                    input(
524                        "Enter the probability threshold for proton (between 0 and 1): "
525                    )
526                )
527
528            while proba_kaon < 0 or proba_kaon > 1:
529                proba_kaon = float(
530                    input(
531                        "Enter the probability threshold for kaon (between 0 and 1): "
532                    )
533                )
534
535            while proba_pion < 0 or proba_pion > 1:
536                proba_pion = float(
537                    input(
538                        "Enter the probability threshold for pion (between 0 and 1): "
539                    )
540                )
541    # if probabilites are set
542    # apply probabilty cuts
543    print(
544        f"\nApplying probability cuts.\nFor protons: {proba_proton}\nFor kaons: {proba_kaon}\nFor pions: {proba_pion}"
545    )
546    validate.xgb_preds(proba_proton, proba_kaon, proba_pion)
547    # graphs
548    validate.confusion_matrix_and_stats()
549    print("Generating plots...")
550    validate.generate_plots()
551    # save validated dataset
552    validate.save_df()
class ValidateModel:
 20class ValidateModel:
 21    """
 22    Class for testing the ml model
 23    """
 24
 25    def __init__(
 26        self,
 27        lower_p_cut: float,
 28        upper_p_cut: float,
 29        anti_particles: bool,
 30        json_file_name: str,
 31        particles_df: pd.DataFrame,
 32    ):
 33        self.lower_p_cut = lower_p_cut
 34        self.upper_p_cut = upper_p_cut
 35        self.anti_particles = anti_particles
 36        self.json_file_name = json_file_name
 37        self.particles_df = particles_df
 38        self.pid_variable_name = json_tools.load_var_name(self.json_file_name, "pid")
 39        self.mass2_variable_name = json_tools.load_var_name(
 40            self.json_file_name, "mass2"
 41        )
 42        self.classes_names = ["protons", "kaons", "pions", "bckgr"]
 43
 44    def get_n_classes(self):
 45        return len(self.classes_names)
 46
 47    def xgb_preds(self, proba_proton: float, proba_kaon: float, proba_pion: float):
 48        """Gets particle type as selected by xgboost model if above probability threshold.
 49
 50        Args:
 51            proba_proton (float): Probablity threshold to classify particle as proton.
 52            proba_kaon (float): Probablity threshold to classify particle as kaon.
 53            proba_pion (float): Probablity threshold to classify particle as pion.
 54        """
 55        df = self.particles_df
 56        df["xgb_preds"] = (
 57            df[["model_output_0", "model_output_1", "model_output_2"]]
 58            .idxmax(axis=1)
 59            .map(lambda x: x.lstrip("model_output_"))
 60            .astype(int)
 61        )
 62        # setting to bckgr if smaller than probability threshold
 63        proton = (df["xgb_preds"] == 0) & (df["model_output_0"] > proba_proton)
 64        pion = (df["xgb_preds"] == 1) & (df["model_output_1"] > proba_kaon)
 65        kaon = (df["xgb_preds"] == 2) & (df["model_output_2"] > proba_pion)
 66        df.loc[~(proton | pion | kaon), "xgb_preds"] = 3
 67
 68        self.particles_df = df
 69
 70    def remap_names(self):
 71        """
 72        Remaps Pid of particles to output format from XGBoost Model.
 73        Protons: 0; Kaons: 1; Pions, Electrons, Muons: 2; Other: 3
 74
 75        """
 76        df = self.particles_df
 77        if self.anti_particles:
 78            df[self.pid_variable_name] = (
 79                df[self.pid_variable_name]
 80                .map(
 81                    defaultdict(
 82                        lambda: 3.0,
 83                        {
 84                            Pid.ANTI_PROTON.value: 0.0,
 85                            Pid.NEG_KAON.value: 1.0,
 86                            Pid.NEG_PION.value: 2.0,
 87                            Pid.ELECTRON.value: 2.0,
 88                            Pid.NEG_MUON.value: 2.0,
 89                        },
 90                    ),
 91                    na_action="ignore",
 92                )
 93                .astype(float)
 94            )
 95        else:
 96            df[self.pid_variable_name] = (
 97                df[self.pid_variable_name]
 98                .map(
 99                    defaultdict(
100                        lambda: 3.0,
101                        {
102                            Pid.PROTON.value: 0.0,
103                            Pid.POS_KAON.value: 1.0,
104                            Pid.POS_PION.value: 2.0,
105                            Pid.POSITRON.value: 2.0,
106                            Pid.POS_MUON.value: 2.0,
107                        },
108                    ),
109                    na_action="ignore",
110                )
111                .astype(float)
112            )
113        self.particles_df = df
114
115    def save_df(self):
116        """
117        Saves dataframe with validated data into pickle format.
118        """
119        self.particles_df.to_pickle("validated_data.pickle")
120
121    def sigma_selection(self, pid: float, nsigma: float = 5, info: bool = False):
122        """Sigma selection for dataframe to remove systmatically (not by the ML model) mismatched particles.
123
124        Args:
125            pid (float): Pid of particle for this selection
126            nsigma (float, optional): _description_. Defaults to 5.
127            info (bool, optional): _description_. Defaults to False.
128        """
129        df = self.particles_df
130        # for selected pid
131        mean = df[df[self.pid_variable_name] == pid][self.mass2_variable_name].mean()
132        std = df[df[self.pid_variable_name] == pid][self.mass2_variable_name].std()
133        outside_sigma = (df[self.pid_variable_name] == pid) & (
134            (df[self.mass2_variable_name] < (mean - nsigma * std))
135            | (df[self.mass2_variable_name] > (mean + nsigma * std))
136        )
137        df_sigma_selected = df[~outside_sigma]
138        if info:
139            df_len = len(df)
140            df1_len = len(df_sigma_selected)
141            print(
142                "we get rid of "
143                + str(round((df_len - df1_len) / df_len * 100, 2))
144                + " % of pid = "
145                + str(pid)
146                + " particle entries"
147            )
148        self.particles_df = df_sigma_selected
149
150    def evaluate_probas(
151        self,
152        start: float = 0.3,
153        stop: float = 0.98,
154        n_steps: int = 30,
155        purity_cut: float = 0.0,
156        save_fig: bool = True,
157    ) -> Tuple[float, float, float]:
158        """Method for evaluating probability (BDT) cut effect on efficency and purity.
159
160        Args:
161            start (float, optional): Lower range of probablity cuts. Defaults to 0.3.
162            stop (float, optional): Upper range of probablity cuts. Defaults to 0.98.
163            n_steps (int, optional): Number of probability cuts to try. Defaults to 30.
164            pid_variable_name (str, optional): Name of the variable containing true Pid. Defaults to "Complex_pid".
165            purity_cut (float, optional): Minimal purity for automatic cuts selection. Defaults to 0..
166            save_fig (bool, optional): Saves figures (BDT cut vs efficiency and purity) to file if True. Defaults to True.
167
168        Returns:
169            Tuple[float, float, float]: Probability cut for each variable.
170        """
171        print(
172            f"Checking efficiency and purity for {int(n_steps)} probablity cuts between {start}, and {stop}..."
173        )
174        probas = np.linspace(start, stop, n_steps)
175        efficienciess_protons, efficiencies_kaons, efficiencies_pions = [], [], []
176        efficiencies = [efficienciess_protons, efficiencies_kaons, efficiencies_pions]
177        purities_protons, purities_kaons, purities_pions = [], [], []
178        purities = [purities_protons, purities_kaons, purities_pions]
179        best_cuts = [0.0, 0.0, 0.0]
180        max_efficiencies = [0.0, 0.0, 0.0]
181        max_purities = [0.0, 0.0, 0.0]
182
183        for proba in probas:
184            self.xgb_preds(proba, proba, proba)
185            # confusion matrix
186            cnf_matrix = confusion_matrix(
187                self.particles_df[self.pid_variable_name],
188                self.particles_df["xgb_preds"],
189            )
190            for pid in range(self.get_n_classes() - 1):
191                efficiency, purity = self.efficiency_stats(
192                    cnf_matrix, pid, print_output=False
193                )
194                efficiencies[pid].append(efficiency)
195                purities[pid].append(purity)
196                if purity_cut > 0.0:
197                    # Minimal purity for automatic threshold selection.
198                    # Will choose the highest efficiency for purity above this value.
199                    if purity >= purity_cut:
200                        if efficiency > max_efficiencies[pid]:
201                            best_cuts[pid] = proba
202                            max_efficiencies[pid] = efficiency
203                            max_purities[pid] = purity
204                    # If max purity is below this value, will choose the highest purity available.
205                    else:
206                        if purity > max_purities[pid]:
207                            best_cuts[pid] = proba
208                            max_efficiencies[pid] = efficiency
209                            max_purities[pid] = purity
210
211        plotting_tools.plot_efficiency_purity(probas, efficiencies, purities, save_fig)
212        if save_fig:
213            print("Plots ready!")
214        if purity_cut > 0:
215            print(f"Selected probaility cuts: {best_cuts}")
216            return (best_cuts[0], best_cuts[1], best_cuts[2])
217        else:
218            return (-1.0, -1.0, -1.0)
219    
220    @staticmethod
221    def efficiency_stats(
222    cnf_matrix: np.ndarray,
223    pid: int,
224    txt_tile: io.TextIOWrapper = None,
225    print_output: bool = True,
226    ) -> Tuple[float, float]:
227        """
228        Prints efficiency stats from confusion matrix into efficiency_stats.txt file and stdout.
229        Efficiency is calculated as correctly identified X / all true simulated X
230        Purity is calculated as correctly identified X / all identified X
231
232        Args:
233            cnf_matrix (np.ndarray): Confusion matrix generated by sklearn.metrics.confusion_matrix.
234            pid (int): Pid of particles to print efficiency stats.
235            txt_tile (io.TextIOWrapper): Text file to write the output. Defaults to None.
236            print_output (bool): Whether to print the output to stdout. Defaults to True.
237
238        Returns:
239            Tuple[float, float]: Tuple with efficiency and purity
240        """
241        all_simulated_signal = cnf_matrix[pid].sum()
242        true_signal = cnf_matrix[pid][pid]
243        false_signal = cnf_matrix[:, pid].sum() - true_signal
244        reconstructed_signals = true_signal + false_signal
245
246        efficiency = (true_signal / all_simulated_signal) * 100
247        purity = (true_signal / reconstructed_signals) * 100
248
249        stats = f"""
250        For particle ID = {pid}: 
251        Efficiency: {efficiency:.2f}%
252        Purity: {purity:.2f}%
253        """
254
255        if print_output:
256            print(stats)
257
258        if txt_tile is not None:
259            txt_tile.writelines(stats)
260
261        return (efficiency, purity)
262
263    def confusion_matrix_and_stats(
264        self, efficiency_filename: str = "efficiency_stats.txt"
265    ):
266        """
267        Generates confusion matrix and efficiency/purity stats.
268        """
269        cnf_matrix = confusion_matrix(
270            self.particles_df[self.pid_variable_name], self.particles_df["xgb_preds"]
271        )
272        plotting_tools.plot_confusion_matrix(cnf_matrix)
273        plotting_tools.plot_confusion_matrix(cnf_matrix, normalize=True)
274        txt_file = open(efficiency_filename, "w+")
275        for pid in range(self.get_n_classes() - 1):
276            self.efficiency_stats(cnf_matrix, pid, txt_file)
277        txt_file.close()
278
279    def generate_plots(self):
280        """
281        Generate tof, mass2, vars, and pT-rapidity plots
282        """
283        self._tof_plots()
284        self._mass2_plots()
285        self._vars_distributions_plots()
286
287    def _tof_plots(self):
288        """
289        Generates tof plots.
290        """
291        for pid, particle_name in enumerate(self.classes_names):
292            # simulated:
293            try:
294                plotting_tools.tof_plot(
295                    self.particles_df[self.particles_df[self.pid_variable_name] == pid],
296                    self.json_file_name,
297                    f"{particle_name} (all simulated)",
298                )
299            except ValueError:
300                print(f"No simulated {particle_name}s")
301            # xgb selected
302            try:
303                plotting_tools.tof_plot(
304                    self.particles_df[self.particles_df["xgb_preds"] == pid],
305                    self.json_file_name,
306                    f"{particle_name} (XGB-selected)",
307                )
308            except ValueError:
309                print(f"No XGB-selected {particle_name}s")
310
311    def _mass2_plots(self):
312        """
313        Generates mass2 plots.
314        """
315        protons_range = (-0.2, 1.8)
316        kaons_range = (-0.2, 0.6)
317        pions_range = (-0.3, 0.3)
318        ranges = [protons_range, kaons_range, pions_range, pions_range]
319        for pid, particle_name in enumerate(self.classes_names):
320            plotting_tools.plot_mass2(
321                self.particles_df[self.particles_df["xgb_preds"] == pid][
322                    self.mass2_variable_name
323                ],
324                self.particles_df[self.particles_df[self.pid_variable_name] == pid][
325                    self.mass2_variable_name
326                ],
327                particle_name,
328                ranges[pid],
329            )
330            plotting_tools.plot_all_particles_mass2(
331                self.particles_df[self.particles_df["xgb_preds"] == pid],
332                self.mass2_variable_name,
333                self.pid_variable_name,
334                particle_name,
335                ranges[pid],
336            )
337
338    def _vars_distributions_plots(self):
339        """
340        Generates distributions of variables and pT-rapidity graphs.
341        """
342        vars_to_draw = json_tools.load_vars_to_draw(self.json_file_name)
343        for pid, particle_name in enumerate(self.classes_names):
344            plotting_tools.var_distributions_plot(
345                vars_to_draw,
346                [
347                    self.particles_df[
348                        (self.particles_df[self.pid_variable_name] == pid)
349                    ],
350                    self.particles_df[
351                        (
352                            (self.particles_df[self.pid_variable_name] == pid)
353                            & (self.particles_df["xgb_preds"] == pid)
354                        )
355                    ],
356                    self.particles_df[
357                        (
358                            (self.particles_df[self.pid_variable_name] != pid)
359                            & (self.particles_df["xgb_preds"] == pid)
360                        )
361                    ],
362                ],
363                [
364                    f"true MC {particle_name}",
365                    f"true selected {particle_name}",
366                    f"false selected {particle_name}",
367                ],
368                filename=f"vars_dist_{particle_name}",
369            )
370            plotting_tools.plot_eff_pT_rap(self.particles_df, pid)
371            plotting_tools.plot_pt_rapidity(self.particles_df, pid)
372
373    @staticmethod
374    def parse_model_name(
375        name: str,
376        pattern: str = r"model_([\d.]+)_([\d.]+)_(anti)|model_([\d.]+)_([\d.]+)_([a-zA-Z]+)",
377    ) -> Tuple[float, float, bool]:
378        """Parser model name to get info about lower momentum cut, upper momentum cut, and if model is trained for anti_particles.
379
380        Args:
381            name (str): Name of the model.
382            pattern (_type_, optional): Pattern of model name.
383             Defaults to r"model_([\d.]+)_([\d.]+)_(anti)|model_([\d.]+)_([\d.]+)_([a-zA-Z]+)".
384
385        Raises:
386            ValueError: Raises error if model name incorrect.
387
388        Returns:
389            Tuple[float, float, bool]: Tuple containing lower_p_cut, upper_p_cut, is_anti
390        """
391        match = re.match(pattern, name)
392        if match:
393            if match.group(3):
394                lower_p_cut = float(match.group(1))
395                upper_p_cut = float(match.group(2))
396                is_anti = True
397            else:
398                lower_p_cut = float(match.group(4))
399                upper_p_cut = float(match.group(5))
400                is_anti = False
401        else:
402            raise ValueError("Incorrect model name, regex not found.")
403        return (lower_p_cut, upper_p_cut, is_anti)

Class for testing the ml model

ValidateModel( lower_p_cut: float, upper_p_cut: float, anti_particles: bool, json_file_name: str, particles_df: pandas.core.frame.DataFrame)
25    def __init__(
26        self,
27        lower_p_cut: float,
28        upper_p_cut: float,
29        anti_particles: bool,
30        json_file_name: str,
31        particles_df: pd.DataFrame,
32    ):
33        self.lower_p_cut = lower_p_cut
34        self.upper_p_cut = upper_p_cut
35        self.anti_particles = anti_particles
36        self.json_file_name = json_file_name
37        self.particles_df = particles_df
38        self.pid_variable_name = json_tools.load_var_name(self.json_file_name, "pid")
39        self.mass2_variable_name = json_tools.load_var_name(
40            self.json_file_name, "mass2"
41        )
42        self.classes_names = ["protons", "kaons", "pions", "bckgr"]
def get_n_classes(self):
44    def get_n_classes(self):
45        return len(self.classes_names)
def xgb_preds(self, proba_proton: float, proba_kaon: float, proba_pion: float):
47    def xgb_preds(self, proba_proton: float, proba_kaon: float, proba_pion: float):
48        """Gets particle type as selected by xgboost model if above probability threshold.
49
50        Args:
51            proba_proton (float): Probablity threshold to classify particle as proton.
52            proba_kaon (float): Probablity threshold to classify particle as kaon.
53            proba_pion (float): Probablity threshold to classify particle as pion.
54        """
55        df = self.particles_df
56        df["xgb_preds"] = (
57            df[["model_output_0", "model_output_1", "model_output_2"]]
58            .idxmax(axis=1)
59            .map(lambda x: x.lstrip("model_output_"))
60            .astype(int)
61        )
62        # setting to bckgr if smaller than probability threshold
63        proton = (df["xgb_preds"] == 0) & (df["model_output_0"] > proba_proton)
64        pion = (df["xgb_preds"] == 1) & (df["model_output_1"] > proba_kaon)
65        kaon = (df["xgb_preds"] == 2) & (df["model_output_2"] > proba_pion)
66        df.loc[~(proton | pion | kaon), "xgb_preds"] = 3
67
68        self.particles_df = df

Gets particle type as selected by xgboost model if above probability threshold.

Args: proba_proton (float): Probablity threshold to classify particle as proton. proba_kaon (float): Probablity threshold to classify particle as kaon. proba_pion (float): Probablity threshold to classify particle as pion.

def remap_names(self):
 70    def remap_names(self):
 71        """
 72        Remaps Pid of particles to output format from XGBoost Model.
 73        Protons: 0; Kaons: 1; Pions, Electrons, Muons: 2; Other: 3
 74
 75        """
 76        df = self.particles_df
 77        if self.anti_particles:
 78            df[self.pid_variable_name] = (
 79                df[self.pid_variable_name]
 80                .map(
 81                    defaultdict(
 82                        lambda: 3.0,
 83                        {
 84                            Pid.ANTI_PROTON.value: 0.0,
 85                            Pid.NEG_KAON.value: 1.0,
 86                            Pid.NEG_PION.value: 2.0,
 87                            Pid.ELECTRON.value: 2.0,
 88                            Pid.NEG_MUON.value: 2.0,
 89                        },
 90                    ),
 91                    na_action="ignore",
 92                )
 93                .astype(float)
 94            )
 95        else:
 96            df[self.pid_variable_name] = (
 97                df[self.pid_variable_name]
 98                .map(
 99                    defaultdict(
100                        lambda: 3.0,
101                        {
102                            Pid.PROTON.value: 0.0,
103                            Pid.POS_KAON.value: 1.0,
104                            Pid.POS_PION.value: 2.0,
105                            Pid.POSITRON.value: 2.0,
106                            Pid.POS_MUON.value: 2.0,
107                        },
108                    ),
109                    na_action="ignore",
110                )
111                .astype(float)
112            )
113        self.particles_df = df

Remaps Pid of particles to output format from XGBoost Model. Protons: 0; Kaons: 1; Pions, Electrons, Muons: 2; Other: 3

def save_df(self):
115    def save_df(self):
116        """
117        Saves dataframe with validated data into pickle format.
118        """
119        self.particles_df.to_pickle("validated_data.pickle")

Saves dataframe with validated data into pickle format.

def sigma_selection(self, pid: float, nsigma: float = 5, info: bool = False):
121    def sigma_selection(self, pid: float, nsigma: float = 5, info: bool = False):
122        """Sigma selection for dataframe to remove systmatically (not by the ML model) mismatched particles.
123
124        Args:
125            pid (float): Pid of particle for this selection
126            nsigma (float, optional): _description_. Defaults to 5.
127            info (bool, optional): _description_. Defaults to False.
128        """
129        df = self.particles_df
130        # for selected pid
131        mean = df[df[self.pid_variable_name] == pid][self.mass2_variable_name].mean()
132        std = df[df[self.pid_variable_name] == pid][self.mass2_variable_name].std()
133        outside_sigma = (df[self.pid_variable_name] == pid) & (
134            (df[self.mass2_variable_name] < (mean - nsigma * std))
135            | (df[self.mass2_variable_name] > (mean + nsigma * std))
136        )
137        df_sigma_selected = df[~outside_sigma]
138        if info:
139            df_len = len(df)
140            df1_len = len(df_sigma_selected)
141            print(
142                "we get rid of "
143                + str(round((df_len - df1_len) / df_len * 100, 2))
144                + " % of pid = "
145                + str(pid)
146                + " particle entries"
147            )
148        self.particles_df = df_sigma_selected

Sigma selection for dataframe to remove systmatically (not by the ML model) mismatched particles.

Args: pid (float): Pid of particle for this selection nsigma (float, optional): _description_. Defaults to 5. info (bool, optional): _description_. Defaults to False.

def evaluate_probas( self, start: float = 0.3, stop: float = 0.98, n_steps: int = 30, purity_cut: float = 0.0, save_fig: bool = True) -> Tuple[float, float, float]:
150    def evaluate_probas(
151        self,
152        start: float = 0.3,
153        stop: float = 0.98,
154        n_steps: int = 30,
155        purity_cut: float = 0.0,
156        save_fig: bool = True,
157    ) -> Tuple[float, float, float]:
158        """Method for evaluating probability (BDT) cut effect on efficency and purity.
159
160        Args:
161            start (float, optional): Lower range of probablity cuts. Defaults to 0.3.
162            stop (float, optional): Upper range of probablity cuts. Defaults to 0.98.
163            n_steps (int, optional): Number of probability cuts to try. Defaults to 30.
164            pid_variable_name (str, optional): Name of the variable containing true Pid. Defaults to "Complex_pid".
165            purity_cut (float, optional): Minimal purity for automatic cuts selection. Defaults to 0..
166            save_fig (bool, optional): Saves figures (BDT cut vs efficiency and purity) to file if True. Defaults to True.
167
168        Returns:
169            Tuple[float, float, float]: Probability cut for each variable.
170        """
171        print(
172            f"Checking efficiency and purity for {int(n_steps)} probablity cuts between {start}, and {stop}..."
173        )
174        probas = np.linspace(start, stop, n_steps)
175        efficienciess_protons, efficiencies_kaons, efficiencies_pions = [], [], []
176        efficiencies = [efficienciess_protons, efficiencies_kaons, efficiencies_pions]
177        purities_protons, purities_kaons, purities_pions = [], [], []
178        purities = [purities_protons, purities_kaons, purities_pions]
179        best_cuts = [0.0, 0.0, 0.0]
180        max_efficiencies = [0.0, 0.0, 0.0]
181        max_purities = [0.0, 0.0, 0.0]
182
183        for proba in probas:
184            self.xgb_preds(proba, proba, proba)
185            # confusion matrix
186            cnf_matrix = confusion_matrix(
187                self.particles_df[self.pid_variable_name],
188                self.particles_df["xgb_preds"],
189            )
190            for pid in range(self.get_n_classes() - 1):
191                efficiency, purity = self.efficiency_stats(
192                    cnf_matrix, pid, print_output=False
193                )
194                efficiencies[pid].append(efficiency)
195                purities[pid].append(purity)
196                if purity_cut > 0.0:
197                    # Minimal purity for automatic threshold selection.
198                    # Will choose the highest efficiency for purity above this value.
199                    if purity >= purity_cut:
200                        if efficiency > max_efficiencies[pid]:
201                            best_cuts[pid] = proba
202                            max_efficiencies[pid] = efficiency
203                            max_purities[pid] = purity
204                    # If max purity is below this value, will choose the highest purity available.
205                    else:
206                        if purity > max_purities[pid]:
207                            best_cuts[pid] = proba
208                            max_efficiencies[pid] = efficiency
209                            max_purities[pid] = purity
210
211        plotting_tools.plot_efficiency_purity(probas, efficiencies, purities, save_fig)
212        if save_fig:
213            print("Plots ready!")
214        if purity_cut > 0:
215            print(f"Selected probaility cuts: {best_cuts}")
216            return (best_cuts[0], best_cuts[1], best_cuts[2])
217        else:
218            return (-1.0, -1.0, -1.0)

Method for evaluating probability (BDT) cut effect on efficency and purity.

Args: start (float, optional): Lower range of probablity cuts. Defaults to 0.3. stop (float, optional): Upper range of probablity cuts. Defaults to 0.98. n_steps (int, optional): Number of probability cuts to try. Defaults to 30. pid_variable_name (str, optional): Name of the variable containing true Pid. Defaults to "Complex_pid". purity_cut (float, optional): Minimal purity for automatic cuts selection. Defaults to 0.. save_fig (bool, optional): Saves figures (BDT cut vs efficiency and purity) to file if True. Defaults to True.

Returns: Tuple[float, float, float]: Probability cut for each variable.

@staticmethod
def efficiency_stats( cnf_matrix: numpy.ndarray, pid: int, txt_tile: _io.TextIOWrapper = None, print_output: bool = True) -> Tuple[float, float]:
220    @staticmethod
221    def efficiency_stats(
222    cnf_matrix: np.ndarray,
223    pid: int,
224    txt_tile: io.TextIOWrapper = None,
225    print_output: bool = True,
226    ) -> Tuple[float, float]:
227        """
228        Prints efficiency stats from confusion matrix into efficiency_stats.txt file and stdout.
229        Efficiency is calculated as correctly identified X / all true simulated X
230        Purity is calculated as correctly identified X / all identified X
231
232        Args:
233            cnf_matrix (np.ndarray): Confusion matrix generated by sklearn.metrics.confusion_matrix.
234            pid (int): Pid of particles to print efficiency stats.
235            txt_tile (io.TextIOWrapper): Text file to write the output. Defaults to None.
236            print_output (bool): Whether to print the output to stdout. Defaults to True.
237
238        Returns:
239            Tuple[float, float]: Tuple with efficiency and purity
240        """
241        all_simulated_signal = cnf_matrix[pid].sum()
242        true_signal = cnf_matrix[pid][pid]
243        false_signal = cnf_matrix[:, pid].sum() - true_signal
244        reconstructed_signals = true_signal + false_signal
245
246        efficiency = (true_signal / all_simulated_signal) * 100
247        purity = (true_signal / reconstructed_signals) * 100
248
249        stats = f"""
250        For particle ID = {pid}: 
251        Efficiency: {efficiency:.2f}%
252        Purity: {purity:.2f}%
253        """
254
255        if print_output:
256            print(stats)
257
258        if txt_tile is not None:
259            txt_tile.writelines(stats)
260
261        return (efficiency, purity)

Prints efficiency stats from confusion matrix into efficiency_stats.txt file and stdout. Efficiency is calculated as correctly identified X / all true simulated X Purity is calculated as correctly identified X / all identified X

Args: cnf_matrix (np.ndarray): Confusion matrix generated by sklearn.metrics.confusion_matrix. pid (int): Pid of particles to print efficiency stats. txt_tile (io.TextIOWrapper): Text file to write the output. Defaults to None. print_output (bool): Whether to print the output to stdout. Defaults to True.

Returns: Tuple[float, float]: Tuple with efficiency and purity

def confusion_matrix_and_stats(self, efficiency_filename: str = 'efficiency_stats.txt'):
263    def confusion_matrix_and_stats(
264        self, efficiency_filename: str = "efficiency_stats.txt"
265    ):
266        """
267        Generates confusion matrix and efficiency/purity stats.
268        """
269        cnf_matrix = confusion_matrix(
270            self.particles_df[self.pid_variable_name], self.particles_df["xgb_preds"]
271        )
272        plotting_tools.plot_confusion_matrix(cnf_matrix)
273        plotting_tools.plot_confusion_matrix(cnf_matrix, normalize=True)
274        txt_file = open(efficiency_filename, "w+")
275        for pid in range(self.get_n_classes() - 1):
276            self.efficiency_stats(cnf_matrix, pid, txt_file)
277        txt_file.close()

Generates confusion matrix and efficiency/purity stats.

def generate_plots(self):
279    def generate_plots(self):
280        """
281        Generate tof, mass2, vars, and pT-rapidity plots
282        """
283        self._tof_plots()
284        self._mass2_plots()
285        self._vars_distributions_plots()

Generate tof, mass2, vars, and pT-rapidity plots

@staticmethod
def parse_model_name( name: str, pattern: str = 'model_([\\d.]+)_([\\d.]+)_(anti)|model_([\\d.]+)_([\\d.]+)_([a-zA-Z]+)') -> Tuple[float, float, bool]:
373    @staticmethod
374    def parse_model_name(
375        name: str,
376        pattern: str = r"model_([\d.]+)_([\d.]+)_(anti)|model_([\d.]+)_([\d.]+)_([a-zA-Z]+)",
377    ) -> Tuple[float, float, bool]:
378        """Parser model name to get info about lower momentum cut, upper momentum cut, and if model is trained for anti_particles.
379
380        Args:
381            name (str): Name of the model.
382            pattern (_type_, optional): Pattern of model name.
383             Defaults to r"model_([\d.]+)_([\d.]+)_(anti)|model_([\d.]+)_([\d.]+)_([a-zA-Z]+)".
384
385        Raises:
386            ValueError: Raises error if model name incorrect.
387
388        Returns:
389            Tuple[float, float, bool]: Tuple containing lower_p_cut, upper_p_cut, is_anti
390        """
391        match = re.match(pattern, name)
392        if match:
393            if match.group(3):
394                lower_p_cut = float(match.group(1))
395                upper_p_cut = float(match.group(2))
396                is_anti = True
397            else:
398                lower_p_cut = float(match.group(4))
399                upper_p_cut = float(match.group(5))
400                is_anti = False
401        else:
402            raise ValueError("Incorrect model name, regex not found.")
403        return (lower_p_cut, upper_p_cut, is_anti)

Parser model name to get info about lower momentum cut, upper momentum cut, and if model is trained for anti_particles.

Args: name (str): Name of the model. pattern (_type_, optional): Pattern of model name. Defaults to r"model_([\d.]+)_([\d.]+)_(anti)|model_([\d.]+)_([\d.]+)_([a-zA-Z]+)".

Raises: ValueError: Raises error if model name incorrect.

Returns: Tuple[float, float, bool]: Tuple containing lower_p_cut, upper_p_cut, is_anti

def parse_args(args: List[str]) -> argparse.Namespace:
406def parse_args(args: List[str]) -> argparse.Namespace:
407    """
408    Arguments parser for the main method.
409
410    Args:
411        args (List[str]): Arguments from the command line, should be sys.argv[1:].
412
413    Returns:
414        argparse.Namespace: argparse.Namespace containg args
415    """
416    parser = argparse.ArgumentParser(
417        prog="ML_PID_CBM ValidateModel",
418        description="Program for validating PID ML models",
419    )
420    parser.add_argument(
421        "--config",
422        "-c",
423        nargs=1,
424        required=True,
425        type=str,
426        help="Filename of path of config json file.",
427    )
428    parser.add_argument(
429        "--modelname",
430        "-m",
431        nargs=1,
432        required=True,
433        type=str,
434        help="Name of folder containing trained ml model.",
435    )
436    proba_group = parser.add_mutually_exclusive_group(required=True)
437    proba_group.add_argument(
438        "--probabilitycuts",
439        "-p",
440        nargs=3,
441        type=float,
442        help="Probability cut value for respectively protons, kaons, and pions. E.g., 0.9 0.95 0.9",
443    )
444    proba_group.add_argument(
445        "--evaluateproba",
446        "-e",
447        nargs=3,
448        type=float,
449        help="Minimal probability cut, maximal, and number of steps to investigate.",
450    )
451    parser.add_argument(
452        "--nworkers",
453        "-n",
454        type=int,
455        default=1,
456        help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.",
457    )
458    decision_group = parser.add_mutually_exclusive_group()
459    decision_group.add_argument(
460        "--interactive",
461        "-i",
462        action="store_true",
463        help="Interactive mode allows selection of probability cuts after evaluating them.",
464    )
465    decision_group.add_argument(
466        "--automatic",
467        "-a",
468        nargs=1,
469        type=float,
470        help="""Minimal purity for automatic threshold selection (in percent) e.g., 90.
471        Will choose the highest efficiency for purity above this value.
472        If max purity is below this value, will choose the highest purity available.""",
473    )
474    return parser.parse_args(args)

Arguments parser for the main method.

Args: args (List[str]): Arguments from the command line, should be sys.argv[1:].

Returns: argparse.Namespace: argparse.Namespace containg args