Mercurial > repos > bgruening > create_tool_recommendation_model
comparison main.py @ 5:9ec705bd11cb draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
| author | bgruening |
|---|---|
| date | Sun, 16 Oct 2022 11:51:32 +0000 |
| parents | f0da532be419 |
| children |
comparison
equal
deleted
inserted
replaced
| 4:f0da532be419 | 5:9ec705bd11cb |
|---|---|
| 1 """ | 1 """ |
| 2 Predict next tools in the Galaxy workflows | 2 Predict next tools in the Galaxy workflows |
| 3 using machine learning (recurrent neural network) | 3 using deep learning learning (Transformers) |
| 4 """ | 4 """ |
| 5 | |
| 6 import argparse | 5 import argparse |
| 7 import time | 6 import time |
| 8 | 7 |
| 9 import extract_workflow_connections | 8 import extract_workflow_connections |
| 10 import keras.callbacks as callbacks | |
| 11 import numpy as np | |
| 12 import optimise_hyperparameters | |
| 13 import prepare_data | 9 import prepare_data |
| 14 import utils | 10 import train_transformer |
| 15 | |
| 16 | |
| 17 class PredictTool: | |
| 18 def __init__(self, num_cpus): | |
| 19 """ Init method. """ | |
| 20 | |
| 21 def find_train_best_network( | |
| 22 self, | |
| 23 network_config, | |
| 24 reverse_dictionary, | |
| 25 train_data, | |
| 26 train_labels, | |
| 27 test_data, | |
| 28 test_labels, | |
| 29 n_epochs, | |
| 30 class_weights, | |
| 31 usage_pred, | |
| 32 standard_connections, | |
| 33 tool_freq, | |
| 34 tool_tr_samples, | |
| 35 ): | |
| 36 """ | |
| 37 Define recurrent neural network and train sequential data | |
| 38 """ | |
| 39 # get tools with lowest representation | |
| 40 lowest_tool_ids = utils.get_lowest_tools(tool_freq) | |
| 41 | |
| 42 print("Start hyperparameter optimisation...") | |
| 43 hyper_opt = optimise_hyperparameters.HyperparameterOptimisation() | |
| 44 best_params, best_model = hyper_opt.train_model( | |
| 45 network_config, | |
| 46 reverse_dictionary, | |
| 47 train_data, | |
| 48 train_labels, | |
| 49 test_data, | |
| 50 test_labels, | |
| 51 tool_tr_samples, | |
| 52 class_weights, | |
| 53 ) | |
| 54 | |
| 55 # define callbacks | |
| 56 early_stopping = callbacks.EarlyStopping( | |
| 57 monitor="loss", | |
| 58 mode="min", | |
| 59 verbose=1, | |
| 60 min_delta=1e-1, | |
| 61 restore_best_weights=True, | |
| 62 ) | |
| 63 predict_callback_test = PredictCallback( | |
| 64 test_data, | |
| 65 test_labels, | |
| 66 reverse_dictionary, | |
| 67 n_epochs, | |
| 68 usage_pred, | |
| 69 standard_connections, | |
| 70 lowest_tool_ids, | |
| 71 ) | |
| 72 | |
| 73 callbacks_list = [predict_callback_test, early_stopping] | |
| 74 batch_size = int(best_params["batch_size"]) | |
| 75 | |
| 76 print("Start training on the best model...") | |
| 77 train_performance = dict() | |
| 78 trained_model = best_model.fit_generator( | |
| 79 utils.balanced_sample_generator( | |
| 80 train_data, | |
| 81 train_labels, | |
| 82 batch_size, | |
| 83 tool_tr_samples, | |
| 84 reverse_dictionary, | |
| 85 ), | |
| 86 steps_per_epoch=len(train_data) // batch_size, | |
| 87 epochs=n_epochs, | |
| 88 callbacks=callbacks_list, | |
| 89 validation_data=(test_data, test_labels), | |
| 90 verbose=2, | |
| 91 shuffle=True, | |
| 92 ) | |
| 93 train_performance["validation_loss"] = np.array( | |
| 94 trained_model.history["val_loss"] | |
| 95 ) | |
| 96 train_performance["precision"] = predict_callback_test.precision | |
| 97 train_performance["usage_weights"] = predict_callback_test.usage_weights | |
| 98 train_performance[ | |
| 99 "published_precision" | |
| 100 ] = predict_callback_test.published_precision | |
| 101 train_performance[ | |
| 102 "lowest_pub_precision" | |
| 103 ] = predict_callback_test.lowest_pub_precision | |
| 104 train_performance[ | |
| 105 "lowest_norm_precision" | |
| 106 ] = predict_callback_test.lowest_norm_precision | |
| 107 train_performance["train_loss"] = np.array(trained_model.history["loss"]) | |
| 108 train_performance["model"] = best_model | |
| 109 train_performance["best_parameters"] = best_params | |
| 110 return train_performance | |
| 111 | |
| 112 | |
| 113 class PredictCallback(callbacks.Callback): | |
| 114 def __init__( | |
| 115 self, | |
| 116 test_data, | |
| 117 test_labels, | |
| 118 reverse_data_dictionary, | |
| 119 n_epochs, | |
| 120 usg_scores, | |
| 121 standard_connections, | |
| 122 lowest_tool_ids, | |
| 123 ): | |
| 124 self.test_data = test_data | |
| 125 self.test_labels = test_labels | |
| 126 self.reverse_data_dictionary = reverse_data_dictionary | |
| 127 self.precision = list() | |
| 128 self.usage_weights = list() | |
| 129 self.published_precision = list() | |
| 130 self.n_epochs = n_epochs | |
| 131 self.pred_usage_scores = usg_scores | |
| 132 self.standard_connections = standard_connections | |
| 133 self.lowest_tool_ids = lowest_tool_ids | |
| 134 self.lowest_pub_precision = list() | |
| 135 self.lowest_norm_precision = list() | |
| 136 | |
| 137 def on_epoch_end(self, epoch, logs={}): | |
| 138 """ | |
| 139 Compute absolute and compatible precision for test data | |
| 140 """ | |
| 141 if len(self.test_data) > 0: | |
| 142 ( | |
| 143 usage_weights, | |
| 144 precision, | |
| 145 precision_pub, | |
| 146 low_pub_prec, | |
| 147 low_norm_prec, | |
| 148 low_num, | |
| 149 ) = utils.verify_model( | |
| 150 self.model, | |
| 151 self.test_data, | |
| 152 self.test_labels, | |
| 153 self.reverse_data_dictionary, | |
| 154 self.pred_usage_scores, | |
| 155 self.standard_connections, | |
| 156 self.lowest_tool_ids, | |
| 157 ) | |
| 158 self.precision.append(precision) | |
| 159 self.usage_weights.append(usage_weights) | |
| 160 self.published_precision.append(precision_pub) | |
| 161 self.lowest_pub_precision.append(low_pub_prec) | |
| 162 self.lowest_norm_precision.append(low_norm_prec) | |
| 163 print("Epoch %d usage weights: %s" % (epoch + 1, usage_weights)) | |
| 164 print("Epoch %d normal precision: %s" % (epoch + 1, precision)) | |
| 165 print("Epoch %d published precision: %s" % (epoch + 1, precision_pub)) | |
| 166 print("Epoch %d lowest published precision: %s" % (epoch + 1, low_pub_prec)) | |
| 167 print("Epoch %d lowest normal precision: %s" % (epoch + 1, low_norm_prec)) | |
| 168 print( | |
| 169 "Epoch %d number of test samples with lowest tool ids: %s" | |
| 170 % (epoch + 1, low_num) | |
| 171 ) | |
| 172 | |
| 173 | 11 |
| 174 if __name__ == "__main__": | 12 if __name__ == "__main__": |
| 175 start_time = time.time() | 13 start_time = time.time() |
| 176 | 14 |
| 177 arg_parser = argparse.ArgumentParser() | 15 arg_parser = argparse.ArgumentParser() |
| 178 arg_parser.add_argument( | 16 arg_parser.add_argument("-wf", "--workflow_file", required=True, help="workflows tabular file") |
| 179 "-wf", "--workflow_file", required=True, help="workflows tabular file" | 17 arg_parser.add_argument("-tu", "--tool_usage_file", required=True, help="tool usage file") |
| 180 ) | |
| 181 arg_parser.add_argument( | |
| 182 "-tu", "--tool_usage_file", required=True, help="tool usage file" | |
| 183 ) | |
| 184 arg_parser.add_argument( | |
| 185 "-om", "--output_model", required=True, help="trained model file" | |
| 186 ) | |
| 187 # data parameters | 18 # data parameters |
| 188 arg_parser.add_argument( | 19 arg_parser.add_argument("-cd", "--cutoff_date", required=True, help="earliest date for taking tool usage") |
| 189 "-cd", | 20 arg_parser.add_argument("-pl", "--maximum_path_length", required=True, help="maximum length of tool path") |
| 190 "--cutoff_date", | 21 arg_parser.add_argument("-om", "--output_model", required=True, help="trained model path") |
| 191 required=True, | |
| 192 help="earliest date for taking tool usage", | |
| 193 ) | |
| 194 arg_parser.add_argument( | |
| 195 "-pl", | |
| 196 "--maximum_path_length", | |
| 197 required=True, | |
| 198 help="maximum length of tool path", | |
| 199 ) | |
| 200 arg_parser.add_argument( | |
| 201 "-ep", | |
| 202 "--n_epochs", | |
| 203 required=True, | |
| 204 help="number of iterations to run to create model", | |
| 205 ) | |
| 206 arg_parser.add_argument( | |
| 207 "-oe", | |
| 208 "--optimize_n_epochs", | |
| 209 required=True, | |
| 210 help="number of iterations to run to find best model parameters", | |
| 211 ) | |
| 212 arg_parser.add_argument( | |
| 213 "-me", | |
| 214 "--max_evals", | |
| 215 required=True, | |
| 216 help="maximum number of configuration evaluations", | |
| 217 ) | |
| 218 arg_parser.add_argument( | |
| 219 "-ts", | |
| 220 "--test_share", | |
| 221 required=True, | |
| 222 help="share of data to be used for testing", | |
| 223 ) | |
| 224 # neural network parameters | 22 # neural network parameters |
| 225 arg_parser.add_argument( | 23 arg_parser.add_argument("-ti", "--n_train_iter", required=True, help="Number of training iterations run to create model") |
| 226 "-bs", | 24 arg_parser.add_argument("-nhd", "--n_heads", required=True, help="Number of head in transformer's multi-head attention") |
| 227 "--batch_size", | 25 arg_parser.add_argument("-ed", "--n_embed_dim", required=True, help="Embedding dimension") |
| 228 required=True, | 26 arg_parser.add_argument("-fd", "--n_feed_forward_dim", required=True, help="Feed forward network dimension") |
| 229 help="size of the tranining batch i.e. the number of samples per batch", | 27 arg_parser.add_argument("-dt", "--dropout", required=True, help="Percentage of neurons to be dropped") |
| 230 ) | 28 arg_parser.add_argument("-lr", "--learning_rate", required=True, help="Learning rate") |
| 231 arg_parser.add_argument( | 29 arg_parser.add_argument("-ts", "--te_share", required=True, help="Share of data to be used for testing") |
| 232 "-ut", "--units", required=True, help="number of hidden recurrent units" | 30 arg_parser.add_argument("-trbs", "--tr_batch_size", required=True, help="Train batch size") |
| 233 ) | 31 arg_parser.add_argument("-trlg", "--tr_logging_step", required=True, help="Train logging frequency") |
| 234 arg_parser.add_argument( | 32 arg_parser.add_argument("-telg", "--te_logging_step", required=True, help="Test logging frequency") |
| 235 "-es", | 33 arg_parser.add_argument("-tebs", "--te_batch_size", required=True, help="Test batch size") |
| 236 "--embedding_size", | |
| 237 required=True, | |
| 238 help="size of the fixed vector learned for each tool", | |
| 239 ) | |
| 240 arg_parser.add_argument( | |
| 241 "-dt", "--dropout", required=True, help="percentage of neurons to be dropped" | |
| 242 ) | |
| 243 arg_parser.add_argument( | |
| 244 "-sd", | |
| 245 "--spatial_dropout", | |
| 246 required=True, | |
| 247 help="1d dropout used for embedding layer", | |
| 248 ) | |
| 249 arg_parser.add_argument( | |
| 250 "-rd", | |
| 251 "--recurrent_dropout", | |
| 252 required=True, | |
| 253 help="dropout for the recurrent layers", | |
| 254 ) | |
| 255 arg_parser.add_argument( | |
| 256 "-lr", "--learning_rate", required=True, help="learning rate" | |
| 257 ) | |
| 258 | 34 |
| 259 # get argument values | 35 # get argument values |
| 260 args = vars(arg_parser.parse_args()) | 36 args = vars(arg_parser.parse_args()) |
| 261 tool_usage_path = args["tool_usage_file"] | 37 tool_usage_path = args["tool_usage_file"] |
| 262 workflows_path = args["workflow_file"] | 38 workflows_path = args["workflow_file"] |
| 263 cutoff_date = args["cutoff_date"] | 39 cutoff_date = args["cutoff_date"] |
| 264 maximum_path_length = int(args["maximum_path_length"]) | 40 maximum_path_length = int(args["maximum_path_length"]) |
| 41 | |
| 42 n_train_iter = int(args["n_train_iter"]) | |
| 43 te_share = float(args["te_share"]) | |
| 44 tr_batch_size = int(args["tr_batch_size"]) | |
| 45 te_batch_size = int(args["te_batch_size"]) | |
| 46 | |
| 47 n_heads = int(args["n_heads"]) | |
| 48 feed_forward_dim = int(args["n_feed_forward_dim"]) | |
| 49 embedding_dim = int(args["n_embed_dim"]) | |
| 50 dropout = float(args["dropout"]) | |
| 51 learning_rate = float(args["learning_rate"]) | |
| 52 te_logging_step = int(args["te_logging_step"]) | |
| 53 tr_logging_step = int(args["tr_logging_step"]) | |
| 265 trained_model_path = args["output_model"] | 54 trained_model_path = args["output_model"] |
| 266 n_epochs = int(args["n_epochs"]) | |
| 267 optimize_n_epochs = int(args["optimize_n_epochs"]) | |
| 268 max_evals = int(args["max_evals"]) | |
| 269 test_share = float(args["test_share"]) | |
| 270 batch_size = args["batch_size"] | |
| 271 units = args["units"] | |
| 272 embedding_size = args["embedding_size"] | |
| 273 dropout = args["dropout"] | |
| 274 spatial_dropout = args["spatial_dropout"] | |
| 275 recurrent_dropout = args["recurrent_dropout"] | |
| 276 learning_rate = args["learning_rate"] | |
| 277 num_cpus = 16 | |
| 278 | 55 |
| 279 config = { | 56 config = { |
| 280 "cutoff_date": cutoff_date, | 57 'cutoff_date': cutoff_date, |
| 281 "maximum_path_length": maximum_path_length, | 58 'maximum_path_length': maximum_path_length, |
| 282 "n_epochs": n_epochs, | 59 'n_train_iter': n_train_iter, |
| 283 "optimize_n_epochs": optimize_n_epochs, | 60 'n_heads': n_heads, |
| 284 "max_evals": max_evals, | 61 'feed_forward_dim': feed_forward_dim, |
| 285 "test_share": test_share, | 62 'embedding_dim': embedding_dim, |
| 286 "batch_size": batch_size, | 63 'dropout': dropout, |
| 287 "units": units, | 64 'learning_rate': learning_rate, |
| 288 "embedding_size": embedding_size, | 65 'te_share': te_share, |
| 289 "dropout": dropout, | 66 'te_logging_step': te_logging_step, |
| 290 "spatial_dropout": spatial_dropout, | 67 'tr_logging_step': tr_logging_step, |
| 291 "recurrent_dropout": recurrent_dropout, | 68 'tr_batch_size': tr_batch_size, |
| 292 "learning_rate": learning_rate, | 69 'te_batch_size': te_batch_size, |
| 70 'trained_model_path': trained_model_path | |
| 293 } | 71 } |
| 294 | 72 print("Preprocessing workflows...") |
| 295 # Extract and process workflows | 73 # Extract and process workflows |
| 296 connections = extract_workflow_connections.ExtractWorkflowConnections() | 74 connections = extract_workflow_connections.ExtractWorkflowConnections() |
| 297 ( | 75 # Process raw workflow file |
| 298 workflow_paths, | 76 wf_dataframe, usage_df = connections.process_raw_files(workflows_path, tool_usage_path, config) |
| 299 compatible_next_tools, | 77 workflow_paths, pub_conn = connections.read_tabular_file(wf_dataframe, config) |
| 300 standard_connections, | |
| 301 ) = connections.read_tabular_file(workflows_path) | |
| 302 # Process the paths from workflows | 78 # Process the paths from workflows |
| 303 print("Dividing data...") | 79 print("Dividing data...") |
| 304 data = prepare_data.PrepareData(maximum_path_length, test_share) | 80 data = prepare_data.PrepareData(maximum_path_length, te_share) |
| 305 ( | 81 train_data, train_labels, test_data, test_labels, f_dict, r_dict, c_wts, c_tools, tr_tool_freq = data.get_data_labels_matrices(workflow_paths, usage_df, cutoff_date, pub_conn) |
| 306 train_data, | 82 print(train_data.shape, train_labels.shape, test_data.shape, test_labels.shape) |
| 307 train_labels, | 83 train_transformer.create_enc_transformer(train_data, train_labels, test_data, test_labels, f_dict, r_dict, c_wts, c_tools, pub_conn, tr_tool_freq, config) |
| 308 test_data, | |
| 309 test_labels, | |
| 310 data_dictionary, | |
| 311 reverse_dictionary, | |
| 312 class_weights, | |
| 313 usage_pred, | |
| 314 train_tool_freq, | |
| 315 tool_tr_samples, | |
| 316 ) = data.get_data_labels_matrices( | |
| 317 workflow_paths, | |
| 318 tool_usage_path, | |
| 319 cutoff_date, | |
| 320 compatible_next_tools, | |
| 321 standard_connections, | |
| 322 ) | |
| 323 # find the best model and start training | |
| 324 predict_tool = PredictTool(num_cpus) | |
| 325 # start training with weighted classes | |
| 326 print("Training with weighted classes and samples ...") | |
| 327 results_weighted = predict_tool.find_train_best_network( | |
| 328 config, | |
| 329 reverse_dictionary, | |
| 330 train_data, | |
| 331 train_labels, | |
| 332 test_data, | |
| 333 test_labels, | |
| 334 n_epochs, | |
| 335 class_weights, | |
| 336 usage_pred, | |
| 337 standard_connections, | |
| 338 train_tool_freq, | |
| 339 tool_tr_samples, | |
| 340 ) | |
| 341 utils.save_model( | |
| 342 results_weighted, | |
| 343 data_dictionary, | |
| 344 compatible_next_tools, | |
| 345 trained_model_path, | |
| 346 class_weights, | |
| 347 standard_connections, | |
| 348 ) | |
| 349 end_time = time.time() | 84 end_time = time.time() |
| 350 print("Program finished in %s seconds" % str(end_time - start_time)) | 85 print("Program finished in %s seconds" % str(end_time - start_time)) |
