Mercurial > repos > bgruening > keras_batch_models
comparison model_prediction.py @ 15:70846a2dd227 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
| author | bgruening | 
|---|---|
| date | Wed, 09 Aug 2023 11:16:03 +0000 | 
| parents | 33af12059f42 | 
| children | 
   comparison
  equal
  deleted
  inserted
  replaced
| 14:81f8e8d847f0 | 15:70846a2dd227 | 
|---|---|
| 2 import json | 2 import json | 
| 3 import warnings | 3 import warnings | 
| 4 | 4 | 
| 5 import numpy as np | 5 import numpy as np | 
| 6 import pandas as pd | 6 import pandas as pd | 
| 7 from galaxy_ml.utils import get_module, load_model, read_columns, try_get_attr | 7 from galaxy_ml.model_persist import load_model_from_h5 | 
| 8 from galaxy_ml.utils import (clean_params, get_module, read_columns, | |
| 9 try_get_attr) | |
| 8 from scipy.io import mmread | 10 from scipy.io import mmread | 
| 9 from sklearn.pipeline import Pipeline | |
| 10 | 11 | 
| 11 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) | 12 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) | 
| 12 | 13 | 
| 13 | 14 | 
| 14 def main( | 15 def main( | 
| 15 inputs, | 16 inputs, | 
| 16 infile_estimator, | 17 infile_estimator, | 
| 17 outfile_predict, | 18 outfile_predict, | 
| 18 infile_weights=None, | |
| 19 infile1=None, | 19 infile1=None, | 
| 20 fasta_path=None, | 20 fasta_path=None, | 
| 21 ref_seq=None, | 21 ref_seq=None, | 
| 22 vcf_path=None, | 22 vcf_path=None, | 
| 23 ): | 23 ): | 
| 25 Parameter | 25 Parameter | 
| 26 --------- | 26 --------- | 
| 27 inputs : str | 27 inputs : str | 
| 28 File path to galaxy tool parameter | 28 File path to galaxy tool parameter | 
| 29 | 29 | 
| 30 infile_estimator : strgit | 30 infile_estimator : str | 
| 31 File path to trained estimator input | 31 File path to trained estimator input | 
| 32 | 32 | 
| 33 outfile_predict : str | 33 outfile_predict : str | 
| 34 File path to save the prediction results, tabular | 34 File path to save the prediction results, tabular | 
| 35 | |
| 36 infile_weights : str | |
| 37 File path to weights input | |
| 38 | 35 | 
| 39 infile1 : str | 36 infile1 : str | 
| 40 File path to dataset containing features | 37 File path to dataset containing features | 
| 41 | 38 | 
| 42 fasta_path : str | 39 fasta_path : str | 
| 52 | 49 | 
| 53 with open(inputs, "r") as param_handler: | 50 with open(inputs, "r") as param_handler: | 
| 54 params = json.load(param_handler) | 51 params = json.load(param_handler) | 
| 55 | 52 | 
| 56 # load model | 53 # load model | 
| 57 with open(infile_estimator, "rb") as est_handler: | 54 estimator = load_model_from_h5(infile_estimator) | 
| 58 estimator = load_model(est_handler) | 55 estimator = clean_params(estimator) | 
| 59 | |
| 60 main_est = estimator | |
| 61 if isinstance(estimator, Pipeline): | |
| 62 main_est = estimator.steps[-1][-1] | |
| 63 if hasattr(main_est, "config") and hasattr(main_est, "load_weights"): | |
| 64 if not infile_weights or infile_weights == "None": | |
| 65 raise ValueError( | |
| 66 "The selected model skeleton asks for weights, " | |
| 67 "but dataset for weights wan not selected!" | |
| 68 ) | |
| 69 main_est.load_weights(infile_weights) | |
| 70 | 56 | 
| 71 # handle data input | 57 # handle data input | 
| 72 input_type = params["input_options"]["selected_input"] | 58 input_type = params["input_options"]["selected_input"] | 
| 73 # tabular input | 59 # tabular input | 
| 74 if input_type == "tabular": | 60 if input_type == "tabular": | 
| 219 | 205 | 
| 220 if __name__ == "__main__": | 206 if __name__ == "__main__": | 
| 221 aparser = argparse.ArgumentParser() | 207 aparser = argparse.ArgumentParser() | 
| 222 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 208 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 
| 223 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") | 209 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") | 
| 224 aparser.add_argument("-w", "--infile_weights", dest="infile_weights") | |
| 225 aparser.add_argument("-X", "--infile1", dest="infile1") | 210 aparser.add_argument("-X", "--infile1", dest="infile1") | 
| 226 aparser.add_argument("-O", "--outfile_predict", dest="outfile_predict") | 211 aparser.add_argument("-O", "--outfile_predict", dest="outfile_predict") | 
| 227 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") | 212 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") | 
| 228 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") | 213 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") | 
| 229 aparser.add_argument("-v", "--vcf_path", dest="vcf_path") | 214 aparser.add_argument("-v", "--vcf_path", dest="vcf_path") | 
| 231 | 216 | 
| 232 main( | 217 main( | 
| 233 args.inputs, | 218 args.inputs, | 
| 234 args.infile_estimator, | 219 args.infile_estimator, | 
| 235 args.outfile_predict, | 220 args.outfile_predict, | 
| 236 infile_weights=args.infile_weights, | |
| 237 infile1=args.infile1, | 221 infile1=args.infile1, | 
| 238 fasta_path=args.fasta_path, | 222 fasta_path=args.fasta_path, | 
| 239 ref_seq=args.ref_seq, | 223 ref_seq=args.ref_seq, | 
| 240 vcf_path=args.vcf_path, | 224 vcf_path=args.vcf_path, | 
| 241 ) | 225 ) | 
