ml_pid_cbm.binary.validate_multiple_binary_models

 1import argparse
 2import os
 3import sys
 4from shutil import copy2
 5from typing import List, Set
 6
 7from ml_pid_cbm.binary.validate_binary_model import ValidateBinaryModel
 8from ml_pid_cbm.validate_multiple_models import ValidateMultipleModels
 9
10
11class ValidateMultipleBinaryModels(ValidateMultipleModels, ValidateBinaryModel):
12    """
13    Class for validating data from multiple binary models.
14    Inherits from ValidateModel
15    """
16    def __init__(self, json_file_name: str, files_list: Set[str], n_workers: int):
17        super().__init__(json_file_name, files_list, n_workers)
18
19def parse_args(args: List[str]) -> argparse.Namespace:
20    """
21    Arguments parser for the main method.
22
23    Args:
24        args (List[str]): Arguments from the command line, should be sys.argv[1:].
25
26    Returns:
27        argparse.Namespace: argparse.Namespace containg args
28    """
29    parser = argparse.ArgumentParser(
30        prog="ML_PID_CBM ValidateMultipleBinaryModels",
31        description="Program for loading multiple validated binary PID ML models",
32    )
33    parser.add_argument(
34        "--modelnames",
35        "-m",
36        nargs="+",
37        required=True,
38        type=str,
39        help="Names of folders containing trained and validated ML models.",
40    )
41    parser.add_argument(
42        "--config",
43        "-c",
44        nargs=1,
45        required=True,
46        type=str,
47        help="Filename of path of config json file.",
48    )
49    parser.add_argument(
50        "--nworkers",
51        "-n",
52        type=int,
53        default=1,
54        help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.",
55    )
56    return parser.parse_args(args)
57
58
59if __name__ == "__main__":
60    args = parse_args(sys.argv[1:])
61    # config  arguments to be loaded from args
62    json_file_name = args.config[0]
63    models = args.modelnames
64    n_workers = args.nworkers
65    pickle_files = {f"{model}/validated_data.pickle" for model in models}
66    validate = ValidateMultipleBinaryModels(json_file_name, pickle_files, n_workers)
67    # new folder for all files
68    json_file_path = os.path.join(os.getcwd(), json_file_name)
69    if not os.path.exists("all_models"):
70        os.makedirs("all_models")
71    os.chdir("all_models")
72    copy2(json_file_path, os.getcwd())
73    # graphs
74    validate.confusion_matrix_and_stats()
75    print("Generating plots...")
76    validate.generate_plots()
12class ValidateMultipleBinaryModels(ValidateMultipleModels, ValidateBinaryModel):
13    """
14    Class for validating data from multiple binary models.
15    Inherits from ValidateModel
16    """
17    def __init__(self, json_file_name: str, files_list: Set[str], n_workers: int):
18        super().__init__(json_file_name, files_list, n_workers)

Class for validating data from multiple binary models. Inherits from ValidateModel

ValidateMultipleBinaryModels(json_file_name: str, files_list: Set[str], n_workers: int)
17    def __init__(self, json_file_name: str, files_list: Set[str], n_workers: int):
18        super().__init__(json_file_name, files_list, n_workers)
Inherited Members
ml_pid_cbm.validate_multiple_models.ValidateMultipleModels
load_pickles
validate_model.ValidateModel
get_n_classes
xgb_preds
remap_names
save_df
sigma_selection
evaluate_probas
efficiency_stats
confusion_matrix_and_stats
generate_plots
parse_model_name
def parse_args(args: List[str]) -> argparse.Namespace:
20def parse_args(args: List[str]) -> argparse.Namespace:
21    """
22    Arguments parser for the main method.
23
24    Args:
25        args (List[str]): Arguments from the command line, should be sys.argv[1:].
26
27    Returns:
28        argparse.Namespace: argparse.Namespace containg args
29    """
30    parser = argparse.ArgumentParser(
31        prog="ML_PID_CBM ValidateMultipleBinaryModels",
32        description="Program for loading multiple validated binary PID ML models",
33    )
34    parser.add_argument(
35        "--modelnames",
36        "-m",
37        nargs="+",
38        required=True,
39        type=str,
40        help="Names of folders containing trained and validated ML models.",
41    )
42    parser.add_argument(
43        "--config",
44        "-c",
45        nargs=1,
46        required=True,
47        type=str,
48        help="Filename of path of config json file.",
49    )
50    parser.add_argument(
51        "--nworkers",
52        "-n",
53        type=int,
54        default=1,
55        help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.",
56    )
57    return parser.parse_args(args)

Arguments parser for the main method.

Args: args (List[str]): Arguments from the command line, should be sys.argv[1:].

Returns: argparse.Namespace: argparse.Namespace containg args