ml_pid_cbm.train_model
Module for training the model.
1""" 2Module for training the model. 3 4""" 5import argparse 6import gc 7import os 8import sys 9from shutil import copy2 10from typing import List 11 12from hipe4ml.model_handler import ModelHandler 13from sklearn.utils.class_weight import compute_sample_weight 14from tools import json_tools, plotting_tools 15from tools.load_data import LoadData 16from tools.prepare_model import PrepareModel 17 18 19class TrainModel: 20 """ 21 Class for training the ml model 22 """ 23 24 def __init__(self, model_hdl: ModelHandler, model_name: str): 25 self.model_hdl = model_hdl 26 self.model_name = model_name 27 28 def train_model_handler( 29 self, train_test_data, sample_weights, model_hdl: ModelHandler = None 30 ): 31 """Trains model handler 32 33 Args: 34 train_test_data (_type_): Train_test_data generated using a method from prepare_model module. 35 sample_weights(List[Float]): ndarray of shape (n_samples,) Array with sample weights. 36 To be computed with sklearn.utils.class_weight.compute_sample_weight 37 model_hdl (ModelHandler, optional): Hipe4ml model handler. Defaults to None. 38 """ 39 model_hdl = model_hdl or self.model_hdl 40 model_hdl.train_test_model( 41 train_test_data, 42 multi_class_opt="ovo", 43 sample_weight=sample_weights, 44 ) 45 self.model_hdl = model_hdl 46 47 def save_model(self, model_name: str = None, model_hdl: ModelHandler = None): 48 """Saves trained model handler. 49 50 Args: 51 model_name (str, optional): Name of the model handler. Defaults to None. 52 model_hdl (ModelHandler, optional): Hipe4ml model handler. Defaults to None. 53 """ 54 model_name = model_name or self.model_name 55 model_hdl = model_hdl or self.model_hdl 56 model_hdl.dump_model_handler(model_name) 57 print(f"\nModel saved as {model_name}") 58 59 60def parse_args(args: List[str]) -> argparse.Namespace: 61 """ 62 Arguments parser for the main method. 63 64 Args: 65 args (List[str]): Arguments from the command line, should be sys.argv[1:]. 66 67 Returns: 68 argparse.Namespace: argparse.Namespace containg args 69 """ 70 parser = argparse.ArgumentParser( 71 prog="ML_PID_CBM TrainModel", description="Program for training PID ML models" 72 ) 73 parser.add_argument( 74 "--config", 75 "-c", 76 nargs=1, 77 required=True, 78 type=str, 79 help="Filename of path of config json file.", 80 ) 81 parser.add_argument( 82 "--momentum", 83 "-p", 84 nargs=2, 85 required=True, 86 type=float, 87 help="Lower and upper momentum limit, e.g., 1 3", 88 ) 89 parser.add_argument( 90 "--antiparticles", 91 action="store_true", 92 help="If should train on particles instead of particles with positive charge.", 93 ) 94 parser.add_argument( 95 "--hyperparams", 96 action="store_true", 97 help="If should optimize hyper params instead of using const values from config file. Will use ranges from config file.", 98 ) 99 parser.add_argument( 100 "--gpu", 101 action="store_true", 102 help="If should use GPU for training. Remember that xgboost-gpu version is needed for this.", 103 ) 104 parser.add_argument( 105 "--nworkers", 106 "-n", 107 type=int, 108 default=1, 109 help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.", 110 ) 111 graphs_group = parser.add_mutually_exclusive_group() 112 graphs_group.add_argument( 113 "--printplots", 114 action="store_true", 115 help="Creates plots and prints them without saving to file.", 116 ) 117 graphs_group.add_argument( 118 "--saveplots", 119 "-plots", 120 action="store_true", 121 help="Creates plots and saves them to file, without printing.", 122 ) 123 parser.add_argument( 124 "--usevalidation", 125 action="store_true", 126 help="if should use validation dataset for post-training plots", 127 ) 128 return parser.parse_args(args) 129 130 131# main method of the training 132if __name__ == "__main__": 133 # parser for main class 134 args = parse_args(sys.argv[1:]) 135 # config arguments to be loaded from args 136 json_file_name = args.config[0] 137 lower_p_cut, upper_p_cut = args.momentum[0], args.momentum[1] 138 anti_particles = args.antiparticles 139 optimize_hyper_params = args.hyperparams 140 use_gpu = args.gpu 141 n_workers = args.nworkers 142 create_plots = args.printplots or args.saveplots or False 143 save_plots = args.saveplots 144 use_validation = args.usevalidation 145 if anti_particles: 146 model_name = f"model_{lower_p_cut:.1f}_{upper_p_cut:.1f}_anti" 147 else: 148 model_name = f"model_{lower_p_cut:.1f}_{upper_p_cut:.1f}_positive" 149 data_file_name = json_tools.load_file_name(json_file_name, "training") 150 151 # loading data 152 loader = LoadData( 153 data_file_name, json_file_name, lower_p_cut, upper_p_cut, anti_particles 154 ) 155 tree_handler = loader.load_tree(max_workers=n_workers) 156 NSIGMA_PROTON = 0 157 NSIGMA_KAON = 0 158 NSIGMA_PION = 0 159 protons, kaons, pions = loader.get_protons_kaons_pions( 160 tree_handler, 161 nsigma_proton=NSIGMA_PROTON, 162 nsigma_kaon=NSIGMA_KAON, 163 nsigma_pion=NSIGMA_PION, 164 ) 165 print(f"\nProtons, kaons, and pions loaded using file {data_file_name}\n") 166 del tree_handler 167 gc.collect() 168 # change location to specific folder for this model 169 json_file_path = os.path.join(os.getcwd(), json_file_name) 170 if not os.path.exists(f"{model_name}"): 171 os.makedirs(f"{model_name}") 172 os.chdir(f"{model_name}") 173 copy2(json_file_path, os.getcwd()) 174 # pretraining plots 175 if create_plots: 176 print("Creating pre-training plots...") 177 plotting_tools.tof_plot( 178 protons, 179 json_file_name, 180 f"protons ({NSIGMA_PROTON}$\sigma$)", 181 save_fig=save_plots, 182 ) 183 plotting_tools.tof_plot( 184 kaons, json_file_name, f"kaons ({NSIGMA_KAON}$\sigma$)", save_fig=save_plots 185 ) 186 plotting_tools.tof_plot( 187 pions, 188 json_file_name, 189 f"pions, muons, electrons ({NSIGMA_PION}$\sigma$)", 190 save_fig=save_plots, 191 ) 192 vars_to_draw = protons.get_var_names() 193 plotting_tools.correlations_plot( 194 vars_to_draw, [protons, kaons, pions], save_fig=save_plots 195 ) 196 # loading model handler 197 model_hdl = PrepareModel(json_file_name, optimize_hyper_params, use_gpu) 198 train_test_data = PrepareModel.prepare_train_test_data([protons, kaons, pions]) 199 del protons, kaons, pions 200 gc.collect() 201 features_for_train = json_tools.load_features_for_train(json_file_name) 202 print("\nPreparing model handler...") 203 model_hdl, study = model_hdl.prepare_model_handler(train_test_data=train_test_data) 204 if create_plots and optimize_hyper_params: 205 plotting_tools.opt_history_plot(study, save_plots) 206 plotting_tools.opt_contour_plot(study, save_plots) 207 # train model 208 train = TrainModel(model_hdl, model_name) 209 sample_weights = compute_sample_weight( 210 class_weight="balanced", # class_weight=None or {0: 1, 1: 3, 2: 1}, deleted for now 211 y=train_test_data[1], 212 ) 213 train.train_model_handler(train_test_data, sample_weights) 214 print("\nModel trained!") 215 train.save_model(model_name) 216 # loading validation dataset as test dataset for pos-training plots 217 if use_validation: 218 data_file_name_test = json_tools.load_file_name(json_file_name, "test") 219 loader_test = LoadData( 220 data_file_name_test, 221 json_file_name, 222 lower_p_cut, 223 upper_p_cut, 224 anti_particles, 225 ) 226 tree_handler_test = loader_test.load_tree(max_workers=n_workers) 227 protons_test, kaons_test, pions_test = loader_test.get_protons_kaons_pions( 228 tree_handler_test, 229 nsigma_proton=NSIGMA_PROTON, 230 nsigma_kaon=NSIGMA_KAON, 231 nsigma_pion=NSIGMA_PION, 232 ) 233 validation_data = PrepareModel.prepare_train_test_data( 234 [protons_test, kaons_test, pions_test] 235 ) 236 train_test_data = [ 237 train_test_data[0], 238 train_test_data[1], 239 validation_data[0], 240 validation_data[1], 241 ] 242 if create_plots: 243 print("Creating post-training plots") 244 y_pred_train = model_hdl.predict(train_test_data[0], False) 245 y_pred_test = model_hdl.predict(train_test_data[2], False) 246 plotting_tools.output_train_test_plot( 247 train.model_hdl, train_test_data, save_fig=save_plots, logscale=True 248 ) 249 250 plotting_tools.roc_plot(train_test_data[3], y_pred_test, save_fig=save_plots) 251 # shapleys for each class 252 feature_names = [item.replace("Complex_", "") for item in features_for_train] 253 plotting_tools.plot_shap_summary( 254 train_test_data[0][features_for_train], 255 train_test_data[1], 256 model_hdl, 257 features_for_train, 258 n_workers, 259 )
20class TrainModel: 21 """ 22 Class for training the ml model 23 """ 24 25 def __init__(self, model_hdl: ModelHandler, model_name: str): 26 self.model_hdl = model_hdl 27 self.model_name = model_name 28 29 def train_model_handler( 30 self, train_test_data, sample_weights, model_hdl: ModelHandler = None 31 ): 32 """Trains model handler 33 34 Args: 35 train_test_data (_type_): Train_test_data generated using a method from prepare_model module. 36 sample_weights(List[Float]): ndarray of shape (n_samples,) Array with sample weights. 37 To be computed with sklearn.utils.class_weight.compute_sample_weight 38 model_hdl (ModelHandler, optional): Hipe4ml model handler. Defaults to None. 39 """ 40 model_hdl = model_hdl or self.model_hdl 41 model_hdl.train_test_model( 42 train_test_data, 43 multi_class_opt="ovo", 44 sample_weight=sample_weights, 45 ) 46 self.model_hdl = model_hdl 47 48 def save_model(self, model_name: str = None, model_hdl: ModelHandler = None): 49 """Saves trained model handler. 50 51 Args: 52 model_name (str, optional): Name of the model handler. Defaults to None. 53 model_hdl (ModelHandler, optional): Hipe4ml model handler. Defaults to None. 54 """ 55 model_name = model_name or self.model_name 56 model_hdl = model_hdl or self.model_hdl 57 model_hdl.dump_model_handler(model_name) 58 print(f"\nModel saved as {model_name}")
Class for training the ml model
29 def train_model_handler( 30 self, train_test_data, sample_weights, model_hdl: ModelHandler = None 31 ): 32 """Trains model handler 33 34 Args: 35 train_test_data (_type_): Train_test_data generated using a method from prepare_model module. 36 sample_weights(List[Float]): ndarray of shape (n_samples,) Array with sample weights. 37 To be computed with sklearn.utils.class_weight.compute_sample_weight 38 model_hdl (ModelHandler, optional): Hipe4ml model handler. Defaults to None. 39 """ 40 model_hdl = model_hdl or self.model_hdl 41 model_hdl.train_test_model( 42 train_test_data, 43 multi_class_opt="ovo", 44 sample_weight=sample_weights, 45 ) 46 self.model_hdl = model_hdl
Trains model handler
Args: train_test_data (_type_): Train_test_data generated using a method from prepare_model module. sample_weights(List[Float]): ndarray of shape (n_samples,) Array with sample weights. To be computed with sklearn.utils.class_weight.compute_sample_weight model_hdl (ModelHandler, optional): Hipe4ml model handler. Defaults to None.
48 def save_model(self, model_name: str = None, model_hdl: ModelHandler = None): 49 """Saves trained model handler. 50 51 Args: 52 model_name (str, optional): Name of the model handler. Defaults to None. 53 model_hdl (ModelHandler, optional): Hipe4ml model handler. Defaults to None. 54 """ 55 model_name = model_name or self.model_name 56 model_hdl = model_hdl or self.model_hdl 57 model_hdl.dump_model_handler(model_name) 58 print(f"\nModel saved as {model_name}")
Saves trained model handler.
Args: model_name (str, optional): Name of the model handler. Defaults to None. model_hdl (ModelHandler, optional): Hipe4ml model handler. Defaults to None.
61def parse_args(args: List[str]) -> argparse.Namespace: 62 """ 63 Arguments parser for the main method. 64 65 Args: 66 args (List[str]): Arguments from the command line, should be sys.argv[1:]. 67 68 Returns: 69 argparse.Namespace: argparse.Namespace containg args 70 """ 71 parser = argparse.ArgumentParser( 72 prog="ML_PID_CBM TrainModel", description="Program for training PID ML models" 73 ) 74 parser.add_argument( 75 "--config", 76 "-c", 77 nargs=1, 78 required=True, 79 type=str, 80 help="Filename of path of config json file.", 81 ) 82 parser.add_argument( 83 "--momentum", 84 "-p", 85 nargs=2, 86 required=True, 87 type=float, 88 help="Lower and upper momentum limit, e.g., 1 3", 89 ) 90 parser.add_argument( 91 "--antiparticles", 92 action="store_true", 93 help="If should train on particles instead of particles with positive charge.", 94 ) 95 parser.add_argument( 96 "--hyperparams", 97 action="store_true", 98 help="If should optimize hyper params instead of using const values from config file. Will use ranges from config file.", 99 ) 100 parser.add_argument( 101 "--gpu", 102 action="store_true", 103 help="If should use GPU for training. Remember that xgboost-gpu version is needed for this.", 104 ) 105 parser.add_argument( 106 "--nworkers", 107 "-n", 108 type=int, 109 default=1, 110 help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.", 111 ) 112 graphs_group = parser.add_mutually_exclusive_group() 113 graphs_group.add_argument( 114 "--printplots", 115 action="store_true", 116 help="Creates plots and prints them without saving to file.", 117 ) 118 graphs_group.add_argument( 119 "--saveplots", 120 "-plots", 121 action="store_true", 122 help="Creates plots and saves them to file, without printing.", 123 ) 124 parser.add_argument( 125 "--usevalidation", 126 action="store_true", 127 help="if should use validation dataset for post-training plots", 128 ) 129 return parser.parse_args(args)
Arguments parser for the main method.
Args: args (List[str]): Arguments from the command line, should be sys.argv[1:].
Returns: argparse.Namespace: argparse.Namespace containg args