ml_pid_cbm.binary.train_binary_model

  1import argparse
  2import gc
  3import os
  4import sys
  5from shutil import copy2
  6from typing import List
  7
  8from hipe4ml.model_handler import ModelHandler
  9from sklearn.utils.class_weight import compute_sample_weight
 10
 11from ml_pid_cbm.tools import json_tools, plotting_tools
 12from ml_pid_cbm.tools.load_data import LoadData
 13from ml_pid_cbm.tools.particles_id import ParticlesId as Pid
 14from ml_pid_cbm.tools.prepare_model import PrepareModel
 15from ml_pid_cbm.train_model import TrainModel
 16
 17
 18class TrainBinaryModel(TrainModel):
 19    """
 20    Class for training the ml model
 21    """
 22
 23    def train_model_handler(
 24        self, train_test_data, sample_weights, model_hdl: ModelHandler = None
 25    ):
 26        """Trains model handler
 27
 28        Args:
 29            train_test_data (_type_): Train_test_data generated using a method from prepare_model module.
 30            sample_weights(List[Float]): ndarray of shape (n_samples,) Array with sample weights.
 31            To be computed with sklearn.utils.class_weight.compute_sample_weight
 32            model_hdl (ModelHandler, optional):  Hipe4ml model handler. Defaults to None.
 33        """
 34        model_hdl = model_hdl or self.model_hdl
 35        model_hdl.train_test_model(train_test_data, sample_weight=sample_weights)
 36        self.model_hdl = model_hdl
 37
 38
 39def parse_args(args: List[str]) -> argparse.Namespace:
 40    """
 41    Arguments parser for the main method.
 42
 43    Args:
 44        args (List[str]): Arguments from the command line, should be sys.argv[1:].
 45
 46    Returns:
 47        argparse.Namespace: argparse.Namespace containg args
 48    """
 49    parser = argparse.ArgumentParser(
 50        prog="ML_PID_CBM TrainModel",
 51        description="Program for training binary PID ML models",
 52    )
 53    parser.add_argument(
 54        "--config",
 55        "-c",
 56        nargs=1,
 57        required=True,
 58        type=str,
 59        help="Filename of path of config json file.",
 60    )
 61    parser.add_argument(
 62        "--momentum",
 63        "-p",
 64        nargs=2,
 65        required=True,
 66        type=float,
 67        help="Lower and upper momentum limit, e.g., 1 3",
 68    )
 69    parser.add_argument(
 70        "--antiparticles",
 71        action="store_true",
 72        help="If should train on particles instead of particles with positive charge.",
 73    )
 74    parser.add_argument(
 75        "--hyperparams",
 76        action="store_true",
 77        help="If should optimize hyper params instead of using const values from config file. Will use ranges from config file.",
 78    )
 79    parser.add_argument(
 80        "--gpu",
 81        action="store_true",
 82        help="If should use GPU for training. Remember that xgboost-gpu version is needed for this.",
 83    )
 84    parser.add_argument(
 85        "--nworkers",
 86        "-n",
 87        type=int,
 88        default=1,
 89        help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.",
 90    )
 91    graphs_group = parser.add_mutually_exclusive_group()
 92    graphs_group.add_argument(
 93        "--printplots",
 94        action="store_true",
 95        help="Creates plots and prints them without saving to file.",
 96    )
 97    graphs_group.add_argument(
 98        "--saveplots",
 99        "-plots",
100        action="store_true",
101        help="Creates plots and saves them to file, without printing.",
102    )
103    parser.add_argument(
104        "--usevalidation",
105        action="store_true",
106        help="if should use validation dataset for post-training plots",
107    )
108    return parser.parse_args(args)
109
110
111# main method of the training
112if __name__ == "__main__":
113    # parser for main class
114    args = parse_args(sys.argv[1:])
115    # config  arguments to be loaded from args
116    json_file_name = args.config[0]
117    lower_p_cut, upper_p_cut = args.momentum[0], args.momentum[1]
118    anti_particles = args.antiparticles
119    optimize_hyper_params = args.hyperparams
120    use_gpu = args.gpu
121    n_workers = args.nworkers
122    create_plots = args.printplots or args.saveplots or False
123    save_plots = args.saveplots
124    use_validation = args.usevalidation
125    if anti_particles:
126        model_name = f"model_{lower_p_cut:.1f}_{upper_p_cut:.1f}_anti"
127    else:
128        model_name = f"model_{lower_p_cut:.1f}_{upper_p_cut:.1f}_positive"
129    data_file_name = json_tools.load_file_name(json_file_name, "training")
130
131    # loading data
132    loader = LoadData(
133        data_file_name, json_file_name, lower_p_cut, upper_p_cut, anti_particles
134    )
135    tree_handler = loader.load_tree(max_workers=n_workers)
136    NSIGMA = 5
137    kaons = loader.get_particles_type(
138        tree_handler, Pid.POS_KAON.value, nsigma=NSIGMA, json_file_name=json_file_name
139    )
140    pid_var_name = json_tools.load_var_name(json_file_name, "pid")
141    bckgr = tree_handler.get_subset(f"{pid_var_name} != {Pid.POS_KAON.value}")
142    print(f"\nKaons and bckgr (everything else) loaded using file {data_file_name}\n")
143    del tree_handler
144    gc.collect()
145    # change location to specific folder for this model
146    json_file_path = os.path.join(os.getcwd(), json_file_name)
147    if not os.path.exists(f"{model_name}"):
148        os.makedirs(f"{model_name}")
149    os.chdir(f"{model_name}")
150    copy2(json_file_path, os.getcwd())
151    # pretraining plots
152    if create_plots:
153        print("Creating pre-training plots...")
154        plotting_tools.tof_plot(
155            kaons, json_file_name, f"kaons ({NSIGMA}$\sigma$)", save_fig=save_plots
156        )
157        vars_to_draw = bckgr.get_var_names()
158        plotting_tools.correlations_plot(
159            vars_to_draw, [kaons, bckgr], save_fig=save_plots
160        )
161    # loading model handler
162    model_hdl = PrepareModel(json_file_name, optimize_hyper_params, use_gpu)
163    train_test_data = PrepareModel.prepare_train_test_data([kaons, bckgr])
164    del kaons, bckgr
165    gc.collect()
166    features_for_train = json_tools.load_features_for_train(json_file_name)
167    print("\nPreparing model handler...")
168    model_hdl, study = model_hdl.prepare_model_handler(train_test_data=train_test_data)
169    if create_plots and optimize_hyper_params:
170        plotting_tools.opt_history_plot(study, save_plots)
171        plotting_tools.opt_contour_plot(study, save_plots)
172    # train model
173    train = TrainModel(model_hdl, model_name)
174    sample_weights = compute_sample_weight(
175        class_weight=None,  # class_weight="balanced" deleted for now
176        y=train_test_data[1],
177    )
178    train.train_model_handler(train_test_data, sample_weights)
179    print("\nModel trained!")
180    train.save_model(model_name)
181    # loading validation dataset as test dataset for pos-training plots
182    if use_validation:
183        data_file_name_test = json_tools.load_file_name(json_file_name, "test")
184        loader_test = LoadData(
185            data_file_name_test,
186            json_file_name,
187            lower_p_cut,
188            upper_p_cut,
189            anti_particles,
190        )
191        tree_handler_test = loader_test.load_tree(max_workers=n_workers)
192        kaons_test = loader.get_particles_type(
193            tree_handler_test,
194            Pid.POS_KAON.value,
195            nsigma=NSIGMA,
196            json_file_name=json_file_name,
197        )
198        bckgr_test = tree_handler_test.get_subset(f"{pid_var_name} != {Pid.POS_KAON.value}")
199        validation_data = PrepareModel.prepare_train_test_data([kaons_test, bckgr_test])
200        train_test_data = [
201            train_test_data[0],
202            train_test_data[1],
203            validation_data[0],
204            validation_data[1],
205        ]
206    if create_plots:
207        print("Creating post-training plots")
208        y_pred_train = model_hdl.predict(train_test_data[0], False)
209        y_pred_test = model_hdl.predict(train_test_data[2], False)
210        plotting_tools.output_train_test_plot(
211            train.model_hdl, train_test_data, leg_labels=["kaons"], save_fig=save_plots, logscale=True
212        )
213
214        plotting_tools.roc_plot(train_test_data[3], y_pred_test, save_fig=save_plots)
215        # shapleys for each class
216        feature_names = [item.replace("Complex_", "") for item in features_for_train]
217        plotting_tools.plot_shap_summary(
218            train_test_data[0][features_for_train],
219            train_test_data[1],
220            model_hdl,
221            features_for_train,
222            n_workers,
223        )
class TrainBinaryModel(ml_pid_cbm.train_model.TrainModel):
19class TrainBinaryModel(TrainModel):
20    """
21    Class for training the ml model
22    """
23
24    def train_model_handler(
25        self, train_test_data, sample_weights, model_hdl: ModelHandler = None
26    ):
27        """Trains model handler
28
29        Args:
30            train_test_data (_type_): Train_test_data generated using a method from prepare_model module.
31            sample_weights(List[Float]): ndarray of shape (n_samples,) Array with sample weights.
32            To be computed with sklearn.utils.class_weight.compute_sample_weight
33            model_hdl (ModelHandler, optional):  Hipe4ml model handler. Defaults to None.
34        """
35        model_hdl = model_hdl or self.model_hdl
36        model_hdl.train_test_model(train_test_data, sample_weight=sample_weights)
37        self.model_hdl = model_hdl

Class for training the ml model

def train_model_handler( self, train_test_data, sample_weights, model_hdl: hipe4ml.model_handler.ModelHandler = None):
24    def train_model_handler(
25        self, train_test_data, sample_weights, model_hdl: ModelHandler = None
26    ):
27        """Trains model handler
28
29        Args:
30            train_test_data (_type_): Train_test_data generated using a method from prepare_model module.
31            sample_weights(List[Float]): ndarray of shape (n_samples,) Array with sample weights.
32            To be computed with sklearn.utils.class_weight.compute_sample_weight
33            model_hdl (ModelHandler, optional):  Hipe4ml model handler. Defaults to None.
34        """
35        model_hdl = model_hdl or self.model_hdl
36        model_hdl.train_test_model(train_test_data, sample_weight=sample_weights)
37        self.model_hdl = model_hdl

Trains model handler

Args: train_test_data (_type_): Train_test_data generated using a method from prepare_model module. sample_weights(List[Float]): ndarray of shape (n_samples,) Array with sample weights. To be computed with sklearn.utils.class_weight.compute_sample_weight model_hdl (ModelHandler, optional): Hipe4ml model handler. Defaults to None.

def parse_args(args: List[str]) -> argparse.Namespace:
 40def parse_args(args: List[str]) -> argparse.Namespace:
 41    """
 42    Arguments parser for the main method.
 43
 44    Args:
 45        args (List[str]): Arguments from the command line, should be sys.argv[1:].
 46
 47    Returns:
 48        argparse.Namespace: argparse.Namespace containg args
 49    """
 50    parser = argparse.ArgumentParser(
 51        prog="ML_PID_CBM TrainModel",
 52        description="Program for training binary PID ML models",
 53    )
 54    parser.add_argument(
 55        "--config",
 56        "-c",
 57        nargs=1,
 58        required=True,
 59        type=str,
 60        help="Filename of path of config json file.",
 61    )
 62    parser.add_argument(
 63        "--momentum",
 64        "-p",
 65        nargs=2,
 66        required=True,
 67        type=float,
 68        help="Lower and upper momentum limit, e.g., 1 3",
 69    )
 70    parser.add_argument(
 71        "--antiparticles",
 72        action="store_true",
 73        help="If should train on particles instead of particles with positive charge.",
 74    )
 75    parser.add_argument(
 76        "--hyperparams",
 77        action="store_true",
 78        help="If should optimize hyper params instead of using const values from config file. Will use ranges from config file.",
 79    )
 80    parser.add_argument(
 81        "--gpu",
 82        action="store_true",
 83        help="If should use GPU for training. Remember that xgboost-gpu version is needed for this.",
 84    )
 85    parser.add_argument(
 86        "--nworkers",
 87        "-n",
 88        type=int,
 89        default=1,
 90        help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.",
 91    )
 92    graphs_group = parser.add_mutually_exclusive_group()
 93    graphs_group.add_argument(
 94        "--printplots",
 95        action="store_true",
 96        help="Creates plots and prints them without saving to file.",
 97    )
 98    graphs_group.add_argument(
 99        "--saveplots",
100        "-plots",
101        action="store_true",
102        help="Creates plots and saves them to file, without printing.",
103    )
104    parser.add_argument(
105        "--usevalidation",
106        action="store_true",
107        help="if should use validation dataset for post-training plots",
108    )
109    return parser.parse_args(args)

Arguments parser for the main method.

Args: args (List[str]): Arguments from the command line, should be sys.argv[1:].

Returns: argparse.Namespace: argparse.Namespace containg args