Mercurial > repos > bgruening > sklearn_generalized_linear
comparison utils.py @ 20:e426fe6b1138 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 8cf3d813ec755166ee0bd517b4ecbbd4f84d4df1
| author | bgruening |
|---|---|
| date | Thu, 23 Aug 2018 16:12:17 -0400 |
| parents | 0a6a4da30a0a |
| children | da10fb828033 |
comparison
equal
deleted
inserted
replaced
| 19:0a6a4da30a0a | 20:e426fe6b1138 |
|---|---|
| 1 import sys | 1 import sys |
| 2 import os | 2 import os |
| 3 import pandas | 3 import pandas |
| 4 import re | 4 import re |
| 5 import pickle | 5 import cPickle as pickle |
| 6 import warnings | 6 import warnings |
| 7 import numpy as np | 7 import numpy as np |
| 8 import xgboost | 8 import xgboost |
| 9 import scipy | 9 import scipy |
| 10 import sklearn | 10 import sklearn |
| 11 import ast | 11 import ast |
| 12 from asteval import Interpreter, make_symbol_table | 12 from asteval import Interpreter, make_symbol_table |
| 13 from sklearn import metrics, model_selection, ensemble, svm, linear_model, naive_bayes, tree, neighbors | 13 from sklearn import (cluster, decomposition, ensemble, feature_extraction, feature_selection, |
| 14 gaussian_process, kernel_approximation, linear_model, metrics, | |
| 15 model_selection, naive_bayes, neighbors, pipeline, preprocessing, | |
| 16 svm, linear_model, tree, discriminant_analysis) | |
| 14 | 17 |
| 15 N_JOBS = int( os.environ.get('GALAXY_SLOTS', 1) ) | 18 N_JOBS = int( os.environ.get('GALAXY_SLOTS', 1) ) |
| 19 | |
| 20 class SafePickler(object): | |
| 21 """ | |
| 22 Used to safely deserialize scikit-learn model objects serialized by cPickle.dump | |
| 23 Usage: | |
| 24 eg.: SafePickler.load(pickled_file_object) | |
| 25 """ | |
| 26 @classmethod | |
| 27 def find_class(self, module, name): | |
| 28 | |
| 29 bad_names = ('and', 'as', 'assert', 'break', 'class', 'continue', | |
| 30 'def', 'del', 'elif', 'else', 'except', 'exec', | |
| 31 'finally', 'for', 'from', 'global', 'if', 'import', | |
| 32 'in', 'is', 'lambda', 'not', 'or', 'pass', 'print', | |
| 33 'raise', 'return', 'try', 'system', 'while', 'with', | |
| 34 'True', 'False', 'None', 'eval', 'execfile', '__import__', | |
| 35 '__package__', '__subclasses__', '__bases__', '__globals__', | |
| 36 '__code__', '__closure__', '__func__', '__self__', '__module__', | |
| 37 '__dict__', '__class__', '__call__', '__get__', | |
| 38 '__getattribute__', '__subclasshook__', '__new__', | |
| 39 '__init__', 'func_globals', 'func_code', 'func_closure', | |
| 40 'im_class', 'im_func', 'im_self', 'gi_code', 'gi_frame', | |
| 41 '__asteval__', 'f_locals', '__mro__') | |
| 42 good_names = ('copy_reg._reconstructor', '__builtin__.object') | |
| 43 | |
| 44 if re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name): | |
| 45 fullname = module + '.' + name | |
| 46 if (fullname in good_names)\ | |
| 47 or ( ( module.startswith('sklearn.') | |
| 48 or module.startswith('xgboost.') | |
| 49 or module.startswith('skrebate.') | |
| 50 or module.startswith('numpy.') | |
| 51 or module == 'numpy' | |
| 52 ) | |
| 53 and (name not in bad_names) | |
| 54 ) : | |
| 55 # TODO: replace with a whitelist checker | |
| 56 if fullname not in SK_NAMES + SKR_NAMES + XGB_NAMES + NUMPY_NAMES + good_names: | |
| 57 print("Warning: global %s is not in pickler whitelist yet and will loss support soon. Contact tool author or leave a message at github.com" % fullname) | |
| 58 mod = sys.modules[module] | |
| 59 return getattr(mod, name) | |
| 60 | |
| 61 raise pickle.UnpicklingError("global '%s' is forbidden" % fullname) | |
| 62 | |
| 63 @classmethod | |
| 64 def load(self, file): | |
| 65 obj = pickle.Unpickler(file) | |
| 66 obj.find_global = self.find_class | |
| 67 return obj.load() | |
| 16 | 68 |
| 17 def read_columns(f, c=None, c_option='by_index_number', return_df=False, **args): | 69 def read_columns(f, c=None, c_option='by_index_number', return_df=False, **args): |
| 18 data = pandas.read_csv(f, **args) | 70 data = pandas.read_csv(f, **args) |
| 19 if c_option == 'by_index_number': | 71 if c_option == 'by_index_number': |
| 20 cols = list(map(lambda x: x - 1, c)) | 72 cols = list(map(lambda x: x - 1, c)) |
| 46 if not options['threshold'] or options['threshold'] == 'None': | 98 if not options['threshold'] or options['threshold'] == 'None': |
| 47 options['threshold'] = None | 99 options['threshold'] = None |
| 48 if inputs['model_inputter']['input_mode'] == 'prefitted': | 100 if inputs['model_inputter']['input_mode'] == 'prefitted': |
| 49 model_file = inputs['model_inputter']['fitted_estimator'] | 101 model_file = inputs['model_inputter']['fitted_estimator'] |
| 50 with open(model_file, 'rb') as model_handler: | 102 with open(model_file, 'rb') as model_handler: |
| 51 fitted_estimator = pickle.load(model_handler) | 103 fitted_estimator = SafePickler.load(model_handler) |
| 52 new_selector = selector(fitted_estimator, prefit=True, **options) | 104 new_selector = selector(fitted_estimator, prefit=True, **options) |
| 53 else: | 105 else: |
| 54 estimator_json = inputs['model_inputter']["estimator_selector"] | 106 estimator_json = inputs['model_inputter']["estimator_selector"] |
| 55 estimator = get_estimator(estimator_json) | 107 estimator = get_estimator(estimator_json) |
| 56 new_selector = selector(estimator, **options) | 108 new_selector = selector(estimator, **options) |
| 130 | 182 |
| 131 syms = make_symbol_table(use_numpy=False, **new_syms) | 183 syms = make_symbol_table(use_numpy=False, **new_syms) |
| 132 | 184 |
| 133 if load_scipy: | 185 if load_scipy: |
| 134 scipy_distributions = scipy.stats.distributions.__dict__ | 186 scipy_distributions = scipy.stats.distributions.__dict__ |
| 135 for key in scipy_distributions.keys(): | 187 for k, v in scipy_distributions.items(): |
| 136 if isinstance(scipy_distributions[key], (scipy.stats.rv_continuous, scipy.stats.rv_discrete)): | 188 if isinstance(v, (scipy.stats.rv_continuous, scipy.stats.rv_discrete)): |
| 137 syms['scipy_stats_' + key] = scipy_distributions[key] | 189 syms['scipy_stats_' + k] = v |
| 138 | 190 |
| 139 if load_numpy: | 191 if load_numpy: |
| 140 from_numpy_random = ['beta', 'binomial', 'bytes', 'chisquare', 'choice', 'dirichlet', 'division', | 192 from_numpy_random = ['beta', 'binomial', 'bytes', 'chisquare', 'choice', 'dirichlet', 'division', |
| 141 'exponential', 'f', 'gamma', 'geometric', 'gumbel', 'hypergeometric', | 193 'exponential', 'f', 'gamma', 'geometric', 'gumbel', 'hypergeometric', |
| 142 'laplace', 'logistic', 'lognormal', 'logseries', 'mtrand', 'multinomial', | 194 'laplace', 'logistic', 'lognormal', 'logseries', 'mtrand', 'multinomial', |
