Mercurial > repos > bgruening > sklearn_generalized_linear
comparison simple_model_fit.py @ 30:c0e3e32f0801 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
| author | bgruening | 
|---|---|
| date | Thu, 07 Nov 2019 05:02:24 -0500 | 
| parents | d3496640fec0 | 
| children | b9a8876452cf | 
   comparison
  equal
  deleted
  inserted
  replaced
| 29:d3496640fec0 | 30:c0e3e32f0801 | 
|---|---|
| 3 import pandas as pd | 3 import pandas as pd | 
| 4 import pickle | 4 import pickle | 
| 5 | 5 | 
| 6 from galaxy_ml.utils import load_model, read_columns | 6 from galaxy_ml.utils import load_model, read_columns | 
| 7 from sklearn.pipeline import Pipeline | 7 from sklearn.pipeline import Pipeline | 
| 8 | |
| 9 | |
| 10 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) | |
| 11 | |
| 12 | |
| 13 # TODO import from galaxy_ml.utils in future versions | |
| 14 def clean_params(estimator, n_jobs=None): | |
| 15 """clean unwanted hyperparameter settings | |
| 16 | |
| 17 If n_jobs is not None, set it into the estimator, if applicable | |
| 18 | |
| 19 Return | |
| 20 ------ | |
| 21 Cleaned estimator object | |
| 22 """ | |
| 23 ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN', | |
| 24 'ReduceLROnPlateau', 'CSVLogger', 'None') | |
| 25 | |
| 26 estimator_params = estimator.get_params() | |
| 27 | |
| 28 for name, p in estimator_params.items(): | |
| 29 # all potential unauthorized file write | |
| 30 if name == 'memory' or name.endswith('__memory') \ | |
| 31 or name.endswith('_path'): | |
| 32 new_p = {name: None} | |
| 33 estimator.set_params(**new_p) | |
| 34 elif n_jobs is not None and (name == 'n_jobs' or | |
| 35 name.endswith('__n_jobs')): | |
| 36 new_p = {name: n_jobs} | |
| 37 estimator.set_params(**new_p) | |
| 38 elif name.endswith('callbacks'): | |
| 39 for cb in p: | |
| 40 cb_type = cb['callback_selection']['callback_type'] | |
| 41 if cb_type not in ALLOWED_CALLBACKS: | |
| 42 raise ValueError( | |
| 43 "Prohibited callback type: %s!" % cb_type) | |
| 44 | |
| 45 return estimator | |
| 8 | 46 | 
| 9 | 47 | 
| 10 def _get_X_y(params, infile1, infile2): | 48 def _get_X_y(params, infile1, infile2): | 
| 11 """ read from inputs and output X and y | 49 """ read from inputs and output X and y | 
| 12 | 50 | 
| 105 params = json.load(param_handler) | 143 params = json.load(param_handler) | 
| 106 | 144 | 
| 107 # load model | 145 # load model | 
| 108 with open(infile_estimator, 'rb') as est_handler: | 146 with open(infile_estimator, 'rb') as est_handler: | 
| 109 estimator = load_model(est_handler) | 147 estimator = load_model(est_handler) | 
| 148 estimator = clean_params(estimator, n_jobs=N_JOBS) | |
| 110 | 149 | 
| 111 X_train, y_train = _get_X_y(params, infile1, infile2) | 150 X_train, y_train = _get_X_y(params, infile1, infile2) | 
| 112 | 151 | 
| 113 estimator.fit(X_train, y_train) | 152 estimator.fit(X_train, y_train) | 
| 114 | 153 | 
