ml_pid_cbm.validate_multiple_models

 1import argparse
 2import os
 3import sys
 4from concurrent.futures import ThreadPoolExecutor
 5from shutil import copy2
 6from typing import List, Set
 7
 8import pandas as pd
 9
10from validate_model import ValidateModel
11
12
13class ValidateMultipleModels(ValidateModel):
14    """
15    Class for validating data from multiple models.
16    Inherits from ValidateModel
17    """
18
19    def __init__(self, json_file_name: str, files_list: Set[str], n_workers: int = 1):
20        super().__init__(-12, 12, False, json_file_name, None)
21        self.particles_df = self.load_pickles(files_list, n_workers)
22
23    @staticmethod
24    def load_pickles(files_list: Set[str], n_workers: int = 1) -> pd.DataFrame:
25        """Loads multiple pickle files produced by validate_model module.
26
27        Args:
28            files_list (Set[str]): Files list containg picle files with datasets.
29            n_workers (int, optional): Number of workers for multithreading. Defaults to 1.
30
31        Returns:
32            pd.DataFrame: Dataframe with merged datasets.
33        """
34        with ThreadPoolExecutor(max_workers=n_workers) as executor:
35            results = list(executor.map(pd.read_pickle, files_list))
36            whole_df = pd.concat(results, ignore_index=True)
37        return whole_df
38
39
40def parse_args(args: List[str]) -> argparse.Namespace:
41    """
42    Arguments parser for the main method.
43
44    Args:
45        args (List[str]): Arguments from the command line, should be sys.argv[1:].
46
47    Returns:
48        argparse.Namespace: argparse.Namespace containg args
49    """
50    parser = argparse.ArgumentParser(
51        prog="ML_PID_CBM ValidateMultipleModels",
52        description="Program for loading multiple validated PID ML models",
53    )
54    parser.add_argument(
55        "--modelnames",
56        "-m",
57        nargs="+",
58        required=True,
59        type=str,
60        help="Names of folders containing trained and validated ML models.",
61    )
62    parser.add_argument(
63        "--config",
64        "-c",
65        nargs=1,
66        required=True,
67        type=str,
68        help="Filename of path of config json file.",
69    )
70    parser.add_argument(
71        "--nworkers",
72        "-n",
73        type=int,
74        default=1,
75        help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.",
76    )
77    return parser.parse_args(args)
78
79
80if __name__ == "__main__":
81    args = parse_args(sys.argv[1:])
82    # config  arguments to be loaded from args
83    json_file_name = args.config[0]
84    models = args.modelnames
85    n_workers = args.nworkers
86    pickle_files = {f"{model}/validated_data.pickle" for model in models}
87    validate = ValidateMultipleModels(json_file_name, pickle_files, n_workers)
88    # new folder for all files
89    json_file_path = os.path.join(os.getcwd(), json_file_name)
90    if not os.path.exists("all_models"):
91        os.makedirs("all_models")
92    os.chdir("all_models")
93    copy2(json_file_path, os.getcwd())
94    # graphs
95    validate.confusion_matrix_and_stats()
96    print("Generating plots...")
97    validate.generate_plots()
class ValidateMultipleModels(validate_model.ValidateModel):
14class ValidateMultipleModels(ValidateModel):
15    """
16    Class for validating data from multiple models.
17    Inherits from ValidateModel
18    """
19
20    def __init__(self, json_file_name: str, files_list: Set[str], n_workers: int = 1):
21        super().__init__(-12, 12, False, json_file_name, None)
22        self.particles_df = self.load_pickles(files_list, n_workers)
23
24    @staticmethod
25    def load_pickles(files_list: Set[str], n_workers: int = 1) -> pd.DataFrame:
26        """Loads multiple pickle files produced by validate_model module.
27
28        Args:
29            files_list (Set[str]): Files list containg picle files with datasets.
30            n_workers (int, optional): Number of workers for multithreading. Defaults to 1.
31
32        Returns:
33            pd.DataFrame: Dataframe with merged datasets.
34        """
35        with ThreadPoolExecutor(max_workers=n_workers) as executor:
36            results = list(executor.map(pd.read_pickle, files_list))
37            whole_df = pd.concat(results, ignore_index=True)
38        return whole_df

Class for validating data from multiple models. Inherits from ValidateModel

ValidateMultipleModels(json_file_name: str, files_list: Set[str], n_workers: int = 1)
20    def __init__(self, json_file_name: str, files_list: Set[str], n_workers: int = 1):
21        super().__init__(-12, 12, False, json_file_name, None)
22        self.particles_df = self.load_pickles(files_list, n_workers)
@staticmethod
def load_pickles(files_list: Set[str], n_workers: int = 1) -> pandas.core.frame.DataFrame:
24    @staticmethod
25    def load_pickles(files_list: Set[str], n_workers: int = 1) -> pd.DataFrame:
26        """Loads multiple pickle files produced by validate_model module.
27
28        Args:
29            files_list (Set[str]): Files list containg picle files with datasets.
30            n_workers (int, optional): Number of workers for multithreading. Defaults to 1.
31
32        Returns:
33            pd.DataFrame: Dataframe with merged datasets.
34        """
35        with ThreadPoolExecutor(max_workers=n_workers) as executor:
36            results = list(executor.map(pd.read_pickle, files_list))
37            whole_df = pd.concat(results, ignore_index=True)
38        return whole_df

Loads multiple pickle files produced by validate_model module.

Args: files_list (Set[str]): Files list containg picle files with datasets. n_workers (int, optional): Number of workers for multithreading. Defaults to 1.

Returns: pd.DataFrame: Dataframe with merged datasets.

Inherited Members
validate_model.ValidateModel
get_n_classes
xgb_preds
remap_names
save_df
sigma_selection
evaluate_probas
efficiency_stats
confusion_matrix_and_stats
generate_plots
parse_model_name
def parse_args(args: List[str]) -> argparse.Namespace:
41def parse_args(args: List[str]) -> argparse.Namespace:
42    """
43    Arguments parser for the main method.
44
45    Args:
46        args (List[str]): Arguments from the command line, should be sys.argv[1:].
47
48    Returns:
49        argparse.Namespace: argparse.Namespace containg args
50    """
51    parser = argparse.ArgumentParser(
52        prog="ML_PID_CBM ValidateMultipleModels",
53        description="Program for loading multiple validated PID ML models",
54    )
55    parser.add_argument(
56        "--modelnames",
57        "-m",
58        nargs="+",
59        required=True,
60        type=str,
61        help="Names of folders containing trained and validated ML models.",
62    )
63    parser.add_argument(
64        "--config",
65        "-c",
66        nargs=1,
67        required=True,
68        type=str,
69        help="Filename of path of config json file.",
70    )
71    parser.add_argument(
72        "--nworkers",
73        "-n",
74        type=int,
75        default=1,
76        help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.",
77    )
78    return parser.parse_args(args)

Arguments parser for the main method.

Args: args (List[str]): Arguments from the command line, should be sys.argv[1:].

Returns: argparse.Namespace: argparse.Namespace containg args