Mercurial > repos > bgruening > create_tool_recommendation_model
comparison main.py @ 3:98bc44d17561 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 65d36f271296a38deeceb0d0e8d471b2898ee8f4"
| author | bgruening | 
|---|---|
| date | Tue, 07 Jul 2020 07:24:21 +0000 | 
| parents | 50753817983a | 
| children | f0da532be419 | 
   comparison
  equal
  deleted
  inserted
  replaced
| 2:50753817983a | 3:98bc44d17561 | 
|---|---|
| 29 inter_op_parallelism_threads=num_cpus, | 29 inter_op_parallelism_threads=num_cpus, | 
| 30 allow_soft_placement=True | 30 allow_soft_placement=True | 
| 31 ) | 31 ) | 
| 32 K.set_session(tf.Session(config=cpu_config)) | 32 K.set_session(tf.Session(config=cpu_config)) | 
| 33 | 33 | 
| 34 def find_train_best_network(self, network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, l_tool_freq, l_tool_tr_samples): | 34 def find_train_best_network(self, network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, tool_freq, tool_tr_samples): | 
| 35 """ | 35 """ | 
| 36 Define recurrent neural network and train sequential data | 36 Define recurrent neural network and train sequential data | 
| 37 """ | 37 """ | 
| 38 # get tools with lowest representation | 38 # get tools with lowest representation | 
| 39 lowest_tool_ids = utils.get_lowest_tools(l_tool_freq) | 39 lowest_tool_ids = utils.get_lowest_tools(tool_freq) | 
| 40 | 40 | 
| 41 print("Start hyperparameter optimisation...") | 41 print("Start hyperparameter optimisation...") | 
| 42 hyper_opt = optimise_hyperparameters.HyperparameterOptimisation() | 42 hyper_opt = optimise_hyperparameters.HyperparameterOptimisation() | 
| 43 best_params, best_model = hyper_opt.train_model(network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, l_tool_tr_samples, class_weights) | 43 best_params, best_model = hyper_opt.train_model(network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, tool_tr_samples, class_weights) | 
| 44 | 44 | 
| 45 # define callbacks | 45 # define callbacks | 
| 46 early_stopping = callbacks.EarlyStopping(monitor='loss', mode='min', verbose=1, min_delta=1e-1, restore_best_weights=True) | 46 early_stopping = callbacks.EarlyStopping(monitor='loss', mode='min', verbose=1, min_delta=1e-1, restore_best_weights=True) | 
| 47 predict_callback_test = PredictCallback(test_data, test_labels, reverse_dictionary, n_epochs, usage_pred, standard_connections, lowest_tool_ids) | 47 predict_callback_test = PredictCallback(test_data, test_labels, reverse_dictionary, n_epochs, usage_pred, standard_connections, lowest_tool_ids) | 
| 48 | 48 | 
| 49 callbacks_list = [predict_callback_test, early_stopping] | 49 callbacks_list = [predict_callback_test, early_stopping] | 
| 50 | |
| 51 batch_size = int(best_params["batch_size"]) | 50 batch_size = int(best_params["batch_size"]) | 
| 52 | 51 | 
| 53 print("Start training on the best model...") | 52 print("Start training on the best model...") | 
| 54 train_performance = dict() | 53 train_performance = dict() | 
| 55 trained_model = best_model.fit_generator( | 54 trained_model = best_model.fit_generator( | 
| 56 utils.balanced_sample_generator( | 55 utils.balanced_sample_generator( | 
| 57 train_data, | 56 train_data, | 
| 58 train_labels, | 57 train_labels, | 
| 59 batch_size, | 58 batch_size, | 
| 60 l_tool_tr_samples | 59 tool_tr_samples, | 
| 60 reverse_dictionary | |
| 61 ), | 61 ), | 
| 62 steps_per_epoch=len(train_data) // batch_size, | 62 steps_per_epoch=len(train_data) // batch_size, | 
| 63 epochs=n_epochs, | 63 epochs=n_epochs, | 
| 64 callbacks=callbacks_list, | 64 callbacks=callbacks_list, | 
| 65 validation_data=(test_data, test_labels), | 65 validation_data=(test_data, test_labels), | 
| 175 connections = extract_workflow_connections.ExtractWorkflowConnections() | 175 connections = extract_workflow_connections.ExtractWorkflowConnections() | 
| 176 workflow_paths, compatible_next_tools, standard_connections = connections.read_tabular_file(workflows_path) | 176 workflow_paths, compatible_next_tools, standard_connections = connections.read_tabular_file(workflows_path) | 
| 177 # Process the paths from workflows | 177 # Process the paths from workflows | 
| 178 print("Dividing data...") | 178 print("Dividing data...") | 
| 179 data = prepare_data.PrepareData(maximum_path_length, test_share) | 179 data = prepare_data.PrepareData(maximum_path_length, test_share) | 
| 180 train_data, train_labels, test_data, test_labels, data_dictionary, reverse_dictionary, class_weights, usage_pred, l_tool_freq, l_tool_tr_samples = data.get_data_labels_matrices(workflow_paths, tool_usage_path, cutoff_date, compatible_next_tools, standard_connections) | 180 train_data, train_labels, test_data, test_labels, data_dictionary, reverse_dictionary, class_weights, usage_pred, train_tool_freq, tool_tr_samples = data.get_data_labels_matrices(workflow_paths, tool_usage_path, cutoff_date, compatible_next_tools, standard_connections) | 
| 181 # find the best model and start training | 181 # find the best model and start training | 
| 182 predict_tool = PredictTool(num_cpus) | 182 predict_tool = PredictTool(num_cpus) | 
| 183 # start training with weighted classes | 183 # start training with weighted classes | 
| 184 print("Training with weighted classes and samples ...") | 184 print("Training with weighted classes and samples ...") | 
| 185 results_weighted = predict_tool.find_train_best_network(config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, l_tool_freq, l_tool_tr_samples) | 185 results_weighted = predict_tool.find_train_best_network(config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, train_tool_freq, tool_tr_samples) | 
| 186 utils.save_model(results_weighted, data_dictionary, compatible_next_tools, trained_model_path, class_weights, standard_connections) | 186 utils.save_model(results_weighted, data_dictionary, compatible_next_tools, trained_model_path, class_weights, standard_connections) | 
| 187 end_time = time.time() | 187 end_time = time.time() | 
| 188 print() | |
| 189 print("Program finished in %s seconds" % str(end_time - start_time)) | 188 print("Program finished in %s seconds" % str(end_time - start_time)) | 
