ml_pid_cbm.binary.validate_multiple_binary_models
1import argparse 2import os 3import sys 4from shutil import copy2 5from typing import List, Set 6 7from ml_pid_cbm.binary.validate_binary_model import ValidateBinaryModel 8from ml_pid_cbm.validate_multiple_models import ValidateMultipleModels 9 10 11class ValidateMultipleBinaryModels(ValidateMultipleModels, ValidateBinaryModel): 12 """ 13 Class for validating data from multiple binary models. 14 Inherits from ValidateModel 15 """ 16 def __init__(self, json_file_name: str, files_list: Set[str], n_workers: int): 17 super().__init__(json_file_name, files_list, n_workers) 18 19def parse_args(args: List[str]) -> argparse.Namespace: 20 """ 21 Arguments parser for the main method. 22 23 Args: 24 args (List[str]): Arguments from the command line, should be sys.argv[1:]. 25 26 Returns: 27 argparse.Namespace: argparse.Namespace containg args 28 """ 29 parser = argparse.ArgumentParser( 30 prog="ML_PID_CBM ValidateMultipleBinaryModels", 31 description="Program for loading multiple validated binary PID ML models", 32 ) 33 parser.add_argument( 34 "--modelnames", 35 "-m", 36 nargs="+", 37 required=True, 38 type=str, 39 help="Names of folders containing trained and validated ML models.", 40 ) 41 parser.add_argument( 42 "--config", 43 "-c", 44 nargs=1, 45 required=True, 46 type=str, 47 help="Filename of path of config json file.", 48 ) 49 parser.add_argument( 50 "--nworkers", 51 "-n", 52 type=int, 53 default=1, 54 help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.", 55 ) 56 return parser.parse_args(args) 57 58 59if __name__ == "__main__": 60 args = parse_args(sys.argv[1:]) 61 # config arguments to be loaded from args 62 json_file_name = args.config[0] 63 models = args.modelnames 64 n_workers = args.nworkers 65 pickle_files = {f"{model}/validated_data.pickle" for model in models} 66 validate = ValidateMultipleBinaryModels(json_file_name, pickle_files, n_workers) 67 # new folder for all files 68 json_file_path = os.path.join(os.getcwd(), json_file_name) 69 if not os.path.exists("all_models"): 70 os.makedirs("all_models") 71 os.chdir("all_models") 72 copy2(json_file_path, os.getcwd()) 73 # graphs 74 validate.confusion_matrix_and_stats() 75 print("Generating plots...") 76 validate.generate_plots()
class
ValidateMultipleBinaryModels(ml_pid_cbm.validate_multiple_models.ValidateMultipleModels, ml_pid_cbm.binary.validate_binary_model.ValidateBinaryModel):
12class ValidateMultipleBinaryModels(ValidateMultipleModels, ValidateBinaryModel): 13 """ 14 Class for validating data from multiple binary models. 15 Inherits from ValidateModel 16 """ 17 def __init__(self, json_file_name: str, files_list: Set[str], n_workers: int): 18 super().__init__(json_file_name, files_list, n_workers)
Class for validating data from multiple binary models. Inherits from ValidateModel
Inherited Members
- validate_model.ValidateModel
- get_n_classes
- xgb_preds
- remap_names
- save_df
- sigma_selection
- evaluate_probas
- efficiency_stats
- confusion_matrix_and_stats
- generate_plots
- parse_model_name
def
parse_args(args: List[str]) -> argparse.Namespace:
20def parse_args(args: List[str]) -> argparse.Namespace: 21 """ 22 Arguments parser for the main method. 23 24 Args: 25 args (List[str]): Arguments from the command line, should be sys.argv[1:]. 26 27 Returns: 28 argparse.Namespace: argparse.Namespace containg args 29 """ 30 parser = argparse.ArgumentParser( 31 prog="ML_PID_CBM ValidateMultipleBinaryModels", 32 description="Program for loading multiple validated binary PID ML models", 33 ) 34 parser.add_argument( 35 "--modelnames", 36 "-m", 37 nargs="+", 38 required=True, 39 type=str, 40 help="Names of folders containing trained and validated ML models.", 41 ) 42 parser.add_argument( 43 "--config", 44 "-c", 45 nargs=1, 46 required=True, 47 type=str, 48 help="Filename of path of config json file.", 49 ) 50 parser.add_argument( 51 "--nworkers", 52 "-n", 53 type=int, 54 default=1, 55 help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.", 56 ) 57 return parser.parse_args(args)
Arguments parser for the main method.
Args: args (List[str]): Arguments from the command line, should be sys.argv[1:].
Returns: argparse.Namespace: argparse.Namespace containg args