Mercurial > repos > bgruening > plotly_ml_performance_plots
comparison plot_ml_performance.py @ 3:e73eb091612b draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/plotly_ml_performance_plots commit daa111fcd8391d451aab39110251864fd120edf0
| author | bgruening |
|---|---|
| date | Wed, 07 Aug 2024 10:20:05 +0000 |
| parents | 2cfa4aabda3e |
| children |
comparison
equal
deleted
inserted
replaced
| 2:2cfa4aabda3e | 3:e73eb091612b |
|---|---|
| 1 import argparse | 1 import argparse |
| 2 | |
| 3 import matplotlib.pyplot as plt | |
| 2 import pandas as pd | 4 import pandas as pd |
| 3 import plotly | 5 import plotly |
| 4 import pickle | |
| 5 import plotly.graph_objs as go | 6 import plotly.graph_objs as go |
| 6 from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, roc_curve, auc | 7 from galaxy_ml.model_persist import load_model_from_h5 |
| 8 from galaxy_ml.utils import clean_params | |
| 9 from sklearn.metrics import ( | |
| 10 auc, | |
| 11 confusion_matrix, | |
| 12 precision_recall_fscore_support, | |
| 13 roc_curve, | |
| 14 ) | |
| 7 from sklearn.preprocessing import label_binarize | 15 from sklearn.preprocessing import label_binarize |
| 8 | 16 |
| 9 | 17 |
| 10 def main(infile_input, infile_output, infile_trained_model): | 18 def main(infile_input, infile_output, infile_trained_model): |
| 11 """ | 19 """ |
| 12 Produce an interactive confusion matrix (heatmap), precision, recall, fscore and auc plots | 20 Produce an interactive confusion matrix (heatmap), precision, recall, fscore and auc plots |
| 13 Args: | 21 Args: |
| 14 infile_input: str, input tabular file with true labels | 22 infile_input: str, input tabular file with true labels |
| 15 infile_output: str, input tabular file with predicted labels | 23 infile_output: str, input tabular file with predicted labels |
| 16 infile_trained_model: str, input trained model file (zip) | 24 infile_trained_model: str, input trained model file (h5mlm) |
| 17 """ | 25 """ |
| 18 | 26 |
| 19 df_input = pd.read_csv(infile_input, sep='\t', parse_dates=True) | 27 df_input = pd.read_csv(infile_input, sep="\t", parse_dates=True) |
| 20 df_output = pd.read_csv(infile_output, sep='\t', parse_dates=True) | 28 df_output = pd.read_csv(infile_output, sep="\t", parse_dates=True) |
| 21 true_labels = df_input.iloc[:, -1].copy() | 29 true_labels = df_input.iloc[:, -1].copy() |
| 22 predicted_labels = df_output.iloc[:, -1].copy() | 30 predicted_labels = df_output.iloc[:, -1].copy() |
| 23 axis_labels = list(set(true_labels)) | 31 axis_labels = list(set(true_labels)) |
| 24 c_matrix = confusion_matrix(true_labels, predicted_labels) | 32 c_matrix = confusion_matrix(true_labels, predicted_labels) |
| 25 data = [ | 33 fig, ax = plt.subplots(figsize=(7, 7)) |
| 26 go.Heatmap( | 34 im = plt.imshow(c_matrix, cmap="viridis") |
| 27 z=c_matrix, | 35 # add number of samples to each cell of confusion matrix plot |
| 28 x=axis_labels, | 36 for i in range(len(c_matrix)): |
| 29 y=axis_labels, | 37 for j in range(len(c_matrix)): |
| 30 colorscale='Portland', | 38 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k") |
| 31 ) | 39 ax.set_ylabel("True class labels") |
| 32 ] | 40 ax.set_xlabel("Predicted class labels") |
| 41 ax.set_title("Confusion Matrix between true and predicted class labels") | |
| 42 ax.set_xticks(axis_labels) | |
| 43 ax.set_yticks(axis_labels) | |
| 44 fig.colorbar(im, ax=ax) | |
| 45 fig.tight_layout() | |
| 46 plt.savefig("output_confusion.png", dpi=120) | |
| 33 | 47 |
| 34 layout = go.Layout( | 48 # plot precision, recall and f_score for each class label |
| 35 title='Confusion Matrix between true and predicted class labels', | 49 precision, recall, f_score, _ = precision_recall_fscore_support( |
| 36 xaxis=dict(title='Predicted class labels'), | 50 true_labels, predicted_labels |
| 37 yaxis=dict(title='True class labels') | |
| 38 ) | 51 ) |
| 39 | 52 |
| 40 fig = go.Figure(data=data, layout=layout) | |
| 41 plotly.offline.plot(fig, filename="output_confusion.html", auto_open=False) | |
| 42 | |
| 43 # plot precision, recall and f_score for each class label | |
| 44 precision, recall, f_score, _ = precision_recall_fscore_support(true_labels, predicted_labels) | |
| 45 | |
| 46 trace_precision = go.Scatter( | 53 trace_precision = go.Scatter( |
| 47 x=axis_labels, | 54 x=axis_labels, y=precision, mode="lines+markers", name="Precision" |
| 48 y=precision, | |
| 49 mode='lines+markers', | |
| 50 name='Precision' | |
| 51 ) | 55 ) |
| 52 | 56 |
| 53 trace_recall = go.Scatter( | 57 trace_recall = go.Scatter( |
| 54 x=axis_labels, | 58 x=axis_labels, y=recall, mode="lines+markers", name="Recall" |
| 55 y=recall, | |
| 56 mode='lines+markers', | |
| 57 name='Recall' | |
| 58 ) | 59 ) |
| 59 | 60 |
| 60 trace_fscore = go.Scatter( | 61 trace_fscore = go.Scatter( |
| 61 x=axis_labels, | 62 x=axis_labels, y=f_score, mode="lines+markers", name="F-score" |
| 62 y=f_score, | |
| 63 mode='lines+markers', | |
| 64 name='F-score' | |
| 65 ) | 63 ) |
| 66 | 64 |
| 67 layout_prf = go.Layout( | 65 layout_prf = go.Layout( |
| 68 title='Precision, recall and f-score of true and predicted class labels', | 66 title="Precision, recall and f-score of true and predicted class labels", |
| 69 xaxis=dict(title='Class labels'), | 67 xaxis=dict(title="Class labels"), |
| 70 yaxis=dict(title='Precision, recall and f-score') | 68 yaxis=dict(title="Precision, recall and f-score"), |
| 71 ) | 69 ) |
| 72 | 70 |
| 73 data_prf = [trace_precision, trace_recall, trace_fscore] | 71 data_prf = [trace_precision, trace_recall, trace_fscore] |
| 74 fig_prf = go.Figure(data=data_prf, layout=layout_prf) | 72 fig_prf = go.Figure(data=data_prf, layout=layout_prf) |
| 75 plotly.offline.plot(fig_prf, filename="output_prf.html", auto_open=False) | 73 plotly.offline.plot(fig_prf, filename="output_prf.html", auto_open=False) |
| 76 | 74 |
| 77 # plot roc and auc curves for different classes | 75 # plot roc and auc curves for different classes |
| 78 with open(infile_trained_model, 'rb') as model_file: | 76 classifier_object = load_model_from_h5(infile_trained_model) |
| 79 model = pickle.load(model_file) | 77 model = clean_params(classifier_object) |
| 80 | 78 |
| 81 # remove the last column (label column) | 79 # remove the last column (label column) |
| 82 test_data = df_input.iloc[:, :-1] | 80 test_data = df_input.iloc[:, :-1] |
| 83 model_items = dir(model) | 81 model_items = dir(model) |
| 84 | 82 |
| 85 try: | 83 try: |
| 86 # find the probability estimating method | 84 # find the probability estimating method |
| 87 if 'predict_proba' in model_items: | 85 if "predict_proba" in model_items: |
| 88 y_score = model.predict_proba(test_data) | 86 y_score = model.predict_proba(test_data) |
| 89 elif 'decision_function' in model_items: | 87 elif "decision_function" in model_items: |
| 90 y_score = model.decision_function(test_data) | 88 y_score = model.decision_function(test_data) |
| 91 | 89 |
| 92 true_labels_list = true_labels.tolist() | 90 true_labels_list = true_labels.tolist() |
| 93 one_hot_labels = label_binarize(true_labels_list, classes=axis_labels) | 91 one_hot_labels = label_binarize(true_labels_list, classes=axis_labels) |
| 94 data_roc = list() | 92 data_roc = list() |
| 102 roc_auc[i] = auc(fpr[i], tpr[i]) | 100 roc_auc[i] = auc(fpr[i], tpr[i]) |
| 103 for i in range(len(axis_labels)): | 101 for i in range(len(axis_labels)): |
| 104 trace = go.Scatter( | 102 trace = go.Scatter( |
| 105 x=fpr[i], | 103 x=fpr[i], |
| 106 y=tpr[i], | 104 y=tpr[i], |
| 107 mode='lines+markers', | 105 mode="lines+markers", |
| 108 name='ROC curve of class {0} (AUC = {1:0.2f})'.format(i, roc_auc[i]) | 106 name="ROC curve of class {0} (AUC = {1:0.2f})".format( |
| 107 i, roc_auc[i] | |
| 108 ), | |
| 109 ) | 109 ) |
| 110 data_roc.append(trace) | 110 data_roc.append(trace) |
| 111 else: | 111 else: |
| 112 try: | 112 try: |
| 113 y_score_binary = y_score[:, 1] | 113 y_score_binary = y_score[:, 1] |
| 114 except: | 114 except Exception: |
| 115 y_score_binary = y_score | 115 y_score_binary = y_score |
| 116 fpr, tpr, _ = roc_curve(one_hot_labels, y_score_binary, pos_label=1) | 116 fpr, tpr, _ = roc_curve(one_hot_labels, y_score_binary, pos_label=1) |
| 117 roc_auc = auc(fpr, tpr) | 117 roc_auc = auc(fpr, tpr) |
| 118 trace = go.Scatter( | 118 trace = go.Scatter( |
| 119 x=fpr, | 119 x=fpr, |
| 120 y=tpr, | 120 y=tpr, |
| 121 mode='lines+markers', | 121 mode="lines+markers", |
| 122 name='ROC curve (AUC = {0:0.2f})'.format(roc_auc) | 122 name="ROC curve (AUC = {0:0.2f})".format(roc_auc), |
| 123 ) | 123 ) |
| 124 data_roc.append(trace) | 124 data_roc.append(trace) |
| 125 | 125 |
| 126 trace_diag = go.Scatter( | 126 trace_diag = go.Scatter(x=[0, 1], y=[0, 1], mode="lines", name="Chance") |
| 127 x=[0, 1], | |
| 128 y=[0, 1], | |
| 129 mode='lines', | |
| 130 name='Chance' | |
| 131 ) | |
| 132 data_roc.append(trace_diag) | 127 data_roc.append(trace_diag) |
| 133 layout_roc = go.Layout( | 128 layout_roc = go.Layout( |
| 134 title='Receiver operating characteristics (ROC) and area under curve (AUC)', | 129 title="Receiver operating characteristics (ROC) and area under curve (AUC)", |
| 135 xaxis=dict(title='False positive rate'), | 130 xaxis=dict(title="False positive rate"), |
| 136 yaxis=dict(title='True positive rate') | 131 yaxis=dict(title="True positive rate"), |
| 137 ) | 132 ) |
| 138 | 133 |
| 139 fig_roc = go.Figure(data=data_roc, layout=layout_roc) | 134 fig_roc = go.Figure(data=data_roc, layout=layout_roc) |
| 140 plotly.offline.plot(fig_roc, filename="output_roc.html", auto_open=False) | 135 plotly.offline.plot(fig_roc, filename="output_roc.html", auto_open=False) |
| 141 | 136 |
| 142 except Exception as exp: | 137 except Exception as exp: |
| 143 print("Plotting the ROC-AUC graph failed. This exception was raised: {}".format(exp)) | 138 print( |
| 139 "Plotting the ROC-AUC graph failed. This exception was raised: {}".format( | |
| 140 exp | |
| 141 ) | |
| 142 ) | |
| 144 | 143 |
| 145 | 144 |
| 146 if __name__ == "__main__": | 145 if __name__ == "__main__": |
| 147 aparser = argparse.ArgumentParser() | 146 aparser = argparse.ArgumentParser() |
| 148 aparser.add_argument("-i", "--input", dest="infile_input", required=True) | 147 aparser.add_argument("-i", "--input", dest="infile_input", required=True) |
