Mercurial > repos > bgruening > sklearn_train_test_split
comparison model_prediction.py @ 5:c0ed68e280a7 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 208a8d348e7c7a182cfbe1b6f17868146428a7e2"
| author | bgruening | 
|---|---|
| date | Tue, 13 Apr 2021 21:13:13 +0000 | 
| parents | ddb10352d9bb | 
| children | 81ab4951f2a3 | 
   comparison
  equal
  deleted
  inserted
  replaced
| 4:76f69baa9f33 | 5:c0ed68e280a7 | 
|---|---|
| 1 import argparse | 1 import argparse | 
| 2 import json | 2 import json | 
| 3 import warnings | |
| 4 | |
| 3 import numpy as np | 5 import numpy as np | 
| 4 import pandas as pd | 6 import pandas as pd | 
| 5 import warnings | 7 from galaxy_ml.utils import get_module, load_model, read_columns, try_get_attr | 
| 6 | |
| 7 from scipy.io import mmread | 8 from scipy.io import mmread | 
| 8 from sklearn.pipeline import Pipeline | 9 from sklearn.pipeline import Pipeline | 
| 9 | 10 | 
| 10 from galaxy_ml.utils import (load_model, read_columns, | 11 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) | 
| 11 get_module, try_get_attr) | 12 | 
| 12 | 13 | 
| 13 | 14 def main( | 
| 14 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) | 15 inputs, | 
| 15 | 16 infile_estimator, | 
| 16 | 17 outfile_predict, | 
| 17 def main(inputs, infile_estimator, outfile_predict, | 18 infile_weights=None, | 
| 18 infile_weights=None, infile1=None, | 19 infile1=None, | 
| 19 fasta_path=None, ref_seq=None, | 20 fasta_path=None, | 
| 20 vcf_path=None): | 21 ref_seq=None, | 
| 22 vcf_path=None, | |
| 23 ): | |
| 21 """ | 24 """ | 
| 22 Parameter | 25 Parameter | 
| 23 --------- | 26 --------- | 
| 24 inputs : str | 27 inputs : str | 
| 25 File path to galaxy tool parameter | 28 File path to galaxy tool parameter | 
| 43 File path to dataset containing the reference genome sequence. | 46 File path to dataset containing the reference genome sequence. | 
| 44 | 47 | 
| 45 vcf_path : str | 48 vcf_path : str | 
| 46 File path to dataset containing variants info. | 49 File path to dataset containing variants info. | 
| 47 """ | 50 """ | 
| 48 warnings.filterwarnings('ignore') | 51 warnings.filterwarnings("ignore") | 
| 49 | 52 | 
| 50 with open(inputs, 'r') as param_handler: | 53 with open(inputs, "r") as param_handler: | 
| 51 params = json.load(param_handler) | 54 params = json.load(param_handler) | 
| 52 | 55 | 
| 53 # load model | 56 # load model | 
| 54 with open(infile_estimator, 'rb') as est_handler: | 57 with open(infile_estimator, "rb") as est_handler: | 
| 55 estimator = load_model(est_handler) | 58 estimator = load_model(est_handler) | 
| 56 | 59 | 
| 57 main_est = estimator | 60 main_est = estimator | 
| 58 if isinstance(estimator, Pipeline): | 61 if isinstance(estimator, Pipeline): | 
| 59 main_est = estimator.steps[-1][-1] | 62 main_est = estimator.steps[-1][-1] | 
| 60 if hasattr(main_est, 'config') and hasattr(main_est, 'load_weights'): | 63 if hasattr(main_est, "config") and hasattr(main_est, "load_weights"): | 
| 61 if not infile_weights or infile_weights == 'None': | 64 if not infile_weights or infile_weights == "None": | 
| 62 raise ValueError("The selected model skeleton asks for weights, " | 65 raise ValueError( | 
| 63 "but dataset for weights wan not selected!") | 66 "The selected model skeleton asks for weights, " "but dataset for weights wan not selected!" | 
| 67 ) | |
| 64 main_est.load_weights(infile_weights) | 68 main_est.load_weights(infile_weights) | 
| 65 | 69 | 
| 66 # handle data input | 70 # handle data input | 
| 67 input_type = params['input_options']['selected_input'] | 71 input_type = params["input_options"]["selected_input"] | 
| 68 # tabular input | 72 # tabular input | 
| 69 if input_type == 'tabular': | 73 if input_type == "tabular": | 
| 70 header = 'infer' if params['input_options']['header1'] else None | 74 header = "infer" if params["input_options"]["header1"] else None | 
| 71 column_option = (params['input_options'] | 75 column_option = params["input_options"]["column_selector_options_1"]["selected_column_selector_option"] | 
| 72 ['column_selector_options_1'] | 76 if column_option in [ | 
| 73 ['selected_column_selector_option']) | 77 "by_index_number", | 
| 74 if column_option in ['by_index_number', 'all_but_by_index_number', | 78 "all_but_by_index_number", | 
| 75 'by_header_name', 'all_but_by_header_name']: | 79 "by_header_name", | 
| 76 c = params['input_options']['column_selector_options_1']['col1'] | 80 "all_but_by_header_name", | 
| 81 ]: | |
| 82 c = params["input_options"]["column_selector_options_1"]["col1"] | |
| 77 else: | 83 else: | 
| 78 c = None | 84 c = None | 
| 79 | 85 | 
| 80 df = pd.read_csv(infile1, sep='\t', header=header, parse_dates=True) | 86 df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True) | 
| 81 | 87 | 
| 82 X = read_columns(df, c=c, c_option=column_option).astype(float) | 88 X = read_columns(df, c=c, c_option=column_option).astype(float) | 
| 83 | 89 | 
| 84 if params['method'] == 'predict': | 90 if params["method"] == "predict": | 
| 85 preds = estimator.predict(X) | 91 preds = estimator.predict(X) | 
| 86 else: | 92 else: | 
| 87 preds = estimator.predict_proba(X) | 93 preds = estimator.predict_proba(X) | 
| 88 | 94 | 
| 89 # sparse input | 95 # sparse input | 
| 90 elif input_type == 'sparse': | 96 elif input_type == "sparse": | 
| 91 X = mmread(open(infile1, 'r')) | 97 X = mmread(open(infile1, "r")) | 
| 92 if params['method'] == 'predict': | 98 if params["method"] == "predict": | 
| 93 preds = estimator.predict(X) | 99 preds = estimator.predict(X) | 
| 94 else: | 100 else: | 
| 95 preds = estimator.predict_proba(X) | 101 preds = estimator.predict_proba(X) | 
| 96 | 102 | 
| 97 # fasta input | 103 # fasta input | 
| 98 elif input_type == 'seq_fasta': | 104 elif input_type == "seq_fasta": | 
| 99 if not hasattr(estimator, 'data_batch_generator'): | 105 if not hasattr(estimator, "data_batch_generator"): | 
| 100 raise ValueError( | 106 raise ValueError( | 
| 101 "To do prediction on sequences in fasta input, " | 107 "To do prediction on sequences in fasta input, " | 
| 102 "the estimator must be a `KerasGBatchClassifier`" | 108 "the estimator must be a `KerasGBatchClassifier`" | 
| 103 "equipped with data_batch_generator!") | 109 "equipped with data_batch_generator!" | 
| 104 pyfaidx = get_module('pyfaidx') | 110 ) | 
| 111 pyfaidx = get_module("pyfaidx") | |
| 105 sequences = pyfaidx.Fasta(fasta_path) | 112 sequences = pyfaidx.Fasta(fasta_path) | 
| 106 n_seqs = len(sequences.keys()) | 113 n_seqs = len(sequences.keys()) | 
| 107 X = np.arange(n_seqs)[:, np.newaxis] | 114 X = np.arange(n_seqs)[:, np.newaxis] | 
| 108 seq_length = estimator.data_batch_generator.seq_length | 115 seq_length = estimator.data_batch_generator.seq_length | 
| 109 batch_size = getattr(estimator, 'batch_size', 32) | 116 batch_size = getattr(estimator, "batch_size", 32) | 
| 110 steps = (n_seqs + batch_size - 1) // batch_size | 117 steps = (n_seqs + batch_size - 1) // batch_size | 
| 111 | 118 | 
| 112 seq_type = params['input_options']['seq_type'] | 119 seq_type = params["input_options"]["seq_type"] | 
| 113 klass = try_get_attr( | 120 klass = try_get_attr("galaxy_ml.preprocessors", seq_type) | 
| 114 'galaxy_ml.preprocessors', seq_type) | 121 | 
| 115 | 122 pred_data_generator = klass(fasta_path, seq_length=seq_length) | 
| 116 pred_data_generator = klass( | 123 | 
| 117 fasta_path, seq_length=seq_length) | 124 if params["method"] == "predict": | 
| 118 | 125 preds = estimator.predict(X, data_generator=pred_data_generator, steps=steps) | 
| 119 if params['method'] == 'predict': | 126 else: | 
| 120 preds = estimator.predict( | 127 preds = estimator.predict_proba(X, data_generator=pred_data_generator, steps=steps) | 
| 121 X, data_generator=pred_data_generator, steps=steps) | |
| 122 else: | |
| 123 preds = estimator.predict_proba( | |
| 124 X, data_generator=pred_data_generator, steps=steps) | |
| 125 | 128 | 
| 126 # vcf input | 129 # vcf input | 
| 127 elif input_type == 'variant_effect': | 130 elif input_type == "variant_effect": | 
| 128 klass = try_get_attr('galaxy_ml.preprocessors', | 131 klass = try_get_attr("galaxy_ml.preprocessors", "GenomicVariantBatchGenerator") | 
| 129 'GenomicVariantBatchGenerator') | 132 | 
| 130 | 133 options = params["input_options"] | 
| 131 options = params['input_options'] | 134 options.pop("selected_input") | 
| 132 options.pop('selected_input') | 135 if options["blacklist_regions"] == "none": | 
| 133 if options['blacklist_regions'] == 'none': | 136 options["blacklist_regions"] = None | 
| 134 options['blacklist_regions'] = None | 137 | 
| 135 | 138 pred_data_generator = klass(ref_genome_path=ref_seq, vcf_path=vcf_path, **options) | 
| 136 pred_data_generator = klass( | |
| 137 ref_genome_path=ref_seq, vcf_path=vcf_path, **options) | |
| 138 | 139 | 
| 139 pred_data_generator.set_processing_attrs() | 140 pred_data_generator.set_processing_attrs() | 
| 140 | 141 | 
| 141 variants = pred_data_generator.variants | 142 variants = pred_data_generator.variants | 
| 142 | 143 | 
| 143 # predict 1600 sample at once then write to file | 144 # predict 1600 sample at once then write to file | 
| 144 gen_flow = pred_data_generator.flow(batch_size=1600) | 145 gen_flow = pred_data_generator.flow(batch_size=1600) | 
| 145 | 146 | 
| 146 file_writer = open(outfile_predict, 'w') | 147 file_writer = open(outfile_predict, "w") | 
| 147 header_row = '\t'.join(['chrom', 'pos', 'name', 'ref', | 148 header_row = "\t".join(["chrom", "pos", "name", "ref", "alt", "strand"]) | 
| 148 'alt', 'strand']) | |
| 149 file_writer.write(header_row) | 149 file_writer.write(header_row) | 
| 150 header_done = False | 150 header_done = False | 
| 151 | 151 | 
| 152 steps_done = 0 | 152 steps_done = 0 | 
| 153 | 153 | 
| 154 # TODO: multiple threading | 154 # TODO: multiple threading | 
| 155 try: | 155 try: | 
| 156 while steps_done < len(gen_flow): | 156 while steps_done < len(gen_flow): | 
| 157 index_array = next(gen_flow.index_generator) | 157 index_array = next(gen_flow.index_generator) | 
| 158 batch_X = gen_flow._get_batches_of_transformed_samples( | 158 batch_X = gen_flow._get_batches_of_transformed_samples(index_array) | 
| 159 index_array) | 159 | 
| 160 | 160 if params["method"] == "predict": | 
| 161 if params['method'] == 'predict': | |
| 162 batch_preds = estimator.predict( | 161 batch_preds = estimator.predict( | 
| 163 batch_X, | 162 batch_X, | 
| 164 # The presence of `pred_data_generator` below is to | 163 # The presence of `pred_data_generator` below is to | 
| 165 # override model carrying data_generator if there | 164 # override model carrying data_generator if there | 
| 166 # is any. | 165 # is any. | 
| 167 data_generator=pred_data_generator) | 166 data_generator=pred_data_generator, | 
| 167 ) | |
| 168 else: | 168 else: | 
| 169 batch_preds = estimator.predict_proba( | 169 batch_preds = estimator.predict_proba( | 
| 170 batch_X, | 170 batch_X, | 
| 171 # The presence of `pred_data_generator` below is to | 171 # The presence of `pred_data_generator` below is to | 
| 172 # override model carrying data_generator if there | 172 # override model carrying data_generator if there | 
| 173 # is any. | 173 # is any. | 
| 174 data_generator=pred_data_generator) | 174 data_generator=pred_data_generator, | 
| 175 ) | |
| 175 | 176 | 
| 176 if batch_preds.ndim == 1: | 177 if batch_preds.ndim == 1: | 
| 177 batch_preds = batch_preds[:, np.newaxis] | 178 batch_preds = batch_preds[:, np.newaxis] | 
| 178 | 179 | 
| 179 batch_meta = variants[index_array] | 180 batch_meta = variants[index_array] | 
| 180 batch_out = np.column_stack([batch_meta, batch_preds]) | 181 batch_out = np.column_stack([batch_meta, batch_preds]) | 
| 181 | 182 | 
| 182 if not header_done: | 183 if not header_done: | 
| 183 heads = np.arange(batch_preds.shape[-1]).astype(str) | 184 heads = np.arange(batch_preds.shape[-1]).astype(str) | 
| 184 heads_str = '\t'.join(heads) | 185 heads_str = "\t".join(heads) | 
| 185 file_writer.write("\t%s\n" % heads_str) | 186 file_writer.write("\t%s\n" % heads_str) | 
| 186 header_done = True | 187 header_done = True | 
| 187 | 188 | 
| 188 for row in batch_out: | 189 for row in batch_out: | 
| 189 row_str = '\t'.join(row) | 190 row_str = "\t".join(row) | 
| 190 file_writer.write("%s\n" % row_str) | 191 file_writer.write("%s\n" % row_str) | 
| 191 | 192 | 
| 192 steps_done += 1 | 193 steps_done += 1 | 
| 193 | 194 | 
| 194 finally: | 195 finally: | 
| 198 return 0 | 199 return 0 | 
| 199 # end input | 200 # end input | 
| 200 | 201 | 
| 201 # output | 202 # output | 
| 202 if len(preds.shape) == 1: | 203 if len(preds.shape) == 1: | 
| 203 rval = pd.DataFrame(preds, columns=['Predicted']) | 204 rval = pd.DataFrame(preds, columns=["Predicted"]) | 
| 204 else: | 205 else: | 
| 205 rval = pd.DataFrame(preds) | 206 rval = pd.DataFrame(preds) | 
| 206 | 207 | 
| 207 rval.to_csv(outfile_predict, sep='\t', header=True, index=False) | 208 rval.to_csv(outfile_predict, sep="\t", header=True, index=False) | 
| 208 | 209 | 
| 209 | 210 | 
| 210 if __name__ == '__main__': | 211 if __name__ == "__main__": | 
| 211 aparser = argparse.ArgumentParser() | 212 aparser = argparse.ArgumentParser() | 
| 212 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 213 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 
| 213 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") | 214 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") | 
| 214 aparser.add_argument("-w", "--infile_weights", dest="infile_weights") | 215 aparser.add_argument("-w", "--infile_weights", dest="infile_weights") | 
| 215 aparser.add_argument("-X", "--infile1", dest="infile1") | 216 aparser.add_argument("-X", "--infile1", dest="infile1") | 
| 217 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") | 218 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") | 
| 218 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") | 219 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") | 
| 219 aparser.add_argument("-v", "--vcf_path", dest="vcf_path") | 220 aparser.add_argument("-v", "--vcf_path", dest="vcf_path") | 
| 220 args = aparser.parse_args() | 221 args = aparser.parse_args() | 
| 221 | 222 | 
| 222 main(args.inputs, args.infile_estimator, args.outfile_predict, | 223 main( | 
| 223 infile_weights=args.infile_weights, infile1=args.infile1, | 224 args.inputs, | 
| 224 fasta_path=args.fasta_path, ref_seq=args.ref_seq, | 225 args.infile_estimator, | 
| 225 vcf_path=args.vcf_path) | 226 args.outfile_predict, | 
| 227 infile_weights=args.infile_weights, | |
| 228 infile1=args.infile1, | |
| 229 fasta_path=args.fasta_path, | |
| 230 ref_seq=args.ref_seq, | |
| 231 vcf_path=args.vcf_path, | |
| 232 ) | 
