Mercurial > repos > bgruening > sklearn_stacking_ensemble_models
comparison simple_model_fit.py @ 6:aae4725f152b draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
| author | bgruening |
|---|---|
| date | Thu, 07 Nov 2019 05:15:47 -0500 |
| parents | 8b5b653ba1ed |
| children | b8c92e94ac1d |
comparison
equal
deleted
inserted
replaced
| 5:8b5b653ba1ed | 6:aae4725f152b |
|---|---|
| 3 import pandas as pd | 3 import pandas as pd |
| 4 import pickle | 4 import pickle |
| 5 | 5 |
| 6 from galaxy_ml.utils import load_model, read_columns | 6 from galaxy_ml.utils import load_model, read_columns |
| 7 from sklearn.pipeline import Pipeline | 7 from sklearn.pipeline import Pipeline |
| 8 | |
| 9 | |
| 10 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) | |
| 11 | |
| 12 | |
| 13 # TODO import from galaxy_ml.utils in future versions | |
| 14 def clean_params(estimator, n_jobs=None): | |
| 15 """clean unwanted hyperparameter settings | |
| 16 | |
| 17 If n_jobs is not None, set it into the estimator, if applicable | |
| 18 | |
| 19 Return | |
| 20 ------ | |
| 21 Cleaned estimator object | |
| 22 """ | |
| 23 ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN', | |
| 24 'ReduceLROnPlateau', 'CSVLogger', 'None') | |
| 25 | |
| 26 estimator_params = estimator.get_params() | |
| 27 | |
| 28 for name, p in estimator_params.items(): | |
| 29 # all potential unauthorized file write | |
| 30 if name == 'memory' or name.endswith('__memory') \ | |
| 31 or name.endswith('_path'): | |
| 32 new_p = {name: None} | |
| 33 estimator.set_params(**new_p) | |
| 34 elif n_jobs is not None and (name == 'n_jobs' or | |
| 35 name.endswith('__n_jobs')): | |
| 36 new_p = {name: n_jobs} | |
| 37 estimator.set_params(**new_p) | |
| 38 elif name.endswith('callbacks'): | |
| 39 for cb in p: | |
| 40 cb_type = cb['callback_selection']['callback_type'] | |
| 41 if cb_type not in ALLOWED_CALLBACKS: | |
| 42 raise ValueError( | |
| 43 "Prohibited callback type: %s!" % cb_type) | |
| 44 | |
| 45 return estimator | |
| 8 | 46 |
| 9 | 47 |
| 10 def _get_X_y(params, infile1, infile2): | 48 def _get_X_y(params, infile1, infile2): |
| 11 """ read from inputs and output X and y | 49 """ read from inputs and output X and y |
| 12 | 50 |
| 105 params = json.load(param_handler) | 143 params = json.load(param_handler) |
| 106 | 144 |
| 107 # load model | 145 # load model |
| 108 with open(infile_estimator, 'rb') as est_handler: | 146 with open(infile_estimator, 'rb') as est_handler: |
| 109 estimator = load_model(est_handler) | 147 estimator = load_model(est_handler) |
| 148 estimator = clean_params(estimator, n_jobs=N_JOBS) | |
| 110 | 149 |
| 111 X_train, y_train = _get_X_y(params, infile1, infile2) | 150 X_train, y_train = _get_X_y(params, infile1, infile2) |
| 112 | 151 |
| 113 estimator.fit(X_train, y_train) | 152 estimator.fit(X_train, y_train) |
| 114 | 153 |
