view flexynesis_plot.xml @ 2:3c5d82bf6e8a draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/flexynesis commit 1afbaf45449e25238935e222f983da62392c067a
author bgruening
date Fri, 04 Jul 2025 14:57:52 +0000
parents bb91bf19eb40
children 52b6f2ac38c7
line wrap: on
line source

<tool id="flexynesis_plot" name="Flexynesis plot" version="@TOOL_VERSION@+galaxy@VERSION_SUFFIX@" profile="@PROFILE@">
    <description>tool for visualizing flexynesis results</description>
    <macros>
        <import>macros.xml</import>
    </macros>
    <expand macro="requirements"/>
    <required_files>
        <include path="flexynesis_plot.py" />
    </required_files>
    <command detect_errors="exit_code"><![CDATA[
        @CHECK_NON_COMMERCIAL_USE@
        mkdir -p inputs/ plots/ &&
        ln -s '$plot_conditional.labels' 'inputs/$plot_conditional.labels.element_identifier.$plot_conditional.labels.ext' &&
        #if $plot_conditional.plot_type == "dimred":
            ln -s '$plot_conditional.embeddings' 'inputs/$plot_conditional.embeddings.element_identifier.$plot_conditional.embeddings.ext' &&
        #end if
        cat '$flexynesis_plot_config' &&
        python '$flexynesis_plot_config'
    ]]></command>
    <configfiles>
        <configfile name="flexynesis_plot_config"><![CDATA[
import sys
sys.path.append('$__tool_directory__/')

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from flexynesis import (
    get_important_features,
    plot_dim_reduced,
    plot_hazard_ratios,
    plot_kaplan_meier_curves,
    plot_pr_curves,
    plot_roc_curves,
    plot_scatter
)
from scipy.stats import kruskal, mannwhitneyu
from flexynesis_plot import (
    plot_label_concordance_heatmap,
    plot_boxplot,
    detect_color_type,
    load_labels,
    load_embeddings,
    match_samples_to_embeddings
)
#if $plot_conditional.plot_type == "dimred":
@PLOT_COMMON_CONFIG@
embeddings, sample_names = load_embeddings('inputs/$plot_conditional.embeddings.element_identifier.$plot_conditional.embeddings.ext')
matched_labels = match_samples_to_embeddings(sample_names, label_data)

label = matched_labels.columns[$plot_conditional.label-1]
color_type = detect_color_type(matched_labels[label])

fig = plot_dim_reduced(
    matrix=embeddings,
    labels=matched_labels[label],
    method='$plot_conditional.method',
    color_type=color_type
)
output_path = "plots/{label}_${plot_conditional.method}.${plot_conditional.format}"
fig.save(output_path, dpi=$plot_conditional.dpi, bbox_inches='tight')
#else if $plot_conditional.plot_type == "scatter":
@PLOT_COMMON_CONFIG@

true_label = label_data.columns[$plot_conditional.true_label-1]
predicted_label = label_data.columns[$plot_conditional.predicted_label-1]
true_values = pd.to_numeric(label_data[true_label], errors='coerce')
predicted_values = pd.to_numeric(label_data[predicted_label], errors='coerce')

if true_values.isna().all() or predicted_values.isna().all():
    raise ValueError("No valid numeric values found for known or predicted labels")

fig = plot_scatter(true_values, predicted_values)

output_path = "plots/${plot_conditional.true_label}_${plot_conditional.predicted_label}_scatter.${plot_conditional.format}"
fig.save(output_path, dpi=$plot_conditional.dpi, bbox_inches='tight')

#else if $plot_conditional.plot_type == "concordance_heatmap":
@PLOT_COMMON_CONFIG@

true_label = label_data.columns[$plot_conditional.true_label-1]
predicted_label = label_data.columns[$plot_conditional.predicted_label-1]

true_values = label_data[true_label].tolist()
predicted_values = label_data[predicted_label].tolist()
fig = plot_label_concordance_heatmap(true_values, predicted_values)
plt.close(fig)
output_path = "plots/{true_label}_{predicted_label}_concordance_heatmap.${plot_conditional.format}"
fig.savefig(output_path, dpi=$plot_conditional.dpi, bbox_inches='tight')

#else if $plot_conditional.plot_type == "pr_curve":
@PR_ROC_CONFIG@

fig = plot_pr_curves(y_true_np, y_probs_np)

output_path = "plots/pr_curves.${plot_conditional.format}"
fig.save(output_path, dpi=$plot_conditional.dpi, bbox_inches='tight')

#else if $plot_conditional.plot_type == "roc_curve":
@PR_ROC_CONFIG@
fig = plot_roc_curves(y_true_np, y_probs_np)
output_path = "plots/roc_curves.${plot_conditional.format}"
fig.save(output_path, dpi=$plot_conditional.dpi, bbox_inches='tight')

#else if $plot_conditional.plot_type == "box_plot":
@PR_ROC_BOX_CONFIG@
# Remove rows with missing data
clean_data = label_data.dropna(subset=['known_label', 'probability'])

if clean_data.empty:
    raise ValueError("    No valid data after cleaning")

# Get unique classes
classes = clean_data['class_label'].unique()

for class_label in classes:
    print(f"    Generating box plot for class: {class_label}")

    # Filter for current class
    class_data = clean_data[clean_data['class_label'] == class_label]

    # Create the box plot
    fig = plot_boxplot(
        categorical_x=class_data['known_label'],
        numerical_y=class_data['probability'],
        title_x='True Label',
        title_y=f'Predicted Probability ({class_label})',
    )

    # Save the plot
    safe_class_name = str(class_label).replace('/', '_').replace('\\', '_').replace(' ', '_').replace(':', '_')
    plt.close(fig)

    output_path = f"plots/box_plot_{safe_class_name}.${plot_conditional.format}"
    fig.savefig(output_path, dpi=$plot_conditional.dpi, bbox_inches='tight')
#end if
        ]]></configfile>
    </configfiles>
    <inputs>
        <expand macro="commercial_use_param"/>
        <conditional name="plot_conditional">
            <param name="plot_type" type="select" label="Flexynesis plot">
                <option value="dimred">Dimensionality reduction</option>
                <option value="scatter">Scatter plot of known vs predicted labels</option>
                <option value="concordance_heatmap">Label concordance heatmap</option>
                <option value="pr_curve">Precision-recall curves</option>
                <option value="roc_curve">ROC curves</option>
                <option value="box_plot">Box plot</option>
            </param>
            <when value="dimred">
                <expand macro="plots_common_param">
                    <expand macro="plots_common_input"/>
                    <param argument="--embeddings" type="data" format="tabular" label="Embeddings" help="Generated by flexynesis"/>
                    <param argument="--label" type="data_column" data_ref="labels" label="Column in the labels file to use for coloring the points in the plot"/>
                    <param name="method" type="select" label="Transformation method">
                        <option value="pca" selected="true">PCA</option>
                        <option value="umap">UMAP</option>
                    </param>
                </expand>
            </when>
            <when value="scatter">
                <expand macro="plots_common_param">
                    <expand macro="plots_common_input"/>
                    <param name="true_label" type="data_column" data_ref="labels" label="Column name in the labels file to use for the true labels"/>
                    <param name="predicted_label" type="data_column" data_ref="labels" label="Column name in the labels file to use for the predicted labels"/>
                </expand>
            </when>
            <when value="concordance_heatmap">
                <expand macro="plots_common_param">
                    <expand macro="plots_common_input"/>
                    <param name="true_label" type="data_column" data_ref="labels" label="Column name in the labels file to use for the true labels"/>
                    <param name="predicted_label" type="data_column" data_ref="labels" label="Column name in the labels file to use for the predicted labels"/>
                </expand>
            </when>
            <when value="pr_curve">
                <expand macro="plots_common_param">
                    <expand macro="plots_common_input"/>
                </expand>
            </when>
            <when value="roc_curve">
                <expand macro="plots_common_param">
                    <expand macro="plots_common_input"/>
                </expand>
            </when>
            <when value="box_plot">
                <expand macro="plots_common_param">
                    <expand macro="plots_common_input"/>
                </expand>
            </when>
        </conditional>
    </inputs>
    <outputs>
        <data name="plot_out" auto_format="true" from_work_dir="plots/*" label="${tool.name} on ${on_string}: ${plot_conditional.plot_type}">
            <filter>plot_conditional['plot_type'] != "box_plot"</filter>
        </data>
        <collection name="boxplot_out" type="list" label="${tool.name} on ${on_string}: box_plot">
            <discover_datasets pattern="__name_and_ext__" directory="plots/"/>
            <filter>plot_conditional['plot_type'] == "box_plot"</filter>
        </collection>
    </outputs>
    <tests>
        <!-- test 1: dimred -->
        <test expect_num_outputs="1">
            <param name="non_commercial_use" value="True"/>
            <conditional name="plot_conditional">
                <param name="plot_type" value="dimred"/>
                <param name="embeddings" value="embeddings.tabular"/>
                <param name="label" value="6"/>
                <param name="method" value="pca"/>
                <param name="labels" value="labels.tabular"/>
                <param name="format" value="jpg"/>
                <param name="dpi" value="300"/>
            </conditional>
            <output name="plot_out">
                <assert_contents>
                    <has_image_center_of_mass center_of_mass="970,733" eps="50"/>
                    <has_image_channels channels="3"/>
                    <has_image_height height="1461" delta="50"/>
                    <has_image_width width="1941" delta="50"/>
                </assert_contents>
            </output>
        </test>
        <!-- test 2: scatter -->
        <test expect_num_outputs="1">
            <param name="non_commercial_use" value="True"/>
            <conditional name="plot_conditional">
                <param name="plot_type" value="scatter"/>
                <param name="labels" value="labels_scatter.tabular"/>
                <param name="true_label" value="5"/>
                <param name="predicted_label" value="6"/>
                <param name="format" value="jpg"/>
                <param name="dpi" value="300"/>
            </conditional>
            <output name="plot_out">
                <assert_contents>
                    <has_image_center_of_mass center_of_mass="970,733" eps="50"/>
                    <has_image_channels channels="3"/>
                    <has_image_height height="1461" delta="50"/>
                    <has_image_width width="1941" delta="50"/>
                </assert_contents>
            </output>
        </test>
        <!-- test 3: concordance_heatmap -->
        <test expect_num_outputs="1">
            <param name="non_commercial_use" value="True"/>
            <conditional name="plot_conditional">
                <param name="plot_type" value="concordance_heatmap"/>
                <param name="labels" value="labels.tabular"/>
                <param name="true_label" value="5"/>
                <param name="predicted_label" value="6"/>
                <param name="format" value="jpg"/>
                <param name="dpi" value="300"/>
            </conditional>
            <output name="plot_out">
                <assert_contents>
                    <has_image_center_of_mass center_of_mass="1450,1310" eps="50"/>
                    <has_image_channels channels="3"/>
                    <has_image_height height="2558" delta="50"/>
                    <has_image_width width="2770" delta="50"/>
                </assert_contents>
            </output>
        </test>
        <!-- test 4: pr_curve -->
        <test expect_num_outputs="1">
            <param name="non_commercial_use" value="True"/>
            <conditional name="plot_conditional">
                <param name="plot_type" value="pr_curve"/>
                <param name="labels" value="labels_pr.tabular"/>
                <param name="format" value="jpg"/>
                <param name="dpi" value="300"/>
            </conditional>
            <output name="plot_out">
                <assert_contents>
                    <has_image_center_of_mass center_of_mass="970,733" eps="50"/>
                    <has_image_channels channels="3"/>
                    <has_image_height height="1461" delta="50"/>
                    <has_image_width width="1941" delta="50"/>
                </assert_contents>
            </output>
        </test>
        <!-- test 5: roc_curve -->
        <test expect_num_outputs="1">
            <param name="non_commercial_use" value="True"/>
            <conditional name="plot_conditional">
                <param name="plot_type" value="roc_curve"/>
                <param name="labels" value="labels_pr.tabular"/>
                <param name="format" value="jpg"/>
                <param name="dpi" value="300"/>
            </conditional>
            <output name="plot_out">
                <assert_contents>
                    <has_image_center_of_mass center_of_mass="970,733" eps="50"/>
                    <has_image_channels channels="3"/>
                    <has_image_height height="1461" delta="50"/>
                    <has_image_width width="1941" delta="50"/>
                </assert_contents>
            </output>
        </test>
        <!-- test 6: box_plot -->
        <test expect_num_outputs="1">
            <param name="non_commercial_use" value="True"/>
            <conditional name="plot_conditional">
                <param name="plot_type" value="box_plot"/>
                <param name="labels" value="labels_pr.tabular"/>
                <param name="format" value="jpg"/>
                <param name="dpi" value="300"/>
            </conditional>
            <output_collection name="boxplot_out" type="list" count="7">
                <element name="box_plot_Basal">
                    <assert_contents>
                        <has_image_center_of_mass center_of_mass="1485,882" eps="20"/>
                        <has_image_channels channels="3"/>
                        <has_image_height height="1783" delta="20"/>
                        <has_image_width width="2967" delta="20"/>
                    </assert_contents>
                </element>
                <element name="box_plot_Her2">
                    <assert_contents>
                        <has_image_center_of_mass center_of_mass="1485,882" eps="20"/>
                        <has_image_channels channels="3"/>
                        <has_image_height height="1765" delta="20"/>
                        <has_image_width width="2967" delta="20"/>
                    </assert_contents>
                </element>
                <element name="box_plot_LumA">
                    <assert_contents>
                        <has_image_center_of_mass center_of_mass="1485,882" eps="20"/>
                        <has_image_channels channels="3"/>
                        <has_image_height height="1783" delta="20"/>
                        <has_image_width width="2967" delta="20"/>
                    </assert_contents>
                </element>
                <element name="box_plot_LumB">
                    <assert_contents>
                        <has_image_center_of_mass center_of_mass="1485,882" eps="20"/>
                        <has_image_channels channels="3"/>
                        <has_image_height height="1783" delta="20"/>
                        <has_image_width width="2967" delta="20"/>
                    </assert_contents>
                </element>
                <element name="box_plot_NC">
                    <assert_contents>
                        <has_image_center_of_mass center_of_mass="1485,882" eps="20"/>
                        <has_image_channels channels="3"/>
                        <has_image_height height="1783" delta="20"/>
                        <has_image_width width="2967" delta="20"/>
                    </assert_contents>
                </element>
                <element name="box_plot_Normal">
                    <assert_contents>
                        <has_image_center_of_mass center_of_mass="1485,882" eps="20"/>
                        <has_image_channels channels="3"/>
                        <has_image_height height="1783" delta="20"/>
                        <has_image_width width="2967" delta="20"/>
                    </assert_contents>
                </element>
                <element name="box_plot_claudin-low">
                    <assert_contents>
                        <has_image_center_of_mass center_of_mass="1485,882" eps="20"/>
                        <has_image_channels channels="3"/>
                        <has_image_height height="1783" delta="20"/>
                        <has_image_width width="2967" delta="20"/>
                    </assert_contents>
                </element>
            </output_collection>
        </test>
    </tests>
    <help><![CDATA[
@COMMON_HELP@

Flexynesis plot is a comprehensive visualization tool designed to create various types of plots for analyzing machine learning results from the Flexynesis framework. This tool supports multiple visualization types to help researchers understand their data and model performance.

Available plot types include:

- **Dimensionality Reduction:** Visualizes high-dimensional data in a lower-dimensional space using methods like PCA or UMAP.
- **Scatter Plot:** Compares known and predicted labels to assess model performance.
- **Label Concordance Heatmap:** Displays the agreement between true and predicted labels in a heatmap format.
- **Precision-Recall Curves:** Plots precision against recall to evaluate the trade-off between these metrics.
- **ROC Curves:** Visualizes the true positive rate against the false positive rate to assess model discrimination.
- **Box Plot:** Shows the distribution of predicted probabilities across different classes, highlighting medians and quartiles.


**Input Files**

Main common input file is the labels file, which is the predicted labels file generated by Flexynesis. It should contain the following columns:

- `sample_id`: Unique identifier for each sample.
- `variable`: The target variable used for the analysis.
- `class_label`: The class labels for the samples, used in box plots.
- `probability`: The predicted probabilities for the labels.
- `known_label`: The true labels for the samples, used in scatter plots and concordance heatmaps.
- `predicted_label`: The labels predicted by the Flexynesis model, used in scatter plots and concordance heatmaps.


For dimensionality reduction plots, an additional embeddings file is required, which contains the reduced-dimensional representations of the samples. This file is generated by Flexynesis.


.. class:: warningmark

PR and ROC curves can only be applied on classification tasks!


.. _Documentation: https://bimsbstatic.mdc-berlin.de/akalin/buyar/flexynesis/site/
.. _copyright holders: https://github.com/BIMSBbioinfo/flexynesis
    ]]></help>
    <expand macro="creator"/>
    <expand macro="citations"/>
</tool>