Mercurial > repos > rv43 > chess_tomo
comparison workflow/run_tomo.py @ 4:9aa288729b9a draft
planemo upload for repository https://github.com/rolfverberg/galaxytools commit f8c4bdb31c20c468045ad5e6eb255a293244bc6c-dirty
| author | rv43 |
|---|---|
| date | Mon, 20 Mar 2023 18:44:23 +0000 |
| parents | |
| children | 543dba81eb15 |
comparison
equal
deleted
inserted
replaced
| 3:fc38431f257f | 4:9aa288729b9a |
|---|---|
| 1 #!/usr/bin/env python3 | |
| 2 | |
| 3 import logging | |
| 4 logger = logging.getLogger(__name__) | |
| 5 | |
| 6 import numpy as np | |
| 7 try: | |
| 8 import numexpr as ne | |
| 9 except: | |
| 10 pass | |
| 11 try: | |
| 12 import scipy.ndimage as spi | |
| 13 except: | |
| 14 pass | |
| 15 | |
| 16 from multiprocessing import cpu_count | |
| 17 from nexusformat.nexus import * | |
| 18 from os import mkdir | |
| 19 from os import path as os_path | |
| 20 try: | |
| 21 from skimage.transform import iradon | |
| 22 except: | |
| 23 pass | |
| 24 try: | |
| 25 from skimage.restoration import denoise_tv_chambolle | |
| 26 except: | |
| 27 pass | |
| 28 from time import time | |
| 29 try: | |
| 30 import tomopy | |
| 31 except: | |
| 32 pass | |
| 33 from yaml import safe_load, safe_dump | |
| 34 | |
| 35 try: | |
| 36 from msnctools.fit import Fit | |
| 37 except: | |
| 38 from fit import Fit | |
| 39 try: | |
| 40 from msnctools.general import illegal_value, is_int, is_int_pair, is_num, is_index_range, \ | |
| 41 input_int, input_num, input_yesno, input_menu, draw_mask_1d, select_image_bounds, \ | |
| 42 select_one_image_bound, clear_imshow, quick_imshow, clear_plot, quick_plot | |
| 43 except: | |
| 44 from general import illegal_value, is_int, is_int_pair, is_num, is_index_range, \ | |
| 45 input_int, input_num, input_yesno, input_menu, draw_mask_1d, select_image_bounds, \ | |
| 46 select_one_image_bound, clear_imshow, quick_imshow, clear_plot, quick_plot | |
| 47 | |
| 48 try: | |
| 49 from workflow.models import import_scanparser, FlatField, TomoField, TomoWorkflow | |
| 50 from workflow.__version__ import __version__ | |
| 51 except: | |
| 52 pass | |
| 53 | |
| 54 num_core_tomopy_limit = 24 | |
| 55 | |
| 56 def nxcopy(nxobject:NXobject, exclude_nxpaths:list[str]=[], nxpath_prefix:str='') -> NXobject: | |
| 57 '''Function that returns a copy of a nexus object, optionally exluding certain child items. | |
| 58 | |
| 59 :param nxobject: the original nexus object to return a "copy" of | |
| 60 :type nxobject: nexusformat.nexus.NXobject | |
| 61 :param exlude_nxpaths: a list of paths to child nexus objects that | |
| 62 should be exluded from the returned "copy", defaults to `[]` | |
| 63 :type exclude_nxpaths: list[str], optional | |
| 64 :param nxpath_prefix: For use in recursive calls from inside this | |
| 65 function only! | |
| 66 :type nxpath_prefix: str | |
| 67 :return: a copy of `nxobject` with some children optionally exluded. | |
| 68 :rtype: NXobject | |
| 69 ''' | |
| 70 | |
| 71 nxobject_copy = nxobject.__class__() | |
| 72 if not len(nxpath_prefix): | |
| 73 if 'default' in nxobject.attrs: | |
| 74 nxobject_copy.attrs['default'] = nxobject.attrs['default'] | |
| 75 else: | |
| 76 for k, v in nxobject.attrs.items(): | |
| 77 nxobject_copy.attrs[k] = v | |
| 78 | |
| 79 for k, v in nxobject.items(): | |
| 80 nxpath = os_path.join(nxpath_prefix, k) | |
| 81 | |
| 82 if nxpath in exclude_nxpaths: | |
| 83 continue | |
| 84 | |
| 85 if isinstance(v, NXgroup): | |
| 86 nxobject_copy[k] = nxcopy(v, exclude_nxpaths=exclude_nxpaths, | |
| 87 nxpath_prefix=os_path.join(nxpath_prefix, k)) | |
| 88 else: | |
| 89 nxobject_copy[k] = v | |
| 90 | |
| 91 return(nxobject_copy) | |
| 92 | |
| 93 class set_numexpr_threads: | |
| 94 | |
| 95 def __init__(self, num_core): | |
| 96 if num_core is None or num_core < 1 or num_core > cpu_count(): | |
| 97 self.num_core = cpu_count() | |
| 98 else: | |
| 99 self.num_core = num_core | |
| 100 | |
| 101 def __enter__(self): | |
| 102 self.num_core_org = ne.set_num_threads(self.num_core) | |
| 103 | |
| 104 def __exit__(self, exc_type, exc_value, traceback): | |
| 105 ne.set_num_threads(self.num_core_org) | |
| 106 | |
| 107 class Tomo: | |
| 108 """Processing tomography data with misalignment. | |
| 109 """ | |
| 110 def __init__(self, galaxy_flag=False, num_core=-1, output_folder='.', save_figs=None, | |
| 111 test_mode=False): | |
| 112 """Initialize with optional config input file or dictionary | |
| 113 """ | |
| 114 if not isinstance(galaxy_flag, bool): | |
| 115 raise ValueError(f'Invalid parameter galaxy_flag ({galaxy_flag})') | |
| 116 self.galaxy_flag = galaxy_flag | |
| 117 self.num_core = num_core | |
| 118 if self.galaxy_flag: | |
| 119 if output_folder != '.': | |
| 120 logger.warning('Ignoring output_folder in galaxy mode') | |
| 121 self.output_folder = '.' | |
| 122 if test_mode != False: | |
| 123 logger.warning('Ignoring test_mode in galaxy mode') | |
| 124 self.test_mode = False | |
| 125 if save_figs is not None: | |
| 126 logger.warning('Ignoring save_figs in galaxy mode') | |
| 127 save_figs = 'only' | |
| 128 else: | |
| 129 self.output_folder = os_path.abspath(output_folder) | |
| 130 if not os_path.isdir(output_folder): | |
| 131 mkdir(os_path.abspath(output_folder)) | |
| 132 if not isinstance(test_mode, bool): | |
| 133 raise ValueError(f'Invalid parameter test_mode ({test_mode})') | |
| 134 self.test_mode = test_mode | |
| 135 if save_figs is None: | |
| 136 save_figs = 'no' | |
| 137 self.test_config = {} | |
| 138 if self.test_mode: | |
| 139 if save_figs != 'only': | |
| 140 logger.warning('Ignoring save_figs in test mode') | |
| 141 save_figs = 'only' | |
| 142 if save_figs == 'only': | |
| 143 self.save_only = True | |
| 144 self.save_figs = True | |
| 145 elif save_figs == 'yes': | |
| 146 self.save_only = False | |
| 147 self.save_figs = True | |
| 148 elif save_figs == 'no': | |
| 149 self.save_only = False | |
| 150 self.save_figs = False | |
| 151 else: | |
| 152 raise ValueError(f'Invalid parameter save_figs ({save_figs})') | |
| 153 if self.save_only: | |
| 154 self.block = False | |
| 155 else: | |
| 156 self.block = True | |
| 157 if self.num_core == -1: | |
| 158 self.num_core = cpu_count() | |
| 159 if not is_int(self.num_core, gt=0, log=False): | |
| 160 raise ValueError(f'Invalid parameter num_core ({num_core})') | |
| 161 if self.num_core > cpu_count(): | |
| 162 logger.warning(f'num_core = {self.num_core} is larger than the number of available ' | |
| 163 f'processors and reduced to {cpu_count()}') | |
| 164 self.num_core= cpu_count() | |
| 165 | |
| 166 def read(self, filename): | |
| 167 extension = os_path.splitext(filename)[1] | |
| 168 if extension == '.yml' or extension == '.yaml': | |
| 169 with open(filename, 'r') as f: | |
| 170 config = safe_load(f) | |
| 171 # if len(config) > 1: | |
| 172 # raise ValueError(f'Multiple root entries in {filename} not yet implemented') | |
| 173 # if len(list(config.values())[0]) > 1: | |
| 174 # raise ValueError(f'Multiple sample maps in {filename} not yet implemented') | |
| 175 return(config) | |
| 176 elif extension == '.nxs': | |
| 177 with NXFile(filename, mode='r') as nxfile: | |
| 178 nxroot = nxfile.readfile() | |
| 179 return(nxroot) | |
| 180 else: | |
| 181 raise ValueError(f'Invalid filename extension ({extension})') | |
| 182 | |
| 183 def write(self, data, filename): | |
| 184 extension = os_path.splitext(filename)[1] | |
| 185 if extension == '.yml' or extension == '.yaml': | |
| 186 with open(filename, 'w') as f: | |
| 187 safe_dump(data, f) | |
| 188 elif extension == '.nxs': | |
| 189 data.save(filename, mode='w') | |
| 190 elif extension == '.nc': | |
| 191 data.to_netcdf(os_path=filename) | |
| 192 else: | |
| 193 raise ValueError(f'Invalid filename extension ({extension})') | |
| 194 | |
| 195 def gen_reduced_data(self, data, img_x_bounds=None): | |
| 196 """Generate the reduced tomography images. | |
| 197 """ | |
| 198 logger.info('Generate the reduced tomography images') | |
| 199 | |
| 200 # Create plot galaxy path directory if needed | |
| 201 if self.galaxy_flag and not os_path.exists('tomo_reduce_plots'): | |
| 202 mkdir('tomo_reduce_plots') | |
| 203 | |
| 204 if isinstance(data, dict): | |
| 205 # Create Nexus format object from input dictionary | |
| 206 wf = TomoWorkflow(**data) | |
| 207 if len(wf.sample_maps) > 1: | |
| 208 raise ValueError(f'Multiple sample maps not yet implemented') | |
| 209 # print(f'\nwf:\n{wf}\n') | |
| 210 nxroot = NXroot() | |
| 211 t0 = time() | |
| 212 for sample_map in wf.sample_maps: | |
| 213 logger.info(f'Start constructing the {sample_map.title} map.') | |
| 214 import_scanparser(sample_map.station) | |
| 215 sample_map.construct_nxentry(nxroot, include_raw_data=False) | |
| 216 logger.info(f'Constructed all sample maps in {time()-t0:.2f} seconds.') | |
| 217 nxentry = nxroot[nxroot.attrs['default']] | |
| 218 # Get test mode configuration info | |
| 219 if self.test_mode: | |
| 220 self.test_config = data['sample_maps'][0]['test_mode'] | |
| 221 elif isinstance(data, NXroot): | |
| 222 nxentry = data[data.attrs['default']] | |
| 223 else: | |
| 224 raise ValueError(f'Invalid parameter data ({data})') | |
| 225 | |
| 226 # Create an NXprocess to store data reduction (meta)data | |
| 227 reduced_data = NXprocess() | |
| 228 | |
| 229 # Generate dark field | |
| 230 if 'dark_field' in nxentry['spec_scans']: | |
| 231 reduced_data = self._gen_dark(nxentry, reduced_data) | |
| 232 | |
| 233 # Generate bright field | |
| 234 reduced_data = self._gen_bright(nxentry, reduced_data) | |
| 235 | |
| 236 # Set vertical detector bounds for image stack | |
| 237 img_x_bounds = self._set_detector_bounds(nxentry, reduced_data, img_x_bounds=img_x_bounds) | |
| 238 logger.info(f'img_x_bounds = {img_x_bounds}') | |
| 239 reduced_data['img_x_bounds'] = img_x_bounds | |
| 240 | |
| 241 # Set zoom and/or theta skip to reduce memory the requirement | |
| 242 zoom_perc, num_theta_skip = self._set_zoom_or_skip() | |
| 243 if zoom_perc is not None: | |
| 244 reduced_data.attrs['zoom_perc'] = zoom_perc | |
| 245 if num_theta_skip is not None: | |
| 246 reduced_data.attrs['num_theta_skip'] = num_theta_skip | |
| 247 | |
| 248 # Generate reduced tomography fields | |
| 249 reduced_data = self._gen_tomo(nxentry, reduced_data) | |
| 250 | |
| 251 # Create a copy of the input Nexus object and remove raw and any existing reduced data | |
| 252 if isinstance(data, NXroot): | |
| 253 exclude_items = [f'{nxentry._name}/reduced_data/data', | |
| 254 f'{nxentry._name}/instrument/detector/data', | |
| 255 f'{nxentry._name}/instrument/detector/image_key', | |
| 256 f'{nxentry._name}/instrument/detector/sequence_number', | |
| 257 f'{nxentry._name}/sample/rotation_angle', | |
| 258 f'{nxentry._name}/sample/x_translation', | |
| 259 f'{nxentry._name}/sample/z_translation', | |
| 260 f'{nxentry._name}/data/data', | |
| 261 f'{nxentry._name}/data/image_key', | |
| 262 f'{nxentry._name}/data/rotation_angle', | |
| 263 f'{nxentry._name}/data/x_translation', | |
| 264 f'{nxentry._name}/data/z_translation'] | |
| 265 nxroot = nxcopy(data, exclude_nxpaths=exclude_items) | |
| 266 nxentry = nxroot[nxroot.attrs['default']] | |
| 267 | |
| 268 # Add the reduced data NXprocess | |
| 269 nxentry.reduced_data = reduced_data | |
| 270 | |
| 271 if 'data' not in nxentry: | |
| 272 nxentry.data = NXdata() | |
| 273 nxentry.attrs['default'] = 'data' | |
| 274 nxentry.data.makelink(nxentry.reduced_data.data.tomo_fields, name='reduced_data') | |
| 275 nxentry.data.makelink(nxentry.reduced_data.rotation_angle, name='rotation_angle') | |
| 276 nxentry.data.attrs['signal'] = 'reduced_data' | |
| 277 | |
| 278 return(nxroot) | |
| 279 | |
| 280 def find_centers(self, nxroot, center_rows=None, center_stack_index=None): | |
| 281 """Find the calibrated center axis info | |
| 282 """ | |
| 283 logger.info('Find the calibrated center axis info') | |
| 284 | |
| 285 if not isinstance(nxroot, NXroot): | |
| 286 raise ValueError(f'Invalid parameter nxroot ({nxroot})') | |
| 287 nxentry = nxroot[nxroot.attrs['default']] | |
| 288 if not isinstance(nxentry, NXentry): | |
| 289 raise ValueError(f'Invalid nxentry ({nxentry})') | |
| 290 if self.galaxy_flag: | |
| 291 if center_rows is not None: | |
| 292 center_rows = tuple(center_rows) | |
| 293 if not is_int_pair(center_rows): | |
| 294 raise ValueError(f'Invalid parameter center_rows ({center_rows})') | |
| 295 elif center_rows is not None: | |
| 296 logger.warning(f'Ignoring parameter center_rows ({center_rows})') | |
| 297 center_rows = None | |
| 298 if self.galaxy_flag: | |
| 299 if center_stack_index is not None and not is_int(center_stack_index, ge=0): | |
| 300 raise ValueError(f'Invalid parameter center_stack_index ({center_stack_index})') | |
| 301 elif center_stack_index is not None: | |
| 302 logger.warning(f'Ignoring parameter center_stack_index ({center_stack_index})') | |
| 303 center_stack_index = None | |
| 304 | |
| 305 # Create plot galaxy path directory and path if needed | |
| 306 if self.galaxy_flag: | |
| 307 if not os_path.exists('tomo_find_centers_plots'): | |
| 308 mkdir('tomo_find_centers_plots') | |
| 309 path = 'tomo_find_centers_plots' | |
| 310 else: | |
| 311 path = self.output_folder | |
| 312 | |
| 313 # Check if reduced data is available | |
| 314 if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data): | |
| 315 raise KeyError(f'Unable to find valid reduced data in {nxentry}.') | |
| 316 | |
| 317 # Select the image stack to calibrate the center axis | |
| 318 # reduced data axes order: stack,theta,row,column | |
| 319 # Note: Nexus cannot follow a link if the data it points to is too big, | |
| 320 # so get the data from the actual place, not from nxentry.data | |
| 321 tomo_fields_shape = nxentry.reduced_data.data.tomo_fields.shape | |
| 322 if len(tomo_fields_shape) != 4 or any(True for dim in tomo_fields_shape if not dim): | |
| 323 raise KeyError('Unable to load the required reduced tomography stack') | |
| 324 num_tomo_stacks = tomo_fields_shape[0] | |
| 325 if num_tomo_stacks == 1: | |
| 326 center_stack_index = 0 | |
| 327 default = 'n' | |
| 328 else: | |
| 329 if self.test_mode: | |
| 330 center_stack_index = self.test_config['center_stack_index']-1 # make offset 0 | |
| 331 elif self.galaxy_flag: | |
| 332 if center_stack_index is None: | |
| 333 center_stack_index = int(num_tomo_stacks/2) | |
| 334 if center_stack_index >= num_tomo_stacks: | |
| 335 raise ValueError(f'Invalid parameter center_stack_index ({center_stack_index})') | |
| 336 else: | |
| 337 center_stack_index = input_int('\nEnter tomography stack index to calibrate the ' | |
| 338 'center axis', ge=1, le=num_tomo_stacks, default=int(1+num_tomo_stacks/2)) | |
| 339 center_stack_index -= 1 | |
| 340 default = 'y' | |
| 341 | |
| 342 # Get thetas (in degrees) | |
| 343 thetas = np.asarray(nxentry.reduced_data.rotation_angle) | |
| 344 | |
| 345 # Get effective pixel_size | |
| 346 if 'zoom_perc' in nxentry.reduced_data: | |
| 347 eff_pixel_size = 100.*(nxentry.instrument.detector.x_pixel_size/ | |
| 348 nxentry.reduced_data.attrs['zoom_perc']) | |
| 349 else: | |
| 350 eff_pixel_size = nxentry.instrument.detector.x_pixel_size | |
| 351 | |
| 352 # Get cross sectional diameter | |
| 353 cross_sectional_dim = tomo_fields_shape[3]*eff_pixel_size | |
| 354 logger.debug(f'cross_sectional_dim = {cross_sectional_dim}') | |
| 355 | |
| 356 # Determine center offset at sample row boundaries | |
| 357 logger.info('Determine center offset at sample row boundaries') | |
| 358 | |
| 359 # Lower row center | |
| 360 if self.test_mode: | |
| 361 lower_row = self.test_config['lower_row'] | |
| 362 elif self.galaxy_flag: | |
| 363 if center_rows is None: | |
| 364 lower_row = 0 | |
| 365 else: | |
| 366 lower_row = min(center_rows) | |
| 367 if not 0 <= lower_row < tomo_fields_shape[2]-1: | |
| 368 raise ValueError(f'Invalid parameter center_rows ({center_rows})') | |
| 369 else: | |
| 370 lower_row = select_one_image_bound( | |
| 371 nxentry.reduced_data.data.tomo_fields[center_stack_index,0,:,:], 0, bound=0, | |
| 372 title=f'theta={round(thetas[0], 2)+0}', | |
| 373 bound_name='row index to find lower center', default=default, raise_error=True) | |
| 374 logger.debug('Finding center...') | |
| 375 t0 = time() | |
| 376 lower_center_offset = self._find_center_one_plane( | |
| 377 #np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index,:,lower_row,:]), | |
| 378 nxentry.reduced_data.data.tomo_fields[center_stack_index,:,lower_row,:], | |
| 379 lower_row, thetas, eff_pixel_size, cross_sectional_dim, path=path, | |
| 380 num_core=self.num_core) | |
| 381 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 382 logger.debug(f'lower_row = {lower_row:.2f}') | |
| 383 logger.debug(f'lower_center_offset = {lower_center_offset:.2f}') | |
| 384 | |
| 385 # Upper row center | |
| 386 if self.test_mode: | |
| 387 upper_row = self.test_config['upper_row'] | |
| 388 elif self.galaxy_flag: | |
| 389 if center_rows is None: | |
| 390 upper_row = tomo_fields_shape[2]-1 | |
| 391 else: | |
| 392 upper_row = max(center_rows) | |
| 393 if not lower_row < upper_row < tomo_fields_shape[2]: | |
| 394 raise ValueError(f'Invalid parameter center_rows ({center_rows})') | |
| 395 else: | |
| 396 upper_row = select_one_image_bound( | |
| 397 nxentry.reduced_data.data.tomo_fields[center_stack_index,0,:,:], 0, | |
| 398 bound=tomo_fields_shape[2]-1, title=f'theta={round(thetas[0], 2)+0}', | |
| 399 bound_name='row index to find upper center', default=default, raise_error=True) | |
| 400 logger.debug('Finding center...') | |
| 401 t0 = time() | |
| 402 upper_center_offset = self._find_center_one_plane( | |
| 403 #np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index,:,upper_row,:]), | |
| 404 nxentry.reduced_data.data.tomo_fields[center_stack_index,:,upper_row,:], | |
| 405 upper_row, thetas, eff_pixel_size, cross_sectional_dim, path=path, | |
| 406 num_core=self.num_core) | |
| 407 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 408 logger.debug(f'upper_row = {upper_row:.2f}') | |
| 409 logger.debug(f'upper_center_offset = {upper_center_offset:.2f}') | |
| 410 | |
| 411 center_config = {'lower_row': lower_row, 'lower_center_offset': lower_center_offset, | |
| 412 'upper_row': upper_row, 'upper_center_offset': upper_center_offset} | |
| 413 if num_tomo_stacks > 1: | |
| 414 center_config['center_stack_index'] = center_stack_index+1 # save as offset 1 | |
| 415 | |
| 416 # Save test data to file | |
| 417 if self.test_mode: | |
| 418 with open(f'{self.output_folder}/center_config.yaml', 'w') as f: | |
| 419 safe_dump(center_config, f) | |
| 420 | |
| 421 return(center_config) | |
| 422 | |
| 423 def reconstruct_data(self, nxroot, center_info, x_bounds=None, y_bounds=None): | |
| 424 """Reconstruct the tomography data. | |
| 425 """ | |
| 426 logger.info('Reconstruct the tomography data') | |
| 427 | |
| 428 if not isinstance(nxroot, NXroot): | |
| 429 raise ValueError(f'Invalid parameter nxroot ({nxroot})') | |
| 430 nxentry = nxroot[nxroot.attrs['default']] | |
| 431 if not isinstance(nxentry, NXentry): | |
| 432 raise ValueError(f'Invalid nxentry ({nxentry})') | |
| 433 if not isinstance(center_info, dict): | |
| 434 raise ValueError(f'Invalid parameter center_info ({center_info})') | |
| 435 | |
| 436 # Create plot galaxy path directory and path if needed | |
| 437 if self.galaxy_flag: | |
| 438 if not os_path.exists('tomo_reconstruct_plots'): | |
| 439 mkdir('tomo_reconstruct_plots') | |
| 440 path = 'tomo_reconstruct_plots' | |
| 441 else: | |
| 442 path = self.output_folder | |
| 443 | |
| 444 # Check if reduced data is available | |
| 445 if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data): | |
| 446 raise KeyError(f'Unable to find valid reduced data in {nxentry}.') | |
| 447 | |
| 448 # Create an NXprocess to store image reconstruction (meta)data | |
| 449 nxprocess = NXprocess() | |
| 450 | |
| 451 # Get rotation axis rows and centers | |
| 452 lower_row = center_info.get('lower_row') | |
| 453 lower_center_offset = center_info.get('lower_center_offset') | |
| 454 upper_row = center_info.get('upper_row') | |
| 455 upper_center_offset = center_info.get('upper_center_offset') | |
| 456 if (lower_row is None or lower_center_offset is None or upper_row is None or | |
| 457 upper_center_offset is None): | |
| 458 raise KeyError(f'Unable to find valid calibrated center axis info in {center_info}.') | |
| 459 center_slope = (upper_center_offset-lower_center_offset)/(upper_row-lower_row) | |
| 460 | |
| 461 # Get thetas (in degrees) | |
| 462 thetas = np.asarray(nxentry.reduced_data.rotation_angle) | |
| 463 | |
| 464 # Reconstruct tomography data | |
| 465 # reduced data axes order: stack,theta,row,column | |
| 466 # reconstructed data order in each stack: row/z,x,y | |
| 467 # Note: Nexus cannot follow a link if the data it points to is too big, | |
| 468 # so get the data from the actual place, not from nxentry.data | |
| 469 if 'zoom_perc' in nxentry.reduced_data: | |
| 470 res_title = f'{nxentry.reduced_data.attrs["zoom_perc"]}p' | |
| 471 else: | |
| 472 res_title = 'fullres' | |
| 473 load_error = False | |
| 474 num_tomo_stacks = nxentry.reduced_data.data.tomo_fields.shape[0] | |
| 475 tomo_recon_stacks = num_tomo_stacks*[np.array([])] | |
| 476 for i in range(num_tomo_stacks): | |
| 477 # Convert reduced data stack from theta,row,column to row,theta,column | |
| 478 logger.debug(f'Reading reduced data stack {i+1}...') | |
| 479 t0 = time() | |
| 480 tomo_stack = np.asarray(nxentry.reduced_data.data.tomo_fields[i]) | |
| 481 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 482 if len(tomo_stack.shape) != 3 or any(True for dim in tomo_stack.shape if not dim): | |
| 483 raise ValueError(f'Unable to load tomography stack {i+1} for reconstruction') | |
| 484 tomo_stack = np.swapaxes(tomo_stack, 0, 1) | |
| 485 assert(len(thetas) == tomo_stack.shape[1]) | |
| 486 assert(0 <= lower_row < upper_row < tomo_stack.shape[0]) | |
| 487 center_offsets = [lower_center_offset-lower_row*center_slope, | |
| 488 upper_center_offset+(tomo_stack.shape[0]-1-upper_row)*center_slope] | |
| 489 t0 = time() | |
| 490 logger.debug(f'Running _reconstruct_one_tomo_stack on {self.num_core} cores ...') | |
| 491 tomo_recon_stack = self._reconstruct_one_tomo_stack(tomo_stack, thetas, | |
| 492 center_offsets=center_offsets, num_core=self.num_core, algorithm='gridrec') | |
| 493 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 494 logger.info(f'Reconstruction of stack {i+1} took {time()-t0:.2f} seconds') | |
| 495 | |
| 496 # Combine stacks | |
| 497 tomo_recon_stacks[i] = tomo_recon_stack | |
| 498 | |
| 499 # Resize the reconstructed tomography data | |
| 500 # reconstructed data order in each stack: row/z,x,y | |
| 501 if self.test_mode: | |
| 502 x_bounds = self.test_config.get('x_bounds') | |
| 503 y_bounds = self.test_config.get('y_bounds') | |
| 504 z_bounds = None | |
| 505 elif self.galaxy_flag: | |
| 506 if x_bounds is not None and not is_int_pair(x_bounds, ge=0, | |
| 507 lt=tomo_recon_stacks[0].shape[1]): | |
| 508 raise ValueError(f'Invalid parameter x_bounds ({x_bounds})') | |
| 509 if y_bounds is not None and not is_int_pair(y_bounds, ge=0, | |
| 510 lt=tomo_recon_stacks[0].shape[1]): | |
| 511 raise ValueError(f'Invalid parameter y_bounds ({y_bounds})') | |
| 512 z_bounds = None | |
| 513 else: | |
| 514 x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_stacks) | |
| 515 if x_bounds is None: | |
| 516 x_range = (0, tomo_recon_stacks[0].shape[1]) | |
| 517 x_slice = int(x_range[1]/2) | |
| 518 else: | |
| 519 x_range = (min(x_bounds), max(x_bounds)) | |
| 520 x_slice = int((x_bounds[0]+x_bounds[1])/2) | |
| 521 if y_bounds is None: | |
| 522 y_range = (0, tomo_recon_stacks[0].shape[2]) | |
| 523 y_slice = int(y_range[1]/2) | |
| 524 else: | |
| 525 y_range = (min(y_bounds), max(y_bounds)) | |
| 526 y_slice = int((y_bounds[0]+y_bounds[1])/2) | |
| 527 if z_bounds is None: | |
| 528 z_range = (0, tomo_recon_stacks[0].shape[0]) | |
| 529 z_slice = int(z_range[1]/2) | |
| 530 else: | |
| 531 z_range = (min(z_bounds), max(z_bounds)) | |
| 532 z_slice = int((z_bounds[0]+z_bounds[1])/2) | |
| 533 | |
| 534 # Plot a few reconstructed image slices | |
| 535 if num_tomo_stacks == 1: | |
| 536 basetitle = 'recon' | |
| 537 else: | |
| 538 basetitle = f'recon stack {i+1}' | |
| 539 for i, stack in enumerate(tomo_recon_stacks): | |
| 540 title = f'{basetitle} {res_title} xslice{x_slice}' | |
| 541 quick_imshow(stack[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]], | |
| 542 title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, | |
| 543 block=self.block) | |
| 544 title = f'{basetitle} {res_title} yslice{y_slice}' | |
| 545 quick_imshow(stack[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice], | |
| 546 title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, | |
| 547 block=self.block) | |
| 548 title = f'{basetitle} {res_title} zslice{z_slice}' | |
| 549 quick_imshow(stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], | |
| 550 title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, | |
| 551 block=self.block) | |
| 552 | |
| 553 # Save test data to file | |
| 554 # reconstructed data order in each stack: row/z,x,y | |
| 555 if self.test_mode: | |
| 556 for i, stack in enumerate(tomo_recon_stacks): | |
| 557 np.savetxt(f'{self.output_folder}/recon_stack_{i+1}.txt', | |
| 558 stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e') | |
| 559 | |
| 560 # Add image reconstruction to reconstructed data NXprocess | |
| 561 # reconstructed data order in each stack: row/z,x,y | |
| 562 nxprocess.data = NXdata() | |
| 563 nxprocess.attrs['default'] = 'data' | |
| 564 for k, v in center_info.items(): | |
| 565 nxprocess[k] = v | |
| 566 if x_bounds is not None: | |
| 567 nxprocess.x_bounds = x_bounds | |
| 568 if y_bounds is not None: | |
| 569 nxprocess.y_bounds = y_bounds | |
| 570 if z_bounds is not None: | |
| 571 nxprocess.z_bounds = z_bounds | |
| 572 nxprocess.data['reconstructed_data'] = np.asarray([stack[z_range[0]:z_range[1], | |
| 573 x_range[0]:x_range[1],y_range[0]:y_range[1]] for stack in tomo_recon_stacks]) | |
| 574 nxprocess.data.attrs['signal'] = 'reconstructed_data' | |
| 575 | |
| 576 # Create a copy of the input Nexus object and remove reduced data | |
| 577 exclude_items = [f'{nxentry._name}/reduced_data/data', f'{nxentry._name}/data/reduced_data'] | |
| 578 nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items) | |
| 579 | |
| 580 # Add the reconstructed data NXprocess to the new Nexus object | |
| 581 nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']] | |
| 582 nxentry_copy.reconstructed_data = nxprocess | |
| 583 if 'data' not in nxentry_copy: | |
| 584 nxentry_copy.data = NXdata() | |
| 585 nxentry_copy.attrs['default'] = 'data' | |
| 586 nxentry_copy.data.makelink(nxprocess.data.reconstructed_data, name='reconstructed_data') | |
| 587 nxentry_copy.data.attrs['signal'] = 'reconstructed_data' | |
| 588 | |
| 589 return(nxroot_copy) | |
| 590 | |
| 591 def combine_data(self, nxroot, x_bounds=None, y_bounds=None): | |
| 592 """Combine the reconstructed tomography stacks. | |
| 593 """ | |
| 594 logger.info('Combine the reconstructed tomography stacks') | |
| 595 | |
| 596 if not isinstance(nxroot, NXroot): | |
| 597 raise ValueError(f'Invalid parameter nxroot ({nxroot})') | |
| 598 nxentry = nxroot[nxroot.attrs['default']] | |
| 599 if not isinstance(nxentry, NXentry): | |
| 600 raise ValueError(f'Invalid nxentry ({nxentry})') | |
| 601 | |
| 602 # Create plot galaxy path directory and path if needed | |
| 603 if self.galaxy_flag: | |
| 604 if not os_path.exists('tomo_combine_plots'): | |
| 605 mkdir('tomo_combine_plots') | |
| 606 path = 'tomo_combine_plots' | |
| 607 else: | |
| 608 path = self.output_folder | |
| 609 | |
| 610 # Check if reconstructed image data is available | |
| 611 if ('reconstructed_data' not in nxentry or 'reconstructed_data' not in nxentry.data): | |
| 612 raise KeyError(f'Unable to find valid reconstructed image data in {nxentry}.') | |
| 613 | |
| 614 # Create an NXprocess to store combined image reconstruction (meta)data | |
| 615 nxprocess = NXprocess() | |
| 616 | |
| 617 # Get the reconstructed data | |
| 618 # reconstructed data order: stack,row(z),x,y | |
| 619 # Note: Nexus cannot follow a link if the data it points to is too big, | |
| 620 # so get the data from the actual place, not from nxentry.data | |
| 621 num_tomo_stacks = nxentry.reconstructed_data.data.reconstructed_data.shape[0] | |
| 622 if num_tomo_stacks == 1: | |
| 623 logger.info('Only one stack available: leaving combine_data') | |
| 624 return(None) | |
| 625 | |
| 626 # Combine the reconstructed stacks | |
| 627 # (load one stack at a time to reduce risk of hitting Nexus data access limit) | |
| 628 t0 = time() | |
| 629 logger.debug(f'Combining the reconstructed stacks ...') | |
| 630 tomo_recon_combined = np.asarray(nxentry.reconstructed_data.data.reconstructed_data[0]) | |
| 631 if num_tomo_stacks > 2: | |
| 632 tomo_recon_combined = np.concatenate([tomo_recon_combined]+ | |
| 633 [nxentry.reconstructed_data.data.reconstructed_data[i] | |
| 634 for i in range(1, num_tomo_stacks-1)]) | |
| 635 if num_tomo_stacks > 1: | |
| 636 tomo_recon_combined = np.concatenate([tomo_recon_combined]+ | |
| 637 [nxentry.reconstructed_data.data.reconstructed_data[num_tomo_stacks-1]]) | |
| 638 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 639 logger.info(f'Combining the reconstructed stacks took {time()-t0:.2f} seconds') | |
| 640 | |
| 641 # Resize the combined tomography data stacks | |
| 642 # combined data order: row/z,x,y | |
| 643 if self.test_mode: | |
| 644 x_bounds = None | |
| 645 y_bounds = None | |
| 646 z_bounds = self.test_config.get('z_bounds') | |
| 647 elif self.galaxy_flag: | |
| 648 if x_bounds is not None and not is_int_pair(x_bounds, ge=0, | |
| 649 lt=tomo_recon_stacks[0].shape[1]): | |
| 650 raise ValueError(f'Invalid parameter x_bounds ({x_bounds})') | |
| 651 if y_bounds is not None and not is_int_pair(y_bounds, ge=0, | |
| 652 lt=tomo_recon_stacks[0].shape[1]): | |
| 653 raise ValueError(f'Invalid parameter y_bounds ({y_bounds})') | |
| 654 z_bounds = None | |
| 655 else: | |
| 656 x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_combined, | |
| 657 z_only=True) | |
| 658 if x_bounds is None: | |
| 659 x_range = (0, tomo_recon_combined.shape[1]) | |
| 660 x_slice = int(x_range[1]/2) | |
| 661 else: | |
| 662 x_range = x_bounds | |
| 663 x_slice = int((x_bounds[0]+x_bounds[1])/2) | |
| 664 if y_bounds is None: | |
| 665 y_range = (0, tomo_recon_combined.shape[2]) | |
| 666 y_slice = int(y_range[1]/2) | |
| 667 else: | |
| 668 y_range = y_bounds | |
| 669 y_slice = int((y_bounds[0]+y_bounds[1])/2) | |
| 670 if z_bounds is None: | |
| 671 z_range = (0, tomo_recon_combined.shape[0]) | |
| 672 z_slice = int(z_range[1]/2) | |
| 673 else: | |
| 674 z_range = z_bounds | |
| 675 z_slice = int((z_bounds[0]+z_bounds[1])/2) | |
| 676 | |
| 677 # Plot a few combined image slices | |
| 678 quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]], | |
| 679 title=f'recon combined xslice{x_slice}', path=path, | |
| 680 save_fig=self.save_figs, save_only=self.save_only, block=self.block) | |
| 681 quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice], | |
| 682 title=f'recon combined yslice{y_slice}', path=path, | |
| 683 save_fig=self.save_figs, save_only=self.save_only, block=self.block) | |
| 684 quick_imshow(tomo_recon_combined[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], | |
| 685 title=f'recon combined zslice{z_slice}', path=path, | |
| 686 save_fig=self.save_figs, save_only=self.save_only, block=self.block) | |
| 687 | |
| 688 # Save test data to file | |
| 689 # combined data order: row/z,x,y | |
| 690 if self.test_mode: | |
| 691 np.savetxt(f'{self.output_folder}/recon_combined.txt', tomo_recon_combined[ | |
| 692 z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e') | |
| 693 | |
| 694 # Add image reconstruction to reconstructed data NXprocess | |
| 695 # combined data order: row/z,x,y | |
| 696 nxprocess.data = NXdata() | |
| 697 nxprocess.attrs['default'] = 'data' | |
| 698 if x_bounds is not None: | |
| 699 nxprocess.x_bounds = x_bounds | |
| 700 if y_bounds is not None: | |
| 701 nxprocess.y_bounds = y_bounds | |
| 702 if z_bounds is not None: | |
| 703 nxprocess.z_bounds = z_bounds | |
| 704 nxprocess.data['combined_data'] = tomo_recon_combined | |
| 705 nxprocess.data.attrs['signal'] = 'combined_data' | |
| 706 | |
| 707 # Create a copy of the input Nexus object and remove reconstructed data | |
| 708 exclude_items = [f'{nxentry._name}/reconstructed_data/data', | |
| 709 f'{nxentry._name}/data/reconstructed_data'] | |
| 710 nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items) | |
| 711 | |
| 712 # Add the combined data NXprocess to the new Nexus object | |
| 713 nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']] | |
| 714 nxentry_copy.combined_data = nxprocess | |
| 715 if 'data' not in nxentry_copy: | |
| 716 nxentry_copy.data = NXdata() | |
| 717 nxentry_copy.attrs['default'] = 'data' | |
| 718 nxentry_copy.data.makelink(nxprocess.data.combined_data, name='combined_data') | |
| 719 nxentry_copy.data.attrs['signal'] = 'combined_data' | |
| 720 | |
| 721 return(nxroot_copy) | |
| 722 | |
| 723 def _gen_dark(self, nxentry, reduced_data): | |
| 724 """Generate dark field. | |
| 725 """ | |
| 726 # Get the dark field images | |
| 727 image_key = nxentry.instrument.detector.get('image_key', None) | |
| 728 if image_key and 'data' in nxentry.instrument.detector: | |
| 729 field_indices = [index for index, key in enumerate(image_key) if key == 2] | |
| 730 tdf_stack = nxentry.instrument.detector.data[field_indices,:,:] | |
| 731 # RV the default NXtomo form does not accomodate bright or dark field stacks | |
| 732 else: | |
| 733 dark_field_scans = nxentry.spec_scans.dark_field | |
| 734 dark_field = FlatField.construct_from_nxcollection(dark_field_scans) | |
| 735 prefix = str(nxentry.instrument.detector.local_name) | |
| 736 tdf_stack = dark_field.get_detector_data(prefix) | |
| 737 if isinstance(tdf_stack, list): | |
| 738 assert(len(tdf_stack) == 1) # TODO | |
| 739 tdf_stack = tdf_stack[0] | |
| 740 | |
| 741 # Take median | |
| 742 if tdf_stack.ndim == 2: | |
| 743 tdf = tdf_stack | |
| 744 elif tdf_stack.ndim == 3: | |
| 745 tdf = np.median(tdf_stack, axis=0) | |
| 746 del tdf_stack | |
| 747 else: | |
| 748 raise ValueError(f'Invalid tdf_stack shape ({tdf_stack.shape})') | |
| 749 | |
| 750 # Remove dark field intensities above the cutoff | |
| 751 #RV tdf_cutoff = None | |
| 752 tdf_cutoff = tdf.min()+2*(np.median(tdf)-tdf.min()) | |
| 753 logger.debug(f'tdf_cutoff = {tdf_cutoff}') | |
| 754 if tdf_cutoff is not None: | |
| 755 if not is_num(tdf_cutoff, ge=0): | |
| 756 logger.warning(f'Ignoring illegal value of tdf_cutoff {tdf_cutoff}') | |
| 757 else: | |
| 758 tdf[tdf > tdf_cutoff] = np.nan | |
| 759 logger.debug(f'tdf_cutoff = {tdf_cutoff}') | |
| 760 | |
| 761 # Remove nans | |
| 762 tdf_mean = np.nanmean(tdf) | |
| 763 logger.debug(f'tdf_mean = {tdf_mean}') | |
| 764 np.nan_to_num(tdf, copy=False, nan=tdf_mean, posinf=tdf_mean, neginf=0.) | |
| 765 | |
| 766 # Plot dark field | |
| 767 if self.galaxy_flag: | |
| 768 quick_imshow(tdf, title='dark field', path='tomo_reduce_plots', save_fig=self.save_figs, | |
| 769 save_only=self.save_only) | |
| 770 elif not self.test_mode: | |
| 771 quick_imshow(tdf, title='dark field', path=self.output_folder, save_fig=self.save_figs, | |
| 772 save_only=self.save_only) | |
| 773 clear_imshow('dark field') | |
| 774 # quick_imshow(tdf, title='dark field', block=True) | |
| 775 | |
| 776 # Add dark field to reduced data NXprocess | |
| 777 reduced_data.data = NXdata() | |
| 778 reduced_data.data['dark_field'] = tdf | |
| 779 | |
| 780 return(reduced_data) | |
| 781 | |
| 782 def _gen_bright(self, nxentry, reduced_data): | |
| 783 """Generate bright field. | |
| 784 """ | |
| 785 # Get the bright field images | |
| 786 image_key = nxentry.instrument.detector.get('image_key', None) | |
| 787 if image_key and 'data' in nxentry.instrument.detector: | |
| 788 field_indices = [index for index, key in enumerate(image_key) if key == 1] | |
| 789 tbf_stack = nxentry.instrument.detector.data[field_indices,:,:] | |
| 790 # RV the default NXtomo form does not accomodate bright or dark field stacks | |
| 791 else: | |
| 792 bright_field_scans = nxentry.spec_scans.bright_field | |
| 793 bright_field = FlatField.construct_from_nxcollection(bright_field_scans) | |
| 794 prefix = str(nxentry.instrument.detector.local_name) | |
| 795 tbf_stack = bright_field.get_detector_data(prefix) | |
| 796 if isinstance(tbf_stack, list): | |
| 797 assert(len(tbf_stack) == 1) # TODO | |
| 798 tbf_stack = tbf_stack[0] | |
| 799 | |
| 800 # Take median if more than one image | |
| 801 """Median or mean: It may be best to try the median because of some image | |
| 802 artifacts that arise due to crinkles in the upstream kapton tape windows | |
| 803 causing some phase contrast images to appear on the detector. | |
| 804 One thing that also may be useful in a future implementation is to do a | |
| 805 brightfield adjustment on EACH frame of the tomo based on a ROI in the | |
| 806 corner of the frame where there is no sample but there is the direct X-ray | |
| 807 beam because there is frame to frame fluctuations from the incoming beam. | |
| 808 We don’t typically account for them but potentially could. | |
| 809 """ | |
| 810 if tbf_stack.ndim == 2: | |
| 811 tbf = tbf_stack | |
| 812 elif tbf_stack.ndim == 3: | |
| 813 tbf = np.median(tbf_stack, axis=0) | |
| 814 del tbf_stack | |
| 815 else: | |
| 816 raise ValueError(f'Invalid tbf_stack shape ({tbf_stacks.shape})') | |
| 817 | |
| 818 # Subtract dark field | |
| 819 if 'data' in reduced_data and 'dark_field' in reduced_data.data: | |
| 820 tbf -= reduced_data.data.dark_field | |
| 821 else: | |
| 822 logger.warning('Dark field unavailable') | |
| 823 | |
| 824 # Set any non-positive values to one | |
| 825 # (avoid negative bright field values for spikes in dark field) | |
| 826 tbf[tbf < 1] = 1 | |
| 827 | |
| 828 # Plot bright field | |
| 829 if self.galaxy_flag: | |
| 830 quick_imshow(tbf, title='bright field', path='tomo_reduce_plots', | |
| 831 save_fig=self.save_figs, save_only=self.save_only) | |
| 832 elif not self.test_mode: | |
| 833 quick_imshow(tbf, title='bright field', path=self.output_folder, | |
| 834 save_fig=self.save_figs, save_only=self.save_only) | |
| 835 clear_imshow('bright field') | |
| 836 # quick_imshow(tbf, title='bright field', block=True) | |
| 837 | |
| 838 # Add bright field to reduced data NXprocess | |
| 839 if 'data' not in reduced_data: | |
| 840 reduced_data.data = NXdata() | |
| 841 reduced_data.data['bright_field'] = tbf | |
| 842 | |
| 843 return(reduced_data) | |
| 844 | |
| 845 def _set_detector_bounds(self, nxentry, reduced_data, img_x_bounds=None): | |
| 846 """Set vertical detector bounds for each image stack. | |
| 847 Right now the range is the same for each set in the image stack. | |
| 848 """ | |
| 849 if self.test_mode: | |
| 850 return(tuple(self.test_config['img_x_bounds'])) | |
| 851 | |
| 852 # Get the first tomography image and the reference heights | |
| 853 image_key = nxentry.instrument.detector.get('image_key', None) | |
| 854 if image_key and 'data' in nxentry.instrument.detector: | |
| 855 field_indices = [index for index, key in enumerate(image_key) if key == 0] | |
| 856 first_image = np.asarray(nxentry.instrument.detector.data[field_indices[0],:,:]) | |
| 857 theta = float(nxentry.sample.rotation_angle[field_indices[0]]) | |
| 858 z_translation_all = nxentry.sample.z_translation[field_indices] | |
| 859 vertical_shifts = sorted(list(set(z_translation_all))) | |
| 860 num_tomo_stacks = len(vertical_shifts) | |
| 861 else: | |
| 862 tomo_field_scans = nxentry.spec_scans.tomo_fields | |
| 863 tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans) | |
| 864 vertical_shifts = tomo_fields.get_vertical_shifts() | |
| 865 if not isinstance(vertical_shifts, list): | |
| 866 vertical_shifts = [vertical_shifts] | |
| 867 prefix = str(nxentry.instrument.detector.local_name) | |
| 868 t0 = time() | |
| 869 first_image = tomo_fields.get_detector_data(prefix, tomo_fields.scan_numbers[0], 0) | |
| 870 logger.debug(f'Getting first image took {time()-t0:.2f} seconds') | |
| 871 num_tomo_stacks = len(tomo_fields.scan_numbers) | |
| 872 theta = tomo_fields.theta_range['start'] | |
| 873 | |
| 874 # Select image bounds | |
| 875 title = f'tomography image at theta={round(theta, 2)+0}' | |
| 876 if (img_x_bounds is not None and not is_index_range(img_x_bounds, ge=0, | |
| 877 le=first_image.shape[0])): | |
| 878 raise ValueError(f'Invalid parameter img_x_bounds ({img_x_bounds})') | |
| 879 if nxentry.instrument.source.attrs['station'] in ('id1a3', 'id3a'): | |
| 880 pixel_size = nxentry.instrument.detector.x_pixel_size | |
| 881 # Try to get a fit from the bright field | |
| 882 tbf = np.asarray(reduced_data.data.bright_field) | |
| 883 tbf_shape = tbf.shape | |
| 884 x_sum = np.sum(tbf, 1) | |
| 885 x_sum_min = x_sum.min() | |
| 886 x_sum_max = x_sum.max() | |
| 887 fit = Fit.fit_data(x_sum, 'rectangle', x=np.array(range(len(x_sum))), form='atan', | |
| 888 guess=True) | |
| 889 parameters = fit.best_values | |
| 890 x_low_fit = parameters.get('center1', None) | |
| 891 x_upp_fit = parameters.get('center2', None) | |
| 892 sig_low = parameters.get('sigma1', None) | |
| 893 sig_upp = parameters.get('sigma2', None) | |
| 894 have_fit = fit.success and x_low_fit is not None and x_upp_fit is not None and \ | |
| 895 sig_low is not None and sig_upp is not None and \ | |
| 896 0 <= x_low_fit < x_upp_fit <= x_sum.size and \ | |
| 897 (sig_low+sig_upp)/(x_upp_fit-x_low_fit) < 0.1 | |
| 898 if have_fit: | |
| 899 # Set a 5% margin on each side | |
| 900 margin = 0.05*(x_upp_fit-x_low_fit) | |
| 901 x_low_fit = max(0, x_low_fit-margin) | |
| 902 x_upp_fit = min(tbf_shape[0], x_upp_fit+margin) | |
| 903 if num_tomo_stacks == 1: | |
| 904 if have_fit: | |
| 905 # Set the default range to enclose the full fitted window | |
| 906 x_low = int(x_low_fit) | |
| 907 x_upp = int(x_upp_fit) | |
| 908 else: | |
| 909 # Center a default range of 1 mm (RV: can we get this from the slits?) | |
| 910 num_x_min = int((1.0-0.5*pixel_size)/pixel_size) | |
| 911 x_low = int(0.5*(tbf_shape[0]-num_x_min)) | |
| 912 x_upp = x_low+num_x_min | |
| 913 else: | |
| 914 # Get the default range from the reference heights | |
| 915 delta_z = vertical_shifts[1]-vertical_shifts[0] | |
| 916 for i in range(2, num_tomo_stacks): | |
| 917 delta_z = min(delta_z, vertical_shifts[i]-vertical_shifts[i-1]) | |
| 918 logger.debug(f'delta_z = {delta_z}') | |
| 919 num_x_min = int((delta_z-0.5*pixel_size)/pixel_size) | |
| 920 logger.debug(f'num_x_min = {num_x_min}') | |
| 921 if num_x_min > tbf_shape[0]: | |
| 922 logger.warning('Image bounds and pixel size prevent seamless stacking') | |
| 923 if have_fit: | |
| 924 # Center the default range relative to the fitted window | |
| 925 x_low = int(0.5*(x_low_fit+x_upp_fit-num_x_min)) | |
| 926 x_upp = x_low+num_x_min | |
| 927 else: | |
| 928 # Center the default range | |
| 929 x_low = int(0.5*(tbf_shape[0]-num_x_min)) | |
| 930 x_upp = x_low+num_x_min | |
| 931 if self.galaxy_flag: | |
| 932 img_x_bounds = (x_low, x_upp) | |
| 933 else: | |
| 934 tmp = np.copy(tbf) | |
| 935 tmp_max = tmp.max() | |
| 936 tmp[x_low,:] = tmp_max | |
| 937 tmp[x_upp-1,:] = tmp_max | |
| 938 quick_imshow(tmp, title='bright field') | |
| 939 tmp = np.copy(first_image) | |
| 940 tmp_max = tmp.max() | |
| 941 tmp[x_low,:] = tmp_max | |
| 942 tmp[x_upp-1,:] = tmp_max | |
| 943 quick_imshow(tmp, title=title) | |
| 944 del tmp | |
| 945 quick_plot((range(x_sum.size), x_sum), | |
| 946 ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'), | |
| 947 ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'), | |
| 948 title='sum over theta and y') | |
| 949 print(f'lower bound = {x_low} (inclusive)') | |
| 950 print(f'upper bound = {x_upp} (exclusive)]') | |
| 951 accept = input_yesno('Accept these bounds (y/n)?', 'y') | |
| 952 clear_imshow('bright field') | |
| 953 clear_imshow(title) | |
| 954 clear_plot('sum over theta and y') | |
| 955 if accept: | |
| 956 img_x_bounds = (x_low, x_upp) | |
| 957 else: | |
| 958 while True: | |
| 959 mask, img_x_bounds = draw_mask_1d(x_sum, title='select x data range', | |
| 960 legend='sum over theta and y') | |
| 961 if len(img_x_bounds) == 1: | |
| 962 break | |
| 963 else: | |
| 964 print(f'Choose a single connected data range') | |
| 965 img_x_bounds = tuple(img_x_bounds[0]) | |
| 966 if (num_tomo_stacks > 1 and img_x_bounds[1]-img_x_bounds[0]+1 < | |
| 967 int((delta_z-0.5*pixel_size)/pixel_size)): | |
| 968 logger.warning('Image bounds and pixel size prevent seamless stacking') | |
| 969 else: | |
| 970 if num_tomo_stacks > 1: | |
| 971 raise NotImplementedError('Selecting image bounds for multiple stacks on FMB') | |
| 972 # For FMB: use the first tomography image to select range | |
| 973 # RV: revisit if they do tomography with multiple stacks | |
| 974 x_sum = np.sum(first_image, 1) | |
| 975 x_sum_min = x_sum.min() | |
| 976 x_sum_max = x_sum.max() | |
| 977 if self.galaxy_flag: | |
| 978 if img_x_bounds is None: | |
| 979 img_x_bounds = (0, first_image.shape[0]) | |
| 980 else: | |
| 981 quick_imshow(first_image, title=title) | |
| 982 print('Select vertical data reduction range from first tomography image') | |
| 983 img_x_bounds = select_image_bounds(first_image, 0, title=title) | |
| 984 clear_imshow(title) | |
| 985 if img_x_bounds is None: | |
| 986 raise ValueError('Unable to select image bounds') | |
| 987 | |
| 988 # Plot results | |
| 989 if self.galaxy_flag: | |
| 990 path = 'tomo_reduce_plots' | |
| 991 else: | |
| 992 path = self.output_folder | |
| 993 x_low = img_x_bounds[0] | |
| 994 x_upp = img_x_bounds[1] | |
| 995 tmp = np.copy(first_image) | |
| 996 tmp_max = tmp.max() | |
| 997 tmp[x_low,:] = tmp_max | |
| 998 tmp[x_upp-1,:] = tmp_max | |
| 999 quick_imshow(tmp, title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, | |
| 1000 block=self.block) | |
| 1001 del tmp | |
| 1002 quick_plot((range(x_sum.size), x_sum), | |
| 1003 ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'), | |
| 1004 ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'), | |
| 1005 title='sum over theta and y', path=path, save_fig=self.save_figs, | |
| 1006 save_only=self.save_only, block=self.block) | |
| 1007 | |
| 1008 return(img_x_bounds) | |
| 1009 | |
| 1010 def _set_zoom_or_skip(self): | |
| 1011 """Set zoom and/or theta skip to reduce memory the requirement for the analysis. | |
| 1012 """ | |
| 1013 # if input_yesno('\nDo you want to zoom in to reduce memory requirement (y/n)?', 'n'): | |
| 1014 # zoom_perc = input_int(' Enter zoom percentage', ge=1, le=100) | |
| 1015 # else: | |
| 1016 # zoom_perc = None | |
| 1017 zoom_perc = None | |
| 1018 # if input_yesno('Do you want to skip thetas to reduce memory requirement (y/n)?', 'n'): | |
| 1019 # num_theta_skip = input_int(' Enter the number skip theta interval', ge=0, | |
| 1020 # lt=num_theta) | |
| 1021 # else: | |
| 1022 # num_theta_skip = None | |
| 1023 num_theta_skip = None | |
| 1024 logger.debug(f'zoom_perc = {zoom_perc}') | |
| 1025 logger.debug(f'num_theta_skip = {num_theta_skip}') | |
| 1026 | |
| 1027 return(zoom_perc, num_theta_skip) | |
| 1028 | |
| 1029 def _gen_tomo(self, nxentry, reduced_data): | |
| 1030 """Generate tomography fields. | |
| 1031 """ | |
| 1032 # Get full bright field | |
| 1033 tbf = np.asarray(reduced_data.data.bright_field) | |
| 1034 tbf_shape = tbf.shape | |
| 1035 | |
| 1036 # Get image bounds | |
| 1037 img_x_bounds = tuple(reduced_data.get('img_x_bounds', (0, tbf_shape[0]))) | |
| 1038 img_y_bounds = tuple(reduced_data.get('img_y_bounds', (0, tbf_shape[1]))) | |
| 1039 | |
| 1040 # Get resized dark field | |
| 1041 # if 'dark_field' in data: | |
| 1042 # tbf = np.asarray(reduced_data.data.dark_field[ | |
| 1043 # img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]]) | |
| 1044 # else: | |
| 1045 # logger.warning('Dark field unavailable') | |
| 1046 # tdf = None | |
| 1047 tdf = None | |
| 1048 | |
| 1049 # Resize bright field | |
| 1050 if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]): | |
| 1051 tbf = tbf[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]] | |
| 1052 | |
| 1053 # Get the tomography images | |
| 1054 image_key = nxentry.instrument.detector.get('image_key', None) | |
| 1055 if image_key and 'data' in nxentry.instrument.detector: | |
| 1056 field_indices_all = [index for index, key in enumerate(image_key) if key == 0] | |
| 1057 z_translation_all = nxentry.sample.z_translation[field_indices_all] | |
| 1058 z_translation_levels = sorted(list(set(z_translation_all))) | |
| 1059 num_tomo_stacks = len(z_translation_levels) | |
| 1060 tomo_stacks = num_tomo_stacks*[np.array([])] | |
| 1061 horizontal_shifts = [] | |
| 1062 vertical_shifts = [] | |
| 1063 thetas = None | |
| 1064 tomo_stacks = [] | |
| 1065 for i, z_translation in enumerate(z_translation_levels): | |
| 1066 field_indices = [field_indices_all[index] | |
| 1067 for index, z in enumerate(z_translation_all) if z == z_translation] | |
| 1068 horizontal_shift = list(set(nxentry.sample.x_translation[field_indices])) | |
| 1069 assert(len(horizontal_shift) == 1) | |
| 1070 horizontal_shifts += horizontal_shift | |
| 1071 vertical_shift = list(set(nxentry.sample.z_translation[field_indices])) | |
| 1072 assert(len(vertical_shift) == 1) | |
| 1073 vertical_shifts += vertical_shift | |
| 1074 sequence_numbers = nxentry.instrument.detector.sequence_number[field_indices] | |
| 1075 if thetas is None: | |
| 1076 thetas = np.asarray(nxentry.sample.rotation_angle[field_indices]) \ | |
| 1077 [sequence_numbers] | |
| 1078 else: | |
| 1079 assert(all(thetas[i] == nxentry.sample.rotation_angle[field_indices[index]] | |
| 1080 for i, index in enumerate(sequence_numbers))) | |
| 1081 assert(list(set(sequence_numbers)) == [i for i in range(len(sequence_numbers))]) | |
| 1082 if list(sequence_numbers) == [i for i in range(len(sequence_numbers))]: | |
| 1083 tomo_stack = np.asarray(nxentry.instrument.detector.data[field_indices]) | |
| 1084 else: | |
| 1085 raise ValueError('Unable to load the tomography images') | |
| 1086 tomo_stacks.append(tomo_stack) | |
| 1087 else: | |
| 1088 tomo_field_scans = nxentry.spec_scans.tomo_fields | |
| 1089 tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans) | |
| 1090 horizontal_shifts = tomo_fields.get_horizontal_shifts() | |
| 1091 vertical_shifts = tomo_fields.get_vertical_shifts() | |
| 1092 prefix = str(nxentry.instrument.detector.local_name) | |
| 1093 t0 = time() | |
| 1094 tomo_stacks = tomo_fields.get_detector_data(prefix) | |
| 1095 logger.debug(f'Getting tomography images took {time()-t0:.2f} seconds') | |
| 1096 logger.debug(f'Getting all images took {time()-t0:.2f} seconds') | |
| 1097 thetas = np.linspace(tomo_fields.theta_range['start'], tomo_fields.theta_range['end'], | |
| 1098 tomo_fields.theta_range['num']) | |
| 1099 if not isinstance(tomo_stacks, list): | |
| 1100 horizontal_shifts = [horizontal_shifts] | |
| 1101 vertical_shifts = [vertical_shifts] | |
| 1102 tomo_stacks = [tomo_stacks] | |
| 1103 | |
| 1104 reduced_tomo_stacks = [] | |
| 1105 if self.galaxy_flag: | |
| 1106 path = 'tomo_reduce_plots' | |
| 1107 else: | |
| 1108 path = self.output_folder | |
| 1109 for i, tomo_stack in enumerate(tomo_stacks): | |
| 1110 # Resize the tomography images | |
| 1111 # Right now the range is the same for each set in the image stack. | |
| 1112 if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]): | |
| 1113 t0 = time() | |
| 1114 tomo_stack = tomo_stack[:,img_x_bounds[0]:img_x_bounds[1], | |
| 1115 img_y_bounds[0]:img_y_bounds[1]].astype('float64') | |
| 1116 logger.debug(f'Resizing tomography images took {time()-t0:.2f} seconds') | |
| 1117 | |
| 1118 # Subtract dark field | |
| 1119 if tdf is not None: | |
| 1120 t0 = time() | |
| 1121 with set_numexpr_threads(self.num_core): | |
| 1122 ne.evaluate('tomo_stack-tdf', out=tomo_stack) | |
| 1123 logger.debug(f'Subtracting dark field took {time()-t0:.2f} seconds') | |
| 1124 | |
| 1125 # Normalize | |
| 1126 t0 = time() | |
| 1127 with set_numexpr_threads(self.num_core): | |
| 1128 ne.evaluate('tomo_stack/tbf', out=tomo_stack, truediv=True) | |
| 1129 logger.debug(f'Normalizing took {time()-t0:.2f} seconds') | |
| 1130 | |
| 1131 # Remove non-positive values and linearize data | |
| 1132 t0 = time() | |
| 1133 cutoff = 1.e-6 | |
| 1134 with set_numexpr_threads(self.num_core): | |
| 1135 ne.evaluate('where(tomo_stack<cutoff, cutoff, tomo_stack)', out=tomo_stack) | |
| 1136 with set_numexpr_threads(self.num_core): | |
| 1137 ne.evaluate('-log(tomo_stack)', out=tomo_stack) | |
| 1138 logger.debug('Removing non-positive values and linearizing data took '+ | |
| 1139 f'{time()-t0:.2f} seconds') | |
| 1140 | |
| 1141 # Get rid of nans/infs that may be introduced by normalization | |
| 1142 t0 = time() | |
| 1143 np.where(np.isfinite(tomo_stack), tomo_stack, 0.) | |
| 1144 logger.debug(f'Remove nans/infs took {time()-t0:.2f} seconds') | |
| 1145 | |
| 1146 # Downsize tomography stack to smaller size | |
| 1147 # TODO use theta_skip as well | |
| 1148 tomo_stack = tomo_stack.astype('float32') | |
| 1149 if not self.test_mode: | |
| 1150 if len(tomo_stacks) == 1: | |
| 1151 title = f'red fullres theta {round(thetas[0], 2)+0}' | |
| 1152 else: | |
| 1153 title = f'red stack {i+1} fullres theta {round(thetas[0], 2)+0}' | |
| 1154 quick_imshow(tomo_stack[0,:,:], title=title, path=path, save_fig=self.save_figs, | |
| 1155 save_only=self.save_only, block=self.block) | |
| 1156 # if not self.block: | |
| 1157 # clear_imshow(title) | |
| 1158 if False and zoom_perc != 100: | |
| 1159 t0 = time() | |
| 1160 logger.debug(f'Zooming in ...') | |
| 1161 tomo_zoom_list = [] | |
| 1162 for j in range(tomo_stack.shape[0]): | |
| 1163 tomo_zoom = spi.zoom(tomo_stack[j,:,:], 0.01*zoom_perc) | |
| 1164 tomo_zoom_list.append(tomo_zoom) | |
| 1165 tomo_stack = np.stack([tomo_zoom for tomo_zoom in tomo_zoom_list]) | |
| 1166 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 1167 logger.info(f'Zooming in took {time()-t0:.2f} seconds') | |
| 1168 del tomo_zoom_list | |
| 1169 if not self.test_mode: | |
| 1170 title = f'red stack {zoom_perc}p theta {round(thetas[0], 2)+0}' | |
| 1171 quick_imshow(tomo_stack[0,:,:], title=title, path=path, save_fig=self.save_figs, | |
| 1172 save_only=self.save_only, block=self.block) | |
| 1173 # if not self.block: | |
| 1174 # clear_imshow(title) | |
| 1175 | |
| 1176 # Save test data to file | |
| 1177 if self.test_mode: | |
| 1178 # row_index = int(tomo_stack.shape[0]/2) | |
| 1179 # np.savetxt(f'{self.output_folder}/red_stack_{i+1}.txt', tomo_stack[row_index,:,:], | |
| 1180 # fmt='%.6e') | |
| 1181 row_index = int(tomo_stack.shape[1]/2) | |
| 1182 np.savetxt(f'{self.output_folder}/red_stack_{i+1}.txt', tomo_stack[:,row_index,:], | |
| 1183 fmt='%.6e') | |
| 1184 | |
| 1185 # Combine resized stacks | |
| 1186 reduced_tomo_stacks.append(tomo_stack) | |
| 1187 | |
| 1188 # Add tomo field info to reduced data NXprocess | |
| 1189 reduced_data['rotation_angle'] = thetas | |
| 1190 reduced_data['x_translation'] = np.asarray(horizontal_shifts) | |
| 1191 reduced_data['z_translation'] = np.asarray(vertical_shifts) | |
| 1192 reduced_data.data['tomo_fields'] = np.asarray(reduced_tomo_stacks) | |
| 1193 | |
| 1194 if tdf is not None: | |
| 1195 del tdf | |
| 1196 del tbf | |
| 1197 | |
| 1198 return(reduced_data) | |
| 1199 | |
| 1200 def _find_center_one_plane(self, sinogram, row, thetas, eff_pixel_size, cross_sectional_dim, | |
| 1201 path=None, tol=0.1, num_core=1): | |
| 1202 """Find center for a single tomography plane. | |
| 1203 """ | |
| 1204 # Try automatic center finding routines for initial value | |
| 1205 # sinogram index order: theta,column | |
| 1206 # need column,theta for iradon, so take transpose | |
| 1207 sinogram = np.asarray(sinogram) | |
| 1208 sinogram_T = sinogram.T | |
| 1209 center = sinogram.shape[1]/2 | |
| 1210 | |
| 1211 # Try using Nghia Vo’s method | |
| 1212 t0 = time() | |
| 1213 if num_core > num_core_tomopy_limit: | |
| 1214 logger.debug(f'Running find_center_vo on {num_core_tomopy_limit} cores ...') | |
| 1215 tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core_tomopy_limit) | |
| 1216 else: | |
| 1217 logger.debug(f'Running find_center_vo on {num_core} cores ...') | |
| 1218 tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core) | |
| 1219 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 1220 logger.info(f'Finding the center using Nghia Vo’s method took {time()-t0:.2f} seconds') | |
| 1221 center_offset_vo = tomo_center-center | |
| 1222 logger.info(f'Center at row {row} using Nghia Vo’s method = {center_offset_vo:.2f}') | |
| 1223 t0 = time() | |
| 1224 logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...') | |
| 1225 recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas, | |
| 1226 eff_pixel_size, cross_sectional_dim, False, num_core) | |
| 1227 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 1228 logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds') | |
| 1229 | |
| 1230 title = f'edges row{row} center offset{center_offset_vo:.2f} Vo' | |
| 1231 self._plot_edges_one_plane(recon_plane, title, path=path) | |
| 1232 | |
| 1233 # Try using phase correlation method | |
| 1234 # if input_yesno('Try finding center using phase correlation (y/n)?', 'n'): | |
| 1235 # t0 = time() | |
| 1236 # logger.debug(f'Running find_center_pc ...') | |
| 1237 # tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=0.1, rotc_guess=tomo_center) | |
| 1238 # error = 1. | |
| 1239 # while error > tol: | |
| 1240 # prev = tomo_center | |
| 1241 # tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=tol, | |
| 1242 # rotc_guess=tomo_center) | |
| 1243 # error = np.abs(tomo_center-prev) | |
| 1244 # logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 1245 # logger.info('Finding the center using the phase correlation method took '+ | |
| 1246 # f'{time()-t0:.2f} seconds') | |
| 1247 # center_offset = tomo_center-center | |
| 1248 # print(f'Center at row {row} using phase correlation = {center_offset:.2f}') | |
| 1249 # t0 = time() | |
| 1250 # logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...') | |
| 1251 # recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas, | |
| 1252 # eff_pixel_size, cross_sectional_dim, False, num_core) | |
| 1253 # logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 1254 # logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds') | |
| 1255 # | |
| 1256 # title = f'edges row{row} center_offset{center_offset:.2f} PC' | |
| 1257 # self._plot_edges_one_plane(recon_plane, title, path=path) | |
| 1258 | |
| 1259 # Select center location | |
| 1260 # if input_yesno('Accept a center location (y) or continue search (n)?', 'y'): | |
| 1261 if True: | |
| 1262 # center_offset = input_num(' Enter chosen center offset', ge=-center, le=center, | |
| 1263 # default=center_offset_vo) | |
| 1264 center_offset = center_offset_vo | |
| 1265 del sinogram_T | |
| 1266 del recon_plane | |
| 1267 return float(center_offset) | |
| 1268 | |
| 1269 # perform center finding search | |
| 1270 while True: | |
| 1271 center_offset_low = input_int('\nEnter lower bound for center offset', ge=-center, | |
| 1272 le=center) | |
| 1273 center_offset_upp = input_int('Enter upper bound for center offset', | |
| 1274 ge=center_offset_low, le=center) | |
| 1275 if center_offset_upp == center_offset_low: | |
| 1276 center_offset_step = 1 | |
| 1277 else: | |
| 1278 center_offset_step = input_int('Enter step size for center offset search', ge=1, | |
| 1279 le=center_offset_upp-center_offset_low) | |
| 1280 num_center_offset = 1+int((center_offset_upp-center_offset_low)/center_offset_step) | |
| 1281 center_offsets = np.linspace(center_offset_low, center_offset_upp, num_center_offset) | |
| 1282 for center_offset in center_offsets: | |
| 1283 if center_offset == center_offset_vo: | |
| 1284 continue | |
| 1285 t0 = time() | |
| 1286 logger.debug(f'Running _reconstruct_one_plane on {num_core} cores ...') | |
| 1287 recon_plane = self._reconstruct_one_plane(sinogram_T, center_offset+center, thetas, | |
| 1288 eff_pixel_size, cross_sectional_dim, False, num_core) | |
| 1289 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 1290 logger.info(f'Reconstructing center_offset {center_offset} took '+ | |
| 1291 f'{time()-t0:.2f} seconds') | |
| 1292 title = f'edges row{row} center_offset{center_offset:.2f}' | |
| 1293 self._plot_edges_one_plane(recon_plane, title, path=path) | |
| 1294 if input_int('\nContinue (0) or end the search (1)', ge=0, le=1): | |
| 1295 break | |
| 1296 | |
| 1297 del sinogram_T | |
| 1298 del recon_plane | |
| 1299 center_offset = input_num(' Enter chosen center offset', ge=-center, le=center) | |
| 1300 return float(center_offset) | |
| 1301 | |
| 1302 def _reconstruct_one_plane(self, tomo_plane_T, center, thetas, eff_pixel_size, | |
| 1303 cross_sectional_dim, plot_sinogram=True, num_core=1): | |
| 1304 """Invert the sinogram for a single tomography plane. | |
| 1305 """ | |
| 1306 # tomo_plane_T index order: column,theta | |
| 1307 assert(0 <= center < tomo_plane_T.shape[0]) | |
| 1308 center_offset = center-tomo_plane_T.shape[0]/2 | |
| 1309 two_offset = 2*int(np.round(center_offset)) | |
| 1310 two_offset_abs = np.abs(two_offset) | |
| 1311 max_rad = int(0.55*(cross_sectional_dim/eff_pixel_size)) # 10% slack to avoid edge effects | |
| 1312 if max_rad > 0.5*tomo_plane_T.shape[0]: | |
| 1313 max_rad = 0.5*tomo_plane_T.shape[0] | |
| 1314 dist_from_edge = max(1, int(np.floor((tomo_plane_T.shape[0]-two_offset_abs)/2.)-max_rad)) | |
| 1315 if two_offset >= 0: | |
| 1316 logger.debug(f'sinogram range = [{two_offset+dist_from_edge}, {-dist_from_edge}]') | |
| 1317 sinogram = tomo_plane_T[two_offset+dist_from_edge:-dist_from_edge,:] | |
| 1318 else: | |
| 1319 logger.debug(f'sinogram range = [{dist_from_edge}, {two_offset-dist_from_edge}]') | |
| 1320 sinogram = tomo_plane_T[dist_from_edge:two_offset-dist_from_edge,:] | |
| 1321 if not self.galaxy_flag and plot_sinogram: | |
| 1322 quick_imshow(sinogram.T, f'sinogram center offset{center_offset:.2f}', aspect='auto', | |
| 1323 path=self.output_folder, save_fig=self.save_figs, save_only=self.save_only, | |
| 1324 block=self.block) | |
| 1325 | |
| 1326 # Inverting sinogram | |
| 1327 t0 = time() | |
| 1328 recon_sinogram = iradon(sinogram, theta=thetas, circle=True) | |
| 1329 logger.debug(f'Inverting sinogram took {time()-t0:.2f} seconds') | |
| 1330 del sinogram | |
| 1331 | |
| 1332 # Performing Gaussian filtering and removing ring artifacts | |
| 1333 recon_parameters = None#self.config.get('recon_parameters') | |
| 1334 if recon_parameters is None: | |
| 1335 sigma = 1.0 | |
| 1336 ring_width = 15 | |
| 1337 else: | |
| 1338 sigma = recon_parameters.get('gaussian_sigma', 1.0) | |
| 1339 if not is_num(sigma, ge=0.0): | |
| 1340 logger.warning(f'Invalid gaussian_sigma ({sigma}) in _reconstruct_one_plane, '+ | |
| 1341 'set to a default value of 1.0') | |
| 1342 sigma = 1.0 | |
| 1343 ring_width = recon_parameters.get('ring_width', 15) | |
| 1344 if not is_int(ring_width, ge=0): | |
| 1345 logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+ | |
| 1346 'set to a default value of 15') | |
| 1347 ring_width = 15 | |
| 1348 t0 = time() | |
| 1349 recon_sinogram = spi.gaussian_filter(recon_sinogram, sigma, mode='nearest') | |
| 1350 recon_clean = np.expand_dims(recon_sinogram, axis=0) | |
| 1351 del recon_sinogram | |
| 1352 recon_clean = tomopy.misc.corr.remove_ring(recon_clean, rwidth=ring_width, ncore=num_core) | |
| 1353 logger.debug(f'Filtering and removing ring artifacts took {time()-t0:.2f} seconds') | |
| 1354 | |
| 1355 return recon_clean | |
| 1356 | |
| 1357 def _plot_edges_one_plane(self, recon_plane, title, path=None): | |
| 1358 vis_parameters = None#self.config.get('vis_parameters') | |
| 1359 if vis_parameters is None: | |
| 1360 weight = 0.1 | |
| 1361 else: | |
| 1362 weight = vis_parameters.get('denoise_weight', 0.1) | |
| 1363 if not is_num(weight, ge=0.0): | |
| 1364 logger.warning(f'Invalid weight ({weight}) in _plot_edges_one_plane, '+ | |
| 1365 'set to a default value of 0.1') | |
| 1366 weight = 0.1 | |
| 1367 edges = denoise_tv_chambolle(recon_plane, weight=weight) | |
| 1368 vmax = np.max(edges[0,:,:]) | |
| 1369 vmin = -vmax | |
| 1370 if path is None: | |
| 1371 path = self.output_folder | |
| 1372 quick_imshow(edges[0,:,:], f'{title} coolwarm', path=path, cmap='coolwarm', | |
| 1373 save_fig=self.save_figs, save_only=self.save_only, block=self.block) | |
| 1374 quick_imshow(edges[0,:,:], f'{title} gray', path=path, cmap='gray', vmin=vmin, vmax=vmax, | |
| 1375 save_fig=self.save_figs, save_only=self.save_only, block=self.block) | |
| 1376 del edges | |
| 1377 | |
| 1378 def _reconstruct_one_tomo_stack(self, tomo_stack, thetas, center_offsets=[], num_core=1, | |
| 1379 algorithm='gridrec'): | |
| 1380 """Reconstruct a single tomography stack. | |
| 1381 """ | |
| 1382 # tomo_stack order: row,theta,column | |
| 1383 # input thetas must be in degrees | |
| 1384 # centers_offset: tomography axis shift in pixels relative to column center | |
| 1385 # RV should we remove stripes? | |
| 1386 # https://tomopy.readthedocs.io/en/latest/api/tomopy.prep.stripe.html | |
| 1387 # RV should we remove rings? | |
| 1388 # https://tomopy.readthedocs.io/en/latest/api/tomopy.misc.corr.html | |
| 1389 # RV: Add an option to do (extra) secondary iterations later or to do some sort of convergence test? | |
| 1390 if not len(center_offsets): | |
| 1391 centers = np.zeros((tomo_stack.shape[0])) | |
| 1392 elif len(center_offsets) == 2: | |
| 1393 centers = np.linspace(center_offsets[0], center_offsets[1], tomo_stack.shape[0]) | |
| 1394 else: | |
| 1395 if center_offsets.size != tomo_stack.shape[0]: | |
| 1396 raise ValueError('center_offsets dimension mismatch in reconstruct_one_tomo_stack') | |
| 1397 centers = center_offsets | |
| 1398 centers += tomo_stack.shape[2]/2 | |
| 1399 | |
| 1400 # Get reconstruction parameters | |
| 1401 recon_parameters = None#self.config.get('recon_parameters') | |
| 1402 if recon_parameters is None: | |
| 1403 sigma = 2.0 | |
| 1404 secondary_iters = 0 | |
| 1405 ring_width = 15 | |
| 1406 else: | |
| 1407 sigma = recon_parameters.get('stripe_fw_sigma', 2.0) | |
| 1408 if not is_num(sigma, ge=0): | |
| 1409 logger.warning(f'Invalid stripe_fw_sigma ({sigma}) in '+ | |
| 1410 '_reconstruct_one_tomo_stack, set to a default value of 2.0') | |
| 1411 ring_width = 15 | |
| 1412 secondary_iters = recon_parameters.get('secondary_iters', 0) | |
| 1413 if not is_int(secondary_iters, ge=0): | |
| 1414 logger.warning(f'Invalid secondary_iters ({secondary_iters}) in '+ | |
| 1415 '_reconstruct_one_tomo_stack, set to a default value of 0 (skip them)') | |
| 1416 ring_width = 0 | |
| 1417 ring_width = recon_parameters.get('ring_width', 15) | |
| 1418 if not is_int(ring_width, ge=0): | |
| 1419 logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+ | |
| 1420 'set to a default value of 15') | |
| 1421 ring_width = 15 | |
| 1422 | |
| 1423 # Remove horizontal stripe | |
| 1424 t0 = time() | |
| 1425 if num_core > num_core_tomopy_limit: | |
| 1426 logger.debug('Running remove_stripe_fw on {num_core_tomopy_limit} cores ...') | |
| 1427 tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma, | |
| 1428 ncore=num_core_tomopy_limit) | |
| 1429 else: | |
| 1430 logger.debug(f'Running remove_stripe_fw on {num_core} cores ...') | |
| 1431 tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma, | |
| 1432 ncore=num_core) | |
| 1433 logger.debug(f'... tomopy.prep.stripe.remove_stripe_fw took {time()-t0:.2f} seconds') | |
| 1434 | |
| 1435 # Perform initial image reconstruction | |
| 1436 logger.debug('Performing initial image reconstruction') | |
| 1437 t0 = time() | |
| 1438 logger.debug(f'Running recon on {num_core} cores ...') | |
| 1439 tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers, | |
| 1440 sinogram_order=True, algorithm=algorithm, ncore=num_core) | |
| 1441 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 1442 logger.info(f'Performing initial image reconstruction took {time()-t0:.2f} seconds') | |
| 1443 | |
| 1444 # Run optional secondary iterations | |
| 1445 if secondary_iters > 0: | |
| 1446 logger.debug(f'Running {secondary_iters} secondary iterations') | |
| 1447 #options = {'method':'SIRT_CUDA', 'proj_type':'cuda', 'num_iter':secondary_iters} | |
| 1448 #RV: doesn't work for me: | |
| 1449 #"Error: CUDA error 803: system has unsupported display driver/cuda driver combination." | |
| 1450 #options = {'method':'SIRT', 'proj_type':'linear', 'MinConstraint': 0, 'num_iter':secondary_iters} | |
| 1451 #SIRT did not finish while running overnight | |
| 1452 #options = {'method':'SART', 'proj_type':'linear', 'num_iter':secondary_iters} | |
| 1453 options = {'method':'SART', 'proj_type':'linear', 'MinConstraint': 0, | |
| 1454 'num_iter':secondary_iters} | |
| 1455 t0 = time() | |
| 1456 logger.debug(f'Running recon on {num_core} cores ...') | |
| 1457 tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers, | |
| 1458 init_recon=tomo_recon_stack, options=options, sinogram_order=True, | |
| 1459 algorithm=tomopy.astra, ncore=num_core) | |
| 1460 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
| 1461 logger.info(f'Performing secondary iterations took {time()-t0:.2f} seconds') | |
| 1462 | |
| 1463 # Remove ring artifacts | |
| 1464 t0 = time() | |
| 1465 tomopy.misc.corr.remove_ring(tomo_recon_stack, rwidth=ring_width, out=tomo_recon_stack, | |
| 1466 ncore=num_core) | |
| 1467 logger.debug(f'Removing ring artifacts took {time()-t0:.2f} seconds') | |
| 1468 | |
| 1469 return tomo_recon_stack | |
| 1470 | |
| 1471 def _resize_reconstructed_data(self, data, z_only=False): | |
| 1472 """Resize the reconstructed tomography data. | |
| 1473 """ | |
| 1474 # Data order: row(z),x,y or stack,row(z),x,y | |
| 1475 if isinstance(data, list): | |
| 1476 for stack in data: | |
| 1477 assert(stack.ndim == 3) | |
| 1478 num_tomo_stacks = len(data) | |
| 1479 tomo_recon_stacks = data | |
| 1480 else: | |
| 1481 assert(data.ndim == 3) | |
| 1482 num_tomo_stacks = 1 | |
| 1483 tomo_recon_stacks = [data] | |
| 1484 | |
| 1485 if z_only: | |
| 1486 x_bounds = None | |
| 1487 y_bounds = None | |
| 1488 else: | |
| 1489 # Selecting x bounds (in yz-plane) | |
| 1490 tomosum = 0 | |
| 1491 [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,2)) | |
| 1492 for i in range(num_tomo_stacks)] | |
| 1493 select_x_bounds = input_yesno('\nDo you want to change the image x-bounds (y/n)?', 'y') | |
| 1494 if not select_x_bounds: | |
| 1495 x_bounds = None | |
| 1496 else: | |
| 1497 accept = False | |
| 1498 index_ranges = None | |
| 1499 while not accept: | |
| 1500 mask, x_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, | |
| 1501 title='select x data range', legend='recon stack sum yz') | |
| 1502 while len(x_bounds) != 1: | |
| 1503 print('Please select exactly one continuous range') | |
| 1504 mask, x_bounds = draw_mask_1d(tomosum, title='select x data range', | |
| 1505 legend='recon stack sum yz') | |
| 1506 x_bounds = x_bounds[0] | |
| 1507 # quick_plot(tomosum, vlines=x_bounds, title='recon stack sum yz') | |
| 1508 # print(f'x_bounds = {x_bounds} (lower bound inclusive, upper bound '+ | |
| 1509 # 'exclusive)') | |
| 1510 # accept = input_yesno('Accept these bounds (y/n)?', 'y') | |
| 1511 accept = True | |
| 1512 logger.debug(f'x_bounds = {x_bounds}') | |
| 1513 | |
| 1514 # Selecting y bounds (in xz-plane) | |
| 1515 tomosum = 0 | |
| 1516 [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,1)) | |
| 1517 for i in range(num_tomo_stacks)] | |
| 1518 select_y_bounds = input_yesno('\nDo you want to change the image y-bounds (y/n)?', 'y') | |
| 1519 if not select_y_bounds: | |
| 1520 y_bounds = None | |
| 1521 else: | |
| 1522 accept = False | |
| 1523 index_ranges = None | |
| 1524 while not accept: | |
| 1525 mask, y_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, | |
| 1526 title='select x data range', legend='recon stack sum xz') | |
| 1527 while len(y_bounds) != 1: | |
| 1528 print('Please select exactly one continuous range') | |
| 1529 mask, y_bounds = draw_mask_1d(tomosum, title='select x data range', | |
| 1530 legend='recon stack sum xz') | |
| 1531 y_bounds = y_bounds[0] | |
| 1532 # quick_plot(tomosum, vlines=y_bounds, title='recon stack sum xz') | |
| 1533 # print(f'y_bounds = {y_bounds} (lower bound inclusive, upper bound '+ | |
| 1534 # 'exclusive)') | |
| 1535 # accept = input_yesno('Accept these bounds (y/n)?', 'y') | |
| 1536 accept = True | |
| 1537 logger.debug(f'y_bounds = {y_bounds}') | |
| 1538 | |
| 1539 # Selecting z bounds (in xy-plane) (only valid for a single image stack) | |
| 1540 if num_tomo_stacks != 1: | |
| 1541 z_bounds = None | |
| 1542 else: | |
| 1543 tomosum = 0 | |
| 1544 [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(1,2)) | |
| 1545 for i in range(num_tomo_stacks)] | |
| 1546 select_z_bounds = input_yesno('Do you want to change the image z-bounds (y/n)?', 'n') | |
| 1547 if not select_z_bounds: | |
| 1548 z_bounds = None | |
| 1549 else: | |
| 1550 accept = False | |
| 1551 index_ranges = None | |
| 1552 while not accept: | |
| 1553 mask, z_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, | |
| 1554 title='select x data range', legend='recon stack sum xy') | |
| 1555 while len(z_bounds) != 1: | |
| 1556 print('Please select exactly one continuous range') | |
| 1557 mask, z_bounds = draw_mask_1d(tomosum, title='select x data range', | |
| 1558 legend='recon stack sum xy') | |
| 1559 z_bounds = z_bounds[0] | |
| 1560 # quick_plot(tomosum, vlines=z_bounds, title='recon stack sum xy') | |
| 1561 # print(f'z_bounds = {z_bounds} (lower bound inclusive, upper bound '+ | |
| 1562 # 'exclusive)') | |
| 1563 # accept = input_yesno('Accept these bounds (y/n)?', 'y') | |
| 1564 accept = True | |
| 1565 logger.debug(f'z_bounds = {z_bounds}') | |
| 1566 | |
| 1567 return(x_bounds, y_bounds, z_bounds) | |
| 1568 | |
| 1569 | |
| 1570 def run_tomo(input_file:str, output_file:str, modes:list[str], center_file=None, num_core=-1, | |
| 1571 output_folder='.', save_figs='no', test_mode=False) -> None: | |
| 1572 | |
| 1573 if test_mode: | |
| 1574 logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s' | |
| 1575 level = logging.getLevelName('INFO') | |
| 1576 logging.basicConfig(filename=f'{output_folder}/tomo.log', filemode='w', | |
| 1577 format=logging_format, level=level, force=True) | |
| 1578 logger.info(f'input_file = {input_file}') | |
| 1579 logger.info(f'center_file = {center_file}') | |
| 1580 logger.info(f'output_file = {output_file}') | |
| 1581 logger.debug(f'modes= {modes}') | |
| 1582 logger.debug(f'num_core= {num_core}') | |
| 1583 logger.info(f'output_folder = {output_folder}') | |
| 1584 logger.info(f'save_figs = {save_figs}') | |
| 1585 logger.info(f'test_mode = {test_mode}') | |
| 1586 | |
| 1587 # Check for correction modes | |
| 1588 legal_modes = ['reduce_data', 'find_center', 'reconstruct_data', 'combine_data', 'all'] | |
| 1589 if modes is None: | |
| 1590 modes = ['all'] | |
| 1591 if not all(True if mode in legal_modes else False for mode in modes): | |
| 1592 raise ValueError(f'Invalid parameter modes ({modes})') | |
| 1593 | |
| 1594 # Instantiate Tomo object | |
| 1595 tomo = Tomo(num_core=num_core, output_folder=output_folder, save_figs=save_figs, | |
| 1596 test_mode=test_mode) | |
| 1597 | |
| 1598 # Read input file | |
| 1599 data = tomo.read(input_file) | |
| 1600 | |
| 1601 # Generate reduced tomography images | |
| 1602 if 'reduce_data' in modes or 'all' in modes: | |
| 1603 data = tomo.gen_reduced_data(data) | |
| 1604 | |
| 1605 # Find rotation axis centers for the tomography stacks. | |
| 1606 center_data = None | |
| 1607 if 'find_center' in modes or 'all' in modes: | |
| 1608 center_data = tomo.find_centers(data) | |
| 1609 | |
| 1610 # Reconstruct tomography stacks | |
| 1611 if 'reconstruct_data' in modes or 'all' in modes: | |
| 1612 if center_data is None: | |
| 1613 # Read input file | |
| 1614 center_data = tomo.read(center_file) | |
| 1615 data = tomo.reconstruct_data(data, center_data) | |
| 1616 center_data = None | |
| 1617 | |
| 1618 # Combine reconstructed tomography stacks | |
| 1619 if 'combine_data' in modes or 'all' in modes: | |
| 1620 data = tomo.combine_data(data) | |
| 1621 | |
| 1622 # Write output file | |
| 1623 if data is not None and not test_mode: | |
| 1624 if center_data is None: | |
| 1625 data = tomo.write(data, output_file) | |
| 1626 else: | |
| 1627 data = tomo.write(center_data, output_file) | |
| 1628 | |
| 1629 logger.info(f'Completed modes: {modes}') |
