Mercurial > repos > bgruening > sklearn_fitted_model_eval
comparison simple_model_fit.py @ 11:ed5472c523fa draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
| author | bgruening |
|---|---|
| date | Wed, 09 Aug 2023 11:56:03 +0000 |
| parents | c333698de5f4 |
| children |
comparison
equal
deleted
inserted
replaced
| 10:7fbc6a108504 | 11:ed5472c523fa |
|---|---|
| 1 import argparse | 1 import argparse |
| 2 import json | 2 import json |
| 3 import pickle | |
| 4 | 3 |
| 5 import pandas as pd | 4 import pandas as pd |
| 6 from galaxy_ml.utils import load_model, read_columns | 5 from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5 |
| 6 from galaxy_ml.utils import read_columns | |
| 7 from scipy.io import mmread | 7 from scipy.io import mmread |
| 8 from sklearn.pipeline import Pipeline | 8 from sklearn.pipeline import Pipeline |
| 9 | 9 |
| 10 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) | 10 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) |
| 11 | 11 |
| 146 """ | 146 """ |
| 147 with open(inputs, "r") as param_handler: | 147 with open(inputs, "r") as param_handler: |
| 148 params = json.load(param_handler) | 148 params = json.load(param_handler) |
| 149 | 149 |
| 150 # load model | 150 # load model |
| 151 with open(infile_estimator, "rb") as est_handler: | 151 estimator = load_model_from_h5(infile_estimator) |
| 152 estimator = load_model(est_handler) | 152 |
| 153 estimator = clean_params(estimator, n_jobs=N_JOBS) | 153 estimator = clean_params(estimator) |
| 154 | 154 |
| 155 X_train, y_train = _get_X_y(params, infile1, infile2) | 155 X_train, y_train = _get_X_y(params, infile1, infile2) |
| 156 | 156 |
| 157 estimator.fit(X_train, y_train) | 157 estimator.fit(X_train, y_train) |
| 158 | 158 |
| 168 if getattr(main_est, "validation_data", None): | 168 if getattr(main_est, "validation_data", None): |
| 169 del main_est.validation_data | 169 del main_est.validation_data |
| 170 if getattr(main_est, "data_generator_", None): | 170 if getattr(main_est, "data_generator_", None): |
| 171 del main_est.data_generator_ | 171 del main_est.data_generator_ |
| 172 | 172 |
| 173 with open(out_object, "wb") as output_handler: | 173 dump_model_to_h5(estimator, out_object) |
| 174 pickle.dump(estimator, output_handler, pickle.HIGHEST_PROTOCOL) | |
| 175 | 174 |
| 176 | 175 |
| 177 if __name__ == "__main__": | 176 if __name__ == "__main__": |
| 178 aparser = argparse.ArgumentParser() | 177 aparser = argparse.ArgumentParser() |
| 179 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 178 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) |
