ml_pid_cbm.validate_model
1import argparse 2import io 3import os 4import re 5import sys 6from collections import defaultdict 7from typing import List, Tuple 8 9import numpy as np 10import pandas as pd 11from hipe4ml.model_handler import ModelHandler 12from sklearn.metrics import confusion_matrix 13 14from tools import json_tools, plotting_tools 15from tools.load_data import LoadData 16from tools.particles_id import ParticlesId as Pid 17 18 19class ValidateModel: 20 """ 21 Class for testing the ml model 22 """ 23 24 def __init__( 25 self, 26 lower_p_cut: float, 27 upper_p_cut: float, 28 anti_particles: bool, 29 json_file_name: str, 30 particles_df: pd.DataFrame, 31 ): 32 self.lower_p_cut = lower_p_cut 33 self.upper_p_cut = upper_p_cut 34 self.anti_particles = anti_particles 35 self.json_file_name = json_file_name 36 self.particles_df = particles_df 37 self.pid_variable_name = json_tools.load_var_name(self.json_file_name, "pid") 38 self.mass2_variable_name = json_tools.load_var_name( 39 self.json_file_name, "mass2" 40 ) 41 self.classes_names = ["protons", "kaons", "pions", "bckgr"] 42 43 def get_n_classes(self): 44 return len(self.classes_names) 45 46 def xgb_preds(self, proba_proton: float, proba_kaon: float, proba_pion: float): 47 """Gets particle type as selected by xgboost model if above probability threshold. 48 49 Args: 50 proba_proton (float): Probablity threshold to classify particle as proton. 51 proba_kaon (float): Probablity threshold to classify particle as kaon. 52 proba_pion (float): Probablity threshold to classify particle as pion. 53 """ 54 df = self.particles_df 55 df["xgb_preds"] = ( 56 df[["model_output_0", "model_output_1", "model_output_2"]] 57 .idxmax(axis=1) 58 .map(lambda x: x.lstrip("model_output_")) 59 .astype(int) 60 ) 61 # setting to bckgr if smaller than probability threshold 62 proton = (df["xgb_preds"] == 0) & (df["model_output_0"] > proba_proton) 63 pion = (df["xgb_preds"] == 1) & (df["model_output_1"] > proba_kaon) 64 kaon = (df["xgb_preds"] == 2) & (df["model_output_2"] > proba_pion) 65 df.loc[~(proton | pion | kaon), "xgb_preds"] = 3 66 67 self.particles_df = df 68 69 def remap_names(self): 70 """ 71 Remaps Pid of particles to output format from XGBoost Model. 72 Protons: 0; Kaons: 1; Pions, Electrons, Muons: 2; Other: 3 73 74 """ 75 df = self.particles_df 76 if self.anti_particles: 77 df[self.pid_variable_name] = ( 78 df[self.pid_variable_name] 79 .map( 80 defaultdict( 81 lambda: 3.0, 82 { 83 Pid.ANTI_PROTON.value: 0.0, 84 Pid.NEG_KAON.value: 1.0, 85 Pid.NEG_PION.value: 2.0, 86 Pid.ELECTRON.value: 2.0, 87 Pid.NEG_MUON.value: 2.0, 88 }, 89 ), 90 na_action="ignore", 91 ) 92 .astype(float) 93 ) 94 else: 95 df[self.pid_variable_name] = ( 96 df[self.pid_variable_name] 97 .map( 98 defaultdict( 99 lambda: 3.0, 100 { 101 Pid.PROTON.value: 0.0, 102 Pid.POS_KAON.value: 1.0, 103 Pid.POS_PION.value: 2.0, 104 Pid.POSITRON.value: 2.0, 105 Pid.POS_MUON.value: 2.0, 106 }, 107 ), 108 na_action="ignore", 109 ) 110 .astype(float) 111 ) 112 self.particles_df = df 113 114 def save_df(self): 115 """ 116 Saves dataframe with validated data into pickle format. 117 """ 118 self.particles_df.to_pickle("validated_data.pickle") 119 120 def sigma_selection(self, pid: float, nsigma: float = 5, info: bool = False): 121 """Sigma selection for dataframe to remove systmatically (not by the ML model) mismatched particles. 122 123 Args: 124 pid (float): Pid of particle for this selection 125 nsigma (float, optional): _description_. Defaults to 5. 126 info (bool, optional): _description_. Defaults to False. 127 """ 128 df = self.particles_df 129 # for selected pid 130 mean = df[df[self.pid_variable_name] == pid][self.mass2_variable_name].mean() 131 std = df[df[self.pid_variable_name] == pid][self.mass2_variable_name].std() 132 outside_sigma = (df[self.pid_variable_name] == pid) & ( 133 (df[self.mass2_variable_name] < (mean - nsigma * std)) 134 | (df[self.mass2_variable_name] > (mean + nsigma * std)) 135 ) 136 df_sigma_selected = df[~outside_sigma] 137 if info: 138 df_len = len(df) 139 df1_len = len(df_sigma_selected) 140 print( 141 "we get rid of " 142 + str(round((df_len - df1_len) / df_len * 100, 2)) 143 + " % of pid = " 144 + str(pid) 145 + " particle entries" 146 ) 147 self.particles_df = df_sigma_selected 148 149 def evaluate_probas( 150 self, 151 start: float = 0.3, 152 stop: float = 0.98, 153 n_steps: int = 30, 154 purity_cut: float = 0.0, 155 save_fig: bool = True, 156 ) -> Tuple[float, float, float]: 157 """Method for evaluating probability (BDT) cut effect on efficency and purity. 158 159 Args: 160 start (float, optional): Lower range of probablity cuts. Defaults to 0.3. 161 stop (float, optional): Upper range of probablity cuts. Defaults to 0.98. 162 n_steps (int, optional): Number of probability cuts to try. Defaults to 30. 163 pid_variable_name (str, optional): Name of the variable containing true Pid. Defaults to "Complex_pid". 164 purity_cut (float, optional): Minimal purity for automatic cuts selection. Defaults to 0.. 165 save_fig (bool, optional): Saves figures (BDT cut vs efficiency and purity) to file if True. Defaults to True. 166 167 Returns: 168 Tuple[float, float, float]: Probability cut for each variable. 169 """ 170 print( 171 f"Checking efficiency and purity for {int(n_steps)} probablity cuts between {start}, and {stop}..." 172 ) 173 probas = np.linspace(start, stop, n_steps) 174 efficienciess_protons, efficiencies_kaons, efficiencies_pions = [], [], [] 175 efficiencies = [efficienciess_protons, efficiencies_kaons, efficiencies_pions] 176 purities_protons, purities_kaons, purities_pions = [], [], [] 177 purities = [purities_protons, purities_kaons, purities_pions] 178 best_cuts = [0.0, 0.0, 0.0] 179 max_efficiencies = [0.0, 0.0, 0.0] 180 max_purities = [0.0, 0.0, 0.0] 181 182 for proba in probas: 183 self.xgb_preds(proba, proba, proba) 184 # confusion matrix 185 cnf_matrix = confusion_matrix( 186 self.particles_df[self.pid_variable_name], 187 self.particles_df["xgb_preds"], 188 ) 189 for pid in range(self.get_n_classes() - 1): 190 efficiency, purity = self.efficiency_stats( 191 cnf_matrix, pid, print_output=False 192 ) 193 efficiencies[pid].append(efficiency) 194 purities[pid].append(purity) 195 if purity_cut > 0.0: 196 # Minimal purity for automatic threshold selection. 197 # Will choose the highest efficiency for purity above this value. 198 if purity >= purity_cut: 199 if efficiency > max_efficiencies[pid]: 200 best_cuts[pid] = proba 201 max_efficiencies[pid] = efficiency 202 max_purities[pid] = purity 203 # If max purity is below this value, will choose the highest purity available. 204 else: 205 if purity > max_purities[pid]: 206 best_cuts[pid] = proba 207 max_efficiencies[pid] = efficiency 208 max_purities[pid] = purity 209 210 plotting_tools.plot_efficiency_purity(probas, efficiencies, purities, save_fig) 211 if save_fig: 212 print("Plots ready!") 213 if purity_cut > 0: 214 print(f"Selected probaility cuts: {best_cuts}") 215 return (best_cuts[0], best_cuts[1], best_cuts[2]) 216 else: 217 return (-1.0, -1.0, -1.0) 218 219 @staticmethod 220 def efficiency_stats( 221 cnf_matrix: np.ndarray, 222 pid: int, 223 txt_tile: io.TextIOWrapper = None, 224 print_output: bool = True, 225 ) -> Tuple[float, float]: 226 """ 227 Prints efficiency stats from confusion matrix into efficiency_stats.txt file and stdout. 228 Efficiency is calculated as correctly identified X / all true simulated X 229 Purity is calculated as correctly identified X / all identified X 230 231 Args: 232 cnf_matrix (np.ndarray): Confusion matrix generated by sklearn.metrics.confusion_matrix. 233 pid (int): Pid of particles to print efficiency stats. 234 txt_tile (io.TextIOWrapper): Text file to write the output. Defaults to None. 235 print_output (bool): Whether to print the output to stdout. Defaults to True. 236 237 Returns: 238 Tuple[float, float]: Tuple with efficiency and purity 239 """ 240 all_simulated_signal = cnf_matrix[pid].sum() 241 true_signal = cnf_matrix[pid][pid] 242 false_signal = cnf_matrix[:, pid].sum() - true_signal 243 reconstructed_signals = true_signal + false_signal 244 245 efficiency = (true_signal / all_simulated_signal) * 100 246 purity = (true_signal / reconstructed_signals) * 100 247 248 stats = f""" 249 For particle ID = {pid}: 250 Efficiency: {efficiency:.2f}% 251 Purity: {purity:.2f}% 252 """ 253 254 if print_output: 255 print(stats) 256 257 if txt_tile is not None: 258 txt_tile.writelines(stats) 259 260 return (efficiency, purity) 261 262 def confusion_matrix_and_stats( 263 self, efficiency_filename: str = "efficiency_stats.txt" 264 ): 265 """ 266 Generates confusion matrix and efficiency/purity stats. 267 """ 268 cnf_matrix = confusion_matrix( 269 self.particles_df[self.pid_variable_name], self.particles_df["xgb_preds"] 270 ) 271 plotting_tools.plot_confusion_matrix(cnf_matrix) 272 plotting_tools.plot_confusion_matrix(cnf_matrix, normalize=True) 273 txt_file = open(efficiency_filename, "w+") 274 for pid in range(self.get_n_classes() - 1): 275 self.efficiency_stats(cnf_matrix, pid, txt_file) 276 txt_file.close() 277 278 def generate_plots(self): 279 """ 280 Generate tof, mass2, vars, and pT-rapidity plots 281 """ 282 self._tof_plots() 283 self._mass2_plots() 284 self._vars_distributions_plots() 285 286 def _tof_plots(self): 287 """ 288 Generates tof plots. 289 """ 290 for pid, particle_name in enumerate(self.classes_names): 291 # simulated: 292 try: 293 plotting_tools.tof_plot( 294 self.particles_df[self.particles_df[self.pid_variable_name] == pid], 295 self.json_file_name, 296 f"{particle_name} (all simulated)", 297 ) 298 except ValueError: 299 print(f"No simulated {particle_name}s") 300 # xgb selected 301 try: 302 plotting_tools.tof_plot( 303 self.particles_df[self.particles_df["xgb_preds"] == pid], 304 self.json_file_name, 305 f"{particle_name} (XGB-selected)", 306 ) 307 except ValueError: 308 print(f"No XGB-selected {particle_name}s") 309 310 def _mass2_plots(self): 311 """ 312 Generates mass2 plots. 313 """ 314 protons_range = (-0.2, 1.8) 315 kaons_range = (-0.2, 0.6) 316 pions_range = (-0.3, 0.3) 317 ranges = [protons_range, kaons_range, pions_range, pions_range] 318 for pid, particle_name in enumerate(self.classes_names): 319 plotting_tools.plot_mass2( 320 self.particles_df[self.particles_df["xgb_preds"] == pid][ 321 self.mass2_variable_name 322 ], 323 self.particles_df[self.particles_df[self.pid_variable_name] == pid][ 324 self.mass2_variable_name 325 ], 326 particle_name, 327 ranges[pid], 328 ) 329 plotting_tools.plot_all_particles_mass2( 330 self.particles_df[self.particles_df["xgb_preds"] == pid], 331 self.mass2_variable_name, 332 self.pid_variable_name, 333 particle_name, 334 ranges[pid], 335 ) 336 337 def _vars_distributions_plots(self): 338 """ 339 Generates distributions of variables and pT-rapidity graphs. 340 """ 341 vars_to_draw = json_tools.load_vars_to_draw(self.json_file_name) 342 for pid, particle_name in enumerate(self.classes_names): 343 plotting_tools.var_distributions_plot( 344 vars_to_draw, 345 [ 346 self.particles_df[ 347 (self.particles_df[self.pid_variable_name] == pid) 348 ], 349 self.particles_df[ 350 ( 351 (self.particles_df[self.pid_variable_name] == pid) 352 & (self.particles_df["xgb_preds"] == pid) 353 ) 354 ], 355 self.particles_df[ 356 ( 357 (self.particles_df[self.pid_variable_name] != pid) 358 & (self.particles_df["xgb_preds"] == pid) 359 ) 360 ], 361 ], 362 [ 363 f"true MC {particle_name}", 364 f"true selected {particle_name}", 365 f"false selected {particle_name}", 366 ], 367 filename=f"vars_dist_{particle_name}", 368 ) 369 plotting_tools.plot_eff_pT_rap(self.particles_df, pid) 370 plotting_tools.plot_pt_rapidity(self.particles_df, pid) 371 372 @staticmethod 373 def parse_model_name( 374 name: str, 375 pattern: str = r"model_([\d.]+)_([\d.]+)_(anti)|model_([\d.]+)_([\d.]+)_([a-zA-Z]+)", 376 ) -> Tuple[float, float, bool]: 377 """Parser model name to get info about lower momentum cut, upper momentum cut, and if model is trained for anti_particles. 378 379 Args: 380 name (str): Name of the model. 381 pattern (_type_, optional): Pattern of model name. 382 Defaults to r"model_([\d.]+)_([\d.]+)_(anti)|model_([\d.]+)_([\d.]+)_([a-zA-Z]+)". 383 384 Raises: 385 ValueError: Raises error if model name incorrect. 386 387 Returns: 388 Tuple[float, float, bool]: Tuple containing lower_p_cut, upper_p_cut, is_anti 389 """ 390 match = re.match(pattern, name) 391 if match: 392 if match.group(3): 393 lower_p_cut = float(match.group(1)) 394 upper_p_cut = float(match.group(2)) 395 is_anti = True 396 else: 397 lower_p_cut = float(match.group(4)) 398 upper_p_cut = float(match.group(5)) 399 is_anti = False 400 else: 401 raise ValueError("Incorrect model name, regex not found.") 402 return (lower_p_cut, upper_p_cut, is_anti) 403 404 405def parse_args(args: List[str]) -> argparse.Namespace: 406 """ 407 Arguments parser for the main method. 408 409 Args: 410 args (List[str]): Arguments from the command line, should be sys.argv[1:]. 411 412 Returns: 413 argparse.Namespace: argparse.Namespace containg args 414 """ 415 parser = argparse.ArgumentParser( 416 prog="ML_PID_CBM ValidateModel", 417 description="Program for validating PID ML models", 418 ) 419 parser.add_argument( 420 "--config", 421 "-c", 422 nargs=1, 423 required=True, 424 type=str, 425 help="Filename of path of config json file.", 426 ) 427 parser.add_argument( 428 "--modelname", 429 "-m", 430 nargs=1, 431 required=True, 432 type=str, 433 help="Name of folder containing trained ml model.", 434 ) 435 proba_group = parser.add_mutually_exclusive_group(required=True) 436 proba_group.add_argument( 437 "--probabilitycuts", 438 "-p", 439 nargs=3, 440 type=float, 441 help="Probability cut value for respectively protons, kaons, and pions. E.g., 0.9 0.95 0.9", 442 ) 443 proba_group.add_argument( 444 "--evaluateproba", 445 "-e", 446 nargs=3, 447 type=float, 448 help="Minimal probability cut, maximal, and number of steps to investigate.", 449 ) 450 parser.add_argument( 451 "--nworkers", 452 "-n", 453 type=int, 454 default=1, 455 help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.", 456 ) 457 decision_group = parser.add_mutually_exclusive_group() 458 decision_group.add_argument( 459 "--interactive", 460 "-i", 461 action="store_true", 462 help="Interactive mode allows selection of probability cuts after evaluating them.", 463 ) 464 decision_group.add_argument( 465 "--automatic", 466 "-a", 467 nargs=1, 468 type=float, 469 help="""Minimal purity for automatic threshold selection (in percent) e.g., 90. 470 Will choose the highest efficiency for purity above this value. 471 If max purity is below this value, will choose the highest purity available.""", 472 ) 473 return parser.parse_args(args) 474 475 476if __name__ == "__main__": 477 # parser for main class 478 args = parse_args(sys.argv[1:]) 479 # config arguments to be loaded from args 480 json_file_name = args.config[0] 481 model_name = args.modelname[0] 482 proba_proton, proba_kaon, proba_pion = ( 483 (args.probabilitycuts[0], args.probabilitycuts[1], args.probabilitycuts[2]) 484 if args.probabilitycuts is not None 485 else (-1.0, -1.0, -1.0) 486 ) 487 488 n_workers = args.nworkers 489 purity_cut = args.automatic[0] if args.automatic is not None else 0.0 490 lower_p, upper_p, is_anti = ValidateModel.parse_model_name(model_name) 491 # loading test data 492 data_file_name = json_tools.load_file_name(json_file_name, "test") 493 494 loader = LoadData(data_file_name, json_file_name, lower_p, upper_p, is_anti) 495 # sigma selection 496 # loading model handler and applying on dataset 497 print( 498 f"\nLoading data from {data_file_name}\nApplying model handler from {model_name}" 499 ) 500 os.chdir(f"{model_name}") 501 model_hdl = ModelHandler() 502 model_hdl.load_model_handler(model_name) 503 test_particles = loader.load_tree(model_handler=model_hdl, max_workers=n_workers) 504 # validate model object 505 validate = ValidateModel( 506 lower_p, upper_p, is_anti, json_file_name, test_particles.get_data_frame() 507 ) 508 # remap Pid to match output XGBoost format 509 validate.remap_names() 510 pid_variable_name = json_tools.load_var_name(json_file_name, "pid") 511 # set probability cuts 512 if args.evaluateproba is not None: 513 proba_proton, proba_kaon, proba_pion = validate.evaluate_probas( 514 args.evaluateproba[0], 515 args.evaluateproba[1], 516 int(args.evaluateproba[2]), 517 purity_cut, 518 not args.interactive, 519 ) 520 if args.interactive: 521 while proba_proton < 0 or proba_proton > 1: 522 proba_proton = float( 523 input( 524 "Enter the probability threshold for proton (between 0 and 1): " 525 ) 526 ) 527 528 while proba_kaon < 0 or proba_kaon > 1: 529 proba_kaon = float( 530 input( 531 "Enter the probability threshold for kaon (between 0 and 1): " 532 ) 533 ) 534 535 while proba_pion < 0 or proba_pion > 1: 536 proba_pion = float( 537 input( 538 "Enter the probability threshold for pion (between 0 and 1): " 539 ) 540 ) 541 # if probabilites are set 542 # apply probabilty cuts 543 print( 544 f"\nApplying probability cuts.\nFor protons: {proba_proton}\nFor kaons: {proba_kaon}\nFor pions: {proba_pion}" 545 ) 546 validate.xgb_preds(proba_proton, proba_kaon, proba_pion) 547 # graphs 548 validate.confusion_matrix_and_stats() 549 print("Generating plots...") 550 validate.generate_plots() 551 # save validated dataset 552 validate.save_df()
20class ValidateModel: 21 """ 22 Class for testing the ml model 23 """ 24 25 def __init__( 26 self, 27 lower_p_cut: float, 28 upper_p_cut: float, 29 anti_particles: bool, 30 json_file_name: str, 31 particles_df: pd.DataFrame, 32 ): 33 self.lower_p_cut = lower_p_cut 34 self.upper_p_cut = upper_p_cut 35 self.anti_particles = anti_particles 36 self.json_file_name = json_file_name 37 self.particles_df = particles_df 38 self.pid_variable_name = json_tools.load_var_name(self.json_file_name, "pid") 39 self.mass2_variable_name = json_tools.load_var_name( 40 self.json_file_name, "mass2" 41 ) 42 self.classes_names = ["protons", "kaons", "pions", "bckgr"] 43 44 def get_n_classes(self): 45 return len(self.classes_names) 46 47 def xgb_preds(self, proba_proton: float, proba_kaon: float, proba_pion: float): 48 """Gets particle type as selected by xgboost model if above probability threshold. 49 50 Args: 51 proba_proton (float): Probablity threshold to classify particle as proton. 52 proba_kaon (float): Probablity threshold to classify particle as kaon. 53 proba_pion (float): Probablity threshold to classify particle as pion. 54 """ 55 df = self.particles_df 56 df["xgb_preds"] = ( 57 df[["model_output_0", "model_output_1", "model_output_2"]] 58 .idxmax(axis=1) 59 .map(lambda x: x.lstrip("model_output_")) 60 .astype(int) 61 ) 62 # setting to bckgr if smaller than probability threshold 63 proton = (df["xgb_preds"] == 0) & (df["model_output_0"] > proba_proton) 64 pion = (df["xgb_preds"] == 1) & (df["model_output_1"] > proba_kaon) 65 kaon = (df["xgb_preds"] == 2) & (df["model_output_2"] > proba_pion) 66 df.loc[~(proton | pion | kaon), "xgb_preds"] = 3 67 68 self.particles_df = df 69 70 def remap_names(self): 71 """ 72 Remaps Pid of particles to output format from XGBoost Model. 73 Protons: 0; Kaons: 1; Pions, Electrons, Muons: 2; Other: 3 74 75 """ 76 df = self.particles_df 77 if self.anti_particles: 78 df[self.pid_variable_name] = ( 79 df[self.pid_variable_name] 80 .map( 81 defaultdict( 82 lambda: 3.0, 83 { 84 Pid.ANTI_PROTON.value: 0.0, 85 Pid.NEG_KAON.value: 1.0, 86 Pid.NEG_PION.value: 2.0, 87 Pid.ELECTRON.value: 2.0, 88 Pid.NEG_MUON.value: 2.0, 89 }, 90 ), 91 na_action="ignore", 92 ) 93 .astype(float) 94 ) 95 else: 96 df[self.pid_variable_name] = ( 97 df[self.pid_variable_name] 98 .map( 99 defaultdict( 100 lambda: 3.0, 101 { 102 Pid.PROTON.value: 0.0, 103 Pid.POS_KAON.value: 1.0, 104 Pid.POS_PION.value: 2.0, 105 Pid.POSITRON.value: 2.0, 106 Pid.POS_MUON.value: 2.0, 107 }, 108 ), 109 na_action="ignore", 110 ) 111 .astype(float) 112 ) 113 self.particles_df = df 114 115 def save_df(self): 116 """ 117 Saves dataframe with validated data into pickle format. 118 """ 119 self.particles_df.to_pickle("validated_data.pickle") 120 121 def sigma_selection(self, pid: float, nsigma: float = 5, info: bool = False): 122 """Sigma selection for dataframe to remove systmatically (not by the ML model) mismatched particles. 123 124 Args: 125 pid (float): Pid of particle for this selection 126 nsigma (float, optional): _description_. Defaults to 5. 127 info (bool, optional): _description_. Defaults to False. 128 """ 129 df = self.particles_df 130 # for selected pid 131 mean = df[df[self.pid_variable_name] == pid][self.mass2_variable_name].mean() 132 std = df[df[self.pid_variable_name] == pid][self.mass2_variable_name].std() 133 outside_sigma = (df[self.pid_variable_name] == pid) & ( 134 (df[self.mass2_variable_name] < (mean - nsigma * std)) 135 | (df[self.mass2_variable_name] > (mean + nsigma * std)) 136 ) 137 df_sigma_selected = df[~outside_sigma] 138 if info: 139 df_len = len(df) 140 df1_len = len(df_sigma_selected) 141 print( 142 "we get rid of " 143 + str(round((df_len - df1_len) / df_len * 100, 2)) 144 + " % of pid = " 145 + str(pid) 146 + " particle entries" 147 ) 148 self.particles_df = df_sigma_selected 149 150 def evaluate_probas( 151 self, 152 start: float = 0.3, 153 stop: float = 0.98, 154 n_steps: int = 30, 155 purity_cut: float = 0.0, 156 save_fig: bool = True, 157 ) -> Tuple[float, float, float]: 158 """Method for evaluating probability (BDT) cut effect on efficency and purity. 159 160 Args: 161 start (float, optional): Lower range of probablity cuts. Defaults to 0.3. 162 stop (float, optional): Upper range of probablity cuts. Defaults to 0.98. 163 n_steps (int, optional): Number of probability cuts to try. Defaults to 30. 164 pid_variable_name (str, optional): Name of the variable containing true Pid. Defaults to "Complex_pid". 165 purity_cut (float, optional): Minimal purity for automatic cuts selection. Defaults to 0.. 166 save_fig (bool, optional): Saves figures (BDT cut vs efficiency and purity) to file if True. Defaults to True. 167 168 Returns: 169 Tuple[float, float, float]: Probability cut for each variable. 170 """ 171 print( 172 f"Checking efficiency and purity for {int(n_steps)} probablity cuts between {start}, and {stop}..." 173 ) 174 probas = np.linspace(start, stop, n_steps) 175 efficienciess_protons, efficiencies_kaons, efficiencies_pions = [], [], [] 176 efficiencies = [efficienciess_protons, efficiencies_kaons, efficiencies_pions] 177 purities_protons, purities_kaons, purities_pions = [], [], [] 178 purities = [purities_protons, purities_kaons, purities_pions] 179 best_cuts = [0.0, 0.0, 0.0] 180 max_efficiencies = [0.0, 0.0, 0.0] 181 max_purities = [0.0, 0.0, 0.0] 182 183 for proba in probas: 184 self.xgb_preds(proba, proba, proba) 185 # confusion matrix 186 cnf_matrix = confusion_matrix( 187 self.particles_df[self.pid_variable_name], 188 self.particles_df["xgb_preds"], 189 ) 190 for pid in range(self.get_n_classes() - 1): 191 efficiency, purity = self.efficiency_stats( 192 cnf_matrix, pid, print_output=False 193 ) 194 efficiencies[pid].append(efficiency) 195 purities[pid].append(purity) 196 if purity_cut > 0.0: 197 # Minimal purity for automatic threshold selection. 198 # Will choose the highest efficiency for purity above this value. 199 if purity >= purity_cut: 200 if efficiency > max_efficiencies[pid]: 201 best_cuts[pid] = proba 202 max_efficiencies[pid] = efficiency 203 max_purities[pid] = purity 204 # If max purity is below this value, will choose the highest purity available. 205 else: 206 if purity > max_purities[pid]: 207 best_cuts[pid] = proba 208 max_efficiencies[pid] = efficiency 209 max_purities[pid] = purity 210 211 plotting_tools.plot_efficiency_purity(probas, efficiencies, purities, save_fig) 212 if save_fig: 213 print("Plots ready!") 214 if purity_cut > 0: 215 print(f"Selected probaility cuts: {best_cuts}") 216 return (best_cuts[0], best_cuts[1], best_cuts[2]) 217 else: 218 return (-1.0, -1.0, -1.0) 219 220 @staticmethod 221 def efficiency_stats( 222 cnf_matrix: np.ndarray, 223 pid: int, 224 txt_tile: io.TextIOWrapper = None, 225 print_output: bool = True, 226 ) -> Tuple[float, float]: 227 """ 228 Prints efficiency stats from confusion matrix into efficiency_stats.txt file and stdout. 229 Efficiency is calculated as correctly identified X / all true simulated X 230 Purity is calculated as correctly identified X / all identified X 231 232 Args: 233 cnf_matrix (np.ndarray): Confusion matrix generated by sklearn.metrics.confusion_matrix. 234 pid (int): Pid of particles to print efficiency stats. 235 txt_tile (io.TextIOWrapper): Text file to write the output. Defaults to None. 236 print_output (bool): Whether to print the output to stdout. Defaults to True. 237 238 Returns: 239 Tuple[float, float]: Tuple with efficiency and purity 240 """ 241 all_simulated_signal = cnf_matrix[pid].sum() 242 true_signal = cnf_matrix[pid][pid] 243 false_signal = cnf_matrix[:, pid].sum() - true_signal 244 reconstructed_signals = true_signal + false_signal 245 246 efficiency = (true_signal / all_simulated_signal) * 100 247 purity = (true_signal / reconstructed_signals) * 100 248 249 stats = f""" 250 For particle ID = {pid}: 251 Efficiency: {efficiency:.2f}% 252 Purity: {purity:.2f}% 253 """ 254 255 if print_output: 256 print(stats) 257 258 if txt_tile is not None: 259 txt_tile.writelines(stats) 260 261 return (efficiency, purity) 262 263 def confusion_matrix_and_stats( 264 self, efficiency_filename: str = "efficiency_stats.txt" 265 ): 266 """ 267 Generates confusion matrix and efficiency/purity stats. 268 """ 269 cnf_matrix = confusion_matrix( 270 self.particles_df[self.pid_variable_name], self.particles_df["xgb_preds"] 271 ) 272 plotting_tools.plot_confusion_matrix(cnf_matrix) 273 plotting_tools.plot_confusion_matrix(cnf_matrix, normalize=True) 274 txt_file = open(efficiency_filename, "w+") 275 for pid in range(self.get_n_classes() - 1): 276 self.efficiency_stats(cnf_matrix, pid, txt_file) 277 txt_file.close() 278 279 def generate_plots(self): 280 """ 281 Generate tof, mass2, vars, and pT-rapidity plots 282 """ 283 self._tof_plots() 284 self._mass2_plots() 285 self._vars_distributions_plots() 286 287 def _tof_plots(self): 288 """ 289 Generates tof plots. 290 """ 291 for pid, particle_name in enumerate(self.classes_names): 292 # simulated: 293 try: 294 plotting_tools.tof_plot( 295 self.particles_df[self.particles_df[self.pid_variable_name] == pid], 296 self.json_file_name, 297 f"{particle_name} (all simulated)", 298 ) 299 except ValueError: 300 print(f"No simulated {particle_name}s") 301 # xgb selected 302 try: 303 plotting_tools.tof_plot( 304 self.particles_df[self.particles_df["xgb_preds"] == pid], 305 self.json_file_name, 306 f"{particle_name} (XGB-selected)", 307 ) 308 except ValueError: 309 print(f"No XGB-selected {particle_name}s") 310 311 def _mass2_plots(self): 312 """ 313 Generates mass2 plots. 314 """ 315 protons_range = (-0.2, 1.8) 316 kaons_range = (-0.2, 0.6) 317 pions_range = (-0.3, 0.3) 318 ranges = [protons_range, kaons_range, pions_range, pions_range] 319 for pid, particle_name in enumerate(self.classes_names): 320 plotting_tools.plot_mass2( 321 self.particles_df[self.particles_df["xgb_preds"] == pid][ 322 self.mass2_variable_name 323 ], 324 self.particles_df[self.particles_df[self.pid_variable_name] == pid][ 325 self.mass2_variable_name 326 ], 327 particle_name, 328 ranges[pid], 329 ) 330 plotting_tools.plot_all_particles_mass2( 331 self.particles_df[self.particles_df["xgb_preds"] == pid], 332 self.mass2_variable_name, 333 self.pid_variable_name, 334 particle_name, 335 ranges[pid], 336 ) 337 338 def _vars_distributions_plots(self): 339 """ 340 Generates distributions of variables and pT-rapidity graphs. 341 """ 342 vars_to_draw = json_tools.load_vars_to_draw(self.json_file_name) 343 for pid, particle_name in enumerate(self.classes_names): 344 plotting_tools.var_distributions_plot( 345 vars_to_draw, 346 [ 347 self.particles_df[ 348 (self.particles_df[self.pid_variable_name] == pid) 349 ], 350 self.particles_df[ 351 ( 352 (self.particles_df[self.pid_variable_name] == pid) 353 & (self.particles_df["xgb_preds"] == pid) 354 ) 355 ], 356 self.particles_df[ 357 ( 358 (self.particles_df[self.pid_variable_name] != pid) 359 & (self.particles_df["xgb_preds"] == pid) 360 ) 361 ], 362 ], 363 [ 364 f"true MC {particle_name}", 365 f"true selected {particle_name}", 366 f"false selected {particle_name}", 367 ], 368 filename=f"vars_dist_{particle_name}", 369 ) 370 plotting_tools.plot_eff_pT_rap(self.particles_df, pid) 371 plotting_tools.plot_pt_rapidity(self.particles_df, pid) 372 373 @staticmethod 374 def parse_model_name( 375 name: str, 376 pattern: str = r"model_([\d.]+)_([\d.]+)_(anti)|model_([\d.]+)_([\d.]+)_([a-zA-Z]+)", 377 ) -> Tuple[float, float, bool]: 378 """Parser model name to get info about lower momentum cut, upper momentum cut, and if model is trained for anti_particles. 379 380 Args: 381 name (str): Name of the model. 382 pattern (_type_, optional): Pattern of model name. 383 Defaults to r"model_([\d.]+)_([\d.]+)_(anti)|model_([\d.]+)_([\d.]+)_([a-zA-Z]+)". 384 385 Raises: 386 ValueError: Raises error if model name incorrect. 387 388 Returns: 389 Tuple[float, float, bool]: Tuple containing lower_p_cut, upper_p_cut, is_anti 390 """ 391 match = re.match(pattern, name) 392 if match: 393 if match.group(3): 394 lower_p_cut = float(match.group(1)) 395 upper_p_cut = float(match.group(2)) 396 is_anti = True 397 else: 398 lower_p_cut = float(match.group(4)) 399 upper_p_cut = float(match.group(5)) 400 is_anti = False 401 else: 402 raise ValueError("Incorrect model name, regex not found.") 403 return (lower_p_cut, upper_p_cut, is_anti)
Class for testing the ml model
25 def __init__( 26 self, 27 lower_p_cut: float, 28 upper_p_cut: float, 29 anti_particles: bool, 30 json_file_name: str, 31 particles_df: pd.DataFrame, 32 ): 33 self.lower_p_cut = lower_p_cut 34 self.upper_p_cut = upper_p_cut 35 self.anti_particles = anti_particles 36 self.json_file_name = json_file_name 37 self.particles_df = particles_df 38 self.pid_variable_name = json_tools.load_var_name(self.json_file_name, "pid") 39 self.mass2_variable_name = json_tools.load_var_name( 40 self.json_file_name, "mass2" 41 ) 42 self.classes_names = ["protons", "kaons", "pions", "bckgr"]
47 def xgb_preds(self, proba_proton: float, proba_kaon: float, proba_pion: float): 48 """Gets particle type as selected by xgboost model if above probability threshold. 49 50 Args: 51 proba_proton (float): Probablity threshold to classify particle as proton. 52 proba_kaon (float): Probablity threshold to classify particle as kaon. 53 proba_pion (float): Probablity threshold to classify particle as pion. 54 """ 55 df = self.particles_df 56 df["xgb_preds"] = ( 57 df[["model_output_0", "model_output_1", "model_output_2"]] 58 .idxmax(axis=1) 59 .map(lambda x: x.lstrip("model_output_")) 60 .astype(int) 61 ) 62 # setting to bckgr if smaller than probability threshold 63 proton = (df["xgb_preds"] == 0) & (df["model_output_0"] > proba_proton) 64 pion = (df["xgb_preds"] == 1) & (df["model_output_1"] > proba_kaon) 65 kaon = (df["xgb_preds"] == 2) & (df["model_output_2"] > proba_pion) 66 df.loc[~(proton | pion | kaon), "xgb_preds"] = 3 67 68 self.particles_df = df
Gets particle type as selected by xgboost model if above probability threshold.
Args: proba_proton (float): Probablity threshold to classify particle as proton. proba_kaon (float): Probablity threshold to classify particle as kaon. proba_pion (float): Probablity threshold to classify particle as pion.
70 def remap_names(self): 71 """ 72 Remaps Pid of particles to output format from XGBoost Model. 73 Protons: 0; Kaons: 1; Pions, Electrons, Muons: 2; Other: 3 74 75 """ 76 df = self.particles_df 77 if self.anti_particles: 78 df[self.pid_variable_name] = ( 79 df[self.pid_variable_name] 80 .map( 81 defaultdict( 82 lambda: 3.0, 83 { 84 Pid.ANTI_PROTON.value: 0.0, 85 Pid.NEG_KAON.value: 1.0, 86 Pid.NEG_PION.value: 2.0, 87 Pid.ELECTRON.value: 2.0, 88 Pid.NEG_MUON.value: 2.0, 89 }, 90 ), 91 na_action="ignore", 92 ) 93 .astype(float) 94 ) 95 else: 96 df[self.pid_variable_name] = ( 97 df[self.pid_variable_name] 98 .map( 99 defaultdict( 100 lambda: 3.0, 101 { 102 Pid.PROTON.value: 0.0, 103 Pid.POS_KAON.value: 1.0, 104 Pid.POS_PION.value: 2.0, 105 Pid.POSITRON.value: 2.0, 106 Pid.POS_MUON.value: 2.0, 107 }, 108 ), 109 na_action="ignore", 110 ) 111 .astype(float) 112 ) 113 self.particles_df = df
Remaps Pid of particles to output format from XGBoost Model. Protons: 0; Kaons: 1; Pions, Electrons, Muons: 2; Other: 3
115 def save_df(self): 116 """ 117 Saves dataframe with validated data into pickle format. 118 """ 119 self.particles_df.to_pickle("validated_data.pickle")
Saves dataframe with validated data into pickle format.
121 def sigma_selection(self, pid: float, nsigma: float = 5, info: bool = False): 122 """Sigma selection for dataframe to remove systmatically (not by the ML model) mismatched particles. 123 124 Args: 125 pid (float): Pid of particle for this selection 126 nsigma (float, optional): _description_. Defaults to 5. 127 info (bool, optional): _description_. Defaults to False. 128 """ 129 df = self.particles_df 130 # for selected pid 131 mean = df[df[self.pid_variable_name] == pid][self.mass2_variable_name].mean() 132 std = df[df[self.pid_variable_name] == pid][self.mass2_variable_name].std() 133 outside_sigma = (df[self.pid_variable_name] == pid) & ( 134 (df[self.mass2_variable_name] < (mean - nsigma * std)) 135 | (df[self.mass2_variable_name] > (mean + nsigma * std)) 136 ) 137 df_sigma_selected = df[~outside_sigma] 138 if info: 139 df_len = len(df) 140 df1_len = len(df_sigma_selected) 141 print( 142 "we get rid of " 143 + str(round((df_len - df1_len) / df_len * 100, 2)) 144 + " % of pid = " 145 + str(pid) 146 + " particle entries" 147 ) 148 self.particles_df = df_sigma_selected
Sigma selection for dataframe to remove systmatically (not by the ML model) mismatched particles.
Args: pid (float): Pid of particle for this selection nsigma (float, optional): _description_. Defaults to 5. info (bool, optional): _description_. Defaults to False.
150 def evaluate_probas( 151 self, 152 start: float = 0.3, 153 stop: float = 0.98, 154 n_steps: int = 30, 155 purity_cut: float = 0.0, 156 save_fig: bool = True, 157 ) -> Tuple[float, float, float]: 158 """Method for evaluating probability (BDT) cut effect on efficency and purity. 159 160 Args: 161 start (float, optional): Lower range of probablity cuts. Defaults to 0.3. 162 stop (float, optional): Upper range of probablity cuts. Defaults to 0.98. 163 n_steps (int, optional): Number of probability cuts to try. Defaults to 30. 164 pid_variable_name (str, optional): Name of the variable containing true Pid. Defaults to "Complex_pid". 165 purity_cut (float, optional): Minimal purity for automatic cuts selection. Defaults to 0.. 166 save_fig (bool, optional): Saves figures (BDT cut vs efficiency and purity) to file if True. Defaults to True. 167 168 Returns: 169 Tuple[float, float, float]: Probability cut for each variable. 170 """ 171 print( 172 f"Checking efficiency and purity for {int(n_steps)} probablity cuts between {start}, and {stop}..." 173 ) 174 probas = np.linspace(start, stop, n_steps) 175 efficienciess_protons, efficiencies_kaons, efficiencies_pions = [], [], [] 176 efficiencies = [efficienciess_protons, efficiencies_kaons, efficiencies_pions] 177 purities_protons, purities_kaons, purities_pions = [], [], [] 178 purities = [purities_protons, purities_kaons, purities_pions] 179 best_cuts = [0.0, 0.0, 0.0] 180 max_efficiencies = [0.0, 0.0, 0.0] 181 max_purities = [0.0, 0.0, 0.0] 182 183 for proba in probas: 184 self.xgb_preds(proba, proba, proba) 185 # confusion matrix 186 cnf_matrix = confusion_matrix( 187 self.particles_df[self.pid_variable_name], 188 self.particles_df["xgb_preds"], 189 ) 190 for pid in range(self.get_n_classes() - 1): 191 efficiency, purity = self.efficiency_stats( 192 cnf_matrix, pid, print_output=False 193 ) 194 efficiencies[pid].append(efficiency) 195 purities[pid].append(purity) 196 if purity_cut > 0.0: 197 # Minimal purity for automatic threshold selection. 198 # Will choose the highest efficiency for purity above this value. 199 if purity >= purity_cut: 200 if efficiency > max_efficiencies[pid]: 201 best_cuts[pid] = proba 202 max_efficiencies[pid] = efficiency 203 max_purities[pid] = purity 204 # If max purity is below this value, will choose the highest purity available. 205 else: 206 if purity > max_purities[pid]: 207 best_cuts[pid] = proba 208 max_efficiencies[pid] = efficiency 209 max_purities[pid] = purity 210 211 plotting_tools.plot_efficiency_purity(probas, efficiencies, purities, save_fig) 212 if save_fig: 213 print("Plots ready!") 214 if purity_cut > 0: 215 print(f"Selected probaility cuts: {best_cuts}") 216 return (best_cuts[0], best_cuts[1], best_cuts[2]) 217 else: 218 return (-1.0, -1.0, -1.0)
Method for evaluating probability (BDT) cut effect on efficency and purity.
Args: start (float, optional): Lower range of probablity cuts. Defaults to 0.3. stop (float, optional): Upper range of probablity cuts. Defaults to 0.98. n_steps (int, optional): Number of probability cuts to try. Defaults to 30. pid_variable_name (str, optional): Name of the variable containing true Pid. Defaults to "Complex_pid". purity_cut (float, optional): Minimal purity for automatic cuts selection. Defaults to 0.. save_fig (bool, optional): Saves figures (BDT cut vs efficiency and purity) to file if True. Defaults to True.
Returns: Tuple[float, float, float]: Probability cut for each variable.
220 @staticmethod 221 def efficiency_stats( 222 cnf_matrix: np.ndarray, 223 pid: int, 224 txt_tile: io.TextIOWrapper = None, 225 print_output: bool = True, 226 ) -> Tuple[float, float]: 227 """ 228 Prints efficiency stats from confusion matrix into efficiency_stats.txt file and stdout. 229 Efficiency is calculated as correctly identified X / all true simulated X 230 Purity is calculated as correctly identified X / all identified X 231 232 Args: 233 cnf_matrix (np.ndarray): Confusion matrix generated by sklearn.metrics.confusion_matrix. 234 pid (int): Pid of particles to print efficiency stats. 235 txt_tile (io.TextIOWrapper): Text file to write the output. Defaults to None. 236 print_output (bool): Whether to print the output to stdout. Defaults to True. 237 238 Returns: 239 Tuple[float, float]: Tuple with efficiency and purity 240 """ 241 all_simulated_signal = cnf_matrix[pid].sum() 242 true_signal = cnf_matrix[pid][pid] 243 false_signal = cnf_matrix[:, pid].sum() - true_signal 244 reconstructed_signals = true_signal + false_signal 245 246 efficiency = (true_signal / all_simulated_signal) * 100 247 purity = (true_signal / reconstructed_signals) * 100 248 249 stats = f""" 250 For particle ID = {pid}: 251 Efficiency: {efficiency:.2f}% 252 Purity: {purity:.2f}% 253 """ 254 255 if print_output: 256 print(stats) 257 258 if txt_tile is not None: 259 txt_tile.writelines(stats) 260 261 return (efficiency, purity)
Prints efficiency stats from confusion matrix into efficiency_stats.txt file and stdout. Efficiency is calculated as correctly identified X / all true simulated X Purity is calculated as correctly identified X / all identified X
Args: cnf_matrix (np.ndarray): Confusion matrix generated by sklearn.metrics.confusion_matrix. pid (int): Pid of particles to print efficiency stats. txt_tile (io.TextIOWrapper): Text file to write the output. Defaults to None. print_output (bool): Whether to print the output to stdout. Defaults to True.
Returns: Tuple[float, float]: Tuple with efficiency and purity
263 def confusion_matrix_and_stats( 264 self, efficiency_filename: str = "efficiency_stats.txt" 265 ): 266 """ 267 Generates confusion matrix and efficiency/purity stats. 268 """ 269 cnf_matrix = confusion_matrix( 270 self.particles_df[self.pid_variable_name], self.particles_df["xgb_preds"] 271 ) 272 plotting_tools.plot_confusion_matrix(cnf_matrix) 273 plotting_tools.plot_confusion_matrix(cnf_matrix, normalize=True) 274 txt_file = open(efficiency_filename, "w+") 275 for pid in range(self.get_n_classes() - 1): 276 self.efficiency_stats(cnf_matrix, pid, txt_file) 277 txt_file.close()
Generates confusion matrix and efficiency/purity stats.
279 def generate_plots(self): 280 """ 281 Generate tof, mass2, vars, and pT-rapidity plots 282 """ 283 self._tof_plots() 284 self._mass2_plots() 285 self._vars_distributions_plots()
Generate tof, mass2, vars, and pT-rapidity plots
373 @staticmethod 374 def parse_model_name( 375 name: str, 376 pattern: str = r"model_([\d.]+)_([\d.]+)_(anti)|model_([\d.]+)_([\d.]+)_([a-zA-Z]+)", 377 ) -> Tuple[float, float, bool]: 378 """Parser model name to get info about lower momentum cut, upper momentum cut, and if model is trained for anti_particles. 379 380 Args: 381 name (str): Name of the model. 382 pattern (_type_, optional): Pattern of model name. 383 Defaults to r"model_([\d.]+)_([\d.]+)_(anti)|model_([\d.]+)_([\d.]+)_([a-zA-Z]+)". 384 385 Raises: 386 ValueError: Raises error if model name incorrect. 387 388 Returns: 389 Tuple[float, float, bool]: Tuple containing lower_p_cut, upper_p_cut, is_anti 390 """ 391 match = re.match(pattern, name) 392 if match: 393 if match.group(3): 394 lower_p_cut = float(match.group(1)) 395 upper_p_cut = float(match.group(2)) 396 is_anti = True 397 else: 398 lower_p_cut = float(match.group(4)) 399 upper_p_cut = float(match.group(5)) 400 is_anti = False 401 else: 402 raise ValueError("Incorrect model name, regex not found.") 403 return (lower_p_cut, upper_p_cut, is_anti)
Parser model name to get info about lower momentum cut, upper momentum cut, and if model is trained for anti_particles.
Args: name (str): Name of the model. pattern (_type_, optional): Pattern of model name. Defaults to r"model_([\d.]+)_([\d.]+)_(anti)|model_([\d.]+)_([\d.]+)_([a-zA-Z]+)".
Raises: ValueError: Raises error if model name incorrect.
Returns: Tuple[float, float, bool]: Tuple containing lower_p_cut, upper_p_cut, is_anti
406def parse_args(args: List[str]) -> argparse.Namespace: 407 """ 408 Arguments parser for the main method. 409 410 Args: 411 args (List[str]): Arguments from the command line, should be sys.argv[1:]. 412 413 Returns: 414 argparse.Namespace: argparse.Namespace containg args 415 """ 416 parser = argparse.ArgumentParser( 417 prog="ML_PID_CBM ValidateModel", 418 description="Program for validating PID ML models", 419 ) 420 parser.add_argument( 421 "--config", 422 "-c", 423 nargs=1, 424 required=True, 425 type=str, 426 help="Filename of path of config json file.", 427 ) 428 parser.add_argument( 429 "--modelname", 430 "-m", 431 nargs=1, 432 required=True, 433 type=str, 434 help="Name of folder containing trained ml model.", 435 ) 436 proba_group = parser.add_mutually_exclusive_group(required=True) 437 proba_group.add_argument( 438 "--probabilitycuts", 439 "-p", 440 nargs=3, 441 type=float, 442 help="Probability cut value for respectively protons, kaons, and pions. E.g., 0.9 0.95 0.9", 443 ) 444 proba_group.add_argument( 445 "--evaluateproba", 446 "-e", 447 nargs=3, 448 type=float, 449 help="Minimal probability cut, maximal, and number of steps to investigate.", 450 ) 451 parser.add_argument( 452 "--nworkers", 453 "-n", 454 type=int, 455 default=1, 456 help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.", 457 ) 458 decision_group = parser.add_mutually_exclusive_group() 459 decision_group.add_argument( 460 "--interactive", 461 "-i", 462 action="store_true", 463 help="Interactive mode allows selection of probability cuts after evaluating them.", 464 ) 465 decision_group.add_argument( 466 "--automatic", 467 "-a", 468 nargs=1, 469 type=float, 470 help="""Minimal purity for automatic threshold selection (in percent) e.g., 90. 471 Will choose the highest efficiency for purity above this value. 472 If max purity is below this value, will choose the highest purity available.""", 473 ) 474 return parser.parse_args(args)
Arguments parser for the main method.
Args: args (List[str]): Arguments from the command line, should be sys.argv[1:].
Returns: argparse.Namespace: argparse.Namespace containg args