ml_pid_cbm.gauss.validate_multiple_gauss
1import argparse 2import os 3import sys 4from shutil import copy2 5from typing import List 6 7from ml_pid_cbm.gauss.validate_gauss import ValidateGauss 8from ml_pid_cbm.validate_multiple_models import ValidateMultipleModels 9 10 11class ValidateMultipleGauss(ValidateMultipleModels, ValidateGauss): 12 """ 13 Class for validating data from multiple models. 14 Inherits from ValidateModel 15 """ 16 17 18def parse_args(args: List[str]) -> argparse.Namespace: 19 """ 20 Arguments parser for the main method. 21 22 Args: 23 args (List[str]): Arguments from the command line, should be sys.argv[1:]. 24 25 Returns: 26 argparse.Namespace: argparse.Namespace containg args 27 """ 28 parser = argparse.ArgumentParser( 29 prog="ML_PID_CBM ValidateMultipleGauss", 30 description="Program for loading multiple validated PID Gauss models", 31 ) 32 parser.add_argument( 33 "--modelnames", 34 "-m", 35 nargs="+", 36 required=True, 37 type=str, 38 help="Names of folders containing trained and validated ML models.", 39 ) 40 parser.add_argument( 41 "--config", 42 "-c", 43 nargs=1, 44 required=True, 45 type=str, 46 help="Filename of path of config json file.", 47 ) 48 parser.add_argument( 49 "--nworkers", 50 "-n", 51 type=int, 52 default=1, 53 help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.", 54 ) 55 return parser.parse_args(args) 56 57 58if __name__ == "__main__": 59 args = parse_args(sys.argv[1:]) 60 # config arguments to be loaded from args 61 json_file_name = args.config[0] 62 models = args.modelnames 63 n_workers = args.nworkers 64 pickle_files = {f"{model}/validated_data.pickle" for model in models} 65 validate = ValidateMultipleGauss(json_file_name, pickle_files, n_workers) 66 # new folder for all files 67 json_file_path = os.path.join(os.getcwd(), json_file_name) 68 if not os.path.exists("all_models"): 69 os.makedirs("all_models") 70 os.chdir("all_models") 71 copy2(json_file_path, os.getcwd()) 72 # graphs 73 validate.confusion_matrix_and_stats() 74 print("Generating plots...") 75 validate.generate_plots()
class
ValidateMultipleGauss(ml_pid_cbm.validate_multiple_models.ValidateMultipleModels, ml_pid_cbm.gauss.validate_gauss.ValidateGauss):
12class ValidateMultipleGauss(ValidateMultipleModels, ValidateGauss): 13 """ 14 Class for validating data from multiple models. 15 Inherits from ValidateModel 16 """
Class for validating data from multiple models. Inherits from ValidateModel
Inherited Members
- validate_model.ValidateModel
- get_n_classes
- xgb_preds
- remap_names
- save_df
- sigma_selection
- evaluate_probas
- efficiency_stats
- confusion_matrix_and_stats
- generate_plots
- parse_model_name
def
parse_args(args: List[str]) -> argparse.Namespace:
19def parse_args(args: List[str]) -> argparse.Namespace: 20 """ 21 Arguments parser for the main method. 22 23 Args: 24 args (List[str]): Arguments from the command line, should be sys.argv[1:]. 25 26 Returns: 27 argparse.Namespace: argparse.Namespace containg args 28 """ 29 parser = argparse.ArgumentParser( 30 prog="ML_PID_CBM ValidateMultipleGauss", 31 description="Program for loading multiple validated PID Gauss models", 32 ) 33 parser.add_argument( 34 "--modelnames", 35 "-m", 36 nargs="+", 37 required=True, 38 type=str, 39 help="Names of folders containing trained and validated ML models.", 40 ) 41 parser.add_argument( 42 "--config", 43 "-c", 44 nargs=1, 45 required=True, 46 type=str, 47 help="Filename of path of config json file.", 48 ) 49 parser.add_argument( 50 "--nworkers", 51 "-n", 52 type=int, 53 default=1, 54 help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.", 55 ) 56 return parser.parse_args(args)
Arguments parser for the main method.
Args: args (List[str]): Arguments from the command line, should be sys.argv[1:].
Returns: argparse.Namespace: argparse.Namespace containg args