diff stacking_ensembles.py @ 3:0a1812986bc3 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 11:10:37 +0000
parents 38c4f8a98038
children
line wrap: on
line diff
--- a/stacking_ensembles.py	Mon Dec 16 10:07:37 2019 +0000
+++ b/stacking_ensembles.py	Wed Aug 09 11:10:37 2023 +0000
@@ -1,26 +1,22 @@
 import argparse
 import ast
 import json
-import mlxtend.regressor
-import mlxtend.classifier
-import pandas as pd
-import pickle
-import sklearn
 import sys
 import warnings
-from sklearn import ensemble
+from distutils.version import LooseVersion as Version
 
-from galaxy_ml.utils import (load_model, get_cv, get_estimator,
-                             get_search_params)
+import mlxtend.classifier
+import mlxtend.regressor
+from galaxy_ml import __version__ as galaxy_ml_version
+from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5
+from galaxy_ml.utils import get_cv, get_estimator
+
+warnings.filterwarnings("ignore")
+
+N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1))
 
 
-warnings.filterwarnings('ignore')
-
-N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1))
-
-
-def main(inputs_path, output_obj, base_paths=None, meta_path=None,
-         outfile_params=None):
+def main(inputs_path, output_obj, base_paths=None, meta_path=None):
     """
     Parameter
     ---------
@@ -35,98 +31,85 @@
 
     meta_path : str
         File path
-
-    outfile_params : str
-        File path for params output
     """
-    with open(inputs_path, 'r') as param_handler:
+    with open(inputs_path, "r") as param_handler:
         params = json.load(param_handler)
 
-    estimator_type = params['algo_selection']['estimator_type']
+    estimator_type = params["algo_selection"]["estimator_type"]
     # get base estimators
     base_estimators = []
-    for idx, base_file in enumerate(base_paths.split(',')):
-        if base_file and base_file != 'None':
-            with open(base_file, 'rb') as handler:
-                model = load_model(handler)
+    for idx, base_file in enumerate(base_paths.split(",")):
+        if base_file and base_file != "None":
+            model = load_model_from_h5(base_file)
         else:
-            estimator_json = (params['base_est_builder'][idx]
-                              ['estimator_selector'])
+            estimator_json = params["base_est_builder"][idx]["estimator_selector"]
             model = get_estimator(estimator_json)
 
-        if estimator_type.startswith('sklearn'):
+        if estimator_type.startswith("sklearn"):
             named = model.__class__.__name__.lower()
-            named = 'base_%d_%s' % (idx, named)
+            named = "base_%d_%s" % (idx, named)
             base_estimators.append((named, model))
         else:
             base_estimators.append(model)
 
     # get meta estimator, if applicable
-    if estimator_type.startswith('mlxtend'):
+    if estimator_type.startswith("mlxtend"):
         if meta_path:
-            with open(meta_path, 'rb') as f:
-                meta_estimator = load_model(f)
+            meta_estimator = load_model_from_h5(meta_path)
         else:
-            estimator_json = (params['algo_selection']
-                              ['meta_estimator']['estimator_selector'])
+            estimator_json = params["algo_selection"]["meta_estimator"][
+                "estimator_selector"
+            ]
             meta_estimator = get_estimator(estimator_json)
 
-    options = params['algo_selection']['options']
+    options = params["algo_selection"]["options"]
 
-    cv_selector = options.pop('cv_selector', None)
+    cv_selector = options.pop("cv_selector", None)
     if cv_selector:
+        if Version(galaxy_ml_version) < Version("0.8.3"):
+            cv_selector.pop("n_stratification_bins", None)
         splitter, groups = get_cv(cv_selector)
-        options['cv'] = splitter
+        options["cv"] = splitter
         # set n_jobs
-        options['n_jobs'] = N_JOBS
+        options["n_jobs"] = N_JOBS
 
-    weights = options.pop('weights', None)
+    weights = options.pop("weights", None)
     if weights:
         weights = ast.literal_eval(weights)
         if weights:
-            options['weights'] = weights
+            options["weights"] = weights
 
-    mod_and_name = estimator_type.split('_')
+    mod_and_name = estimator_type.split("_")
     mod = sys.modules[mod_and_name[0]]
     klass = getattr(mod, mod_and_name[1])
 
-    if estimator_type.startswith('sklearn'):
-        options['n_jobs'] = N_JOBS
+    if estimator_type.startswith("sklearn"):
+        options["n_jobs"] = N_JOBS
         ensemble_estimator = klass(base_estimators, **options)
 
     elif mod == mlxtend.classifier:
         ensemble_estimator = klass(
-            classifiers=base_estimators,
-            meta_classifier=meta_estimator,
-            **options)
+            classifiers=base_estimators, meta_classifier=meta_estimator, **options
+        )
 
     else:
         ensemble_estimator = klass(
-            regressors=base_estimators,
-            meta_regressor=meta_estimator,
-            **options)
+            regressors=base_estimators, meta_regressor=meta_estimator, **options
+        )
 
     print(ensemble_estimator)
     for base_est in base_estimators:
         print(base_est)
 
-    with open(output_obj, 'wb') as out_handler:
-        pickle.dump(ensemble_estimator, out_handler, pickle.HIGHEST_PROTOCOL)
-
-    if params['get_params'] and outfile_params:
-        results = get_search_params(ensemble_estimator)
-        df = pd.DataFrame(results, columns=['', 'Parameter', 'Value'])
-        df.to_csv(outfile_params, sep='\t', index=False)
+    dump_model_to_h5(ensemble_estimator, output_obj)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     aparser = argparse.ArgumentParser()
     aparser.add_argument("-b", "--bases", dest="bases")
     aparser.add_argument("-m", "--meta", dest="meta")
     aparser.add_argument("-i", "--inputs", dest="inputs")
     aparser.add_argument("-o", "--outfile", dest="outfile")
-    aparser.add_argument("-p", "--outfile_params", dest="outfile_params")
     args = aparser.parse_args()
 
-    main(args.inputs, args.outfile, base_paths=args.bases,
-         meta_path=args.meta, outfile_params=args.outfile_params)
+    main(args.inputs, args.outfile, base_paths=args.bases, meta_path=args.meta)