Mercurial > repos > bgruening > sklearn_fitted_model_eval
comparison stacking_ensembles.py @ 11:ed5472c523fa draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
| author | bgruening | 
|---|---|
| date | Wed, 09 Aug 2023 11:56:03 +0000 | 
| parents | c333698de5f4 | 
| children | 
   comparison
  equal
  deleted
  inserted
  replaced
| 10:7fbc6a108504 | 11:ed5472c523fa | 
|---|---|
| 1 import argparse | 1 import argparse | 
| 2 import ast | 2 import ast | 
| 3 import json | 3 import json | 
| 4 import pickle | |
| 5 import sys | 4 import sys | 
| 6 import warnings | 5 import warnings | 
| 6 from distutils.version import LooseVersion as Version | |
| 7 | 7 | 
| 8 import mlxtend.classifier | 8 import mlxtend.classifier | 
| 9 import mlxtend.regressor | 9 import mlxtend.regressor | 
| 10 import pandas as pd | 10 from galaxy_ml import __version__ as galaxy_ml_version | 
| 11 from galaxy_ml.utils import (get_cv, get_estimator, get_search_params, | 11 from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5 | 
| 12 load_model) | 12 from galaxy_ml.utils import get_cv, get_estimator | 
| 13 | 13 | 
| 14 warnings.filterwarnings("ignore") | 14 warnings.filterwarnings("ignore") | 
| 15 | 15 | 
| 16 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) | 16 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) | 
| 17 | 17 | 
| 18 | 18 | 
| 19 def main(inputs_path, output_obj, base_paths=None, meta_path=None, outfile_params=None): | 19 def main(inputs_path, output_obj, base_paths=None, meta_path=None): | 
| 20 """ | 20 """ | 
| 21 Parameter | 21 Parameter | 
| 22 --------- | 22 --------- | 
| 23 inputs_path : str | 23 inputs_path : str | 
| 24 File path for Galaxy parameters | 24 File path for Galaxy parameters | 
| 29 base_paths : str | 29 base_paths : str | 
| 30 File path or paths concatenated by comma. | 30 File path or paths concatenated by comma. | 
| 31 | 31 | 
| 32 meta_path : str | 32 meta_path : str | 
| 33 File path | 33 File path | 
| 34 | |
| 35 outfile_params : str | |
| 36 File path for params output | |
| 37 """ | 34 """ | 
| 38 with open(inputs_path, "r") as param_handler: | 35 with open(inputs_path, "r") as param_handler: | 
| 39 params = json.load(param_handler) | 36 params = json.load(param_handler) | 
| 40 | 37 | 
| 41 estimator_type = params["algo_selection"]["estimator_type"] | 38 estimator_type = params["algo_selection"]["estimator_type"] | 
| 42 # get base estimators | 39 # get base estimators | 
| 43 base_estimators = [] | 40 base_estimators = [] | 
| 44 for idx, base_file in enumerate(base_paths.split(",")): | 41 for idx, base_file in enumerate(base_paths.split(",")): | 
| 45 if base_file and base_file != "None": | 42 if base_file and base_file != "None": | 
| 46 with open(base_file, "rb") as handler: | 43 model = load_model_from_h5(base_file) | 
| 47 model = load_model(handler) | |
| 48 else: | 44 else: | 
| 49 estimator_json = params["base_est_builder"][idx]["estimator_selector"] | 45 estimator_json = params["base_est_builder"][idx]["estimator_selector"] | 
| 50 model = get_estimator(estimator_json) | 46 model = get_estimator(estimator_json) | 
| 51 | 47 | 
| 52 if estimator_type.startswith("sklearn"): | 48 if estimator_type.startswith("sklearn"): | 
| 57 base_estimators.append(model) | 53 base_estimators.append(model) | 
| 58 | 54 | 
| 59 # get meta estimator, if applicable | 55 # get meta estimator, if applicable | 
| 60 if estimator_type.startswith("mlxtend"): | 56 if estimator_type.startswith("mlxtend"): | 
| 61 if meta_path: | 57 if meta_path: | 
| 62 with open(meta_path, "rb") as f: | 58 meta_estimator = load_model_from_h5(meta_path) | 
| 63 meta_estimator = load_model(f) | |
| 64 else: | 59 else: | 
| 65 estimator_json = params["algo_selection"]["meta_estimator"][ | 60 estimator_json = params["algo_selection"]["meta_estimator"][ | 
| 66 "estimator_selector" | 61 "estimator_selector" | 
| 67 ] | 62 ] | 
| 68 meta_estimator = get_estimator(estimator_json) | 63 meta_estimator = get_estimator(estimator_json) | 
| 69 | 64 | 
| 70 options = params["algo_selection"]["options"] | 65 options = params["algo_selection"]["options"] | 
| 71 | 66 | 
| 72 cv_selector = options.pop("cv_selector", None) | 67 cv_selector = options.pop("cv_selector", None) | 
| 73 if cv_selector: | 68 if cv_selector: | 
| 74 splitter, _groups = get_cv(cv_selector) | 69 if Version(galaxy_ml_version) < Version("0.8.3"): | 
| 70 cv_selector.pop("n_stratification_bins", None) | |
| 71 splitter, groups = get_cv(cv_selector) | |
| 75 options["cv"] = splitter | 72 options["cv"] = splitter | 
| 76 # set n_jobs | 73 # set n_jobs | 
| 77 options["n_jobs"] = N_JOBS | 74 options["n_jobs"] = N_JOBS | 
| 78 | 75 | 
| 79 weights = options.pop("weights", None) | 76 weights = options.pop("weights", None) | 
| 102 | 99 | 
| 103 print(ensemble_estimator) | 100 print(ensemble_estimator) | 
| 104 for base_est in base_estimators: | 101 for base_est in base_estimators: | 
| 105 print(base_est) | 102 print(base_est) | 
| 106 | 103 | 
| 107 with open(output_obj, "wb") as out_handler: | 104 dump_model_to_h5(ensemble_estimator, output_obj) | 
| 108 pickle.dump(ensemble_estimator, out_handler, pickle.HIGHEST_PROTOCOL) | |
| 109 | |
| 110 if params["get_params"] and outfile_params: | |
| 111 results = get_search_params(ensemble_estimator) | |
| 112 df = pd.DataFrame(results, columns=["", "Parameter", "Value"]) | |
| 113 df.to_csv(outfile_params, sep="\t", index=False) | |
| 114 | 105 | 
| 115 | 106 | 
| 116 if __name__ == "__main__": | 107 if __name__ == "__main__": | 
| 117 aparser = argparse.ArgumentParser() | 108 aparser = argparse.ArgumentParser() | 
| 118 aparser.add_argument("-b", "--bases", dest="bases") | 109 aparser.add_argument("-b", "--bases", dest="bases") | 
| 119 aparser.add_argument("-m", "--meta", dest="meta") | 110 aparser.add_argument("-m", "--meta", dest="meta") | 
| 120 aparser.add_argument("-i", "--inputs", dest="inputs") | 111 aparser.add_argument("-i", "--inputs", dest="inputs") | 
| 121 aparser.add_argument("-o", "--outfile", dest="outfile") | 112 aparser.add_argument("-o", "--outfile", dest="outfile") | 
| 122 aparser.add_argument("-p", "--outfile_params", dest="outfile_params") | |
| 123 args = aparser.parse_args() | 113 args = aparser.parse_args() | 
| 124 | 114 | 
| 125 main( | 115 main(args.inputs, args.outfile, base_paths=args.bases, meta_path=args.meta) | 
| 126 args.inputs, | |
| 127 args.outfile, | |
| 128 base_paths=args.bases, | |
| 129 meta_path=args.meta, | |
| 130 outfile_params=args.outfile_params, | |
| 131 ) | 
