diff workflow/run_tomo.py @ 75:d5e1d4ea2b7e draft default tip

planemo upload for repository https://github.com/rolfverberg/galaxytools commit 6afde341a94586fe3972bdbbfbf5dabd5e8dec69
author rv43
date Thu, 23 Mar 2023 13:39:14 +0000
parents 1cf15b61cd83
children
line wrap: on
line diff
--- a/workflow/run_tomo.py	Tue Mar 21 17:40:03 2023 +0000
+++ b/workflow/run_tomo.py	Thu Mar 23 13:39:14 2023 +0000
@@ -15,7 +15,7 @@
 
 from multiprocessing import cpu_count
 from nexusformat.nexus import *
-from os import mkdir
+from os import mkdir, environ
 from os import path as os_path
 try:
     from skimage.transform import iradon
@@ -99,7 +99,7 @@
             self.num_core = num_core
 
     def __enter__(self):
-        self.num_core_org = ne.set_num_threads(self.num_core)
+        self.num_core_org = ne.set_num_threads(min(self.num_core, ne.MAX_THREADS))
 
     def __exit__(self, exc_type, exc_value, traceback):
         ne.set_num_threads(self.num_core_org)
@@ -164,28 +164,42 @@
             self.num_core= cpu_count()
 
     def read(self, filename):
-        extension = os_path.splitext(filename)[1]
-        if extension == '.yml' or extension == '.yaml':
-            with open(filename, 'r') as f:
-                config = safe_load(f)
-#            if len(config) > 1:
-#                raise ValueError(f'Multiple root entries in {filename} not yet implemented')
-#            if len(list(config.values())[0]) > 1:
-#                raise ValueError(f'Multiple sample maps in {filename} not yet implemented')
-            return(config)
-        elif extension == '.nxs':
-            with NXFile(filename, mode='r') as nxfile:
-                nxroot = nxfile.readfile()
-            return(nxroot)
+        logger.info(f'looking for {filename}')
+        if self.galaxy_flag:
+            try:
+                with open(filename, 'r') as f:
+                    config = safe_load(f)
+                return(config)
+            except:
+                try:
+                    with NXFile(filename, mode='r') as nxfile:
+                        nxroot = nxfile.readfile()
+                    return(nxroot)
+                except:
+                    raise ValueError(f'Unable to open ({filename})')
         else:
-            raise ValueError(f'Invalid filename extension ({extension})')
+            extension = os_path.splitext(filename)[1]
+            if extension == '.yml' or extension == '.yaml':
+                with open(filename, 'r') as f:
+                    config = safe_load(f)
+#                if len(config) > 1:
+#                    raise ValueError(f'Multiple root entries in {filename} not yet implemented')
+#                if len(list(config.values())[0]) > 1:
+#                    raise ValueError(f'Multiple sample maps in {filename} not yet implemented')
+                return(config)
+            elif extension == '.nxs':
+                with NXFile(filename, mode='r') as nxfile:
+                    nxroot = nxfile.readfile()
+                return(nxroot)
+            else:
+                raise ValueError(f'Invalid filename extension ({extension})')
 
     def write(self, data, filename):
         extension = os_path.splitext(filename)[1]
         if extension == '.yml' or extension == '.yaml':
             with open(filename, 'w') as f:
                 safe_dump(data, f)
-        elif extension == '.nxs':
+        elif extension == '.nxs' or extension == '.nex':
             data.save(filename, mode='w')
         elif extension == '.nc':
             data.to_netcdf(os_path=filename)
@@ -287,14 +301,18 @@
         nxentry = nxroot[nxroot.attrs['default']]
         if not isinstance(nxentry, NXentry):
             raise ValueError(f'Invalid nxentry ({nxentry})')
-        if self.galaxy_flag:
-            if center_rows is not None:
-                center_rows = tuple(center_rows)
-                if not is_int_pair(center_rows):
+        if center_rows is not None:
+            if self.galaxy_flag:
+                if not is_int_pair(center_rows, ge=-1):
                     raise ValueError(f'Invalid parameter center_rows ({center_rows})')
-        elif center_rows is not None:
-            logger.warning(f'Ignoring parameter center_rows ({center_rows})')
-            center_rows = None
+                if (center_rows[0] != -1 and center_rows[1] != -1 and
+                        center_rows[0] > center_rows[1]):
+                    center_rows = (center_rows[1], center_rows[0])
+                else:
+                    center_rows = tuple(center_rows)
+            else:
+                logger.warning(f'Ignoring parameter center_rows ({center_rows})')
+                center_rows = None
         if self.galaxy_flag:
             if center_stack_index is not None and not is_int(center_stack_index, ge=0):
                 raise ValueError(f'Invalid parameter center_stack_index ({center_stack_index})')
@@ -360,10 +378,10 @@
         if self.test_mode:
             lower_row = self.test_config['lower_row']
         elif self.galaxy_flag:
-            if center_rows is None:
+            if center_rows is None or center_rows[0] == -1:
                 lower_row = 0
             else:
-                lower_row = min(center_rows)
+                lower_row = center_rows[0]
                 if not 0 <= lower_row < tomo_fields_shape[2]-1:
                     raise ValueError(f'Invalid parameter center_rows ({center_rows})')
         else:
@@ -374,7 +392,6 @@
         logger.debug('Finding center...')
         t0 = time()
         lower_center_offset = self._find_center_one_plane(
-                #np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index,:,lower_row,:]),
                 nxentry.reduced_data.data.tomo_fields[center_stack_index,:,lower_row,:],
                 lower_row, thetas, eff_pixel_size, cross_sectional_dim, path=path,
                 num_core=self.num_core)
@@ -386,10 +403,10 @@
         if self.test_mode:
             upper_row = self.test_config['upper_row']
         elif self.galaxy_flag:
-            if center_rows is None:
+            if center_rows is None or center_rows[1] == -1:
                 upper_row = tomo_fields_shape[2]-1
             else:
-                upper_row = max(center_rows)
+                upper_row = center_rows[1]
                 if not lower_row < upper_row < tomo_fields_shape[2]:
                     raise ValueError(f'Invalid parameter center_rows ({center_rows})')
         else:
@@ -499,15 +516,31 @@
         # Resize the reconstructed tomography data
         #   reconstructed data order in each stack: row/z,x,y
         if self.test_mode:
-            x_bounds = self.test_config.get('x_bounds')
-            y_bounds = self.test_config.get('y_bounds')
+            x_bounds = tuple(self.test_config.get('x_bounds'))
+            y_bounds = tuple(self.test_config.get('y_bounds'))
             z_bounds = None
         elif self.galaxy_flag:
-            if x_bounds is not None and not is_int_pair(x_bounds, ge=0,
-                    lt=tomo_recon_stacks[0].shape[1]):
+            x_max = tomo_recon_stacks[0].shape[1]
+            if x_bounds is None:
+                x_bounds = (0, x_max)
+            elif is_int_pair(x_bounds, ge=-1, le=x_max):
+                x_bounds = tuple(x_bounds)
+                if x_bounds[0] == -1:
+                    x_bounds = (0, x_bounds[1])
+                if x_bounds[1] == -1:
+                    x_bounds = (x_bounds[0], x_max)
+            if not is_index_range(x_bounds, ge=0, le=x_max):
                 raise ValueError(f'Invalid parameter x_bounds ({x_bounds})')
-            if y_bounds is not None and not is_int_pair(y_bounds, ge=0,
-                    lt=tomo_recon_stacks[0].shape[1]):
+            y_max = tomo_recon_stacks[0].shape[1]
+            if y_bounds is None:
+                y_bounds = (0, y_max)
+            elif is_int_pair(y_bounds, ge=-1, le=y_max):
+                y_bounds = tuple(y_bounds)
+                if y_bounds[0] == -1:
+                    y_bounds = (0, y_bounds[1])
+                if y_bounds[1] == -1:
+                    y_bounds = (y_bounds[0], y_max)
+            if not is_index_range(y_bounds, ge=0, le=y_max):
                 raise ValueError(f'Invalid parameter y_bounds ({y_bounds})')
             z_bounds = None
         else:
@@ -535,17 +568,17 @@
         if num_tomo_stacks == 1:
             basetitle = 'recon'
         else:
-            basetitle = f'recon stack {i+1}'
+            basetitle = f'recon stack'
         for i, stack in enumerate(tomo_recon_stacks):
-            title = f'{basetitle} {res_title} xslice{x_slice}'
+            title = f'{basetitle} {i+1} {res_title} xslice{x_slice}'
             quick_imshow(stack[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]],
                     title=title, path=path, save_fig=self.save_figs, save_only=self.save_only,
                     block=self.block)
-            title = f'{basetitle} {res_title} yslice{y_slice}'
+            title = f'{basetitle} {i+1} {res_title} yslice{y_slice}'
             quick_imshow(stack[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice],
                     title=title, path=path, save_fig=self.save_figs, save_only=self.save_only,
                     block=self.block)
-            title = f'{basetitle} {res_title} zslice{z_slice}'
+            title = f'{basetitle} {i+1} {res_title} zslice{z_slice}'
             quick_imshow(stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]],
                     title=title, path=path, save_fig=self.save_figs, save_only=self.save_only,
                     block=self.block)
@@ -740,7 +773,7 @@
 
         # Take median
         if tdf_stack.ndim == 2:
-            tdf = tdf_stack
+            tdf = tdf_stack.astype('float64')
         elif tdf_stack.ndim == 3:
             tdf = np.median(tdf_stack, axis=0)
             del tdf_stack
@@ -808,7 +841,7 @@
            We don’t typically account for them but potentially could.
         """
         if tbf_stack.ndim == 2:
-            tbf = tbf_stack
+            tbf = tbf_stack.astype('float64')
         elif tbf_stack.ndim == 3:
             tbf = np.median(tbf_stack, axis=0)
             del tbf_stack
@@ -823,7 +856,7 @@
 
         # Set any non-positive values to one
         # (avoid negative bright field values for spikes in dark field)
-        tbf[tbf < 1] = 1
+        tbf[tbf < 1.0] = 1.0
 
         # Plot bright field
         if self.galaxy_flag:
@@ -873,9 +906,6 @@
 
         # Select image bounds
         title = f'tomography image at theta={round(theta, 2)+0}'
-        if (img_x_bounds is not None and not is_index_range(img_x_bounds, ge=0,
-                le=first_image.shape[0])):
-            raise ValueError(f'Invalid parameter img_x_bounds ({img_x_bounds})')
         if nxentry.instrument.source.attrs['station'] in ('id1a3', 'id3a'):
             pixel_size = nxentry.instrument.detector.x_pixel_size
             # Try to get a fit from the bright field
@@ -977,6 +1007,14 @@
             if self.galaxy_flag:
                 if img_x_bounds is None:
                     img_x_bounds = (0, first_image.shape[0])
+                elif is_int_pair(img_x_bounds, ge=-1, le=first_image.shape[0]):
+                    img_x_bounds = tuple(img_x_bounds)
+                    if img_x_bounds[0] == -1:
+                        img_x_bounds = (0, img_x_bounds[1])
+                    if img_x_bounds[1] == -1:
+                        img_x_bounds = (img_x_bounds[0], first_image.shape[0])
+                if not is_index_range(img_x_bounds, ge=0, le=first_image.shape[0]):
+                    raise ValueError(f'Invalid parameter img_x_bounds ({img_x_bounds})')
             else:
                 quick_imshow(first_image, title=title)
                 print('Select vertical data reduction range from first tomography image')
@@ -1031,23 +1069,29 @@
         """
         # Get full bright field
         tbf = np.asarray(reduced_data.data.bright_field)
-        tbf_shape = tbf.shape
+        img_shape = tbf.shape
 
         # Get image bounds
-        img_x_bounds = tuple(reduced_data.get('img_x_bounds', (0, tbf_shape[0])))
-        img_y_bounds = tuple(reduced_data.get('img_y_bounds', (0, tbf_shape[1])))
+        img_x_bounds = tuple(reduced_data.get('img_x_bounds', (0, img_shape[0])))
+        img_y_bounds = tuple(reduced_data.get('img_y_bounds', (0, img_shape[1])))
+        if img_x_bounds == (0, img_shape[0]) and img_y_bounds == (0, img_shape[1]):
+            resize_flag = False
+        else:
+            resize_flag = True
 
         # Get resized dark field
-#        if 'dark_field' in data:
-#            tbf = np.asarray(reduced_data.data.dark_field[
-#                    img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]])
-#        else:
-#            logger.warning('Dark field unavailable')
-#            tdf = None
-        tdf = None
+        if 'dark_field' in reduced_data.data:
+            if resize_flag:
+                tdf = np.asarray(reduced_data.data.dark_field[
+                        img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]])
+            else:
+                tdf = np.asarray(reduced_data.data.dark_field)
+        else:
+            logger.warning('Dark field unavailable')
+            tdf = None
 
         # Resize bright field
-        if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]):
+        if resize_flag:
             tbf = tbf[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]]
 
         # Get the tomography images
@@ -1107,13 +1151,15 @@
         else:
             path = self.output_folder
         for i, tomo_stack in enumerate(tomo_stacks):
-            # Resize the tomography images
-            # Right now the range is the same for each set in the image stack.
-            if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]):
+            # Resize the tomography images as needed
+            # Right now the range is the same for each set in the image stack
+            if resize_flag:
                 t0 = time()
                 tomo_stack = tomo_stack[:,img_x_bounds[0]:img_x_bounds[1],
                         img_y_bounds[0]:img_y_bounds[1]].astype('float64')
                 logger.debug(f'Resizing tomography images took {time()-t0:.2f} seconds')
+            else:
+                tomo_stack = tomo_stack.astype('float64')
 
             # Subtract dark field
             if tdf is not None: