ml_pid_cbm.validate_multiple_models
1import argparse 2import os 3import sys 4from concurrent.futures import ThreadPoolExecutor 5from shutil import copy2 6from typing import List, Set 7 8import pandas as pd 9 10from validate_model import ValidateModel 11 12 13class ValidateMultipleModels(ValidateModel): 14 """ 15 Class for validating data from multiple models. 16 Inherits from ValidateModel 17 """ 18 19 def __init__(self, json_file_name: str, files_list: Set[str], n_workers: int = 1): 20 super().__init__(-12, 12, False, json_file_name, None) 21 self.particles_df = self.load_pickles(files_list, n_workers) 22 23 @staticmethod 24 def load_pickles(files_list: Set[str], n_workers: int = 1) -> pd.DataFrame: 25 """Loads multiple pickle files produced by validate_model module. 26 27 Args: 28 files_list (Set[str]): Files list containg picle files with datasets. 29 n_workers (int, optional): Number of workers for multithreading. Defaults to 1. 30 31 Returns: 32 pd.DataFrame: Dataframe with merged datasets. 33 """ 34 with ThreadPoolExecutor(max_workers=n_workers) as executor: 35 results = list(executor.map(pd.read_pickle, files_list)) 36 whole_df = pd.concat(results, ignore_index=True) 37 return whole_df 38 39 40def parse_args(args: List[str]) -> argparse.Namespace: 41 """ 42 Arguments parser for the main method. 43 44 Args: 45 args (List[str]): Arguments from the command line, should be sys.argv[1:]. 46 47 Returns: 48 argparse.Namespace: argparse.Namespace containg args 49 """ 50 parser = argparse.ArgumentParser( 51 prog="ML_PID_CBM ValidateMultipleModels", 52 description="Program for loading multiple validated PID ML models", 53 ) 54 parser.add_argument( 55 "--modelnames", 56 "-m", 57 nargs="+", 58 required=True, 59 type=str, 60 help="Names of folders containing trained and validated ML models.", 61 ) 62 parser.add_argument( 63 "--config", 64 "-c", 65 nargs=1, 66 required=True, 67 type=str, 68 help="Filename of path of config json file.", 69 ) 70 parser.add_argument( 71 "--nworkers", 72 "-n", 73 type=int, 74 default=1, 75 help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.", 76 ) 77 return parser.parse_args(args) 78 79 80if __name__ == "__main__": 81 args = parse_args(sys.argv[1:]) 82 # config arguments to be loaded from args 83 json_file_name = args.config[0] 84 models = args.modelnames 85 n_workers = args.nworkers 86 pickle_files = {f"{model}/validated_data.pickle" for model in models} 87 validate = ValidateMultipleModels(json_file_name, pickle_files, n_workers) 88 # new folder for all files 89 json_file_path = os.path.join(os.getcwd(), json_file_name) 90 if not os.path.exists("all_models"): 91 os.makedirs("all_models") 92 os.chdir("all_models") 93 copy2(json_file_path, os.getcwd()) 94 # graphs 95 validate.confusion_matrix_and_stats() 96 print("Generating plots...") 97 validate.generate_plots()
class
ValidateMultipleModels(validate_model.ValidateModel):
14class ValidateMultipleModels(ValidateModel): 15 """ 16 Class for validating data from multiple models. 17 Inherits from ValidateModel 18 """ 19 20 def __init__(self, json_file_name: str, files_list: Set[str], n_workers: int = 1): 21 super().__init__(-12, 12, False, json_file_name, None) 22 self.particles_df = self.load_pickles(files_list, n_workers) 23 24 @staticmethod 25 def load_pickles(files_list: Set[str], n_workers: int = 1) -> pd.DataFrame: 26 """Loads multiple pickle files produced by validate_model module. 27 28 Args: 29 files_list (Set[str]): Files list containg picle files with datasets. 30 n_workers (int, optional): Number of workers for multithreading. Defaults to 1. 31 32 Returns: 33 pd.DataFrame: Dataframe with merged datasets. 34 """ 35 with ThreadPoolExecutor(max_workers=n_workers) as executor: 36 results = list(executor.map(pd.read_pickle, files_list)) 37 whole_df = pd.concat(results, ignore_index=True) 38 return whole_df
Class for validating data from multiple models. Inherits from ValidateModel
@staticmethod
def
load_pickles(files_list: Set[str], n_workers: int = 1) -> pandas.core.frame.DataFrame:
24 @staticmethod 25 def load_pickles(files_list: Set[str], n_workers: int = 1) -> pd.DataFrame: 26 """Loads multiple pickle files produced by validate_model module. 27 28 Args: 29 files_list (Set[str]): Files list containg picle files with datasets. 30 n_workers (int, optional): Number of workers for multithreading. Defaults to 1. 31 32 Returns: 33 pd.DataFrame: Dataframe with merged datasets. 34 """ 35 with ThreadPoolExecutor(max_workers=n_workers) as executor: 36 results = list(executor.map(pd.read_pickle, files_list)) 37 whole_df = pd.concat(results, ignore_index=True) 38 return whole_df
Loads multiple pickle files produced by validate_model module.
Args: files_list (Set[str]): Files list containg picle files with datasets. n_workers (int, optional): Number of workers for multithreading. Defaults to 1.
Returns: pd.DataFrame: Dataframe with merged datasets.
Inherited Members
- validate_model.ValidateModel
- get_n_classes
- xgb_preds
- remap_names
- save_df
- sigma_selection
- evaluate_probas
- efficiency_stats
- confusion_matrix_and_stats
- generate_plots
- parse_model_name
def
parse_args(args: List[str]) -> argparse.Namespace:
41def parse_args(args: List[str]) -> argparse.Namespace: 42 """ 43 Arguments parser for the main method. 44 45 Args: 46 args (List[str]): Arguments from the command line, should be sys.argv[1:]. 47 48 Returns: 49 argparse.Namespace: argparse.Namespace containg args 50 """ 51 parser = argparse.ArgumentParser( 52 prog="ML_PID_CBM ValidateMultipleModels", 53 description="Program for loading multiple validated PID ML models", 54 ) 55 parser.add_argument( 56 "--modelnames", 57 "-m", 58 nargs="+", 59 required=True, 60 type=str, 61 help="Names of folders containing trained and validated ML models.", 62 ) 63 parser.add_argument( 64 "--config", 65 "-c", 66 nargs=1, 67 required=True, 68 type=str, 69 help="Filename of path of config json file.", 70 ) 71 parser.add_argument( 72 "--nworkers", 73 "-n", 74 type=int, 75 default=1, 76 help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.", 77 ) 78 return parser.parse_args(args)
Arguments parser for the main method.
Args: args (List[str]): Arguments from the command line, should be sys.argv[1:].
Returns: argparse.Namespace: argparse.Namespace containg args