ml_pid_cbm.gauss.validate_multiple_gauss

 1import argparse
 2import os
 3import sys
 4from shutil import copy2
 5from typing import List
 6
 7from ml_pid_cbm.gauss.validate_gauss import ValidateGauss
 8from ml_pid_cbm.validate_multiple_models import ValidateMultipleModels
 9
10
11class ValidateMultipleGauss(ValidateMultipleModels, ValidateGauss):
12    """
13    Class for validating data from multiple models.
14    Inherits from ValidateModel
15    """
16
17
18def parse_args(args: List[str]) -> argparse.Namespace:
19    """
20    Arguments parser for the main method.
21
22    Args:
23        args (List[str]): Arguments from the command line, should be sys.argv[1:].
24
25    Returns:
26        argparse.Namespace: argparse.Namespace containg args
27    """
28    parser = argparse.ArgumentParser(
29        prog="ML_PID_CBM ValidateMultipleGauss",
30        description="Program for loading multiple validated PID Gauss models",
31    )
32    parser.add_argument(
33        "--modelnames",
34        "-m",
35        nargs="+",
36        required=True,
37        type=str,
38        help="Names of folders containing trained and validated ML models.",
39    )
40    parser.add_argument(
41        "--config",
42        "-c",
43        nargs=1,
44        required=True,
45        type=str,
46        help="Filename of path of config json file.",
47    )
48    parser.add_argument(
49        "--nworkers",
50        "-n",
51        type=int,
52        default=1,
53        help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.",
54    )
55    return parser.parse_args(args)
56
57
58if __name__ == "__main__":
59    args = parse_args(sys.argv[1:])
60    # config  arguments to be loaded from args
61    json_file_name = args.config[0]
62    models = args.modelnames
63    n_workers = args.nworkers
64    pickle_files = {f"{model}/validated_data.pickle" for model in models}
65    validate = ValidateMultipleGauss(json_file_name, pickle_files, n_workers)
66    # new folder for all files
67    json_file_path = os.path.join(os.getcwd(), json_file_name)
68    if not os.path.exists("all_models"):
69        os.makedirs("all_models")
70    os.chdir("all_models")
71    copy2(json_file_path, os.getcwd())
72    # graphs
73    validate.confusion_matrix_and_stats()
74    print("Generating plots...")
75    validate.generate_plots()
12class ValidateMultipleGauss(ValidateMultipleModels, ValidateGauss):
13    """
14    Class for validating data from multiple models.
15    Inherits from ValidateModel
16    """

Class for validating data from multiple models. Inherits from ValidateModel

Inherited Members
ml_pid_cbm.validate_multiple_models.ValidateMultipleModels
ValidateMultipleModels
load_pickles
validate_model.ValidateModel
get_n_classes
xgb_preds
remap_names
save_df
sigma_selection
evaluate_probas
efficiency_stats
confusion_matrix_and_stats
generate_plots
parse_model_name
ml_pid_cbm.gauss.validate_gauss.ValidateGauss
gauss_preds
remap_gauss_names
def parse_args(args: List[str]) -> argparse.Namespace:
19def parse_args(args: List[str]) -> argparse.Namespace:
20    """
21    Arguments parser for the main method.
22
23    Args:
24        args (List[str]): Arguments from the command line, should be sys.argv[1:].
25
26    Returns:
27        argparse.Namespace: argparse.Namespace containg args
28    """
29    parser = argparse.ArgumentParser(
30        prog="ML_PID_CBM ValidateMultipleGauss",
31        description="Program for loading multiple validated PID Gauss models",
32    )
33    parser.add_argument(
34        "--modelnames",
35        "-m",
36        nargs="+",
37        required=True,
38        type=str,
39        help="Names of folders containing trained and validated ML models.",
40    )
41    parser.add_argument(
42        "--config",
43        "-c",
44        nargs=1,
45        required=True,
46        type=str,
47        help="Filename of path of config json file.",
48    )
49    parser.add_argument(
50        "--nworkers",
51        "-n",
52        type=int,
53        default=1,
54        help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.",
55    )
56    return parser.parse_args(args)

Arguments parser for the main method.

Args: args (List[str]): Arguments from the command line, should be sys.argv[1:].

Returns: argparse.Namespace: argparse.Namespace containg args