view split.py @ 3:d45a07063da1 draft default tip

planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tree/master/tools/split_image/ commit df96ae15da34285b0a9d435a48924665fff37d6a
author imgteam
date Sat, 04 Apr 2026 21:22:07 +0000
parents 227e8928af6e
children
line wrap: on
line source

import argparse
import math
import pathlib

import giatools
import giatools.io
import numpy as np


class OutputWriter:

    def __init__(
        self,
        dir_path: pathlib.Path,
        num_images: int,
        squeeze: bool,
        verbose: bool,
        offset: int = 0,
        step: int = 1,
        count: int | None = None,
    ):
        self.positions = np.arange(num_images)[offset::step] + 1
        if count is not None:
            self.positions = self.positions[:count]

        print(f'Writing {len(self.positions)} out of {num_images} image(s)')
        decimals = math.ceil(math.log10(1 + num_images))
        self.output_filepath_pattern = str(dir_path / f'%0{decimals}d.tiff')
        self.last_pos = 0
        self.squeeze = squeeze
        self.verbose = verbose

    def write(self, img: giatools.Image):
        self.last_pos += 1
        if self.last_pos in self.positions:
            if self.squeeze:
                img = img.squeeze()
            if self.last_pos == self.positions[0] or self.verbose:
                prefix = f'Output {self.last_pos}' if self.verbose else 'Output'
                print(f'{prefix} axes:', img.axes)
                print(f'{prefix} shape:', img.data.shape)
            img.write(self.output_filepath_pattern % self.last_pos)


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('input', type=pathlib.Path)
    parser.add_argument('axis', type=str, choices=list(giatools.default_normalized_axes) + ['S', ''])
    parser.add_argument('output', type=pathlib.Path)
    parser.add_argument('--squeeze', action='store_true', default=False)
    parser.add_argument('offset', type=int)
    parser.add_argument('step', type=int)
    parser.add_argument('--count', type=int)
    args = parser.parse_args()

    # If splitting a file that contains multiple images...
    if args.axis == '':

        # Peek the number of images
        num_images = giatools.io.peek_num_images_in_file(args.input)
        print(f'Found {num_images} image(s) in file')

        # Extract the individual images
        output = OutputWriter(
            dir_path=args.output,
            num_images=num_images,
            squeeze=args.squeeze,
            verbose=(num_images > 1),
            offset=args.offset,
            step=args.step,
            count=args.count,
        )
        for position in range(num_images):
            img = giatools.Image.read(args.input, position=position, normalize_axes=None)
            output.write(img)

    # If splitting along an image axes...
    else:

        # Validate and normalize input parameters
        axis = args.axis.replace('S', 'C')

        # Read input image with normalized axes
        img_in = giatools.Image.read(args.input)
        print('Input image axes:', img_in.original_axes)
        print('Input image shape:', img_in.squeeze_like(img_in.original_axes).data.shape)

        # Determine the axis to split along
        axis_pos = img_in.axes.index(axis)

        # Perform the splitting
        arr = np.moveaxis(img_in.data, axis_pos, 0)
        output = OutputWriter(
            dir_path=args.output,
            num_images=arr.shape[0],
            squeeze=args.squeeze,
            verbose=False,
            offset=args.offset,
            step=args.step,
            count=args.count,
        )
        for img_idx, img in enumerate(arr):
            img = np.moveaxis(img[None], 0, axis_pos)

            # Construct the output image, remove axes added by normalization
            img_out = giatools.Image(
                data=img,
                axes=img_in.axes,
                metadata=img_in.metadata,
            ).squeeze_like(
                img_in.original_axes,
            )

            # Save the result (write stdout during first iteration)
            output.write(img_out)