Mercurial > repos > bgruening > keras_model_builder
comparison stacking_ensembles.py @ 0:ac8bef635fcb draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 60f0fbc0eafd7c11bc60fb6c77f2937782efd8a9-dirty
| author | bgruening |
|---|---|
| date | Fri, 09 Aug 2019 06:22:23 -0400 |
| parents | |
| children | 25d4cbb56e1a |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:ac8bef635fcb |
|---|---|
| 1 import argparse | |
| 2 import ast | |
| 3 import json | |
| 4 import mlxtend.regressor | |
| 5 import mlxtend.classifier | |
| 6 import pandas as pd | |
| 7 import pickle | |
| 8 import sklearn | |
| 9 import sys | |
| 10 import warnings | |
| 11 from sklearn import ensemble | |
| 12 | |
| 13 from galaxy_ml.utils import (load_model, get_cv, get_estimator, | |
| 14 get_search_params) | |
| 15 | |
| 16 | |
| 17 warnings.filterwarnings('ignore') | |
| 18 | |
| 19 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) | |
| 20 | |
| 21 | |
| 22 def main(inputs_path, output_obj, base_paths=None, meta_path=None, | |
| 23 outfile_params=None): | |
| 24 """ | |
| 25 Parameter | |
| 26 --------- | |
| 27 inputs_path : str | |
| 28 File path for Galaxy parameters | |
| 29 | |
| 30 output_obj : str | |
| 31 File path for ensemble estimator ouput | |
| 32 | |
| 33 base_paths : str | |
| 34 File path or paths concatenated by comma. | |
| 35 | |
| 36 meta_path : str | |
| 37 File path | |
| 38 | |
| 39 outfile_params : str | |
| 40 File path for params output | |
| 41 """ | |
| 42 with open(inputs_path, 'r') as param_handler: | |
| 43 params = json.load(param_handler) | |
| 44 | |
| 45 estimator_type = params['algo_selection']['estimator_type'] | |
| 46 # get base estimators | |
| 47 base_estimators = [] | |
| 48 for idx, base_file in enumerate(base_paths.split(',')): | |
| 49 if base_file and base_file != 'None': | |
| 50 with open(base_file, 'rb') as handler: | |
| 51 model = load_model(handler) | |
| 52 else: | |
| 53 estimator_json = (params['base_est_builder'][idx] | |
| 54 ['estimator_selector']) | |
| 55 model = get_estimator(estimator_json) | |
| 56 | |
| 57 if estimator_type.startswith('sklearn'): | |
| 58 named = model.__class__.__name__.lower() | |
| 59 named = 'base_%d_%s' % (idx, named) | |
| 60 base_estimators.append((named, model)) | |
| 61 else: | |
| 62 base_estimators.append(model) | |
| 63 | |
| 64 # get meta estimator, if applicable | |
| 65 if estimator_type.startswith('mlxtend'): | |
| 66 if meta_path: | |
| 67 with open(meta_path, 'rb') as f: | |
| 68 meta_estimator = load_model(f) | |
| 69 else: | |
| 70 estimator_json = (params['algo_selection'] | |
| 71 ['meta_estimator']['estimator_selector']) | |
| 72 meta_estimator = get_estimator(estimator_json) | |
| 73 | |
| 74 options = params['algo_selection']['options'] | |
| 75 | |
| 76 cv_selector = options.pop('cv_selector', None) | |
| 77 if cv_selector: | |
| 78 splitter, groups = get_cv(cv_selector) | |
| 79 options['cv'] = splitter | |
| 80 # set n_jobs | |
| 81 options['n_jobs'] = N_JOBS | |
| 82 | |
| 83 weights = options.pop('weights', None) | |
| 84 if weights: | |
| 85 options['weights'] = ast.literal_eval(weights) | |
| 86 | |
| 87 mod_and_name = estimator_type.split('_') | |
| 88 mod = sys.modules[mod_and_name[0]] | |
| 89 klass = getattr(mod, mod_and_name[1]) | |
| 90 | |
| 91 if estimator_type.startswith('sklearn'): | |
| 92 options['n_jobs'] = N_JOBS | |
| 93 ensemble_estimator = klass(base_estimators, **options) | |
| 94 | |
| 95 elif mod == mlxtend.classifier: | |
| 96 ensemble_estimator = klass( | |
| 97 classifiers=base_estimators, | |
| 98 meta_classifier=meta_estimator, | |
| 99 **options) | |
| 100 | |
| 101 else: | |
| 102 ensemble_estimator = klass( | |
| 103 regressors=base_estimators, | |
| 104 meta_regressor=meta_estimator, | |
| 105 **options) | |
| 106 | |
| 107 print(ensemble_estimator) | |
| 108 for base_est in base_estimators: | |
| 109 print(base_est) | |
| 110 | |
| 111 with open(output_obj, 'wb') as out_handler: | |
| 112 pickle.dump(ensemble_estimator, out_handler, pickle.HIGHEST_PROTOCOL) | |
| 113 | |
| 114 if params['get_params'] and outfile_params: | |
| 115 results = get_search_params(ensemble_estimator) | |
| 116 df = pd.DataFrame(results, columns=['', 'Parameter', 'Value']) | |
| 117 df.to_csv(outfile_params, sep='\t', index=False) | |
| 118 | |
| 119 | |
| 120 if __name__ == '__main__': | |
| 121 aparser = argparse.ArgumentParser() | |
| 122 aparser.add_argument("-b", "--bases", dest="bases") | |
| 123 aparser.add_argument("-m", "--meta", dest="meta") | |
| 124 aparser.add_argument("-i", "--inputs", dest="inputs") | |
| 125 aparser.add_argument("-o", "--outfile", dest="outfile") | |
| 126 aparser.add_argument("-p", "--outfile_params", dest="outfile_params") | |
| 127 args = aparser.parse_args() | |
| 128 | |
| 129 main(args.inputs, args.outfile, base_paths=args.bases, | |
| 130 meta_path=args.meta, outfile_params=args.outfile_params) |
