diff crop_image.py @ 0:f8bfa85cac4c draft default tip

planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tree/master/tools/crop_image/ commit 7a5037206d267aa7d9b7e5e062327c3464942471
author imgteam
date Fri, 06 Jun 2025 12:46:50 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/crop_image.py	Fri Jun 06 12:46:50 2025 +0000
@@ -0,0 +1,68 @@
+import argparse
+import os
+
+import numpy as np
+from giatools.image import Image
+
+
+def crop_image(
+    image_filepath: str,
+    labelmap_filepath: str,
+    output_ext: str,
+    output_dir: str,
+    skip_labels: frozenset[int],
+):
+    image = Image.read(image_filepath)
+    labelmap = Image.read(labelmap_filepath)
+
+    if image.axes != labelmap.axes:
+        raise ValueError(f'Axes mismatch between image ({image.axes}) and label map ({labelmap.axes}).')
+
+    if image.data.shape != labelmap.data.shape:
+        raise ValueError(f'Shape mismatch between image ({image.data.shape}) and label map ({labelmap.data.shape}).')
+
+    for label in np.unique(labelmap.data):
+        if label in skip_labels:
+            continue
+        roi_mask = (labelmap.data == label)
+        roi = crop_image_to_mask(image.data, roi_mask)
+        roi_image = Image(roi, image.axes).normalize_axes_like(image.original_axes)
+        roi_image.write(os.path.join(output_dir, f'{label}.{output_ext}'))
+
+
+def crop_image_to_mask(data: np.ndarray, mask: np.ndarray) -> np.ndarray:
+    """
+    Crop the `data` array to the minimal bounding box in `mask`.
+
+    The arguments are not modified.
+    """
+    assert data.shape == mask.shape
+
+    # Crop `data` to the convex hull of the mask in each dimension
+    for dim in range(data.ndim):
+        mask1d = mask.any(axis=tuple(i for i in range(mask.ndim) if i != dim))
+        mask1d_indices = np.where(mask1d)[0]
+        mask1d_indices_cvxhull = np.arange(min(mask1d_indices), max(mask1d_indices) + 1)
+        data = data.take(axis=dim, indices=mask1d_indices_cvxhull)
+
+    return data
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('image', type=str)
+    parser.add_argument('labelmap', type=str)
+    parser.add_argument('skip_labels', type=str)
+    parser.add_argument('output_ext', type=str)
+    parser.add_argument('output_dir', type=str)
+    args = parser.parse_args()
+
+    crop_image(
+        image_filepath=args.image,
+        labelmap_filepath=args.labelmap,
+        output_ext=args.output_ext,
+        output_dir=args.output_dir,
+        skip_labels=frozenset(
+            int(label.strip()) for label in args.skip_labels.split(',') if label.strip()
+        ) if args.skip_labels.strip() else frozenset(),
+    )