Mercurial > repos > rv43 > test_tomo_reconstruct
comparison workflow/run_tomo.py @ 0:98e23dff1de2 draft default tip
planemo upload for repository https://github.com/rolfverberg/galaxytools commit f8c4bdb31c20c468045ad5e6eb255a293244bc6c-dirty
| author | rv43 |
|---|---|
| date | Tue, 21 Mar 2023 16:22:42 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:98e23dff1de2 |
|---|---|
| 1 #!/usr/bin/env python3 | |
| 2 | |
| 3 import logging | |
| 4 logger = logging.getLogger(__name__) | |
| 5 | |
| 6 import numpy as np | |
| 7 try: | |
| 8 import numexpr as ne | |
| 9 except: | |
| 10 pass | |
| 11 try: | |
| 12 import scipy.ndimage as spi | |
| 13 except: | |
| 14 pass | |
| 15 | |
| 16 from multiprocessing import cpu_count | |
| 17 from nexusformat.nexus import * | |
| 18 from os import mkdir | |
| 19 from os import path as os_path | |
| 20 try: | |
| 21 from skimage.transform import iradon | |
| 22 except: | |
| 23 pass | |
| 24 try: | |
| 25 from skimage.restoration import denoise_tv_chambolle | |
| 26 except: | |
| 27 pass | |
| 28 from time import time | |
| 29 try: | |
| 30 import tomopy | |
| 31 except: | |
| 32 pass | |
| 33 from yaml import safe_load, safe_dump | |
| 34 | |
| 35 try: | |
| 36 from msnctools.fit import Fit | |
| 37 except: | |
| 38 from fit import Fit | |
| 39 try: | |
| 40 from msnctools.general import illegal_value, is_int, is_int_pair, is_num, is_index_range, \ | |
| 41 input_int, input_num, input_yesno, input_menu, draw_mask_1d, select_image_bounds, \ | |
| 42 select_one_image_bound, clear_imshow, quick_imshow, clear_plot, quick_plot | |
| 43 except: | |
| 44 from general import illegal_value, is_int, is_int_pair, is_num, is_index_range, \ | |
| 45 input_int, input_num, input_yesno, input_menu, draw_mask_1d, select_image_bounds, \ | |
| 46 select_one_image_bound, clear_imshow, quick_imshow, clear_plot, quick_plot | |
| 47 | |
| 48 try: | |
| 49 from workflow.models import import_scanparser, FlatField, TomoField, TomoWorkflow | |
| 50 from workflow.__version__ import __version__ | |
| 51 except: | |
| 52 pass | |
| 53 | |
| 54 num_core_tomopy_limit = 24 | |
| 55 | |
| 56 def nxcopy(nxobject:NXobject, exclude_nxpaths:list[str]=[], nxpath_prefix:str='') -> NXobject: | |
| 57 '''Function that returns a copy of a nexus object, optionally exluding certain child items. | |
| 58 | |
| 59 :param nxobject: the original nexus object to return a "copy" of | |
| 60 :type nxobject: nexusformat.nexus.NXobject | |
| 61 :param exlude_nxpaths: a list of paths to child nexus objects that | |
| 62 should be exluded from the returned "copy", defaults to `[]` | |
| 63 :type exclude_nxpaths: list[str], optional | |
| 64 :param nxpath_prefix: For use in recursive calls from inside this | |
| 65 function only! | |
| 66 :type nxpath_prefix: str | |
| 67 :return: a copy of `nxobject` with some children optionally exluded. | |
| 68 :rtype: NXobject | |
| 69 ''' | |
| 70 | |
| 71 nxobject_copy = nxobject.__class__() | |
| 72 if not len(nxpath_prefix): | |
| 73 if 'default' in nxobject.attrs: | |
| 74 nxobject_copy.attrs['default'] = nxobject.attrs['default'] | |
| 75 else: | |
| 76 for k, v in nxobject.attrs.items(): | |
| 77 nxobject_copy.attrs[k] = v | |
| 78 | |
| 79 for k, v in nxobject.items(): | |
| 80 nxpath = os_path.join(nxpath_prefix, k) | |
| 81 | |
| 82 if nxpath in exclude_nxpaths: | |
| 83 continue | |
| 84 | |
| 85 if isinstance(v, NXgroup): | |
| 86 nxobject_copy[k] = nxcopy(v, exclude_nxpaths=exclude_nxpaths, | |
| 87 nxpath_prefix=os_path.join(nxpath_prefix, k)) | |
| 88 else: | |
| 89 nxobject_copy[k] = v | |
| 90 | |
| 91 return(nxobject_copy) | |
| 92 | |
| 93 class set_numexpr_threads: | |
| 94 | |
| 95 def __init__(self, num_core): | |
| 96 if num_core is None or num_core < 1 or num_core > cpu_count(): | |
| 97 self.num_core = cpu_count() | |
| 98 else: | |
| 99 self.num_core = num_core | |
| 100 | |
| 101 def __enter__(self): | |
| 102 self.num_core_org = ne.set_num_threads(self.num_core) | |
| 103 | |
| 104 def __exit__(self, exc_type, exc_value, traceback): | |
| 105 ne.set_num_threads(self.num_core_org) | |
| 106 | |
| 107 class Tomo: | |
| 108 """Processing tomography data with misalignment. | |
| 109 """ | |
| 110 def __init__(self, galaxy_flag=False, num_core=-1, output_folder='.', save_figs=None, | |
| 111 test_mode=False): | |
| 112 """Initialize with optional config input file or dictionary | |
| 113 """ | |
| 114 if not isinstance(galaxy_flag, bool): | |
| 115 raise ValueError(f'Invalid parameter galaxy_flag ({galaxy_flag})') | |
| 116 self.galaxy_flag = galaxy_flag | |
| 117 self.num_core = num_core | |
| 118 if self.galaxy_flag: | |
| 119 if output_folder != '.': | |
| 120 logger.warning('Ignoring output_folder in galaxy mode') | |
| 121 self.output_folder = '.' | |
| 122 if test_mode != False: | |
| 123 logger.warning('Ignoring test_mode in galaxy mode') | |
| 124 self.test_mode = False | |
| 125 if save_figs is not None: | |
| 126 logger.warning('Ignoring save_figs in galaxy mode') | |
| 127 save_figs = 'only' | |
| 128 else: | |
| 129 self.output_folder = os_path.abspath(output_folder) | |
| 130 if not os_path.isdir(output_folder): | |
| 131 mkdir(os_path.abspath(output_folder)) | |
| 132 if not isinstance(test_mode, bool): | |
| 133 raise ValueError(f'Invalid parameter test_mode ({test_mode})') | |
| 134 self.test_mode = test_mode | |
| 135 if save_figs is None: | |
| 136 save_figs = 'no' | |
| 137 self.test_config = {} | |
| 138 if self.test_mode: | |
| 139 if save_figs != 'only': | |
| 140 logger.warning('Ignoring save_figs in test mode') | |
| 141 save_figs = 'only' | |
| 142 if save_figs == 'only': | |
| 143 self.save_only = True | |
| 144 self.save_figs = True | |
| 145 elif save_figs == 'yes': | |
| 146 self.save_only = False | |
| 147 self.save_figs = True | |
| 148 elif save_figs == 'no': | |
| 149 self.save_only = False | |
| 150 self.save_figs = False | |
| 151 else: | |
| 152 raise ValueError(f'Invalid parameter save_figs ({save_figs})') | |
| 153 if self.save_only: | |
| 154 self.block = False | |
| 155 else: | |
| 156 self.block = True | |
| 157 if self.num_core == -1: | |
| 158 self.num_core = cpu_count() | |
| 159 if not is_int(self.num_core, gt=0, log=False): | |
| 160 raise ValueError(f'Invalid parameter num_core ({num_core})') | |
| 161 if self.num_core > cpu_count(): | |
| 162 logger.warning(f'num_core = {self.num_core} is larger than the number of available ' | |
| 163 f'processors and reduced to {cpu_count()}') | |
| 164 self.num_core= cpu_count() | |
| 165 | |
| 166 def read(self, filename): | |
| 167 logger.info(f'looking for {filename}') | |
| 168 if self.galaxy_flag: | |
| 169 try: | |
| 170 with open(filename, 'r') as f: | |
| 171 config = safe_load(f) | |
| 172 return(config) | |
| 173 except: | |
| 174 try: | |
| 175 with NXFile(filename, mode='r') as nxfile: | |
| 176 nxroot = nxfile.readfile() | |
| 177 return(nxroot) | |
| 178 except: | |
| 179 raise ValueError(f'Unable to open ({filename})') | |
| 180 else: | |
| 181 extension = os_path.splitext(filename)[1] | |
| 182 if extension == '.yml' or extension == '.yaml': | |
| 183 with open(filename, 'r') as f: | |
| 184 config = safe_load(f) | |
| 185 # if len(config) > 1: | |
| 186 # raise ValueError(f'Multiple root entries in {filename} not yet implemented') | |
| 187 # if len(list(config.values())[0]) > 1: | |
| 188 # raise ValueError(f'Multiple sample maps in {filename} not yet implemented') | |
| 189 return(config) | |
| 190 elif extension == '.nxs': | |
| 191 with NXFile(filename, mode='r') as nxfile: | |
| 192 nxroot = nxfile.readfile() | |
| 193 return(nxroot) | |
| 194 else: | |
| 195 raise ValueError(f'Invalid filename extension ({extension})') | |
| 196 | |
| 197 def write(self, data, filename): | |
| 198 extension = os_path.splitext(filename)[1] | |
| 199 if extension == '.yml' or extension == '.yaml': | |
| 200 with open(filename, 'w') as f: | |
| 201 safe_dump(data, f) | |
| 202 elif extension == '.nxs' or extension == '.nex': | |
| 203 data.save(filename, mode='w') | |
| 204 elif extension == '.nc': | |
| 205 data.to_netcdf(os_path=filename) | |
| 206 else: | |
| 207 raise ValueError(f'Invalid filename extension ({extension})') | |
| 208 | |
| 209 def gen_reduced_data(self, data, img_x_bounds=None): | |
| 210 """Generate the reduced tomography images. | |
| 211 """ | |
| 212 logger.info('Generate the reduced tomography images') | |
| 213 | |
| 214 # Create plot galaxy path directory if needed | |
| 215 if self.galaxy_flag and not os_path.exists('tomo_reduce_plots'): | |
| 216 mkdir('tomo_reduce_plots') | |
| 217 | |
| 218 if isinstance(data, dict): | |
| 219 # Create Nexus format object from input dictionary | |
| 220 wf = TomoWorkflow(**data) | |
| 221 if len(wf.sample_maps) > 1: | |
| 222 raise ValueError(f'Multiple sample maps not yet implemented') | |
| 223 # print(f'\nwf:\n{wf}\n') | |
| 224 nxroot = NXroot() | |
| 225 t0 = time() | |
| 226 for sample_map in wf.sample_maps: | |
| 227 logger.info(f'Start constructing the {sample_map.title} map.') | |
| 228 import_scanparser(sample_map.station) | |
| 229 sample_map.construct_nxentry(nxroot, include_raw_data=False) | |
| 230 logger.info(f'Constructed all sample maps in {time()-t0:.2f} seconds.') | |
| 231 nxentry = nxroot[nxroot.attrs['default']] | |
| 232 # Get test mode configuration info | |
| 233 if self.test_mode: | |
| 234 self.test_config = data['sample_maps'][0]['test_mode'] | |
| 235 elif isinstance(data, NXroot): | |
| 236 nxentry = data[data.attrs['default']] | |
| 237 else: | |
| 238 raise ValueError(f'Invalid parameter data ({data})') | |
| 239 | |
| 240 # Create an NXprocess to store data reduction (meta)data | |
| 241 reduced_data = NXprocess() | |
| 242 | |
| 243 # Generate dark field | |
| 244 if 'dark_field' in nxentry['spec_scans']: | |
| 245 reduced_data = self._gen_dark(nxentry, reduced_data) | |
| 246 | |
| 247 # Generate bright field | |
| 248 reduced_data = self._gen_bright(nxentry, reduced_data) | |
| 249 | |
| 250 # Set vertical detector bounds for image stack | |
| 251 img_x_bounds = self._set_detector_bounds(nxentry, reduced_data, img_x_bounds=img_x_bounds) | |
| 252 logger.info(f'img_x_bounds = {img_x_bounds}') | |
| 253 reduced_data['img_x_bounds'] = img_x_bounds | |
| 254 | |
| 255 # Set zoom and/or theta skip to reduce memory the requirement | |
| 256 zoom_perc, num_theta_skip = self._set_zoom_or_skip() | |
| 257 if zoom_perc is not None: | |
| 258 reduced_data.attrs['zoom_perc'] = zoom_perc | |
| 259 if num_theta_skip is not None: | |
| 260 reduced_data.attrs['num_theta_skip'] = num_theta_skip | |
| 261 | |
| 262 # Generate reduced tomography fields | |
| 263 reduced_data = self._gen_tomo(nxentry, reduced_data) | |
| 264 | |
| 265 # Create a copy of the input Nexus object and remove raw and any existing reduced data | |
| 266 if isinstance(data, NXroot): | |
| 267 exclude_items = [f'{nxentry._name}/reduced_data/data', | |
| 268 f'{nxentry._name}/instrument/detector/data', | |
| 269 f'{nxentry._name}/instrument/detector/image_key', | |
| 270 f'{nxentry._name}/instrument/detector/sequence_number', | |
| 271 f'{nxentry._name}/sample/rotation_angle', | |
| 272 f'{nxentry._name}/sample/x_translation', | |
| 273 f'{nxentry._name}/sample/z_translation', | |
| 274 f'{nxentry._name}/data/data', | |
| 275 f'{nxentry._name}/data/image_key', | |
| 276 f'{nxentry._name}/data/rotation_angle', | |
| 277 f'{nxentry._name}/data/x_translation', | |
| 278 f'{nxentry._name}/data/z_translation'] | |
| 279 nxroot = nxcopy(data, exclude_nxpaths=exclude_items) | |
| 280 nxentry = nxroot[nxroot.attrs['default']] | |
| 281 | |
| 282 # Add the reduced data NXprocess | |
| 283 nxentry.reduced_data = reduced_data | |
| 284 | |
| 285 if 'data' not in nxentry: | |
| 286 nxentry.data = NXdata() | |
| 287 nxentry.attrs['default'] = 'data' | |
| 288 nxentry.data.makelink(nxentry.reduced_data.data.tomo_fields, name='reduced_data') | |
| 289 nxentry.data.makelink(nxentry.reduced_data.rotation_angle, name='rotation_angle') | |
| 290 nxentry.data.attrs['signal'] = 'reduced_data' | |
| 291 | |
| 292 return(nxroot) | |
| 293 | |
| 294 def find_centers(self, nxroot, center_rows=None, center_stack_index=None): | |
| 295 """Find the calibrated center axis info | |
| 296 """ | |
| 297 logger.info('Find the calibrated center axis info') | |
| 298 | |
| 299 if not isinstance(nxroot, NXroot): | |
| 300 raise ValueError(f'Invalid parameter nxroot ({nxroot})') | |
| 301 nxentry = nxroot[nxroot.attrs['default']] | |
| 302 if not isinstance(nxentry, NXentry): | |
| 303 raise ValueError(f'Invalid nxentry ({nxentry})') | |
| 304 if self.galaxy_flag: | |
| 305 if center_rows is not None: | |
| 306 center_rows = tuple(center_rows) | |
| 307 if not is_int_pair(center_rows): | |
| 308 raise ValueError(f'Invalid parameter center_rows ({center_rows})') | |
| 309 elif center_rows is not None: | |
| 310 logger.warning(f'Ignoring parameter center_rows ({center_rows})') | |
| 311 center_rows = None | |
| 312 if self.galaxy_flag: | |
| 313 if center_stack_index is not None and not is_int(center_stack_index, ge=0): | |
| 314 raise ValueError(f'Invalid parameter center_stack_index ({center_stack_index})') | |
| 315 elif center_stack_index is not None: | |
| 316 logger.warning(f'Ignoring parameter center_stack_index ({center_stack_index})') | |
| 317 center_stack_index = None | |
| 318 | |
| 319 # Create plot galaxy path directory and path if needed | |
| 320 if self.galaxy_flag: | |
| 321 if not os_path.exists('tomo_find_centers_plots'): | |
| 322 mkdir('tomo_find_centers_plots') | |
| 323 path = 'tomo_find_centers_plots' | |
| 324 else: | |
| 325 path = self.output_folder | |
| 326 | |
| 327 # Check if reduced data is available | |
| 328 if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data): | |
| 329 raise KeyError(f'Unable to find valid reduced data in {nxentry}.') | |
| 330 | |
| 331 # Select the image stack to calibrate the center axis | |
| 332 # reduced data axes order: stack,theta,row,column | |
| 333 # Note: Nexus cannot follow a link if the data it points to is too big, | |
| 334 # so get the data from the actual place, not from nxentry.data | |
| 335 tomo_fields_shape = nxentry.reduced_data.data.tomo_fields.shape | |
| 336 if len(tomo_fields_shape) != 4 or any(True for dim in tomo_fields_shape if not dim): | |
| 337 raise KeyError('Unable to load the required reduced tomography stack') | |
| 338 num_tomo_stacks = tomo_fields_shape[0] | |
| 339 if num_tomo_stacks == 1: | |
| 340 center_stack_index = 0 | |
| 341 default = 'n' | |
| 342 else: | |
| 343 if self.test_mode: | |
| 344 center_stack_index = self.test_config['center_stack_index']-1 # make offset 0 | |
| 345 elif self.galaxy_flag: | |
| 346 if center_stack_index is None: | |
| 347 center_stack_index = int(num_tomo_stacks/2) | |
| 348 if center_stack_index >= num_tomo_stacks: | |
| 349 raise ValueError(f'Invalid parameter center_stack_index ({center_stack_index})') | |
| 350 else: | |
| 351 center_stack_index = input_int('\nEnter tomography stack index to calibrate the ' | |
| 352 'center axis', ge=1, le=num_tomo_stacks, default=int(1+num_tomo_stacks/2)) | |
| 353 center_stack_index -= 1 | |
| 354 default = 'y' | |
| 355 | |
| 356 # Get thetas (in degrees) | |
| 357 thetas = np.asarray(nxentry.reduced_data.rotation_angle) | |
| 358 | |
| 359 # Get effective pixel_size | |
| 360 if 'zoom_perc' in nxentry.reduced_data: | |
| 361 eff_pixel_size = 100.*(nxentry.instrument.detector.x_pixel_size/ | |
| 362 nxentry.reduced_data.attrs['zoom_perc']) | |
| 363 else: | |
| 364 eff_pixel_size = nxentry.instrument.detector.x_pixel_size | |
| 365 | |
| 366 # Get cross sectional diameter | |
| 367 cross_sectional_dim = tomo_fields_shape[3]*eff_pixel_size | |
| 368 logger.debug(f'cross_sectional_dim = {cross_sectional_dim}') | |
| 369 | |
| 370 # Determine center offset at sample row boundaries | |
| 371 logger.info('Determine center offset at sample row boundaries') | |
| 372 | |
| 373 # Lower row center | |
| 374 if self.test_mode: | |
| 375 lower_row = self.test_config['lower_row'] | |
| 376 elif self.galaxy_flag: | |
| 377 if center_rows is None or center_rows[0] == -1: | |
| 378 lower_row = 0 | |
| 379 else: | |
| 380 lower_row = min(center_rows) | |
| 381 if not 0 <= lower_row < tomo_fields_shape[2]-1: | |
| 382 raise ValueError(f'Invalid parameter center_rows ({center_rows})') | |
| 383 else: | |
| 384 lower_row = select_one_image_bound( | |
| 385 nxentry.reduced_data.data.tomo_fields[center_stack_index,0,:,:], 0, bound=0, | |
| 386 title=f'theta={round(thetas[0], 2)+0}', | |
| 387 bound_name='row index to find lower center', default=default, raise_error=True) | |
| 388 logger.debug('Finding center...') | |
| 389 t0 = time() | |
| 390 lower_center_offset = self._find_center_one_plane( | |
| 391 #np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index,:,lower_row,:]), | |
| 392 nxentry.reduced_data.data.tomo_fields[center_stack_index,:,lower_row,:], | |
| 393 lower_row, thetas, eff_pixel_size, cross_sectional_dim, path=path, | |
| 394 num_core=self.num_core) | |
| 395 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 396 logger.debug(f'lower_row = {lower_row:.2f}') | |
| 397 logger.debug(f'lower_center_offset = {lower_center_offset:.2f}') | |
| 398 | |
| 399 # Upper row center | |
| 400 if self.test_mode: | |
| 401 upper_row = self.test_config['upper_row'] | |
| 402 elif self.galaxy_flag: | |
| 403 if center_rows is None or center_rows[1] == -1: | |
| 404 upper_row = tomo_fields_shape[2]-1 | |
| 405 else: | |
| 406 upper_row = max(center_rows) | |
| 407 if not lower_row < upper_row < tomo_fields_shape[2]: | |
| 408 raise ValueError(f'Invalid parameter center_rows ({center_rows})') | |
| 409 else: | |
| 410 upper_row = select_one_image_bound( | |
| 411 nxentry.reduced_data.data.tomo_fields[center_stack_index,0,:,:], 0, | |
| 412 bound=tomo_fields_shape[2]-1, title=f'theta={round(thetas[0], 2)+0}', | |
| 413 bound_name='row index to find upper center', default=default, raise_error=True) | |
| 414 logger.debug('Finding center...') | |
| 415 t0 = time() | |
| 416 upper_center_offset = self._find_center_one_plane( | |
| 417 #np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index,:,upper_row,:]), | |
| 418 nxentry.reduced_data.data.tomo_fields[center_stack_index,:,upper_row,:], | |
| 419 upper_row, thetas, eff_pixel_size, cross_sectional_dim, path=path, | |
| 420 num_core=self.num_core) | |
| 421 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 422 logger.debug(f'upper_row = {upper_row:.2f}') | |
| 423 logger.debug(f'upper_center_offset = {upper_center_offset:.2f}') | |
| 424 | |
| 425 center_config = {'lower_row': lower_row, 'lower_center_offset': lower_center_offset, | |
| 426 'upper_row': upper_row, 'upper_center_offset': upper_center_offset} | |
| 427 if num_tomo_stacks > 1: | |
| 428 center_config['center_stack_index'] = center_stack_index+1 # save as offset 1 | |
| 429 | |
| 430 # Save test data to file | |
| 431 if self.test_mode: | |
| 432 with open(f'{self.output_folder}/center_config.yaml', 'w') as f: | |
| 433 safe_dump(center_config, f) | |
| 434 | |
| 435 return(center_config) | |
| 436 | |
| 437 def reconstruct_data(self, nxroot, center_info, x_bounds=None, y_bounds=None): | |
| 438 """Reconstruct the tomography data. | |
| 439 """ | |
| 440 logger.info('Reconstruct the tomography data') | |
| 441 | |
| 442 if not isinstance(nxroot, NXroot): | |
| 443 raise ValueError(f'Invalid parameter nxroot ({nxroot})') | |
| 444 nxentry = nxroot[nxroot.attrs['default']] | |
| 445 if not isinstance(nxentry, NXentry): | |
| 446 raise ValueError(f'Invalid nxentry ({nxentry})') | |
| 447 if not isinstance(center_info, dict): | |
| 448 raise ValueError(f'Invalid parameter center_info ({center_info})') | |
| 449 | |
| 450 # Create plot galaxy path directory and path if needed | |
| 451 if self.galaxy_flag: | |
| 452 if not os_path.exists('tomo_reconstruct_plots'): | |
| 453 mkdir('tomo_reconstruct_plots') | |
| 454 path = 'tomo_reconstruct_plots' | |
| 455 else: | |
| 456 path = self.output_folder | |
| 457 | |
| 458 # Check if reduced data is available | |
| 459 if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data): | |
| 460 raise KeyError(f'Unable to find valid reduced data in {nxentry}.') | |
| 461 | |
| 462 # Create an NXprocess to store image reconstruction (meta)data | |
| 463 nxprocess = NXprocess() | |
| 464 | |
| 465 # Get rotation axis rows and centers | |
| 466 lower_row = center_info.get('lower_row') | |
| 467 lower_center_offset = center_info.get('lower_center_offset') | |
| 468 upper_row = center_info.get('upper_row') | |
| 469 upper_center_offset = center_info.get('upper_center_offset') | |
| 470 if (lower_row is None or lower_center_offset is None or upper_row is None or | |
| 471 upper_center_offset is None): | |
| 472 raise KeyError(f'Unable to find valid calibrated center axis info in {center_info}.') | |
| 473 center_slope = (upper_center_offset-lower_center_offset)/(upper_row-lower_row) | |
| 474 | |
| 475 # Get thetas (in degrees) | |
| 476 thetas = np.asarray(nxentry.reduced_data.rotation_angle) | |
| 477 | |
| 478 # Reconstruct tomography data | |
| 479 # reduced data axes order: stack,theta,row,column | |
| 480 # reconstructed data order in each stack: row/z,x,y | |
| 481 # Note: Nexus cannot follow a link if the data it points to is too big, | |
| 482 # so get the data from the actual place, not from nxentry.data | |
| 483 if 'zoom_perc' in nxentry.reduced_data: | |
| 484 res_title = f'{nxentry.reduced_data.attrs["zoom_perc"]}p' | |
| 485 else: | |
| 486 res_title = 'fullres' | |
| 487 load_error = False | |
| 488 num_tomo_stacks = nxentry.reduced_data.data.tomo_fields.shape[0] | |
| 489 tomo_recon_stacks = num_tomo_stacks*[np.array([])] | |
| 490 for i in range(num_tomo_stacks): | |
| 491 # Convert reduced data stack from theta,row,column to row,theta,column | |
| 492 logger.debug(f'Reading reduced data stack {i+1}...') | |
| 493 t0 = time() | |
| 494 tomo_stack = np.asarray(nxentry.reduced_data.data.tomo_fields[i]) | |
| 495 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 496 if len(tomo_stack.shape) != 3 or any(True for dim in tomo_stack.shape if not dim): | |
| 497 raise ValueError(f'Unable to load tomography stack {i+1} for reconstruction') | |
| 498 tomo_stack = np.swapaxes(tomo_stack, 0, 1) | |
| 499 assert(len(thetas) == tomo_stack.shape[1]) | |
| 500 assert(0 <= lower_row < upper_row < tomo_stack.shape[0]) | |
| 501 center_offsets = [lower_center_offset-lower_row*center_slope, | |
| 502 upper_center_offset+(tomo_stack.shape[0]-1-upper_row)*center_slope] | |
| 503 t0 = time() | |
| 504 logger.debug(f'Running _reconstruct_one_tomo_stack on {self.num_core} cores ...') | |
| 505 tomo_recon_stack = self._reconstruct_one_tomo_stack(tomo_stack, thetas, | |
| 506 center_offsets=center_offsets, num_core=self.num_core, algorithm='gridrec') | |
| 507 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 508 logger.info(f'Reconstruction of stack {i+1} took {time()-t0:.2f} seconds') | |
| 509 | |
| 510 # Combine stacks | |
| 511 tomo_recon_stacks[i] = tomo_recon_stack | |
| 512 | |
| 513 # Resize the reconstructed tomography data | |
| 514 # reconstructed data order in each stack: row/z,x,y | |
| 515 if self.test_mode: | |
| 516 x_bounds = self.test_config.get('x_bounds') | |
| 517 y_bounds = self.test_config.get('y_bounds') | |
| 518 z_bounds = None | |
| 519 elif self.galaxy_flag: | |
| 520 if x_bounds is not None and not is_int_pair(x_bounds, ge=0, | |
| 521 lt=tomo_recon_stacks[0].shape[1]): | |
| 522 raise ValueError(f'Invalid parameter x_bounds ({x_bounds})') | |
| 523 if y_bounds is not None and not is_int_pair(y_bounds, ge=0, | |
| 524 lt=tomo_recon_stacks[0].shape[1]): | |
| 525 raise ValueError(f'Invalid parameter y_bounds ({y_bounds})') | |
| 526 z_bounds = None | |
| 527 else: | |
| 528 x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_stacks) | |
| 529 if x_bounds is None: | |
| 530 x_range = (0, tomo_recon_stacks[0].shape[1]) | |
| 531 x_slice = int(x_range[1]/2) | |
| 532 else: | |
| 533 x_range = (min(x_bounds), max(x_bounds)) | |
| 534 x_slice = int((x_bounds[0]+x_bounds[1])/2) | |
| 535 if y_bounds is None: | |
| 536 y_range = (0, tomo_recon_stacks[0].shape[2]) | |
| 537 y_slice = int(y_range[1]/2) | |
| 538 else: | |
| 539 y_range = (min(y_bounds), max(y_bounds)) | |
| 540 y_slice = int((y_bounds[0]+y_bounds[1])/2) | |
| 541 if z_bounds is None: | |
| 542 z_range = (0, tomo_recon_stacks[0].shape[0]) | |
| 543 z_slice = int(z_range[1]/2) | |
| 544 else: | |
| 545 z_range = (min(z_bounds), max(z_bounds)) | |
| 546 z_slice = int((z_bounds[0]+z_bounds[1])/2) | |
| 547 | |
| 548 # Plot a few reconstructed image slices | |
| 549 if num_tomo_stacks == 1: | |
| 550 basetitle = 'recon' | |
| 551 else: | |
| 552 basetitle = f'recon stack {i+1}' | |
| 553 for i, stack in enumerate(tomo_recon_stacks): | |
| 554 title = f'{basetitle} {res_title} xslice{x_slice}' | |
| 555 quick_imshow(stack[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]], | |
| 556 title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, | |
| 557 block=self.block) | |
| 558 title = f'{basetitle} {res_title} yslice{y_slice}' | |
| 559 quick_imshow(stack[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice], | |
| 560 title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, | |
| 561 block=self.block) | |
| 562 title = f'{basetitle} {res_title} zslice{z_slice}' | |
| 563 quick_imshow(stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], | |
| 564 title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, | |
| 565 block=self.block) | |
| 566 | |
| 567 # Save test data to file | |
| 568 # reconstructed data order in each stack: row/z,x,y | |
| 569 if self.test_mode: | |
| 570 for i, stack in enumerate(tomo_recon_stacks): | |
| 571 np.savetxt(f'{self.output_folder}/recon_stack_{i+1}.txt', | |
| 572 stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e') | |
| 573 | |
| 574 # Add image reconstruction to reconstructed data NXprocess | |
| 575 # reconstructed data order in each stack: row/z,x,y | |
| 576 nxprocess.data = NXdata() | |
| 577 nxprocess.attrs['default'] = 'data' | |
| 578 for k, v in center_info.items(): | |
| 579 nxprocess[k] = v | |
| 580 if x_bounds is not None: | |
| 581 nxprocess.x_bounds = x_bounds | |
| 582 if y_bounds is not None: | |
| 583 nxprocess.y_bounds = y_bounds | |
| 584 if z_bounds is not None: | |
| 585 nxprocess.z_bounds = z_bounds | |
| 586 nxprocess.data['reconstructed_data'] = np.asarray([stack[z_range[0]:z_range[1], | |
| 587 x_range[0]:x_range[1],y_range[0]:y_range[1]] for stack in tomo_recon_stacks]) | |
| 588 nxprocess.data.attrs['signal'] = 'reconstructed_data' | |
| 589 | |
| 590 # Create a copy of the input Nexus object and remove reduced data | |
| 591 exclude_items = [f'{nxentry._name}/reduced_data/data', f'{nxentry._name}/data/reduced_data'] | |
| 592 nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items) | |
| 593 | |
| 594 # Add the reconstructed data NXprocess to the new Nexus object | |
| 595 nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']] | |
| 596 nxentry_copy.reconstructed_data = nxprocess | |
| 597 if 'data' not in nxentry_copy: | |
| 598 nxentry_copy.data = NXdata() | |
| 599 nxentry_copy.attrs['default'] = 'data' | |
| 600 nxentry_copy.data.makelink(nxprocess.data.reconstructed_data, name='reconstructed_data') | |
| 601 nxentry_copy.data.attrs['signal'] = 'reconstructed_data' | |
| 602 | |
| 603 return(nxroot_copy) | |
| 604 | |
| 605 def combine_data(self, nxroot, x_bounds=None, y_bounds=None): | |
| 606 """Combine the reconstructed tomography stacks. | |
| 607 """ | |
| 608 logger.info('Combine the reconstructed tomography stacks') | |
| 609 | |
| 610 if not isinstance(nxroot, NXroot): | |
| 611 raise ValueError(f'Invalid parameter nxroot ({nxroot})') | |
| 612 nxentry = nxroot[nxroot.attrs['default']] | |
| 613 if not isinstance(nxentry, NXentry): | |
| 614 raise ValueError(f'Invalid nxentry ({nxentry})') | |
| 615 | |
| 616 # Create plot galaxy path directory and path if needed | |
| 617 if self.galaxy_flag: | |
| 618 if not os_path.exists('tomo_combine_plots'): | |
| 619 mkdir('tomo_combine_plots') | |
| 620 path = 'tomo_combine_plots' | |
| 621 else: | |
| 622 path = self.output_folder | |
| 623 | |
| 624 # Check if reconstructed image data is available | |
| 625 if ('reconstructed_data' not in nxentry or 'reconstructed_data' not in nxentry.data): | |
| 626 raise KeyError(f'Unable to find valid reconstructed image data in {nxentry}.') | |
| 627 | |
| 628 # Create an NXprocess to store combined image reconstruction (meta)data | |
| 629 nxprocess = NXprocess() | |
| 630 | |
| 631 # Get the reconstructed data | |
| 632 # reconstructed data order: stack,row(z),x,y | |
| 633 # Note: Nexus cannot follow a link if the data it points to is too big, | |
| 634 # so get the data from the actual place, not from nxentry.data | |
| 635 num_tomo_stacks = nxentry.reconstructed_data.data.reconstructed_data.shape[0] | |
| 636 if num_tomo_stacks == 1: | |
| 637 logger.info('Only one stack available: leaving combine_data') | |
| 638 return(None) | |
| 639 | |
| 640 # Combine the reconstructed stacks | |
| 641 # (load one stack at a time to reduce risk of hitting Nexus data access limit) | |
| 642 t0 = time() | |
| 643 logger.debug(f'Combining the reconstructed stacks ...') | |
| 644 tomo_recon_combined = np.asarray(nxentry.reconstructed_data.data.reconstructed_data[0]) | |
| 645 if num_tomo_stacks > 2: | |
| 646 tomo_recon_combined = np.concatenate([tomo_recon_combined]+ | |
| 647 [nxentry.reconstructed_data.data.reconstructed_data[i] | |
| 648 for i in range(1, num_tomo_stacks-1)]) | |
| 649 if num_tomo_stacks > 1: | |
| 650 tomo_recon_combined = np.concatenate([tomo_recon_combined]+ | |
| 651 [nxentry.reconstructed_data.data.reconstructed_data[num_tomo_stacks-1]]) | |
| 652 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 653 logger.info(f'Combining the reconstructed stacks took {time()-t0:.2f} seconds') | |
| 654 | |
| 655 # Resize the combined tomography data stacks | |
| 656 # combined data order: row/z,x,y | |
| 657 if self.test_mode: | |
| 658 x_bounds = None | |
| 659 y_bounds = None | |
| 660 z_bounds = self.test_config.get('z_bounds') | |
| 661 elif self.galaxy_flag: | |
| 662 if x_bounds is not None and not is_int_pair(x_bounds, ge=0, | |
| 663 lt=tomo_recon_stacks[0].shape[1]): | |
| 664 raise ValueError(f'Invalid parameter x_bounds ({x_bounds})') | |
| 665 if y_bounds is not None and not is_int_pair(y_bounds, ge=0, | |
| 666 lt=tomo_recon_stacks[0].shape[1]): | |
| 667 raise ValueError(f'Invalid parameter y_bounds ({y_bounds})') | |
| 668 z_bounds = None | |
| 669 else: | |
| 670 x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_combined, | |
| 671 z_only=True) | |
| 672 if x_bounds is None: | |
| 673 x_range = (0, tomo_recon_combined.shape[1]) | |
| 674 x_slice = int(x_range[1]/2) | |
| 675 else: | |
| 676 x_range = x_bounds | |
| 677 x_slice = int((x_bounds[0]+x_bounds[1])/2) | |
| 678 if y_bounds is None: | |
| 679 y_range = (0, tomo_recon_combined.shape[2]) | |
| 680 y_slice = int(y_range[1]/2) | |
| 681 else: | |
| 682 y_range = y_bounds | |
| 683 y_slice = int((y_bounds[0]+y_bounds[1])/2) | |
| 684 if z_bounds is None: | |
| 685 z_range = (0, tomo_recon_combined.shape[0]) | |
| 686 z_slice = int(z_range[1]/2) | |
| 687 else: | |
| 688 z_range = z_bounds | |
| 689 z_slice = int((z_bounds[0]+z_bounds[1])/2) | |
| 690 | |
| 691 # Plot a few combined image slices | |
| 692 quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]], | |
| 693 title=f'recon combined xslice{x_slice}', path=path, | |
| 694 save_fig=self.save_figs, save_only=self.save_only, block=self.block) | |
| 695 quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice], | |
| 696 title=f'recon combined yslice{y_slice}', path=path, | |
| 697 save_fig=self.save_figs, save_only=self.save_only, block=self.block) | |
| 698 quick_imshow(tomo_recon_combined[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], | |
| 699 title=f'recon combined zslice{z_slice}', path=path, | |
| 700 save_fig=self.save_figs, save_only=self.save_only, block=self.block) | |
| 701 | |
| 702 # Save test data to file | |
| 703 # combined data order: row/z,x,y | |
| 704 if self.test_mode: | |
| 705 np.savetxt(f'{self.output_folder}/recon_combined.txt', tomo_recon_combined[ | |
| 706 z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e') | |
| 707 | |
| 708 # Add image reconstruction to reconstructed data NXprocess | |
| 709 # combined data order: row/z,x,y | |
| 710 nxprocess.data = NXdata() | |
| 711 nxprocess.attrs['default'] = 'data' | |
| 712 if x_bounds is not None: | |
| 713 nxprocess.x_bounds = x_bounds | |
| 714 if y_bounds is not None: | |
| 715 nxprocess.y_bounds = y_bounds | |
| 716 if z_bounds is not None: | |
| 717 nxprocess.z_bounds = z_bounds | |
| 718 nxprocess.data['combined_data'] = tomo_recon_combined | |
| 719 nxprocess.data.attrs['signal'] = 'combined_data' | |
| 720 | |
| 721 # Create a copy of the input Nexus object and remove reconstructed data | |
| 722 exclude_items = [f'{nxentry._name}/reconstructed_data/data', | |
| 723 f'{nxentry._name}/data/reconstructed_data'] | |
| 724 nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items) | |
| 725 | |
| 726 # Add the combined data NXprocess to the new Nexus object | |
| 727 nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']] | |
| 728 nxentry_copy.combined_data = nxprocess | |
| 729 if 'data' not in nxentry_copy: | |
| 730 nxentry_copy.data = NXdata() | |
| 731 nxentry_copy.attrs['default'] = 'data' | |
| 732 nxentry_copy.data.makelink(nxprocess.data.combined_data, name='combined_data') | |
| 733 nxentry_copy.data.attrs['signal'] = 'combined_data' | |
| 734 | |
| 735 return(nxroot_copy) | |
| 736 | |
| 737 def _gen_dark(self, nxentry, reduced_data): | |
| 738 """Generate dark field. | |
| 739 """ | |
| 740 # Get the dark field images | |
| 741 image_key = nxentry.instrument.detector.get('image_key', None) | |
| 742 if image_key and 'data' in nxentry.instrument.detector: | |
| 743 field_indices = [index for index, key in enumerate(image_key) if key == 2] | |
| 744 tdf_stack = nxentry.instrument.detector.data[field_indices,:,:] | |
| 745 # RV the default NXtomo form does not accomodate bright or dark field stacks | |
| 746 else: | |
| 747 dark_field_scans = nxentry.spec_scans.dark_field | |
| 748 dark_field = FlatField.construct_from_nxcollection(dark_field_scans) | |
| 749 prefix = str(nxentry.instrument.detector.local_name) | |
| 750 tdf_stack = dark_field.get_detector_data(prefix) | |
| 751 if isinstance(tdf_stack, list): | |
| 752 assert(len(tdf_stack) == 1) # TODO | |
| 753 tdf_stack = tdf_stack[0] | |
| 754 | |
| 755 # Take median | |
| 756 if tdf_stack.ndim == 2: | |
| 757 tdf = tdf_stack | |
| 758 elif tdf_stack.ndim == 3: | |
| 759 tdf = np.median(tdf_stack, axis=0) | |
| 760 del tdf_stack | |
| 761 else: | |
| 762 raise ValueError(f'Invalid tdf_stack shape ({tdf_stack.shape})') | |
| 763 | |
| 764 # Remove dark field intensities above the cutoff | |
| 765 #RV tdf_cutoff = None | |
| 766 tdf_cutoff = tdf.min()+2*(np.median(tdf)-tdf.min()) | |
| 767 logger.debug(f'tdf_cutoff = {tdf_cutoff}') | |
| 768 if tdf_cutoff is not None: | |
| 769 if not is_num(tdf_cutoff, ge=0): | |
| 770 logger.warning(f'Ignoring illegal value of tdf_cutoff {tdf_cutoff}') | |
| 771 else: | |
| 772 tdf[tdf > tdf_cutoff] = np.nan | |
| 773 logger.debug(f'tdf_cutoff = {tdf_cutoff}') | |
| 774 | |
| 775 # Remove nans | |
| 776 tdf_mean = np.nanmean(tdf) | |
| 777 logger.debug(f'tdf_mean = {tdf_mean}') | |
| 778 np.nan_to_num(tdf, copy=False, nan=tdf_mean, posinf=tdf_mean, neginf=0.) | |
| 779 | |
| 780 # Plot dark field | |
| 781 if self.galaxy_flag: | |
| 782 quick_imshow(tdf, title='dark field', path='tomo_reduce_plots', save_fig=self.save_figs, | |
| 783 save_only=self.save_only) | |
| 784 elif not self.test_mode: | |
| 785 quick_imshow(tdf, title='dark field', path=self.output_folder, save_fig=self.save_figs, | |
| 786 save_only=self.save_only) | |
| 787 clear_imshow('dark field') | |
| 788 # quick_imshow(tdf, title='dark field', block=True) | |
| 789 | |
| 790 # Add dark field to reduced data NXprocess | |
| 791 reduced_data.data = NXdata() | |
| 792 reduced_data.data['dark_field'] = tdf | |
| 793 | |
| 794 return(reduced_data) | |
| 795 | |
| 796 def _gen_bright(self, nxentry, reduced_data): | |
| 797 """Generate bright field. | |
| 798 """ | |
| 799 # Get the bright field images | |
| 800 image_key = nxentry.instrument.detector.get('image_key', None) | |
| 801 if image_key and 'data' in nxentry.instrument.detector: | |
| 802 field_indices = [index for index, key in enumerate(image_key) if key == 1] | |
| 803 tbf_stack = nxentry.instrument.detector.data[field_indices,:,:] | |
| 804 # RV the default NXtomo form does not accomodate bright or dark field stacks | |
| 805 else: | |
| 806 bright_field_scans = nxentry.spec_scans.bright_field | |
| 807 bright_field = FlatField.construct_from_nxcollection(bright_field_scans) | |
| 808 prefix = str(nxentry.instrument.detector.local_name) | |
| 809 tbf_stack = bright_field.get_detector_data(prefix) | |
| 810 if isinstance(tbf_stack, list): | |
| 811 assert(len(tbf_stack) == 1) # TODO | |
| 812 tbf_stack = tbf_stack[0] | |
| 813 | |
| 814 # Take median if more than one image | |
| 815 """Median or mean: It may be best to try the median because of some image | |
| 816 artifacts that arise due to crinkles in the upstream kapton tape windows | |
| 817 causing some phase contrast images to appear on the detector. | |
| 818 One thing that also may be useful in a future implementation is to do a | |
| 819 brightfield adjustment on EACH frame of the tomo based on a ROI in the | |
| 820 corner of the frame where there is no sample but there is the direct X-ray | |
| 821 beam because there is frame to frame fluctuations from the incoming beam. | |
| 822 We don’t typically account for them but potentially could. | |
| 823 """ | |
| 824 if tbf_stack.ndim == 2: | |
| 825 tbf = tbf_stack | |
| 826 elif tbf_stack.ndim == 3: | |
| 827 tbf = np.median(tbf_stack, axis=0) | |
| 828 del tbf_stack | |
| 829 else: | |
| 830 raise ValueError(f'Invalid tbf_stack shape ({tbf_stacks.shape})') | |
| 831 | |
| 832 # Subtract dark field | |
| 833 if 'data' in reduced_data and 'dark_field' in reduced_data.data: | |
| 834 tbf -= reduced_data.data.dark_field | |
| 835 else: | |
| 836 logger.warning('Dark field unavailable') | |
| 837 | |
| 838 # Set any non-positive values to one | |
| 839 # (avoid negative bright field values for spikes in dark field) | |
| 840 tbf[tbf < 1] = 1 | |
| 841 | |
| 842 # Plot bright field | |
| 843 if self.galaxy_flag: | |
| 844 quick_imshow(tbf, title='bright field', path='tomo_reduce_plots', | |
| 845 save_fig=self.save_figs, save_only=self.save_only) | |
| 846 elif not self.test_mode: | |
| 847 quick_imshow(tbf, title='bright field', path=self.output_folder, | |
| 848 save_fig=self.save_figs, save_only=self.save_only) | |
| 849 clear_imshow('bright field') | |
| 850 # quick_imshow(tbf, title='bright field', block=True) | |
| 851 | |
| 852 # Add bright field to reduced data NXprocess | |
| 853 if 'data' not in reduced_data: | |
| 854 reduced_data.data = NXdata() | |
| 855 reduced_data.data['bright_field'] = tbf | |
| 856 | |
| 857 return(reduced_data) | |
| 858 | |
| 859 def _set_detector_bounds(self, nxentry, reduced_data, img_x_bounds=None): | |
| 860 """Set vertical detector bounds for each image stack. | |
| 861 Right now the range is the same for each set in the image stack. | |
| 862 """ | |
| 863 if self.test_mode: | |
| 864 return(tuple(self.test_config['img_x_bounds'])) | |
| 865 | |
| 866 # Get the first tomography image and the reference heights | |
| 867 image_key = nxentry.instrument.detector.get('image_key', None) | |
| 868 if image_key and 'data' in nxentry.instrument.detector: | |
| 869 field_indices = [index for index, key in enumerate(image_key) if key == 0] | |
| 870 first_image = np.asarray(nxentry.instrument.detector.data[field_indices[0],:,:]) | |
| 871 theta = float(nxentry.sample.rotation_angle[field_indices[0]]) | |
| 872 z_translation_all = nxentry.sample.z_translation[field_indices] | |
| 873 vertical_shifts = sorted(list(set(z_translation_all))) | |
| 874 num_tomo_stacks = len(vertical_shifts) | |
| 875 else: | |
| 876 tomo_field_scans = nxentry.spec_scans.tomo_fields | |
| 877 tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans) | |
| 878 vertical_shifts = tomo_fields.get_vertical_shifts() | |
| 879 if not isinstance(vertical_shifts, list): | |
| 880 vertical_shifts = [vertical_shifts] | |
| 881 prefix = str(nxentry.instrument.detector.local_name) | |
| 882 t0 = time() | |
| 883 first_image = tomo_fields.get_detector_data(prefix, tomo_fields.scan_numbers[0], 0) | |
| 884 logger.debug(f'Getting first image took {time()-t0:.2f} seconds') | |
| 885 num_tomo_stacks = len(tomo_fields.scan_numbers) | |
| 886 theta = tomo_fields.theta_range['start'] | |
| 887 | |
| 888 # Select image bounds | |
| 889 title = f'tomography image at theta={round(theta, 2)+0}' | |
| 890 if img_x_bounds is not None: | |
| 891 if is_int_pair(img_x_bounds) and img_x_bounds[0] == -1 and img_x_bounds[1] == -1: | |
| 892 img_x_bounds = None | |
| 893 elif not is_index_range(img_x_bounds, ge=0, le=first_image.shape[0]): | |
| 894 raise ValueError(f'Invalid parameter img_x_bounds ({img_x_bounds})') | |
| 895 if nxentry.instrument.source.attrs['station'] in ('id1a3', 'id3a'): | |
| 896 pixel_size = nxentry.instrument.detector.x_pixel_size | |
| 897 # Try to get a fit from the bright field | |
| 898 tbf = np.asarray(reduced_data.data.bright_field) | |
| 899 tbf_shape = tbf.shape | |
| 900 x_sum = np.sum(tbf, 1) | |
| 901 x_sum_min = x_sum.min() | |
| 902 x_sum_max = x_sum.max() | |
| 903 fit = Fit.fit_data(x_sum, 'rectangle', x=np.array(range(len(x_sum))), form='atan', | |
| 904 guess=True) | |
| 905 parameters = fit.best_values | |
| 906 x_low_fit = parameters.get('center1', None) | |
| 907 x_upp_fit = parameters.get('center2', None) | |
| 908 sig_low = parameters.get('sigma1', None) | |
| 909 sig_upp = parameters.get('sigma2', None) | |
| 910 have_fit = fit.success and x_low_fit is not None and x_upp_fit is not None and \ | |
| 911 sig_low is not None and sig_upp is not None and \ | |
| 912 0 <= x_low_fit < x_upp_fit <= x_sum.size and \ | |
| 913 (sig_low+sig_upp)/(x_upp_fit-x_low_fit) < 0.1 | |
| 914 if have_fit: | |
| 915 # Set a 5% margin on each side | |
| 916 margin = 0.05*(x_upp_fit-x_low_fit) | |
| 917 x_low_fit = max(0, x_low_fit-margin) | |
| 918 x_upp_fit = min(tbf_shape[0], x_upp_fit+margin) | |
| 919 if num_tomo_stacks == 1: | |
| 920 if have_fit: | |
| 921 # Set the default range to enclose the full fitted window | |
| 922 x_low = int(x_low_fit) | |
| 923 x_upp = int(x_upp_fit) | |
| 924 else: | |
| 925 # Center a default range of 1 mm (RV: can we get this from the slits?) | |
| 926 num_x_min = int((1.0-0.5*pixel_size)/pixel_size) | |
| 927 x_low = int(0.5*(tbf_shape[0]-num_x_min)) | |
| 928 x_upp = x_low+num_x_min | |
| 929 else: | |
| 930 # Get the default range from the reference heights | |
| 931 delta_z = vertical_shifts[1]-vertical_shifts[0] | |
| 932 for i in range(2, num_tomo_stacks): | |
| 933 delta_z = min(delta_z, vertical_shifts[i]-vertical_shifts[i-1]) | |
| 934 logger.debug(f'delta_z = {delta_z}') | |
| 935 num_x_min = int((delta_z-0.5*pixel_size)/pixel_size) | |
| 936 logger.debug(f'num_x_min = {num_x_min}') | |
| 937 if num_x_min > tbf_shape[0]: | |
| 938 logger.warning('Image bounds and pixel size prevent seamless stacking') | |
| 939 if have_fit: | |
| 940 # Center the default range relative to the fitted window | |
| 941 x_low = int(0.5*(x_low_fit+x_upp_fit-num_x_min)) | |
| 942 x_upp = x_low+num_x_min | |
| 943 else: | |
| 944 # Center the default range | |
| 945 x_low = int(0.5*(tbf_shape[0]-num_x_min)) | |
| 946 x_upp = x_low+num_x_min | |
| 947 if self.galaxy_flag: | |
| 948 img_x_bounds = (x_low, x_upp) | |
| 949 else: | |
| 950 tmp = np.copy(tbf) | |
| 951 tmp_max = tmp.max() | |
| 952 tmp[x_low,:] = tmp_max | |
| 953 tmp[x_upp-1,:] = tmp_max | |
| 954 quick_imshow(tmp, title='bright field') | |
| 955 tmp = np.copy(first_image) | |
| 956 tmp_max = tmp.max() | |
| 957 tmp[x_low,:] = tmp_max | |
| 958 tmp[x_upp-1,:] = tmp_max | |
| 959 quick_imshow(tmp, title=title) | |
| 960 del tmp | |
| 961 quick_plot((range(x_sum.size), x_sum), | |
| 962 ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'), | |
| 963 ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'), | |
| 964 title='sum over theta and y') | |
| 965 print(f'lower bound = {x_low} (inclusive)') | |
| 966 print(f'upper bound = {x_upp} (exclusive)]') | |
| 967 accept = input_yesno('Accept these bounds (y/n)?', 'y') | |
| 968 clear_imshow('bright field') | |
| 969 clear_imshow(title) | |
| 970 clear_plot('sum over theta and y') | |
| 971 if accept: | |
| 972 img_x_bounds = (x_low, x_upp) | |
| 973 else: | |
| 974 while True: | |
| 975 mask, img_x_bounds = draw_mask_1d(x_sum, title='select x data range', | |
| 976 legend='sum over theta and y') | |
| 977 if len(img_x_bounds) == 1: | |
| 978 break | |
| 979 else: | |
| 980 print(f'Choose a single connected data range') | |
| 981 img_x_bounds = tuple(img_x_bounds[0]) | |
| 982 if (num_tomo_stacks > 1 and img_x_bounds[1]-img_x_bounds[0]+1 < | |
| 983 int((delta_z-0.5*pixel_size)/pixel_size)): | |
| 984 logger.warning('Image bounds and pixel size prevent seamless stacking') | |
| 985 else: | |
| 986 if num_tomo_stacks > 1: | |
| 987 raise NotImplementedError('Selecting image bounds for multiple stacks on FMB') | |
| 988 # For FMB: use the first tomography image to select range | |
| 989 # RV: revisit if they do tomography with multiple stacks | |
| 990 x_sum = np.sum(first_image, 1) | |
| 991 x_sum_min = x_sum.min() | |
| 992 x_sum_max = x_sum.max() | |
| 993 if self.galaxy_flag: | |
| 994 if img_x_bounds is None: | |
| 995 img_x_bounds = (0, first_image.shape[0]) | |
| 996 else: | |
| 997 quick_imshow(first_image, title=title) | |
| 998 print('Select vertical data reduction range from first tomography image') | |
| 999 img_x_bounds = select_image_bounds(first_image, 0, title=title) | |
| 1000 clear_imshow(title) | |
| 1001 if img_x_bounds is None: | |
| 1002 raise ValueError('Unable to select image bounds') | |
| 1003 | |
| 1004 # Plot results | |
| 1005 if self.galaxy_flag: | |
| 1006 path = 'tomo_reduce_plots' | |
| 1007 else: | |
| 1008 path = self.output_folder | |
| 1009 x_low = img_x_bounds[0] | |
| 1010 x_upp = img_x_bounds[1] | |
| 1011 tmp = np.copy(first_image) | |
| 1012 tmp_max = tmp.max() | |
| 1013 tmp[x_low,:] = tmp_max | |
| 1014 tmp[x_upp-1,:] = tmp_max | |
| 1015 quick_imshow(tmp, title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, | |
| 1016 block=self.block) | |
| 1017 del tmp | |
| 1018 quick_plot((range(x_sum.size), x_sum), | |
| 1019 ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'), | |
| 1020 ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'), | |
| 1021 title='sum over theta and y', path=path, save_fig=self.save_figs, | |
| 1022 save_only=self.save_only, block=self.block) | |
| 1023 | |
| 1024 return(img_x_bounds) | |
| 1025 | |
| 1026 def _set_zoom_or_skip(self): | |
| 1027 """Set zoom and/or theta skip to reduce memory the requirement for the analysis. | |
| 1028 """ | |
| 1029 # if input_yesno('\nDo you want to zoom in to reduce memory requirement (y/n)?', 'n'): | |
| 1030 # zoom_perc = input_int(' Enter zoom percentage', ge=1, le=100) | |
| 1031 # else: | |
| 1032 # zoom_perc = None | |
| 1033 zoom_perc = None | |
| 1034 # if input_yesno('Do you want to skip thetas to reduce memory requirement (y/n)?', 'n'): | |
| 1035 # num_theta_skip = input_int(' Enter the number skip theta interval', ge=0, | |
| 1036 # lt=num_theta) | |
| 1037 # else: | |
| 1038 # num_theta_skip = None | |
| 1039 num_theta_skip = None | |
| 1040 logger.debug(f'zoom_perc = {zoom_perc}') | |
| 1041 logger.debug(f'num_theta_skip = {num_theta_skip}') | |
| 1042 | |
| 1043 return(zoom_perc, num_theta_skip) | |
| 1044 | |
| 1045 def _gen_tomo(self, nxentry, reduced_data): | |
| 1046 """Generate tomography fields. | |
| 1047 """ | |
| 1048 # Get full bright field | |
| 1049 tbf = np.asarray(reduced_data.data.bright_field) | |
| 1050 tbf_shape = tbf.shape | |
| 1051 | |
| 1052 # Get image bounds | |
| 1053 img_x_bounds = tuple(reduced_data.get('img_x_bounds', (0, tbf_shape[0]))) | |
| 1054 img_y_bounds = tuple(reduced_data.get('img_y_bounds', (0, tbf_shape[1]))) | |
| 1055 | |
| 1056 # Get resized dark field | |
| 1057 # if 'dark_field' in data: | |
| 1058 # tbf = np.asarray(reduced_data.data.dark_field[ | |
| 1059 # img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]]) | |
| 1060 # else: | |
| 1061 # logger.warning('Dark field unavailable') | |
| 1062 # tdf = None | |
| 1063 tdf = None | |
| 1064 | |
| 1065 # Resize bright field | |
| 1066 if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]): | |
| 1067 tbf = tbf[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]] | |
| 1068 | |
| 1069 # Get the tomography images | |
| 1070 image_key = nxentry.instrument.detector.get('image_key', None) | |
| 1071 if image_key and 'data' in nxentry.instrument.detector: | |
| 1072 field_indices_all = [index for index, key in enumerate(image_key) if key == 0] | |
| 1073 z_translation_all = nxentry.sample.z_translation[field_indices_all] | |
| 1074 z_translation_levels = sorted(list(set(z_translation_all))) | |
| 1075 num_tomo_stacks = len(z_translation_levels) | |
| 1076 tomo_stacks = num_tomo_stacks*[np.array([])] | |
| 1077 horizontal_shifts = [] | |
| 1078 vertical_shifts = [] | |
| 1079 thetas = None | |
| 1080 tomo_stacks = [] | |
| 1081 for i, z_translation in enumerate(z_translation_levels): | |
| 1082 field_indices = [field_indices_all[index] | |
| 1083 for index, z in enumerate(z_translation_all) if z == z_translation] | |
| 1084 horizontal_shift = list(set(nxentry.sample.x_translation[field_indices])) | |
| 1085 assert(len(horizontal_shift) == 1) | |
| 1086 horizontal_shifts += horizontal_shift | |
| 1087 vertical_shift = list(set(nxentry.sample.z_translation[field_indices])) | |
| 1088 assert(len(vertical_shift) == 1) | |
| 1089 vertical_shifts += vertical_shift | |
| 1090 sequence_numbers = nxentry.instrument.detector.sequence_number[field_indices] | |
| 1091 if thetas is None: | |
| 1092 thetas = np.asarray(nxentry.sample.rotation_angle[field_indices]) \ | |
| 1093 [sequence_numbers] | |
| 1094 else: | |
| 1095 assert(all(thetas[i] == nxentry.sample.rotation_angle[field_indices[index]] | |
| 1096 for i, index in enumerate(sequence_numbers))) | |
| 1097 assert(list(set(sequence_numbers)) == [i for i in range(len(sequence_numbers))]) | |
| 1098 if list(sequence_numbers) == [i for i in range(len(sequence_numbers))]: | |
| 1099 tomo_stack = np.asarray(nxentry.instrument.detector.data[field_indices]) | |
| 1100 else: | |
| 1101 raise ValueError('Unable to load the tomography images') | |
| 1102 tomo_stacks.append(tomo_stack) | |
| 1103 else: | |
| 1104 tomo_field_scans = nxentry.spec_scans.tomo_fields | |
| 1105 tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans) | |
| 1106 horizontal_shifts = tomo_fields.get_horizontal_shifts() | |
| 1107 vertical_shifts = tomo_fields.get_vertical_shifts() | |
| 1108 prefix = str(nxentry.instrument.detector.local_name) | |
| 1109 t0 = time() | |
| 1110 tomo_stacks = tomo_fields.get_detector_data(prefix) | |
| 1111 logger.debug(f'Getting tomography images took {time()-t0:.2f} seconds') | |
| 1112 logger.debug(f'Getting all images took {time()-t0:.2f} seconds') | |
| 1113 thetas = np.linspace(tomo_fields.theta_range['start'], tomo_fields.theta_range['end'], | |
| 1114 tomo_fields.theta_range['num']) | |
| 1115 if not isinstance(tomo_stacks, list): | |
| 1116 horizontal_shifts = [horizontal_shifts] | |
| 1117 vertical_shifts = [vertical_shifts] | |
| 1118 tomo_stacks = [tomo_stacks] | |
| 1119 | |
| 1120 reduced_tomo_stacks = [] | |
| 1121 if self.galaxy_flag: | |
| 1122 path = 'tomo_reduce_plots' | |
| 1123 else: | |
| 1124 path = self.output_folder | |
| 1125 for i, tomo_stack in enumerate(tomo_stacks): | |
| 1126 # Resize the tomography images | |
| 1127 # Right now the range is the same for each set in the image stack. | |
| 1128 if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]): | |
| 1129 t0 = time() | |
| 1130 tomo_stack = tomo_stack[:,img_x_bounds[0]:img_x_bounds[1], | |
| 1131 img_y_bounds[0]:img_y_bounds[1]].astype('float64') | |
| 1132 logger.debug(f'Resizing tomography images took {time()-t0:.2f} seconds') | |
| 1133 | |
| 1134 # Subtract dark field | |
| 1135 if tdf is not None: | |
| 1136 t0 = time() | |
| 1137 with set_numexpr_threads(self.num_core): | |
| 1138 ne.evaluate('tomo_stack-tdf', out=tomo_stack) | |
| 1139 logger.debug(f'Subtracting dark field took {time()-t0:.2f} seconds') | |
| 1140 | |
| 1141 # Normalize | |
| 1142 t0 = time() | |
| 1143 with set_numexpr_threads(self.num_core): | |
| 1144 ne.evaluate('tomo_stack/tbf', out=tomo_stack, truediv=True) | |
| 1145 logger.debug(f'Normalizing took {time()-t0:.2f} seconds') | |
| 1146 | |
| 1147 # Remove non-positive values and linearize data | |
| 1148 t0 = time() | |
| 1149 cutoff = 1.e-6 | |
| 1150 with set_numexpr_threads(self.num_core): | |
| 1151 ne.evaluate('where(tomo_stack<cutoff, cutoff, tomo_stack)', out=tomo_stack) | |
| 1152 with set_numexpr_threads(self.num_core): | |
| 1153 ne.evaluate('-log(tomo_stack)', out=tomo_stack) | |
| 1154 logger.debug('Removing non-positive values and linearizing data took '+ | |
| 1155 f'{time()-t0:.2f} seconds') | |
| 1156 | |
| 1157 # Get rid of nans/infs that may be introduced by normalization | |
| 1158 t0 = time() | |
| 1159 np.where(np.isfinite(tomo_stack), tomo_stack, 0.) | |
| 1160 logger.debug(f'Remove nans/infs took {time()-t0:.2f} seconds') | |
| 1161 | |
| 1162 # Downsize tomography stack to smaller size | |
| 1163 # TODO use theta_skip as well | |
| 1164 tomo_stack = tomo_stack.astype('float32') | |
| 1165 if not self.test_mode: | |
| 1166 if len(tomo_stacks) == 1: | |
| 1167 title = f'red fullres theta {round(thetas[0], 2)+0}' | |
| 1168 else: | |
| 1169 title = f'red stack {i+1} fullres theta {round(thetas[0], 2)+0}' | |
| 1170 quick_imshow(tomo_stack[0,:,:], title=title, path=path, save_fig=self.save_figs, | |
| 1171 save_only=self.save_only, block=self.block) | |
| 1172 # if not self.block: | |
| 1173 # clear_imshow(title) | |
| 1174 if False and zoom_perc != 100: | |
| 1175 t0 = time() | |
| 1176 logger.debug(f'Zooming in ...') | |
| 1177 tomo_zoom_list = [] | |
| 1178 for j in range(tomo_stack.shape[0]): | |
| 1179 tomo_zoom = spi.zoom(tomo_stack[j,:,:], 0.01*zoom_perc) | |
| 1180 tomo_zoom_list.append(tomo_zoom) | |
| 1181 tomo_stack = np.stack([tomo_zoom for tomo_zoom in tomo_zoom_list]) | |
| 1182 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 1183 logger.info(f'Zooming in took {time()-t0:.2f} seconds') | |
| 1184 del tomo_zoom_list | |
| 1185 if not self.test_mode: | |
| 1186 title = f'red stack {zoom_perc}p theta {round(thetas[0], 2)+0}' | |
| 1187 quick_imshow(tomo_stack[0,:,:], title=title, path=path, save_fig=self.save_figs, | |
| 1188 save_only=self.save_only, block=self.block) | |
| 1189 # if not self.block: | |
| 1190 # clear_imshow(title) | |
| 1191 | |
| 1192 # Save test data to file | |
| 1193 if self.test_mode: | |
| 1194 # row_index = int(tomo_stack.shape[0]/2) | |
| 1195 # np.savetxt(f'{self.output_folder}/red_stack_{i+1}.txt', tomo_stack[row_index,:,:], | |
| 1196 # fmt='%.6e') | |
| 1197 row_index = int(tomo_stack.shape[1]/2) | |
| 1198 np.savetxt(f'{self.output_folder}/red_stack_{i+1}.txt', tomo_stack[:,row_index,:], | |
| 1199 fmt='%.6e') | |
| 1200 | |
| 1201 # Combine resized stacks | |
| 1202 reduced_tomo_stacks.append(tomo_stack) | |
| 1203 | |
| 1204 # Add tomo field info to reduced data NXprocess | |
| 1205 reduced_data['rotation_angle'] = thetas | |
| 1206 reduced_data['x_translation'] = np.asarray(horizontal_shifts) | |
| 1207 reduced_data['z_translation'] = np.asarray(vertical_shifts) | |
| 1208 reduced_data.data['tomo_fields'] = np.asarray(reduced_tomo_stacks) | |
| 1209 | |
| 1210 if tdf is not None: | |
| 1211 del tdf | |
| 1212 del tbf | |
| 1213 | |
| 1214 return(reduced_data) | |
| 1215 | |
| 1216 def _find_center_one_plane(self, sinogram, row, thetas, eff_pixel_size, cross_sectional_dim, | |
| 1217 path=None, tol=0.1, num_core=1): | |
| 1218 """Find center for a single tomography plane. | |
| 1219 """ | |
| 1220 # Try automatic center finding routines for initial value | |
| 1221 # sinogram index order: theta,column | |
| 1222 # need column,theta for iradon, so take transpose | |
| 1223 sinogram = np.asarray(sinogram) | |
| 1224 sinogram_T = sinogram.T | |
| 1225 center = sinogram.shape[1]/2 | |
| 1226 | |
| 1227 # Try using Nghia Vo’s method | |
| 1228 t0 = time() | |
| 1229 if num_core > num_core_tomopy_limit: | |
| 1230 logger.debug(f'Running find_center_vo on {num_core_tomopy_limit} cores ...') | |
| 1231 tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core_tomopy_limit) | |
| 1232 else: | |
| 1233 logger.debug(f'Running find_center_vo on {num_core} cores ...') | |
| 1234 tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core) | |
| 1235 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 1236 logger.info(f'Finding the center using Nghia Vo’s method took {time()-t0:.2f} seconds') | |
| 1237 center_offset_vo = tomo_center-center | |
| 1238 logger.info(f'Center at row {row} using Nghia Vo’s method = {center_offset_vo:.2f}') | |
| 1239 t0 = time() | |
| 1240 logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...') | |
| 1241 recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas, | |
| 1242 eff_pixel_size, cross_sectional_dim, False, num_core) | |
| 1243 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 1244 logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds') | |
| 1245 | |
| 1246 title = f'edges row{row} center offset{center_offset_vo:.2f} Vo' | |
| 1247 self._plot_edges_one_plane(recon_plane, title, path=path) | |
| 1248 | |
| 1249 # Try using phase correlation method | |
| 1250 # if input_yesno('Try finding center using phase correlation (y/n)?', 'n'): | |
| 1251 # t0 = time() | |
| 1252 # logger.debug(f'Running find_center_pc ...') | |
| 1253 # tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=0.1, rotc_guess=tomo_center) | |
| 1254 # error = 1. | |
| 1255 # while error > tol: | |
| 1256 # prev = tomo_center | |
| 1257 # tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=tol, | |
| 1258 # rotc_guess=tomo_center) | |
| 1259 # error = np.abs(tomo_center-prev) | |
| 1260 # logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 1261 # logger.info('Finding the center using the phase correlation method took '+ | |
| 1262 # f'{time()-t0:.2f} seconds') | |
| 1263 # center_offset = tomo_center-center | |
| 1264 # print(f'Center at row {row} using phase correlation = {center_offset:.2f}') | |
| 1265 # t0 = time() | |
| 1266 # logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...') | |
| 1267 # recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas, | |
| 1268 # eff_pixel_size, cross_sectional_dim, False, num_core) | |
| 1269 # logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 1270 # logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds') | |
| 1271 # | |
| 1272 # title = f'edges row{row} center_offset{center_offset:.2f} PC' | |
| 1273 # self._plot_edges_one_plane(recon_plane, title, path=path) | |
| 1274 | |
| 1275 # Select center location | |
| 1276 # if input_yesno('Accept a center location (y) or continue search (n)?', 'y'): | |
| 1277 if True: | |
| 1278 # center_offset = input_num(' Enter chosen center offset', ge=-center, le=center, | |
| 1279 # default=center_offset_vo) | |
| 1280 center_offset = center_offset_vo | |
| 1281 del sinogram_T | |
| 1282 del recon_plane | |
| 1283 return float(center_offset) | |
| 1284 | |
| 1285 # perform center finding search | |
| 1286 while True: | |
| 1287 center_offset_low = input_int('\nEnter lower bound for center offset', ge=-center, | |
| 1288 le=center) | |
| 1289 center_offset_upp = input_int('Enter upper bound for center offset', | |
| 1290 ge=center_offset_low, le=center) | |
| 1291 if center_offset_upp == center_offset_low: | |
| 1292 center_offset_step = 1 | |
| 1293 else: | |
| 1294 center_offset_step = input_int('Enter step size for center offset search', ge=1, | |
| 1295 le=center_offset_upp-center_offset_low) | |
| 1296 num_center_offset = 1+int((center_offset_upp-center_offset_low)/center_offset_step) | |
| 1297 center_offsets = np.linspace(center_offset_low, center_offset_upp, num_center_offset) | |
| 1298 for center_offset in center_offsets: | |
| 1299 if center_offset == center_offset_vo: | |
| 1300 continue | |
| 1301 t0 = time() | |
| 1302 logger.debug(f'Running _reconstruct_one_plane on {num_core} cores ...') | |
| 1303 recon_plane = self._reconstruct_one_plane(sinogram_T, center_offset+center, thetas, | |
| 1304 eff_pixel_size, cross_sectional_dim, False, num_core) | |
| 1305 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 1306 logger.info(f'Reconstructing center_offset {center_offset} took '+ | |
| 1307 f'{time()-t0:.2f} seconds') | |
| 1308 title = f'edges row{row} center_offset{center_offset:.2f}' | |
| 1309 self._plot_edges_one_plane(recon_plane, title, path=path) | |
| 1310 if input_int('\nContinue (0) or end the search (1)', ge=0, le=1): | |
| 1311 break | |
| 1312 | |
| 1313 del sinogram_T | |
| 1314 del recon_plane | |
| 1315 center_offset = input_num(' Enter chosen center offset', ge=-center, le=center) | |
| 1316 return float(center_offset) | |
| 1317 | |
| 1318 def _reconstruct_one_plane(self, tomo_plane_T, center, thetas, eff_pixel_size, | |
| 1319 cross_sectional_dim, plot_sinogram=True, num_core=1): | |
| 1320 """Invert the sinogram for a single tomography plane. | |
| 1321 """ | |
| 1322 # tomo_plane_T index order: column,theta | |
| 1323 assert(0 <= center < tomo_plane_T.shape[0]) | |
| 1324 center_offset = center-tomo_plane_T.shape[0]/2 | |
| 1325 two_offset = 2*int(np.round(center_offset)) | |
| 1326 two_offset_abs = np.abs(two_offset) | |
| 1327 max_rad = int(0.55*(cross_sectional_dim/eff_pixel_size)) # 10% slack to avoid edge effects | |
| 1328 if max_rad > 0.5*tomo_plane_T.shape[0]: | |
| 1329 max_rad = 0.5*tomo_plane_T.shape[0] | |
| 1330 dist_from_edge = max(1, int(np.floor((tomo_plane_T.shape[0]-two_offset_abs)/2.)-max_rad)) | |
| 1331 if two_offset >= 0: | |
| 1332 logger.debug(f'sinogram range = [{two_offset+dist_from_edge}, {-dist_from_edge}]') | |
| 1333 sinogram = tomo_plane_T[two_offset+dist_from_edge:-dist_from_edge,:] | |
| 1334 else: | |
| 1335 logger.debug(f'sinogram range = [{dist_from_edge}, {two_offset-dist_from_edge}]') | |
| 1336 sinogram = tomo_plane_T[dist_from_edge:two_offset-dist_from_edge,:] | |
| 1337 if not self.galaxy_flag and plot_sinogram: | |
| 1338 quick_imshow(sinogram.T, f'sinogram center offset{center_offset:.2f}', aspect='auto', | |
| 1339 path=self.output_folder, save_fig=self.save_figs, save_only=self.save_only, | |
| 1340 block=self.block) | |
| 1341 | |
| 1342 # Inverting sinogram | |
| 1343 t0 = time() | |
| 1344 recon_sinogram = iradon(sinogram, theta=thetas, circle=True) | |
| 1345 logger.debug(f'Inverting sinogram took {time()-t0:.2f} seconds') | |
| 1346 del sinogram | |
| 1347 | |
| 1348 # Performing Gaussian filtering and removing ring artifacts | |
| 1349 recon_parameters = None#self.config.get('recon_parameters') | |
| 1350 if recon_parameters is None: | |
| 1351 sigma = 1.0 | |
| 1352 ring_width = 15 | |
| 1353 else: | |
| 1354 sigma = recon_parameters.get('gaussian_sigma', 1.0) | |
| 1355 if not is_num(sigma, ge=0.0): | |
| 1356 logger.warning(f'Invalid gaussian_sigma ({sigma}) in _reconstruct_one_plane, '+ | |
| 1357 'set to a default value of 1.0') | |
| 1358 sigma = 1.0 | |
| 1359 ring_width = recon_parameters.get('ring_width', 15) | |
| 1360 if not is_int(ring_width, ge=0): | |
| 1361 logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+ | |
| 1362 'set to a default value of 15') | |
| 1363 ring_width = 15 | |
| 1364 t0 = time() | |
| 1365 recon_sinogram = spi.gaussian_filter(recon_sinogram, sigma, mode='nearest') | |
| 1366 recon_clean = np.expand_dims(recon_sinogram, axis=0) | |
| 1367 del recon_sinogram | |
| 1368 recon_clean = tomopy.misc.corr.remove_ring(recon_clean, rwidth=ring_width, ncore=num_core) | |
| 1369 logger.debug(f'Filtering and removing ring artifacts took {time()-t0:.2f} seconds') | |
| 1370 | |
| 1371 return recon_clean | |
| 1372 | |
| 1373 def _plot_edges_one_plane(self, recon_plane, title, path=None): | |
| 1374 vis_parameters = None#self.config.get('vis_parameters') | |
| 1375 if vis_parameters is None: | |
| 1376 weight = 0.1 | |
| 1377 else: | |
| 1378 weight = vis_parameters.get('denoise_weight', 0.1) | |
| 1379 if not is_num(weight, ge=0.0): | |
| 1380 logger.warning(f'Invalid weight ({weight}) in _plot_edges_one_plane, '+ | |
| 1381 'set to a default value of 0.1') | |
| 1382 weight = 0.1 | |
| 1383 edges = denoise_tv_chambolle(recon_plane, weight=weight) | |
| 1384 vmax = np.max(edges[0,:,:]) | |
| 1385 vmin = -vmax | |
| 1386 if path is None: | |
| 1387 path = self.output_folder | |
| 1388 quick_imshow(edges[0,:,:], f'{title} coolwarm', path=path, cmap='coolwarm', | |
| 1389 save_fig=self.save_figs, save_only=self.save_only, block=self.block) | |
| 1390 quick_imshow(edges[0,:,:], f'{title} gray', path=path, cmap='gray', vmin=vmin, vmax=vmax, | |
| 1391 save_fig=self.save_figs, save_only=self.save_only, block=self.block) | |
| 1392 del edges | |
| 1393 | |
| 1394 def _reconstruct_one_tomo_stack(self, tomo_stack, thetas, center_offsets=[], num_core=1, | |
| 1395 algorithm='gridrec'): | |
| 1396 """Reconstruct a single tomography stack. | |
| 1397 """ | |
| 1398 # tomo_stack order: row,theta,column | |
| 1399 # input thetas must be in degrees | |
| 1400 # centers_offset: tomography axis shift in pixels relative to column center | |
| 1401 # RV should we remove stripes? | |
| 1402 # https://tomopy.readthedocs.io/en/latest/api/tomopy.prep.stripe.html | |
| 1403 # RV should we remove rings? | |
| 1404 # https://tomopy.readthedocs.io/en/latest/api/tomopy.misc.corr.html | |
| 1405 # RV: Add an option to do (extra) secondary iterations later or to do some sort of convergence test? | |
| 1406 if not len(center_offsets): | |
| 1407 centers = np.zeros((tomo_stack.shape[0])) | |
| 1408 elif len(center_offsets) == 2: | |
| 1409 centers = np.linspace(center_offsets[0], center_offsets[1], tomo_stack.shape[0]) | |
| 1410 else: | |
| 1411 if center_offsets.size != tomo_stack.shape[0]: | |
| 1412 raise ValueError('center_offsets dimension mismatch in reconstruct_one_tomo_stack') | |
| 1413 centers = center_offsets | |
| 1414 centers += tomo_stack.shape[2]/2 | |
| 1415 | |
| 1416 # Get reconstruction parameters | |
| 1417 recon_parameters = None#self.config.get('recon_parameters') | |
| 1418 if recon_parameters is None: | |
| 1419 sigma = 2.0 | |
| 1420 secondary_iters = 0 | |
| 1421 ring_width = 15 | |
| 1422 else: | |
| 1423 sigma = recon_parameters.get('stripe_fw_sigma', 2.0) | |
| 1424 if not is_num(sigma, ge=0): | |
| 1425 logger.warning(f'Invalid stripe_fw_sigma ({sigma}) in '+ | |
| 1426 '_reconstruct_one_tomo_stack, set to a default value of 2.0') | |
| 1427 ring_width = 15 | |
| 1428 secondary_iters = recon_parameters.get('secondary_iters', 0) | |
| 1429 if not is_int(secondary_iters, ge=0): | |
| 1430 logger.warning(f'Invalid secondary_iters ({secondary_iters}) in '+ | |
| 1431 '_reconstruct_one_tomo_stack, set to a default value of 0 (skip them)') | |
| 1432 ring_width = 0 | |
| 1433 ring_width = recon_parameters.get('ring_width', 15) | |
| 1434 if not is_int(ring_width, ge=0): | |
| 1435 logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+ | |
| 1436 'set to a default value of 15') | |
| 1437 ring_width = 15 | |
| 1438 | |
| 1439 # Remove horizontal stripe | |
| 1440 t0 = time() | |
| 1441 if num_core > num_core_tomopy_limit: | |
| 1442 logger.debug('Running remove_stripe_fw on {num_core_tomopy_limit} cores ...') | |
| 1443 tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma, | |
| 1444 ncore=num_core_tomopy_limit) | |
| 1445 else: | |
| 1446 logger.debug(f'Running remove_stripe_fw on {num_core} cores ...') | |
| 1447 tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma, | |
| 1448 ncore=num_core) | |
| 1449 logger.debug(f'... tomopy.prep.stripe.remove_stripe_fw took {time()-t0:.2f} seconds') | |
| 1450 | |
| 1451 # Perform initial image reconstruction | |
| 1452 logger.debug('Performing initial image reconstruction') | |
| 1453 t0 = time() | |
| 1454 logger.debug(f'Running recon on {num_core} cores ...') | |
| 1455 tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers, | |
| 1456 sinogram_order=True, algorithm=algorithm, ncore=num_core) | |
| 1457 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 1458 logger.info(f'Performing initial image reconstruction took {time()-t0:.2f} seconds') | |
| 1459 | |
| 1460 # Run optional secondary iterations | |
| 1461 if secondary_iters > 0: | |
| 1462 logger.debug(f'Running {secondary_iters} secondary iterations') | |
| 1463 #options = {'method':'SIRT_CUDA', 'proj_type':'cuda', 'num_iter':secondary_iters} | |
| 1464 #RV: doesn't work for me: | |
| 1465 #"Error: CUDA error 803: system has unsupported display driver/cuda driver combination." | |
| 1466 #options = {'method':'SIRT', 'proj_type':'linear', 'MinConstraint': 0, 'num_iter':secondary_iters} | |
| 1467 #SIRT did not finish while running overnight | |
| 1468 #options = {'method':'SART', 'proj_type':'linear', 'num_iter':secondary_iters} | |
| 1469 options = {'method':'SART', 'proj_type':'linear', 'MinConstraint': 0, | |
| 1470 'num_iter':secondary_iters} | |
| 1471 t0 = time() | |
| 1472 logger.debug(f'Running recon on {num_core} cores ...') | |
| 1473 tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers, | |
| 1474 init_recon=tomo_recon_stack, options=options, sinogram_order=True, | |
| 1475 algorithm=tomopy.astra, ncore=num_core) | |
| 1476 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 1477 logger.info(f'Performing secondary iterations took {time()-t0:.2f} seconds') | |
| 1478 | |
| 1479 # Remove ring artifacts | |
| 1480 t0 = time() | |
| 1481 tomopy.misc.corr.remove_ring(tomo_recon_stack, rwidth=ring_width, out=tomo_recon_stack, | |
| 1482 ncore=num_core) | |
| 1483 logger.debug(f'Removing ring artifacts took {time()-t0:.2f} seconds') | |
| 1484 | |
| 1485 return tomo_recon_stack | |
| 1486 | |
| 1487 def _resize_reconstructed_data(self, data, z_only=False): | |
| 1488 """Resize the reconstructed tomography data. | |
| 1489 """ | |
| 1490 # Data order: row(z),x,y or stack,row(z),x,y | |
| 1491 if isinstance(data, list): | |
| 1492 for stack in data: | |
| 1493 assert(stack.ndim == 3) | |
| 1494 num_tomo_stacks = len(data) | |
| 1495 tomo_recon_stacks = data | |
| 1496 else: | |
| 1497 assert(data.ndim == 3) | |
| 1498 num_tomo_stacks = 1 | |
| 1499 tomo_recon_stacks = [data] | |
| 1500 | |
| 1501 if z_only: | |
| 1502 x_bounds = None | |
| 1503 y_bounds = None | |
| 1504 else: | |
| 1505 # Selecting x bounds (in yz-plane) | |
| 1506 tomosum = 0 | |
| 1507 [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,2)) | |
| 1508 for i in range(num_tomo_stacks)] | |
| 1509 select_x_bounds = input_yesno('\nDo you want to change the image x-bounds (y/n)?', 'y') | |
| 1510 if not select_x_bounds: | |
| 1511 x_bounds = None | |
| 1512 else: | |
| 1513 accept = False | |
| 1514 index_ranges = None | |
| 1515 while not accept: | |
| 1516 mask, x_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, | |
| 1517 title='select x data range', legend='recon stack sum yz') | |
| 1518 while len(x_bounds) != 1: | |
| 1519 print('Please select exactly one continuous range') | |
| 1520 mask, x_bounds = draw_mask_1d(tomosum, title='select x data range', | |
| 1521 legend='recon stack sum yz') | |
| 1522 x_bounds = x_bounds[0] | |
| 1523 # quick_plot(tomosum, vlines=x_bounds, title='recon stack sum yz') | |
| 1524 # print(f'x_bounds = {x_bounds} (lower bound inclusive, upper bound '+ | |
| 1525 # 'exclusive)') | |
| 1526 # accept = input_yesno('Accept these bounds (y/n)?', 'y') | |
| 1527 accept = True | |
| 1528 logger.debug(f'x_bounds = {x_bounds}') | |
| 1529 | |
| 1530 # Selecting y bounds (in xz-plane) | |
| 1531 tomosum = 0 | |
| 1532 [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,1)) | |
| 1533 for i in range(num_tomo_stacks)] | |
| 1534 select_y_bounds = input_yesno('\nDo you want to change the image y-bounds (y/n)?', 'y') | |
| 1535 if not select_y_bounds: | |
| 1536 y_bounds = None | |
| 1537 else: | |
| 1538 accept = False | |
| 1539 index_ranges = None | |
| 1540 while not accept: | |
| 1541 mask, y_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, | |
| 1542 title='select x data range', legend='recon stack sum xz') | |
| 1543 while len(y_bounds) != 1: | |
| 1544 print('Please select exactly one continuous range') | |
| 1545 mask, y_bounds = draw_mask_1d(tomosum, title='select x data range', | |
| 1546 legend='recon stack sum xz') | |
| 1547 y_bounds = y_bounds[0] | |
| 1548 # quick_plot(tomosum, vlines=y_bounds, title='recon stack sum xz') | |
| 1549 # print(f'y_bounds = {y_bounds} (lower bound inclusive, upper bound '+ | |
| 1550 # 'exclusive)') | |
| 1551 # accept = input_yesno('Accept these bounds (y/n)?', 'y') | |
| 1552 accept = True | |
| 1553 logger.debug(f'y_bounds = {y_bounds}') | |
| 1554 | |
| 1555 # Selecting z bounds (in xy-plane) (only valid for a single image stack) | |
| 1556 if num_tomo_stacks != 1: | |
| 1557 z_bounds = None | |
| 1558 else: | |
| 1559 tomosum = 0 | |
| 1560 [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(1,2)) | |
| 1561 for i in range(num_tomo_stacks)] | |
| 1562 select_z_bounds = input_yesno('Do you want to change the image z-bounds (y/n)?', 'n') | |
| 1563 if not select_z_bounds: | |
| 1564 z_bounds = None | |
| 1565 else: | |
| 1566 accept = False | |
| 1567 index_ranges = None | |
| 1568 while not accept: | |
| 1569 mask, z_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, | |
| 1570 title='select x data range', legend='recon stack sum xy') | |
| 1571 while len(z_bounds) != 1: | |
| 1572 print('Please select exactly one continuous range') | |
| 1573 mask, z_bounds = draw_mask_1d(tomosum, title='select x data range', | |
| 1574 legend='recon stack sum xy') | |
| 1575 z_bounds = z_bounds[0] | |
| 1576 # quick_plot(tomosum, vlines=z_bounds, title='recon stack sum xy') | |
| 1577 # print(f'z_bounds = {z_bounds} (lower bound inclusive, upper bound '+ | |
| 1578 # 'exclusive)') | |
| 1579 # accept = input_yesno('Accept these bounds (y/n)?', 'y') | |
| 1580 accept = True | |
| 1581 logger.debug(f'z_bounds = {z_bounds}') | |
| 1582 | |
| 1583 return(x_bounds, y_bounds, z_bounds) | |
| 1584 | |
| 1585 | |
| 1586 def run_tomo(input_file:str, output_file:str, modes:list[str], center_file=None, num_core=-1, | |
| 1587 output_folder='.', save_figs='no', test_mode=False) -> None: | |
| 1588 | |
| 1589 if test_mode: | |
| 1590 logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s' | |
| 1591 level = logging.getLevelName('INFO') | |
| 1592 logging.basicConfig(filename=f'{output_folder}/tomo.log', filemode='w', | |
| 1593 format=logging_format, level=level, force=True) | |
| 1594 logger.info(f'input_file = {input_file}') | |
| 1595 logger.info(f'center_file = {center_file}') | |
| 1596 logger.info(f'output_file = {output_file}') | |
| 1597 logger.debug(f'modes= {modes}') | |
| 1598 logger.debug(f'num_core= {num_core}') | |
| 1599 logger.info(f'output_folder = {output_folder}') | |
| 1600 logger.info(f'save_figs = {save_figs}') | |
| 1601 logger.info(f'test_mode = {test_mode}') | |
| 1602 | |
| 1603 # Check for correction modes | |
| 1604 legal_modes = ['reduce_data', 'find_center', 'reconstruct_data', 'combine_data', 'all'] | |
| 1605 if modes is None: | |
| 1606 modes = ['all'] | |
| 1607 if not all(True if mode in legal_modes else False for mode in modes): | |
| 1608 raise ValueError(f'Invalid parameter modes ({modes})') | |
| 1609 | |
| 1610 # Instantiate Tomo object | |
| 1611 tomo = Tomo(num_core=num_core, output_folder=output_folder, save_figs=save_figs, | |
| 1612 test_mode=test_mode) | |
| 1613 | |
| 1614 # Read input file | |
| 1615 data = tomo.read(input_file) | |
| 1616 | |
| 1617 # Generate reduced tomography images | |
| 1618 if 'reduce_data' in modes or 'all' in modes: | |
| 1619 data = tomo.gen_reduced_data(data) | |
| 1620 | |
| 1621 # Find rotation axis centers for the tomography stacks. | |
| 1622 center_data = None | |
| 1623 if 'find_center' in modes or 'all' in modes: | |
| 1624 center_data = tomo.find_centers(data) | |
| 1625 | |
| 1626 # Reconstruct tomography stacks | |
| 1627 if 'reconstruct_data' in modes or 'all' in modes: | |
| 1628 if center_data is None: | |
| 1629 # Read input file | |
| 1630 center_data = tomo.read(center_file) | |
| 1631 data = tomo.reconstruct_data(data, center_data) | |
| 1632 center_data = None | |
| 1633 | |
| 1634 # Combine reconstructed tomography stacks | |
| 1635 if 'combine_data' in modes or 'all' in modes: | |
| 1636 data = tomo.combine_data(data) | |
| 1637 | |
| 1638 # Write output file | |
| 1639 if data is not None and not test_mode: | |
| 1640 if center_data is None: | |
| 1641 data = tomo.write(data, output_file) | |
| 1642 else: | |
| 1643 data = tomo.write(center_data, output_file) | |
| 1644 | |
| 1645 logger.info(f'Completed modes: {modes}') |
