Mercurial > repos > bgruening > model_prediction
comparison train_test_eval.py @ 14:99dcca81a784 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
| author | bgruening |
|---|---|
| date | Wed, 09 Aug 2023 11:45:18 +0000 |
| parents | 5f848056fbf8 |
| children |
comparison
equal
deleted
inserted
replaced
| 13:c1054de039f4 | 14:99dcca81a784 |
|---|---|
| 1 import argparse | 1 import argparse |
| 2 import json | 2 import json |
| 3 import os | 3 import os |
| 4 import pickle | |
| 5 import warnings | 4 import warnings |
| 6 from itertools import chain | 5 from itertools import chain |
| 7 | 6 |
| 8 import joblib | 7 import joblib |
| 9 import numpy as np | 8 import numpy as np |
| 10 import pandas as pd | 9 import pandas as pd |
| 10 from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5 | |
| 11 from galaxy_ml.model_validations import train_test_split | 11 from galaxy_ml.model_validations import train_test_split |
| 12 from galaxy_ml.utils import (get_module, get_scoring, load_model, | 12 from galaxy_ml.utils import ( |
| 13 read_columns, SafeEval, try_get_attr) | 13 clean_params, |
| 14 get_module, | |
| 15 get_scoring, | |
| 16 read_columns, | |
| 17 SafeEval, | |
| 18 try_get_attr | |
| 19 ) | |
| 14 from scipy.io import mmread | 20 from scipy.io import mmread |
| 15 from sklearn import pipeline | 21 from sklearn import pipeline |
| 16 from sklearn.metrics.scorer import _check_multimetric_scoring | |
| 17 from sklearn.model_selection import _search, _validation | 22 from sklearn.model_selection import _search, _validation |
| 18 from sklearn.model_selection._validation import _score | 23 from sklearn.model_selection._validation import _score |
| 19 from sklearn.utils import indexable, safe_indexing | 24 from sklearn.utils import _safe_indexing, indexable |
| 20 | 25 |
| 21 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score") | 26 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score") |
| 22 setattr(_search, "_fit_and_score", _fit_and_score) | 27 setattr(_search, "_fit_and_score", _fit_and_score) |
| 23 setattr(_validation, "_fit_and_score", _fit_and_score) | 28 setattr(_validation, "_fit_and_score", _fit_and_score) |
| 24 | 29 |
| 91 index_arr = np.arange(n_samples) | 96 index_arr = np.arange(n_samples) |
| 92 test = index_arr[np.isin(groups, group_names)] | 97 test = index_arr[np.isin(groups, group_names)] |
| 93 train = index_arr[~np.isin(groups, group_names)] | 98 train = index_arr[~np.isin(groups, group_names)] |
| 94 rval = list( | 99 rval = list( |
| 95 chain.from_iterable( | 100 chain.from_iterable( |
| 96 (safe_indexing(a, train), safe_indexing(a, test)) for a in new_arrays | 101 (_safe_indexing(a, train), _safe_indexing(a, test)) for a in new_arrays |
| 97 ) | 102 ) |
| 98 ) | 103 ) |
| 99 else: | 104 else: |
| 100 rval = train_test_split(*new_arrays, **kwargs) | 105 rval = train_test_split(*new_arrays, **kwargs) |
| 101 | 106 |
| 162 | 167 |
| 163 with open(inputs, "r") as param_handler: | 168 with open(inputs, "r") as param_handler: |
| 164 params = json.load(param_handler) | 169 params = json.load(param_handler) |
| 165 | 170 |
| 166 # load estimator | 171 # load estimator |
| 167 with open(infile_estimator, "rb") as estimator_handler: | 172 estimator = load_model_from_h5(infile_estimator) |
| 168 estimator = load_model(estimator_handler) | 173 estimator = clean_params(estimator) |
| 169 | 174 |
| 170 # swap hyperparameter | 175 # swap hyperparameter |
| 171 swapping = params["experiment_schemes"]["hyperparams_swapping"] | 176 swapping = params["experiment_schemes"]["hyperparams_swapping"] |
| 172 swap_params = _eval_swap_params(swapping) | 177 swap_params = _eval_swap_params(swapping) |
| 173 estimator.set_params(**swap_params) | 178 estimator.set_params(**swap_params) |
| 346 secondary_scoring = scoring.get("secondary_scoring", None) | 351 secondary_scoring = scoring.get("secondary_scoring", None) |
| 347 if secondary_scoring is not None: | 352 if secondary_scoring is not None: |
| 348 # If secondary_scoring is specified, convert the list into comman separated string | 353 # If secondary_scoring is specified, convert the list into comman separated string |
| 349 scoring["secondary_scoring"] = ",".join(scoring["secondary_scoring"]) | 354 scoring["secondary_scoring"] = ",".join(scoring["secondary_scoring"]) |
| 350 scorer = get_scoring(scoring) | 355 scorer = get_scoring(scoring) |
| 351 scorer, _ = _check_multimetric_scoring(estimator, scoring=scorer) | |
| 352 | 356 |
| 353 # handle test (first) split | 357 # handle test (first) split |
| 354 test_split_options = params["experiment_schemes"]["test_split"]["split_algos"] | 358 test_split_options = params["experiment_schemes"]["test_split"]["split_algos"] |
| 355 | 359 |
| 356 if test_split_options["shuffle"] == "group": | 360 if test_split_options["shuffle"] == "group": |
| 410 if hasattr(estimator, "evaluate"): | 414 if hasattr(estimator, "evaluate"): |
| 411 scores = estimator.evaluate( | 415 scores = estimator.evaluate( |
| 412 X_test, y_test=y_test, scorer=scorer, is_multimetric=True | 416 X_test, y_test=y_test, scorer=scorer, is_multimetric=True |
| 413 ) | 417 ) |
| 414 else: | 418 else: |
| 415 scores = _score(estimator, X_test, y_test, scorer, is_multimetric=True) | 419 scores = _score(estimator, X_test, y_test, scorer) |
| 416 # handle output | 420 # handle output |
| 417 for name, score in scores.items(): | 421 for name, score in scores.items(): |
| 418 scores[name] = [score] | 422 scores[name] = [score] |
| 419 df = pd.DataFrame(scores) | 423 df = pd.DataFrame(scores) |
| 420 df = df[sorted(df.columns)] | 424 df = df[sorted(df.columns)] |
| 439 if getattr(main_est, "validation_data", None): | 443 if getattr(main_est, "validation_data", None): |
| 440 del main_est.validation_data | 444 del main_est.validation_data |
| 441 if getattr(main_est, "data_generator_", None): | 445 if getattr(main_est, "data_generator_", None): |
| 442 del main_est.data_generator_ | 446 del main_est.data_generator_ |
| 443 | 447 |
| 444 with open(outfile_object, "wb") as output_handler: | 448 dump_model_to_h5(estimator, outfile_object) |
| 445 pickle.dump(estimator, output_handler, pickle.HIGHEST_PROTOCOL) | |
| 446 | 449 |
| 447 | 450 |
| 448 if __name__ == "__main__": | 451 if __name__ == "__main__": |
| 449 aparser = argparse.ArgumentParser() | 452 aparser = argparse.ArgumentParser() |
| 450 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 453 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) |
