Mercurial > repos > bgruening > sklearn_model_validation
changeset 11:61844bce4115 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit b1e5fa3170484d2cc3396f2abe99bb8cfcfa9c65
| author | bgruening | 
|---|---|
| date | Tue, 07 Aug 2018 05:44:02 -0400 | 
| parents | 7a32e580f45d | 
| children | 86b57b062f96 | 
| files | main_macros.xml model_validation.xml | 
| diffstat | 2 files changed, 64 insertions(+), 44 deletions(-) [+] | 
line wrap: on
 line diff
--- a/main_macros.xml Sat Aug 04 17:33:18 2018 -0400 +++ b/main_macros.xml Tue Aug 07 05:44:02 2018 -0400 @@ -100,53 +100,55 @@ return X, y </token> + <token name="@SAFE_EVAL_FUNCTION@"> +def safe_eval(literal): + + FROM_SCIPY_STATS = [ 'bernoulli', 'binom', 'boltzmann', 'dlaplace', 'geom', 'hypergeom', + 'logser', 'nbinom', 'planck', 'poisson', 'randint', 'skellam', 'zipf' ] + + FROM_NUMPY_RANDOM = [ 'beta', 'binomial', 'bytes', 'chisquare', 'choice', 'dirichlet', 'division', + 'exponential', 'f', 'gamma', 'geometric', 'gumbel', 'hypergeometric', + 'laplace', 'logistic', 'lognormal', 'logseries', 'mtrand', 'multinomial', + 'multivariate_normal', 'negative_binomial', 'noncentral_chisquare', 'noncentral_f', + 'normal', 'pareto', 'permutation', 'poisson', 'power', 'rand', 'randint', + 'randn', 'random', 'random_integers', 'random_sample', 'ranf', 'rayleigh', + 'sample', 'seed', 'set_state', 'shuffle', 'standard_cauchy', 'standard_exponential', + 'standard_gamma', 'standard_normal', 'standard_t', 'triangular', 'uniform', + 'vonmises', 'wald', 'weibull', 'zipf' ] + + # File opening and other unneeded functions could be dropped + UNWANTED = ['open', 'type', 'dir', 'id', 'str', 'repr'] + + # Allowed symbol table. Add more if needed. + new_syms = { + 'np_arange': getattr(np, 'arange'), + 'ensemble_ExtraTreesClassifier': getattr(ensemble, 'ExtraTreesClassifier') + } + + syms = make_symbol_table(use_numpy=False, **new_syms) + + for method in FROM_SCIPY_STATS: + syms['scipy_stats_' + method] = getattr(scipy.stats, method) + + for func in FROM_NUMPY_RANDOM: + syms['np_random_' + func] = getattr(np.random, func) + + for key in UNWANTED: + syms.pop(key, None) + + aeval = Interpreter(symtable=syms, use_numpy=False, minimal=False, + no_if=True, no_for=True, no_while=True, no_try=True, + no_functiondef=True, no_ifexp=True, no_listcomp=False, + no_augassign=False, no_assert=True, no_delete=True, + no_raise=True, no_print=True) + + return aeval(literal) + </token> + <token name="@GET_SEARCH_PARAMS_FUNCTION@"> def get_search_params(params_builder): search_params = {} - def safe_eval(literal): - - FROM_SCIPY_STATS = [ 'bernoulli', 'binom', 'boltzmann', 'dlaplace', 'geom', 'hypergeom', - 'logser', 'nbinom', 'planck', 'poisson', 'randint', 'skellam', 'zipf' ] - - FROM_NUMPY_RANDOM = [ 'beta', 'binomial', 'bytes', 'chisquare', 'choice', 'dirichlet', 'division', - 'exponential', 'f', 'gamma', 'geometric', 'gumbel', 'hypergeometric', - 'laplace', 'logistic', 'lognormal', 'logseries', 'mtrand', 'multinomial', - 'multivariate_normal', 'negative_binomial', 'noncentral_chisquare', 'noncentral_f', - 'normal', 'pareto', 'permutation', 'poisson', 'power', 'rand', 'randint', - 'randn', 'random', 'random_integers', 'random_sample', 'ranf', 'rayleigh', - 'sample', 'seed', 'set_state', 'shuffle', 'standard_cauchy', 'standard_exponential', - 'standard_gamma', 'standard_normal', 'standard_t', 'triangular', 'uniform', - 'vonmises', 'wald', 'weibull', 'zipf' ] - - # File opening and other unneeded functions could be dropped - UNWANTED = ['open', 'type', 'dir', 'id', 'str', 'repr'] - - # Allowed symbol table. Add more if needed. - new_syms = { - 'np_arange': getattr(np, 'arange'), - 'ensemble_ExtraTreesClassifier': getattr(ensemble, 'ExtraTreesClassifier') - } - - syms = make_symbol_table(use_numpy=False, **new_syms) - - for method in FROM_SCIPY_STATS: - syms['scipy_stats_' + method] = getattr(scipy.stats, method) - - for func in FROM_NUMPY_RANDOM: - syms['np_random_' + func] = getattr(np.random, func) - - for key in UNWANTED: - syms.pop(key, None) - - aeval = Interpreter(symtable=syms, use_numpy=False, minimal=False, - no_if=True, no_for=True, no_while=True, no_try=True, - no_functiondef=True, no_ifexp=True, no_listcomp=False, - no_augassign=False, no_assert=True, no_delete=True, - no_raise=True, no_print=True) - - return aeval(literal) - for p in params_builder['param_set']: search_p = p['search_param_selector']['search_p'] if search_p.strip() == '': @@ -189,6 +191,20 @@ return estimator </token> + <token name="@GET_CV_FUNCTION@"> +def get_cv(literal): + if literal == "": + return None + if re.match(r'^\d+$', literal): + return int(literal) + m = re.match(r'^(?P<method>\w+)\((?P<args>.*)\)$', literal) + if m: + my_class = getattr( model_selection, m.group('method') ) + args = safe_eval( 'dict('+ m.group('args') + ')' ) + return my_class( **args ) + sys.exit("Unsupported CV input: %s" %literal) + </token> + <xml name="python_requirements"> <requirements> <requirement type="package" version="2.7">python</requirement> @@ -1143,7 +1159,7 @@ </xml> <xml name="model_validation_common_options"> - <param argument="cv" type="integer" value="" optional="true" label="cv" help="The number of folds in a (Stratified)KFold" /> + <param argument="cv" type="text" value="" size="50" optional="true" label="cv" help="Optional. Integer or evalable splitter object, e.g., StratifiedKFold(n_splits=3, shuffle=True, random_state=10). Leave blank for default." /> <expand macro="n_jobs"/> <expand macro="verbose"/> <yield/>
--- a/model_validation.xml Sat Aug 04 17:33:18 2018 -0400 +++ b/model_validation.xml Tue Aug 07 05:44:02 2018 -0400 @@ -18,6 +18,7 @@ import sys import json import pandas +import re import ast import pickle import numpy as np @@ -28,6 +29,8 @@ @COLUMNS_FUNCTION@ @GET_ESTIMATOR_FUNCTION@ @FEATURE_SELECTOR_FUNCTION@ +@SAFE_EVAL_FUNCTION@ +@GET_CV_FUNCTION@ input_json_path = sys.argv[1] @@ -70,6 +73,7 @@ y=y.ravel() options = params["model_validation_functions"]["options"] +options['cv'] = get_cv( options['cv'] ) if 'scoring' in options and options['scoring'] == '': options['scoring'] = None if 'pre_dispatch' in options and options['pre_dispatch'] == '':
