diff workflow/run_tomo.py @ 71:1cf15b61cd83 draft

planemo upload for repository https://github.com/rolfverberg/galaxytools commit 366e516aef0735af2998c6ff3af037181c8d5213
author rv43
date Mon, 20 Mar 2023 13:56:57 +0000
parents fba792d5f83b
children d5e1d4ea2b7e
line wrap: on
line diff
--- a/workflow/run_tomo.py	Fri Mar 10 16:39:22 2023 +0000
+++ b/workflow/run_tomo.py	Mon Mar 20 13:56:57 2023 +0000
@@ -32,13 +32,24 @@
     pass
 from yaml import safe_load, safe_dump
 
-from msnctools.fit import Fit
-from msnctools.general import illegal_value, is_int, is_int_pair, is_num, is_index_range, \
-        input_int, input_num, input_yesno, input_menu, draw_mask_1d, select_image_bounds, \
-        select_one_image_bound, clear_imshow, quick_imshow, clear_plot, quick_plot
+try:
+    from msnctools.fit import Fit
+except:
+    from fit import Fit
+try:
+    from msnctools.general import illegal_value, is_int, is_int_pair, is_num, is_index_range, \
+            input_int, input_num, input_yesno, input_menu, draw_mask_1d, select_image_bounds, \
+            select_one_image_bound, clear_imshow, quick_imshow, clear_plot, quick_plot
+except:
+    from general import illegal_value, is_int, is_int_pair, is_num, is_index_range, \
+            input_int, input_num, input_yesno, input_menu, draw_mask_1d, select_image_bounds, \
+            select_one_image_bound, clear_imshow, quick_imshow, clear_plot, quick_plot
 
-from workflow.models import import_scanparser, FlatField, TomoField, TomoWorkflow
-from workflow.__version__ import __version__
+try:
+    from workflow.models import import_scanparser, FlatField, TomoField, TomoWorkflow
+    from workflow.__version__ import __version__
+except:
+    pass
 
 num_core_tomopy_limit = 24
 
@@ -266,7 +277,7 @@
  
         return(nxroot)
 
-    def find_centers(self, nxroot, center_rows=None):
+    def find_centers(self, nxroot, center_rows=None, center_stack_index=None):
         """Find the calibrated center axis info
         """
         logger.info('Find the calibrated center axis info')
@@ -277,13 +288,19 @@
         if not isinstance(nxentry, NXentry):
             raise ValueError(f'Invalid nxentry ({nxentry})')
         if self.galaxy_flag:
-            if center_rows is None:
-                raise ValueError(f'Missing parameter center_rows ({center_rows})')
-            if not is_int_pair(center_rows):
-                raise ValueError(f'Invalid parameter center_rows ({center_rows})')
+            if center_rows is not None:
+                center_rows = tuple(center_rows)
+                if not is_int_pair(center_rows):
+                    raise ValueError(f'Invalid parameter center_rows ({center_rows})')
         elif center_rows is not None:
-            logging.warning(f'Ignoring parameter center_rows ({center_rows})')
+            logger.warning(f'Ignoring parameter center_rows ({center_rows})')
             center_rows = None
+        if self.galaxy_flag:
+            if center_stack_index is not None and not is_int(center_stack_index, ge=0):
+                raise ValueError(f'Invalid parameter center_stack_index ({center_stack_index})')
+        elif center_stack_index is not None:
+            logger.warning(f'Ignoring parameter center_stack_index ({center_stack_index})')
+            center_stack_index = None
 
         # Create plot galaxy path directory and path if needed
         if self.galaxy_flag:
@@ -298,26 +315,28 @@
             raise KeyError(f'Unable to find valid reduced data in {nxentry}.')
 
         # Select the image stack to calibrate the center axis
-        #   reduced data axes order: stack,row,theta,column
+        #   reduced data axes order: stack,theta,row,column
         #   Note: Nexus cannot follow a link if the data it points to is too big,
         #         so get the data from the actual place, not from nxentry.data
-        num_tomo_stacks = nxentry.reduced_data.data.tomo_fields.shape[0]
+        tomo_fields_shape = nxentry.reduced_data.data.tomo_fields.shape
+        if len(tomo_fields_shape) != 4 or any(True for dim in tomo_fields_shape if not dim):
+            raise KeyError('Unable to load the required reduced tomography stack')
+        num_tomo_stacks = tomo_fields_shape[0]
         if num_tomo_stacks == 1:
             center_stack_index = 0
-            center_stack = np.asarray(nxentry.reduced_data.data.tomo_fields[0])
-            if not center_stack.size:
-                raise KeyError('Unable to load the required reduced tomography stack')
             default = 'n'
         else:
             if self.test_mode:
                 center_stack_index = self.test_config['center_stack_index']-1 # make offset 0
+            elif self.galaxy_flag:
+                if center_stack_index is None:
+                    center_stack_index = int(num_tomo_stacks/2)
+                if center_stack_index >= num_tomo_stacks:
+                    raise ValueError(f'Invalid parameter center_stack_index ({center_stack_index})')
             else:
                 center_stack_index = input_int('\nEnter tomography stack index to calibrate the '
-                        'center axis', ge=0, le=num_tomo_stacks-1, default=int(num_tomo_stacks/2))
-            center_stack = \
-                    np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index])
-            if not center_stack.size:
-                raise KeyError('Unable to load the required reduced tomography stack')
+                        'center axis', ge=1, le=num_tomo_stacks, default=int(1+num_tomo_stacks/2))
+                center_stack_index -= 1
             default = 'y'
 
         # Get thetas (in degrees)
@@ -331,26 +350,35 @@
             eff_pixel_size = nxentry.instrument.detector.x_pixel_size
 
         # Get cross sectional diameter
-        cross_sectional_dim = center_stack.shape[2]*eff_pixel_size
+        cross_sectional_dim = tomo_fields_shape[3]*eff_pixel_size
         logger.debug(f'cross_sectional_dim = {cross_sectional_dim}')
 
         # Determine center offset at sample row boundaries
         logger.info('Determine center offset at sample row boundaries')
 
         # Lower row center
-        # center_stack order: row,theta,column
         if self.test_mode:
             lower_row = self.test_config['lower_row']
         elif self.galaxy_flag:
-            lower_row = min(center_rows)
-            if not 0 <= lower_row < center_stack.shape[0]-1:
-                raise ValueError(f'Invalid parameter center_rows ({center_rows})')
+            if center_rows is None:
+                lower_row = 0
+            else:
+                lower_row = min(center_rows)
+                if not 0 <= lower_row < tomo_fields_shape[2]-1:
+                    raise ValueError(f'Invalid parameter center_rows ({center_rows})')
         else:
-            lower_row = select_one_image_bound(center_stack[:,0,:], 0, bound=0,
+            lower_row = select_one_image_bound(
+                    nxentry.reduced_data.data.tomo_fields[center_stack_index,0,:,:], 0, bound=0,
                     title=f'theta={round(thetas[0], 2)+0}',
-                    bound_name='row index to find lower center', default=default)
-        lower_center_offset = self._find_center_one_plane(center_stack[lower_row,:,:], lower_row,
-                thetas, eff_pixel_size, cross_sectional_dim, path=path, num_core=self.num_core)
+                    bound_name='row index to find lower center', default=default, raise_error=True)
+        logger.debug('Finding center...')
+        t0 = time()
+        lower_center_offset = self._find_center_one_plane(
+                #np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index,:,lower_row,:]),
+                nxentry.reduced_data.data.tomo_fields[center_stack_index,:,lower_row,:],
+                lower_row, thetas, eff_pixel_size, cross_sectional_dim, path=path,
+                num_core=self.num_core)
+        logger.debug(f'... done in {time()-t0:.2f} seconds')
         logger.debug(f'lower_row = {lower_row:.2f}')
         logger.debug(f'lower_center_offset = {lower_center_offset:.2f}')
 
@@ -358,18 +386,27 @@
         if self.test_mode:
             upper_row = self.test_config['upper_row']
         elif self.galaxy_flag:
-            upper_row = max(center_rows)
-            if not lower_row < upper_row < center_stack.shape[0]:
-                raise ValueError(f'Invalid parameter center_rows ({center_rows})')
+            if center_rows is None:
+                upper_row = tomo_fields_shape[2]-1
+            else:
+                upper_row = max(center_rows)
+                if not lower_row < upper_row < tomo_fields_shape[2]:
+                    raise ValueError(f'Invalid parameter center_rows ({center_rows})')
         else:
-            upper_row = select_one_image_bound(center_stack[:,0,:], 0,
-                    bound=center_stack.shape[0]-1, title=f'theta={round(thetas[0], 2)+0}',
-                    bound_name='row index to find upper center', default=default)
-        upper_center_offset = self._find_center_one_plane(center_stack[upper_row,:,:], upper_row,
-                thetas, eff_pixel_size, cross_sectional_dim, path=path, num_core=self.num_core)
+            upper_row = select_one_image_bound(
+                    nxentry.reduced_data.data.tomo_fields[center_stack_index,0,:,:], 0,
+                    bound=tomo_fields_shape[2]-1, title=f'theta={round(thetas[0], 2)+0}',
+                    bound_name='row index to find upper center', default=default, raise_error=True)
+        logger.debug('Finding center...')
+        t0 = time()
+        upper_center_offset = self._find_center_one_plane(
+                #np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index,:,upper_row,:]),
+                nxentry.reduced_data.data.tomo_fields[center_stack_index,:,upper_row,:],
+                upper_row, thetas, eff_pixel_size, cross_sectional_dim, path=path,
+                num_core=self.num_core)
+        logger.debug(f'... done in {time()-t0:.2f} seconds')
         logger.debug(f'upper_row = {upper_row:.2f}')
         logger.debug(f'upper_center_offset = {upper_center_offset:.2f}')
-        del center_stack
 
         center_config = {'lower_row': lower_row, 'lower_center_offset': lower_center_offset,
                 'upper_row': upper_row, 'upper_center_offset': upper_center_offset}
@@ -409,11 +446,6 @@
             raise KeyError(f'Unable to find valid reduced data in {nxentry}.')
 
         # Create an NXprocess to store image reconstruction (meta)data
-#        if 'reconstructed_data' in nxentry:
-#            logger.warning(f'Existing reconstructed data in {nxentry} will be overwritten.')
-#            del nxentry['reconstructed_data']
-#        if 'data' in nxentry and 'reconstructed_data' in nxentry.data:
-#            del nxentry.data['reconstructed_data']
         nxprocess = NXprocess()
 
         # Get rotation axis rows and centers
@@ -430,7 +462,7 @@
         thetas = np.asarray(nxentry.reduced_data.rotation_angle)
 
         # Reconstruct tomography data
-        #   reduced data axes order: stack,row,theta,column
+        #   reduced data axes order: stack,theta,row,column
         #   reconstructed data order in each stack: row/z,x,y
         #   Note: Nexus cannot follow a link if the data it points to is too big,
         #         so get the data from the actual place, not from nxentry.data
@@ -442,9 +474,15 @@
         num_tomo_stacks = nxentry.reduced_data.data.tomo_fields.shape[0]
         tomo_recon_stacks = num_tomo_stacks*[np.array([])]
         for i in range(num_tomo_stacks):
+            # Convert reduced data stack from theta,row,column to row,theta,column
+            logger.debug(f'Reading reduced data stack {i+1}...')
+            t0 = time()
             tomo_stack = np.asarray(nxentry.reduced_data.data.tomo_fields[i])
-            if not tomo_stack.size:
-                raise KeyError(f'Unable to load tomography stack {i} for reconstruction')
+            logger.debug(f'... done in {time()-t0:.2f} seconds')
+            if len(tomo_stack.shape) != 3 or any(True for dim in tomo_stack.shape if not dim):
+                raise ValueError(f'Unable to load tomography stack {i+1} for reconstruction')
+            tomo_stack = np.swapaxes(tomo_stack, 0, 1)
+            assert(len(thetas) == tomo_stack.shape[1])
             assert(0 <= lower_row < upper_row < tomo_stack.shape[0])
             center_offsets = [lower_center_offset-lower_row*center_slope,
                     upper_center_offset+(tomo_stack.shape[0]-1-upper_row)*center_slope]
@@ -453,7 +491,7 @@
             tomo_recon_stack = self._reconstruct_one_tomo_stack(tomo_stack, thetas,
                     center_offsets=center_offsets, num_core=self.num_core, algorithm='gridrec')
             logger.debug(f'... done in {time()-t0:.2f} seconds')
-            logger.info(f'Reconstruction of stack {i} took {time()-t0:.2f} seconds')
+            logger.info(f'Reconstruction of stack {i+1} took {time()-t0:.2f} seconds')
 
             # Combine stacks
             tomo_recon_stacks[i] = tomo_recon_stack
@@ -497,7 +535,7 @@
         if num_tomo_stacks == 1:
             basetitle = 'recon'
         else:
-            basetitle = f'recon stack {i}'
+            basetitle = f'recon stack {i+1}'
         for i, stack in enumerate(tomo_recon_stacks):
             title = f'{basetitle} {res_title} xslice{x_slice}'
             quick_imshow(stack[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]],
@@ -550,7 +588,7 @@
 
         return(nxroot_copy)
 
-    def combine_data(self, nxroot):
+    def combine_data(self, nxroot, x_bounds=None, y_bounds=None):
         """Combine the reconstructed tomography stacks.
         """
         logger.info('Combine the reconstructed tomography stacks')
@@ -574,41 +612,39 @@
             raise KeyError(f'Unable to find valid reconstructed image data in {nxentry}.')
 
         # Create an NXprocess to store combined image reconstruction (meta)data
-#        if 'combined_data' in nxentry:
-#            logger.warning(f'Existing combined data in {nxentry} will be overwritten.')
-#            del nxentry['combined_data']
-#        if 'data' in nxentry 'combined_data' in nxentry.data:
-#            del nxentry.data['combined_data']
         nxprocess = NXprocess()
 
         # Get the reconstructed data
         #   reconstructed data order: stack,row(z),x,y
         #   Note: Nexus cannot follow a link if the data it points to is too big,
         #         so get the data from the actual place, not from nxentry.data
-        tomo_recon_stacks = np.asarray(nxentry.reconstructed_data.data.reconstructed_data)
-        num_tomo_stacks = tomo_recon_stacks.shape[0]
+        num_tomo_stacks = nxentry.reconstructed_data.data.reconstructed_data.shape[0]
         if num_tomo_stacks == 1:
-            return(nxroot)
+            logger.info('Only one stack available: leaving combine_data')
+            return(None)
+
+        # Combine the reconstructed stacks
+        # (load one stack at a time to reduce risk of hitting Nexus data access limit)
         t0 = time()
         logger.debug(f'Combining the reconstructed stacks ...')
-        tomo_recon_combined = tomo_recon_stacks[0,:,:,:]
+        tomo_recon_combined = np.asarray(nxentry.reconstructed_data.data.reconstructed_data[0])
         if num_tomo_stacks > 2:
             tomo_recon_combined = np.concatenate([tomo_recon_combined]+
-                    [tomo_recon_stacks[i,:,:,:] for i in range(1, num_tomo_stacks-1)])
+                    [nxentry.reconstructed_data.data.reconstructed_data[i]
+                    for i in range(1, num_tomo_stacks-1)])
         if num_tomo_stacks > 1:
             tomo_recon_combined = np.concatenate([tomo_recon_combined]+
-                    [tomo_recon_stacks[num_tomo_stacks-1,:,:,:]])
+                    [nxentry.reconstructed_data.data.reconstructed_data[num_tomo_stacks-1]])
         logger.debug(f'... done in {time()-t0:.2f} seconds')
         logger.info(f'Combining the reconstructed stacks took {time()-t0:.2f} seconds')
 
-        # Resize the combined tomography data set
+        # Resize the combined tomography data stacks
         #   combined data order: row/z,x,y
         if self.test_mode:
             x_bounds = None
             y_bounds = None
             z_bounds = self.test_config.get('z_bounds')
         elif self.galaxy_flag:
-            exit('TODO')
             if x_bounds is not None and not is_int_pair(x_bounds, ge=0,
                     lt=tomo_recon_stacks[0].shape[1]):
                 raise ValueError(f'Invalid parameter x_bounds ({x_bounds})')
@@ -699,7 +735,8 @@
             prefix = str(nxentry.instrument.detector.local_name)
             tdf_stack = dark_field.get_detector_data(prefix)
             if isinstance(tdf_stack, list):
-                exit('TODO')
+                assert(len(tdf_stack) == 1) # TODO
+                tdf_stack = tdf_stack[0]
 
         # Take median
         if tdf_stack.ndim == 2:
@@ -757,7 +794,8 @@
             prefix = str(nxentry.instrument.detector.local_name)
             tbf_stack = bright_field.get_detector_data(prefix)
             if isinstance(tbf_stack, list):
-                exit('TODO')
+                assert(len(tbf_stack) == 1) # TODO
+                tbf_stack = tbf_stack[0]
 
         # Take median if more than one image
         """Median or mean: It may be best to try the median because of some image 
@@ -818,8 +856,8 @@
             first_image = np.asarray(nxentry.instrument.detector.data[field_indices[0],:,:])
             theta = float(nxentry.sample.rotation_angle[field_indices[0]])
             z_translation_all = nxentry.sample.z_translation[field_indices]
-            z_translation_levels = sorted(list(set(z_translation_all)))
-            num_tomo_stacks = len(z_translation_levels)
+            vertical_shifts = sorted(list(set(z_translation_all)))
+            num_tomo_stacks = len(vertical_shifts)
         else:
             tomo_field_scans = nxentry.spec_scans.tomo_fields
             tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans)
@@ -1112,7 +1150,7 @@
                 if len(tomo_stacks) == 1:
                     title = f'red fullres theta {round(thetas[0], 2)+0}'
                 else:
-                    title = f'red stack {i} fullres theta {round(thetas[0], 2)+0}'
+                    title = f'red stack {i+1} fullres theta {round(thetas[0], 2)+0}'
                 quick_imshow(tomo_stack[0,:,:], title=title, path=path, save_fig=self.save_figs,
                         save_only=self.save_only, block=self.block)
 #                if not self.block:
@@ -1135,15 +1173,13 @@
 #                    if not self.block:
 #                        clear_imshow(title)
 
-            # Convert tomography stack from theta,row,column to row,theta,column
-            t0 = time()
-            tomo_stack = np.swapaxes(tomo_stack, 0, 1)
-            logger.debug(f'Converting coordinate order took {time()-t0:.2f} seconds')
-
             # Save test data to file
             if self.test_mode:
-                row_index = int(tomo_stack.shape[0]/2)
-                np.savetxt(f'{self.output_folder}/red_stack_{i+1}.txt', tomo_stack[row_index,:,:],
+#                row_index = int(tomo_stack.shape[0]/2)
+#                np.savetxt(f'{self.output_folder}/red_stack_{i+1}.txt', tomo_stack[row_index,:,:],
+#                        fmt='%.6e')
+                row_index = int(tomo_stack.shape[1]/2)
+                np.savetxt(f'{self.output_folder}/red_stack_{i+1}.txt', tomo_stack[:,row_index,:],
                         fmt='%.6e')
 
             # Combine resized stacks
@@ -1168,6 +1204,7 @@
         # Try automatic center finding routines for initial value
         # sinogram index order: theta,column
         # need column,theta for iradon, so take transpose
+        sinogram = np.asarray(sinogram)
         sinogram_T = sinogram.T
         center = sinogram.shape[1]/2
 
@@ -1499,7 +1536,7 @@
                     accept = True
             logger.debug(f'y_bounds = {y_bounds}')
 
-        # Selecting z bounds (in xy-plane) (only valid for a single image set)
+        # Selecting z bounds (in xy-plane) (only valid for a single image stack)
         if num_tomo_stacks != 1:
             z_bounds = None
         else:
@@ -1548,9 +1585,11 @@
     logger.info(f'test_mode = {test_mode}')
 
     # Check for correction modes
+    legal_modes = ['reduce_data', 'find_center', 'reconstruct_data', 'combine_data', 'all']
     if modes is None:
         modes = ['all']
-    logger.debug(f'modes {type(modes)} = {modes}')
+    if not all(True if mode in legal_modes else False for mode in modes):
+        raise ValueError(f'Invalid parameter modes ({modes})')
 
     # Instantiate Tomo object
     tomo = Tomo(num_core=num_core, output_folder=output_folder, save_figs=save_figs,
@@ -1581,9 +1620,10 @@
         data = tomo.combine_data(data)
 
     # Write output file
-    if not test_mode:
+    if data is not None and not test_mode:
         if center_data is None:
             data = tomo.write(data, output_file)
         else:
             data = tomo.write(center_data, output_file)
 
+    logger.info(f'Completed modes: {modes}')