ml_pid_cbm.binary.train_binary_model
1import argparse 2import gc 3import os 4import sys 5from shutil import copy2 6from typing import List 7 8from hipe4ml.model_handler import ModelHandler 9from sklearn.utils.class_weight import compute_sample_weight 10 11from ml_pid_cbm.tools import json_tools, plotting_tools 12from ml_pid_cbm.tools.load_data import LoadData 13from ml_pid_cbm.tools.particles_id import ParticlesId as Pid 14from ml_pid_cbm.tools.prepare_model import PrepareModel 15from ml_pid_cbm.train_model import TrainModel 16 17 18class TrainBinaryModel(TrainModel): 19 """ 20 Class for training the ml model 21 """ 22 23 def train_model_handler( 24 self, train_test_data, sample_weights, model_hdl: ModelHandler = None 25 ): 26 """Trains model handler 27 28 Args: 29 train_test_data (_type_): Train_test_data generated using a method from prepare_model module. 30 sample_weights(List[Float]): ndarray of shape (n_samples,) Array with sample weights. 31 To be computed with sklearn.utils.class_weight.compute_sample_weight 32 model_hdl (ModelHandler, optional): Hipe4ml model handler. Defaults to None. 33 """ 34 model_hdl = model_hdl or self.model_hdl 35 model_hdl.train_test_model(train_test_data, sample_weight=sample_weights) 36 self.model_hdl = model_hdl 37 38 39def parse_args(args: List[str]) -> argparse.Namespace: 40 """ 41 Arguments parser for the main method. 42 43 Args: 44 args (List[str]): Arguments from the command line, should be sys.argv[1:]. 45 46 Returns: 47 argparse.Namespace: argparse.Namespace containg args 48 """ 49 parser = argparse.ArgumentParser( 50 prog="ML_PID_CBM TrainModel", 51 description="Program for training binary PID ML models", 52 ) 53 parser.add_argument( 54 "--config", 55 "-c", 56 nargs=1, 57 required=True, 58 type=str, 59 help="Filename of path of config json file.", 60 ) 61 parser.add_argument( 62 "--momentum", 63 "-p", 64 nargs=2, 65 required=True, 66 type=float, 67 help="Lower and upper momentum limit, e.g., 1 3", 68 ) 69 parser.add_argument( 70 "--antiparticles", 71 action="store_true", 72 help="If should train on particles instead of particles with positive charge.", 73 ) 74 parser.add_argument( 75 "--hyperparams", 76 action="store_true", 77 help="If should optimize hyper params instead of using const values from config file. Will use ranges from config file.", 78 ) 79 parser.add_argument( 80 "--gpu", 81 action="store_true", 82 help="If should use GPU for training. Remember that xgboost-gpu version is needed for this.", 83 ) 84 parser.add_argument( 85 "--nworkers", 86 "-n", 87 type=int, 88 default=1, 89 help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.", 90 ) 91 graphs_group = parser.add_mutually_exclusive_group() 92 graphs_group.add_argument( 93 "--printplots", 94 action="store_true", 95 help="Creates plots and prints them without saving to file.", 96 ) 97 graphs_group.add_argument( 98 "--saveplots", 99 "-plots", 100 action="store_true", 101 help="Creates plots and saves them to file, without printing.", 102 ) 103 parser.add_argument( 104 "--usevalidation", 105 action="store_true", 106 help="if should use validation dataset for post-training plots", 107 ) 108 return parser.parse_args(args) 109 110 111# main method of the training 112if __name__ == "__main__": 113 # parser for main class 114 args = parse_args(sys.argv[1:]) 115 # config arguments to be loaded from args 116 json_file_name = args.config[0] 117 lower_p_cut, upper_p_cut = args.momentum[0], args.momentum[1] 118 anti_particles = args.antiparticles 119 optimize_hyper_params = args.hyperparams 120 use_gpu = args.gpu 121 n_workers = args.nworkers 122 create_plots = args.printplots or args.saveplots or False 123 save_plots = args.saveplots 124 use_validation = args.usevalidation 125 if anti_particles: 126 model_name = f"model_{lower_p_cut:.1f}_{upper_p_cut:.1f}_anti" 127 else: 128 model_name = f"model_{lower_p_cut:.1f}_{upper_p_cut:.1f}_positive" 129 data_file_name = json_tools.load_file_name(json_file_name, "training") 130 131 # loading data 132 loader = LoadData( 133 data_file_name, json_file_name, lower_p_cut, upper_p_cut, anti_particles 134 ) 135 tree_handler = loader.load_tree(max_workers=n_workers) 136 NSIGMA = 5 137 kaons = loader.get_particles_type( 138 tree_handler, Pid.POS_KAON.value, nsigma=NSIGMA, json_file_name=json_file_name 139 ) 140 pid_var_name = json_tools.load_var_name(json_file_name, "pid") 141 bckgr = tree_handler.get_subset(f"{pid_var_name} != {Pid.POS_KAON.value}") 142 print(f"\nKaons and bckgr (everything else) loaded using file {data_file_name}\n") 143 del tree_handler 144 gc.collect() 145 # change location to specific folder for this model 146 json_file_path = os.path.join(os.getcwd(), json_file_name) 147 if not os.path.exists(f"{model_name}"): 148 os.makedirs(f"{model_name}") 149 os.chdir(f"{model_name}") 150 copy2(json_file_path, os.getcwd()) 151 # pretraining plots 152 if create_plots: 153 print("Creating pre-training plots...") 154 plotting_tools.tof_plot( 155 kaons, json_file_name, f"kaons ({NSIGMA}$\sigma$)", save_fig=save_plots 156 ) 157 vars_to_draw = bckgr.get_var_names() 158 plotting_tools.correlations_plot( 159 vars_to_draw, [kaons, bckgr], save_fig=save_plots 160 ) 161 # loading model handler 162 model_hdl = PrepareModel(json_file_name, optimize_hyper_params, use_gpu) 163 train_test_data = PrepareModel.prepare_train_test_data([kaons, bckgr]) 164 del kaons, bckgr 165 gc.collect() 166 features_for_train = json_tools.load_features_for_train(json_file_name) 167 print("\nPreparing model handler...") 168 model_hdl, study = model_hdl.prepare_model_handler(train_test_data=train_test_data) 169 if create_plots and optimize_hyper_params: 170 plotting_tools.opt_history_plot(study, save_plots) 171 plotting_tools.opt_contour_plot(study, save_plots) 172 # train model 173 train = TrainModel(model_hdl, model_name) 174 sample_weights = compute_sample_weight( 175 class_weight=None, # class_weight="balanced" deleted for now 176 y=train_test_data[1], 177 ) 178 train.train_model_handler(train_test_data, sample_weights) 179 print("\nModel trained!") 180 train.save_model(model_name) 181 # loading validation dataset as test dataset for pos-training plots 182 if use_validation: 183 data_file_name_test = json_tools.load_file_name(json_file_name, "test") 184 loader_test = LoadData( 185 data_file_name_test, 186 json_file_name, 187 lower_p_cut, 188 upper_p_cut, 189 anti_particles, 190 ) 191 tree_handler_test = loader_test.load_tree(max_workers=n_workers) 192 kaons_test = loader.get_particles_type( 193 tree_handler_test, 194 Pid.POS_KAON.value, 195 nsigma=NSIGMA, 196 json_file_name=json_file_name, 197 ) 198 bckgr_test = tree_handler_test.get_subset(f"{pid_var_name} != {Pid.POS_KAON.value}") 199 validation_data = PrepareModel.prepare_train_test_data([kaons_test, bckgr_test]) 200 train_test_data = [ 201 train_test_data[0], 202 train_test_data[1], 203 validation_data[0], 204 validation_data[1], 205 ] 206 if create_plots: 207 print("Creating post-training plots") 208 y_pred_train = model_hdl.predict(train_test_data[0], False) 209 y_pred_test = model_hdl.predict(train_test_data[2], False) 210 plotting_tools.output_train_test_plot( 211 train.model_hdl, train_test_data, leg_labels=["kaons"], save_fig=save_plots, logscale=True 212 ) 213 214 plotting_tools.roc_plot(train_test_data[3], y_pred_test, save_fig=save_plots) 215 # shapleys for each class 216 feature_names = [item.replace("Complex_", "") for item in features_for_train] 217 plotting_tools.plot_shap_summary( 218 train_test_data[0][features_for_train], 219 train_test_data[1], 220 model_hdl, 221 features_for_train, 222 n_workers, 223 )
19class TrainBinaryModel(TrainModel): 20 """ 21 Class for training the ml model 22 """ 23 24 def train_model_handler( 25 self, train_test_data, sample_weights, model_hdl: ModelHandler = None 26 ): 27 """Trains model handler 28 29 Args: 30 train_test_data (_type_): Train_test_data generated using a method from prepare_model module. 31 sample_weights(List[Float]): ndarray of shape (n_samples,) Array with sample weights. 32 To be computed with sklearn.utils.class_weight.compute_sample_weight 33 model_hdl (ModelHandler, optional): Hipe4ml model handler. Defaults to None. 34 """ 35 model_hdl = model_hdl or self.model_hdl 36 model_hdl.train_test_model(train_test_data, sample_weight=sample_weights) 37 self.model_hdl = model_hdl
Class for training the ml model
def
train_model_handler( self, train_test_data, sample_weights, model_hdl: hipe4ml.model_handler.ModelHandler = None):
24 def train_model_handler( 25 self, train_test_data, sample_weights, model_hdl: ModelHandler = None 26 ): 27 """Trains model handler 28 29 Args: 30 train_test_data (_type_): Train_test_data generated using a method from prepare_model module. 31 sample_weights(List[Float]): ndarray of shape (n_samples,) Array with sample weights. 32 To be computed with sklearn.utils.class_weight.compute_sample_weight 33 model_hdl (ModelHandler, optional): Hipe4ml model handler. Defaults to None. 34 """ 35 model_hdl = model_hdl or self.model_hdl 36 model_hdl.train_test_model(train_test_data, sample_weight=sample_weights) 37 self.model_hdl = model_hdl
Trains model handler
Args: train_test_data (_type_): Train_test_data generated using a method from prepare_model module. sample_weights(List[Float]): ndarray of shape (n_samples,) Array with sample weights. To be computed with sklearn.utils.class_weight.compute_sample_weight model_hdl (ModelHandler, optional): Hipe4ml model handler. Defaults to None.
Inherited Members
def
parse_args(args: List[str]) -> argparse.Namespace:
40def parse_args(args: List[str]) -> argparse.Namespace: 41 """ 42 Arguments parser for the main method. 43 44 Args: 45 args (List[str]): Arguments from the command line, should be sys.argv[1:]. 46 47 Returns: 48 argparse.Namespace: argparse.Namespace containg args 49 """ 50 parser = argparse.ArgumentParser( 51 prog="ML_PID_CBM TrainModel", 52 description="Program for training binary PID ML models", 53 ) 54 parser.add_argument( 55 "--config", 56 "-c", 57 nargs=1, 58 required=True, 59 type=str, 60 help="Filename of path of config json file.", 61 ) 62 parser.add_argument( 63 "--momentum", 64 "-p", 65 nargs=2, 66 required=True, 67 type=float, 68 help="Lower and upper momentum limit, e.g., 1 3", 69 ) 70 parser.add_argument( 71 "--antiparticles", 72 action="store_true", 73 help="If should train on particles instead of particles with positive charge.", 74 ) 75 parser.add_argument( 76 "--hyperparams", 77 action="store_true", 78 help="If should optimize hyper params instead of using const values from config file. Will use ranges from config file.", 79 ) 80 parser.add_argument( 81 "--gpu", 82 action="store_true", 83 help="If should use GPU for training. Remember that xgboost-gpu version is needed for this.", 84 ) 85 parser.add_argument( 86 "--nworkers", 87 "-n", 88 type=int, 89 default=1, 90 help="Max number of workers for ThreadPoolExecutor which reads Root tree with data.", 91 ) 92 graphs_group = parser.add_mutually_exclusive_group() 93 graphs_group.add_argument( 94 "--printplots", 95 action="store_true", 96 help="Creates plots and prints them without saving to file.", 97 ) 98 graphs_group.add_argument( 99 "--saveplots", 100 "-plots", 101 action="store_true", 102 help="Creates plots and saves them to file, without printing.", 103 ) 104 parser.add_argument( 105 "--usevalidation", 106 action="store_true", 107 help="if should use validation dataset for post-training plots", 108 ) 109 return parser.parse_args(args)
Arguments parser for the main method.
Args: args (List[str]): Arguments from the command line, should be sys.argv[1:].
Returns: argparse.Namespace: argparse.Namespace containg args