Mercurial > repos > bgruening > create_tool_recommendation_model
comparison predict_tool_usage.py @ 4:f0da532be419 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 5eebc0cb44e71f581d548b7e842002705dd155eb"
| author | bgruening |
|---|---|
| date | Fri, 06 May 2022 09:04:44 +0000 |
| parents | 50753817983a |
| children | 9ec705bd11cb |
comparison
equal
deleted
inserted
replaced
| 3:98bc44d17561 | 4:f0da532be419 |
|---|---|
| 1 """ | 1 """ |
| 2 Predict tool usage to weigh the predicted tools | 2 Predict tool usage to weigh the predicted tools |
| 3 """ | 3 """ |
| 4 | 4 |
| 5 import collections | |
| 6 import csv | |
| 5 import os | 7 import os |
| 8 import warnings | |
| 9 | |
| 6 import numpy as np | 10 import numpy as np |
| 7 import warnings | 11 import utils |
| 8 import csv | |
| 9 import collections | |
| 10 | |
| 11 from sklearn.svm import SVR | |
| 12 from sklearn.model_selection import GridSearchCV | 12 from sklearn.model_selection import GridSearchCV |
| 13 from sklearn.pipeline import Pipeline | 13 from sklearn.pipeline import Pipeline |
| 14 | 14 from sklearn.svm import SVR |
| 15 import utils | |
| 16 | 15 |
| 17 warnings.filterwarnings("ignore") | 16 warnings.filterwarnings("ignore") |
| 18 | 17 |
| 19 main_path = os.getcwd() | 18 main_path = os.getcwd() |
| 20 | 19 |
| 21 | 20 |
| 22 class ToolPopularity: | 21 class ToolPopularity: |
| 23 | |
| 24 def __init__(self): | 22 def __init__(self): |
| 25 """ Init method. """ | 23 """ Init method. """ |
| 26 | 24 |
| 27 def extract_tool_usage(self, tool_usage_file, cutoff_date, dictionary): | 25 def extract_tool_usage(self, tool_usage_file, cutoff_date, dictionary): |
| 28 """ | 26 """ |
| 29 Extract the tool usage over time for each tool | 27 Extract the tool usage over time for each tool |
| 30 """ | 28 """ |
| 31 tool_usage_dict = dict() | 29 tool_usage_dict = dict() |
| 32 all_dates = list() | 30 all_dates = list() |
| 33 all_tool_list = list(dictionary.keys()) | 31 all_tool_list = list(dictionary.keys()) |
| 34 with open(tool_usage_file, 'rt') as usage_file: | 32 with open(tool_usage_file, "rt") as usage_file: |
| 35 tool_usage = csv.reader(usage_file, delimiter='\t') | 33 tool_usage = csv.reader(usage_file, delimiter="\t") |
| 36 for index, row in enumerate(tool_usage): | 34 for index, row in enumerate(tool_usage): |
| 37 if (str(row[1]) > cutoff_date) is True: | 35 row = [item.strip() for item in row] |
| 36 if (str(row[1]).strip() > cutoff_date) is True: | |
| 38 tool_id = utils.format_tool_id(row[0]) | 37 tool_id = utils.format_tool_id(row[0]) |
| 39 if tool_id in all_tool_list: | 38 if tool_id in all_tool_list: |
| 40 all_dates.append(row[1]) | 39 all_dates.append(row[1]) |
| 41 if tool_id not in tool_usage_dict: | 40 if tool_id not in tool_usage_dict: |
| 42 tool_usage_dict[tool_id] = dict() | 41 tool_usage_dict[tool_id] = dict() |
| 65 """ | 64 """ |
| 66 Fit a curve for the tool usage over time to predict future tool usage | 65 Fit a curve for the tool usage over time to predict future tool usage |
| 67 """ | 66 """ |
| 68 epsilon = 0.0 | 67 epsilon = 0.0 |
| 69 cv = 5 | 68 cv = 5 |
| 70 s_typ = 'neg_mean_absolute_error' | 69 s_typ = "neg_mean_absolute_error" |
| 71 n_jobs = 4 | 70 n_jobs = 4 |
| 72 s_error = 1 | 71 s_error = 1 |
| 73 iid = True | |
| 74 tr_score = False | 72 tr_score = False |
| 75 try: | 73 try: |
| 76 pipe = Pipeline(steps=[('regressor', SVR(gamma='scale'))]) | 74 pipe = Pipeline(steps=[("regressor", SVR(gamma="scale"))]) |
| 77 param_grid = { | 75 param_grid = { |
| 78 'regressor__kernel': ['rbf', 'poly', 'linear'], | 76 "regressor__kernel": ["rbf", "poly", "linear"], |
| 79 'regressor__degree': [2, 3] | 77 "regressor__degree": [2, 3], |
| 80 } | 78 } |
| 81 search = GridSearchCV(pipe, param_grid, iid=iid, cv=cv, scoring=s_typ, n_jobs=n_jobs, error_score=s_error, return_train_score=tr_score) | 79 search = GridSearchCV( |
| 80 pipe, | |
| 81 param_grid, | |
| 82 cv=cv, | |
| 83 scoring=s_typ, | |
| 84 n_jobs=n_jobs, | |
| 85 error_score=s_error, | |
| 86 return_train_score=tr_score, | |
| 87 ) | |
| 82 search.fit(x_reshaped, y_reshaped.ravel()) | 88 search.fit(x_reshaped, y_reshaped.ravel()) |
| 83 model = search.best_estimator_ | 89 model = search.best_estimator_ |
| 84 # set the next time point to get prediction for | 90 # set the next time point to get prediction for |
| 85 prediction_point = np.reshape([x_reshaped[-1][0] + 1], (1, 1)) | 91 prediction_point = np.reshape([x_reshaped[-1][0] + 1], (1, 1)) |
| 86 prediction = model.predict(prediction_point) | 92 prediction = model.predict(prediction_point) |
| 87 if prediction < epsilon: | 93 if prediction < epsilon: |
| 88 prediction = [epsilon] | 94 prediction = [epsilon] |
| 89 return prediction[0] | 95 return prediction[0] |
| 90 except Exception: | 96 except Exception as e: |
| 97 print(e) | |
| 91 return epsilon | 98 return epsilon |
| 92 | 99 |
| 93 def get_pupularity_prediction(self, tools_usage): | 100 def get_pupularity_prediction(self, tools_usage): |
| 94 """ | 101 """ |
| 95 Get the popularity prediction for each tool | 102 Get the popularity prediction for each tool |
