Mercurial > repos > rv43 > tomo
comparison workflow/run_tomo.py @ 69:fba792d5f83b draft
planemo upload for repository https://github.com/rolfverberg/galaxytools commit ab9f412c362a4ab986d00e21d5185cfcf82485c1
| author | rv43 | 
|---|---|
| date | Fri, 10 Mar 2023 16:02:04 +0000 | 
| parents | |
| children | 1cf15b61cd83 | 
   comparison
  equal
  deleted
  inserted
  replaced
| 68:ba5866d0251d | 69:fba792d5f83b | 
|---|---|
| 1 #!/usr/bin/env python3 | |
| 2 | |
| 3 import logging | |
| 4 logger = logging.getLogger(__name__) | |
| 5 | |
| 6 import numpy as np | |
| 7 try: | |
| 8 import numexpr as ne | |
| 9 except: | |
| 10 pass | |
| 11 try: | |
| 12 import scipy.ndimage as spi | |
| 13 except: | |
| 14 pass | |
| 15 | |
| 16 from multiprocessing import cpu_count | |
| 17 from nexusformat.nexus import * | |
| 18 from os import mkdir | |
| 19 from os import path as os_path | |
| 20 try: | |
| 21 from skimage.transform import iradon | |
| 22 except: | |
| 23 pass | |
| 24 try: | |
| 25 from skimage.restoration import denoise_tv_chambolle | |
| 26 except: | |
| 27 pass | |
| 28 from time import time | |
| 29 try: | |
| 30 import tomopy | |
| 31 except: | |
| 32 pass | |
| 33 from yaml import safe_load, safe_dump | |
| 34 | |
| 35 from msnctools.fit import Fit | |
| 36 from msnctools.general import illegal_value, is_int, is_int_pair, is_num, is_index_range, \ | |
| 37 input_int, input_num, input_yesno, input_menu, draw_mask_1d, select_image_bounds, \ | |
| 38 select_one_image_bound, clear_imshow, quick_imshow, clear_plot, quick_plot | |
| 39 | |
| 40 from workflow.models import import_scanparser, FlatField, TomoField, TomoWorkflow | |
| 41 from workflow.__version__ import __version__ | |
| 42 | |
| 43 num_core_tomopy_limit = 24 | |
| 44 | |
| 45 def nxcopy(nxobject:NXobject, exclude_nxpaths:list[str]=[], nxpath_prefix:str='') -> NXobject: | |
| 46 '''Function that returns a copy of a nexus object, optionally exluding certain child items. | |
| 47 | |
| 48 :param nxobject: the original nexus object to return a "copy" of | |
| 49 :type nxobject: nexusformat.nexus.NXobject | |
| 50 :param exlude_nxpaths: a list of paths to child nexus objects that | |
| 51 should be exluded from the returned "copy", defaults to `[]` | |
| 52 :type exclude_nxpaths: list[str], optional | |
| 53 :param nxpath_prefix: For use in recursive calls from inside this | |
| 54 function only! | |
| 55 :type nxpath_prefix: str | |
| 56 :return: a copy of `nxobject` with some children optionally exluded. | |
| 57 :rtype: NXobject | |
| 58 ''' | |
| 59 | |
| 60 nxobject_copy = nxobject.__class__() | |
| 61 if not len(nxpath_prefix): | |
| 62 if 'default' in nxobject.attrs: | |
| 63 nxobject_copy.attrs['default'] = nxobject.attrs['default'] | |
| 64 else: | |
| 65 for k, v in nxobject.attrs.items(): | |
| 66 nxobject_copy.attrs[k] = v | |
| 67 | |
| 68 for k, v in nxobject.items(): | |
| 69 nxpath = os_path.join(nxpath_prefix, k) | |
| 70 | |
| 71 if nxpath in exclude_nxpaths: | |
| 72 continue | |
| 73 | |
| 74 if isinstance(v, NXgroup): | |
| 75 nxobject_copy[k] = nxcopy(v, exclude_nxpaths=exclude_nxpaths, | |
| 76 nxpath_prefix=os_path.join(nxpath_prefix, k)) | |
| 77 else: | |
| 78 nxobject_copy[k] = v | |
| 79 | |
| 80 return(nxobject_copy) | |
| 81 | |
| 82 class set_numexpr_threads: | |
| 83 | |
| 84 def __init__(self, num_core): | |
| 85 if num_core is None or num_core < 1 or num_core > cpu_count(): | |
| 86 self.num_core = cpu_count() | |
| 87 else: | |
| 88 self.num_core = num_core | |
| 89 | |
| 90 def __enter__(self): | |
| 91 self.num_core_org = ne.set_num_threads(self.num_core) | |
| 92 | |
| 93 def __exit__(self, exc_type, exc_value, traceback): | |
| 94 ne.set_num_threads(self.num_core_org) | |
| 95 | |
| 96 class Tomo: | |
| 97 """Processing tomography data with misalignment. | |
| 98 """ | |
| 99 def __init__(self, galaxy_flag=False, num_core=-1, output_folder='.', save_figs=None, | |
| 100 test_mode=False): | |
| 101 """Initialize with optional config input file or dictionary | |
| 102 """ | |
| 103 if not isinstance(galaxy_flag, bool): | |
| 104 raise ValueError(f'Invalid parameter galaxy_flag ({galaxy_flag})') | |
| 105 self.galaxy_flag = galaxy_flag | |
| 106 self.num_core = num_core | |
| 107 if self.galaxy_flag: | |
| 108 if output_folder != '.': | |
| 109 logger.warning('Ignoring output_folder in galaxy mode') | |
| 110 self.output_folder = '.' | |
| 111 if test_mode != False: | |
| 112 logger.warning('Ignoring test_mode in galaxy mode') | |
| 113 self.test_mode = False | |
| 114 if save_figs is not None: | |
| 115 logger.warning('Ignoring save_figs in galaxy mode') | |
| 116 save_figs = 'only' | |
| 117 else: | |
| 118 self.output_folder = os_path.abspath(output_folder) | |
| 119 if not os_path.isdir(output_folder): | |
| 120 mkdir(os_path.abspath(output_folder)) | |
| 121 if not isinstance(test_mode, bool): | |
| 122 raise ValueError(f'Invalid parameter test_mode ({test_mode})') | |
| 123 self.test_mode = test_mode | |
| 124 if save_figs is None: | |
| 125 save_figs = 'no' | |
| 126 self.test_config = {} | |
| 127 if self.test_mode: | |
| 128 if save_figs != 'only': | |
| 129 logger.warning('Ignoring save_figs in test mode') | |
| 130 save_figs = 'only' | |
| 131 if save_figs == 'only': | |
| 132 self.save_only = True | |
| 133 self.save_figs = True | |
| 134 elif save_figs == 'yes': | |
| 135 self.save_only = False | |
| 136 self.save_figs = True | |
| 137 elif save_figs == 'no': | |
| 138 self.save_only = False | |
| 139 self.save_figs = False | |
| 140 else: | |
| 141 raise ValueError(f'Invalid parameter save_figs ({save_figs})') | |
| 142 if self.save_only: | |
| 143 self.block = False | |
| 144 else: | |
| 145 self.block = True | |
| 146 if self.num_core == -1: | |
| 147 self.num_core = cpu_count() | |
| 148 if not is_int(self.num_core, gt=0, log=False): | |
| 149 raise ValueError(f'Invalid parameter num_core ({num_core})') | |
| 150 if self.num_core > cpu_count(): | |
| 151 logger.warning(f'num_core = {self.num_core} is larger than the number of available ' | |
| 152 f'processors and reduced to {cpu_count()}') | |
| 153 self.num_core= cpu_count() | |
| 154 | |
| 155 def read(self, filename): | |
| 156 extension = os_path.splitext(filename)[1] | |
| 157 if extension == '.yml' or extension == '.yaml': | |
| 158 with open(filename, 'r') as f: | |
| 159 config = safe_load(f) | |
| 160 # if len(config) > 1: | |
| 161 # raise ValueError(f'Multiple root entries in {filename} not yet implemented') | |
| 162 # if len(list(config.values())[0]) > 1: | |
| 163 # raise ValueError(f'Multiple sample maps in {filename} not yet implemented') | |
| 164 return(config) | |
| 165 elif extension == '.nxs': | |
| 166 with NXFile(filename, mode='r') as nxfile: | |
| 167 nxroot = nxfile.readfile() | |
| 168 return(nxroot) | |
| 169 else: | |
| 170 raise ValueError(f'Invalid filename extension ({extension})') | |
| 171 | |
| 172 def write(self, data, filename): | |
| 173 extension = os_path.splitext(filename)[1] | |
| 174 if extension == '.yml' or extension == '.yaml': | |
| 175 with open(filename, 'w') as f: | |
| 176 safe_dump(data, f) | |
| 177 elif extension == '.nxs': | |
| 178 data.save(filename, mode='w') | |
| 179 elif extension == '.nc': | |
| 180 data.to_netcdf(os_path=filename) | |
| 181 else: | |
| 182 raise ValueError(f'Invalid filename extension ({extension})') | |
| 183 | |
| 184 def gen_reduced_data(self, data, img_x_bounds=None): | |
| 185 """Generate the reduced tomography images. | |
| 186 """ | |
| 187 logger.info('Generate the reduced tomography images') | |
| 188 | |
| 189 # Create plot galaxy path directory if needed | |
| 190 if self.galaxy_flag and not os_path.exists('tomo_reduce_plots'): | |
| 191 mkdir('tomo_reduce_plots') | |
| 192 | |
| 193 if isinstance(data, dict): | |
| 194 # Create Nexus format object from input dictionary | |
| 195 wf = TomoWorkflow(**data) | |
| 196 if len(wf.sample_maps) > 1: | |
| 197 raise ValueError(f'Multiple sample maps not yet implemented') | |
| 198 # print(f'\nwf:\n{wf}\n') | |
| 199 nxroot = NXroot() | |
| 200 t0 = time() | |
| 201 for sample_map in wf.sample_maps: | |
| 202 logger.info(f'Start constructing the {sample_map.title} map.') | |
| 203 import_scanparser(sample_map.station) | |
| 204 sample_map.construct_nxentry(nxroot, include_raw_data=False) | |
| 205 logger.info(f'Constructed all sample maps in {time()-t0:.2f} seconds.') | |
| 206 nxentry = nxroot[nxroot.attrs['default']] | |
| 207 # Get test mode configuration info | |
| 208 if self.test_mode: | |
| 209 self.test_config = data['sample_maps'][0]['test_mode'] | |
| 210 elif isinstance(data, NXroot): | |
| 211 nxentry = data[data.attrs['default']] | |
| 212 else: | |
| 213 raise ValueError(f'Invalid parameter data ({data})') | |
| 214 | |
| 215 # Create an NXprocess to store data reduction (meta)data | |
| 216 reduced_data = NXprocess() | |
| 217 | |
| 218 # Generate dark field | |
| 219 if 'dark_field' in nxentry['spec_scans']: | |
| 220 reduced_data = self._gen_dark(nxentry, reduced_data) | |
| 221 | |
| 222 # Generate bright field | |
| 223 reduced_data = self._gen_bright(nxentry, reduced_data) | |
| 224 | |
| 225 # Set vertical detector bounds for image stack | |
| 226 img_x_bounds = self._set_detector_bounds(nxentry, reduced_data, img_x_bounds=img_x_bounds) | |
| 227 logger.info(f'img_x_bounds = {img_x_bounds}') | |
| 228 reduced_data['img_x_bounds'] = img_x_bounds | |
| 229 | |
| 230 # Set zoom and/or theta skip to reduce memory the requirement | |
| 231 zoom_perc, num_theta_skip = self._set_zoom_or_skip() | |
| 232 if zoom_perc is not None: | |
| 233 reduced_data.attrs['zoom_perc'] = zoom_perc | |
| 234 if num_theta_skip is not None: | |
| 235 reduced_data.attrs['num_theta_skip'] = num_theta_skip | |
| 236 | |
| 237 # Generate reduced tomography fields | |
| 238 reduced_data = self._gen_tomo(nxentry, reduced_data) | |
| 239 | |
| 240 # Create a copy of the input Nexus object and remove raw and any existing reduced data | |
| 241 if isinstance(data, NXroot): | |
| 242 exclude_items = [f'{nxentry._name}/reduced_data/data', | |
| 243 f'{nxentry._name}/instrument/detector/data', | |
| 244 f'{nxentry._name}/instrument/detector/image_key', | |
| 245 f'{nxentry._name}/instrument/detector/sequence_number', | |
| 246 f'{nxentry._name}/sample/rotation_angle', | |
| 247 f'{nxentry._name}/sample/x_translation', | |
| 248 f'{nxentry._name}/sample/z_translation', | |
| 249 f'{nxentry._name}/data/data', | |
| 250 f'{nxentry._name}/data/image_key', | |
| 251 f'{nxentry._name}/data/rotation_angle', | |
| 252 f'{nxentry._name}/data/x_translation', | |
| 253 f'{nxentry._name}/data/z_translation'] | |
| 254 nxroot = nxcopy(data, exclude_nxpaths=exclude_items) | |
| 255 nxentry = nxroot[nxroot.attrs['default']] | |
| 256 | |
| 257 # Add the reduced data NXprocess | |
| 258 nxentry.reduced_data = reduced_data | |
| 259 | |
| 260 if 'data' not in nxentry: | |
| 261 nxentry.data = NXdata() | |
| 262 nxentry.attrs['default'] = 'data' | |
| 263 nxentry.data.makelink(nxentry.reduced_data.data.tomo_fields, name='reduced_data') | |
| 264 nxentry.data.makelink(nxentry.reduced_data.rotation_angle, name='rotation_angle') | |
| 265 nxentry.data.attrs['signal'] = 'reduced_data' | |
| 266 | |
| 267 return(nxroot) | |
| 268 | |
| 269 def find_centers(self, nxroot, center_rows=None): | |
| 270 """Find the calibrated center axis info | |
| 271 """ | |
| 272 logger.info('Find the calibrated center axis info') | |
| 273 | |
| 274 if not isinstance(nxroot, NXroot): | |
| 275 raise ValueError(f'Invalid parameter nxroot ({nxroot})') | |
| 276 nxentry = nxroot[nxroot.attrs['default']] | |
| 277 if not isinstance(nxentry, NXentry): | |
| 278 raise ValueError(f'Invalid nxentry ({nxentry})') | |
| 279 if self.galaxy_flag: | |
| 280 if center_rows is None: | |
| 281 raise ValueError(f'Missing parameter center_rows ({center_rows})') | |
| 282 if not is_int_pair(center_rows): | |
| 283 raise ValueError(f'Invalid parameter center_rows ({center_rows})') | |
| 284 elif center_rows is not None: | |
| 285 logging.warning(f'Ignoring parameter center_rows ({center_rows})') | |
| 286 center_rows = None | |
| 287 | |
| 288 # Create plot galaxy path directory and path if needed | |
| 289 if self.galaxy_flag: | |
| 290 if not os_path.exists('tomo_find_centers_plots'): | |
| 291 mkdir('tomo_find_centers_plots') | |
| 292 path = 'tomo_find_centers_plots' | |
| 293 else: | |
| 294 path = self.output_folder | |
| 295 | |
| 296 # Check if reduced data is available | |
| 297 if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data): | |
| 298 raise KeyError(f'Unable to find valid reduced data in {nxentry}.') | |
| 299 | |
| 300 # Select the image stack to calibrate the center axis | |
| 301 # reduced data axes order: stack,row,theta,column | |
| 302 # Note: Nexus cannot follow a link if the data it points to is too big, | |
| 303 # so get the data from the actual place, not from nxentry.data | |
| 304 num_tomo_stacks = nxentry.reduced_data.data.tomo_fields.shape[0] | |
| 305 if num_tomo_stacks == 1: | |
| 306 center_stack_index = 0 | |
| 307 center_stack = np.asarray(nxentry.reduced_data.data.tomo_fields[0]) | |
| 308 if not center_stack.size: | |
| 309 raise KeyError('Unable to load the required reduced tomography stack') | |
| 310 default = 'n' | |
| 311 else: | |
| 312 if self.test_mode: | |
| 313 center_stack_index = self.test_config['center_stack_index']-1 # make offset 0 | |
| 314 else: | |
| 315 center_stack_index = input_int('\nEnter tomography stack index to calibrate the ' | |
| 316 'center axis', ge=0, le=num_tomo_stacks-1, default=int(num_tomo_stacks/2)) | |
| 317 center_stack = \ | |
| 318 np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index]) | |
| 319 if not center_stack.size: | |
| 320 raise KeyError('Unable to load the required reduced tomography stack') | |
| 321 default = 'y' | |
| 322 | |
| 323 # Get thetas (in degrees) | |
| 324 thetas = np.asarray(nxentry.reduced_data.rotation_angle) | |
| 325 | |
| 326 # Get effective pixel_size | |
| 327 if 'zoom_perc' in nxentry.reduced_data: | |
| 328 eff_pixel_size = 100.*(nxentry.instrument.detector.x_pixel_size/ | |
| 329 nxentry.reduced_data.attrs['zoom_perc']) | |
| 330 else: | |
| 331 eff_pixel_size = nxentry.instrument.detector.x_pixel_size | |
| 332 | |
| 333 # Get cross sectional diameter | |
| 334 cross_sectional_dim = center_stack.shape[2]*eff_pixel_size | |
| 335 logger.debug(f'cross_sectional_dim = {cross_sectional_dim}') | |
| 336 | |
| 337 # Determine center offset at sample row boundaries | |
| 338 logger.info('Determine center offset at sample row boundaries') | |
| 339 | |
| 340 # Lower row center | |
| 341 # center_stack order: row,theta,column | |
| 342 if self.test_mode: | |
| 343 lower_row = self.test_config['lower_row'] | |
| 344 elif self.galaxy_flag: | |
| 345 lower_row = min(center_rows) | |
| 346 if not 0 <= lower_row < center_stack.shape[0]-1: | |
| 347 raise ValueError(f'Invalid parameter center_rows ({center_rows})') | |
| 348 else: | |
| 349 lower_row = select_one_image_bound(center_stack[:,0,:], 0, bound=0, | |
| 350 title=f'theta={round(thetas[0], 2)+0}', | |
| 351 bound_name='row index to find lower center', default=default) | |
| 352 lower_center_offset = self._find_center_one_plane(center_stack[lower_row,:,:], lower_row, | |
| 353 thetas, eff_pixel_size, cross_sectional_dim, path=path, num_core=self.num_core) | |
| 354 logger.debug(f'lower_row = {lower_row:.2f}') | |
| 355 logger.debug(f'lower_center_offset = {lower_center_offset:.2f}') | |
| 356 | |
| 357 # Upper row center | |
| 358 if self.test_mode: | |
| 359 upper_row = self.test_config['upper_row'] | |
| 360 elif self.galaxy_flag: | |
| 361 upper_row = max(center_rows) | |
| 362 if not lower_row < upper_row < center_stack.shape[0]: | |
| 363 raise ValueError(f'Invalid parameter center_rows ({center_rows})') | |
| 364 else: | |
| 365 upper_row = select_one_image_bound(center_stack[:,0,:], 0, | |
| 366 bound=center_stack.shape[0]-1, title=f'theta={round(thetas[0], 2)+0}', | |
| 367 bound_name='row index to find upper center', default=default) | |
| 368 upper_center_offset = self._find_center_one_plane(center_stack[upper_row,:,:], upper_row, | |
| 369 thetas, eff_pixel_size, cross_sectional_dim, path=path, num_core=self.num_core) | |
| 370 logger.debug(f'upper_row = {upper_row:.2f}') | |
| 371 logger.debug(f'upper_center_offset = {upper_center_offset:.2f}') | |
| 372 del center_stack | |
| 373 | |
| 374 center_config = {'lower_row': lower_row, 'lower_center_offset': lower_center_offset, | |
| 375 'upper_row': upper_row, 'upper_center_offset': upper_center_offset} | |
| 376 if num_tomo_stacks > 1: | |
| 377 center_config['center_stack_index'] = center_stack_index+1 # save as offset 1 | |
| 378 | |
| 379 # Save test data to file | |
| 380 if self.test_mode: | |
| 381 with open(f'{self.output_folder}/center_config.yaml', 'w') as f: | |
| 382 safe_dump(center_config, f) | |
| 383 | |
| 384 return(center_config) | |
| 385 | |
| 386 def reconstruct_data(self, nxroot, center_info, x_bounds=None, y_bounds=None): | |
| 387 """Reconstruct the tomography data. | |
| 388 """ | |
| 389 logger.info('Reconstruct the tomography data') | |
| 390 | |
| 391 if not isinstance(nxroot, NXroot): | |
| 392 raise ValueError(f'Invalid parameter nxroot ({nxroot})') | |
| 393 nxentry = nxroot[nxroot.attrs['default']] | |
| 394 if not isinstance(nxentry, NXentry): | |
| 395 raise ValueError(f'Invalid nxentry ({nxentry})') | |
| 396 if not isinstance(center_info, dict): | |
| 397 raise ValueError(f'Invalid parameter center_info ({center_info})') | |
| 398 | |
| 399 # Create plot galaxy path directory and path if needed | |
| 400 if self.galaxy_flag: | |
| 401 if not os_path.exists('tomo_reconstruct_plots'): | |
| 402 mkdir('tomo_reconstruct_plots') | |
| 403 path = 'tomo_reconstruct_plots' | |
| 404 else: | |
| 405 path = self.output_folder | |
| 406 | |
| 407 # Check if reduced data is available | |
| 408 if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data): | |
| 409 raise KeyError(f'Unable to find valid reduced data in {nxentry}.') | |
| 410 | |
| 411 # Create an NXprocess to store image reconstruction (meta)data | |
| 412 # if 'reconstructed_data' in nxentry: | |
| 413 # logger.warning(f'Existing reconstructed data in {nxentry} will be overwritten.') | |
| 414 # del nxentry['reconstructed_data'] | |
| 415 # if 'data' in nxentry and 'reconstructed_data' in nxentry.data: | |
| 416 # del nxentry.data['reconstructed_data'] | |
| 417 nxprocess = NXprocess() | |
| 418 | |
| 419 # Get rotation axis rows and centers | |
| 420 lower_row = center_info.get('lower_row') | |
| 421 lower_center_offset = center_info.get('lower_center_offset') | |
| 422 upper_row = center_info.get('upper_row') | |
| 423 upper_center_offset = center_info.get('upper_center_offset') | |
| 424 if (lower_row is None or lower_center_offset is None or upper_row is None or | |
| 425 upper_center_offset is None): | |
| 426 raise KeyError(f'Unable to find valid calibrated center axis info in {center_info}.') | |
| 427 center_slope = (upper_center_offset-lower_center_offset)/(upper_row-lower_row) | |
| 428 | |
| 429 # Get thetas (in degrees) | |
| 430 thetas = np.asarray(nxentry.reduced_data.rotation_angle) | |
| 431 | |
| 432 # Reconstruct tomography data | |
| 433 # reduced data axes order: stack,row,theta,column | |
| 434 # reconstructed data order in each stack: row/z,x,y | |
| 435 # Note: Nexus cannot follow a link if the data it points to is too big, | |
| 436 # so get the data from the actual place, not from nxentry.data | |
| 437 if 'zoom_perc' in nxentry.reduced_data: | |
| 438 res_title = f'{nxentry.reduced_data.attrs["zoom_perc"]}p' | |
| 439 else: | |
| 440 res_title = 'fullres' | |
| 441 load_error = False | |
| 442 num_tomo_stacks = nxentry.reduced_data.data.tomo_fields.shape[0] | |
| 443 tomo_recon_stacks = num_tomo_stacks*[np.array([])] | |
| 444 for i in range(num_tomo_stacks): | |
| 445 tomo_stack = np.asarray(nxentry.reduced_data.data.tomo_fields[i]) | |
| 446 if not tomo_stack.size: | |
| 447 raise KeyError(f'Unable to load tomography stack {i} for reconstruction') | |
| 448 assert(0 <= lower_row < upper_row < tomo_stack.shape[0]) | |
| 449 center_offsets = [lower_center_offset-lower_row*center_slope, | |
| 450 upper_center_offset+(tomo_stack.shape[0]-1-upper_row)*center_slope] | |
| 451 t0 = time() | |
| 452 logger.debug(f'Running _reconstruct_one_tomo_stack on {self.num_core} cores ...') | |
| 453 tomo_recon_stack = self._reconstruct_one_tomo_stack(tomo_stack, thetas, | |
| 454 center_offsets=center_offsets, num_core=self.num_core, algorithm='gridrec') | |
| 455 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 456 logger.info(f'Reconstruction of stack {i} took {time()-t0:.2f} seconds') | |
| 457 | |
| 458 # Combine stacks | |
| 459 tomo_recon_stacks[i] = tomo_recon_stack | |
| 460 | |
| 461 # Resize the reconstructed tomography data | |
| 462 # reconstructed data order in each stack: row/z,x,y | |
| 463 if self.test_mode: | |
| 464 x_bounds = self.test_config.get('x_bounds') | |
| 465 y_bounds = self.test_config.get('y_bounds') | |
| 466 z_bounds = None | |
| 467 elif self.galaxy_flag: | |
| 468 if x_bounds is not None and not is_int_pair(x_bounds, ge=0, | |
| 469 lt=tomo_recon_stacks[0].shape[1]): | |
| 470 raise ValueError(f'Invalid parameter x_bounds ({x_bounds})') | |
| 471 if y_bounds is not None and not is_int_pair(y_bounds, ge=0, | |
| 472 lt=tomo_recon_stacks[0].shape[1]): | |
| 473 raise ValueError(f'Invalid parameter y_bounds ({y_bounds})') | |
| 474 z_bounds = None | |
| 475 else: | |
| 476 x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_stacks) | |
| 477 if x_bounds is None: | |
| 478 x_range = (0, tomo_recon_stacks[0].shape[1]) | |
| 479 x_slice = int(x_range[1]/2) | |
| 480 else: | |
| 481 x_range = (min(x_bounds), max(x_bounds)) | |
| 482 x_slice = int((x_bounds[0]+x_bounds[1])/2) | |
| 483 if y_bounds is None: | |
| 484 y_range = (0, tomo_recon_stacks[0].shape[2]) | |
| 485 y_slice = int(y_range[1]/2) | |
| 486 else: | |
| 487 y_range = (min(y_bounds), max(y_bounds)) | |
| 488 y_slice = int((y_bounds[0]+y_bounds[1])/2) | |
| 489 if z_bounds is None: | |
| 490 z_range = (0, tomo_recon_stacks[0].shape[0]) | |
| 491 z_slice = int(z_range[1]/2) | |
| 492 else: | |
| 493 z_range = (min(z_bounds), max(z_bounds)) | |
| 494 z_slice = int((z_bounds[0]+z_bounds[1])/2) | |
| 495 | |
| 496 # Plot a few reconstructed image slices | |
| 497 if num_tomo_stacks == 1: | |
| 498 basetitle = 'recon' | |
| 499 else: | |
| 500 basetitle = f'recon stack {i}' | |
| 501 for i, stack in enumerate(tomo_recon_stacks): | |
| 502 title = f'{basetitle} {res_title} xslice{x_slice}' | |
| 503 quick_imshow(stack[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]], | |
| 504 title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, | |
| 505 block=self.block) | |
| 506 title = f'{basetitle} {res_title} yslice{y_slice}' | |
| 507 quick_imshow(stack[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice], | |
| 508 title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, | |
| 509 block=self.block) | |
| 510 title = f'{basetitle} {res_title} zslice{z_slice}' | |
| 511 quick_imshow(stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], | |
| 512 title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, | |
| 513 block=self.block) | |
| 514 | |
| 515 # Save test data to file | |
| 516 # reconstructed data order in each stack: row/z,x,y | |
| 517 if self.test_mode: | |
| 518 for i, stack in enumerate(tomo_recon_stacks): | |
| 519 np.savetxt(f'{self.output_folder}/recon_stack_{i+1}.txt', | |
| 520 stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e') | |
| 521 | |
| 522 # Add image reconstruction to reconstructed data NXprocess | |
| 523 # reconstructed data order in each stack: row/z,x,y | |
| 524 nxprocess.data = NXdata() | |
| 525 nxprocess.attrs['default'] = 'data' | |
| 526 for k, v in center_info.items(): | |
| 527 nxprocess[k] = v | |
| 528 if x_bounds is not None: | |
| 529 nxprocess.x_bounds = x_bounds | |
| 530 if y_bounds is not None: | |
| 531 nxprocess.y_bounds = y_bounds | |
| 532 if z_bounds is not None: | |
| 533 nxprocess.z_bounds = z_bounds | |
| 534 nxprocess.data['reconstructed_data'] = np.asarray([stack[z_range[0]:z_range[1], | |
| 535 x_range[0]:x_range[1],y_range[0]:y_range[1]] for stack in tomo_recon_stacks]) | |
| 536 nxprocess.data.attrs['signal'] = 'reconstructed_data' | |
| 537 | |
| 538 # Create a copy of the input Nexus object and remove reduced data | |
| 539 exclude_items = [f'{nxentry._name}/reduced_data/data', f'{nxentry._name}/data/reduced_data'] | |
| 540 nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items) | |
| 541 | |
| 542 # Add the reconstructed data NXprocess to the new Nexus object | |
| 543 nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']] | |
| 544 nxentry_copy.reconstructed_data = nxprocess | |
| 545 if 'data' not in nxentry_copy: | |
| 546 nxentry_copy.data = NXdata() | |
| 547 nxentry_copy.attrs['default'] = 'data' | |
| 548 nxentry_copy.data.makelink(nxprocess.data.reconstructed_data, name='reconstructed_data') | |
| 549 nxentry_copy.data.attrs['signal'] = 'reconstructed_data' | |
| 550 | |
| 551 return(nxroot_copy) | |
| 552 | |
| 553 def combine_data(self, nxroot): | |
| 554 """Combine the reconstructed tomography stacks. | |
| 555 """ | |
| 556 logger.info('Combine the reconstructed tomography stacks') | |
| 557 | |
| 558 if not isinstance(nxroot, NXroot): | |
| 559 raise ValueError(f'Invalid parameter nxroot ({nxroot})') | |
| 560 nxentry = nxroot[nxroot.attrs['default']] | |
| 561 if not isinstance(nxentry, NXentry): | |
| 562 raise ValueError(f'Invalid nxentry ({nxentry})') | |
| 563 | |
| 564 # Create plot galaxy path directory and path if needed | |
| 565 if self.galaxy_flag: | |
| 566 if not os_path.exists('tomo_combine_plots'): | |
| 567 mkdir('tomo_combine_plots') | |
| 568 path = 'tomo_combine_plots' | |
| 569 else: | |
| 570 path = self.output_folder | |
| 571 | |
| 572 # Check if reconstructed image data is available | |
| 573 if ('reconstructed_data' not in nxentry or 'reconstructed_data' not in nxentry.data): | |
| 574 raise KeyError(f'Unable to find valid reconstructed image data in {nxentry}.') | |
| 575 | |
| 576 # Create an NXprocess to store combined image reconstruction (meta)data | |
| 577 # if 'combined_data' in nxentry: | |
| 578 # logger.warning(f'Existing combined data in {nxentry} will be overwritten.') | |
| 579 # del nxentry['combined_data'] | |
| 580 # if 'data' in nxentry 'combined_data' in nxentry.data: | |
| 581 # del nxentry.data['combined_data'] | |
| 582 nxprocess = NXprocess() | |
| 583 | |
| 584 # Get the reconstructed data | |
| 585 # reconstructed data order: stack,row(z),x,y | |
| 586 # Note: Nexus cannot follow a link if the data it points to is too big, | |
| 587 # so get the data from the actual place, not from nxentry.data | |
| 588 tomo_recon_stacks = np.asarray(nxentry.reconstructed_data.data.reconstructed_data) | |
| 589 num_tomo_stacks = tomo_recon_stacks.shape[0] | |
| 590 if num_tomo_stacks == 1: | |
| 591 return(nxroot) | |
| 592 t0 = time() | |
| 593 logger.debug(f'Combining the reconstructed stacks ...') | |
| 594 tomo_recon_combined = tomo_recon_stacks[0,:,:,:] | |
| 595 if num_tomo_stacks > 2: | |
| 596 tomo_recon_combined = np.concatenate([tomo_recon_combined]+ | |
| 597 [tomo_recon_stacks[i,:,:,:] for i in range(1, num_tomo_stacks-1)]) | |
| 598 if num_tomo_stacks > 1: | |
| 599 tomo_recon_combined = np.concatenate([tomo_recon_combined]+ | |
| 600 [tomo_recon_stacks[num_tomo_stacks-1,:,:,:]]) | |
| 601 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 602 logger.info(f'Combining the reconstructed stacks took {time()-t0:.2f} seconds') | |
| 603 | |
| 604 # Resize the combined tomography data set | |
| 605 # combined data order: row/z,x,y | |
| 606 if self.test_mode: | |
| 607 x_bounds = None | |
| 608 y_bounds = None | |
| 609 z_bounds = self.test_config.get('z_bounds') | |
| 610 elif self.galaxy_flag: | |
| 611 exit('TODO') | |
| 612 if x_bounds is not None and not is_int_pair(x_bounds, ge=0, | |
| 613 lt=tomo_recon_stacks[0].shape[1]): | |
| 614 raise ValueError(f'Invalid parameter x_bounds ({x_bounds})') | |
| 615 if y_bounds is not None and not is_int_pair(y_bounds, ge=0, | |
| 616 lt=tomo_recon_stacks[0].shape[1]): | |
| 617 raise ValueError(f'Invalid parameter y_bounds ({y_bounds})') | |
| 618 z_bounds = None | |
| 619 else: | |
| 620 x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_combined, | |
| 621 z_only=True) | |
| 622 if x_bounds is None: | |
| 623 x_range = (0, tomo_recon_combined.shape[1]) | |
| 624 x_slice = int(x_range[1]/2) | |
| 625 else: | |
| 626 x_range = x_bounds | |
| 627 x_slice = int((x_bounds[0]+x_bounds[1])/2) | |
| 628 if y_bounds is None: | |
| 629 y_range = (0, tomo_recon_combined.shape[2]) | |
| 630 y_slice = int(y_range[1]/2) | |
| 631 else: | |
| 632 y_range = y_bounds | |
| 633 y_slice = int((y_bounds[0]+y_bounds[1])/2) | |
| 634 if z_bounds is None: | |
| 635 z_range = (0, tomo_recon_combined.shape[0]) | |
| 636 z_slice = int(z_range[1]/2) | |
| 637 else: | |
| 638 z_range = z_bounds | |
| 639 z_slice = int((z_bounds[0]+z_bounds[1])/2) | |
| 640 | |
| 641 # Plot a few combined image slices | |
| 642 quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]], | |
| 643 title=f'recon combined xslice{x_slice}', path=path, | |
| 644 save_fig=self.save_figs, save_only=self.save_only, block=self.block) | |
| 645 quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice], | |
| 646 title=f'recon combined yslice{y_slice}', path=path, | |
| 647 save_fig=self.save_figs, save_only=self.save_only, block=self.block) | |
| 648 quick_imshow(tomo_recon_combined[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], | |
| 649 title=f'recon combined zslice{z_slice}', path=path, | |
| 650 save_fig=self.save_figs, save_only=self.save_only, block=self.block) | |
| 651 | |
| 652 # Save test data to file | |
| 653 # combined data order: row/z,x,y | |
| 654 if self.test_mode: | |
| 655 np.savetxt(f'{self.output_folder}/recon_combined.txt', tomo_recon_combined[ | |
| 656 z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e') | |
| 657 | |
| 658 # Add image reconstruction to reconstructed data NXprocess | |
| 659 # combined data order: row/z,x,y | |
| 660 nxprocess.data = NXdata() | |
| 661 nxprocess.attrs['default'] = 'data' | |
| 662 if x_bounds is not None: | |
| 663 nxprocess.x_bounds = x_bounds | |
| 664 if y_bounds is not None: | |
| 665 nxprocess.y_bounds = y_bounds | |
| 666 if z_bounds is not None: | |
| 667 nxprocess.z_bounds = z_bounds | |
| 668 nxprocess.data['combined_data'] = tomo_recon_combined | |
| 669 nxprocess.data.attrs['signal'] = 'combined_data' | |
| 670 | |
| 671 # Create a copy of the input Nexus object and remove reconstructed data | |
| 672 exclude_items = [f'{nxentry._name}/reconstructed_data/data', | |
| 673 f'{nxentry._name}/data/reconstructed_data'] | |
| 674 nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items) | |
| 675 | |
| 676 # Add the combined data NXprocess to the new Nexus object | |
| 677 nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']] | |
| 678 nxentry_copy.combined_data = nxprocess | |
| 679 if 'data' not in nxentry_copy: | |
| 680 nxentry_copy.data = NXdata() | |
| 681 nxentry_copy.attrs['default'] = 'data' | |
| 682 nxentry_copy.data.makelink(nxprocess.data.combined_data, name='combined_data') | |
| 683 nxentry_copy.data.attrs['signal'] = 'combined_data' | |
| 684 | |
| 685 return(nxroot_copy) | |
| 686 | |
| 687 def _gen_dark(self, nxentry, reduced_data): | |
| 688 """Generate dark field. | |
| 689 """ | |
| 690 # Get the dark field images | |
| 691 image_key = nxentry.instrument.detector.get('image_key', None) | |
| 692 if image_key and 'data' in nxentry.instrument.detector: | |
| 693 field_indices = [index for index, key in enumerate(image_key) if key == 2] | |
| 694 tdf_stack = nxentry.instrument.detector.data[field_indices,:,:] | |
| 695 # RV the default NXtomo form does not accomodate bright or dark field stacks | |
| 696 else: | |
| 697 dark_field_scans = nxentry.spec_scans.dark_field | |
| 698 dark_field = FlatField.construct_from_nxcollection(dark_field_scans) | |
| 699 prefix = str(nxentry.instrument.detector.local_name) | |
| 700 tdf_stack = dark_field.get_detector_data(prefix) | |
| 701 if isinstance(tdf_stack, list): | |
| 702 exit('TODO') | |
| 703 | |
| 704 # Take median | |
| 705 if tdf_stack.ndim == 2: | |
| 706 tdf = tdf_stack | |
| 707 elif tdf_stack.ndim == 3: | |
| 708 tdf = np.median(tdf_stack, axis=0) | |
| 709 del tdf_stack | |
| 710 else: | |
| 711 raise ValueError(f'Invalid tdf_stack shape ({tdf_stack.shape})') | |
| 712 | |
| 713 # Remove dark field intensities above the cutoff | |
| 714 #RV tdf_cutoff = None | |
| 715 tdf_cutoff = tdf.min()+2*(np.median(tdf)-tdf.min()) | |
| 716 logger.debug(f'tdf_cutoff = {tdf_cutoff}') | |
| 717 if tdf_cutoff is not None: | |
| 718 if not is_num(tdf_cutoff, ge=0): | |
| 719 logger.warning(f'Ignoring illegal value of tdf_cutoff {tdf_cutoff}') | |
| 720 else: | |
| 721 tdf[tdf > tdf_cutoff] = np.nan | |
| 722 logger.debug(f'tdf_cutoff = {tdf_cutoff}') | |
| 723 | |
| 724 # Remove nans | |
| 725 tdf_mean = np.nanmean(tdf) | |
| 726 logger.debug(f'tdf_mean = {tdf_mean}') | |
| 727 np.nan_to_num(tdf, copy=False, nan=tdf_mean, posinf=tdf_mean, neginf=0.) | |
| 728 | |
| 729 # Plot dark field | |
| 730 if self.galaxy_flag: | |
| 731 quick_imshow(tdf, title='dark field', path='tomo_reduce_plots', save_fig=self.save_figs, | |
| 732 save_only=self.save_only) | |
| 733 elif not self.test_mode: | |
| 734 quick_imshow(tdf, title='dark field', path=self.output_folder, save_fig=self.save_figs, | |
| 735 save_only=self.save_only) | |
| 736 clear_imshow('dark field') | |
| 737 # quick_imshow(tdf, title='dark field', block=True) | |
| 738 | |
| 739 # Add dark field to reduced data NXprocess | |
| 740 reduced_data.data = NXdata() | |
| 741 reduced_data.data['dark_field'] = tdf | |
| 742 | |
| 743 return(reduced_data) | |
| 744 | |
| 745 def _gen_bright(self, nxentry, reduced_data): | |
| 746 """Generate bright field. | |
| 747 """ | |
| 748 # Get the bright field images | |
| 749 image_key = nxentry.instrument.detector.get('image_key', None) | |
| 750 if image_key and 'data' in nxentry.instrument.detector: | |
| 751 field_indices = [index for index, key in enumerate(image_key) if key == 1] | |
| 752 tbf_stack = nxentry.instrument.detector.data[field_indices,:,:] | |
| 753 # RV the default NXtomo form does not accomodate bright or dark field stacks | |
| 754 else: | |
| 755 bright_field_scans = nxentry.spec_scans.bright_field | |
| 756 bright_field = FlatField.construct_from_nxcollection(bright_field_scans) | |
| 757 prefix = str(nxentry.instrument.detector.local_name) | |
| 758 tbf_stack = bright_field.get_detector_data(prefix) | |
| 759 if isinstance(tbf_stack, list): | |
| 760 exit('TODO') | |
| 761 | |
| 762 # Take median if more than one image | |
| 763 """Median or mean: It may be best to try the median because of some image | |
| 764 artifacts that arise due to crinkles in the upstream kapton tape windows | |
| 765 causing some phase contrast images to appear on the detector. | |
| 766 One thing that also may be useful in a future implementation is to do a | |
| 767 brightfield adjustment on EACH frame of the tomo based on a ROI in the | |
| 768 corner of the frame where there is no sample but there is the direct X-ray | |
| 769 beam because there is frame to frame fluctuations from the incoming beam. | |
| 770 We don’t typically account for them but potentially could. | |
| 771 """ | |
| 772 if tbf_stack.ndim == 2: | |
| 773 tbf = tbf_stack | |
| 774 elif tbf_stack.ndim == 3: | |
| 775 tbf = np.median(tbf_stack, axis=0) | |
| 776 del tbf_stack | |
| 777 else: | |
| 778 raise ValueError(f'Invalid tbf_stack shape ({tbf_stacks.shape})') | |
| 779 | |
| 780 # Subtract dark field | |
| 781 if 'data' in reduced_data and 'dark_field' in reduced_data.data: | |
| 782 tbf -= reduced_data.data.dark_field | |
| 783 else: | |
| 784 logger.warning('Dark field unavailable') | |
| 785 | |
| 786 # Set any non-positive values to one | |
| 787 # (avoid negative bright field values for spikes in dark field) | |
| 788 tbf[tbf < 1] = 1 | |
| 789 | |
| 790 # Plot bright field | |
| 791 if self.galaxy_flag: | |
| 792 quick_imshow(tbf, title='bright field', path='tomo_reduce_plots', | |
| 793 save_fig=self.save_figs, save_only=self.save_only) | |
| 794 elif not self.test_mode: | |
| 795 quick_imshow(tbf, title='bright field', path=self.output_folder, | |
| 796 save_fig=self.save_figs, save_only=self.save_only) | |
| 797 clear_imshow('bright field') | |
| 798 # quick_imshow(tbf, title='bright field', block=True) | |
| 799 | |
| 800 # Add bright field to reduced data NXprocess | |
| 801 if 'data' not in reduced_data: | |
| 802 reduced_data.data = NXdata() | |
| 803 reduced_data.data['bright_field'] = tbf | |
| 804 | |
| 805 return(reduced_data) | |
| 806 | |
| 807 def _set_detector_bounds(self, nxentry, reduced_data, img_x_bounds=None): | |
| 808 """Set vertical detector bounds for each image stack. | |
| 809 Right now the range is the same for each set in the image stack. | |
| 810 """ | |
| 811 if self.test_mode: | |
| 812 return(tuple(self.test_config['img_x_bounds'])) | |
| 813 | |
| 814 # Get the first tomography image and the reference heights | |
| 815 image_key = nxentry.instrument.detector.get('image_key', None) | |
| 816 if image_key and 'data' in nxentry.instrument.detector: | |
| 817 field_indices = [index for index, key in enumerate(image_key) if key == 0] | |
| 818 first_image = np.asarray(nxentry.instrument.detector.data[field_indices[0],:,:]) | |
| 819 theta = float(nxentry.sample.rotation_angle[field_indices[0]]) | |
| 820 z_translation_all = nxentry.sample.z_translation[field_indices] | |
| 821 z_translation_levels = sorted(list(set(z_translation_all))) | |
| 822 num_tomo_stacks = len(z_translation_levels) | |
| 823 else: | |
| 824 tomo_field_scans = nxentry.spec_scans.tomo_fields | |
| 825 tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans) | |
| 826 vertical_shifts = tomo_fields.get_vertical_shifts() | |
| 827 if not isinstance(vertical_shifts, list): | |
| 828 vertical_shifts = [vertical_shifts] | |
| 829 prefix = str(nxentry.instrument.detector.local_name) | |
| 830 t0 = time() | |
| 831 first_image = tomo_fields.get_detector_data(prefix, tomo_fields.scan_numbers[0], 0) | |
| 832 logger.debug(f'Getting first image took {time()-t0:.2f} seconds') | |
| 833 num_tomo_stacks = len(tomo_fields.scan_numbers) | |
| 834 theta = tomo_fields.theta_range['start'] | |
| 835 | |
| 836 # Select image bounds | |
| 837 title = f'tomography image at theta={round(theta, 2)+0}' | |
| 838 if (img_x_bounds is not None and not is_index_range(img_x_bounds, ge=0, | |
| 839 le=first_image.shape[0])): | |
| 840 raise ValueError(f'Invalid parameter img_x_bounds ({img_x_bounds})') | |
| 841 if nxentry.instrument.source.attrs['station'] in ('id1a3', 'id3a'): | |
| 842 pixel_size = nxentry.instrument.detector.x_pixel_size | |
| 843 # Try to get a fit from the bright field | |
| 844 tbf = np.asarray(reduced_data.data.bright_field) | |
| 845 tbf_shape = tbf.shape | |
| 846 x_sum = np.sum(tbf, 1) | |
| 847 x_sum_min = x_sum.min() | |
| 848 x_sum_max = x_sum.max() | |
| 849 fit = Fit.fit_data(x_sum, 'rectangle', x=np.array(range(len(x_sum))), form='atan', | |
| 850 guess=True) | |
| 851 parameters = fit.best_values | |
| 852 x_low_fit = parameters.get('center1', None) | |
| 853 x_upp_fit = parameters.get('center2', None) | |
| 854 sig_low = parameters.get('sigma1', None) | |
| 855 sig_upp = parameters.get('sigma2', None) | |
| 856 have_fit = fit.success and x_low_fit is not None and x_upp_fit is not None and \ | |
| 857 sig_low is not None and sig_upp is not None and \ | |
| 858 0 <= x_low_fit < x_upp_fit <= x_sum.size and \ | |
| 859 (sig_low+sig_upp)/(x_upp_fit-x_low_fit) < 0.1 | |
| 860 if have_fit: | |
| 861 # Set a 5% margin on each side | |
| 862 margin = 0.05*(x_upp_fit-x_low_fit) | |
| 863 x_low_fit = max(0, x_low_fit-margin) | |
| 864 x_upp_fit = min(tbf_shape[0], x_upp_fit+margin) | |
| 865 if num_tomo_stacks == 1: | |
| 866 if have_fit: | |
| 867 # Set the default range to enclose the full fitted window | |
| 868 x_low = int(x_low_fit) | |
| 869 x_upp = int(x_upp_fit) | |
| 870 else: | |
| 871 # Center a default range of 1 mm (RV: can we get this from the slits?) | |
| 872 num_x_min = int((1.0-0.5*pixel_size)/pixel_size) | |
| 873 x_low = int(0.5*(tbf_shape[0]-num_x_min)) | |
| 874 x_upp = x_low+num_x_min | |
| 875 else: | |
| 876 # Get the default range from the reference heights | |
| 877 delta_z = vertical_shifts[1]-vertical_shifts[0] | |
| 878 for i in range(2, num_tomo_stacks): | |
| 879 delta_z = min(delta_z, vertical_shifts[i]-vertical_shifts[i-1]) | |
| 880 logger.debug(f'delta_z = {delta_z}') | |
| 881 num_x_min = int((delta_z-0.5*pixel_size)/pixel_size) | |
| 882 logger.debug(f'num_x_min = {num_x_min}') | |
| 883 if num_x_min > tbf_shape[0]: | |
| 884 logger.warning('Image bounds and pixel size prevent seamless stacking') | |
| 885 if have_fit: | |
| 886 # Center the default range relative to the fitted window | |
| 887 x_low = int(0.5*(x_low_fit+x_upp_fit-num_x_min)) | |
| 888 x_upp = x_low+num_x_min | |
| 889 else: | |
| 890 # Center the default range | |
| 891 x_low = int(0.5*(tbf_shape[0]-num_x_min)) | |
| 892 x_upp = x_low+num_x_min | |
| 893 if self.galaxy_flag: | |
| 894 img_x_bounds = (x_low, x_upp) | |
| 895 else: | |
| 896 tmp = np.copy(tbf) | |
| 897 tmp_max = tmp.max() | |
| 898 tmp[x_low,:] = tmp_max | |
| 899 tmp[x_upp-1,:] = tmp_max | |
| 900 quick_imshow(tmp, title='bright field') | |
| 901 tmp = np.copy(first_image) | |
| 902 tmp_max = tmp.max() | |
| 903 tmp[x_low,:] = tmp_max | |
| 904 tmp[x_upp-1,:] = tmp_max | |
| 905 quick_imshow(tmp, title=title) | |
| 906 del tmp | |
| 907 quick_plot((range(x_sum.size), x_sum), | |
| 908 ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'), | |
| 909 ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'), | |
| 910 title='sum over theta and y') | |
| 911 print(f'lower bound = {x_low} (inclusive)') | |
| 912 print(f'upper bound = {x_upp} (exclusive)]') | |
| 913 accept = input_yesno('Accept these bounds (y/n)?', 'y') | |
| 914 clear_imshow('bright field') | |
| 915 clear_imshow(title) | |
| 916 clear_plot('sum over theta and y') | |
| 917 if accept: | |
| 918 img_x_bounds = (x_low, x_upp) | |
| 919 else: | |
| 920 while True: | |
| 921 mask, img_x_bounds = draw_mask_1d(x_sum, title='select x data range', | |
| 922 legend='sum over theta and y') | |
| 923 if len(img_x_bounds) == 1: | |
| 924 break | |
| 925 else: | |
| 926 print(f'Choose a single connected data range') | |
| 927 img_x_bounds = tuple(img_x_bounds[0]) | |
| 928 if (num_tomo_stacks > 1 and img_x_bounds[1]-img_x_bounds[0]+1 < | |
| 929 int((delta_z-0.5*pixel_size)/pixel_size)): | |
| 930 logger.warning('Image bounds and pixel size prevent seamless stacking') | |
| 931 else: | |
| 932 if num_tomo_stacks > 1: | |
| 933 raise NotImplementedError('Selecting image bounds for multiple stacks on FMB') | |
| 934 # For FMB: use the first tomography image to select range | |
| 935 # RV: revisit if they do tomography with multiple stacks | |
| 936 x_sum = np.sum(first_image, 1) | |
| 937 x_sum_min = x_sum.min() | |
| 938 x_sum_max = x_sum.max() | |
| 939 if self.galaxy_flag: | |
| 940 if img_x_bounds is None: | |
| 941 img_x_bounds = (0, first_image.shape[0]) | |
| 942 else: | |
| 943 quick_imshow(first_image, title=title) | |
| 944 print('Select vertical data reduction range from first tomography image') | |
| 945 img_x_bounds = select_image_bounds(first_image, 0, title=title) | |
| 946 clear_imshow(title) | |
| 947 if img_x_bounds is None: | |
| 948 raise ValueError('Unable to select image bounds') | |
| 949 | |
| 950 # Plot results | |
| 951 if self.galaxy_flag: | |
| 952 path = 'tomo_reduce_plots' | |
| 953 else: | |
| 954 path = self.output_folder | |
| 955 x_low = img_x_bounds[0] | |
| 956 x_upp = img_x_bounds[1] | |
| 957 tmp = np.copy(first_image) | |
| 958 tmp_max = tmp.max() | |
| 959 tmp[x_low,:] = tmp_max | |
| 960 tmp[x_upp-1,:] = tmp_max | |
| 961 quick_imshow(tmp, title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, | |
| 962 block=self.block) | |
| 963 del tmp | |
| 964 quick_plot((range(x_sum.size), x_sum), | |
| 965 ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'), | |
| 966 ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'), | |
| 967 title='sum over theta and y', path=path, save_fig=self.save_figs, | |
| 968 save_only=self.save_only, block=self.block) | |
| 969 | |
| 970 return(img_x_bounds) | |
| 971 | |
| 972 def _set_zoom_or_skip(self): | |
| 973 """Set zoom and/or theta skip to reduce memory the requirement for the analysis. | |
| 974 """ | |
| 975 # if input_yesno('\nDo you want to zoom in to reduce memory requirement (y/n)?', 'n'): | |
| 976 # zoom_perc = input_int(' Enter zoom percentage', ge=1, le=100) | |
| 977 # else: | |
| 978 # zoom_perc = None | |
| 979 zoom_perc = None | |
| 980 # if input_yesno('Do you want to skip thetas to reduce memory requirement (y/n)?', 'n'): | |
| 981 # num_theta_skip = input_int(' Enter the number skip theta interval', ge=0, | |
| 982 # lt=num_theta) | |
| 983 # else: | |
| 984 # num_theta_skip = None | |
| 985 num_theta_skip = None | |
| 986 logger.debug(f'zoom_perc = {zoom_perc}') | |
| 987 logger.debug(f'num_theta_skip = {num_theta_skip}') | |
| 988 | |
| 989 return(zoom_perc, num_theta_skip) | |
| 990 | |
| 991 def _gen_tomo(self, nxentry, reduced_data): | |
| 992 """Generate tomography fields. | |
| 993 """ | |
| 994 # Get full bright field | |
| 995 tbf = np.asarray(reduced_data.data.bright_field) | |
| 996 tbf_shape = tbf.shape | |
| 997 | |
| 998 # Get image bounds | |
| 999 img_x_bounds = tuple(reduced_data.get('img_x_bounds', (0, tbf_shape[0]))) | |
| 1000 img_y_bounds = tuple(reduced_data.get('img_y_bounds', (0, tbf_shape[1]))) | |
| 1001 | |
| 1002 # Get resized dark field | |
| 1003 # if 'dark_field' in data: | |
| 1004 # tbf = np.asarray(reduced_data.data.dark_field[ | |
| 1005 # img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]]) | |
| 1006 # else: | |
| 1007 # logger.warning('Dark field unavailable') | |
| 1008 # tdf = None | |
| 1009 tdf = None | |
| 1010 | |
| 1011 # Resize bright field | |
| 1012 if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]): | |
| 1013 tbf = tbf[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]] | |
| 1014 | |
| 1015 # Get the tomography images | |
| 1016 image_key = nxentry.instrument.detector.get('image_key', None) | |
| 1017 if image_key and 'data' in nxentry.instrument.detector: | |
| 1018 field_indices_all = [index for index, key in enumerate(image_key) if key == 0] | |
| 1019 z_translation_all = nxentry.sample.z_translation[field_indices_all] | |
| 1020 z_translation_levels = sorted(list(set(z_translation_all))) | |
| 1021 num_tomo_stacks = len(z_translation_levels) | |
| 1022 tomo_stacks = num_tomo_stacks*[np.array([])] | |
| 1023 horizontal_shifts = [] | |
| 1024 vertical_shifts = [] | |
| 1025 thetas = None | |
| 1026 tomo_stacks = [] | |
| 1027 for i, z_translation in enumerate(z_translation_levels): | |
| 1028 field_indices = [field_indices_all[index] | |
| 1029 for index, z in enumerate(z_translation_all) if z == z_translation] | |
| 1030 horizontal_shift = list(set(nxentry.sample.x_translation[field_indices])) | |
| 1031 assert(len(horizontal_shift) == 1) | |
| 1032 horizontal_shifts += horizontal_shift | |
| 1033 vertical_shift = list(set(nxentry.sample.z_translation[field_indices])) | |
| 1034 assert(len(vertical_shift) == 1) | |
| 1035 vertical_shifts += vertical_shift | |
| 1036 sequence_numbers = nxentry.instrument.detector.sequence_number[field_indices] | |
| 1037 if thetas is None: | |
| 1038 thetas = np.asarray(nxentry.sample.rotation_angle[field_indices]) \ | |
| 1039 [sequence_numbers] | |
| 1040 else: | |
| 1041 assert(all(thetas[i] == nxentry.sample.rotation_angle[field_indices[index]] | |
| 1042 for i, index in enumerate(sequence_numbers))) | |
| 1043 assert(list(set(sequence_numbers)) == [i for i in range(len(sequence_numbers))]) | |
| 1044 if list(sequence_numbers) == [i for i in range(len(sequence_numbers))]: | |
| 1045 tomo_stack = np.asarray(nxentry.instrument.detector.data[field_indices]) | |
| 1046 else: | |
| 1047 raise ValueError('Unable to load the tomography images') | |
| 1048 tomo_stacks.append(tomo_stack) | |
| 1049 else: | |
| 1050 tomo_field_scans = nxentry.spec_scans.tomo_fields | |
| 1051 tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans) | |
| 1052 horizontal_shifts = tomo_fields.get_horizontal_shifts() | |
| 1053 vertical_shifts = tomo_fields.get_vertical_shifts() | |
| 1054 prefix = str(nxentry.instrument.detector.local_name) | |
| 1055 t0 = time() | |
| 1056 tomo_stacks = tomo_fields.get_detector_data(prefix) | |
| 1057 logger.debug(f'Getting tomography images took {time()-t0:.2f} seconds') | |
| 1058 logger.debug(f'Getting all images took {time()-t0:.2f} seconds') | |
| 1059 thetas = np.linspace(tomo_fields.theta_range['start'], tomo_fields.theta_range['end'], | |
| 1060 tomo_fields.theta_range['num']) | |
| 1061 if not isinstance(tomo_stacks, list): | |
| 1062 horizontal_shifts = [horizontal_shifts] | |
| 1063 vertical_shifts = [vertical_shifts] | |
| 1064 tomo_stacks = [tomo_stacks] | |
| 1065 | |
| 1066 reduced_tomo_stacks = [] | |
| 1067 if self.galaxy_flag: | |
| 1068 path = 'tomo_reduce_plots' | |
| 1069 else: | |
| 1070 path = self.output_folder | |
| 1071 for i, tomo_stack in enumerate(tomo_stacks): | |
| 1072 # Resize the tomography images | |
| 1073 # Right now the range is the same for each set in the image stack. | |
| 1074 if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]): | |
| 1075 t0 = time() | |
| 1076 tomo_stack = tomo_stack[:,img_x_bounds[0]:img_x_bounds[1], | |
| 1077 img_y_bounds[0]:img_y_bounds[1]].astype('float64') | |
| 1078 logger.debug(f'Resizing tomography images took {time()-t0:.2f} seconds') | |
| 1079 | |
| 1080 # Subtract dark field | |
| 1081 if tdf is not None: | |
| 1082 t0 = time() | |
| 1083 with set_numexpr_threads(self.num_core): | |
| 1084 ne.evaluate('tomo_stack-tdf', out=tomo_stack) | |
| 1085 logger.debug(f'Subtracting dark field took {time()-t0:.2f} seconds') | |
| 1086 | |
| 1087 # Normalize | |
| 1088 t0 = time() | |
| 1089 with set_numexpr_threads(self.num_core): | |
| 1090 ne.evaluate('tomo_stack/tbf', out=tomo_stack, truediv=True) | |
| 1091 logger.debug(f'Normalizing took {time()-t0:.2f} seconds') | |
| 1092 | |
| 1093 # Remove non-positive values and linearize data | |
| 1094 t0 = time() | |
| 1095 cutoff = 1.e-6 | |
| 1096 with set_numexpr_threads(self.num_core): | |
| 1097 ne.evaluate('where(tomo_stack<cutoff, cutoff, tomo_stack)', out=tomo_stack) | |
| 1098 with set_numexpr_threads(self.num_core): | |
| 1099 ne.evaluate('-log(tomo_stack)', out=tomo_stack) | |
| 1100 logger.debug('Removing non-positive values and linearizing data took '+ | |
| 1101 f'{time()-t0:.2f} seconds') | |
| 1102 | |
| 1103 # Get rid of nans/infs that may be introduced by normalization | |
| 1104 t0 = time() | |
| 1105 np.where(np.isfinite(tomo_stack), tomo_stack, 0.) | |
| 1106 logger.debug(f'Remove nans/infs took {time()-t0:.2f} seconds') | |
| 1107 | |
| 1108 # Downsize tomography stack to smaller size | |
| 1109 # TODO use theta_skip as well | |
| 1110 tomo_stack = tomo_stack.astype('float32') | |
| 1111 if not self.test_mode: | |
| 1112 if len(tomo_stacks) == 1: | |
| 1113 title = f'red fullres theta {round(thetas[0], 2)+0}' | |
| 1114 else: | |
| 1115 title = f'red stack {i} fullres theta {round(thetas[0], 2)+0}' | |
| 1116 quick_imshow(tomo_stack[0,:,:], title=title, path=path, save_fig=self.save_figs, | |
| 1117 save_only=self.save_only, block=self.block) | |
| 1118 # if not self.block: | |
| 1119 # clear_imshow(title) | |
| 1120 if False and zoom_perc != 100: | |
| 1121 t0 = time() | |
| 1122 logger.debug(f'Zooming in ...') | |
| 1123 tomo_zoom_list = [] | |
| 1124 for j in range(tomo_stack.shape[0]): | |
| 1125 tomo_zoom = spi.zoom(tomo_stack[j,:,:], 0.01*zoom_perc) | |
| 1126 tomo_zoom_list.append(tomo_zoom) | |
| 1127 tomo_stack = np.stack([tomo_zoom for tomo_zoom in tomo_zoom_list]) | |
| 1128 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 1129 logger.info(f'Zooming in took {time()-t0:.2f} seconds') | |
| 1130 del tomo_zoom_list | |
| 1131 if not self.test_mode: | |
| 1132 title = f'red stack {zoom_perc}p theta {round(thetas[0], 2)+0}' | |
| 1133 quick_imshow(tomo_stack[0,:,:], title=title, path=path, save_fig=self.save_figs, | |
| 1134 save_only=self.save_only, block=self.block) | |
| 1135 # if not self.block: | |
| 1136 # clear_imshow(title) | |
| 1137 | |
| 1138 # Convert tomography stack from theta,row,column to row,theta,column | |
| 1139 t0 = time() | |
| 1140 tomo_stack = np.swapaxes(tomo_stack, 0, 1) | |
| 1141 logger.debug(f'Converting coordinate order took {time()-t0:.2f} seconds') | |
| 1142 | |
| 1143 # Save test data to file | |
| 1144 if self.test_mode: | |
| 1145 row_index = int(tomo_stack.shape[0]/2) | |
| 1146 np.savetxt(f'{self.output_folder}/red_stack_{i+1}.txt', tomo_stack[row_index,:,:], | |
| 1147 fmt='%.6e') | |
| 1148 | |
| 1149 # Combine resized stacks | |
| 1150 reduced_tomo_stacks.append(tomo_stack) | |
| 1151 | |
| 1152 # Add tomo field info to reduced data NXprocess | |
| 1153 reduced_data['rotation_angle'] = thetas | |
| 1154 reduced_data['x_translation'] = np.asarray(horizontal_shifts) | |
| 1155 reduced_data['z_translation'] = np.asarray(vertical_shifts) | |
| 1156 reduced_data.data['tomo_fields'] = np.asarray(reduced_tomo_stacks) | |
| 1157 | |
| 1158 if tdf is not None: | |
| 1159 del tdf | |
| 1160 del tbf | |
| 1161 | |
| 1162 return(reduced_data) | |
| 1163 | |
| 1164 def _find_center_one_plane(self, sinogram, row, thetas, eff_pixel_size, cross_sectional_dim, | |
| 1165 path=None, tol=0.1, num_core=1): | |
| 1166 """Find center for a single tomography plane. | |
| 1167 """ | |
| 1168 # Try automatic center finding routines for initial value | |
| 1169 # sinogram index order: theta,column | |
| 1170 # need column,theta for iradon, so take transpose | |
| 1171 sinogram_T = sinogram.T | |
| 1172 center = sinogram.shape[1]/2 | |
| 1173 | |
| 1174 # Try using Nghia Vo’s method | |
| 1175 t0 = time() | |
| 1176 if num_core > num_core_tomopy_limit: | |
| 1177 logger.debug(f'Running find_center_vo on {num_core_tomopy_limit} cores ...') | |
| 1178 tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core_tomopy_limit) | |
| 1179 else: | |
| 1180 logger.debug(f'Running find_center_vo on {num_core} cores ...') | |
| 1181 tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core) | |
| 1182 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 1183 logger.info(f'Finding the center using Nghia Vo’s method took {time()-t0:.2f} seconds') | |
| 1184 center_offset_vo = tomo_center-center | |
| 1185 logger.info(f'Center at row {row} using Nghia Vo’s method = {center_offset_vo:.2f}') | |
| 1186 t0 = time() | |
| 1187 logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...') | |
| 1188 recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas, | |
| 1189 eff_pixel_size, cross_sectional_dim, False, num_core) | |
| 1190 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 1191 logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds') | |
| 1192 | |
| 1193 title = f'edges row{row} center offset{center_offset_vo:.2f} Vo' | |
| 1194 self._plot_edges_one_plane(recon_plane, title, path=path) | |
| 1195 | |
| 1196 # Try using phase correlation method | |
| 1197 # if input_yesno('Try finding center using phase correlation (y/n)?', 'n'): | |
| 1198 # t0 = time() | |
| 1199 # logger.debug(f'Running find_center_pc ...') | |
| 1200 # tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=0.1, rotc_guess=tomo_center) | |
| 1201 # error = 1. | |
| 1202 # while error > tol: | |
| 1203 # prev = tomo_center | |
| 1204 # tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=tol, | |
| 1205 # rotc_guess=tomo_center) | |
| 1206 # error = np.abs(tomo_center-prev) | |
| 1207 # logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 1208 # logger.info('Finding the center using the phase correlation method took '+ | |
| 1209 # f'{time()-t0:.2f} seconds') | |
| 1210 # center_offset = tomo_center-center | |
| 1211 # print(f'Center at row {row} using phase correlation = {center_offset:.2f}') | |
| 1212 # t0 = time() | |
| 1213 # logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...') | |
| 1214 # recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas, | |
| 1215 # eff_pixel_size, cross_sectional_dim, False, num_core) | |
| 1216 # logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 1217 # logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds') | |
| 1218 # | |
| 1219 # title = f'edges row{row} center_offset{center_offset:.2f} PC' | |
| 1220 # self._plot_edges_one_plane(recon_plane, title, path=path) | |
| 1221 | |
| 1222 # Select center location | |
| 1223 # if input_yesno('Accept a center location (y) or continue search (n)?', 'y'): | |
| 1224 if True: | |
| 1225 # center_offset = input_num(' Enter chosen center offset', ge=-center, le=center, | |
| 1226 # default=center_offset_vo) | |
| 1227 center_offset = center_offset_vo | |
| 1228 del sinogram_T | |
| 1229 del recon_plane | |
| 1230 return float(center_offset) | |
| 1231 | |
| 1232 # perform center finding search | |
| 1233 while True: | |
| 1234 center_offset_low = input_int('\nEnter lower bound for center offset', ge=-center, | |
| 1235 le=center) | |
| 1236 center_offset_upp = input_int('Enter upper bound for center offset', | |
| 1237 ge=center_offset_low, le=center) | |
| 1238 if center_offset_upp == center_offset_low: | |
| 1239 center_offset_step = 1 | |
| 1240 else: | |
| 1241 center_offset_step = input_int('Enter step size for center offset search', ge=1, | |
| 1242 le=center_offset_upp-center_offset_low) | |
| 1243 num_center_offset = 1+int((center_offset_upp-center_offset_low)/center_offset_step) | |
| 1244 center_offsets = np.linspace(center_offset_low, center_offset_upp, num_center_offset) | |
| 1245 for center_offset in center_offsets: | |
| 1246 if center_offset == center_offset_vo: | |
| 1247 continue | |
| 1248 t0 = time() | |
| 1249 logger.debug(f'Running _reconstruct_one_plane on {num_core} cores ...') | |
| 1250 recon_plane = self._reconstruct_one_plane(sinogram_T, center_offset+center, thetas, | |
| 1251 eff_pixel_size, cross_sectional_dim, False, num_core) | |
| 1252 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 1253 logger.info(f'Reconstructing center_offset {center_offset} took '+ | |
| 1254 f'{time()-t0:.2f} seconds') | |
| 1255 title = f'edges row{row} center_offset{center_offset:.2f}' | |
| 1256 self._plot_edges_one_plane(recon_plane, title, path=path) | |
| 1257 if input_int('\nContinue (0) or end the search (1)', ge=0, le=1): | |
| 1258 break | |
| 1259 | |
| 1260 del sinogram_T | |
| 1261 del recon_plane | |
| 1262 center_offset = input_num(' Enter chosen center offset', ge=-center, le=center) | |
| 1263 return float(center_offset) | |
| 1264 | |
| 1265 def _reconstruct_one_plane(self, tomo_plane_T, center, thetas, eff_pixel_size, | |
| 1266 cross_sectional_dim, plot_sinogram=True, num_core=1): | |
| 1267 """Invert the sinogram for a single tomography plane. | |
| 1268 """ | |
| 1269 # tomo_plane_T index order: column,theta | |
| 1270 assert(0 <= center < tomo_plane_T.shape[0]) | |
| 1271 center_offset = center-tomo_plane_T.shape[0]/2 | |
| 1272 two_offset = 2*int(np.round(center_offset)) | |
| 1273 two_offset_abs = np.abs(two_offset) | |
| 1274 max_rad = int(0.55*(cross_sectional_dim/eff_pixel_size)) # 10% slack to avoid edge effects | |
| 1275 if max_rad > 0.5*tomo_plane_T.shape[0]: | |
| 1276 max_rad = 0.5*tomo_plane_T.shape[0] | |
| 1277 dist_from_edge = max(1, int(np.floor((tomo_plane_T.shape[0]-two_offset_abs)/2.)-max_rad)) | |
| 1278 if two_offset >= 0: | |
| 1279 logger.debug(f'sinogram range = [{two_offset+dist_from_edge}, {-dist_from_edge}]') | |
| 1280 sinogram = tomo_plane_T[two_offset+dist_from_edge:-dist_from_edge,:] | |
| 1281 else: | |
| 1282 logger.debug(f'sinogram range = [{dist_from_edge}, {two_offset-dist_from_edge}]') | |
| 1283 sinogram = tomo_plane_T[dist_from_edge:two_offset-dist_from_edge,:] | |
| 1284 if not self.galaxy_flag and plot_sinogram: | |
| 1285 quick_imshow(sinogram.T, f'sinogram center offset{center_offset:.2f}', aspect='auto', | |
| 1286 path=self.output_folder, save_fig=self.save_figs, save_only=self.save_only, | |
| 1287 block=self.block) | |
| 1288 | |
| 1289 # Inverting sinogram | |
| 1290 t0 = time() | |
| 1291 recon_sinogram = iradon(sinogram, theta=thetas, circle=True) | |
| 1292 logger.debug(f'Inverting sinogram took {time()-t0:.2f} seconds') | |
| 1293 del sinogram | |
| 1294 | |
| 1295 # Performing Gaussian filtering and removing ring artifacts | |
| 1296 recon_parameters = None#self.config.get('recon_parameters') | |
| 1297 if recon_parameters is None: | |
| 1298 sigma = 1.0 | |
| 1299 ring_width = 15 | |
| 1300 else: | |
| 1301 sigma = recon_parameters.get('gaussian_sigma', 1.0) | |
| 1302 if not is_num(sigma, ge=0.0): | |
| 1303 logger.warning(f'Invalid gaussian_sigma ({sigma}) in _reconstruct_one_plane, '+ | |
| 1304 'set to a default value of 1.0') | |
| 1305 sigma = 1.0 | |
| 1306 ring_width = recon_parameters.get('ring_width', 15) | |
| 1307 if not is_int(ring_width, ge=0): | |
| 1308 logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+ | |
| 1309 'set to a default value of 15') | |
| 1310 ring_width = 15 | |
| 1311 t0 = time() | |
| 1312 recon_sinogram = spi.gaussian_filter(recon_sinogram, sigma, mode='nearest') | |
| 1313 recon_clean = np.expand_dims(recon_sinogram, axis=0) | |
| 1314 del recon_sinogram | |
| 1315 recon_clean = tomopy.misc.corr.remove_ring(recon_clean, rwidth=ring_width, ncore=num_core) | |
| 1316 logger.debug(f'Filtering and removing ring artifacts took {time()-t0:.2f} seconds') | |
| 1317 | |
| 1318 return recon_clean | |
| 1319 | |
| 1320 def _plot_edges_one_plane(self, recon_plane, title, path=None): | |
| 1321 vis_parameters = None#self.config.get('vis_parameters') | |
| 1322 if vis_parameters is None: | |
| 1323 weight = 0.1 | |
| 1324 else: | |
| 1325 weight = vis_parameters.get('denoise_weight', 0.1) | |
| 1326 if not is_num(weight, ge=0.0): | |
| 1327 logger.warning(f'Invalid weight ({weight}) in _plot_edges_one_plane, '+ | |
| 1328 'set to a default value of 0.1') | |
| 1329 weight = 0.1 | |
| 1330 edges = denoise_tv_chambolle(recon_plane, weight=weight) | |
| 1331 vmax = np.max(edges[0,:,:]) | |
| 1332 vmin = -vmax | |
| 1333 if path is None: | |
| 1334 path = self.output_folder | |
| 1335 quick_imshow(edges[0,:,:], f'{title} coolwarm', path=path, cmap='coolwarm', | |
| 1336 save_fig=self.save_figs, save_only=self.save_only, block=self.block) | |
| 1337 quick_imshow(edges[0,:,:], f'{title} gray', path=path, cmap='gray', vmin=vmin, vmax=vmax, | |
| 1338 save_fig=self.save_figs, save_only=self.save_only, block=self.block) | |
| 1339 del edges | |
| 1340 | |
| 1341 def _reconstruct_one_tomo_stack(self, tomo_stack, thetas, center_offsets=[], num_core=1, | |
| 1342 algorithm='gridrec'): | |
| 1343 """Reconstruct a single tomography stack. | |
| 1344 """ | |
| 1345 # tomo_stack order: row,theta,column | |
| 1346 # input thetas must be in degrees | |
| 1347 # centers_offset: tomography axis shift in pixels relative to column center | |
| 1348 # RV should we remove stripes? | |
| 1349 # https://tomopy.readthedocs.io/en/latest/api/tomopy.prep.stripe.html | |
| 1350 # RV should we remove rings? | |
| 1351 # https://tomopy.readthedocs.io/en/latest/api/tomopy.misc.corr.html | |
| 1352 # RV: Add an option to do (extra) secondary iterations later or to do some sort of convergence test? | |
| 1353 if not len(center_offsets): | |
| 1354 centers = np.zeros((tomo_stack.shape[0])) | |
| 1355 elif len(center_offsets) == 2: | |
| 1356 centers = np.linspace(center_offsets[0], center_offsets[1], tomo_stack.shape[0]) | |
| 1357 else: | |
| 1358 if center_offsets.size != tomo_stack.shape[0]: | |
| 1359 raise ValueError('center_offsets dimension mismatch in reconstruct_one_tomo_stack') | |
| 1360 centers = center_offsets | |
| 1361 centers += tomo_stack.shape[2]/2 | |
| 1362 | |
| 1363 # Get reconstruction parameters | |
| 1364 recon_parameters = None#self.config.get('recon_parameters') | |
| 1365 if recon_parameters is None: | |
| 1366 sigma = 2.0 | |
| 1367 secondary_iters = 0 | |
| 1368 ring_width = 15 | |
| 1369 else: | |
| 1370 sigma = recon_parameters.get('stripe_fw_sigma', 2.0) | |
| 1371 if not is_num(sigma, ge=0): | |
| 1372 logger.warning(f'Invalid stripe_fw_sigma ({sigma}) in '+ | |
| 1373 '_reconstruct_one_tomo_stack, set to a default value of 2.0') | |
| 1374 ring_width = 15 | |
| 1375 secondary_iters = recon_parameters.get('secondary_iters', 0) | |
| 1376 if not is_int(secondary_iters, ge=0): | |
| 1377 logger.warning(f'Invalid secondary_iters ({secondary_iters}) in '+ | |
| 1378 '_reconstruct_one_tomo_stack, set to a default value of 0 (skip them)') | |
| 1379 ring_width = 0 | |
| 1380 ring_width = recon_parameters.get('ring_width', 15) | |
| 1381 if not is_int(ring_width, ge=0): | |
| 1382 logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+ | |
| 1383 'set to a default value of 15') | |
| 1384 ring_width = 15 | |
| 1385 | |
| 1386 # Remove horizontal stripe | |
| 1387 t0 = time() | |
| 1388 if num_core > num_core_tomopy_limit: | |
| 1389 logger.debug('Running remove_stripe_fw on {num_core_tomopy_limit} cores ...') | |
| 1390 tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma, | |
| 1391 ncore=num_core_tomopy_limit) | |
| 1392 else: | |
| 1393 logger.debug(f'Running remove_stripe_fw on {num_core} cores ...') | |
| 1394 tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma, | |
| 1395 ncore=num_core) | |
| 1396 logger.debug(f'... tomopy.prep.stripe.remove_stripe_fw took {time()-t0:.2f} seconds') | |
| 1397 | |
| 1398 # Perform initial image reconstruction | |
| 1399 logger.debug('Performing initial image reconstruction') | |
| 1400 t0 = time() | |
| 1401 logger.debug(f'Running recon on {num_core} cores ...') | |
| 1402 tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers, | |
| 1403 sinogram_order=True, algorithm=algorithm, ncore=num_core) | |
| 1404 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 1405 logger.info(f'Performing initial image reconstruction took {time()-t0:.2f} seconds') | |
| 1406 | |
| 1407 # Run optional secondary iterations | |
| 1408 if secondary_iters > 0: | |
| 1409 logger.debug(f'Running {secondary_iters} secondary iterations') | |
| 1410 #options = {'method':'SIRT_CUDA', 'proj_type':'cuda', 'num_iter':secondary_iters} | |
| 1411 #RV: doesn't work for me: | |
| 1412 #"Error: CUDA error 803: system has unsupported display driver/cuda driver combination." | |
| 1413 #options = {'method':'SIRT', 'proj_type':'linear', 'MinConstraint': 0, 'num_iter':secondary_iters} | |
| 1414 #SIRT did not finish while running overnight | |
| 1415 #options = {'method':'SART', 'proj_type':'linear', 'num_iter':secondary_iters} | |
| 1416 options = {'method':'SART', 'proj_type':'linear', 'MinConstraint': 0, | |
| 1417 'num_iter':secondary_iters} | |
| 1418 t0 = time() | |
| 1419 logger.debug(f'Running recon on {num_core} cores ...') | |
| 1420 tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers, | |
| 1421 init_recon=tomo_recon_stack, options=options, sinogram_order=True, | |
| 1422 algorithm=tomopy.astra, ncore=num_core) | |
| 1423 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 1424 logger.info(f'Performing secondary iterations took {time()-t0:.2f} seconds') | |
| 1425 | |
| 1426 # Remove ring artifacts | |
| 1427 t0 = time() | |
| 1428 tomopy.misc.corr.remove_ring(tomo_recon_stack, rwidth=ring_width, out=tomo_recon_stack, | |
| 1429 ncore=num_core) | |
| 1430 logger.debug(f'Removing ring artifacts took {time()-t0:.2f} seconds') | |
| 1431 | |
| 1432 return tomo_recon_stack | |
| 1433 | |
| 1434 def _resize_reconstructed_data(self, data, z_only=False): | |
| 1435 """Resize the reconstructed tomography data. | |
| 1436 """ | |
| 1437 # Data order: row(z),x,y or stack,row(z),x,y | |
| 1438 if isinstance(data, list): | |
| 1439 for stack in data: | |
| 1440 assert(stack.ndim == 3) | |
| 1441 num_tomo_stacks = len(data) | |
| 1442 tomo_recon_stacks = data | |
| 1443 else: | |
| 1444 assert(data.ndim == 3) | |
| 1445 num_tomo_stacks = 1 | |
| 1446 tomo_recon_stacks = [data] | |
| 1447 | |
| 1448 if z_only: | |
| 1449 x_bounds = None | |
| 1450 y_bounds = None | |
| 1451 else: | |
| 1452 # Selecting x bounds (in yz-plane) | |
| 1453 tomosum = 0 | |
| 1454 [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,2)) | |
| 1455 for i in range(num_tomo_stacks)] | |
| 1456 select_x_bounds = input_yesno('\nDo you want to change the image x-bounds (y/n)?', 'y') | |
| 1457 if not select_x_bounds: | |
| 1458 x_bounds = None | |
| 1459 else: | |
| 1460 accept = False | |
| 1461 index_ranges = None | |
| 1462 while not accept: | |
| 1463 mask, x_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, | |
| 1464 title='select x data range', legend='recon stack sum yz') | |
| 1465 while len(x_bounds) != 1: | |
| 1466 print('Please select exactly one continuous range') | |
| 1467 mask, x_bounds = draw_mask_1d(tomosum, title='select x data range', | |
| 1468 legend='recon stack sum yz') | |
| 1469 x_bounds = x_bounds[0] | |
| 1470 # quick_plot(tomosum, vlines=x_bounds, title='recon stack sum yz') | |
| 1471 # print(f'x_bounds = {x_bounds} (lower bound inclusive, upper bound '+ | |
| 1472 # 'exclusive)') | |
| 1473 # accept = input_yesno('Accept these bounds (y/n)?', 'y') | |
| 1474 accept = True | |
| 1475 logger.debug(f'x_bounds = {x_bounds}') | |
| 1476 | |
| 1477 # Selecting y bounds (in xz-plane) | |
| 1478 tomosum = 0 | |
| 1479 [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,1)) | |
| 1480 for i in range(num_tomo_stacks)] | |
| 1481 select_y_bounds = input_yesno('\nDo you want to change the image y-bounds (y/n)?', 'y') | |
| 1482 if not select_y_bounds: | |
| 1483 y_bounds = None | |
| 1484 else: | |
| 1485 accept = False | |
| 1486 index_ranges = None | |
| 1487 while not accept: | |
| 1488 mask, y_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, | |
| 1489 title='select x data range', legend='recon stack sum xz') | |
| 1490 while len(y_bounds) != 1: | |
| 1491 print('Please select exactly one continuous range') | |
| 1492 mask, y_bounds = draw_mask_1d(tomosum, title='select x data range', | |
| 1493 legend='recon stack sum xz') | |
| 1494 y_bounds = y_bounds[0] | |
| 1495 # quick_plot(tomosum, vlines=y_bounds, title='recon stack sum xz') | |
| 1496 # print(f'y_bounds = {y_bounds} (lower bound inclusive, upper bound '+ | |
| 1497 # 'exclusive)') | |
| 1498 # accept = input_yesno('Accept these bounds (y/n)?', 'y') | |
| 1499 accept = True | |
| 1500 logger.debug(f'y_bounds = {y_bounds}') | |
| 1501 | |
| 1502 # Selecting z bounds (in xy-plane) (only valid for a single image set) | |
| 1503 if num_tomo_stacks != 1: | |
| 1504 z_bounds = None | |
| 1505 else: | |
| 1506 tomosum = 0 | |
| 1507 [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(1,2)) | |
| 1508 for i in range(num_tomo_stacks)] | |
| 1509 select_z_bounds = input_yesno('Do you want to change the image z-bounds (y/n)?', 'n') | |
| 1510 if not select_z_bounds: | |
| 1511 z_bounds = None | |
| 1512 else: | |
| 1513 accept = False | |
| 1514 index_ranges = None | |
| 1515 while not accept: | |
| 1516 mask, z_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, | |
| 1517 title='select x data range', legend='recon stack sum xy') | |
| 1518 while len(z_bounds) != 1: | |
| 1519 print('Please select exactly one continuous range') | |
| 1520 mask, z_bounds = draw_mask_1d(tomosum, title='select x data range', | |
| 1521 legend='recon stack sum xy') | |
| 1522 z_bounds = z_bounds[0] | |
| 1523 # quick_plot(tomosum, vlines=z_bounds, title='recon stack sum xy') | |
| 1524 # print(f'z_bounds = {z_bounds} (lower bound inclusive, upper bound '+ | |
| 1525 # 'exclusive)') | |
| 1526 # accept = input_yesno('Accept these bounds (y/n)?', 'y') | |
| 1527 accept = True | |
| 1528 logger.debug(f'z_bounds = {z_bounds}') | |
| 1529 | |
| 1530 return(x_bounds, y_bounds, z_bounds) | |
| 1531 | |
| 1532 | |
| 1533 def run_tomo(input_file:str, output_file:str, modes:list[str], center_file=None, num_core=-1, | |
| 1534 output_folder='.', save_figs='no', test_mode=False) -> None: | |
| 1535 | |
| 1536 if test_mode: | |
| 1537 logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s' | |
| 1538 level = logging.getLevelName('INFO') | |
| 1539 logging.basicConfig(filename=f'{output_folder}/tomo.log', filemode='w', | |
| 1540 format=logging_format, level=level, force=True) | |
| 1541 logger.info(f'input_file = {input_file}') | |
| 1542 logger.info(f'center_file = {center_file}') | |
| 1543 logger.info(f'output_file = {output_file}') | |
| 1544 logger.debug(f'modes= {modes}') | |
| 1545 logger.debug(f'num_core= {num_core}') | |
| 1546 logger.info(f'output_folder = {output_folder}') | |
| 1547 logger.info(f'save_figs = {save_figs}') | |
| 1548 logger.info(f'test_mode = {test_mode}') | |
| 1549 | |
| 1550 # Check for correction modes | |
| 1551 if modes is None: | |
| 1552 modes = ['all'] | |
| 1553 logger.debug(f'modes {type(modes)} = {modes}') | |
| 1554 | |
| 1555 # Instantiate Tomo object | |
| 1556 tomo = Tomo(num_core=num_core, output_folder=output_folder, save_figs=save_figs, | |
| 1557 test_mode=test_mode) | |
| 1558 | |
| 1559 # Read input file | |
| 1560 data = tomo.read(input_file) | |
| 1561 | |
| 1562 # Generate reduced tomography images | |
| 1563 if 'reduce_data' in modes or 'all' in modes: | |
| 1564 data = tomo.gen_reduced_data(data) | |
| 1565 | |
| 1566 # Find rotation axis centers for the tomography stacks. | |
| 1567 center_data = None | |
| 1568 if 'find_center' in modes or 'all' in modes: | |
| 1569 center_data = tomo.find_centers(data) | |
| 1570 | |
| 1571 # Reconstruct tomography stacks | |
| 1572 if 'reconstruct_data' in modes or 'all' in modes: | |
| 1573 if center_data is None: | |
| 1574 # Read input file | |
| 1575 center_data = tomo.read(center_file) | |
| 1576 data = tomo.reconstruct_data(data, center_data) | |
| 1577 center_data = None | |
| 1578 | |
| 1579 # Combine reconstructed tomography stacks | |
| 1580 if 'combine_data' in modes or 'all' in modes: | |
| 1581 data = tomo.combine_data(data) | |
| 1582 | |
| 1583 # Write output file | |
| 1584 if not test_mode: | |
| 1585 if center_data is None: | |
| 1586 data = tomo.write(data, output_file) | |
| 1587 else: | |
| 1588 data = tomo.write(center_data, output_file) | |
| 1589 | 
