diff auto_threshold.py @ 4:7d80eb2411fb draft default tip

planemo upload for repository https://github.com/BMCV/galaxy-image-analysis/tree/master/tools/2d_auto_threshold/ commit 01343602708de3cc7fa4986af9000adc36dd0651
author imgteam
date Sat, 07 Jun 2025 18:38:16 +0000
parents 5224cc463a97
children
line wrap: on
line diff
--- a/auto_threshold.py	Sat Feb 19 15:16:52 2022 +0000
+++ b/auto_threshold.py	Sat Jun 07 18:38:16 2025 +0000
@@ -1,47 +1,117 @@
 """
-Copyright 2017-2022 Biomedical Computer Vision Group, Heidelberg University.
+Copyright 2017-2024 Biomedical Computer Vision Group, Heidelberg University.
 
 Distributed under the MIT license.
 See file LICENSE for detail or copy at https://opensource.org/licenses/MIT
-
 """
 
 import argparse
 
+import numpy as np
 import skimage.filters
-import skimage.io
 import skimage.util
-import tifffile
+from giatools.image import Image
+
+
+class DefaultThresholdingMethod:
+
+    def __init__(self, thres, accept: list[str] | None = None, **kwargs):
+        self.thres = thres
+        self.accept = accept if accept else []
+        self.kwargs = kwargs
+
+    def __call__(self, image, *args, offset=0, **kwargs):
+        accepted_kwargs = self.kwargs.copy()
+        for key, val in kwargs.items():
+            if key in self.accept:
+                accepted_kwargs[key] = val
+        thres = self.thres(image, *args, **accepted_kwargs)
+        return image > thres + offset
+
 
-thOptions = {
-    'otsu': lambda img_raw, bz: skimage.filters.threshold_otsu(img_raw),
-    'li': lambda img_raw, bz: skimage.filters.threshold_li(img_raw),
-    'yen': lambda img_raw, bz: skimage.filters.threshold_yen(img_raw),
-    'isodata': lambda img_raw, bz: skimage.filters.threshold_isodata(img_raw),
+class ManualThresholding:
+
+    def __call__(self, image, thres1: float, thres2: float | None, **kwargs):
+        if thres2 is None:
+            return image > thres1
+        else:
+            thres1, thres2 = sorted((thres1, thres2))
+            return skimage.filters.apply_hysteresis_threshold(image, thres1, thres2)
+
 
-    'loc_gaussian': lambda img_raw, bz: skimage.filters.threshold_local(img_raw, bz, method='gaussian'),
-    'loc_median': lambda img_raw, bz: skimage.filters.threshold_local(img_raw, bz, method='median'),
-    'loc_mean': lambda img_raw, bz: skimage.filters.threshold_local(img_raw, bz, method='mean')
+th_methods = {
+    'manual': ManualThresholding(),
+
+    'otsu': DefaultThresholdingMethod(skimage.filters.threshold_otsu),
+    'li': DefaultThresholdingMethod(skimage.filters.threshold_li),
+    'yen': DefaultThresholdingMethod(skimage.filters.threshold_yen),
+    'isodata': DefaultThresholdingMethod(skimage.filters.threshold_isodata),
+
+    'loc_gaussian': DefaultThresholdingMethod(skimage.filters.threshold_local, accept=['block_size'], method='gaussian'),
+    'loc_median': DefaultThresholdingMethod(skimage.filters.threshold_local, accept=['block_size'], method='median'),
+    'loc_mean': DefaultThresholdingMethod(skimage.filters.threshold_local, accept=['block_size'], method='mean'),
 }
 
 
-def auto_thresholding(in_fn, out_fn, th_method, block_size=5, dark_bg=True):
-    img = skimage.io.imread(in_fn)
-    th = thOptions[th_method](img, block_size)
-    if dark_bg:
-        res = img > th
-    else:
-        res = img <= th
-    tifffile.imwrite(out_fn, skimage.util.img_as_ubyte(res))
+def do_thresholding(
+    input_filepath: str,
+    output_filepath: str,
+    th_method: str,
+    block_size: int,
+    offset: float,
+    threshold1: float,
+    threshold2: float | None,
+    invert_output: bool,
+):
+    assert th_method in th_methods, f'Unknown method "{th_method}"'
+
+    # Load image
+    img_in = Image.read(input_filepath)
+
+    # Perform thresholding
+    result = th_methods[th_method](
+        image=img_in.data,
+        block_size=block_size,
+        offset=offset,
+        thres1=threshold1,
+        thres2=threshold2,
+    )
+    if invert_output:
+        result = np.logical_not(result)
+
+    # Convert to canonical representation for binary images
+    result = (result * 255).astype(np.uint8)
+
+    # Write result
+    Image(
+        data=skimage.util.img_as_ubyte(result),
+        axes=img_in.axes,
+    ).normalize_axes_like(
+        img_in.original_axes,
+    ).write(
+        output_filepath,
+    )
 
 
 if __name__ == "__main__":
-    parser = argparse.ArgumentParser(description='Automatic Image Thresholding')
-    parser.add_argument('im_in', help='Path to the input image')
-    parser.add_argument('im_out', help='Path to the output image (TIFF)')
-    parser.add_argument('th_method', choices=thOptions.keys(), help='Thresholding method')
-    parser.add_argument('block_size', type=int, default=5, help='Odd size of pixel neighborhood for calculating the threshold')
-    parser.add_argument('dark_bg', default=True, type=bool, help='True if background is dark')
+    parser = argparse.ArgumentParser(description='Automatic image thresholding')
+    parser.add_argument('input', type=str, help='Path to the input image')
+    parser.add_argument('output', type=str, help='Path to the output image (uint8)')
+    parser.add_argument('th_method', choices=th_methods.keys(), help='Thresholding method')
+    parser.add_argument('block_size', type=int, help='Odd size of pixel neighborhood for calculating the threshold')
+    parser.add_argument('offset', type=float, help='Offset of automatically determined threshold value')
+    parser.add_argument('threshold1', type=float, help='Manual threshold value')
+    parser.add_argument('--threshold2', type=float, help='Second manual threshold value (for hysteresis thresholding)')
+    parser.add_argument('--invert_output', default=False, action='store_true', help='Values below/above the threshold are labeled with 0/255 by default, and with 255/0 if this argument is used')
     args = parser.parse_args()
 
-    auto_thresholding(args.im_in, args.im_out, args.th_method, args.block_size, args.dark_bg)
+    do_thresholding(
+        args.input,
+        args.output,
+        args.th_method,
+        args.block_size,
+        args.offset,
+        args.threshold1,
+        args.threshold2,
+        args.invert_output,
+    )