Mercurial > repos > bgruening > create_tool_recommendation_model
comparison prepare_data.py @ 2:50753817983a draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit c635df659fe1835679438589ded43136b0e515c6"
| author | bgruening |
|---|---|
| date | Sat, 09 May 2020 09:38:04 +0000 |
| parents | 22ebbac136c7 |
| children | 98bc44d17561 |
comparison
equal
deleted
inserted
replaced
| 1:275e98795e99 | 2:50753817983a |
|---|---|
| 8 import collections | 8 import collections |
| 9 import numpy as np | 9 import numpy as np |
| 10 import random | 10 import random |
| 11 | 11 |
| 12 import predict_tool_usage | 12 import predict_tool_usage |
| 13 import utils | |
| 13 | 14 |
| 14 main_path = os.getcwd() | 15 main_path = os.getcwd() |
| 15 | 16 |
| 16 | 17 |
| 17 class PrepareData: | 18 class PrepareData: |
| 18 | 19 |
| 19 @classmethod | |
| 20 def __init__(self, max_seq_length, test_data_share): | 20 def __init__(self, max_seq_length, test_data_share): |
| 21 """ Init method. """ | 21 """ Init method. """ |
| 22 self.max_tool_sequence_len = max_seq_length | 22 self.max_tool_sequence_len = max_seq_length |
| 23 self.test_share = test_data_share | 23 self.test_share = test_data_share |
| 24 | 24 |
| 25 @classmethod | |
| 26 def process_workflow_paths(self, workflow_paths): | 25 def process_workflow_paths(self, workflow_paths): |
| 27 """ | 26 """ |
| 28 Get all the tools and complete set of individual paths for each workflow | 27 Get all the tools and complete set of individual paths for each workflow |
| 29 """ | 28 """ |
| 30 tokens = list() | 29 tokens = list() |
| 38 tokens = list(set(tokens)) | 37 tokens = list(set(tokens)) |
| 39 tokens = np.array(tokens) | 38 tokens = np.array(tokens) |
| 40 tokens = np.reshape(tokens, [-1, ]) | 39 tokens = np.reshape(tokens, [-1, ]) |
| 41 return tokens, raw_paths | 40 return tokens, raw_paths |
| 42 | 41 |
| 43 @classmethod | |
| 44 def create_new_dict(self, new_data_dict): | 42 def create_new_dict(self, new_data_dict): |
| 45 """ | 43 """ |
| 46 Create new data dictionary | 44 Create new data dictionary |
| 47 """ | 45 """ |
| 48 reverse_dict = dict((v, k) for k, v in new_data_dict.items()) | 46 reverse_dict = dict((v, k) for k, v in new_data_dict.items()) |
| 49 return new_data_dict, reverse_dict | 47 return new_data_dict, reverse_dict |
| 50 | 48 |
| 51 @classmethod | |
| 52 def assemble_dictionary(self, new_data_dict, old_data_dictionary={}): | 49 def assemble_dictionary(self, new_data_dict, old_data_dictionary={}): |
| 53 """ | 50 """ |
| 54 Create/update tools indices in the forward and backward dictionary | 51 Create/update tools indices in the forward and backward dictionary |
| 55 """ | 52 """ |
| 56 new_data_dict, reverse_dict = self.create_new_dict(new_data_dict) | 53 new_data_dict, reverse_dict = self.create_new_dict(new_data_dict) |
| 57 return new_data_dict, reverse_dict | 54 return new_data_dict, reverse_dict |
| 58 | 55 |
| 59 @classmethod | |
| 60 def create_data_dictionary(self, words, old_data_dictionary={}): | 56 def create_data_dictionary(self, words, old_data_dictionary={}): |
| 61 """ | 57 """ |
| 62 Create two dictionaries having tools names and their indexes | 58 Create two dictionaries having tools names and their indexes |
| 63 """ | 59 """ |
| 64 count = collections.Counter(words).most_common() | 60 count = collections.Counter(words).most_common() |
| 66 for word, _ in count: | 62 for word, _ in count: |
| 67 dictionary[word] = len(dictionary) + 1 | 63 dictionary[word] = len(dictionary) + 1 |
| 68 dictionary, reverse_dictionary = self.assemble_dictionary(dictionary, old_data_dictionary) | 64 dictionary, reverse_dictionary = self.assemble_dictionary(dictionary, old_data_dictionary) |
| 69 return dictionary, reverse_dictionary | 65 return dictionary, reverse_dictionary |
| 70 | 66 |
| 71 @classmethod | |
| 72 def decompose_paths(self, paths, dictionary): | 67 def decompose_paths(self, paths, dictionary): |
| 73 """ | 68 """ |
| 74 Decompose the paths to variable length sub-paths keeping the first tool fixed | 69 Decompose the paths to variable length sub-paths keeping the first tool fixed |
| 75 """ | 70 """ |
| 76 sub_paths_pos = list() | 71 sub_paths_pos = list() |
| 84 if len(tools_pos) > 1: | 79 if len(tools_pos) > 1: |
| 85 sub_paths_pos.append(",".join(tools_pos)) | 80 sub_paths_pos.append(",".join(tools_pos)) |
| 86 sub_paths_pos = list(set(sub_paths_pos)) | 81 sub_paths_pos = list(set(sub_paths_pos)) |
| 87 return sub_paths_pos | 82 return sub_paths_pos |
| 88 | 83 |
| 89 @classmethod | |
| 90 def prepare_paths_labels_dictionary(self, dictionary, reverse_dictionary, paths, compatible_next_tools): | 84 def prepare_paths_labels_dictionary(self, dictionary, reverse_dictionary, paths, compatible_next_tools): |
| 91 """ | 85 """ |
| 92 Create a dictionary of sequences with their labels for training and test paths | 86 Create a dictionary of sequences with their labels for training and test paths |
| 93 """ | 87 """ |
| 94 paths_labels = dict() | 88 paths_labels = dict() |
| 114 paths_labels[train_tools] = composite_labels | 108 paths_labels[train_tools] = composite_labels |
| 115 for item in paths_labels: | 109 for item in paths_labels: |
| 116 paths_labels[item] = ",".join(list(set(paths_labels[item].split(",")))) | 110 paths_labels[item] = ",".join(list(set(paths_labels[item].split(",")))) |
| 117 return paths_labels | 111 return paths_labels |
| 118 | 112 |
| 119 @classmethod | 113 def pad_test_paths(self, paths_dictionary, num_classes): |
| 120 def pad_paths(self, paths_dictionary, num_classes): | |
| 121 """ | 114 """ |
| 122 Add padding to the tools sequences and create multi-hot encoded labels | 115 Add padding to the tools sequences and create multi-hot encoded labels |
| 123 """ | 116 """ |
| 124 size_data = len(paths_dictionary) | 117 size_data = len(paths_dictionary) |
| 125 data_mat = np.zeros([size_data, self.max_tool_sequence_len]) | 118 data_mat = np.zeros([size_data, self.max_tool_sequence_len]) |
| 133 for label_item in train_label.split(","): | 126 for label_item in train_label.split(","): |
| 134 label_mat[train_counter][int(label_item)] = 1.0 | 127 label_mat[train_counter][int(label_item)] = 1.0 |
| 135 train_counter += 1 | 128 train_counter += 1 |
| 136 return data_mat, label_mat | 129 return data_mat, label_mat |
| 137 | 130 |
| 138 @classmethod | 131 def pad_paths(self, paths_dictionary, num_classes, standard_connections, reverse_dictionary): |
| 132 """ | |
| 133 Add padding to the tools sequences and create multi-hot encoded labels | |
| 134 """ | |
| 135 size_data = len(paths_dictionary) | |
| 136 data_mat = np.zeros([size_data, self.max_tool_sequence_len]) | |
| 137 label_mat = np.zeros([size_data, 2 * (num_classes + 1)]) | |
| 138 pos_flag = 1.0 | |
| 139 train_counter = 0 | |
| 140 for train_seq, train_label in list(paths_dictionary.items()): | |
| 141 pub_connections = list() | |
| 142 positions = train_seq.split(",") | |
| 143 last_tool_id = positions[-1] | |
| 144 last_tool_name = reverse_dictionary[int(last_tool_id)] | |
| 145 start_pos = self.max_tool_sequence_len - len(positions) | |
| 146 for id_pos, pos in enumerate(positions): | |
| 147 data_mat[train_counter][start_pos + id_pos] = int(pos) | |
| 148 if last_tool_name in standard_connections: | |
| 149 pub_connections = standard_connections[last_tool_name] | |
| 150 for label_item in train_label.split(","): | |
| 151 label_pos = int(label_item) | |
| 152 label_row = label_mat[train_counter] | |
| 153 if reverse_dictionary[label_pos] in pub_connections: | |
| 154 label_row[label_pos] = pos_flag | |
| 155 else: | |
| 156 label_row[label_pos + num_classes + 1] = pos_flag | |
| 157 train_counter += 1 | |
| 158 return data_mat, label_mat | |
| 159 | |
| 139 def split_test_train_data(self, multilabels_paths): | 160 def split_test_train_data(self, multilabels_paths): |
| 140 """ | 161 """ |
| 141 Split into test and train data randomly for each run | 162 Split into test and train data randomly for each run |
| 142 """ | 163 """ |
| 143 train_dict = dict() | 164 train_dict = dict() |
| 150 test_dict[path] = multilabels_paths[path] | 171 test_dict[path] = multilabels_paths[path] |
| 151 else: | 172 else: |
| 152 train_dict[path] = multilabels_paths[path] | 173 train_dict[path] = multilabels_paths[path] |
| 153 return train_dict, test_dict | 174 return train_dict, test_dict |
| 154 | 175 |
| 155 @classmethod | |
| 156 def verify_overlap(self, train_paths, test_paths): | |
| 157 """ | |
| 158 Verify the overlapping of samples in train and test data | |
| 159 """ | |
| 160 intersection = list(set(train_paths).intersection(set(test_paths))) | |
| 161 print("Overlap in train and test: %d" % len(intersection)) | |
| 162 | |
| 163 @classmethod | |
| 164 def get_predicted_usage(self, data_dictionary, predicted_usage): | 176 def get_predicted_usage(self, data_dictionary, predicted_usage): |
| 165 """ | 177 """ |
| 166 Get predicted usage for tools | 178 Get predicted usage for tools |
| 167 """ | 179 """ |
| 168 usage = dict() | 180 usage = dict() |
| 178 except Exception: | 190 except Exception: |
| 179 usage[v] = epsilon | 191 usage[v] = epsilon |
| 180 continue | 192 continue |
| 181 return usage | 193 return usage |
| 182 | 194 |
| 183 @classmethod | |
| 184 def assign_class_weights(self, n_classes, predicted_usage): | 195 def assign_class_weights(self, n_classes, predicted_usage): |
| 185 """ | 196 """ |
| 186 Compute class weights using usage | 197 Compute class weights using usage |
| 187 """ | 198 """ |
| 188 class_weights = dict() | 199 class_weights = dict() |
| 189 class_weights[str(0)] = 0.0 | 200 class_weights[str(0)] = 0.0 |
| 190 for key in range(1, n_classes): | 201 for key in range(1, n_classes + 1): |
| 191 u_score = predicted_usage[key] | 202 u_score = predicted_usage[key] |
| 192 if u_score < 1.0: | 203 if u_score < 1.0: |
| 193 u_score += 1.0 | 204 u_score += 1.0 |
| 194 class_weights[key] = np.log(u_score) | 205 class_weights[key] = np.round(np.log(u_score), 6) |
| 195 return class_weights | 206 return class_weights |
| 196 | 207 |
| 197 @classmethod | 208 def get_train_last_tool_freq(self, train_paths, reverse_dictionary): |
| 198 def get_sample_weights(self, train_data, reverse_dictionary, paths_frequency): | 209 """ |
| 199 """ | 210 Get the frequency of last tool of each tool sequence |
| 200 Compute the frequency of paths in training data | 211 to estimate the frequency of tool sequences |
| 201 """ | 212 """ |
| 202 path_weights = np.zeros(len(train_data)) | 213 last_tool_freq = dict() |
| 203 for path_index, path in enumerate(train_data): | 214 inv_freq = dict() |
| 204 sample_pos = np.where(path > 0)[0] | 215 for path in train_paths: |
| 205 sample_tool_pos = path[sample_pos[0]:] | 216 last_tool = path.split(",")[-1] |
| 206 path_name = ",".join([reverse_dictionary[int(tool_pos)] for tool_pos in sample_tool_pos]) | 217 if last_tool not in last_tool_freq: |
| 207 try: | 218 last_tool_freq[last_tool] = 0 |
| 208 path_weights[path_index] = int(paths_frequency[path_name]) | 219 last_tool_freq[last_tool] += 1 |
| 209 except Exception: | 220 max_freq = max(last_tool_freq.values()) |
| 210 path_weights[path_index] = 1 | 221 for t in last_tool_freq: |
| 211 return path_weights | 222 inv_freq[t] = int(np.round(max_freq / float(last_tool_freq[t]), 0)) |
| 212 | 223 return last_tool_freq, inv_freq |
| 213 @classmethod | 224 |
| 214 def get_data_labels_matrices(self, workflow_paths, tool_usage_path, cutoff_date, compatible_next_tools, old_data_dictionary={}): | 225 def get_toolid_samples(self, train_data, l_tool_freq): |
| 226 l_tool_tr_samples = dict() | |
| 227 for tool_id in l_tool_freq: | |
| 228 for index, tr_sample in enumerate(train_data): | |
| 229 last_tool_id = str(int(tr_sample[-1])) | |
| 230 if last_tool_id == tool_id: | |
| 231 if last_tool_id not in l_tool_tr_samples: | |
| 232 l_tool_tr_samples[last_tool_id] = list() | |
| 233 l_tool_tr_samples[last_tool_id].append(index) | |
| 234 return l_tool_tr_samples | |
| 235 | |
| 236 def get_data_labels_matrices(self, workflow_paths, tool_usage_path, cutoff_date, compatible_next_tools, standard_connections, old_data_dictionary={}): | |
| 215 """ | 237 """ |
| 216 Convert the training and test paths into corresponding numpy matrices | 238 Convert the training and test paths into corresponding numpy matrices |
| 217 """ | 239 """ |
| 218 processed_data, raw_paths = self.process_workflow_paths(workflow_paths) | 240 processed_data, raw_paths = self.process_workflow_paths(workflow_paths) |
| 219 dictionary, reverse_dictionary = self.create_data_dictionary(processed_data, old_data_dictionary) | 241 dictionary, rev_dict = self.create_data_dictionary(processed_data, old_data_dictionary) |
| 220 num_classes = len(dictionary) | 242 num_classes = len(dictionary) |
| 221 | 243 |
| 222 print("Raw paths: %d" % len(raw_paths)) | 244 print("Raw paths: %d" % len(raw_paths)) |
| 223 random.shuffle(raw_paths) | 245 random.shuffle(raw_paths) |
| 224 | 246 |
| 225 print("Decomposing paths...") | 247 print("Decomposing paths...") |
| 226 all_unique_paths = self.decompose_paths(raw_paths, dictionary) | 248 all_unique_paths = self.decompose_paths(raw_paths, dictionary) |
| 227 random.shuffle(all_unique_paths) | 249 random.shuffle(all_unique_paths) |
| 228 | 250 |
| 229 print("Creating dictionaries...") | 251 print("Creating dictionaries...") |
| 230 multilabels_paths = self.prepare_paths_labels_dictionary(dictionary, reverse_dictionary, all_unique_paths, compatible_next_tools) | 252 multilabels_paths = self.prepare_paths_labels_dictionary(dictionary, rev_dict, all_unique_paths, compatible_next_tools) |
| 231 | 253 |
| 232 print("Complete data: %d" % len(multilabels_paths)) | 254 print("Complete data: %d" % len(multilabels_paths)) |
| 233 train_paths_dict, test_paths_dict = self.split_test_train_data(multilabels_paths) | 255 train_paths_dict, test_paths_dict = self.split_test_train_data(multilabels_paths) |
| 234 | 256 |
| 257 # get sample frequency | |
| 258 l_tool_freq, inv_last_tool_freq = self.get_train_last_tool_freq(train_paths_dict, rev_dict) | |
| 259 | |
| 235 print("Train data: %d" % len(train_paths_dict)) | 260 print("Train data: %d" % len(train_paths_dict)) |
| 236 print("Test data: %d" % len(test_paths_dict)) | 261 print("Test data: %d" % len(test_paths_dict)) |
| 237 | 262 |
| 238 test_data, test_labels = self.pad_paths(test_paths_dict, num_classes) | 263 print("Padding train and test data...") |
| 239 train_data, train_labels = self.pad_paths(train_paths_dict, num_classes) | 264 # pad training and test data with leading zeros |
| 265 test_data, test_labels = self.pad_paths(test_paths_dict, num_classes, standard_connections, rev_dict) | |
| 266 train_data, train_labels = self.pad_paths(train_paths_dict, num_classes, standard_connections, rev_dict) | |
| 267 | |
| 268 l_tool_tr_samples = self.get_toolid_samples(train_data, l_tool_freq) | |
| 240 | 269 |
| 241 # Predict tools usage | 270 # Predict tools usage |
| 242 print("Predicting tools' usage...") | 271 print("Predicting tools' usage...") |
| 243 usage_pred = predict_tool_usage.ToolPopularity() | 272 usage_pred = predict_tool_usage.ToolPopularity() |
| 244 usage = usage_pred.extract_tool_usage(tool_usage_path, cutoff_date, dictionary) | 273 usage = usage_pred.extract_tool_usage(tool_usage_path, cutoff_date, dictionary) |
| 245 tool_usage_prediction = usage_pred.get_pupularity_prediction(usage) | 274 tool_usage_prediction = usage_pred.get_pupularity_prediction(usage) |
| 246 tool_predicted_usage = self.get_predicted_usage(dictionary, tool_usage_prediction) | 275 t_pred_usage = self.get_predicted_usage(dictionary, tool_usage_prediction) |
| 247 | 276 |
| 248 # get class weights using the predicted usage for each tool | 277 # get class weights using the predicted usage for each tool |
| 249 class_weights = self.assign_class_weights(train_labels.shape[1], tool_predicted_usage) | 278 class_weights = self.assign_class_weights(num_classes, t_pred_usage) |
| 250 | 279 |
| 251 return train_data, train_labels, test_data, test_labels, dictionary, reverse_dictionary, class_weights, tool_predicted_usage | 280 return train_data, train_labels, test_data, test_labels, dictionary, rev_dict, class_weights, t_pred_usage, l_tool_freq, l_tool_tr_samples |
