Mercurial > repos > bgruening > sklearn_stacking_ensemble_models
comparison ml_visualization_ex.py @ 8:6430b9b00d2f draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9e28f4466084464d38d3f8db2aff07974be4ba69"
| author | bgruening |
|---|---|
| date | Wed, 11 Mar 2020 17:24:17 +0000 |
| parents | 00819b7f2f55 |
| children | b8c92e94ac1d |
comparison
equal
deleted
inserted
replaced
| 7:00819b7f2f55 | 8:6430b9b00d2f |
|---|---|
| 11 | 11 |
| 12 from keras.models import model_from_json | 12 from keras.models import model_from_json |
| 13 from keras.utils import plot_model | 13 from keras.utils import plot_model |
| 14 from sklearn.feature_selection.base import SelectorMixin | 14 from sklearn.feature_selection.base import SelectorMixin |
| 15 from sklearn.metrics import precision_recall_curve, average_precision_score | 15 from sklearn.metrics import precision_recall_curve, average_precision_score |
| 16 from sklearn.metrics import roc_curve, auc | 16 from sklearn.metrics import roc_curve, auc, confusion_matrix |
| 17 from sklearn.pipeline import Pipeline | 17 from sklearn.pipeline import Pipeline |
| 18 from galaxy_ml.utils import load_model, read_columns, SafeEval | 18 from galaxy_ml.utils import load_model, read_columns, SafeEval |
| 19 | 19 |
| 20 | 20 |
| 21 safe_eval = SafeEval() | 21 safe_eval = SafeEval() |
| 264 plt.savefig(os.path.join(folder, "output.svg"), format="svg") | 264 plt.savefig(os.path.join(folder, "output.svg"), format="svg") |
| 265 os.rename(os.path.join(folder, "output.svg"), | 265 os.rename(os.path.join(folder, "output.svg"), |
| 266 os.path.join(folder, "output")) | 266 os.path.join(folder, "output")) |
| 267 | 267 |
| 268 | 268 |
| 269 def get_dataframe(file_path, plot_selection, header_name, column_name): | |
| 270 header = 'infer' if plot_selection[header_name] else None | |
| 271 column_option = plot_selection[column_name]["selected_column_selector_option"] | |
| 272 if column_option in ["by_index_number", "all_but_by_index_number", "by_header_name", "all_but_by_header_name"]: | |
| 273 col = plot_selection[column_name]["col1"] | |
| 274 else: | |
| 275 col = None | |
| 276 _, input_df = read_columns(file_path, c=col, | |
| 277 c_option=column_option, | |
| 278 return_df=True, | |
| 279 sep='\t', header=header, | |
| 280 parse_dates=True) | |
| 281 return input_df | |
| 282 | |
| 283 | |
| 269 def main(inputs, infile_estimator=None, infile1=None, | 284 def main(inputs, infile_estimator=None, infile1=None, |
| 270 infile2=None, outfile_result=None, | 285 infile2=None, outfile_result=None, |
| 271 outfile_object=None, groups=None, | 286 outfile_object=None, groups=None, |
| 272 ref_seq=None, intervals=None, | 287 ref_seq=None, intervals=None, |
| 273 targets=None, fasta_path=None, | 288 targets=None, fasta_path=None, |
| 274 model_config=None): | 289 model_config=None, true_labels=None, |
| 290 predicted_labels=None, plot_color=None, | |
| 291 title=None): | |
| 275 """ | 292 """ |
| 276 Parameter | 293 Parameter |
| 277 --------- | 294 --------- |
| 278 inputs : str | 295 inputs : str |
| 279 File path to galaxy tool parameter | 296 File path to galaxy tool parameter |
| 309 fasta_path : str, default is None | 326 fasta_path : str, default is None |
| 310 File path to dataset containing fasta file | 327 File path to dataset containing fasta file |
| 311 | 328 |
| 312 model_config : str, default is None | 329 model_config : str, default is None |
| 313 File path to dataset containing JSON config for neural networks | 330 File path to dataset containing JSON config for neural networks |
| 331 | |
| 332 true_labels : str, default is None | |
| 333 File path to dataset containing true labels | |
| 334 | |
| 335 predicted_labels : str, default is None | |
| 336 File path to dataset containing true predicted labels | |
| 337 | |
| 338 plot_color : str, default is None | |
| 339 Color of the confusion matrix heatmap | |
| 340 | |
| 341 title : str, default is None | |
| 342 Title of the confusion matrix heatmap | |
| 314 """ | 343 """ |
| 315 warnings.simplefilter('ignore') | 344 warnings.simplefilter('ignore') |
| 316 | 345 |
| 317 with open(inputs, 'r') as param_handler: | 346 with open(inputs, 'r') as param_handler: |
| 318 params = json.load(param_handler) | 347 params = json.load(param_handler) |
| 541 plot_model(model, to_file="output.png") | 570 plot_model(model, to_file="output.png") |
| 542 os.rename('output.png', 'output') | 571 os.rename('output.png', 'output') |
| 543 | 572 |
| 544 return 0 | 573 return 0 |
| 545 | 574 |
| 575 elif plot_type == 'classification_confusion_matrix': | |
| 576 plot_selection = params["plotting_selection"] | |
| 577 input_true = get_dataframe(true_labels, plot_selection, "header_true", "column_selector_options_true") | |
| 578 header_predicted = 'infer' if plot_selection["header_predicted"] else None | |
| 579 input_predicted = pd.read_csv(predicted_labels, sep='\t', parse_dates=True, header=header_predicted) | |
| 580 true_classes = input_true.iloc[:, -1].copy() | |
| 581 predicted_classes = input_predicted.iloc[:, -1].copy() | |
| 582 axis_labels = list(set(true_classes)) | |
| 583 c_matrix = confusion_matrix(true_classes, predicted_classes) | |
| 584 fig, ax = plt.subplots(figsize=(7, 7)) | |
| 585 im = plt.imshow(c_matrix, cmap=plot_color) | |
| 586 for i in range(len(c_matrix)): | |
| 587 for j in range(len(c_matrix)): | |
| 588 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k") | |
| 589 ax.set_ylabel('True class labels') | |
| 590 ax.set_xlabel('Predicted class labels') | |
| 591 ax.set_title(title) | |
| 592 ax.set_xticks(axis_labels) | |
| 593 ax.set_yticks(axis_labels) | |
| 594 fig.colorbar(im, ax=ax) | |
| 595 fig.tight_layout() | |
| 596 plt.savefig("output.png", dpi=125) | |
| 597 os.rename('output.png', 'output') | |
| 598 | |
| 599 return 0 | |
| 600 | |
| 546 # save pdf file to disk | 601 # save pdf file to disk |
| 547 # fig.write_image("image.pdf", format='pdf') | 602 # fig.write_image("image.pdf", format='pdf') |
| 548 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) | 603 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) |
| 549 | 604 |
| 550 | 605 |
| 560 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") | 615 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") |
| 561 aparser.add_argument("-b", "--intervals", dest="intervals") | 616 aparser.add_argument("-b", "--intervals", dest="intervals") |
| 562 aparser.add_argument("-t", "--targets", dest="targets") | 617 aparser.add_argument("-t", "--targets", dest="targets") |
| 563 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") | 618 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") |
| 564 aparser.add_argument("-c", "--model_config", dest="model_config") | 619 aparser.add_argument("-c", "--model_config", dest="model_config") |
| 620 aparser.add_argument("-tl", "--true_labels", dest="true_labels") | |
| 621 aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels") | |
| 622 aparser.add_argument("-pc", "--plot_color", dest="plot_color") | |
| 623 aparser.add_argument("-pt", "--title", dest="title") | |
| 565 args = aparser.parse_args() | 624 args = aparser.parse_args() |
| 566 | 625 |
| 567 main(args.inputs, args.infile_estimator, args.infile1, args.infile2, | 626 main(args.inputs, args.infile_estimator, args.infile1, args.infile2, |
| 568 args.outfile_result, outfile_object=args.outfile_object, | 627 args.outfile_result, outfile_object=args.outfile_object, |
| 569 groups=args.groups, ref_seq=args.ref_seq, intervals=args.intervals, | 628 groups=args.groups, ref_seq=args.ref_seq, intervals=args.intervals, |
| 570 targets=args.targets, fasta_path=args.fasta_path, | 629 targets=args.targets, fasta_path=args.fasta_path, |
| 571 model_config=args.model_config) | 630 model_config=args.model_config, true_labels=args.true_labels, |
| 631 predicted_labels=args.predicted_labels, | |
| 632 plot_color=args.plot_color, | |
| 633 title=args.title) |
