Mercurial > repos > rv43 > tomo
view workflow/run_tomo.py @ 69:fba792d5f83b draft
planemo upload for repository https://github.com/rolfverberg/galaxytools commit ab9f412c362a4ab986d00e21d5185cfcf82485c1
author | rv43 |
---|---|
date | Fri, 10 Mar 2023 16:02:04 +0000 |
parents | |
children | 1cf15b61cd83 |
line wrap: on
line source
#!/usr/bin/env python3 import logging logger = logging.getLogger(__name__) import numpy as np try: import numexpr as ne except: pass try: import scipy.ndimage as spi except: pass from multiprocessing import cpu_count from nexusformat.nexus import * from os import mkdir from os import path as os_path try: from skimage.transform import iradon except: pass try: from skimage.restoration import denoise_tv_chambolle except: pass from time import time try: import tomopy except: pass from yaml import safe_load, safe_dump from msnctools.fit import Fit from msnctools.general import illegal_value, is_int, is_int_pair, is_num, is_index_range, \ input_int, input_num, input_yesno, input_menu, draw_mask_1d, select_image_bounds, \ select_one_image_bound, clear_imshow, quick_imshow, clear_plot, quick_plot from workflow.models import import_scanparser, FlatField, TomoField, TomoWorkflow from workflow.__version__ import __version__ num_core_tomopy_limit = 24 def nxcopy(nxobject:NXobject, exclude_nxpaths:list[str]=[], nxpath_prefix:str='') -> NXobject: '''Function that returns a copy of a nexus object, optionally exluding certain child items. :param nxobject: the original nexus object to return a "copy" of :type nxobject: nexusformat.nexus.NXobject :param exlude_nxpaths: a list of paths to child nexus objects that should be exluded from the returned "copy", defaults to `[]` :type exclude_nxpaths: list[str], optional :param nxpath_prefix: For use in recursive calls from inside this function only! :type nxpath_prefix: str :return: a copy of `nxobject` with some children optionally exluded. :rtype: NXobject ''' nxobject_copy = nxobject.__class__() if not len(nxpath_prefix): if 'default' in nxobject.attrs: nxobject_copy.attrs['default'] = nxobject.attrs['default'] else: for k, v in nxobject.attrs.items(): nxobject_copy.attrs[k] = v for k, v in nxobject.items(): nxpath = os_path.join(nxpath_prefix, k) if nxpath in exclude_nxpaths: continue if isinstance(v, NXgroup): nxobject_copy[k] = nxcopy(v, exclude_nxpaths=exclude_nxpaths, nxpath_prefix=os_path.join(nxpath_prefix, k)) else: nxobject_copy[k] = v return(nxobject_copy) class set_numexpr_threads: def __init__(self, num_core): if num_core is None or num_core < 1 or num_core > cpu_count(): self.num_core = cpu_count() else: self.num_core = num_core def __enter__(self): self.num_core_org = ne.set_num_threads(self.num_core) def __exit__(self, exc_type, exc_value, traceback): ne.set_num_threads(self.num_core_org) class Tomo: """Processing tomography data with misalignment. """ def __init__(self, galaxy_flag=False, num_core=-1, output_folder='.', save_figs=None, test_mode=False): """Initialize with optional config input file or dictionary """ if not isinstance(galaxy_flag, bool): raise ValueError(f'Invalid parameter galaxy_flag ({galaxy_flag})') self.galaxy_flag = galaxy_flag self.num_core = num_core if self.galaxy_flag: if output_folder != '.': logger.warning('Ignoring output_folder in galaxy mode') self.output_folder = '.' if test_mode != False: logger.warning('Ignoring test_mode in galaxy mode') self.test_mode = False if save_figs is not None: logger.warning('Ignoring save_figs in galaxy mode') save_figs = 'only' else: self.output_folder = os_path.abspath(output_folder) if not os_path.isdir(output_folder): mkdir(os_path.abspath(output_folder)) if not isinstance(test_mode, bool): raise ValueError(f'Invalid parameter test_mode ({test_mode})') self.test_mode = test_mode if save_figs is None: save_figs = 'no' self.test_config = {} if self.test_mode: if save_figs != 'only': logger.warning('Ignoring save_figs in test mode') save_figs = 'only' if save_figs == 'only': self.save_only = True self.save_figs = True elif save_figs == 'yes': self.save_only = False self.save_figs = True elif save_figs == 'no': self.save_only = False self.save_figs = False else: raise ValueError(f'Invalid parameter save_figs ({save_figs})') if self.save_only: self.block = False else: self.block = True if self.num_core == -1: self.num_core = cpu_count() if not is_int(self.num_core, gt=0, log=False): raise ValueError(f'Invalid parameter num_core ({num_core})') if self.num_core > cpu_count(): logger.warning(f'num_core = {self.num_core} is larger than the number of available ' f'processors and reduced to {cpu_count()}') self.num_core= cpu_count() def read(self, filename): extension = os_path.splitext(filename)[1] if extension == '.yml' or extension == '.yaml': with open(filename, 'r') as f: config = safe_load(f) # if len(config) > 1: # raise ValueError(f'Multiple root entries in {filename} not yet implemented') # if len(list(config.values())[0]) > 1: # raise ValueError(f'Multiple sample maps in {filename} not yet implemented') return(config) elif extension == '.nxs': with NXFile(filename, mode='r') as nxfile: nxroot = nxfile.readfile() return(nxroot) else: raise ValueError(f'Invalid filename extension ({extension})') def write(self, data, filename): extension = os_path.splitext(filename)[1] if extension == '.yml' or extension == '.yaml': with open(filename, 'w') as f: safe_dump(data, f) elif extension == '.nxs': data.save(filename, mode='w') elif extension == '.nc': data.to_netcdf(os_path=filename) else: raise ValueError(f'Invalid filename extension ({extension})') def gen_reduced_data(self, data, img_x_bounds=None): """Generate the reduced tomography images. """ logger.info('Generate the reduced tomography images') # Create plot galaxy path directory if needed if self.galaxy_flag and not os_path.exists('tomo_reduce_plots'): mkdir('tomo_reduce_plots') if isinstance(data, dict): # Create Nexus format object from input dictionary wf = TomoWorkflow(**data) if len(wf.sample_maps) > 1: raise ValueError(f'Multiple sample maps not yet implemented') # print(f'\nwf:\n{wf}\n') nxroot = NXroot() t0 = time() for sample_map in wf.sample_maps: logger.info(f'Start constructing the {sample_map.title} map.') import_scanparser(sample_map.station) sample_map.construct_nxentry(nxroot, include_raw_data=False) logger.info(f'Constructed all sample maps in {time()-t0:.2f} seconds.') nxentry = nxroot[nxroot.attrs['default']] # Get test mode configuration info if self.test_mode: self.test_config = data['sample_maps'][0]['test_mode'] elif isinstance(data, NXroot): nxentry = data[data.attrs['default']] else: raise ValueError(f'Invalid parameter data ({data})') # Create an NXprocess to store data reduction (meta)data reduced_data = NXprocess() # Generate dark field if 'dark_field' in nxentry['spec_scans']: reduced_data = self._gen_dark(nxentry, reduced_data) # Generate bright field reduced_data = self._gen_bright(nxentry, reduced_data) # Set vertical detector bounds for image stack img_x_bounds = self._set_detector_bounds(nxentry, reduced_data, img_x_bounds=img_x_bounds) logger.info(f'img_x_bounds = {img_x_bounds}') reduced_data['img_x_bounds'] = img_x_bounds # Set zoom and/or theta skip to reduce memory the requirement zoom_perc, num_theta_skip = self._set_zoom_or_skip() if zoom_perc is not None: reduced_data.attrs['zoom_perc'] = zoom_perc if num_theta_skip is not None: reduced_data.attrs['num_theta_skip'] = num_theta_skip # Generate reduced tomography fields reduced_data = self._gen_tomo(nxentry, reduced_data) # Create a copy of the input Nexus object and remove raw and any existing reduced data if isinstance(data, NXroot): exclude_items = [f'{nxentry._name}/reduced_data/data', f'{nxentry._name}/instrument/detector/data', f'{nxentry._name}/instrument/detector/image_key', f'{nxentry._name}/instrument/detector/sequence_number', f'{nxentry._name}/sample/rotation_angle', f'{nxentry._name}/sample/x_translation', f'{nxentry._name}/sample/z_translation', f'{nxentry._name}/data/data', f'{nxentry._name}/data/image_key', f'{nxentry._name}/data/rotation_angle', f'{nxentry._name}/data/x_translation', f'{nxentry._name}/data/z_translation'] nxroot = nxcopy(data, exclude_nxpaths=exclude_items) nxentry = nxroot[nxroot.attrs['default']] # Add the reduced data NXprocess nxentry.reduced_data = reduced_data if 'data' not in nxentry: nxentry.data = NXdata() nxentry.attrs['default'] = 'data' nxentry.data.makelink(nxentry.reduced_data.data.tomo_fields, name='reduced_data') nxentry.data.makelink(nxentry.reduced_data.rotation_angle, name='rotation_angle') nxentry.data.attrs['signal'] = 'reduced_data' return(nxroot) def find_centers(self, nxroot, center_rows=None): """Find the calibrated center axis info """ logger.info('Find the calibrated center axis info') if not isinstance(nxroot, NXroot): raise ValueError(f'Invalid parameter nxroot ({nxroot})') nxentry = nxroot[nxroot.attrs['default']] if not isinstance(nxentry, NXentry): raise ValueError(f'Invalid nxentry ({nxentry})') if self.galaxy_flag: if center_rows is None: raise ValueError(f'Missing parameter center_rows ({center_rows})') if not is_int_pair(center_rows): raise ValueError(f'Invalid parameter center_rows ({center_rows})') elif center_rows is not None: logging.warning(f'Ignoring parameter center_rows ({center_rows})') center_rows = None # Create plot galaxy path directory and path if needed if self.galaxy_flag: if not os_path.exists('tomo_find_centers_plots'): mkdir('tomo_find_centers_plots') path = 'tomo_find_centers_plots' else: path = self.output_folder # Check if reduced data is available if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data): raise KeyError(f'Unable to find valid reduced data in {nxentry}.') # Select the image stack to calibrate the center axis # reduced data axes order: stack,row,theta,column # Note: Nexus cannot follow a link if the data it points to is too big, # so get the data from the actual place, not from nxentry.data num_tomo_stacks = nxentry.reduced_data.data.tomo_fields.shape[0] if num_tomo_stacks == 1: center_stack_index = 0 center_stack = np.asarray(nxentry.reduced_data.data.tomo_fields[0]) if not center_stack.size: raise KeyError('Unable to load the required reduced tomography stack') default = 'n' else: if self.test_mode: center_stack_index = self.test_config['center_stack_index']-1 # make offset 0 else: center_stack_index = input_int('\nEnter tomography stack index to calibrate the ' 'center axis', ge=0, le=num_tomo_stacks-1, default=int(num_tomo_stacks/2)) center_stack = \ np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index]) if not center_stack.size: raise KeyError('Unable to load the required reduced tomography stack') default = 'y' # Get thetas (in degrees) thetas = np.asarray(nxentry.reduced_data.rotation_angle) # Get effective pixel_size if 'zoom_perc' in nxentry.reduced_data: eff_pixel_size = 100.*(nxentry.instrument.detector.x_pixel_size/ nxentry.reduced_data.attrs['zoom_perc']) else: eff_pixel_size = nxentry.instrument.detector.x_pixel_size # Get cross sectional diameter cross_sectional_dim = center_stack.shape[2]*eff_pixel_size logger.debug(f'cross_sectional_dim = {cross_sectional_dim}') # Determine center offset at sample row boundaries logger.info('Determine center offset at sample row boundaries') # Lower row center # center_stack order: row,theta,column if self.test_mode: lower_row = self.test_config['lower_row'] elif self.galaxy_flag: lower_row = min(center_rows) if not 0 <= lower_row < center_stack.shape[0]-1: raise ValueError(f'Invalid parameter center_rows ({center_rows})') else: lower_row = select_one_image_bound(center_stack[:,0,:], 0, bound=0, title=f'theta={round(thetas[0], 2)+0}', bound_name='row index to find lower center', default=default) lower_center_offset = self._find_center_one_plane(center_stack[lower_row,:,:], lower_row, thetas, eff_pixel_size, cross_sectional_dim, path=path, num_core=self.num_core) logger.debug(f'lower_row = {lower_row:.2f}') logger.debug(f'lower_center_offset = {lower_center_offset:.2f}') # Upper row center if self.test_mode: upper_row = self.test_config['upper_row'] elif self.galaxy_flag: upper_row = max(center_rows) if not lower_row < upper_row < center_stack.shape[0]: raise ValueError(f'Invalid parameter center_rows ({center_rows})') else: upper_row = select_one_image_bound(center_stack[:,0,:], 0, bound=center_stack.shape[0]-1, title=f'theta={round(thetas[0], 2)+0}', bound_name='row index to find upper center', default=default) upper_center_offset = self._find_center_one_plane(center_stack[upper_row,:,:], upper_row, thetas, eff_pixel_size, cross_sectional_dim, path=path, num_core=self.num_core) logger.debug(f'upper_row = {upper_row:.2f}') logger.debug(f'upper_center_offset = {upper_center_offset:.2f}') del center_stack center_config = {'lower_row': lower_row, 'lower_center_offset': lower_center_offset, 'upper_row': upper_row, 'upper_center_offset': upper_center_offset} if num_tomo_stacks > 1: center_config['center_stack_index'] = center_stack_index+1 # save as offset 1 # Save test data to file if self.test_mode: with open(f'{self.output_folder}/center_config.yaml', 'w') as f: safe_dump(center_config, f) return(center_config) def reconstruct_data(self, nxroot, center_info, x_bounds=None, y_bounds=None): """Reconstruct the tomography data. """ logger.info('Reconstruct the tomography data') if not isinstance(nxroot, NXroot): raise ValueError(f'Invalid parameter nxroot ({nxroot})') nxentry = nxroot[nxroot.attrs['default']] if not isinstance(nxentry, NXentry): raise ValueError(f'Invalid nxentry ({nxentry})') if not isinstance(center_info, dict): raise ValueError(f'Invalid parameter center_info ({center_info})') # Create plot galaxy path directory and path if needed if self.galaxy_flag: if not os_path.exists('tomo_reconstruct_plots'): mkdir('tomo_reconstruct_plots') path = 'tomo_reconstruct_plots' else: path = self.output_folder # Check if reduced data is available if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data): raise KeyError(f'Unable to find valid reduced data in {nxentry}.') # Create an NXprocess to store image reconstruction (meta)data # if 'reconstructed_data' in nxentry: # logger.warning(f'Existing reconstructed data in {nxentry} will be overwritten.') # del nxentry['reconstructed_data'] # if 'data' in nxentry and 'reconstructed_data' in nxentry.data: # del nxentry.data['reconstructed_data'] nxprocess = NXprocess() # Get rotation axis rows and centers lower_row = center_info.get('lower_row') lower_center_offset = center_info.get('lower_center_offset') upper_row = center_info.get('upper_row') upper_center_offset = center_info.get('upper_center_offset') if (lower_row is None or lower_center_offset is None or upper_row is None or upper_center_offset is None): raise KeyError(f'Unable to find valid calibrated center axis info in {center_info}.') center_slope = (upper_center_offset-lower_center_offset)/(upper_row-lower_row) # Get thetas (in degrees) thetas = np.asarray(nxentry.reduced_data.rotation_angle) # Reconstruct tomography data # reduced data axes order: stack,row,theta,column # reconstructed data order in each stack: row/z,x,y # Note: Nexus cannot follow a link if the data it points to is too big, # so get the data from the actual place, not from nxentry.data if 'zoom_perc' in nxentry.reduced_data: res_title = f'{nxentry.reduced_data.attrs["zoom_perc"]}p' else: res_title = 'fullres' load_error = False num_tomo_stacks = nxentry.reduced_data.data.tomo_fields.shape[0] tomo_recon_stacks = num_tomo_stacks*[np.array([])] for i in range(num_tomo_stacks): tomo_stack = np.asarray(nxentry.reduced_data.data.tomo_fields[i]) if not tomo_stack.size: raise KeyError(f'Unable to load tomography stack {i} for reconstruction') assert(0 <= lower_row < upper_row < tomo_stack.shape[0]) center_offsets = [lower_center_offset-lower_row*center_slope, upper_center_offset+(tomo_stack.shape[0]-1-upper_row)*center_slope] t0 = time() logger.debug(f'Running _reconstruct_one_tomo_stack on {self.num_core} cores ...') tomo_recon_stack = self._reconstruct_one_tomo_stack(tomo_stack, thetas, center_offsets=center_offsets, num_core=self.num_core, algorithm='gridrec') logger.debug(f'... done in {time()-t0:.2f} seconds') logger.info(f'Reconstruction of stack {i} took {time()-t0:.2f} seconds') # Combine stacks tomo_recon_stacks[i] = tomo_recon_stack # Resize the reconstructed tomography data # reconstructed data order in each stack: row/z,x,y if self.test_mode: x_bounds = self.test_config.get('x_bounds') y_bounds = self.test_config.get('y_bounds') z_bounds = None elif self.galaxy_flag: if x_bounds is not None and not is_int_pair(x_bounds, ge=0, lt=tomo_recon_stacks[0].shape[1]): raise ValueError(f'Invalid parameter x_bounds ({x_bounds})') if y_bounds is not None and not is_int_pair(y_bounds, ge=0, lt=tomo_recon_stacks[0].shape[1]): raise ValueError(f'Invalid parameter y_bounds ({y_bounds})') z_bounds = None else: x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_stacks) if x_bounds is None: x_range = (0, tomo_recon_stacks[0].shape[1]) x_slice = int(x_range[1]/2) else: x_range = (min(x_bounds), max(x_bounds)) x_slice = int((x_bounds[0]+x_bounds[1])/2) if y_bounds is None: y_range = (0, tomo_recon_stacks[0].shape[2]) y_slice = int(y_range[1]/2) else: y_range = (min(y_bounds), max(y_bounds)) y_slice = int((y_bounds[0]+y_bounds[1])/2) if z_bounds is None: z_range = (0, tomo_recon_stacks[0].shape[0]) z_slice = int(z_range[1]/2) else: z_range = (min(z_bounds), max(z_bounds)) z_slice = int((z_bounds[0]+z_bounds[1])/2) # Plot a few reconstructed image slices if num_tomo_stacks == 1: basetitle = 'recon' else: basetitle = f'recon stack {i}' for i, stack in enumerate(tomo_recon_stacks): title = f'{basetitle} {res_title} xslice{x_slice}' quick_imshow(stack[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]], title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, block=self.block) title = f'{basetitle} {res_title} yslice{y_slice}' quick_imshow(stack[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice], title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, block=self.block) title = f'{basetitle} {res_title} zslice{z_slice}' quick_imshow(stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, block=self.block) # Save test data to file # reconstructed data order in each stack: row/z,x,y if self.test_mode: for i, stack in enumerate(tomo_recon_stacks): np.savetxt(f'{self.output_folder}/recon_stack_{i+1}.txt', stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e') # Add image reconstruction to reconstructed data NXprocess # reconstructed data order in each stack: row/z,x,y nxprocess.data = NXdata() nxprocess.attrs['default'] = 'data' for k, v in center_info.items(): nxprocess[k] = v if x_bounds is not None: nxprocess.x_bounds = x_bounds if y_bounds is not None: nxprocess.y_bounds = y_bounds if z_bounds is not None: nxprocess.z_bounds = z_bounds nxprocess.data['reconstructed_data'] = np.asarray([stack[z_range[0]:z_range[1], x_range[0]:x_range[1],y_range[0]:y_range[1]] for stack in tomo_recon_stacks]) nxprocess.data.attrs['signal'] = 'reconstructed_data' # Create a copy of the input Nexus object and remove reduced data exclude_items = [f'{nxentry._name}/reduced_data/data', f'{nxentry._name}/data/reduced_data'] nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items) # Add the reconstructed data NXprocess to the new Nexus object nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']] nxentry_copy.reconstructed_data = nxprocess if 'data' not in nxentry_copy: nxentry_copy.data = NXdata() nxentry_copy.attrs['default'] = 'data' nxentry_copy.data.makelink(nxprocess.data.reconstructed_data, name='reconstructed_data') nxentry_copy.data.attrs['signal'] = 'reconstructed_data' return(nxroot_copy) def combine_data(self, nxroot): """Combine the reconstructed tomography stacks. """ logger.info('Combine the reconstructed tomography stacks') if not isinstance(nxroot, NXroot): raise ValueError(f'Invalid parameter nxroot ({nxroot})') nxentry = nxroot[nxroot.attrs['default']] if not isinstance(nxentry, NXentry): raise ValueError(f'Invalid nxentry ({nxentry})') # Create plot galaxy path directory and path if needed if self.galaxy_flag: if not os_path.exists('tomo_combine_plots'): mkdir('tomo_combine_plots') path = 'tomo_combine_plots' else: path = self.output_folder # Check if reconstructed image data is available if ('reconstructed_data' not in nxentry or 'reconstructed_data' not in nxentry.data): raise KeyError(f'Unable to find valid reconstructed image data in {nxentry}.') # Create an NXprocess to store combined image reconstruction (meta)data # if 'combined_data' in nxentry: # logger.warning(f'Existing combined data in {nxentry} will be overwritten.') # del nxentry['combined_data'] # if 'data' in nxentry 'combined_data' in nxentry.data: # del nxentry.data['combined_data'] nxprocess = NXprocess() # Get the reconstructed data # reconstructed data order: stack,row(z),x,y # Note: Nexus cannot follow a link if the data it points to is too big, # so get the data from the actual place, not from nxentry.data tomo_recon_stacks = np.asarray(nxentry.reconstructed_data.data.reconstructed_data) num_tomo_stacks = tomo_recon_stacks.shape[0] if num_tomo_stacks == 1: return(nxroot) t0 = time() logger.debug(f'Combining the reconstructed stacks ...') tomo_recon_combined = tomo_recon_stacks[0,:,:,:] if num_tomo_stacks > 2: tomo_recon_combined = np.concatenate([tomo_recon_combined]+ [tomo_recon_stacks[i,:,:,:] for i in range(1, num_tomo_stacks-1)]) if num_tomo_stacks > 1: tomo_recon_combined = np.concatenate([tomo_recon_combined]+ [tomo_recon_stacks[num_tomo_stacks-1,:,:,:]]) logger.debug(f'... done in {time()-t0:.2f} seconds') logger.info(f'Combining the reconstructed stacks took {time()-t0:.2f} seconds') # Resize the combined tomography data set # combined data order: row/z,x,y if self.test_mode: x_bounds = None y_bounds = None z_bounds = self.test_config.get('z_bounds') elif self.galaxy_flag: exit('TODO') if x_bounds is not None and not is_int_pair(x_bounds, ge=0, lt=tomo_recon_stacks[0].shape[1]): raise ValueError(f'Invalid parameter x_bounds ({x_bounds})') if y_bounds is not None and not is_int_pair(y_bounds, ge=0, lt=tomo_recon_stacks[0].shape[1]): raise ValueError(f'Invalid parameter y_bounds ({y_bounds})') z_bounds = None else: x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_combined, z_only=True) if x_bounds is None: x_range = (0, tomo_recon_combined.shape[1]) x_slice = int(x_range[1]/2) else: x_range = x_bounds x_slice = int((x_bounds[0]+x_bounds[1])/2) if y_bounds is None: y_range = (0, tomo_recon_combined.shape[2]) y_slice = int(y_range[1]/2) else: y_range = y_bounds y_slice = int((y_bounds[0]+y_bounds[1])/2) if z_bounds is None: z_range = (0, tomo_recon_combined.shape[0]) z_slice = int(z_range[1]/2) else: z_range = z_bounds z_slice = int((z_bounds[0]+z_bounds[1])/2) # Plot a few combined image slices quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]], title=f'recon combined xslice{x_slice}', path=path, save_fig=self.save_figs, save_only=self.save_only, block=self.block) quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice], title=f'recon combined yslice{y_slice}', path=path, save_fig=self.save_figs, save_only=self.save_only, block=self.block) quick_imshow(tomo_recon_combined[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], title=f'recon combined zslice{z_slice}', path=path, save_fig=self.save_figs, save_only=self.save_only, block=self.block) # Save test data to file # combined data order: row/z,x,y if self.test_mode: np.savetxt(f'{self.output_folder}/recon_combined.txt', tomo_recon_combined[ z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e') # Add image reconstruction to reconstructed data NXprocess # combined data order: row/z,x,y nxprocess.data = NXdata() nxprocess.attrs['default'] = 'data' if x_bounds is not None: nxprocess.x_bounds = x_bounds if y_bounds is not None: nxprocess.y_bounds = y_bounds if z_bounds is not None: nxprocess.z_bounds = z_bounds nxprocess.data['combined_data'] = tomo_recon_combined nxprocess.data.attrs['signal'] = 'combined_data' # Create a copy of the input Nexus object and remove reconstructed data exclude_items = [f'{nxentry._name}/reconstructed_data/data', f'{nxentry._name}/data/reconstructed_data'] nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items) # Add the combined data NXprocess to the new Nexus object nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']] nxentry_copy.combined_data = nxprocess if 'data' not in nxentry_copy: nxentry_copy.data = NXdata() nxentry_copy.attrs['default'] = 'data' nxentry_copy.data.makelink(nxprocess.data.combined_data, name='combined_data') nxentry_copy.data.attrs['signal'] = 'combined_data' return(nxroot_copy) def _gen_dark(self, nxentry, reduced_data): """Generate dark field. """ # Get the dark field images image_key = nxentry.instrument.detector.get('image_key', None) if image_key and 'data' in nxentry.instrument.detector: field_indices = [index for index, key in enumerate(image_key) if key == 2] tdf_stack = nxentry.instrument.detector.data[field_indices,:,:] # RV the default NXtomo form does not accomodate bright or dark field stacks else: dark_field_scans = nxentry.spec_scans.dark_field dark_field = FlatField.construct_from_nxcollection(dark_field_scans) prefix = str(nxentry.instrument.detector.local_name) tdf_stack = dark_field.get_detector_data(prefix) if isinstance(tdf_stack, list): exit('TODO') # Take median if tdf_stack.ndim == 2: tdf = tdf_stack elif tdf_stack.ndim == 3: tdf = np.median(tdf_stack, axis=0) del tdf_stack else: raise ValueError(f'Invalid tdf_stack shape ({tdf_stack.shape})') # Remove dark field intensities above the cutoff #RV tdf_cutoff = None tdf_cutoff = tdf.min()+2*(np.median(tdf)-tdf.min()) logger.debug(f'tdf_cutoff = {tdf_cutoff}') if tdf_cutoff is not None: if not is_num(tdf_cutoff, ge=0): logger.warning(f'Ignoring illegal value of tdf_cutoff {tdf_cutoff}') else: tdf[tdf > tdf_cutoff] = np.nan logger.debug(f'tdf_cutoff = {tdf_cutoff}') # Remove nans tdf_mean = np.nanmean(tdf) logger.debug(f'tdf_mean = {tdf_mean}') np.nan_to_num(tdf, copy=False, nan=tdf_mean, posinf=tdf_mean, neginf=0.) # Plot dark field if self.galaxy_flag: quick_imshow(tdf, title='dark field', path='tomo_reduce_plots', save_fig=self.save_figs, save_only=self.save_only) elif not self.test_mode: quick_imshow(tdf, title='dark field', path=self.output_folder, save_fig=self.save_figs, save_only=self.save_only) clear_imshow('dark field') # quick_imshow(tdf, title='dark field', block=True) # Add dark field to reduced data NXprocess reduced_data.data = NXdata() reduced_data.data['dark_field'] = tdf return(reduced_data) def _gen_bright(self, nxentry, reduced_data): """Generate bright field. """ # Get the bright field images image_key = nxentry.instrument.detector.get('image_key', None) if image_key and 'data' in nxentry.instrument.detector: field_indices = [index for index, key in enumerate(image_key) if key == 1] tbf_stack = nxentry.instrument.detector.data[field_indices,:,:] # RV the default NXtomo form does not accomodate bright or dark field stacks else: bright_field_scans = nxentry.spec_scans.bright_field bright_field = FlatField.construct_from_nxcollection(bright_field_scans) prefix = str(nxentry.instrument.detector.local_name) tbf_stack = bright_field.get_detector_data(prefix) if isinstance(tbf_stack, list): exit('TODO') # Take median if more than one image """Median or mean: It may be best to try the median because of some image artifacts that arise due to crinkles in the upstream kapton tape windows causing some phase contrast images to appear on the detector. One thing that also may be useful in a future implementation is to do a brightfield adjustment on EACH frame of the tomo based on a ROI in the corner of the frame where there is no sample but there is the direct X-ray beam because there is frame to frame fluctuations from the incoming beam. We don’t typically account for them but potentially could. """ if tbf_stack.ndim == 2: tbf = tbf_stack elif tbf_stack.ndim == 3: tbf = np.median(tbf_stack, axis=0) del tbf_stack else: raise ValueError(f'Invalid tbf_stack shape ({tbf_stacks.shape})') # Subtract dark field if 'data' in reduced_data and 'dark_field' in reduced_data.data: tbf -= reduced_data.data.dark_field else: logger.warning('Dark field unavailable') # Set any non-positive values to one # (avoid negative bright field values for spikes in dark field) tbf[tbf < 1] = 1 # Plot bright field if self.galaxy_flag: quick_imshow(tbf, title='bright field', path='tomo_reduce_plots', save_fig=self.save_figs, save_only=self.save_only) elif not self.test_mode: quick_imshow(tbf, title='bright field', path=self.output_folder, save_fig=self.save_figs, save_only=self.save_only) clear_imshow('bright field') # quick_imshow(tbf, title='bright field', block=True) # Add bright field to reduced data NXprocess if 'data' not in reduced_data: reduced_data.data = NXdata() reduced_data.data['bright_field'] = tbf return(reduced_data) def _set_detector_bounds(self, nxentry, reduced_data, img_x_bounds=None): """Set vertical detector bounds for each image stack. Right now the range is the same for each set in the image stack. """ if self.test_mode: return(tuple(self.test_config['img_x_bounds'])) # Get the first tomography image and the reference heights image_key = nxentry.instrument.detector.get('image_key', None) if image_key and 'data' in nxentry.instrument.detector: field_indices = [index for index, key in enumerate(image_key) if key == 0] first_image = np.asarray(nxentry.instrument.detector.data[field_indices[0],:,:]) theta = float(nxentry.sample.rotation_angle[field_indices[0]]) z_translation_all = nxentry.sample.z_translation[field_indices] z_translation_levels = sorted(list(set(z_translation_all))) num_tomo_stacks = len(z_translation_levels) else: tomo_field_scans = nxentry.spec_scans.tomo_fields tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans) vertical_shifts = tomo_fields.get_vertical_shifts() if not isinstance(vertical_shifts, list): vertical_shifts = [vertical_shifts] prefix = str(nxentry.instrument.detector.local_name) t0 = time() first_image = tomo_fields.get_detector_data(prefix, tomo_fields.scan_numbers[0], 0) logger.debug(f'Getting first image took {time()-t0:.2f} seconds') num_tomo_stacks = len(tomo_fields.scan_numbers) theta = tomo_fields.theta_range['start'] # Select image bounds title = f'tomography image at theta={round(theta, 2)+0}' if (img_x_bounds is not None and not is_index_range(img_x_bounds, ge=0, le=first_image.shape[0])): raise ValueError(f'Invalid parameter img_x_bounds ({img_x_bounds})') if nxentry.instrument.source.attrs['station'] in ('id1a3', 'id3a'): pixel_size = nxentry.instrument.detector.x_pixel_size # Try to get a fit from the bright field tbf = np.asarray(reduced_data.data.bright_field) tbf_shape = tbf.shape x_sum = np.sum(tbf, 1) x_sum_min = x_sum.min() x_sum_max = x_sum.max() fit = Fit.fit_data(x_sum, 'rectangle', x=np.array(range(len(x_sum))), form='atan', guess=True) parameters = fit.best_values x_low_fit = parameters.get('center1', None) x_upp_fit = parameters.get('center2', None) sig_low = parameters.get('sigma1', None) sig_upp = parameters.get('sigma2', None) have_fit = fit.success and x_low_fit is not None and x_upp_fit is not None and \ sig_low is not None and sig_upp is not None and \ 0 <= x_low_fit < x_upp_fit <= x_sum.size and \ (sig_low+sig_upp)/(x_upp_fit-x_low_fit) < 0.1 if have_fit: # Set a 5% margin on each side margin = 0.05*(x_upp_fit-x_low_fit) x_low_fit = max(0, x_low_fit-margin) x_upp_fit = min(tbf_shape[0], x_upp_fit+margin) if num_tomo_stacks == 1: if have_fit: # Set the default range to enclose the full fitted window x_low = int(x_low_fit) x_upp = int(x_upp_fit) else: # Center a default range of 1 mm (RV: can we get this from the slits?) num_x_min = int((1.0-0.5*pixel_size)/pixel_size) x_low = int(0.5*(tbf_shape[0]-num_x_min)) x_upp = x_low+num_x_min else: # Get the default range from the reference heights delta_z = vertical_shifts[1]-vertical_shifts[0] for i in range(2, num_tomo_stacks): delta_z = min(delta_z, vertical_shifts[i]-vertical_shifts[i-1]) logger.debug(f'delta_z = {delta_z}') num_x_min = int((delta_z-0.5*pixel_size)/pixel_size) logger.debug(f'num_x_min = {num_x_min}') if num_x_min > tbf_shape[0]: logger.warning('Image bounds and pixel size prevent seamless stacking') if have_fit: # Center the default range relative to the fitted window x_low = int(0.5*(x_low_fit+x_upp_fit-num_x_min)) x_upp = x_low+num_x_min else: # Center the default range x_low = int(0.5*(tbf_shape[0]-num_x_min)) x_upp = x_low+num_x_min if self.galaxy_flag: img_x_bounds = (x_low, x_upp) else: tmp = np.copy(tbf) tmp_max = tmp.max() tmp[x_low,:] = tmp_max tmp[x_upp-1,:] = tmp_max quick_imshow(tmp, title='bright field') tmp = np.copy(first_image) tmp_max = tmp.max() tmp[x_low,:] = tmp_max tmp[x_upp-1,:] = tmp_max quick_imshow(tmp, title=title) del tmp quick_plot((range(x_sum.size), x_sum), ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'), ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'), title='sum over theta and y') print(f'lower bound = {x_low} (inclusive)') print(f'upper bound = {x_upp} (exclusive)]') accept = input_yesno('Accept these bounds (y/n)?', 'y') clear_imshow('bright field') clear_imshow(title) clear_plot('sum over theta and y') if accept: img_x_bounds = (x_low, x_upp) else: while True: mask, img_x_bounds = draw_mask_1d(x_sum, title='select x data range', legend='sum over theta and y') if len(img_x_bounds) == 1: break else: print(f'Choose a single connected data range') img_x_bounds = tuple(img_x_bounds[0]) if (num_tomo_stacks > 1 and img_x_bounds[1]-img_x_bounds[0]+1 < int((delta_z-0.5*pixel_size)/pixel_size)): logger.warning('Image bounds and pixel size prevent seamless stacking') else: if num_tomo_stacks > 1: raise NotImplementedError('Selecting image bounds for multiple stacks on FMB') # For FMB: use the first tomography image to select range # RV: revisit if they do tomography with multiple stacks x_sum = np.sum(first_image, 1) x_sum_min = x_sum.min() x_sum_max = x_sum.max() if self.galaxy_flag: if img_x_bounds is None: img_x_bounds = (0, first_image.shape[0]) else: quick_imshow(first_image, title=title) print('Select vertical data reduction range from first tomography image') img_x_bounds = select_image_bounds(first_image, 0, title=title) clear_imshow(title) if img_x_bounds is None: raise ValueError('Unable to select image bounds') # Plot results if self.galaxy_flag: path = 'tomo_reduce_plots' else: path = self.output_folder x_low = img_x_bounds[0] x_upp = img_x_bounds[1] tmp = np.copy(first_image) tmp_max = tmp.max() tmp[x_low,:] = tmp_max tmp[x_upp-1,:] = tmp_max quick_imshow(tmp, title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, block=self.block) del tmp quick_plot((range(x_sum.size), x_sum), ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'), ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'), title='sum over theta and y', path=path, save_fig=self.save_figs, save_only=self.save_only, block=self.block) return(img_x_bounds) def _set_zoom_or_skip(self): """Set zoom and/or theta skip to reduce memory the requirement for the analysis. """ # if input_yesno('\nDo you want to zoom in to reduce memory requirement (y/n)?', 'n'): # zoom_perc = input_int(' Enter zoom percentage', ge=1, le=100) # else: # zoom_perc = None zoom_perc = None # if input_yesno('Do you want to skip thetas to reduce memory requirement (y/n)?', 'n'): # num_theta_skip = input_int(' Enter the number skip theta interval', ge=0, # lt=num_theta) # else: # num_theta_skip = None num_theta_skip = None logger.debug(f'zoom_perc = {zoom_perc}') logger.debug(f'num_theta_skip = {num_theta_skip}') return(zoom_perc, num_theta_skip) def _gen_tomo(self, nxentry, reduced_data): """Generate tomography fields. """ # Get full bright field tbf = np.asarray(reduced_data.data.bright_field) tbf_shape = tbf.shape # Get image bounds img_x_bounds = tuple(reduced_data.get('img_x_bounds', (0, tbf_shape[0]))) img_y_bounds = tuple(reduced_data.get('img_y_bounds', (0, tbf_shape[1]))) # Get resized dark field # if 'dark_field' in data: # tbf = np.asarray(reduced_data.data.dark_field[ # img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]]) # else: # logger.warning('Dark field unavailable') # tdf = None tdf = None # Resize bright field if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]): tbf = tbf[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]] # Get the tomography images image_key = nxentry.instrument.detector.get('image_key', None) if image_key and 'data' in nxentry.instrument.detector: field_indices_all = [index for index, key in enumerate(image_key) if key == 0] z_translation_all = nxentry.sample.z_translation[field_indices_all] z_translation_levels = sorted(list(set(z_translation_all))) num_tomo_stacks = len(z_translation_levels) tomo_stacks = num_tomo_stacks*[np.array([])] horizontal_shifts = [] vertical_shifts = [] thetas = None tomo_stacks = [] for i, z_translation in enumerate(z_translation_levels): field_indices = [field_indices_all[index] for index, z in enumerate(z_translation_all) if z == z_translation] horizontal_shift = list(set(nxentry.sample.x_translation[field_indices])) assert(len(horizontal_shift) == 1) horizontal_shifts += horizontal_shift vertical_shift = list(set(nxentry.sample.z_translation[field_indices])) assert(len(vertical_shift) == 1) vertical_shifts += vertical_shift sequence_numbers = nxentry.instrument.detector.sequence_number[field_indices] if thetas is None: thetas = np.asarray(nxentry.sample.rotation_angle[field_indices]) \ [sequence_numbers] else: assert(all(thetas[i] == nxentry.sample.rotation_angle[field_indices[index]] for i, index in enumerate(sequence_numbers))) assert(list(set(sequence_numbers)) == [i for i in range(len(sequence_numbers))]) if list(sequence_numbers) == [i for i in range(len(sequence_numbers))]: tomo_stack = np.asarray(nxentry.instrument.detector.data[field_indices]) else: raise ValueError('Unable to load the tomography images') tomo_stacks.append(tomo_stack) else: tomo_field_scans = nxentry.spec_scans.tomo_fields tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans) horizontal_shifts = tomo_fields.get_horizontal_shifts() vertical_shifts = tomo_fields.get_vertical_shifts() prefix = str(nxentry.instrument.detector.local_name) t0 = time() tomo_stacks = tomo_fields.get_detector_data(prefix) logger.debug(f'Getting tomography images took {time()-t0:.2f} seconds') logger.debug(f'Getting all images took {time()-t0:.2f} seconds') thetas = np.linspace(tomo_fields.theta_range['start'], tomo_fields.theta_range['end'], tomo_fields.theta_range['num']) if not isinstance(tomo_stacks, list): horizontal_shifts = [horizontal_shifts] vertical_shifts = [vertical_shifts] tomo_stacks = [tomo_stacks] reduced_tomo_stacks = [] if self.galaxy_flag: path = 'tomo_reduce_plots' else: path = self.output_folder for i, tomo_stack in enumerate(tomo_stacks): # Resize the tomography images # Right now the range is the same for each set in the image stack. if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]): t0 = time() tomo_stack = tomo_stack[:,img_x_bounds[0]:img_x_bounds[1], img_y_bounds[0]:img_y_bounds[1]].astype('float64') logger.debug(f'Resizing tomography images took {time()-t0:.2f} seconds') # Subtract dark field if tdf is not None: t0 = time() with set_numexpr_threads(self.num_core): ne.evaluate('tomo_stack-tdf', out=tomo_stack) logger.debug(f'Subtracting dark field took {time()-t0:.2f} seconds') # Normalize t0 = time() with set_numexpr_threads(self.num_core): ne.evaluate('tomo_stack/tbf', out=tomo_stack, truediv=True) logger.debug(f'Normalizing took {time()-t0:.2f} seconds') # Remove non-positive values and linearize data t0 = time() cutoff = 1.e-6 with set_numexpr_threads(self.num_core): ne.evaluate('where(tomo_stack<cutoff, cutoff, tomo_stack)', out=tomo_stack) with set_numexpr_threads(self.num_core): ne.evaluate('-log(tomo_stack)', out=tomo_stack) logger.debug('Removing non-positive values and linearizing data took '+ f'{time()-t0:.2f} seconds') # Get rid of nans/infs that may be introduced by normalization t0 = time() np.where(np.isfinite(tomo_stack), tomo_stack, 0.) logger.debug(f'Remove nans/infs took {time()-t0:.2f} seconds') # Downsize tomography stack to smaller size # TODO use theta_skip as well tomo_stack = tomo_stack.astype('float32') if not self.test_mode: if len(tomo_stacks) == 1: title = f'red fullres theta {round(thetas[0], 2)+0}' else: title = f'red stack {i} fullres theta {round(thetas[0], 2)+0}' quick_imshow(tomo_stack[0,:,:], title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, block=self.block) # if not self.block: # clear_imshow(title) if False and zoom_perc != 100: t0 = time() logger.debug(f'Zooming in ...') tomo_zoom_list = [] for j in range(tomo_stack.shape[0]): tomo_zoom = spi.zoom(tomo_stack[j,:,:], 0.01*zoom_perc) tomo_zoom_list.append(tomo_zoom) tomo_stack = np.stack([tomo_zoom for tomo_zoom in tomo_zoom_list]) logger.debug(f'... done in {time()-t0:.2f} seconds') logger.info(f'Zooming in took {time()-t0:.2f} seconds') del tomo_zoom_list if not self.test_mode: title = f'red stack {zoom_perc}p theta {round(thetas[0], 2)+0}' quick_imshow(tomo_stack[0,:,:], title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, block=self.block) # if not self.block: # clear_imshow(title) # Convert tomography stack from theta,row,column to row,theta,column t0 = time() tomo_stack = np.swapaxes(tomo_stack, 0, 1) logger.debug(f'Converting coordinate order took {time()-t0:.2f} seconds') # Save test data to file if self.test_mode: row_index = int(tomo_stack.shape[0]/2) np.savetxt(f'{self.output_folder}/red_stack_{i+1}.txt', tomo_stack[row_index,:,:], fmt='%.6e') # Combine resized stacks reduced_tomo_stacks.append(tomo_stack) # Add tomo field info to reduced data NXprocess reduced_data['rotation_angle'] = thetas reduced_data['x_translation'] = np.asarray(horizontal_shifts) reduced_data['z_translation'] = np.asarray(vertical_shifts) reduced_data.data['tomo_fields'] = np.asarray(reduced_tomo_stacks) if tdf is not None: del tdf del tbf return(reduced_data) def _find_center_one_plane(self, sinogram, row, thetas, eff_pixel_size, cross_sectional_dim, path=None, tol=0.1, num_core=1): """Find center for a single tomography plane. """ # Try automatic center finding routines for initial value # sinogram index order: theta,column # need column,theta for iradon, so take transpose sinogram_T = sinogram.T center = sinogram.shape[1]/2 # Try using Nghia Vo’s method t0 = time() if num_core > num_core_tomopy_limit: logger.debug(f'Running find_center_vo on {num_core_tomopy_limit} cores ...') tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core_tomopy_limit) else: logger.debug(f'Running find_center_vo on {num_core} cores ...') tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core) logger.debug(f'... done in {time()-t0:.2f} seconds') logger.info(f'Finding the center using Nghia Vo’s method took {time()-t0:.2f} seconds') center_offset_vo = tomo_center-center logger.info(f'Center at row {row} using Nghia Vo’s method = {center_offset_vo:.2f}') t0 = time() logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...') recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas, eff_pixel_size, cross_sectional_dim, False, num_core) logger.debug(f'... done in {time()-t0:.2f} seconds') logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds') title = f'edges row{row} center offset{center_offset_vo:.2f} Vo' self._plot_edges_one_plane(recon_plane, title, path=path) # Try using phase correlation method # if input_yesno('Try finding center using phase correlation (y/n)?', 'n'): # t0 = time() # logger.debug(f'Running find_center_pc ...') # tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=0.1, rotc_guess=tomo_center) # error = 1. # while error > tol: # prev = tomo_center # tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=tol, # rotc_guess=tomo_center) # error = np.abs(tomo_center-prev) # logger.debug(f'... done in {time()-t0:.2f} seconds') # logger.info('Finding the center using the phase correlation method took '+ # f'{time()-t0:.2f} seconds') # center_offset = tomo_center-center # print(f'Center at row {row} using phase correlation = {center_offset:.2f}') # t0 = time() # logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...') # recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas, # eff_pixel_size, cross_sectional_dim, False, num_core) # logger.debug(f'... done in {time()-t0:.2f} seconds') # logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds') # # title = f'edges row{row} center_offset{center_offset:.2f} PC' # self._plot_edges_one_plane(recon_plane, title, path=path) # Select center location # if input_yesno('Accept a center location (y) or continue search (n)?', 'y'): if True: # center_offset = input_num(' Enter chosen center offset', ge=-center, le=center, # default=center_offset_vo) center_offset = center_offset_vo del sinogram_T del recon_plane return float(center_offset) # perform center finding search while True: center_offset_low = input_int('\nEnter lower bound for center offset', ge=-center, le=center) center_offset_upp = input_int('Enter upper bound for center offset', ge=center_offset_low, le=center) if center_offset_upp == center_offset_low: center_offset_step = 1 else: center_offset_step = input_int('Enter step size for center offset search', ge=1, le=center_offset_upp-center_offset_low) num_center_offset = 1+int((center_offset_upp-center_offset_low)/center_offset_step) center_offsets = np.linspace(center_offset_low, center_offset_upp, num_center_offset) for center_offset in center_offsets: if center_offset == center_offset_vo: continue t0 = time() logger.debug(f'Running _reconstruct_one_plane on {num_core} cores ...') recon_plane = self._reconstruct_one_plane(sinogram_T, center_offset+center, thetas, eff_pixel_size, cross_sectional_dim, False, num_core) logger.debug(f'... done in {time()-t0:.2f} seconds') logger.info(f'Reconstructing center_offset {center_offset} took '+ f'{time()-t0:.2f} seconds') title = f'edges row{row} center_offset{center_offset:.2f}' self._plot_edges_one_plane(recon_plane, title, path=path) if input_int('\nContinue (0) or end the search (1)', ge=0, le=1): break del sinogram_T del recon_plane center_offset = input_num(' Enter chosen center offset', ge=-center, le=center) return float(center_offset) def _reconstruct_one_plane(self, tomo_plane_T, center, thetas, eff_pixel_size, cross_sectional_dim, plot_sinogram=True, num_core=1): """Invert the sinogram for a single tomography plane. """ # tomo_plane_T index order: column,theta assert(0 <= center < tomo_plane_T.shape[0]) center_offset = center-tomo_plane_T.shape[0]/2 two_offset = 2*int(np.round(center_offset)) two_offset_abs = np.abs(two_offset) max_rad = int(0.55*(cross_sectional_dim/eff_pixel_size)) # 10% slack to avoid edge effects if max_rad > 0.5*tomo_plane_T.shape[0]: max_rad = 0.5*tomo_plane_T.shape[0] dist_from_edge = max(1, int(np.floor((tomo_plane_T.shape[0]-two_offset_abs)/2.)-max_rad)) if two_offset >= 0: logger.debug(f'sinogram range = [{two_offset+dist_from_edge}, {-dist_from_edge}]') sinogram = tomo_plane_T[two_offset+dist_from_edge:-dist_from_edge,:] else: logger.debug(f'sinogram range = [{dist_from_edge}, {two_offset-dist_from_edge}]') sinogram = tomo_plane_T[dist_from_edge:two_offset-dist_from_edge,:] if not self.galaxy_flag and plot_sinogram: quick_imshow(sinogram.T, f'sinogram center offset{center_offset:.2f}', aspect='auto', path=self.output_folder, save_fig=self.save_figs, save_only=self.save_only, block=self.block) # Inverting sinogram t0 = time() recon_sinogram = iradon(sinogram, theta=thetas, circle=True) logger.debug(f'Inverting sinogram took {time()-t0:.2f} seconds') del sinogram # Performing Gaussian filtering and removing ring artifacts recon_parameters = None#self.config.get('recon_parameters') if recon_parameters is None: sigma = 1.0 ring_width = 15 else: sigma = recon_parameters.get('gaussian_sigma', 1.0) if not is_num(sigma, ge=0.0): logger.warning(f'Invalid gaussian_sigma ({sigma}) in _reconstruct_one_plane, '+ 'set to a default value of 1.0') sigma = 1.0 ring_width = recon_parameters.get('ring_width', 15) if not is_int(ring_width, ge=0): logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+ 'set to a default value of 15') ring_width = 15 t0 = time() recon_sinogram = spi.gaussian_filter(recon_sinogram, sigma, mode='nearest') recon_clean = np.expand_dims(recon_sinogram, axis=0) del recon_sinogram recon_clean = tomopy.misc.corr.remove_ring(recon_clean, rwidth=ring_width, ncore=num_core) logger.debug(f'Filtering and removing ring artifacts took {time()-t0:.2f} seconds') return recon_clean def _plot_edges_one_plane(self, recon_plane, title, path=None): vis_parameters = None#self.config.get('vis_parameters') if vis_parameters is None: weight = 0.1 else: weight = vis_parameters.get('denoise_weight', 0.1) if not is_num(weight, ge=0.0): logger.warning(f'Invalid weight ({weight}) in _plot_edges_one_plane, '+ 'set to a default value of 0.1') weight = 0.1 edges = denoise_tv_chambolle(recon_plane, weight=weight) vmax = np.max(edges[0,:,:]) vmin = -vmax if path is None: path = self.output_folder quick_imshow(edges[0,:,:], f'{title} coolwarm', path=path, cmap='coolwarm', save_fig=self.save_figs, save_only=self.save_only, block=self.block) quick_imshow(edges[0,:,:], f'{title} gray', path=path, cmap='gray', vmin=vmin, vmax=vmax, save_fig=self.save_figs, save_only=self.save_only, block=self.block) del edges def _reconstruct_one_tomo_stack(self, tomo_stack, thetas, center_offsets=[], num_core=1, algorithm='gridrec'): """Reconstruct a single tomography stack. """ # tomo_stack order: row,theta,column # input thetas must be in degrees # centers_offset: tomography axis shift in pixels relative to column center # RV should we remove stripes? # https://tomopy.readthedocs.io/en/latest/api/tomopy.prep.stripe.html # RV should we remove rings? # https://tomopy.readthedocs.io/en/latest/api/tomopy.misc.corr.html # RV: Add an option to do (extra) secondary iterations later or to do some sort of convergence test? if not len(center_offsets): centers = np.zeros((tomo_stack.shape[0])) elif len(center_offsets) == 2: centers = np.linspace(center_offsets[0], center_offsets[1], tomo_stack.shape[0]) else: if center_offsets.size != tomo_stack.shape[0]: raise ValueError('center_offsets dimension mismatch in reconstruct_one_tomo_stack') centers = center_offsets centers += tomo_stack.shape[2]/2 # Get reconstruction parameters recon_parameters = None#self.config.get('recon_parameters') if recon_parameters is None: sigma = 2.0 secondary_iters = 0 ring_width = 15 else: sigma = recon_parameters.get('stripe_fw_sigma', 2.0) if not is_num(sigma, ge=0): logger.warning(f'Invalid stripe_fw_sigma ({sigma}) in '+ '_reconstruct_one_tomo_stack, set to a default value of 2.0') ring_width = 15 secondary_iters = recon_parameters.get('secondary_iters', 0) if not is_int(secondary_iters, ge=0): logger.warning(f'Invalid secondary_iters ({secondary_iters}) in '+ '_reconstruct_one_tomo_stack, set to a default value of 0 (skip them)') ring_width = 0 ring_width = recon_parameters.get('ring_width', 15) if not is_int(ring_width, ge=0): logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+ 'set to a default value of 15') ring_width = 15 # Remove horizontal stripe t0 = time() if num_core > num_core_tomopy_limit: logger.debug('Running remove_stripe_fw on {num_core_tomopy_limit} cores ...') tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma, ncore=num_core_tomopy_limit) else: logger.debug(f'Running remove_stripe_fw on {num_core} cores ...') tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma, ncore=num_core) logger.debug(f'... tomopy.prep.stripe.remove_stripe_fw took {time()-t0:.2f} seconds') # Perform initial image reconstruction logger.debug('Performing initial image reconstruction') t0 = time() logger.debug(f'Running recon on {num_core} cores ...') tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers, sinogram_order=True, algorithm=algorithm, ncore=num_core) logger.debug(f'... done in {time()-t0:.2f} seconds') logger.info(f'Performing initial image reconstruction took {time()-t0:.2f} seconds') # Run optional secondary iterations if secondary_iters > 0: logger.debug(f'Running {secondary_iters} secondary iterations') #options = {'method':'SIRT_CUDA', 'proj_type':'cuda', 'num_iter':secondary_iters} #RV: doesn't work for me: #"Error: CUDA error 803: system has unsupported display driver/cuda driver combination." #options = {'method':'SIRT', 'proj_type':'linear', 'MinConstraint': 0, 'num_iter':secondary_iters} #SIRT did not finish while running overnight #options = {'method':'SART', 'proj_type':'linear', 'num_iter':secondary_iters} options = {'method':'SART', 'proj_type':'linear', 'MinConstraint': 0, 'num_iter':secondary_iters} t0 = time() logger.debug(f'Running recon on {num_core} cores ...') tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers, init_recon=tomo_recon_stack, options=options, sinogram_order=True, algorithm=tomopy.astra, ncore=num_core) logger.debug(f'... done in {time()-t0:.2f} seconds') logger.info(f'Performing secondary iterations took {time()-t0:.2f} seconds') # Remove ring artifacts t0 = time() tomopy.misc.corr.remove_ring(tomo_recon_stack, rwidth=ring_width, out=tomo_recon_stack, ncore=num_core) logger.debug(f'Removing ring artifacts took {time()-t0:.2f} seconds') return tomo_recon_stack def _resize_reconstructed_data(self, data, z_only=False): """Resize the reconstructed tomography data. """ # Data order: row(z),x,y or stack,row(z),x,y if isinstance(data, list): for stack in data: assert(stack.ndim == 3) num_tomo_stacks = len(data) tomo_recon_stacks = data else: assert(data.ndim == 3) num_tomo_stacks = 1 tomo_recon_stacks = [data] if z_only: x_bounds = None y_bounds = None else: # Selecting x bounds (in yz-plane) tomosum = 0 [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,2)) for i in range(num_tomo_stacks)] select_x_bounds = input_yesno('\nDo you want to change the image x-bounds (y/n)?', 'y') if not select_x_bounds: x_bounds = None else: accept = False index_ranges = None while not accept: mask, x_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, title='select x data range', legend='recon stack sum yz') while len(x_bounds) != 1: print('Please select exactly one continuous range') mask, x_bounds = draw_mask_1d(tomosum, title='select x data range', legend='recon stack sum yz') x_bounds = x_bounds[0] # quick_plot(tomosum, vlines=x_bounds, title='recon stack sum yz') # print(f'x_bounds = {x_bounds} (lower bound inclusive, upper bound '+ # 'exclusive)') # accept = input_yesno('Accept these bounds (y/n)?', 'y') accept = True logger.debug(f'x_bounds = {x_bounds}') # Selecting y bounds (in xz-plane) tomosum = 0 [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,1)) for i in range(num_tomo_stacks)] select_y_bounds = input_yesno('\nDo you want to change the image y-bounds (y/n)?', 'y') if not select_y_bounds: y_bounds = None else: accept = False index_ranges = None while not accept: mask, y_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, title='select x data range', legend='recon stack sum xz') while len(y_bounds) != 1: print('Please select exactly one continuous range') mask, y_bounds = draw_mask_1d(tomosum, title='select x data range', legend='recon stack sum xz') y_bounds = y_bounds[0] # quick_plot(tomosum, vlines=y_bounds, title='recon stack sum xz') # print(f'y_bounds = {y_bounds} (lower bound inclusive, upper bound '+ # 'exclusive)') # accept = input_yesno('Accept these bounds (y/n)?', 'y') accept = True logger.debug(f'y_bounds = {y_bounds}') # Selecting z bounds (in xy-plane) (only valid for a single image set) if num_tomo_stacks != 1: z_bounds = None else: tomosum = 0 [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(1,2)) for i in range(num_tomo_stacks)] select_z_bounds = input_yesno('Do you want to change the image z-bounds (y/n)?', 'n') if not select_z_bounds: z_bounds = None else: accept = False index_ranges = None while not accept: mask, z_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, title='select x data range', legend='recon stack sum xy') while len(z_bounds) != 1: print('Please select exactly one continuous range') mask, z_bounds = draw_mask_1d(tomosum, title='select x data range', legend='recon stack sum xy') z_bounds = z_bounds[0] # quick_plot(tomosum, vlines=z_bounds, title='recon stack sum xy') # print(f'z_bounds = {z_bounds} (lower bound inclusive, upper bound '+ # 'exclusive)') # accept = input_yesno('Accept these bounds (y/n)?', 'y') accept = True logger.debug(f'z_bounds = {z_bounds}') return(x_bounds, y_bounds, z_bounds) def run_tomo(input_file:str, output_file:str, modes:list[str], center_file=None, num_core=-1, output_folder='.', save_figs='no', test_mode=False) -> None: if test_mode: logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s' level = logging.getLevelName('INFO') logging.basicConfig(filename=f'{output_folder}/tomo.log', filemode='w', format=logging_format, level=level, force=True) logger.info(f'input_file = {input_file}') logger.info(f'center_file = {center_file}') logger.info(f'output_file = {output_file}') logger.debug(f'modes= {modes}') logger.debug(f'num_core= {num_core}') logger.info(f'output_folder = {output_folder}') logger.info(f'save_figs = {save_figs}') logger.info(f'test_mode = {test_mode}') # Check for correction modes if modes is None: modes = ['all'] logger.debug(f'modes {type(modes)} = {modes}') # Instantiate Tomo object tomo = Tomo(num_core=num_core, output_folder=output_folder, save_figs=save_figs, test_mode=test_mode) # Read input file data = tomo.read(input_file) # Generate reduced tomography images if 'reduce_data' in modes or 'all' in modes: data = tomo.gen_reduced_data(data) # Find rotation axis centers for the tomography stacks. center_data = None if 'find_center' in modes or 'all' in modes: center_data = tomo.find_centers(data) # Reconstruct tomography stacks if 'reconstruct_data' in modes or 'all' in modes: if center_data is None: # Read input file center_data = tomo.read(center_file) data = tomo.reconstruct_data(data, center_data) center_data = None # Combine reconstructed tomography stacks if 'combine_data' in modes or 'all' in modes: data = tomo.combine_data(data) # Write output file if not test_mode: if center_data is None: data = tomo.write(data, output_file) else: data = tomo.write(center_data, output_file)