Mercurial > repos > bgruening > sklearn_label_encoder
comparison ml_visualization_ex.py @ 0:03155260beb3 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ea12f973df4b97a2691d9e4ce6bf6fae59d57717"
| author | bgruening |
|---|---|
| date | Fri, 30 Apr 2021 23:36:38 +0000 |
| parents | |
| children | b008b609205e |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:03155260beb3 |
|---|---|
| 1 import argparse | |
| 2 import json | |
| 3 import os | |
| 4 import warnings | |
| 5 | |
| 6 import matplotlib | |
| 7 import matplotlib.pyplot as plt | |
| 8 import numpy as np | |
| 9 import pandas as pd | |
| 10 import plotly | |
| 11 import plotly.graph_objs as go | |
| 12 from galaxy_ml.utils import load_model, read_columns, SafeEval | |
| 13 from keras.models import model_from_json | |
| 14 from keras.utils import plot_model | |
| 15 from sklearn.feature_selection.base import SelectorMixin | |
| 16 from sklearn.metrics import (auc, average_precision_score, confusion_matrix, | |
| 17 precision_recall_curve, roc_curve) | |
| 18 from sklearn.pipeline import Pipeline | |
| 19 | |
| 20 safe_eval = SafeEval() | |
| 21 | |
| 22 # plotly default colors | |
| 23 default_colors = [ | |
| 24 "#1f77b4", # muted blue | |
| 25 "#ff7f0e", # safety orange | |
| 26 "#2ca02c", # cooked asparagus green | |
| 27 "#d62728", # brick red | |
| 28 "#9467bd", # muted purple | |
| 29 "#8c564b", # chestnut brown | |
| 30 "#e377c2", # raspberry yogurt pink | |
| 31 "#7f7f7f", # middle gray | |
| 32 "#bcbd22", # curry yellow-green | |
| 33 "#17becf", # blue-teal | |
| 34 ] | |
| 35 | |
| 36 | |
| 37 def visualize_pr_curve_plotly(df1, df2, pos_label, title=None): | |
| 38 """output pr-curve in html using plotly | |
| 39 | |
| 40 df1 : pandas.DataFrame | |
| 41 Containing y_true | |
| 42 df2 : pandas.DataFrame | |
| 43 Containing y_score | |
| 44 pos_label : None | |
| 45 The label of positive class | |
| 46 title : str | |
| 47 Plot title | |
| 48 """ | |
| 49 data = [] | |
| 50 for idx in range(df1.shape[1]): | |
| 51 y_true = df1.iloc[:, idx].values | |
| 52 y_score = df2.iloc[:, idx].values | |
| 53 | |
| 54 precision, recall, _ = precision_recall_curve( | |
| 55 y_true, y_score, pos_label=pos_label | |
| 56 ) | |
| 57 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1) | |
| 58 | |
| 59 trace = go.Scatter( | |
| 60 x=recall, | |
| 61 y=precision, | |
| 62 mode="lines", | |
| 63 marker=dict(color=default_colors[idx % len(default_colors)]), | |
| 64 name="%s (area = %.3f)" % (idx, ap), | |
| 65 ) | |
| 66 data.append(trace) | |
| 67 | |
| 68 layout = go.Layout( | |
| 69 xaxis=dict(title="Recall", linecolor="lightslategray", linewidth=1), | |
| 70 yaxis=dict(title="Precision", linecolor="lightslategray", linewidth=1), | |
| 71 title=dict( | |
| 72 text=title or "Precision-Recall Curve", | |
| 73 x=0.5, | |
| 74 y=0.92, | |
| 75 xanchor="center", | |
| 76 yanchor="top", | |
| 77 ), | |
| 78 font=dict(family="sans-serif", size=11), | |
| 79 # control backgroud colors | |
| 80 plot_bgcolor="rgba(255,255,255,0)", | |
| 81 ) | |
| 82 """ | |
| 83 legend=dict( | |
| 84 x=0.95, | |
| 85 y=0, | |
| 86 traceorder="normal", | |
| 87 font=dict( | |
| 88 family="sans-serif", | |
| 89 size=9, | |
| 90 color="black" | |
| 91 ), | |
| 92 bgcolor="LightSteelBlue", | |
| 93 bordercolor="Black", | |
| 94 borderwidth=2 | |
| 95 ),""" | |
| 96 | |
| 97 fig = go.Figure(data=data, layout=layout) | |
| 98 | |
| 99 plotly.offline.plot(fig, filename="output.html", auto_open=False) | |
| 100 # to be discovered by `from_work_dir` | |
| 101 os.rename("output.html", "output") | |
| 102 | |
| 103 | |
| 104 def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None): | |
| 105 """visualize pr-curve using matplotlib and output svg image""" | |
| 106 backend = matplotlib.get_backend() | |
| 107 if "inline" not in backend: | |
| 108 matplotlib.use("SVG") | |
| 109 plt.style.use("seaborn-colorblind") | |
| 110 plt.figure() | |
| 111 | |
| 112 for idx in range(df1.shape[1]): | |
| 113 y_true = df1.iloc[:, idx].values | |
| 114 y_score = df2.iloc[:, idx].values | |
| 115 | |
| 116 precision, recall, _ = precision_recall_curve( | |
| 117 y_true, y_score, pos_label=pos_label | |
| 118 ) | |
| 119 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1) | |
| 120 | |
| 121 plt.step( | |
| 122 recall, | |
| 123 precision, | |
| 124 "r-", | |
| 125 color="black", | |
| 126 alpha=0.3, | |
| 127 lw=1, | |
| 128 where="post", | |
| 129 label="%s (area = %.3f)" % (idx, ap), | |
| 130 ) | |
| 131 | |
| 132 plt.xlim([0.0, 1.0]) | |
| 133 plt.ylim([0.0, 1.05]) | |
| 134 plt.xlabel("Recall") | |
| 135 plt.ylabel("Precision") | |
| 136 title = title or "Precision-Recall Curve" | |
| 137 plt.title(title) | |
| 138 folder = os.getcwd() | |
| 139 plt.savefig(os.path.join(folder, "output.svg"), format="svg") | |
| 140 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output")) | |
| 141 | |
| 142 | |
| 143 def visualize_roc_curve_plotly(df1, df2, pos_label, drop_intermediate=True, title=None): | |
| 144 """output roc-curve in html using plotly | |
| 145 | |
| 146 df1 : pandas.DataFrame | |
| 147 Containing y_true | |
| 148 df2 : pandas.DataFrame | |
| 149 Containing y_score | |
| 150 pos_label : None | |
| 151 The label of positive class | |
| 152 drop_intermediate : bool | |
| 153 Whether to drop some suboptimal thresholds | |
| 154 title : str | |
| 155 Plot title | |
| 156 """ | |
| 157 data = [] | |
| 158 for idx in range(df1.shape[1]): | |
| 159 y_true = df1.iloc[:, idx].values | |
| 160 y_score = df2.iloc[:, idx].values | |
| 161 | |
| 162 fpr, tpr, _ = roc_curve( | |
| 163 y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate | |
| 164 ) | |
| 165 roc_auc = auc(fpr, tpr) | |
| 166 | |
| 167 trace = go.Scatter( | |
| 168 x=fpr, | |
| 169 y=tpr, | |
| 170 mode="lines", | |
| 171 marker=dict(color=default_colors[idx % len(default_colors)]), | |
| 172 name="%s (area = %.3f)" % (idx, roc_auc), | |
| 173 ) | |
| 174 data.append(trace) | |
| 175 | |
| 176 layout = go.Layout( | |
| 177 xaxis=dict( | |
| 178 title="False Positive Rate", linecolor="lightslategray", linewidth=1 | |
| 179 ), | |
| 180 yaxis=dict(title="True Positive Rate", linecolor="lightslategray", linewidth=1), | |
| 181 title=dict( | |
| 182 text=title or "Receiver Operating Characteristic (ROC) Curve", | |
| 183 x=0.5, | |
| 184 y=0.92, | |
| 185 xanchor="center", | |
| 186 yanchor="top", | |
| 187 ), | |
| 188 font=dict(family="sans-serif", size=11), | |
| 189 # control backgroud colors | |
| 190 plot_bgcolor="rgba(255,255,255,0)", | |
| 191 ) | |
| 192 """ | |
| 193 # legend=dict( | |
| 194 # x=0.95, | |
| 195 # y=0, | |
| 196 # traceorder="normal", | |
| 197 # font=dict( | |
| 198 # family="sans-serif", | |
| 199 # size=9, | |
| 200 # color="black" | |
| 201 # ), | |
| 202 # bgcolor="LightSteelBlue", | |
| 203 # bordercolor="Black", | |
| 204 # borderwidth=2 | |
| 205 # ), | |
| 206 """ | |
| 207 | |
| 208 fig = go.Figure(data=data, layout=layout) | |
| 209 | |
| 210 plotly.offline.plot(fig, filename="output.html", auto_open=False) | |
| 211 # to be discovered by `from_work_dir` | |
| 212 os.rename("output.html", "output") | |
| 213 | |
| 214 | |
| 215 def visualize_roc_curve_matplotlib( | |
| 216 df1, df2, pos_label, drop_intermediate=True, title=None | |
| 217 ): | |
| 218 """visualize roc-curve using matplotlib and output svg image""" | |
| 219 backend = matplotlib.get_backend() | |
| 220 if "inline" not in backend: | |
| 221 matplotlib.use("SVG") | |
| 222 plt.style.use("seaborn-colorblind") | |
| 223 plt.figure() | |
| 224 | |
| 225 for idx in range(df1.shape[1]): | |
| 226 y_true = df1.iloc[:, idx].values | |
| 227 y_score = df2.iloc[:, idx].values | |
| 228 | |
| 229 fpr, tpr, _ = roc_curve( | |
| 230 y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate | |
| 231 ) | |
| 232 roc_auc = auc(fpr, tpr) | |
| 233 | |
| 234 plt.step( | |
| 235 fpr, | |
| 236 tpr, | |
| 237 "r-", | |
| 238 color="black", | |
| 239 alpha=0.3, | |
| 240 lw=1, | |
| 241 where="post", | |
| 242 label="%s (area = %.3f)" % (idx, roc_auc), | |
| 243 ) | |
| 244 | |
| 245 plt.xlim([0.0, 1.0]) | |
| 246 plt.ylim([0.0, 1.05]) | |
| 247 plt.xlabel("False Positive Rate") | |
| 248 plt.ylabel("True Positive Rate") | |
| 249 title = title or "Receiver Operating Characteristic (ROC) Curve" | |
| 250 plt.title(title) | |
| 251 folder = os.getcwd() | |
| 252 plt.savefig(os.path.join(folder, "output.svg"), format="svg") | |
| 253 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output")) | |
| 254 | |
| 255 | |
| 256 def get_dataframe(file_path, plot_selection, header_name, column_name): | |
| 257 header = "infer" if plot_selection[header_name] else None | |
| 258 column_option = plot_selection[column_name]["selected_column_selector_option"] | |
| 259 if column_option in [ | |
| 260 "by_index_number", | |
| 261 "all_but_by_index_number", | |
| 262 "by_header_name", | |
| 263 "all_but_by_header_name", | |
| 264 ]: | |
| 265 col = plot_selection[column_name]["col1"] | |
| 266 else: | |
| 267 col = None | |
| 268 _, input_df = read_columns( | |
| 269 file_path, | |
| 270 c=col, | |
| 271 c_option=column_option, | |
| 272 return_df=True, | |
| 273 sep="\t", | |
| 274 header=header, | |
| 275 parse_dates=True, | |
| 276 ) | |
| 277 return input_df | |
| 278 | |
| 279 | |
| 280 def main( | |
| 281 inputs, | |
| 282 infile_estimator=None, | |
| 283 infile1=None, | |
| 284 infile2=None, | |
| 285 outfile_result=None, | |
| 286 outfile_object=None, | |
| 287 groups=None, | |
| 288 ref_seq=None, | |
| 289 intervals=None, | |
| 290 targets=None, | |
| 291 fasta_path=None, | |
| 292 model_config=None, | |
| 293 true_labels=None, | |
| 294 predicted_labels=None, | |
| 295 plot_color=None, | |
| 296 title=None, | |
| 297 ): | |
| 298 """ | |
| 299 Parameter | |
| 300 --------- | |
| 301 inputs : str | |
| 302 File path to galaxy tool parameter | |
| 303 | |
| 304 infile_estimator : str, default is None | |
| 305 File path to estimator | |
| 306 | |
| 307 infile1 : str, default is None | |
| 308 File path to dataset containing features or true labels. | |
| 309 | |
| 310 infile2 : str, default is None | |
| 311 File path to dataset containing target values or predicted | |
| 312 probabilities. | |
| 313 | |
| 314 outfile_result : str, default is None | |
| 315 File path to save the results, either cv_results or test result | |
| 316 | |
| 317 outfile_object : str, default is None | |
| 318 File path to save searchCV object | |
| 319 | |
| 320 groups : str, default is None | |
| 321 File path to dataset containing groups labels | |
| 322 | |
| 323 ref_seq : str, default is None | |
| 324 File path to dataset containing genome sequence file | |
| 325 | |
| 326 intervals : str, default is None | |
| 327 File path to dataset containing interval file | |
| 328 | |
| 329 targets : str, default is None | |
| 330 File path to dataset compressed target bed file | |
| 331 | |
| 332 fasta_path : str, default is None | |
| 333 File path to dataset containing fasta file | |
| 334 | |
| 335 model_config : str, default is None | |
| 336 File path to dataset containing JSON config for neural networks | |
| 337 | |
| 338 true_labels : str, default is None | |
| 339 File path to dataset containing true labels | |
| 340 | |
| 341 predicted_labels : str, default is None | |
| 342 File path to dataset containing true predicted labels | |
| 343 | |
| 344 plot_color : str, default is None | |
| 345 Color of the confusion matrix heatmap | |
| 346 | |
| 347 title : str, default is None | |
| 348 Title of the confusion matrix heatmap | |
| 349 """ | |
| 350 warnings.simplefilter("ignore") | |
| 351 | |
| 352 with open(inputs, "r") as param_handler: | |
| 353 params = json.load(param_handler) | |
| 354 | |
| 355 title = params["plotting_selection"]["title"].strip() | |
| 356 plot_type = params["plotting_selection"]["plot_type"] | |
| 357 plot_format = params["plotting_selection"]["plot_format"] | |
| 358 | |
| 359 if plot_type == "feature_importances": | |
| 360 with open(infile_estimator, "rb") as estimator_handler: | |
| 361 estimator = load_model(estimator_handler) | |
| 362 | |
| 363 column_option = params["plotting_selection"]["column_selector_options"][ | |
| 364 "selected_column_selector_option" | |
| 365 ] | |
| 366 if column_option in [ | |
| 367 "by_index_number", | |
| 368 "all_but_by_index_number", | |
| 369 "by_header_name", | |
| 370 "all_but_by_header_name", | |
| 371 ]: | |
| 372 c = params["plotting_selection"]["column_selector_options"]["col1"] | |
| 373 else: | |
| 374 c = None | |
| 375 | |
| 376 _, input_df = read_columns( | |
| 377 infile1, | |
| 378 c=c, | |
| 379 c_option=column_option, | |
| 380 return_df=True, | |
| 381 sep="\t", | |
| 382 header="infer", | |
| 383 parse_dates=True, | |
| 384 ) | |
| 385 | |
| 386 feature_names = input_df.columns.values | |
| 387 | |
| 388 if isinstance(estimator, Pipeline): | |
| 389 for st in estimator.steps[:-1]: | |
| 390 if isinstance(st[-1], SelectorMixin): | |
| 391 mask = st[-1].get_support() | |
| 392 feature_names = feature_names[mask] | |
| 393 estimator = estimator.steps[-1][-1] | |
| 394 | |
| 395 if hasattr(estimator, "coef_"): | |
| 396 coefs = estimator.coef_ | |
| 397 else: | |
| 398 coefs = getattr(estimator, "feature_importances_", None) | |
| 399 if coefs is None: | |
| 400 raise RuntimeError( | |
| 401 "The classifier does not expose " | |
| 402 '"coef_" or "feature_importances_" ' | |
| 403 "attributes" | |
| 404 ) | |
| 405 | |
| 406 threshold = params["plotting_selection"]["threshold"] | |
| 407 if threshold is not None: | |
| 408 mask = (coefs > threshold) | (coefs < -threshold) | |
| 409 coefs = coefs[mask] | |
| 410 feature_names = feature_names[mask] | |
| 411 | |
| 412 # sort | |
| 413 indices = np.argsort(coefs)[::-1] | |
| 414 | |
| 415 trace = go.Bar(x=feature_names[indices], y=coefs[indices]) | |
| 416 layout = go.Layout(title=title or "Feature Importances") | |
| 417 fig = go.Figure(data=[trace], layout=layout) | |
| 418 | |
| 419 plotly.offline.plot(fig, filename="output.html", auto_open=False) | |
| 420 # to be discovered by `from_work_dir` | |
| 421 os.rename("output.html", "output") | |
| 422 | |
| 423 return 0 | |
| 424 | |
| 425 elif plot_type in ("pr_curve", "roc_curve"): | |
| 426 df1 = pd.read_csv(infile1, sep="\t", header="infer") | |
| 427 df2 = pd.read_csv(infile2, sep="\t", header="infer").astype(np.float32) | |
| 428 | |
| 429 minimum = params["plotting_selection"]["report_minimum_n_positives"] | |
| 430 # filter out columns whose n_positives is beblow the threhold | |
| 431 if minimum: | |
| 432 mask = df1.sum(axis=0) >= minimum | |
| 433 df1 = df1.loc[:, mask] | |
| 434 df2 = df2.loc[:, mask] | |
| 435 | |
| 436 pos_label = params["plotting_selection"]["pos_label"].strip() or None | |
| 437 | |
| 438 if plot_type == "pr_curve": | |
| 439 if plot_format == "plotly_html": | |
| 440 visualize_pr_curve_plotly(df1, df2, pos_label, title=title) | |
| 441 else: | |
| 442 visualize_pr_curve_matplotlib(df1, df2, pos_label, title) | |
| 443 else: # 'roc_curve' | |
| 444 drop_intermediate = params["plotting_selection"]["drop_intermediate"] | |
| 445 if plot_format == "plotly_html": | |
| 446 visualize_roc_curve_plotly( | |
| 447 df1, | |
| 448 df2, | |
| 449 pos_label, | |
| 450 drop_intermediate=drop_intermediate, | |
| 451 title=title, | |
| 452 ) | |
| 453 else: | |
| 454 visualize_roc_curve_matplotlib( | |
| 455 df1, | |
| 456 df2, | |
| 457 pos_label, | |
| 458 drop_intermediate=drop_intermediate, | |
| 459 title=title, | |
| 460 ) | |
| 461 | |
| 462 return 0 | |
| 463 | |
| 464 elif plot_type == "rfecv_gridscores": | |
| 465 input_df = pd.read_csv(infile1, sep="\t", header="infer") | |
| 466 scores = input_df.iloc[:, 0] | |
| 467 steps = params["plotting_selection"]["steps"].strip() | |
| 468 steps = safe_eval(steps) | |
| 469 | |
| 470 data = go.Scatter( | |
| 471 x=list(range(len(scores))), | |
| 472 y=scores, | |
| 473 text=[str(_) for _ in steps] if steps else None, | |
| 474 mode="lines", | |
| 475 ) | |
| 476 layout = go.Layout( | |
| 477 xaxis=dict(title="Number of features selected"), | |
| 478 yaxis=dict(title="Cross validation score"), | |
| 479 title=dict( | |
| 480 text=title or None, x=0.5, y=0.92, xanchor="center", yanchor="top" | |
| 481 ), | |
| 482 font=dict(family="sans-serif", size=11), | |
| 483 # control backgroud colors | |
| 484 plot_bgcolor="rgba(255,255,255,0)", | |
| 485 ) | |
| 486 """ | |
| 487 # legend=dict( | |
| 488 # x=0.95, | |
| 489 # y=0, | |
| 490 # traceorder="normal", | |
| 491 # font=dict( | |
| 492 # family="sans-serif", | |
| 493 # size=9, | |
| 494 # color="black" | |
| 495 # ), | |
| 496 # bgcolor="LightSteelBlue", | |
| 497 # bordercolor="Black", | |
| 498 # borderwidth=2 | |
| 499 # ), | |
| 500 """ | |
| 501 | |
| 502 fig = go.Figure(data=[data], layout=layout) | |
| 503 plotly.offline.plot(fig, filename="output.html", auto_open=False) | |
| 504 # to be discovered by `from_work_dir` | |
| 505 os.rename("output.html", "output") | |
| 506 | |
| 507 return 0 | |
| 508 | |
| 509 elif plot_type == "learning_curve": | |
| 510 input_df = pd.read_csv(infile1, sep="\t", header="infer") | |
| 511 plot_std_err = params["plotting_selection"]["plot_std_err"] | |
| 512 data1 = go.Scatter( | |
| 513 x=input_df["train_sizes_abs"], | |
| 514 y=input_df["mean_train_scores"], | |
| 515 error_y=dict(array=input_df["std_train_scores"]) if plot_std_err else None, | |
| 516 mode="lines", | |
| 517 name="Train Scores", | |
| 518 ) | |
| 519 data2 = go.Scatter( | |
| 520 x=input_df["train_sizes_abs"], | |
| 521 y=input_df["mean_test_scores"], | |
| 522 error_y=dict(array=input_df["std_test_scores"]) if plot_std_err else None, | |
| 523 mode="lines", | |
| 524 name="Test Scores", | |
| 525 ) | |
| 526 layout = dict( | |
| 527 xaxis=dict(title="No. of samples"), | |
| 528 yaxis=dict(title="Performance Score"), | |
| 529 # modify these configurations to customize image | |
| 530 title=dict( | |
| 531 text=title or "Learning Curve", | |
| 532 x=0.5, | |
| 533 y=0.92, | |
| 534 xanchor="center", | |
| 535 yanchor="top", | |
| 536 ), | |
| 537 font=dict(family="sans-serif", size=11), | |
| 538 # control backgroud colors | |
| 539 plot_bgcolor="rgba(255,255,255,0)", | |
| 540 ) | |
| 541 """ | |
| 542 # legend=dict( | |
| 543 # x=0.95, | |
| 544 # y=0, | |
| 545 # traceorder="normal", | |
| 546 # font=dict( | |
| 547 # family="sans-serif", | |
| 548 # size=9, | |
| 549 # color="black" | |
| 550 # ), | |
| 551 # bgcolor="LightSteelBlue", | |
| 552 # bordercolor="Black", | |
| 553 # borderwidth=2 | |
| 554 # ), | |
| 555 """ | |
| 556 | |
| 557 fig = go.Figure(data=[data1, data2], layout=layout) | |
| 558 plotly.offline.plot(fig, filename="output.html", auto_open=False) | |
| 559 # to be discovered by `from_work_dir` | |
| 560 os.rename("output.html", "output") | |
| 561 | |
| 562 return 0 | |
| 563 | |
| 564 elif plot_type == "keras_plot_model": | |
| 565 with open(model_config, "r") as f: | |
| 566 model_str = f.read() | |
| 567 model = model_from_json(model_str) | |
| 568 plot_model(model, to_file="output.png") | |
| 569 os.rename("output.png", "output") | |
| 570 | |
| 571 return 0 | |
| 572 | |
| 573 elif plot_type == "classification_confusion_matrix": | |
| 574 plot_selection = params["plotting_selection"] | |
| 575 input_true = get_dataframe( | |
| 576 true_labels, plot_selection, "header_true", "column_selector_options_true" | |
| 577 ) | |
| 578 header_predicted = "infer" if plot_selection["header_predicted"] else None | |
| 579 input_predicted = pd.read_csv( | |
| 580 predicted_labels, sep="\t", parse_dates=True, header=header_predicted | |
| 581 ) | |
| 582 true_classes = input_true.iloc[:, -1].copy() | |
| 583 predicted_classes = input_predicted.iloc[:, -1].copy() | |
| 584 axis_labels = list(set(true_classes)) | |
| 585 c_matrix = confusion_matrix(true_classes, predicted_classes) | |
| 586 fig, ax = plt.subplots(figsize=(7, 7)) | |
| 587 im = plt.imshow(c_matrix, cmap=plot_color) | |
| 588 for i in range(len(c_matrix)): | |
| 589 for j in range(len(c_matrix)): | |
| 590 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k") | |
| 591 ax.set_ylabel("True class labels") | |
| 592 ax.set_xlabel("Predicted class labels") | |
| 593 ax.set_title(title) | |
| 594 ax.set_xticks(axis_labels) | |
| 595 ax.set_yticks(axis_labels) | |
| 596 fig.colorbar(im, ax=ax) | |
| 597 fig.tight_layout() | |
| 598 plt.savefig("output.png", dpi=125) | |
| 599 os.rename("output.png", "output") | |
| 600 | |
| 601 return 0 | |
| 602 | |
| 603 # save pdf file to disk | |
| 604 # fig.write_image("image.pdf", format='pdf') | |
| 605 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) | |
| 606 | |
| 607 | |
| 608 if __name__ == "__main__": | |
| 609 aparser = argparse.ArgumentParser() | |
| 610 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | |
| 611 aparser.add_argument("-e", "--estimator", dest="infile_estimator") | |
| 612 aparser.add_argument("-X", "--infile1", dest="infile1") | |
| 613 aparser.add_argument("-y", "--infile2", dest="infile2") | |
| 614 aparser.add_argument("-O", "--outfile_result", dest="outfile_result") | |
| 615 aparser.add_argument("-o", "--outfile_object", dest="outfile_object") | |
| 616 aparser.add_argument("-g", "--groups", dest="groups") | |
| 617 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") | |
| 618 aparser.add_argument("-b", "--intervals", dest="intervals") | |
| 619 aparser.add_argument("-t", "--targets", dest="targets") | |
| 620 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") | |
| 621 aparser.add_argument("-c", "--model_config", dest="model_config") | |
| 622 aparser.add_argument("-tl", "--true_labels", dest="true_labels") | |
| 623 aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels") | |
| 624 aparser.add_argument("-pc", "--plot_color", dest="plot_color") | |
| 625 aparser.add_argument("-pt", "--title", dest="title") | |
| 626 args = aparser.parse_args() | |
| 627 | |
| 628 main( | |
| 629 args.inputs, | |
| 630 args.infile_estimator, | |
| 631 args.infile1, | |
| 632 args.infile2, | |
| 633 args.outfile_result, | |
| 634 outfile_object=args.outfile_object, | |
| 635 groups=args.groups, | |
| 636 ref_seq=args.ref_seq, | |
| 637 intervals=args.intervals, | |
| 638 targets=args.targets, | |
| 639 fasta_path=args.fasta_path, | |
| 640 model_config=args.model_config, | |
| 641 true_labels=args.true_labels, | |
| 642 predicted_labels=args.predicted_labels, | |
| 643 plot_color=args.plot_color, | |
| 644 title=args.title, | |
| 645 ) |
