Mercurial > repos > imgteam > concat_channels
diff concat_channels.py @ 3:01c1d5af33be draft default tip
planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tree/master/tools/concat_channels/ commit a94f04c109c545a9f892a6ce7a5ffef152253201
| author | imgteam |
|---|---|
| date | Fri, 12 Dec 2025 21:15:44 +0000 |
| parents | ad1caf2331c6 |
| children |
line wrap: on
line diff
--- a/concat_channels.py Sun Dec 07 16:15:54 2025 +0000 +++ b/concat_channels.py Fri Dec 12 21:15:44 2025 +0000 @@ -1,4 +1,5 @@ import argparse +from typing import Any import giatools import numpy as np @@ -6,20 +7,19 @@ import skimage.util -normalized_axes = 'QTZYXC' - - def concat_channels( input_image_paths: list[str], output_image_path: str, axis: str, preserve_values: bool, + sort_by: str | None, ): # Create list of arrays to be concatenated - images = [] + images = list() + metadata = dict() for image_path in input_image_paths: - img = giatools.Image.read(image_path, normalize_axes=normalized_axes) + img = giatools.Image.read(image_path, normalize_axes=giatools.default_normalized_axes) arr = img.data # Preserve values: Convert to `float` dtype without changing the values @@ -30,25 +30,106 @@ else: arr = skimage.util.img_as_float(arr) + # Record the metadata + for metadata_key, metadata_value in img.metadata.items(): + metadata.setdefault(metadata_key, list()) + metadata[metadata_key].append(metadata_value) + + # Record the image data images.append(arr) + # Perform sorting, if requested + if sort_by is not None: + + # Validate that `sort_by` is available as metadata for all images + sort_keys = list( + filter( + lambda value: value is not None, + metadata.get(sort_by, list()), + ), + ) + if len(sort_keys) != len(images): + raise ValueError( + f'Requested to sort by "{sort_by}", ' + f'but this is not available for all {len(images)} images' + f' (available for only {len(sort_keys)} images)' + ) + + # Sort images by the corresponding `sort_key` metadata value + sorted_indices = sorted(range(len(images)), key=lambda i: sort_keys[i]) + images = [images[i] for i in sorted_indices] + + # Determine consensual metadata + # TODO: Convert metadata of images with different units of measurement into a common unit + final_metadata = dict() + for metadata_key, metadata_values in metadata.items(): + if (metadata_value := reduce_metadata(metadata_values)) is not None: + final_metadata[metadata_key] = metadata_value + + # Update the `z_spacing` metadata, if concatenating along the Z-axis and `z_position` is available for all images + if axis == 'Z' and len(images) >= 2 and len(z_positions := metadata.get('z_position', list())) == len(images): + z_positions = sorted(z_positions) # don't mutate the `metadata` dictionary for easier future code maintenance + final_metadata['z_spacing'] = abs(np.subtract(z_positions[1:], z_positions[:-1]).mean()) + # Do the concatenation - axis_pos = normalized_axes.index(axis) + axis_pos = giatools.default_normalized_axes.index(axis) arr = np.concatenate(images, axis_pos) - res = giatools.Image(arr, normalized_axes) + res = giatools.Image( + data=arr, + axes=giatools.default_normalized_axes, + metadata=final_metadata, + ) # Squeeze singleton axes and save - squeezed_axes = ''.join(np.array(list(res.axes))[np.array(arr.shape) > 1]) - res = res.squeeze_like(squeezed_axes) + res = res.squeeze() + print('Output TIFF shape:', res.data.shape) + print('Output TIFF axes:', res.axes) + print('Output TIFF', metadata_to_str(final_metadata)) res.write(output_image_path, backend='tifffile') +def reduce_metadata(values: list[Any]) -> Any | None: + non_none_values = list(filter(lambda value: value is not None, values)) + + # Reduction is not possible if more than one type is involved (or none) + value_types = [type(value) for value in non_none_values] + if len(frozenset(value_types)) != 1: + return None + else: + value_type = value_types[0] + + # For floating point types, reduce via arithmetic average + if np.issubdtype(value_type, np.floating): + return np.mean(non_none_values) + + # For integer types, reduce via the median + if np.issubdtype(value_type, np.integer): + return int(np.median(non_none_values)) + + # For all other types, reduction is only possible if the values are identical + if len(frozenset(non_none_values)) == 1: + return non_none_values[0] + else: + return None + + +def metadata_to_str(metadata: dict) -> str: + tokens = list() + for key in sorted(metadata.keys()): + value = metadata[key] + if isinstance(value, tuple): + value = '(' + ', '.join([f'{val}' for val in value]) + ')' + tokens.append(f'{key}: {value}') + return ', '.join(tokens) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('input_files', type=str, nargs='+') parser.add_argument('out_file', type=str) parser.add_argument('axis', type=str) parser.add_argument('--preserve_values', default=False, action='store_true') + parser.add_argument('--sort_by', type=str, default=None) args = parser.parse_args() concat_channels( @@ -56,4 +137,5 @@ args.out_file, args.axis, args.preserve_values, + args.sort_by, )
