diff scale_image.py @ 4:5122286b700e draft default tip

planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tree/master/tools/scale_image/ commit cd908933bd7bd8756c213af57ea6343a90effc12
author imgteam
date Sat, 13 Dec 2025 22:11:13 +0000
parents ba2b1a6f1b84
children
line wrap: on
line diff
--- a/scale_image.py	Thu Oct 17 10:47:14 2024 +0000
+++ b/scale_image.py	Sat Dec 13 22:11:13 2025 +0000
@@ -1,51 +1,272 @@
 import argparse
+import json
 import sys
+from typing import (
+    Any,
+    Literal,
+)
 
 import giatools.io
 import numpy as np
 import skimage.io
 import skimage.transform
 import skimage.util
-from PIL import Image
+
+
+def get_uniform_scale(
+    img: giatools.Image,
+    axes: Literal['all', 'spatial'],
+    factor: float,
+) -> tuple[float, ...]:
+    """
+    Determine a tuple of `scale` factors for uniform or spatially uniform scaling.
+
+    Axes, that are not present in the original image data, are ignored.
+    """
+    ignored_axes = [
+        axis for axis_idx, axis in enumerate(img.axes)
+        if axis not in img.original_axes or (
+            factor < 1 and img.data.shape[axis_idx] == 1
+        )
+    ]
+    match axes:
+
+        case 'all':
+            return tuple(
+                [
+                    (factor if axis not in ignored_axes else 1)
+                    for axis in img.axes if axis != 'C'
+                ]
+            )
+
+        case 'spatial':
+            return tuple(
+                [
+                    (factor if axis in 'YXZ' and axis not in ignored_axes else 1)
+                    for axis in img.axes if axis != 'C'
+                ]
+            )
+
+        case _:
+            raise ValueError(f'Unknown axes for uniform scaling: "{axes}"')
+
+
+def get_scale_for_isotropy(
+    img: giatools.Image,
+    sample: Literal['up', 'down'],
+) -> tuple[float, ...]:
+    """
+    Determine a tuple of `scale` factors to establish spatial isotropy.
+
+    The `sample` parameter governs whether to up-sample or down-sample the image data.
+    """
+    scale = [1] * (len(img.axes) - 1)  # omit the channel axis
+    z_axis, y_axis, x_axis = [
+        img.axes.index(axis) for axis in 'ZYX'
+    ]
+
+    # Determine the pixel size of the image
+    if 'resolution' in img.metadata:
+        pixel_size = np.divide(1, img.metadata['resolution'])
+    else:
+        sys.exit('Resolution information missing in image metadata')
+
+    # Define unified transformation of pixel/voxel sizes to scale factors
+    def voxel_size_to_scale(voxel_size: np.ndarray) -> list:
+        match sample:
+            case 'up':
+                return (voxel_size / voxel_size.min()).tolist()
+            case 'down':
+                return (voxel_size / voxel_size.max()).tolist()
+            case _:
+                raise ValueError(f'Unknown value for sample: "{sample}"')
+
+    # Handle the 3-D case
+    if img.data.shape[z_axis] > 1:
+
+        # Determine the voxel depth of the image
+        if (voxel_depth := img.metadata.get('z_spacing', None)) is None:
+            sys.exit('Voxel depth information missing in image metadata')
+
+        # Determine the XYZ scale factors
+        scale[x_axis], scale[y_axis], scale[z_axis] = (
+            voxel_size_to_scale(
+                np.array([*pixel_size, voxel_depth]),
+            )
+        )
+
+    # Handle the 2-D case
+    else:
+
+        # Determine the XY scale factors
+        scale[x_axis], scale[y_axis] = (
+            voxel_size_to_scale(
+                np.array(pixel_size),
+            )
+        )
+
+    return tuple(scale)
+
+
+def get_aa_sigma_by_scale(scale: float) -> float:
+    """
+    Determine the optimal size of the Gaussian filter for anti-aliasing.
+
+    See for details: https://scikit-image.org/docs/0.25.x/api/skimage.transform.html#skimage.transform.rescale
+    """
+    return (1 / scale - 1) / 2 if scale < 1 else 0
 
 
-def scale_image(input_file, output_file, scale, order, antialias):
-    Image.MAX_IMAGE_PIXELS = 50000 * 50000
-    im = giatools.io.imread(input_file)
+def get_new_metadata(
+    old: giatools.Image,
+    scale: float | tuple[float, ...],
+    arr: np.ndarray,
+) -> dict[str, Any]:
+    """
+    Determine the result metadata (copy and adapt).
+    """
+    metadata = dict(old.metadata)
+    scales = (
+        [scale] * (len(old.axes) - 1)  # omit the channel axis
+        if isinstance(scale, float) else scale
+    )
+
+    # Determine the original pixel size
+    old_pixel_size = (
+        np.divide(1, old.metadata['resolution'])
+        if 'resolution' in old.metadata else (1, 1)
+    )
 
-    # Parse `--scale` argument
-    if ',' in scale:
-        scale = [float(s.strip()) for s in scale.split(',')]
-        assert len(scale) <= im.ndim, f'Image has {im.ndim} axes, but scale factors were given for {len(scale)} axes.'
-        scale = scale + [1] * (im.ndim - len(scale))
+    # Determine the new pixel size and update metadata
+    new_pixel_size = np.divide(
+        old_pixel_size,
+        (
+            scales[old.axes.index('X')],
+            scales[old.axes.index('Y')],
+        ),
+    )
+    metadata['resolution'] = tuple(1 / new_pixel_size)
+
+    # Update the metadata for the new voxel depth
+    old_voxel_depth = old.metadata.get('z_spacing', 1)
+    metadata['z_spacing'] = old_voxel_depth / scales[old.axes.index('Z')]
+
+    return metadata
+
+
+def metadata_to_str(metadata: dict) -> str:
+    tokens = list()
+    for key in sorted(metadata.keys()):
+        value = metadata[key]
+        if isinstance(value, tuple):
+            value = '(' + ', '.join([f'{val}' for val in value]) + ')'
+        tokens.append(f'{key}: {value}')
+    if len(metadata_str := ', '.join(tokens)) > 0:
+        return metadata_str
+    else:
+        return 'has no metadata'
+
+
+def write_output(filepath: str, img: giatools.Image):
+    """
+    Validate that the output file format is suitable for the image data, then write it.
+    """
+    print('Output shape:', img.data.shape)
+    print('Output axes:', img.axes)
+    print('Output', metadata_to_str(img.metadata))
 
-    else:
-        scale = float(scale)
+    # Validate that the output file format is suitable for the image data
+    if filepath.lower().endswith('.png'):
+        if not frozenset(img.axes) <= frozenset('YXC'):
+            sys.exit(f'Cannot write PNG file with axes "{img.axes}"')
+
+    # Write image data to the output file
+    img.write(filepath)
+
+
+def scale_image(
+    input_filepath: str,
+    output_filepath: str,
+    mode: Literal['uniform', 'explicit', 'isotropy'],
+    order: int,
+    anti_alias: bool,
+    **cfg,
+):
+    img = giatools.Image.read(input_filepath)
+    print('Input axes:', img.original_axes)
+    print('Input', metadata_to_str(img.metadata))
+
+    # Determine `scale` for scaling
+    match mode:
+
+        case 'uniform':
+            scale = get_uniform_scale(img, cfg['axes'], cfg['factor'])
 
-        # For images with 3 or more axes, the last axis is assumed to correspond to channels
-        if im.ndim >= 3:
-            scale = [scale] * (im.ndim - 1) + [1]
+        case 'explicit':
+            scale = tuple(
+                [cfg.get(f'factor_{axis.lower()}', 1) for axis in img.axes if axis != 'C']
+            )
+
+        case 'isotropy':
+            scale = get_scale_for_isotropy(img, cfg['sample'])
+
+        case _:
+            raise ValueError(f'Unknown mode: "{mode}"')
 
-    # Do the scaling
-    res = skimage.transform.rescale(im, scale, order, anti_aliasing=antialias, preserve_range=True)
+    # Assemble remaining `rescale` parameters
+    rescale_kwargs = dict(
+        scale=scale,
+        order=order,
+        preserve_range=True,
+        channel_axis=img.axes.index('C'),
+    )
+    if (anti_alias := anti_alias and (np.array(scale) < 1).any()):
+        rescale_kwargs['anti_aliasing'] = anti_alias
+        rescale_kwargs['anti_aliasing_sigma'] = tuple(
+            [
+                get_aa_sigma_by_scale(s) for s in scale
+            ] + [0]  # `skimage.transform.rescale` also expects a value for the channel axis
+        )
+    else:
+        rescale_kwargs['anti_aliasing'] = False
+
+    # Re-sample the image data to perform the scaling
+    for key, value in rescale_kwargs.items():
+        print(f'{key}: {value}')
+    arr = skimage.transform.rescale(img.data, **rescale_kwargs)
 
     # Preserve the `dtype` so that both brightness and range of values is preserved
-    if res.dtype != im.dtype:
-        if np.issubdtype(im.dtype, np.integer):
-            res = res.round()
-        res = res.astype(im.dtype)
+    if arr.dtype != img.data.dtype:
+        if np.issubdtype(img.data.dtype, np.integer):
+            arr = arr.round()
+        arr = arr.astype(img.data.dtype)
 
-    # Save result
-    skimage.io.imsave(output_file, res)
+    # Determine the result metadata and save result
+    metadata = get_new_metadata(img, scale, arr)
+    write_output(
+        output_filepath,
+        giatools.Image(
+            data=arr,
+            axes=img.axes,
+            metadata=metadata,
+        ).squeeze()
+    )
 
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument('input_file', type=argparse.FileType('r'), default=sys.stdin)
-    parser.add_argument('out_file', type=argparse.FileType('w'), default=sys.stdin)
-    parser.add_argument('--scale', type=str, required=True)
-    parser.add_argument('--order', type=int, required=True)
-    parser.add_argument('--antialias', default=False, action='store_true')
+    parser.add_argument('input', type=str)
+    parser.add_argument('output', type=str)
+    parser.add_argument('params', type=str)
     args = parser.parse_args()
 
-    scale_image(args.input_file.name, args.out_file.name, args.scale, args.order, args.antialias)
+    # Read the config file
+    with open(args.params) as cfgf:
+        cfg = json.load(cfgf)
+
+    # Perform scaling
+    scale_image(
+        args.input,
+        args.output,
+        **cfg,
+    )