Mercurial > repos > bgruening > sklearn_stacking_ensemble_models
comparison stacking_ensembles.py @ 9:b8c92e94ac1d draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
| author | bgruening | 
|---|---|
| date | Tue, 13 Apr 2021 15:49:42 +0000 | 
| parents | 8b5b653ba1ed | 
| children | 2d890789ac48 | 
   comparison
  equal
  deleted
  inserted
  replaced
| 8:6430b9b00d2f | 9:b8c92e94ac1d | 
|---|---|
| 3 import json | 3 import json | 
| 4 import mlxtend.regressor | 4 import mlxtend.regressor | 
| 5 import mlxtend.classifier | 5 import mlxtend.classifier | 
| 6 import pandas as pd | 6 import pandas as pd | 
| 7 import pickle | 7 import pickle | 
| 8 import sklearn | |
| 9 import sys | 8 import sys | 
| 10 import warnings | 9 import warnings | 
| 11 from sklearn import ensemble | 10 from galaxy_ml.utils import load_model, get_cv, get_estimator, get_search_params | 
| 12 | |
| 13 from galaxy_ml.utils import (load_model, get_cv, get_estimator, | |
| 14 get_search_params) | |
| 15 | 11 | 
| 16 | 12 | 
| 17 warnings.filterwarnings('ignore') | 13 warnings.filterwarnings("ignore") | 
| 18 | 14 | 
| 19 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) | 15 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) | 
| 20 | 16 | 
| 21 | 17 | 
| 22 def main(inputs_path, output_obj, base_paths=None, meta_path=None, | 18 def main(inputs_path, output_obj, base_paths=None, meta_path=None, outfile_params=None): | 
| 23 outfile_params=None): | |
| 24 """ | 19 """ | 
| 25 Parameter | 20 Parameter | 
| 26 --------- | 21 --------- | 
| 27 inputs_path : str | 22 inputs_path : str | 
| 28 File path for Galaxy parameters | 23 File path for Galaxy parameters | 
| 37 File path | 32 File path | 
| 38 | 33 | 
| 39 outfile_params : str | 34 outfile_params : str | 
| 40 File path for params output | 35 File path for params output | 
| 41 """ | 36 """ | 
| 42 with open(inputs_path, 'r') as param_handler: | 37 with open(inputs_path, "r") as param_handler: | 
| 43 params = json.load(param_handler) | 38 params = json.load(param_handler) | 
| 44 | 39 | 
| 45 estimator_type = params['algo_selection']['estimator_type'] | 40 estimator_type = params["algo_selection"]["estimator_type"] | 
| 46 # get base estimators | 41 # get base estimators | 
| 47 base_estimators = [] | 42 base_estimators = [] | 
| 48 for idx, base_file in enumerate(base_paths.split(',')): | 43 for idx, base_file in enumerate(base_paths.split(",")): | 
| 49 if base_file and base_file != 'None': | 44 if base_file and base_file != "None": | 
| 50 with open(base_file, 'rb') as handler: | 45 with open(base_file, "rb") as handler: | 
| 51 model = load_model(handler) | 46 model = load_model(handler) | 
| 52 else: | 47 else: | 
| 53 estimator_json = (params['base_est_builder'][idx] | 48 estimator_json = params["base_est_builder"][idx]["estimator_selector"] | 
| 54 ['estimator_selector']) | |
| 55 model = get_estimator(estimator_json) | 49 model = get_estimator(estimator_json) | 
| 56 | 50 | 
| 57 if estimator_type.startswith('sklearn'): | 51 if estimator_type.startswith("sklearn"): | 
| 58 named = model.__class__.__name__.lower() | 52 named = model.__class__.__name__.lower() | 
| 59 named = 'base_%d_%s' % (idx, named) | 53 named = "base_%d_%s" % (idx, named) | 
| 60 base_estimators.append((named, model)) | 54 base_estimators.append((named, model)) | 
| 61 else: | 55 else: | 
| 62 base_estimators.append(model) | 56 base_estimators.append(model) | 
| 63 | 57 | 
| 64 # get meta estimator, if applicable | 58 # get meta estimator, if applicable | 
| 65 if estimator_type.startswith('mlxtend'): | 59 if estimator_type.startswith("mlxtend"): | 
| 66 if meta_path: | 60 if meta_path: | 
| 67 with open(meta_path, 'rb') as f: | 61 with open(meta_path, "rb") as f: | 
| 68 meta_estimator = load_model(f) | 62 meta_estimator = load_model(f) | 
| 69 else: | 63 else: | 
| 70 estimator_json = (params['algo_selection'] | 64 estimator_json = params["algo_selection"]["meta_estimator"]["estimator_selector"] | 
| 71 ['meta_estimator']['estimator_selector']) | |
| 72 meta_estimator = get_estimator(estimator_json) | 65 meta_estimator = get_estimator(estimator_json) | 
| 73 | 66 | 
| 74 options = params['algo_selection']['options'] | 67 options = params["algo_selection"]["options"] | 
| 75 | 68 | 
| 76 cv_selector = options.pop('cv_selector', None) | 69 cv_selector = options.pop("cv_selector", None) | 
| 77 if cv_selector: | 70 if cv_selector: | 
| 78 splitter, groups = get_cv(cv_selector) | 71 splitter, _groups = get_cv(cv_selector) | 
| 79 options['cv'] = splitter | 72 options["cv"] = splitter | 
| 80 # set n_jobs | 73 # set n_jobs | 
| 81 options['n_jobs'] = N_JOBS | 74 options["n_jobs"] = N_JOBS | 
| 82 | 75 | 
| 83 weights = options.pop('weights', None) | 76 weights = options.pop("weights", None) | 
| 84 if weights: | 77 if weights: | 
| 85 weights = ast.literal_eval(weights) | 78 weights = ast.literal_eval(weights) | 
| 86 if weights: | 79 if weights: | 
| 87 options['weights'] = weights | 80 options["weights"] = weights | 
| 88 | 81 | 
| 89 mod_and_name = estimator_type.split('_') | 82 mod_and_name = estimator_type.split("_") | 
| 90 mod = sys.modules[mod_and_name[0]] | 83 mod = sys.modules[mod_and_name[0]] | 
| 91 klass = getattr(mod, mod_and_name[1]) | 84 klass = getattr(mod, mod_and_name[1]) | 
| 92 | 85 | 
| 93 if estimator_type.startswith('sklearn'): | 86 if estimator_type.startswith("sklearn"): | 
| 94 options['n_jobs'] = N_JOBS | 87 options["n_jobs"] = N_JOBS | 
| 95 ensemble_estimator = klass(base_estimators, **options) | 88 ensemble_estimator = klass(base_estimators, **options) | 
| 96 | 89 | 
| 97 elif mod == mlxtend.classifier: | 90 elif mod == mlxtend.classifier: | 
| 98 ensemble_estimator = klass( | 91 ensemble_estimator = klass(classifiers=base_estimators, meta_classifier=meta_estimator, **options) | 
| 99 classifiers=base_estimators, | |
| 100 meta_classifier=meta_estimator, | |
| 101 **options) | |
| 102 | 92 | 
| 103 else: | 93 else: | 
| 104 ensemble_estimator = klass( | 94 ensemble_estimator = klass(regressors=base_estimators, meta_regressor=meta_estimator, **options) | 
| 105 regressors=base_estimators, | |
| 106 meta_regressor=meta_estimator, | |
| 107 **options) | |
| 108 | 95 | 
| 109 print(ensemble_estimator) | 96 print(ensemble_estimator) | 
| 110 for base_est in base_estimators: | 97 for base_est in base_estimators: | 
| 111 print(base_est) | 98 print(base_est) | 
| 112 | 99 | 
| 113 with open(output_obj, 'wb') as out_handler: | 100 with open(output_obj, "wb") as out_handler: | 
| 114 pickle.dump(ensemble_estimator, out_handler, pickle.HIGHEST_PROTOCOL) | 101 pickle.dump(ensemble_estimator, out_handler, pickle.HIGHEST_PROTOCOL) | 
| 115 | 102 | 
| 116 if params['get_params'] and outfile_params: | 103 if params["get_params"] and outfile_params: | 
| 117 results = get_search_params(ensemble_estimator) | 104 results = get_search_params(ensemble_estimator) | 
| 118 df = pd.DataFrame(results, columns=['', 'Parameter', 'Value']) | 105 df = pd.DataFrame(results, columns=["", "Parameter", "Value"]) | 
| 119 df.to_csv(outfile_params, sep='\t', index=False) | 106 df.to_csv(outfile_params, sep="\t", index=False) | 
| 120 | 107 | 
| 121 | 108 | 
| 122 if __name__ == '__main__': | 109 if __name__ == "__main__": | 
| 123 aparser = argparse.ArgumentParser() | 110 aparser = argparse.ArgumentParser() | 
| 124 aparser.add_argument("-b", "--bases", dest="bases") | 111 aparser.add_argument("-b", "--bases", dest="bases") | 
| 125 aparser.add_argument("-m", "--meta", dest="meta") | 112 aparser.add_argument("-m", "--meta", dest="meta") | 
| 126 aparser.add_argument("-i", "--inputs", dest="inputs") | 113 aparser.add_argument("-i", "--inputs", dest="inputs") | 
| 127 aparser.add_argument("-o", "--outfile", dest="outfile") | 114 aparser.add_argument("-o", "--outfile", dest="outfile") | 
| 128 aparser.add_argument("-p", "--outfile_params", dest="outfile_params") | 115 aparser.add_argument("-p", "--outfile_params", dest="outfile_params") | 
| 129 args = aparser.parse_args() | 116 args = aparser.parse_args() | 
| 130 | 117 | 
| 131 main(args.inputs, args.outfile, base_paths=args.bases, | 118 main( | 
| 132 meta_path=args.meta, outfile_params=args.outfile_params) | 119 args.inputs, | 
| 120 args.outfile, | |
| 121 base_paths=args.bases, | |
| 122 meta_path=args.meta, | |
| 123 outfile_params=args.outfile_params, | |
| 124 ) | 
