Mercurial > repos > bgruening > create_tool_recommendation_model
annotate main.py @ 2:50753817983a draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
| author | bgruening | 
|---|---|
| date | Sat, 09 May 2020 09:38:04 +0000 | 
| parents | 275e98795e99 | 
| children | 98bc44d17561 | 
| rev | line source | 
|---|---|
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 1 """ | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 2 Predict next tools in the Galaxy workflows | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 3 using machine learning (recurrent neural network) | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 4 """ | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 5 | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 6 import numpy as np | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 7 import argparse | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 8 import time | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 9 | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 10 # machine learning library | 
| 2 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 11 import tensorflow as tf | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 12 from keras import backend as K | 
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 13 import keras.callbacks as callbacks | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 14 | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 15 import extract_workflow_connections | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 16 import prepare_data | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 17 import optimise_hyperparameters | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 18 import utils | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 19 | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 20 | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 21 class PredictTool: | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 22 | 
| 2 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 23 def __init__(self, num_cpus): | 
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 24 """ Init method. """ | 
| 2 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 25 # set the number of cpus | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 26 cpu_config = tf.ConfigProto( | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 27 device_count={"CPU": num_cpus}, | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 28 intra_op_parallelism_threads=num_cpus, | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 29 inter_op_parallelism_threads=num_cpus, | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 30 allow_soft_placement=True | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 31 ) | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 32 K.set_session(tf.Session(config=cpu_config)) | 
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 33 | 
| 2 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 34 def find_train_best_network(self, network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, l_tool_freq, l_tool_tr_samples): | 
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 35 """ | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 36 Define recurrent neural network and train sequential data | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 37 """ | 
| 2 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 38 # get tools with lowest representation | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 39 lowest_tool_ids = utils.get_lowest_tools(l_tool_freq) | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 40 | 
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 41 print("Start hyperparameter optimisation...") | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 42 hyper_opt = optimise_hyperparameters.HyperparameterOptimisation() | 
| 2 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 43 best_params, best_model = hyper_opt.train_model(network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, l_tool_tr_samples, class_weights) | 
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 44 | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 45 # define callbacks | 
| 2 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 46 early_stopping = callbacks.EarlyStopping(monitor='loss', mode='min', verbose=1, min_delta=1e-1, restore_best_weights=True) | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 47 predict_callback_test = PredictCallback(test_data, test_labels, reverse_dictionary, n_epochs, usage_pred, standard_connections, lowest_tool_ids) | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 48 | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 49 callbacks_list = [predict_callback_test, early_stopping] | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 50 | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 51 batch_size = int(best_params["batch_size"]) | 
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 52 | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 53 print("Start training on the best model...") | 
| 2 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 54 train_performance = dict() | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 55 trained_model = best_model.fit_generator( | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 56 utils.balanced_sample_generator( | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 57 train_data, | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 58 train_labels, | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 59 batch_size, | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 60 l_tool_tr_samples | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 61 ), | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 62 steps_per_epoch=len(train_data) // batch_size, | 
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 63 epochs=n_epochs, | 
| 2 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 64 callbacks=callbacks_list, | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 65 validation_data=(test_data, test_labels), | 
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 66 verbose=2, | 
| 2 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 67 shuffle=True | 
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 68 ) | 
| 2 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 69 train_performance["validation_loss"] = np.array(trained_model.history["val_loss"]) | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 70 train_performance["precision"] = predict_callback_test.precision | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 71 train_performance["usage_weights"] = predict_callback_test.usage_weights | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 72 train_performance["published_precision"] = predict_callback_test.published_precision | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 73 train_performance["lowest_pub_precision"] = predict_callback_test.lowest_pub_precision | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 74 train_performance["lowest_norm_precision"] = predict_callback_test.lowest_norm_precision | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 75 train_performance["train_loss"] = np.array(trained_model.history["loss"]) | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 76 train_performance["model"] = best_model | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 77 train_performance["best_parameters"] = best_params | 
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 78 return train_performance | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 79 | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 80 | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 81 class PredictCallback(callbacks.Callback): | 
| 2 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 82 def __init__(self, test_data, test_labels, reverse_data_dictionary, n_epochs, usg_scores, standard_connections, lowest_tool_ids): | 
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 83 self.test_data = test_data | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 84 self.test_labels = test_labels | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 85 self.reverse_data_dictionary = reverse_data_dictionary | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 86 self.precision = list() | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 87 self.usage_weights = list() | 
| 2 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 88 self.published_precision = list() | 
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 89 self.n_epochs = n_epochs | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 90 self.pred_usage_scores = usg_scores | 
| 2 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 91 self.standard_connections = standard_connections | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 92 self.lowest_tool_ids = lowest_tool_ids | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 93 self.lowest_pub_precision = list() | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 94 self.lowest_norm_precision = list() | 
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 95 | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 96 def on_epoch_end(self, epoch, logs={}): | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 97 """ | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 98 Compute absolute and compatible precision for test data | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 99 """ | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 100 if len(self.test_data) > 0: | 
| 2 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 101 usage_weights, precision, precision_pub, low_pub_prec, low_norm_prec, low_num = utils.verify_model(self.model, self.test_data, self.test_labels, self.reverse_data_dictionary, self.pred_usage_scores, self.standard_connections, self.lowest_tool_ids) | 
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 102 self.precision.append(precision) | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 103 self.usage_weights.append(usage_weights) | 
| 2 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 104 self.published_precision.append(precision_pub) | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 105 self.lowest_pub_precision.append(low_pub_prec) | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 106 self.lowest_norm_precision.append(low_norm_prec) | 
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 107 print("Epoch %d usage weights: %s" % (epoch + 1, usage_weights)) | 
| 2 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 108 print("Epoch %d normal precision: %s" % (epoch + 1, precision)) | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 109 print("Epoch %d published precision: %s" % (epoch + 1, precision_pub)) | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 110 print("Epoch %d lowest published precision: %s" % (epoch + 1, low_pub_prec)) | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 111 print("Epoch %d lowest normal precision: %s" % (epoch + 1, low_norm_prec)) | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 112 print("Epoch %d number of test samples with lowest tool ids: %s" % (epoch + 1, low_num)) | 
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 113 | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 114 | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 115 if __name__ == "__main__": | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 116 start_time = time.time() | 
| 2 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 117 | 
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 118 arg_parser = argparse.ArgumentParser() | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 119 arg_parser.add_argument("-wf", "--workflow_file", required=True, help="workflows tabular file") | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 120 arg_parser.add_argument("-tu", "--tool_usage_file", required=True, help="tool usage file") | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 121 arg_parser.add_argument("-om", "--output_model", required=True, help="trained model file") | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 122 # data parameters | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 123 arg_parser.add_argument("-cd", "--cutoff_date", required=True, help="earliest date for taking tool usage") | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 124 arg_parser.add_argument("-pl", "--maximum_path_length", required=True, help="maximum length of tool path") | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 125 arg_parser.add_argument("-ep", "--n_epochs", required=True, help="number of iterations to run to create model") | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 126 arg_parser.add_argument("-oe", "--optimize_n_epochs", required=True, help="number of iterations to run to find best model parameters") | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 127 arg_parser.add_argument("-me", "--max_evals", required=True, help="maximum number of configuration evaluations") | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 128 arg_parser.add_argument("-ts", "--test_share", required=True, help="share of data to be used for testing") | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 129 # neural network parameters | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 130 arg_parser.add_argument("-bs", "--batch_size", required=True, help="size of the tranining batch i.e. the number of samples per batch") | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 131 arg_parser.add_argument("-ut", "--units", required=True, help="number of hidden recurrent units") | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 132 arg_parser.add_argument("-es", "--embedding_size", required=True, help="size of the fixed vector learned for each tool") | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 133 arg_parser.add_argument("-dt", "--dropout", required=True, help="percentage of neurons to be dropped") | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 134 arg_parser.add_argument("-sd", "--spatial_dropout", required=True, help="1d dropout used for embedding layer") | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 135 arg_parser.add_argument("-rd", "--recurrent_dropout", required=True, help="dropout for the recurrent layers") | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 136 arg_parser.add_argument("-lr", "--learning_rate", required=True, help="learning rate") | 
| 2 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 137 | 
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 138 # get argument values | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 139 args = vars(arg_parser.parse_args()) | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 140 tool_usage_path = args["tool_usage_file"] | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 141 workflows_path = args["workflow_file"] | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 142 cutoff_date = args["cutoff_date"] | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 143 maximum_path_length = int(args["maximum_path_length"]) | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 144 trained_model_path = args["output_model"] | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 145 n_epochs = int(args["n_epochs"]) | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 146 optimize_n_epochs = int(args["optimize_n_epochs"]) | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 147 max_evals = int(args["max_evals"]) | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 148 test_share = float(args["test_share"]) | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 149 batch_size = args["batch_size"] | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 150 units = args["units"] | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 151 embedding_size = args["embedding_size"] | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 152 dropout = args["dropout"] | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 153 spatial_dropout = args["spatial_dropout"] | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 154 recurrent_dropout = args["recurrent_dropout"] | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 155 learning_rate = args["learning_rate"] | 
| 2 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 156 num_cpus = 16 | 
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 157 | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 158 config = { | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 159 'cutoff_date': cutoff_date, | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 160 'maximum_path_length': maximum_path_length, | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 161 'n_epochs': n_epochs, | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 162 'optimize_n_epochs': optimize_n_epochs, | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 163 'max_evals': max_evals, | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 164 'test_share': test_share, | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 165 'batch_size': batch_size, | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 166 'units': units, | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 167 'embedding_size': embedding_size, | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 168 'dropout': dropout, | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 169 'spatial_dropout': spatial_dropout, | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 170 'recurrent_dropout': recurrent_dropout, | 
| 2 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 171 'learning_rate': learning_rate | 
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 172 } | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 173 | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 174 # Extract and process workflows | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 175 connections = extract_workflow_connections.ExtractWorkflowConnections() | 
| 2 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 176 workflow_paths, compatible_next_tools, standard_connections = connections.read_tabular_file(workflows_path) | 
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 177 # Process the paths from workflows | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 178 print("Dividing data...") | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 179 data = prepare_data.PrepareData(maximum_path_length, test_share) | 
| 2 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 180 train_data, train_labels, test_data, test_labels, data_dictionary, reverse_dictionary, class_weights, usage_pred, l_tool_freq, l_tool_tr_samples = data.get_data_labels_matrices(workflow_paths, tool_usage_path, cutoff_date, compatible_next_tools, standard_connections) | 
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 181 # find the best model and start training | 
| 2 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 182 predict_tool = PredictTool(num_cpus) | 
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 183 # start training with weighted classes | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 184 print("Training with weighted classes and samples ...") | 
| 2 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 185 results_weighted = predict_tool.find_train_best_network(config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, l_tool_freq, l_tool_tr_samples) | 
| 
50753817983a
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
 bgruening parents: 
1diff
changeset | 186 utils.save_model(results_weighted, data_dictionary, compatible_next_tools, trained_model_path, class_weights, standard_connections) | 
| 0 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 187 end_time = time.time() | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 188 print() | 
| 
22ebbac136c7
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
 bgruening parents: diff
changeset | 189 print("Program finished in %s seconds" % str(end_time - start_time)) | 
