comparison main.py @ 2:50753817983a draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
author bgruening
date Sat, 09 May 2020 09:38:04 +0000
parents 275e98795e99
children 98bc44d17561
comparison
equal deleted inserted replaced
1:275e98795e99 2:50753817983a
6 import numpy as np 6 import numpy as np
7 import argparse 7 import argparse
8 import time 8 import time
9 9
10 # machine learning library 10 # machine learning library
11 import tensorflow as tf
12 from keras import backend as K
11 import keras.callbacks as callbacks 13 import keras.callbacks as callbacks
12 14
13 import extract_workflow_connections 15 import extract_workflow_connections
14 import prepare_data 16 import prepare_data
15 import optimise_hyperparameters 17 import optimise_hyperparameters
16 import utils 18 import utils
17 19
18 20
19 class PredictTool: 21 class PredictTool:
20 22
21 @classmethod 23 def __init__(self, num_cpus):
22 def __init__(self):
23 """ Init method. """ 24 """ Init method. """
25 # set the number of cpus
26 cpu_config = tf.ConfigProto(
27 device_count={"CPU": num_cpus},
28 intra_op_parallelism_threads=num_cpus,
29 inter_op_parallelism_threads=num_cpus,
30 allow_soft_placement=True
31 )
32 K.set_session(tf.Session(config=cpu_config))
24 33
25 @classmethod 34 def find_train_best_network(self, network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, l_tool_freq, l_tool_tr_samples):
26 def find_train_best_network(self, network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, compatible_next_tools):
27 """ 35 """
28 Define recurrent neural network and train sequential data 36 Define recurrent neural network and train sequential data
29 """ 37 """
38 # get tools with lowest representation
39 lowest_tool_ids = utils.get_lowest_tools(l_tool_freq)
40
30 print("Start hyperparameter optimisation...") 41 print("Start hyperparameter optimisation...")
31 hyper_opt = optimise_hyperparameters.HyperparameterOptimisation() 42 hyper_opt = optimise_hyperparameters.HyperparameterOptimisation()
32 best_params = hyper_opt.train_model(network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, class_weights) 43 best_params, best_model = hyper_opt.train_model(network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, l_tool_tr_samples, class_weights)
33
34 # retrieve the model and train on complete dataset without validation set
35 model, best_params = utils.set_recurrent_network(best_params, reverse_dictionary, class_weights)
36 44
37 # define callbacks 45 # define callbacks
38 predict_callback_test = PredictCallback(test_data, test_labels, reverse_dictionary, n_epochs, compatible_next_tools, usage_pred) 46 early_stopping = callbacks.EarlyStopping(monitor='loss', mode='min', verbose=1, min_delta=1e-1, restore_best_weights=True)
39 # tensor_board = callbacks.TensorBoard(log_dir=log_directory, histogram_freq=0, write_graph=True, write_images=True) 47 predict_callback_test = PredictCallback(test_data, test_labels, reverse_dictionary, n_epochs, usage_pred, standard_connections, lowest_tool_ids)
40 callbacks_list = [predict_callback_test] 48
49 callbacks_list = [predict_callback_test, early_stopping]
50
51 batch_size = int(best_params["batch_size"])
41 52
42 print("Start training on the best model...") 53 print("Start training on the best model...")
43 model_fit = model.fit( 54 train_performance = dict()
44 train_data, 55 trained_model = best_model.fit_generator(
45 train_labels, 56 utils.balanced_sample_generator(
46 batch_size=int(best_params["batch_size"]), 57 train_data,
58 train_labels,
59 batch_size,
60 l_tool_tr_samples
61 ),
62 steps_per_epoch=len(train_data) // batch_size,
47 epochs=n_epochs, 63 epochs=n_epochs,
64 callbacks=callbacks_list,
65 validation_data=(test_data, test_labels),
48 verbose=2, 66 verbose=2,
49 callbacks=callbacks_list, 67 shuffle=True
50 shuffle="batch",
51 validation_data=(test_data, test_labels)
52 ) 68 )
53 69 train_performance["validation_loss"] = np.array(trained_model.history["val_loss"])
54 train_performance = { 70 train_performance["precision"] = predict_callback_test.precision
55 "train_loss": np.array(model_fit.history["loss"]), 71 train_performance["usage_weights"] = predict_callback_test.usage_weights
56 "model": model, 72 train_performance["published_precision"] = predict_callback_test.published_precision
57 "best_parameters": best_params 73 train_performance["lowest_pub_precision"] = predict_callback_test.lowest_pub_precision
58 } 74 train_performance["lowest_norm_precision"] = predict_callback_test.lowest_norm_precision
59 75 train_performance["train_loss"] = np.array(trained_model.history["loss"])
60 # if there is test data, add more information 76 train_performance["model"] = best_model
61 if len(test_data) > 0: 77 train_performance["best_parameters"] = best_params
62 train_performance["validation_loss"] = np.array(model_fit.history["val_loss"])
63 train_performance["precision"] = predict_callback_test.precision
64 train_performance["usage_weights"] = predict_callback_test.usage_weights
65 return train_performance 78 return train_performance
66 79
67 80
68 class PredictCallback(callbacks.Callback): 81 class PredictCallback(callbacks.Callback):
69 def __init__(self, test_data, test_labels, reverse_data_dictionary, n_epochs, next_compatible_tools, usg_scores): 82 def __init__(self, test_data, test_labels, reverse_data_dictionary, n_epochs, usg_scores, standard_connections, lowest_tool_ids):
70 self.test_data = test_data 83 self.test_data = test_data
71 self.test_labels = test_labels 84 self.test_labels = test_labels
72 self.reverse_data_dictionary = reverse_data_dictionary 85 self.reverse_data_dictionary = reverse_data_dictionary
73 self.precision = list() 86 self.precision = list()
74 self.usage_weights = list() 87 self.usage_weights = list()
88 self.published_precision = list()
75 self.n_epochs = n_epochs 89 self.n_epochs = n_epochs
76 self.next_compatible_tools = next_compatible_tools
77 self.pred_usage_scores = usg_scores 90 self.pred_usage_scores = usg_scores
91 self.standard_connections = standard_connections
92 self.lowest_tool_ids = lowest_tool_ids
93 self.lowest_pub_precision = list()
94 self.lowest_norm_precision = list()
78 95
79 def on_epoch_end(self, epoch, logs={}): 96 def on_epoch_end(self, epoch, logs={}):
80 """ 97 """
81 Compute absolute and compatible precision for test data 98 Compute absolute and compatible precision for test data
82 """ 99 """
83 if len(self.test_data) > 0: 100 if len(self.test_data) > 0:
84 precision, usage_weights = utils.verify_model(self.model, self.test_data, self.test_labels, self.reverse_data_dictionary, self.next_compatible_tools, self.pred_usage_scores) 101 usage_weights, precision, precision_pub, low_pub_prec, low_norm_prec, low_num = utils.verify_model(self.model, self.test_data, self.test_labels, self.reverse_data_dictionary, self.pred_usage_scores, self.standard_connections, self.lowest_tool_ids)
85 self.precision.append(precision) 102 self.precision.append(precision)
86 self.usage_weights.append(usage_weights) 103 self.usage_weights.append(usage_weights)
87 print("Epoch %d precision: %s" % (epoch + 1, precision)) 104 self.published_precision.append(precision_pub)
105 self.lowest_pub_precision.append(low_pub_prec)
106 self.lowest_norm_precision.append(low_norm_prec)
88 print("Epoch %d usage weights: %s" % (epoch + 1, usage_weights)) 107 print("Epoch %d usage weights: %s" % (epoch + 1, usage_weights))
108 print("Epoch %d normal precision: %s" % (epoch + 1, precision))
109 print("Epoch %d published precision: %s" % (epoch + 1, precision_pub))
110 print("Epoch %d lowest published precision: %s" % (epoch + 1, low_pub_prec))
111 print("Epoch %d lowest normal precision: %s" % (epoch + 1, low_norm_prec))
112 print("Epoch %d number of test samples with lowest tool ids: %s" % (epoch + 1, low_num))
89 113
90 114
91 if __name__ == "__main__": 115 if __name__ == "__main__":
92 start_time = time.time() 116 start_time = time.time()
117
93 arg_parser = argparse.ArgumentParser() 118 arg_parser = argparse.ArgumentParser()
94 arg_parser.add_argument("-wf", "--workflow_file", required=True, help="workflows tabular file") 119 arg_parser.add_argument("-wf", "--workflow_file", required=True, help="workflows tabular file")
95 arg_parser.add_argument("-tu", "--tool_usage_file", required=True, help="tool usage file") 120 arg_parser.add_argument("-tu", "--tool_usage_file", required=True, help="tool usage file")
96 arg_parser.add_argument("-om", "--output_model", required=True, help="trained model file") 121 arg_parser.add_argument("-om", "--output_model", required=True, help="trained model file")
97 # data parameters 122 # data parameters
99 arg_parser.add_argument("-pl", "--maximum_path_length", required=True, help="maximum length of tool path") 124 arg_parser.add_argument("-pl", "--maximum_path_length", required=True, help="maximum length of tool path")
100 arg_parser.add_argument("-ep", "--n_epochs", required=True, help="number of iterations to run to create model") 125 arg_parser.add_argument("-ep", "--n_epochs", required=True, help="number of iterations to run to create model")
101 arg_parser.add_argument("-oe", "--optimize_n_epochs", required=True, help="number of iterations to run to find best model parameters") 126 arg_parser.add_argument("-oe", "--optimize_n_epochs", required=True, help="number of iterations to run to find best model parameters")
102 arg_parser.add_argument("-me", "--max_evals", required=True, help="maximum number of configuration evaluations") 127 arg_parser.add_argument("-me", "--max_evals", required=True, help="maximum number of configuration evaluations")
103 arg_parser.add_argument("-ts", "--test_share", required=True, help="share of data to be used for testing") 128 arg_parser.add_argument("-ts", "--test_share", required=True, help="share of data to be used for testing")
104 arg_parser.add_argument("-vs", "--validation_share", required=True, help="share of data to be used for validation")
105 # neural network parameters 129 # neural network parameters
106 arg_parser.add_argument("-bs", "--batch_size", required=True, help="size of the tranining batch i.e. the number of samples per batch") 130 arg_parser.add_argument("-bs", "--batch_size", required=True, help="size of the tranining batch i.e. the number of samples per batch")
107 arg_parser.add_argument("-ut", "--units", required=True, help="number of hidden recurrent units") 131 arg_parser.add_argument("-ut", "--units", required=True, help="number of hidden recurrent units")
108 arg_parser.add_argument("-es", "--embedding_size", required=True, help="size of the fixed vector learned for each tool") 132 arg_parser.add_argument("-es", "--embedding_size", required=True, help="size of the fixed vector learned for each tool")
109 arg_parser.add_argument("-dt", "--dropout", required=True, help="percentage of neurons to be dropped") 133 arg_parser.add_argument("-dt", "--dropout", required=True, help="percentage of neurons to be dropped")
110 arg_parser.add_argument("-sd", "--spatial_dropout", required=True, help="1d dropout used for embedding layer") 134 arg_parser.add_argument("-sd", "--spatial_dropout", required=True, help="1d dropout used for embedding layer")
111 arg_parser.add_argument("-rd", "--recurrent_dropout", required=True, help="dropout for the recurrent layers") 135 arg_parser.add_argument("-rd", "--recurrent_dropout", required=True, help="dropout for the recurrent layers")
112 arg_parser.add_argument("-lr", "--learning_rate", required=True, help="learning rate") 136 arg_parser.add_argument("-lr", "--learning_rate", required=True, help="learning rate")
113 arg_parser.add_argument("-ar", "--activation_recurrent", required=True, help="activation function for recurrent layers") 137
114 arg_parser.add_argument("-ao", "--activation_output", required=True, help="activation function for output layers")
115 # get argument values 138 # get argument values
116 args = vars(arg_parser.parse_args()) 139 args = vars(arg_parser.parse_args())
117 tool_usage_path = args["tool_usage_file"] 140 tool_usage_path = args["tool_usage_file"]
118 workflows_path = args["workflow_file"] 141 workflows_path = args["workflow_file"]
119 cutoff_date = args["cutoff_date"] 142 cutoff_date = args["cutoff_date"]
121 trained_model_path = args["output_model"] 144 trained_model_path = args["output_model"]
122 n_epochs = int(args["n_epochs"]) 145 n_epochs = int(args["n_epochs"])
123 optimize_n_epochs = int(args["optimize_n_epochs"]) 146 optimize_n_epochs = int(args["optimize_n_epochs"])
124 max_evals = int(args["max_evals"]) 147 max_evals = int(args["max_evals"])
125 test_share = float(args["test_share"]) 148 test_share = float(args["test_share"])
126 validation_share = float(args["validation_share"])
127 batch_size = args["batch_size"] 149 batch_size = args["batch_size"]
128 units = args["units"] 150 units = args["units"]
129 embedding_size = args["embedding_size"] 151 embedding_size = args["embedding_size"]
130 dropout = args["dropout"] 152 dropout = args["dropout"]
131 spatial_dropout = args["spatial_dropout"] 153 spatial_dropout = args["spatial_dropout"]
132 recurrent_dropout = args["recurrent_dropout"] 154 recurrent_dropout = args["recurrent_dropout"]
133 learning_rate = args["learning_rate"] 155 learning_rate = args["learning_rate"]
134 activation_recurrent = args["activation_recurrent"] 156 num_cpus = 16
135 activation_output = args["activation_output"]
136 157
137 config = { 158 config = {
138 'cutoff_date': cutoff_date, 159 'cutoff_date': cutoff_date,
139 'maximum_path_length': maximum_path_length, 160 'maximum_path_length': maximum_path_length,
140 'n_epochs': n_epochs, 161 'n_epochs': n_epochs,
141 'optimize_n_epochs': optimize_n_epochs, 162 'optimize_n_epochs': optimize_n_epochs,
142 'max_evals': max_evals, 163 'max_evals': max_evals,
143 'test_share': test_share, 164 'test_share': test_share,
144 'validation_share': validation_share,
145 'batch_size': batch_size, 165 'batch_size': batch_size,
146 'units': units, 166 'units': units,
147 'embedding_size': embedding_size, 167 'embedding_size': embedding_size,
148 'dropout': dropout, 168 'dropout': dropout,
149 'spatial_dropout': spatial_dropout, 169 'spatial_dropout': spatial_dropout,
150 'recurrent_dropout': recurrent_dropout, 170 'recurrent_dropout': recurrent_dropout,
151 'learning_rate': learning_rate, 171 'learning_rate': learning_rate
152 'activation_recurrent': activation_recurrent,
153 'activation_output': activation_output
154 } 172 }
155 173
156 # Extract and process workflows 174 # Extract and process workflows
157 connections = extract_workflow_connections.ExtractWorkflowConnections() 175 connections = extract_workflow_connections.ExtractWorkflowConnections()
158 workflow_paths, compatible_next_tools = connections.read_tabular_file(workflows_path) 176 workflow_paths, compatible_next_tools, standard_connections = connections.read_tabular_file(workflows_path)
159 # Process the paths from workflows 177 # Process the paths from workflows
160 print("Dividing data...") 178 print("Dividing data...")
161 data = prepare_data.PrepareData(maximum_path_length, test_share) 179 data = prepare_data.PrepareData(maximum_path_length, test_share)
162 train_data, train_labels, test_data, test_labels, data_dictionary, reverse_dictionary, class_weights, usage_pred = data.get_data_labels_matrices(workflow_paths, tool_usage_path, cutoff_date, compatible_next_tools) 180 train_data, train_labels, test_data, test_labels, data_dictionary, reverse_dictionary, class_weights, usage_pred, l_tool_freq, l_tool_tr_samples = data.get_data_labels_matrices(workflow_paths, tool_usage_path, cutoff_date, compatible_next_tools, standard_connections)
163 # find the best model and start training 181 # find the best model and start training
164 predict_tool = PredictTool() 182 predict_tool = PredictTool(num_cpus)
165 # start training with weighted classes 183 # start training with weighted classes
166 print("Training with weighted classes and samples ...") 184 print("Training with weighted classes and samples ...")
167 results_weighted = predict_tool.find_train_best_network(config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, compatible_next_tools) 185 results_weighted = predict_tool.find_train_best_network(config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, l_tool_freq, l_tool_tr_samples)
168 print() 186 utils.save_model(results_weighted, data_dictionary, compatible_next_tools, trained_model_path, class_weights, standard_connections)
169 print("Best parameters \n")
170 print(results_weighted["best_parameters"])
171 print()
172 utils.save_model(results_weighted, data_dictionary, compatible_next_tools, trained_model_path, class_weights)
173 end_time = time.time() 187 end_time = time.time()
174 print() 188 print()
175 print("Program finished in %s seconds" % str(end_time - start_time)) 189 print("Program finished in %s seconds" % str(end_time - start_time))