Mercurial > repos > iuc > virhunter
diff predict.py @ 1:341bcf4d4fcd draft
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit c3685ed6a70b47012b62b95a2a3db062bd3b7475
| author | iuc |
|---|---|
| date | Thu, 05 Jan 2023 14:27:31 +0000 |
| parents | 6052fcc0d113 |
| children | 206c8054d74a |
line wrap: on
line diff
--- a/predict.py Wed Nov 09 12:18:36 2022 +0000 +++ b/predict.py Thu Jan 05 14:27:31 2023 +0000 @@ -9,7 +9,7 @@ import pandas as pd from Bio import SeqIO from joblib import load -from models import model_5, model_7 +from models import model_10, model_5, model_7 from utils import preprocess as pp os.environ["CUDA_VISIBLE_DEVICES"] = "" @@ -18,7 +18,7 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' -def predict_nn(ds_path, nn_weights_path, length, batch_size=256): +def predict_nn(ds_path, nn_weights_path, length, use_10, batch_size=256): """ Breaks down contigs into fragments and uses pretrained neural networks to give predictions for fragments @@ -37,10 +37,14 @@ "pred_plant_7": [], "pred_vir_7": [], "pred_bact_7": [], - # "pred_plant_10": [], - # "pred_vir_10": [], - # "pred_bact_10": [], } + if use_10: + out_table_ = { + "pred_plant_10": [], + "pred_vir_10": [], + "pred_bact_10": [], + } + out_table.update(out_table_) if not seqs_: raise ValueError("All sequences were smaller than length of the model") test_fragments = [] @@ -56,24 +60,32 @@ out_table["fragment"].append(j) test_encoded = pp.one_hot_encode(test_fragments) test_encoded_rc = pp.one_hot_encode(test_fragments_rc) - # for model, s in zip([model_5.model(length), model_7.model(length), model_10.model(length)], [5, 7, 10]): - for model, s in zip([model_5.model(length), model_7.model(length)], [5, 7]): + if use_10: + zipped_models = zip([model_5.model(length), model_7.model(length), model_10.model(length)], [5, 7, 10]) + else: + zipped_models = zip([model_5.model(length), model_7.model(length)], [5, 7]) + for model, s in zipped_models: model.load_weights(Path(nn_weights_path, f"model_{s}_{length}.h5")) prediction = model.predict([test_encoded, test_encoded_rc], batch_size) out_table[f"pred_plant_{s}"].extend(list(prediction[..., 0])) out_table[f"pred_vir_{s}"].extend(list(prediction[..., 1])) out_table[f"pred_bact_{s}"].extend(list(prediction[..., 2])) + return pd.DataFrame(out_table) -def predict_rf(df, rf_weights_path, length): +def predict_rf(df, rf_weights_path, length, use_10): """ Using predictions by predict_nn and weights of a trained RF classifier gives a single prediction for a fragment """ clf = load(Path(rf_weights_path, f"RF_{length}.joblib")) - X = df[["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7"]] - # X = ["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7", "pred_plant_10", "pred_vir_10", ]] + if use_10: + X = df[ + ["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7", "pred_plant_10", "pred_vir_10", ]] + else: + X = df[ + ["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7", ]] y_pred = clf.predict(X) mapping = {0: "plant", 1: "virus", 2: "bacteria"} df["RF_decision"] = np.vectorize(mapping.get)(y_pred) @@ -89,12 +101,10 @@ Based on predictions of predict_rf for fragments gives a final prediction for the whole contig """ df = ( - df.groupby(["id", "length", 'RF_decision'], sort=False) - .size() - .unstack(fill_value=0) + df.groupby(["id", "length", 'RF_decision'], sort=False).size().unstack(fill_value=0) ) df = df.reset_index() - df = df.reindex(['length', 'id', 'virus', 'plant', 'bacteria'], axis=1) + df = df.reindex(['length', 'id', 'virus', 'plant', 'bacteria'], axis=1).fillna(value=0) conditions = [ (df['virus'] > df['plant']) & (df['virus'] > df['bacteria']), (df['plant'] > df['virus']) & (df['plant'] > df['bacteria']), @@ -131,7 +141,7 @@ assert Path(weights).exists(), f'{weights} does not exist' assert isinstance(limit, int), 'limit should be an integer' Path(out_path).mkdir(parents=True, exist_ok=True) - + use_10 = Path(weights, 'model_10_500.h5').exists() for ts in test_ds: dfs_fr = [] dfs_cont = [] @@ -141,12 +151,14 @@ ds_path=ts, nn_weights_path=weights, length=l_, + use_10=use_10 ) print(df) df = predict_rf( df=df, rf_weights_path=weights, length=l_, + use_10=use_10 ) df = df.round(3) dfs_fr.append(df) @@ -178,7 +190,7 @@ parser.add_argument("--weights", help="path to the folder containing weights for NN and RF modules trained on 500 and 1000 fragment lengths (str)") parser.add_argument("--out_path", help="path to the folder to store predictions (str)") parser.add_argument("--return_viral", help="whether to return contigs annotated as viral in separate fasta file (True/False)") - parser.add_argument("--limit", help="Do predictions only for contigs > l. We suggest l=750. (int)", type=int) + parser.add_argument("--limit", help="Do predictions only for contigs > l. We suggest l=750. (int)", type=int, default=750) args = parser.parse_args() if args.test_ds:
