Mercurial > repos > bgruening > sklearn_label_encoder
comparison model_prediction.py @ 0:03155260beb3 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ea12f973df4b97a2691d9e4ce6bf6fae59d57717"
| author | bgruening |
|---|---|
| date | Fri, 30 Apr 2021 23:36:38 +0000 |
| parents | |
| children | b008b609205e |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:03155260beb3 |
|---|---|
| 1 import argparse | |
| 2 import json | |
| 3 import warnings | |
| 4 | |
| 5 import numpy as np | |
| 6 import pandas as pd | |
| 7 from galaxy_ml.utils import get_module, load_model, read_columns, try_get_attr | |
| 8 from scipy.io import mmread | |
| 9 from sklearn.pipeline import Pipeline | |
| 10 | |
| 11 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) | |
| 12 | |
| 13 | |
| 14 def main( | |
| 15 inputs, | |
| 16 infile_estimator, | |
| 17 outfile_predict, | |
| 18 infile_weights=None, | |
| 19 infile1=None, | |
| 20 fasta_path=None, | |
| 21 ref_seq=None, | |
| 22 vcf_path=None, | |
| 23 ): | |
| 24 """ | |
| 25 Parameter | |
| 26 --------- | |
| 27 inputs : str | |
| 28 File path to galaxy tool parameter | |
| 29 | |
| 30 infile_estimator : strgit | |
| 31 File path to trained estimator input | |
| 32 | |
| 33 outfile_predict : str | |
| 34 File path to save the prediction results, tabular | |
| 35 | |
| 36 infile_weights : str | |
| 37 File path to weights input | |
| 38 | |
| 39 infile1 : str | |
| 40 File path to dataset containing features | |
| 41 | |
| 42 fasta_path : str | |
| 43 File path to dataset containing fasta file | |
| 44 | |
| 45 ref_seq : str | |
| 46 File path to dataset containing the reference genome sequence. | |
| 47 | |
| 48 vcf_path : str | |
| 49 File path to dataset containing variants info. | |
| 50 """ | |
| 51 warnings.filterwarnings("ignore") | |
| 52 | |
| 53 with open(inputs, "r") as param_handler: | |
| 54 params = json.load(param_handler) | |
| 55 | |
| 56 # load model | |
| 57 with open(infile_estimator, "rb") as est_handler: | |
| 58 estimator = load_model(est_handler) | |
| 59 | |
| 60 main_est = estimator | |
| 61 if isinstance(estimator, Pipeline): | |
| 62 main_est = estimator.steps[-1][-1] | |
| 63 if hasattr(main_est, "config") and hasattr(main_est, "load_weights"): | |
| 64 if not infile_weights or infile_weights == "None": | |
| 65 raise ValueError( | |
| 66 "The selected model skeleton asks for weights, " | |
| 67 "but dataset for weights wan not selected!" | |
| 68 ) | |
| 69 main_est.load_weights(infile_weights) | |
| 70 | |
| 71 # handle data input | |
| 72 input_type = params["input_options"]["selected_input"] | |
| 73 # tabular input | |
| 74 if input_type == "tabular": | |
| 75 header = "infer" if params["input_options"]["header1"] else None | |
| 76 column_option = params["input_options"]["column_selector_options_1"][ | |
| 77 "selected_column_selector_option" | |
| 78 ] | |
| 79 if column_option in [ | |
| 80 "by_index_number", | |
| 81 "all_but_by_index_number", | |
| 82 "by_header_name", | |
| 83 "all_but_by_header_name", | |
| 84 ]: | |
| 85 c = params["input_options"]["column_selector_options_1"]["col1"] | |
| 86 else: | |
| 87 c = None | |
| 88 | |
| 89 df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True) | |
| 90 | |
| 91 X = read_columns(df, c=c, c_option=column_option).astype(float) | |
| 92 | |
| 93 if params["method"] == "predict": | |
| 94 preds = estimator.predict(X) | |
| 95 else: | |
| 96 preds = estimator.predict_proba(X) | |
| 97 | |
| 98 # sparse input | |
| 99 elif input_type == "sparse": | |
| 100 X = mmread(open(infile1, "r")) | |
| 101 if params["method"] == "predict": | |
| 102 preds = estimator.predict(X) | |
| 103 else: | |
| 104 preds = estimator.predict_proba(X) | |
| 105 | |
| 106 # fasta input | |
| 107 elif input_type == "seq_fasta": | |
| 108 if not hasattr(estimator, "data_batch_generator"): | |
| 109 raise ValueError( | |
| 110 "To do prediction on sequences in fasta input, " | |
| 111 "the estimator must be a `KerasGBatchClassifier`" | |
| 112 "equipped with data_batch_generator!" | |
| 113 ) | |
| 114 pyfaidx = get_module("pyfaidx") | |
| 115 sequences = pyfaidx.Fasta(fasta_path) | |
| 116 n_seqs = len(sequences.keys()) | |
| 117 X = np.arange(n_seqs)[:, np.newaxis] | |
| 118 seq_length = estimator.data_batch_generator.seq_length | |
| 119 batch_size = getattr(estimator, "batch_size", 32) | |
| 120 steps = (n_seqs + batch_size - 1) // batch_size | |
| 121 | |
| 122 seq_type = params["input_options"]["seq_type"] | |
| 123 klass = try_get_attr("galaxy_ml.preprocessors", seq_type) | |
| 124 | |
| 125 pred_data_generator = klass(fasta_path, seq_length=seq_length) | |
| 126 | |
| 127 if params["method"] == "predict": | |
| 128 preds = estimator.predict( | |
| 129 X, data_generator=pred_data_generator, steps=steps | |
| 130 ) | |
| 131 else: | |
| 132 preds = estimator.predict_proba( | |
| 133 X, data_generator=pred_data_generator, steps=steps | |
| 134 ) | |
| 135 | |
| 136 # vcf input | |
| 137 elif input_type == "variant_effect": | |
| 138 klass = try_get_attr("galaxy_ml.preprocessors", "GenomicVariantBatchGenerator") | |
| 139 | |
| 140 options = params["input_options"] | |
| 141 options.pop("selected_input") | |
| 142 if options["blacklist_regions"] == "none": | |
| 143 options["blacklist_regions"] = None | |
| 144 | |
| 145 pred_data_generator = klass( | |
| 146 ref_genome_path=ref_seq, vcf_path=vcf_path, **options | |
| 147 ) | |
| 148 | |
| 149 pred_data_generator.set_processing_attrs() | |
| 150 | |
| 151 variants = pred_data_generator.variants | |
| 152 | |
| 153 # predict 1600 sample at once then write to file | |
| 154 gen_flow = pred_data_generator.flow(batch_size=1600) | |
| 155 | |
| 156 file_writer = open(outfile_predict, "w") | |
| 157 header_row = "\t".join(["chrom", "pos", "name", "ref", "alt", "strand"]) | |
| 158 file_writer.write(header_row) | |
| 159 header_done = False | |
| 160 | |
| 161 steps_done = 0 | |
| 162 | |
| 163 # TODO: multiple threading | |
| 164 try: | |
| 165 while steps_done < len(gen_flow): | |
| 166 index_array = next(gen_flow.index_generator) | |
| 167 batch_X = gen_flow._get_batches_of_transformed_samples(index_array) | |
| 168 | |
| 169 if params["method"] == "predict": | |
| 170 batch_preds = estimator.predict( | |
| 171 batch_X, | |
| 172 # The presence of `pred_data_generator` below is to | |
| 173 # override model carrying data_generator if there | |
| 174 # is any. | |
| 175 data_generator=pred_data_generator, | |
| 176 ) | |
| 177 else: | |
| 178 batch_preds = estimator.predict_proba( | |
| 179 batch_X, | |
| 180 # The presence of `pred_data_generator` below is to | |
| 181 # override model carrying data_generator if there | |
| 182 # is any. | |
| 183 data_generator=pred_data_generator, | |
| 184 ) | |
| 185 | |
| 186 if batch_preds.ndim == 1: | |
| 187 batch_preds = batch_preds[:, np.newaxis] | |
| 188 | |
| 189 batch_meta = variants[index_array] | |
| 190 batch_out = np.column_stack([batch_meta, batch_preds]) | |
| 191 | |
| 192 if not header_done: | |
| 193 heads = np.arange(batch_preds.shape[-1]).astype(str) | |
| 194 heads_str = "\t".join(heads) | |
| 195 file_writer.write("\t%s\n" % heads_str) | |
| 196 header_done = True | |
| 197 | |
| 198 for row in batch_out: | |
| 199 row_str = "\t".join(row) | |
| 200 file_writer.write("%s\n" % row_str) | |
| 201 | |
| 202 steps_done += 1 | |
| 203 | |
| 204 finally: | |
| 205 file_writer.close() | |
| 206 # TODO: make api `pred_data_generator.close()` | |
| 207 pred_data_generator.close() | |
| 208 return 0 | |
| 209 # end input | |
| 210 | |
| 211 # output | |
| 212 if len(preds.shape) == 1: | |
| 213 rval = pd.DataFrame(preds, columns=["Predicted"]) | |
| 214 else: | |
| 215 rval = pd.DataFrame(preds) | |
| 216 | |
| 217 rval.to_csv(outfile_predict, sep="\t", header=True, index=False) | |
| 218 | |
| 219 | |
| 220 if __name__ == "__main__": | |
| 221 aparser = argparse.ArgumentParser() | |
| 222 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | |
| 223 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") | |
| 224 aparser.add_argument("-w", "--infile_weights", dest="infile_weights") | |
| 225 aparser.add_argument("-X", "--infile1", dest="infile1") | |
| 226 aparser.add_argument("-O", "--outfile_predict", dest="outfile_predict") | |
| 227 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") | |
| 228 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") | |
| 229 aparser.add_argument("-v", "--vcf_path", dest="vcf_path") | |
| 230 args = aparser.parse_args() | |
| 231 | |
| 232 main( | |
| 233 args.inputs, | |
| 234 args.infile_estimator, | |
| 235 args.outfile_predict, | |
| 236 infile_weights=args.infile_weights, | |
| 237 infile1=args.infile1, | |
| 238 fasta_path=args.fasta_path, | |
| 239 ref_seq=args.ref_seq, | |
| 240 vcf_path=args.vcf_path, | |
| 241 ) |
