ml_pid_cbm.train_model

Module for training the model.

  1"""
  2Module for training the model.
  3
  4"""
  5import argparse
  6import gc
  7import os
  8import sys
  9from shutil import copy2
 10from typing import List
 11
 12from hipe4ml.model_handler import ModelHandler
 13from sklearn.utils.class_weight import compute_sample_weight
 14from tools import json_tools, plotting_tools
 15from tools.load_data import LoadData
 16from tools.prepare_model import PrepareModel
 17
 18
 19class TrainModel:
 20    """
 21    Class for training the ml model
 22    """
 23
 24    def __init__(self, model_hdl: ModelHandler, model_name: str):
 25        self.model_hdl = model_hdl
 26        self.model_name = model_name
 27
 28    def train_model_handler(
 29        self, train_test_data, sample_weights, model_hdl: ModelHandler = None
 30    ):
 31        """Trains model handler
 32
 33        Args:
 34            train_test_data (_type_): Train_test_data generated using a method from prepare_model module.
 35            sample_weights(List[Float]): ndarray of shape (n_samples,) Array with sample weights.
 36            To be computed with sklearn.utils.class_weight.compute_sample_weight
 37            model_hdl (ModelHandler, optional):  Hipe4ml model handler. Defaults to None.
 38        """
 39        model_hdl = model_hdl or self.model_hdl
 40        model_hdl.train_test_model(
 41            train_test_data,
 42            multi_class_opt="ovo",
 43            sample_weight=sample_weights,
 44        )
 45        self.model_hdl = model_hdl
 46
 47    def save_model(self, model_name: str = None, model_hdl: ModelHandler = None):
 48        """Saves trained model handler.
 49
 50        Args:
 51            model_name (str, optional): Name of the model handler. Defaults to None.
 52            model_hdl (ModelHandler, optional): Hipe4ml model handler. Defaults to None.
 53        """
 54        model_name = model_name or self.model_name
 55        model_hdl = model_hdl or self.model_hdl
 56        model_hdl.dump_model_handler(model_name)
 57        print(f"\nModel saved as {model_name}")
 58
 59
 60def parse_args(args: List[str]) -> argparse.Namespace:
 61    """
 62    Arguments parser for the main method.
 63
 64    Args:
 65        args (List[str]): Arguments from the command line, should be sys.argv[1:].
 66
 67    Returns:
 68        argparse.Namespace: argparse.Namespace containg args
 69    """
 70    parser = argparse.ArgumentParser(
 71        prog="ML_PID_CBM TrainModel", description="Program for training PID ML models"
 72    )
 73    parser.add_argument(
 74        "--config",
 75        "-c",
 76        nargs=1,
 77        required=True,
 78        type=str,
 79        help="Filename of path of config json file.",
 80    )
 81    parser.add_argument(
 82        "--momentum",
 83        "-p",
 84        nargs=2,
 85        required=True,
 86        type=float,
 87        help="Lower and upper momentum limit, e.g., 1 3",
 88    )
 89    parser.add_argument(
 90        "--antiparticles",
 91        action="store_true",
 92        help="If should train on particles instead of particles with positive charge.",
 93    )
 94    parser.add_argument(
 95        "--hyperparams",
 96        action="store_true",
 97        help="If should optimize hyper params instead of using const values from config file. Will use ranges from config file.",
 98    )
 99    parser.add_argument(
100        "--gpu",
101        action="store_true",
102        help="If should use GPU for training. Remember that xgboost-gpu version is needed for this.",
103    )
104    parser.add_argument(
105        "--nworkers",
106        "-n",
107        type=int,
108        default=1,
109        help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.",
110    )
111    graphs_group = parser.add_mutually_exclusive_group()
112    graphs_group.add_argument(
113        "--printplots",
114        action="store_true",
115        help="Creates plots and prints them without saving to file.",
116    )
117    graphs_group.add_argument(
118        "--saveplots",
119        "-plots",
120        action="store_true",
121        help="Creates plots and saves them to file, without printing.",
122    )
123    parser.add_argument(
124        "--usevalidation",
125        action="store_true",
126        help="if should use validation dataset for post-training plots",
127    )
128    return parser.parse_args(args)
129
130
131# main method of the training
132if __name__ == "__main__":
133    # parser for main class
134    args = parse_args(sys.argv[1:])
135    # config  arguments to be loaded from args
136    json_file_name = args.config[0]
137    lower_p_cut, upper_p_cut = args.momentum[0], args.momentum[1]
138    anti_particles = args.antiparticles
139    optimize_hyper_params = args.hyperparams
140    use_gpu = args.gpu
141    n_workers = args.nworkers
142    create_plots = args.printplots or args.saveplots or False
143    save_plots = args.saveplots
144    use_validation = args.usevalidation
145    if anti_particles:
146        model_name = f"model_{lower_p_cut:.1f}_{upper_p_cut:.1f}_anti"
147    else:
148        model_name = f"model_{lower_p_cut:.1f}_{upper_p_cut:.1f}_positive"
149    data_file_name = json_tools.load_file_name(json_file_name, "training")
150
151    # loading data
152    loader = LoadData(
153        data_file_name, json_file_name, lower_p_cut, upper_p_cut, anti_particles
154    )
155    tree_handler = loader.load_tree(max_workers=n_workers)
156    NSIGMA_PROTON = 0
157    NSIGMA_KAON = 0
158    NSIGMA_PION = 0
159    protons, kaons, pions = loader.get_protons_kaons_pions(
160        tree_handler,
161        nsigma_proton=NSIGMA_PROTON,
162        nsigma_kaon=NSIGMA_KAON,
163        nsigma_pion=NSIGMA_PION,
164    )
165    print(f"\nProtons, kaons, and pions loaded using file {data_file_name}\n")
166    del tree_handler
167    gc.collect()
168    # change location to specific folder for this model
169    json_file_path = os.path.join(os.getcwd(), json_file_name)
170    if not os.path.exists(f"{model_name}"):
171        os.makedirs(f"{model_name}")
172    os.chdir(f"{model_name}")
173    copy2(json_file_path, os.getcwd())
174    # pretraining plots
175    if create_plots:
176        print("Creating pre-training plots...")
177        plotting_tools.tof_plot(
178            protons,
179            json_file_name,
180            f"protons ({NSIGMA_PROTON}$\sigma$)",
181            save_fig=save_plots,
182        )
183        plotting_tools.tof_plot(
184            kaons, json_file_name, f"kaons ({NSIGMA_KAON}$\sigma$)", save_fig=save_plots
185        )
186        plotting_tools.tof_plot(
187            pions,
188            json_file_name,
189            f"pions, muons, electrons ({NSIGMA_PION}$\sigma$)",
190            save_fig=save_plots,
191        )
192        vars_to_draw = protons.get_var_names()
193        plotting_tools.correlations_plot(
194            vars_to_draw, [protons, kaons, pions], save_fig=save_plots
195        )
196    # loading model handler
197    model_hdl = PrepareModel(json_file_name, optimize_hyper_params, use_gpu)
198    train_test_data = PrepareModel.prepare_train_test_data([protons, kaons, pions])
199    del protons, kaons, pions
200    gc.collect()
201    features_for_train = json_tools.load_features_for_train(json_file_name)
202    print("\nPreparing model handler...")
203    model_hdl, study = model_hdl.prepare_model_handler(train_test_data=train_test_data)
204    if create_plots and optimize_hyper_params:
205        plotting_tools.opt_history_plot(study, save_plots)
206        plotting_tools.opt_contour_plot(study, save_plots)
207    # train model
208    train = TrainModel(model_hdl, model_name)
209    sample_weights = compute_sample_weight(
210        class_weight="balanced",  # class_weight=None or {0: 1, 1: 3, 2: 1}, deleted for now
211        y=train_test_data[1],
212    )
213    train.train_model_handler(train_test_data, sample_weights)
214    print("\nModel trained!")
215    train.save_model(model_name)
216    # loading validation dataset as test dataset for pos-training plots
217    if use_validation:
218        data_file_name_test = json_tools.load_file_name(json_file_name, "test")
219        loader_test = LoadData(
220            data_file_name_test,
221            json_file_name,
222            lower_p_cut,
223            upper_p_cut,
224            anti_particles,
225        )
226        tree_handler_test = loader_test.load_tree(max_workers=n_workers)
227        protons_test, kaons_test, pions_test = loader_test.get_protons_kaons_pions(
228            tree_handler_test,
229            nsigma_proton=NSIGMA_PROTON,
230            nsigma_kaon=NSIGMA_KAON,
231            nsigma_pion=NSIGMA_PION,
232        )
233        validation_data = PrepareModel.prepare_train_test_data(
234            [protons_test, kaons_test, pions_test]
235        )
236        train_test_data = [
237            train_test_data[0],
238            train_test_data[1],
239            validation_data[0],
240            validation_data[1],
241        ]
242    if create_plots:
243        print("Creating post-training plots")
244        y_pred_train = model_hdl.predict(train_test_data[0], False)
245        y_pred_test = model_hdl.predict(train_test_data[2], False)
246        plotting_tools.output_train_test_plot(
247            train.model_hdl, train_test_data, save_fig=save_plots, logscale=True
248        )
249
250        plotting_tools.roc_plot(train_test_data[3], y_pred_test, save_fig=save_plots)
251        # shapleys for each class
252        feature_names = [item.replace("Complex_", "") for item in features_for_train]
253        plotting_tools.plot_shap_summary(
254            train_test_data[0][features_for_train],
255            train_test_data[1],
256            model_hdl,
257            features_for_train,
258            n_workers,
259        )
class TrainModel:
20class TrainModel:
21    """
22    Class for training the ml model
23    """
24
25    def __init__(self, model_hdl: ModelHandler, model_name: str):
26        self.model_hdl = model_hdl
27        self.model_name = model_name
28
29    def train_model_handler(
30        self, train_test_data, sample_weights, model_hdl: ModelHandler = None
31    ):
32        """Trains model handler
33
34        Args:
35            train_test_data (_type_): Train_test_data generated using a method from prepare_model module.
36            sample_weights(List[Float]): ndarray of shape (n_samples,) Array with sample weights.
37            To be computed with sklearn.utils.class_weight.compute_sample_weight
38            model_hdl (ModelHandler, optional):  Hipe4ml model handler. Defaults to None.
39        """
40        model_hdl = model_hdl or self.model_hdl
41        model_hdl.train_test_model(
42            train_test_data,
43            multi_class_opt="ovo",
44            sample_weight=sample_weights,
45        )
46        self.model_hdl = model_hdl
47
48    def save_model(self, model_name: str = None, model_hdl: ModelHandler = None):
49        """Saves trained model handler.
50
51        Args:
52            model_name (str, optional): Name of the model handler. Defaults to None.
53            model_hdl (ModelHandler, optional): Hipe4ml model handler. Defaults to None.
54        """
55        model_name = model_name or self.model_name
56        model_hdl = model_hdl or self.model_hdl
57        model_hdl.dump_model_handler(model_name)
58        print(f"\nModel saved as {model_name}")

Class for training the ml model

TrainModel(model_hdl: hipe4ml.model_handler.ModelHandler, model_name: str)
25    def __init__(self, model_hdl: ModelHandler, model_name: str):
26        self.model_hdl = model_hdl
27        self.model_name = model_name
def train_model_handler( self, train_test_data, sample_weights, model_hdl: hipe4ml.model_handler.ModelHandler = None):
29    def train_model_handler(
30        self, train_test_data, sample_weights, model_hdl: ModelHandler = None
31    ):
32        """Trains model handler
33
34        Args:
35            train_test_data (_type_): Train_test_data generated using a method from prepare_model module.
36            sample_weights(List[Float]): ndarray of shape (n_samples,) Array with sample weights.
37            To be computed with sklearn.utils.class_weight.compute_sample_weight
38            model_hdl (ModelHandler, optional):  Hipe4ml model handler. Defaults to None.
39        """
40        model_hdl = model_hdl or self.model_hdl
41        model_hdl.train_test_model(
42            train_test_data,
43            multi_class_opt="ovo",
44            sample_weight=sample_weights,
45        )
46        self.model_hdl = model_hdl

Trains model handler

Args: train_test_data (_type_): Train_test_data generated using a method from prepare_model module. sample_weights(List[Float]): ndarray of shape (n_samples,) Array with sample weights. To be computed with sklearn.utils.class_weight.compute_sample_weight model_hdl (ModelHandler, optional): Hipe4ml model handler. Defaults to None.

def save_model( self, model_name: str = None, model_hdl: hipe4ml.model_handler.ModelHandler = None):
48    def save_model(self, model_name: str = None, model_hdl: ModelHandler = None):
49        """Saves trained model handler.
50
51        Args:
52            model_name (str, optional): Name of the model handler. Defaults to None.
53            model_hdl (ModelHandler, optional): Hipe4ml model handler. Defaults to None.
54        """
55        model_name = model_name or self.model_name
56        model_hdl = model_hdl or self.model_hdl
57        model_hdl.dump_model_handler(model_name)
58        print(f"\nModel saved as {model_name}")

Saves trained model handler.

Args: model_name (str, optional): Name of the model handler. Defaults to None. model_hdl (ModelHandler, optional): Hipe4ml model handler. Defaults to None.

def parse_args(args: List[str]) -> argparse.Namespace:
 61def parse_args(args: List[str]) -> argparse.Namespace:
 62    """
 63    Arguments parser for the main method.
 64
 65    Args:
 66        args (List[str]): Arguments from the command line, should be sys.argv[1:].
 67
 68    Returns:
 69        argparse.Namespace: argparse.Namespace containg args
 70    """
 71    parser = argparse.ArgumentParser(
 72        prog="ML_PID_CBM TrainModel", description="Program for training PID ML models"
 73    )
 74    parser.add_argument(
 75        "--config",
 76        "-c",
 77        nargs=1,
 78        required=True,
 79        type=str,
 80        help="Filename of path of config json file.",
 81    )
 82    parser.add_argument(
 83        "--momentum",
 84        "-p",
 85        nargs=2,
 86        required=True,
 87        type=float,
 88        help="Lower and upper momentum limit, e.g., 1 3",
 89    )
 90    parser.add_argument(
 91        "--antiparticles",
 92        action="store_true",
 93        help="If should train on particles instead of particles with positive charge.",
 94    )
 95    parser.add_argument(
 96        "--hyperparams",
 97        action="store_true",
 98        help="If should optimize hyper params instead of using const values from config file. Will use ranges from config file.",
 99    )
100    parser.add_argument(
101        "--gpu",
102        action="store_true",
103        help="If should use GPU for training. Remember that xgboost-gpu version is needed for this.",
104    )
105    parser.add_argument(
106        "--nworkers",
107        "-n",
108        type=int,
109        default=1,
110        help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.",
111    )
112    graphs_group = parser.add_mutually_exclusive_group()
113    graphs_group.add_argument(
114        "--printplots",
115        action="store_true",
116        help="Creates plots and prints them without saving to file.",
117    )
118    graphs_group.add_argument(
119        "--saveplots",
120        "-plots",
121        action="store_true",
122        help="Creates plots and saves them to file, without printing.",
123    )
124    parser.add_argument(
125        "--usevalidation",
126        action="store_true",
127        help="if should use validation dataset for post-training plots",
128    )
129    return parser.parse_args(args)

Arguments parser for the main method.

Args: args (List[str]): Arguments from the command line, should be sys.argv[1:].

Returns: argparse.Namespace: argparse.Namespace containg args