Mercurial > repos > bgruening > create_tool_recommendation_model
comparison predict_tool_usage.py @ 0:22ebbac136c7 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
| author | bgruening |
|---|---|
| date | Wed, 28 Aug 2019 07:19:13 -0400 |
| parents | |
| children | 50753817983a |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:22ebbac136c7 |
|---|---|
| 1 """ | |
| 2 Predict tool usage to weigh the predicted tools | |
| 3 """ | |
| 4 | |
| 5 import os | |
| 6 import numpy as np | |
| 7 import warnings | |
| 8 import csv | |
| 9 import collections | |
| 10 | |
| 11 from sklearn.svm import SVR | |
| 12 from sklearn.model_selection import GridSearchCV | |
| 13 from sklearn.pipeline import Pipeline | |
| 14 | |
| 15 import utils | |
| 16 | |
| 17 warnings.filterwarnings("ignore") | |
| 18 | |
| 19 main_path = os.getcwd() | |
| 20 | |
| 21 | |
| 22 class ToolPopularity: | |
| 23 | |
| 24 @classmethod | |
| 25 def __init__(self): | |
| 26 """ Init method. """ | |
| 27 | |
| 28 @classmethod | |
| 29 def extract_tool_usage(self, tool_usage_file, cutoff_date, dictionary): | |
| 30 """ | |
| 31 Extract the tool usage over time for each tool | |
| 32 """ | |
| 33 tool_usage_dict = dict() | |
| 34 all_dates = list() | |
| 35 all_tool_list = list(dictionary.keys()) | |
| 36 with open(tool_usage_file, 'rt') as usage_file: | |
| 37 tool_usage = csv.reader(usage_file, delimiter='\t') | |
| 38 for index, row in enumerate(tool_usage): | |
| 39 if (str(row[1]) > cutoff_date) is True: | |
| 40 tool_id = utils.format_tool_id(row[0]) | |
| 41 if tool_id in all_tool_list: | |
| 42 all_dates.append(row[1]) | |
| 43 if tool_id not in tool_usage_dict: | |
| 44 tool_usage_dict[tool_id] = dict() | |
| 45 tool_usage_dict[tool_id][row[1]] = int(row[2]) | |
| 46 else: | |
| 47 curr_date = row[1] | |
| 48 # merge the usage of different version of tools into one | |
| 49 if curr_date in tool_usage_dict[tool_id]: | |
| 50 tool_usage_dict[tool_id][curr_date] += int(row[2]) | |
| 51 else: | |
| 52 tool_usage_dict[tool_id][curr_date] = int(row[2]) | |
| 53 # get unique dates | |
| 54 unique_dates = list(set(all_dates)) | |
| 55 for tool in tool_usage_dict: | |
| 56 usage = tool_usage_dict[tool] | |
| 57 # extract those dates for which tool's usage is not present in raw data | |
| 58 dates_not_present = list(set(unique_dates) ^ set(usage.keys())) | |
| 59 # impute the missing values by 0 | |
| 60 for dt in dates_not_present: | |
| 61 tool_usage_dict[tool][dt] = 0 | |
| 62 # sort the usage list by date | |
| 63 tool_usage_dict[tool] = collections.OrderedDict(sorted(usage.items())) | |
| 64 return tool_usage_dict | |
| 65 | |
| 66 @classmethod | |
| 67 def learn_tool_popularity(self, x_reshaped, y_reshaped): | |
| 68 """ | |
| 69 Fit a curve for the tool usage over time to predict future tool usage | |
| 70 """ | |
| 71 epsilon = 0.0 | |
| 72 cv = 5 | |
| 73 s_typ = 'neg_mean_absolute_error' | |
| 74 n_jobs = 4 | |
| 75 s_error = 1 | |
| 76 iid = True | |
| 77 tr_score = False | |
| 78 try: | |
| 79 pipe = Pipeline(steps=[('regressor', SVR(gamma='scale'))]) | |
| 80 param_grid = { | |
| 81 'regressor__kernel': ['rbf', 'poly', 'linear'], | |
| 82 'regressor__degree': [2, 3] | |
| 83 } | |
| 84 search = GridSearchCV(pipe, param_grid, iid=iid, cv=cv, scoring=s_typ, n_jobs=n_jobs, error_score=s_error, return_train_score=tr_score) | |
| 85 search.fit(x_reshaped, y_reshaped.ravel()) | |
| 86 model = search.best_estimator_ | |
| 87 # set the next time point to get prediction for | |
| 88 prediction_point = np.reshape([x_reshaped[-1][0] + 1], (1, 1)) | |
| 89 prediction = model.predict(prediction_point) | |
| 90 if prediction < epsilon: | |
| 91 prediction = [epsilon] | |
| 92 return prediction[0] | |
| 93 except Exception: | |
| 94 return epsilon | |
| 95 | |
| 96 @classmethod | |
| 97 def get_pupularity_prediction(self, tools_usage): | |
| 98 """ | |
| 99 Get the popularity prediction for each tool | |
| 100 """ | |
| 101 usage_prediction = dict() | |
| 102 for tool_name, usage in tools_usage.items(): | |
| 103 y_val = list() | |
| 104 x_val = list() | |
| 105 for x, y in usage.items(): | |
| 106 x_val.append(x) | |
| 107 y_val.append(y) | |
| 108 x_pos = np.arange(len(x_val)) | |
| 109 x_reshaped = x_pos.reshape(len(x_pos), 1) | |
| 110 y_reshaped = np.reshape(y_val, (len(x_pos), 1)) | |
| 111 prediction = np.round(self.learn_tool_popularity(x_reshaped, y_reshaped), 8) | |
| 112 usage_prediction[tool_name] = prediction | |
| 113 return usage_prediction |
