diff model_prediction.py @ 3:0a1812986bc3 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 11:10:37 +0000
parents 38c4f8a98038
children
line wrap: on
line diff
--- a/model_prediction.py	Mon Dec 16 10:07:37 2019 +0000
+++ b/model_prediction.py	Wed Aug 09 11:10:37 2023 +0000
@@ -1,38 +1,38 @@
 import argparse
 import json
-import numpy as np
-import pandas as pd
 import warnings
 
+import numpy as np
+import pandas as pd
+from galaxy_ml.model_persist import load_model_from_h5
+from galaxy_ml.utils import (clean_params, get_module, read_columns,
+                             try_get_attr)
 from scipy.io import mmread
-from sklearn.pipeline import Pipeline
 
-from galaxy_ml.utils import (load_model, read_columns,
-                             get_module, try_get_attr)
+N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1))
 
 
-N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1))
-
-
-def main(inputs, infile_estimator, outfile_predict,
-         infile_weights=None, infile1=None,
-         fasta_path=None, ref_seq=None,
-         vcf_path=None):
+def main(
+    inputs,
+    infile_estimator,
+    outfile_predict,
+    infile1=None,
+    fasta_path=None,
+    ref_seq=None,
+    vcf_path=None,
+):
     """
     Parameter
     ---------
     inputs : str
         File path to galaxy tool parameter
 
-    infile_estimator : strgit
+    infile_estimator : str
         File path to trained estimator input
 
     outfile_predict : str
         File path to save the prediction results, tabular
 
-    infile_weights : str
-        File path to weights input
-
     infile1 : str
         File path to dataset containing features
 
@@ -45,96 +45,92 @@
     vcf_path : str
         File path to dataset containing variants info.
     """
-    warnings.filterwarnings('ignore')
+    warnings.filterwarnings("ignore")
 
-    with open(inputs, 'r') as param_handler:
+    with open(inputs, "r") as param_handler:
         params = json.load(param_handler)
 
     # load model
-    with open(infile_estimator, 'rb') as est_handler:
-        estimator = load_model(est_handler)
-
-    main_est = estimator
-    if isinstance(estimator, Pipeline):
-        main_est = estimator.steps[-1][-1]
-    if hasattr(main_est, 'config') and hasattr(main_est, 'load_weights'):
-        if not infile_weights or infile_weights == 'None':
-            raise ValueError("The selected model skeleton asks for weights, "
-                             "but dataset for weights wan not selected!")
-        main_est.load_weights(infile_weights)
+    estimator = load_model_from_h5(infile_estimator)
+    estimator = clean_params(estimator)
 
     # handle data input
-    input_type = params['input_options']['selected_input']
+    input_type = params["input_options"]["selected_input"]
     # tabular input
-    if input_type == 'tabular':
-        header = 'infer' if params['input_options']['header1'] else None
-        column_option = (params['input_options']
-                               ['column_selector_options_1']
-                               ['selected_column_selector_option'])
-        if column_option in ['by_index_number', 'all_but_by_index_number',
-                             'by_header_name', 'all_but_by_header_name']:
-            c = params['input_options']['column_selector_options_1']['col1']
+    if input_type == "tabular":
+        header = "infer" if params["input_options"]["header1"] else None
+        column_option = params["input_options"]["column_selector_options_1"][
+            "selected_column_selector_option"
+        ]
+        if column_option in [
+            "by_index_number",
+            "all_but_by_index_number",
+            "by_header_name",
+            "all_but_by_header_name",
+        ]:
+            c = params["input_options"]["column_selector_options_1"]["col1"]
         else:
             c = None
 
-        df = pd.read_csv(infile1, sep='\t', header=header, parse_dates=True)
+        df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True)
 
         X = read_columns(df, c=c, c_option=column_option).astype(float)
 
-        if params['method'] == 'predict':
+        if params["method"] == "predict":
             preds = estimator.predict(X)
         else:
             preds = estimator.predict_proba(X)
 
     # sparse input
-    elif input_type == 'sparse':
-        X = mmread(open(infile1, 'r'))
-        if params['method'] == 'predict':
+    elif input_type == "sparse":
+        X = mmread(open(infile1, "r"))
+        if params["method"] == "predict":
             preds = estimator.predict(X)
         else:
             preds = estimator.predict_proba(X)
 
     # fasta input
-    elif input_type == 'seq_fasta':
-        if not hasattr(estimator, 'data_batch_generator'):
+    elif input_type == "seq_fasta":
+        if not hasattr(estimator, "data_batch_generator"):
             raise ValueError(
                 "To do prediction on sequences in fasta input, "
                 "the estimator must be a `KerasGBatchClassifier`"
-                "equipped with data_batch_generator!")
-        pyfaidx = get_module('pyfaidx')
+                "equipped with data_batch_generator!"
+            )
+        pyfaidx = get_module("pyfaidx")
         sequences = pyfaidx.Fasta(fasta_path)
         n_seqs = len(sequences.keys())
         X = np.arange(n_seqs)[:, np.newaxis]
         seq_length = estimator.data_batch_generator.seq_length
-        batch_size = getattr(estimator, 'batch_size', 32)
+        batch_size = getattr(estimator, "batch_size", 32)
         steps = (n_seqs + batch_size - 1) // batch_size
 
-        seq_type = params['input_options']['seq_type']
-        klass = try_get_attr(
-            'galaxy_ml.preprocessors', seq_type)
+        seq_type = params["input_options"]["seq_type"]
+        klass = try_get_attr("galaxy_ml.preprocessors", seq_type)
+
+        pred_data_generator = klass(fasta_path, seq_length=seq_length)
 
-        pred_data_generator = klass(
-            fasta_path, seq_length=seq_length)
-
-        if params['method'] == 'predict':
+        if params["method"] == "predict":
             preds = estimator.predict(
-                X, data_generator=pred_data_generator, steps=steps)
+                X, data_generator=pred_data_generator, steps=steps
+            )
         else:
             preds = estimator.predict_proba(
-                X, data_generator=pred_data_generator, steps=steps)
+                X, data_generator=pred_data_generator, steps=steps
+            )
 
     # vcf input
-    elif input_type == 'variant_effect':
-        klass = try_get_attr('galaxy_ml.preprocessors',
-                             'GenomicVariantBatchGenerator')
+    elif input_type == "variant_effect":
+        klass = try_get_attr("galaxy_ml.preprocessors", "GenomicVariantBatchGenerator")
 
-        options = params['input_options']
-        options.pop('selected_input')
-        if options['blacklist_regions'] == 'none':
-            options['blacklist_regions'] = None
+        options = params["input_options"]
+        options.pop("selected_input")
+        if options["blacklist_regions"] == "none":
+            options["blacklist_regions"] = None
 
         pred_data_generator = klass(
-            ref_genome_path=ref_seq, vcf_path=vcf_path, **options)
+            ref_genome_path=ref_seq, vcf_path=vcf_path, **options
+        )
 
         pred_data_generator.set_processing_attrs()
 
@@ -143,9 +139,8 @@
         # predict 1600 sample at once then write to file
         gen_flow = pred_data_generator.flow(batch_size=1600)
 
-        file_writer = open(outfile_predict, 'w')
-        header_row = '\t'.join(['chrom', 'pos', 'name', 'ref',
-                                'alt', 'strand'])
+        file_writer = open(outfile_predict, "w")
+        header_row = "\t".join(["chrom", "pos", "name", "ref", "alt", "strand"])
         file_writer.write(header_row)
         header_done = False
 
@@ -155,23 +150,24 @@
         try:
             while steps_done < len(gen_flow):
                 index_array = next(gen_flow.index_generator)
-                batch_X = gen_flow._get_batches_of_transformed_samples(
-                    index_array)
+                batch_X = gen_flow._get_batches_of_transformed_samples(index_array)
 
-                if params['method'] == 'predict':
+                if params["method"] == "predict":
                     batch_preds = estimator.predict(
                         batch_X,
                         # The presence of `pred_data_generator` below is to
                         # override model carrying data_generator if there
                         # is any.
-                        data_generator=pred_data_generator)
+                        data_generator=pred_data_generator,
+                    )
                 else:
                     batch_preds = estimator.predict_proba(
                         batch_X,
                         # The presence of `pred_data_generator` below is to
                         # override model carrying data_generator if there
                         # is any.
-                        data_generator=pred_data_generator)
+                        data_generator=pred_data_generator,
+                    )
 
                 if batch_preds.ndim == 1:
                     batch_preds = batch_preds[:, np.newaxis]
@@ -181,12 +177,12 @@
 
                 if not header_done:
                     heads = np.arange(batch_preds.shape[-1]).astype(str)
-                    heads_str = '\t'.join(heads)
+                    heads_str = "\t".join(heads)
                     file_writer.write("\t%s\n" % heads_str)
                     header_done = True
 
                 for row in batch_out:
-                    row_str = '\t'.join(row)
+                    row_str = "\t".join(row)
                     file_writer.write("%s\n" % row_str)
 
                 steps_done += 1
@@ -200,18 +196,17 @@
 
     # output
     if len(preds.shape) == 1:
-        rval = pd.DataFrame(preds, columns=['Predicted'])
+        rval = pd.DataFrame(preds, columns=["Predicted"])
     else:
         rval = pd.DataFrame(preds)
 
-    rval.to_csv(outfile_predict, sep='\t', header=True, index=False)
+    rval.to_csv(outfile_predict, sep="\t", header=True, index=False)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     aparser = argparse.ArgumentParser()
     aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
     aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator")
-    aparser.add_argument("-w", "--infile_weights", dest="infile_weights")
     aparser.add_argument("-X", "--infile1", dest="infile1")
     aparser.add_argument("-O", "--outfile_predict", dest="outfile_predict")
     aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
@@ -219,7 +214,12 @@
     aparser.add_argument("-v", "--vcf_path", dest="vcf_path")
     args = aparser.parse_args()
 
-    main(args.inputs, args.infile_estimator, args.outfile_predict,
-         infile_weights=args.infile_weights, infile1=args.infile1,
-         fasta_path=args.fasta_path, ref_seq=args.ref_seq,
-         vcf_path=args.vcf_path)
+    main(
+        args.inputs,
+        args.infile_estimator,
+        args.outfile_predict,
+        infile1=args.infile1,
+        fasta_path=args.fasta_path,
+        ref_seq=args.ref_seq,
+        vcf_path=args.vcf_path,
+    )