Mercurial > repos > bgruening > sklearn_fitted_model_eval
changeset 13:21dccb45999c draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 80417bf0158a9b596e485dd66408f738f405145a
| author | bgruening | 
|---|---|
| date | Mon, 02 Oct 2023 09:00:28 +0000 | 
| parents | 9d067f053602 | 
| children | |
| files | keras_train_and_eval.py ml_visualization_ex.py | 
| diffstat | 2 files changed, 97 insertions(+), 6 deletions(-) [+] | 
line wrap: on
 line diff
--- a/keras_train_and_eval.py Wed Aug 09 12:27:50 2023 +0000 +++ b/keras_train_and_eval.py Mon Oct 02 09:00:28 2023 +0000 @@ -188,6 +188,7 @@ infile1, infile2, outfile_result, + outfile_history=None, outfile_object=None, outfile_y_true=None, outfile_y_preds=None, @@ -215,6 +216,9 @@ outfile_result : str File path to save the results, either cv_results or test result. + outfile_history : str, optional + File path to save the training history. + outfile_object : str, optional File path to save searchCV object. @@ -253,9 +257,7 @@ swapping = params["experiment_schemes"]["hyperparams_swapping"] swap_params = _eval_swap_params(swapping) estimator.set_params(**swap_params) - estimator_params = estimator.get_params() - # store read dataframe object loaded_df = {} @@ -448,12 +450,20 @@ # train and eval if hasattr(estimator, "config") and hasattr(estimator, "model_type"): if exp_scheme == "train_val_test": - estimator.fit(X_train, y_train, validation_data=(X_val, y_val)) + history = estimator.fit(X_train, y_train, validation_data=(X_val, y_val)) else: - estimator.fit(X_train, y_train, validation_data=(X_test, y_test)) + history = estimator.fit(X_train, y_train, validation_data=(X_test, y_test)) else: - estimator.fit(X_train, y_train) - + history = estimator.fit(X_train, y_train) + if "callbacks" in estimator_params: + for cb in estimator_params["callbacks"]: + if cb["callback_selection"]["callback_type"] == "CSVLogger": + hist_df = pd.DataFrame(history.history) + hist_df["epoch"] = np.arange(1, estimator_params["epochs"] + 1) + epo_col = hist_df.pop('epoch') + hist_df.insert(0, 'epoch', epo_col) + hist_df.to_csv(path_or_buf=outfile_history, sep="\t", header=True, index=False) + break if isinstance(estimator, KerasGBatchClassifier): scores = {} steps = estimator.prediction_steps @@ -526,6 +536,7 @@ aparser.add_argument("-X", "--infile1", dest="infile1") aparser.add_argument("-y", "--infile2", dest="infile2") aparser.add_argument("-O", "--outfile_result", dest="outfile_result") + aparser.add_argument("-hi", "--outfile_history", dest="outfile_history") aparser.add_argument("-o", "--outfile_object", dest="outfile_object") aparser.add_argument("-l", "--outfile_y_true", dest="outfile_y_true") aparser.add_argument("-p", "--outfile_y_preds", dest="outfile_y_preds") @@ -542,6 +553,7 @@ args.infile1, args.infile2, args.outfile_result, + outfile_history=args.outfile_history, outfile_object=args.outfile_object, outfile_y_true=args.outfile_y_true, outfile_y_preds=args.outfile_y_preds,
--- a/ml_visualization_ex.py Wed Aug 09 12:27:50 2023 +0000 +++ b/ml_visualization_ex.py Mon Oct 02 09:00:28 2023 +0000 @@ -15,6 +15,7 @@ from sklearn.metrics import ( auc, average_precision_score, + confusion_matrix, precision_recall_curve, roc_curve, ) @@ -258,6 +259,30 @@ os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output")) +def get_dataframe(file_path, plot_selection, header_name, column_name): + header = "infer" if plot_selection[header_name] else None + column_option = plot_selection[column_name]["selected_column_selector_option"] + if column_option in [ + "by_index_number", + "all_but_by_index_number", + "by_header_name", + "all_but_by_header_name", + ]: + col = plot_selection[column_name]["col1"] + else: + col = None + _, input_df = read_columns( + file_path, + c=col, + c_option=column_option, + return_df=True, + sep="\t", + header=header, + parse_dates=True, + ) + return input_df + + def main( inputs, infile_estimator=None, @@ -271,6 +296,10 @@ targets=None, fasta_path=None, model_config=None, + true_labels=None, + predicted_labels=None, + plot_color=None, + title=None, ): """ Parameter @@ -311,6 +340,18 @@ model_config : str, default is None File path to dataset containing JSON config for neural networks + + true_labels : str, default is None + File path to dataset containing true labels + + predicted_labels : str, default is None + File path to dataset containing true predicted labels + + plot_color : str, default is None + Color of the confusion matrix heatmap + + title : str, default is None + Title of the confusion matrix heatmap """ warnings.simplefilter("ignore") @@ -534,6 +575,36 @@ return 0 + elif plot_type == "classification_confusion_matrix": + plot_selection = params["plotting_selection"] + input_true = get_dataframe( + true_labels, plot_selection, "header_true", "column_selector_options_true" + ) + header_predicted = "infer" if plot_selection["header_predicted"] else None + input_predicted = pd.read_csv( + predicted_labels, sep="\t", parse_dates=True, header=header_predicted + ) + true_classes = input_true.iloc[:, -1].copy() + predicted_classes = input_predicted.iloc[:, -1].copy() + axis_labels = list(set(true_classes)) + c_matrix = confusion_matrix(true_classes, predicted_classes) + fig, ax = plt.subplots(figsize=(7, 7)) + im = plt.imshow(c_matrix, cmap=plot_color) + for i in range(len(c_matrix)): + for j in range(len(c_matrix)): + ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k") + ax.set_ylabel("True class labels") + ax.set_xlabel("Predicted class labels") + ax.set_title(title) + ax.set_xticks(axis_labels) + ax.set_yticks(axis_labels) + fig.colorbar(im, ax=ax) + fig.tight_layout() + plt.savefig("output.png", dpi=125) + os.rename("output.png", "output") + + return 0 + # save pdf file to disk # fig.write_image("image.pdf", format='pdf') # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) @@ -553,6 +624,10 @@ aparser.add_argument("-t", "--targets", dest="targets") aparser.add_argument("-f", "--fasta_path", dest="fasta_path") aparser.add_argument("-c", "--model_config", dest="model_config") + aparser.add_argument("-tl", "--true_labels", dest="true_labels") + aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels") + aparser.add_argument("-pc", "--plot_color", dest="plot_color") + aparser.add_argument("-pt", "--title", dest="title") args = aparser.parse_args() main( @@ -568,4 +643,8 @@ targets=args.targets, fasta_path=args.fasta_path, model_config=args.model_config, + true_labels=args.true_labels, + predicted_labels=args.predicted_labels, + plot_color=args.plot_color, + title=args.title, )
