diff binary2label.py @ 4:8f0dd9a58ec3 draft default tip

planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tree/master/tools/binary2labelimage/ commit f5a4de7535e433e3b0e96e0694e481b6643a54f8
author imgteam
date Sat, 03 Jan 2026 14:14:05 +0000
parents a041e4e9d449
children
line wrap: on
line diff
--- a/binary2label.py	Mon May 12 08:15:32 2025 +0000
+++ b/binary2label.py	Sat Jan 03 14:14:05 2026 +0000
@@ -1,27 +1,70 @@
-import argparse
+import giatools
+import numpy as np
+import scipy.ndimage as ndi
+
+# Fail early if an optional backend is not available
+giatools.require_backend('omezarr')
+
 
-import giatools
-import scipy.ndimage as ndi
-import tifffile
+def label_watershed(arr: np.ndarray, **kwargs) -> np.ndarray:
+    import skimage.util
+    from skimage.feature import peak_local_max
+    from skimage.segmentation import watershed
+    distance = ndi.distance_transform_edt(arr)
+    local_max_indices = peak_local_max(
+        distance,
+        labels=arr,
+        **kwargs,
+    )
+    local_max_mask = np.zeros(arr.shape, dtype=bool)
+    local_max_mask[tuple(local_max_indices.T)] = True
+    markers = ndi.label(local_max_mask)[0]
+    res = watershed(-distance, markers, mask=arr)
+    return skimage.util.img_as_uint(res)  # converts to uint16
 
 
-# Parse CLI parameters
-parser = argparse.ArgumentParser()
-parser.add_argument('input', type=str, help='input file')
-parser.add_argument('output', type=str, help='output file (TIFF)')
-args = parser.parse_args()
+if __name__ == '__main__':
+
+    tool = giatools.ToolBaseplate()
+    tool.add_input_image('input')
+    tool.add_output_image('output')
+    tool.parse_args()
+
+    # Validate the input image and the selected method
+    try:
+        input_image = tool.args.input_images['input']
+        if (method := tool.args.params.pop('method')) == 'watershed' and input_image.shape[input_image.axes.index('Z')] > 1:
+            raise ValueError(f'Method "{method}" is not applicable to 3-D images.')
+
+        elif input_image.shape[input_image.axes.index('C')] > 1:
+            raise ValueError('Multi-channel images are forbidden to avoid confusion with multi-channel labels (e.g., RGB labels).')
+
+        else:
+
+            # Choose the requested labeling method
+            match method:
 
-# Read the input image with the original axes
-img = giatools.Image.read(args.input)
-img = img.normalize_axes_like(
-    img.original_axes,
-)
+                case 'cca':
+                    joint_axes = 'ZYX'
+                    label = lambda input_section_bin: (  # noqa: E731
+                        ndi.label(input_section_bin, **tool.args.params)[0].astype(np.uint16)
+                    )
+
+                case 'watershed':
+                    joint_axes = 'YX'
+                    label = lambda input_section_bin: (  # noqa: E731
+                        label_watershed(input_section_bin, **tool.args.params)  # already uint16
+                    )
 
-# Make sure the image is truly binary
-img_arr_bin = (img.data > 0)
+                case _:
+                    raise ValueError(f'Unknown method: "{method}"')
 
-# Perform the labeling
-img.data = ndi.label(img_arr_bin)[0]
+        # Perform the labeling
+        for section in tool.run(joint_axes):
+            section['output'] = label(
+                section['input'].data > 0,  # ensure that the input data is truly binary
+            )
 
-# Write the result image (same axes as input image)
-tifffile.imwrite(args.output, img.data, metadata=dict(axes=img.axes))
+    # Exit and print error to stderr
+    except ValueError as err:
+        exit(err.args[0])