comparison workflow/run_tomo.py @ 0:98e23dff1de2 draft default tip

planemo upload for repository https://github.com/rolfverberg/galaxytools commit f8c4bdb31c20c468045ad5e6eb255a293244bc6c-dirty
author rv43
date Tue, 21 Mar 2023 16:22:42 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:98e23dff1de2
1 #!/usr/bin/env python3
2
3 import logging
4 logger = logging.getLogger(__name__)
5
6 import numpy as np
7 try:
8 import numexpr as ne
9 except:
10 pass
11 try:
12 import scipy.ndimage as spi
13 except:
14 pass
15
16 from multiprocessing import cpu_count
17 from nexusformat.nexus import *
18 from os import mkdir
19 from os import path as os_path
20 try:
21 from skimage.transform import iradon
22 except:
23 pass
24 try:
25 from skimage.restoration import denoise_tv_chambolle
26 except:
27 pass
28 from time import time
29 try:
30 import tomopy
31 except:
32 pass
33 from yaml import safe_load, safe_dump
34
35 try:
36 from msnctools.fit import Fit
37 except:
38 from fit import Fit
39 try:
40 from msnctools.general import illegal_value, is_int, is_int_pair, is_num, is_index_range, \
41 input_int, input_num, input_yesno, input_menu, draw_mask_1d, select_image_bounds, \
42 select_one_image_bound, clear_imshow, quick_imshow, clear_plot, quick_plot
43 except:
44 from general import illegal_value, is_int, is_int_pair, is_num, is_index_range, \
45 input_int, input_num, input_yesno, input_menu, draw_mask_1d, select_image_bounds, \
46 select_one_image_bound, clear_imshow, quick_imshow, clear_plot, quick_plot
47
48 try:
49 from workflow.models import import_scanparser, FlatField, TomoField, TomoWorkflow
50 from workflow.__version__ import __version__
51 except:
52 pass
53
54 num_core_tomopy_limit = 24
55
56 def nxcopy(nxobject:NXobject, exclude_nxpaths:list[str]=[], nxpath_prefix:str='') -> NXobject:
57 '''Function that returns a copy of a nexus object, optionally exluding certain child items.
58
59 :param nxobject: the original nexus object to return a "copy" of
60 :type nxobject: nexusformat.nexus.NXobject
61 :param exlude_nxpaths: a list of paths to child nexus objects that
62 should be exluded from the returned "copy", defaults to `[]`
63 :type exclude_nxpaths: list[str], optional
64 :param nxpath_prefix: For use in recursive calls from inside this
65 function only!
66 :type nxpath_prefix: str
67 :return: a copy of `nxobject` with some children optionally exluded.
68 :rtype: NXobject
69 '''
70
71 nxobject_copy = nxobject.__class__()
72 if not len(nxpath_prefix):
73 if 'default' in nxobject.attrs:
74 nxobject_copy.attrs['default'] = nxobject.attrs['default']
75 else:
76 for k, v in nxobject.attrs.items():
77 nxobject_copy.attrs[k] = v
78
79 for k, v in nxobject.items():
80 nxpath = os_path.join(nxpath_prefix, k)
81
82 if nxpath in exclude_nxpaths:
83 continue
84
85 if isinstance(v, NXgroup):
86 nxobject_copy[k] = nxcopy(v, exclude_nxpaths=exclude_nxpaths,
87 nxpath_prefix=os_path.join(nxpath_prefix, k))
88 else:
89 nxobject_copy[k] = v
90
91 return(nxobject_copy)
92
93 class set_numexpr_threads:
94
95 def __init__(self, num_core):
96 if num_core is None or num_core < 1 or num_core > cpu_count():
97 self.num_core = cpu_count()
98 else:
99 self.num_core = num_core
100
101 def __enter__(self):
102 self.num_core_org = ne.set_num_threads(self.num_core)
103
104 def __exit__(self, exc_type, exc_value, traceback):
105 ne.set_num_threads(self.num_core_org)
106
107 class Tomo:
108 """Processing tomography data with misalignment.
109 """
110 def __init__(self, galaxy_flag=False, num_core=-1, output_folder='.', save_figs=None,
111 test_mode=False):
112 """Initialize with optional config input file or dictionary
113 """
114 if not isinstance(galaxy_flag, bool):
115 raise ValueError(f'Invalid parameter galaxy_flag ({galaxy_flag})')
116 self.galaxy_flag = galaxy_flag
117 self.num_core = num_core
118 if self.galaxy_flag:
119 if output_folder != '.':
120 logger.warning('Ignoring output_folder in galaxy mode')
121 self.output_folder = '.'
122 if test_mode != False:
123 logger.warning('Ignoring test_mode in galaxy mode')
124 self.test_mode = False
125 if save_figs is not None:
126 logger.warning('Ignoring save_figs in galaxy mode')
127 save_figs = 'only'
128 else:
129 self.output_folder = os_path.abspath(output_folder)
130 if not os_path.isdir(output_folder):
131 mkdir(os_path.abspath(output_folder))
132 if not isinstance(test_mode, bool):
133 raise ValueError(f'Invalid parameter test_mode ({test_mode})')
134 self.test_mode = test_mode
135 if save_figs is None:
136 save_figs = 'no'
137 self.test_config = {}
138 if self.test_mode:
139 if save_figs != 'only':
140 logger.warning('Ignoring save_figs in test mode')
141 save_figs = 'only'
142 if save_figs == 'only':
143 self.save_only = True
144 self.save_figs = True
145 elif save_figs == 'yes':
146 self.save_only = False
147 self.save_figs = True
148 elif save_figs == 'no':
149 self.save_only = False
150 self.save_figs = False
151 else:
152 raise ValueError(f'Invalid parameter save_figs ({save_figs})')
153 if self.save_only:
154 self.block = False
155 else:
156 self.block = True
157 if self.num_core == -1:
158 self.num_core = cpu_count()
159 if not is_int(self.num_core, gt=0, log=False):
160 raise ValueError(f'Invalid parameter num_core ({num_core})')
161 if self.num_core > cpu_count():
162 logger.warning(f'num_core = {self.num_core} is larger than the number of available '
163 f'processors and reduced to {cpu_count()}')
164 self.num_core= cpu_count()
165
166 def read(self, filename):
167 logger.info(f'looking for {filename}')
168 if self.galaxy_flag:
169 try:
170 with open(filename, 'r') as f:
171 config = safe_load(f)
172 return(config)
173 except:
174 try:
175 with NXFile(filename, mode='r') as nxfile:
176 nxroot = nxfile.readfile()
177 return(nxroot)
178 except:
179 raise ValueError(f'Unable to open ({filename})')
180 else:
181 extension = os_path.splitext(filename)[1]
182 if extension == '.yml' or extension == '.yaml':
183 with open(filename, 'r') as f:
184 config = safe_load(f)
185 # if len(config) > 1:
186 # raise ValueError(f'Multiple root entries in {filename} not yet implemented')
187 # if len(list(config.values())[0]) > 1:
188 # raise ValueError(f'Multiple sample maps in {filename} not yet implemented')
189 return(config)
190 elif extension == '.nxs':
191 with NXFile(filename, mode='r') as nxfile:
192 nxroot = nxfile.readfile()
193 return(nxroot)
194 else:
195 raise ValueError(f'Invalid filename extension ({extension})')
196
197 def write(self, data, filename):
198 extension = os_path.splitext(filename)[1]
199 if extension == '.yml' or extension == '.yaml':
200 with open(filename, 'w') as f:
201 safe_dump(data, f)
202 elif extension == '.nxs' or extension == '.nex':
203 data.save(filename, mode='w')
204 elif extension == '.nc':
205 data.to_netcdf(os_path=filename)
206 else:
207 raise ValueError(f'Invalid filename extension ({extension})')
208
209 def gen_reduced_data(self, data, img_x_bounds=None):
210 """Generate the reduced tomography images.
211 """
212 logger.info('Generate the reduced tomography images')
213
214 # Create plot galaxy path directory if needed
215 if self.galaxy_flag and not os_path.exists('tomo_reduce_plots'):
216 mkdir('tomo_reduce_plots')
217
218 if isinstance(data, dict):
219 # Create Nexus format object from input dictionary
220 wf = TomoWorkflow(**data)
221 if len(wf.sample_maps) > 1:
222 raise ValueError(f'Multiple sample maps not yet implemented')
223 # print(f'\nwf:\n{wf}\n')
224 nxroot = NXroot()
225 t0 = time()
226 for sample_map in wf.sample_maps:
227 logger.info(f'Start constructing the {sample_map.title} map.')
228 import_scanparser(sample_map.station)
229 sample_map.construct_nxentry(nxroot, include_raw_data=False)
230 logger.info(f'Constructed all sample maps in {time()-t0:.2f} seconds.')
231 nxentry = nxroot[nxroot.attrs['default']]
232 # Get test mode configuration info
233 if self.test_mode:
234 self.test_config = data['sample_maps'][0]['test_mode']
235 elif isinstance(data, NXroot):
236 nxentry = data[data.attrs['default']]
237 else:
238 raise ValueError(f'Invalid parameter data ({data})')
239
240 # Create an NXprocess to store data reduction (meta)data
241 reduced_data = NXprocess()
242
243 # Generate dark field
244 if 'dark_field' in nxentry['spec_scans']:
245 reduced_data = self._gen_dark(nxentry, reduced_data)
246
247 # Generate bright field
248 reduced_data = self._gen_bright(nxentry, reduced_data)
249
250 # Set vertical detector bounds for image stack
251 img_x_bounds = self._set_detector_bounds(nxentry, reduced_data, img_x_bounds=img_x_bounds)
252 logger.info(f'img_x_bounds = {img_x_bounds}')
253 reduced_data['img_x_bounds'] = img_x_bounds
254
255 # Set zoom and/or theta skip to reduce memory the requirement
256 zoom_perc, num_theta_skip = self._set_zoom_or_skip()
257 if zoom_perc is not None:
258 reduced_data.attrs['zoom_perc'] = zoom_perc
259 if num_theta_skip is not None:
260 reduced_data.attrs['num_theta_skip'] = num_theta_skip
261
262 # Generate reduced tomography fields
263 reduced_data = self._gen_tomo(nxentry, reduced_data)
264
265 # Create a copy of the input Nexus object and remove raw and any existing reduced data
266 if isinstance(data, NXroot):
267 exclude_items = [f'{nxentry._name}/reduced_data/data',
268 f'{nxentry._name}/instrument/detector/data',
269 f'{nxentry._name}/instrument/detector/image_key',
270 f'{nxentry._name}/instrument/detector/sequence_number',
271 f'{nxentry._name}/sample/rotation_angle',
272 f'{nxentry._name}/sample/x_translation',
273 f'{nxentry._name}/sample/z_translation',
274 f'{nxentry._name}/data/data',
275 f'{nxentry._name}/data/image_key',
276 f'{nxentry._name}/data/rotation_angle',
277 f'{nxentry._name}/data/x_translation',
278 f'{nxentry._name}/data/z_translation']
279 nxroot = nxcopy(data, exclude_nxpaths=exclude_items)
280 nxentry = nxroot[nxroot.attrs['default']]
281
282 # Add the reduced data NXprocess
283 nxentry.reduced_data = reduced_data
284
285 if 'data' not in nxentry:
286 nxentry.data = NXdata()
287 nxentry.attrs['default'] = 'data'
288 nxentry.data.makelink(nxentry.reduced_data.data.tomo_fields, name='reduced_data')
289 nxentry.data.makelink(nxentry.reduced_data.rotation_angle, name='rotation_angle')
290 nxentry.data.attrs['signal'] = 'reduced_data'
291
292 return(nxroot)
293
294 def find_centers(self, nxroot, center_rows=None, center_stack_index=None):
295 """Find the calibrated center axis info
296 """
297 logger.info('Find the calibrated center axis info')
298
299 if not isinstance(nxroot, NXroot):
300 raise ValueError(f'Invalid parameter nxroot ({nxroot})')
301 nxentry = nxroot[nxroot.attrs['default']]
302 if not isinstance(nxentry, NXentry):
303 raise ValueError(f'Invalid nxentry ({nxentry})')
304 if self.galaxy_flag:
305 if center_rows is not None:
306 center_rows = tuple(center_rows)
307 if not is_int_pair(center_rows):
308 raise ValueError(f'Invalid parameter center_rows ({center_rows})')
309 elif center_rows is not None:
310 logger.warning(f'Ignoring parameter center_rows ({center_rows})')
311 center_rows = None
312 if self.galaxy_flag:
313 if center_stack_index is not None and not is_int(center_stack_index, ge=0):
314 raise ValueError(f'Invalid parameter center_stack_index ({center_stack_index})')
315 elif center_stack_index is not None:
316 logger.warning(f'Ignoring parameter center_stack_index ({center_stack_index})')
317 center_stack_index = None
318
319 # Create plot galaxy path directory and path if needed
320 if self.galaxy_flag:
321 if not os_path.exists('tomo_find_centers_plots'):
322 mkdir('tomo_find_centers_plots')
323 path = 'tomo_find_centers_plots'
324 else:
325 path = self.output_folder
326
327 # Check if reduced data is available
328 if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data):
329 raise KeyError(f'Unable to find valid reduced data in {nxentry}.')
330
331 # Select the image stack to calibrate the center axis
332 # reduced data axes order: stack,theta,row,column
333 # Note: Nexus cannot follow a link if the data it points to is too big,
334 # so get the data from the actual place, not from nxentry.data
335 tomo_fields_shape = nxentry.reduced_data.data.tomo_fields.shape
336 if len(tomo_fields_shape) != 4 or any(True for dim in tomo_fields_shape if not dim):
337 raise KeyError('Unable to load the required reduced tomography stack')
338 num_tomo_stacks = tomo_fields_shape[0]
339 if num_tomo_stacks == 1:
340 center_stack_index = 0
341 default = 'n'
342 else:
343 if self.test_mode:
344 center_stack_index = self.test_config['center_stack_index']-1 # make offset 0
345 elif self.galaxy_flag:
346 if center_stack_index is None:
347 center_stack_index = int(num_tomo_stacks/2)
348 if center_stack_index >= num_tomo_stacks:
349 raise ValueError(f'Invalid parameter center_stack_index ({center_stack_index})')
350 else:
351 center_stack_index = input_int('\nEnter tomography stack index to calibrate the '
352 'center axis', ge=1, le=num_tomo_stacks, default=int(1+num_tomo_stacks/2))
353 center_stack_index -= 1
354 default = 'y'
355
356 # Get thetas (in degrees)
357 thetas = np.asarray(nxentry.reduced_data.rotation_angle)
358
359 # Get effective pixel_size
360 if 'zoom_perc' in nxentry.reduced_data:
361 eff_pixel_size = 100.*(nxentry.instrument.detector.x_pixel_size/
362 nxentry.reduced_data.attrs['zoom_perc'])
363 else:
364 eff_pixel_size = nxentry.instrument.detector.x_pixel_size
365
366 # Get cross sectional diameter
367 cross_sectional_dim = tomo_fields_shape[3]*eff_pixel_size
368 logger.debug(f'cross_sectional_dim = {cross_sectional_dim}')
369
370 # Determine center offset at sample row boundaries
371 logger.info('Determine center offset at sample row boundaries')
372
373 # Lower row center
374 if self.test_mode:
375 lower_row = self.test_config['lower_row']
376 elif self.galaxy_flag:
377 if center_rows is None or center_rows[0] == -1:
378 lower_row = 0
379 else:
380 lower_row = min(center_rows)
381 if not 0 <= lower_row < tomo_fields_shape[2]-1:
382 raise ValueError(f'Invalid parameter center_rows ({center_rows})')
383 else:
384 lower_row = select_one_image_bound(
385 nxentry.reduced_data.data.tomo_fields[center_stack_index,0,:,:], 0, bound=0,
386 title=f'theta={round(thetas[0], 2)+0}',
387 bound_name='row index to find lower center', default=default, raise_error=True)
388 logger.debug('Finding center...')
389 t0 = time()
390 lower_center_offset = self._find_center_one_plane(
391 #np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index,:,lower_row,:]),
392 nxentry.reduced_data.data.tomo_fields[center_stack_index,:,lower_row,:],
393 lower_row, thetas, eff_pixel_size, cross_sectional_dim, path=path,
394 num_core=self.num_core)
395 logger.debug(f'... done in {time()-t0:.2f} seconds')
396 logger.debug(f'lower_row = {lower_row:.2f}')
397 logger.debug(f'lower_center_offset = {lower_center_offset:.2f}')
398
399 # Upper row center
400 if self.test_mode:
401 upper_row = self.test_config['upper_row']
402 elif self.galaxy_flag:
403 if center_rows is None or center_rows[1] == -1:
404 upper_row = tomo_fields_shape[2]-1
405 else:
406 upper_row = max(center_rows)
407 if not lower_row < upper_row < tomo_fields_shape[2]:
408 raise ValueError(f'Invalid parameter center_rows ({center_rows})')
409 else:
410 upper_row = select_one_image_bound(
411 nxentry.reduced_data.data.tomo_fields[center_stack_index,0,:,:], 0,
412 bound=tomo_fields_shape[2]-1, title=f'theta={round(thetas[0], 2)+0}',
413 bound_name='row index to find upper center', default=default, raise_error=True)
414 logger.debug('Finding center...')
415 t0 = time()
416 upper_center_offset = self._find_center_one_plane(
417 #np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index,:,upper_row,:]),
418 nxentry.reduced_data.data.tomo_fields[center_stack_index,:,upper_row,:],
419 upper_row, thetas, eff_pixel_size, cross_sectional_dim, path=path,
420 num_core=self.num_core)
421 logger.debug(f'... done in {time()-t0:.2f} seconds')
422 logger.debug(f'upper_row = {upper_row:.2f}')
423 logger.debug(f'upper_center_offset = {upper_center_offset:.2f}')
424
425 center_config = {'lower_row': lower_row, 'lower_center_offset': lower_center_offset,
426 'upper_row': upper_row, 'upper_center_offset': upper_center_offset}
427 if num_tomo_stacks > 1:
428 center_config['center_stack_index'] = center_stack_index+1 # save as offset 1
429
430 # Save test data to file
431 if self.test_mode:
432 with open(f'{self.output_folder}/center_config.yaml', 'w') as f:
433 safe_dump(center_config, f)
434
435 return(center_config)
436
437 def reconstruct_data(self, nxroot, center_info, x_bounds=None, y_bounds=None):
438 """Reconstruct the tomography data.
439 """
440 logger.info('Reconstruct the tomography data')
441
442 if not isinstance(nxroot, NXroot):
443 raise ValueError(f'Invalid parameter nxroot ({nxroot})')
444 nxentry = nxroot[nxroot.attrs['default']]
445 if not isinstance(nxentry, NXentry):
446 raise ValueError(f'Invalid nxentry ({nxentry})')
447 if not isinstance(center_info, dict):
448 raise ValueError(f'Invalid parameter center_info ({center_info})')
449
450 # Create plot galaxy path directory and path if needed
451 if self.galaxy_flag:
452 if not os_path.exists('tomo_reconstruct_plots'):
453 mkdir('tomo_reconstruct_plots')
454 path = 'tomo_reconstruct_plots'
455 else:
456 path = self.output_folder
457
458 # Check if reduced data is available
459 if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data):
460 raise KeyError(f'Unable to find valid reduced data in {nxentry}.')
461
462 # Create an NXprocess to store image reconstruction (meta)data
463 nxprocess = NXprocess()
464
465 # Get rotation axis rows and centers
466 lower_row = center_info.get('lower_row')
467 lower_center_offset = center_info.get('lower_center_offset')
468 upper_row = center_info.get('upper_row')
469 upper_center_offset = center_info.get('upper_center_offset')
470 if (lower_row is None or lower_center_offset is None or upper_row is None or
471 upper_center_offset is None):
472 raise KeyError(f'Unable to find valid calibrated center axis info in {center_info}.')
473 center_slope = (upper_center_offset-lower_center_offset)/(upper_row-lower_row)
474
475 # Get thetas (in degrees)
476 thetas = np.asarray(nxentry.reduced_data.rotation_angle)
477
478 # Reconstruct tomography data
479 # reduced data axes order: stack,theta,row,column
480 # reconstructed data order in each stack: row/z,x,y
481 # Note: Nexus cannot follow a link if the data it points to is too big,
482 # so get the data from the actual place, not from nxentry.data
483 if 'zoom_perc' in nxentry.reduced_data:
484 res_title = f'{nxentry.reduced_data.attrs["zoom_perc"]}p'
485 else:
486 res_title = 'fullres'
487 load_error = False
488 num_tomo_stacks = nxentry.reduced_data.data.tomo_fields.shape[0]
489 tomo_recon_stacks = num_tomo_stacks*[np.array([])]
490 for i in range(num_tomo_stacks):
491 # Convert reduced data stack from theta,row,column to row,theta,column
492 logger.debug(f'Reading reduced data stack {i+1}...')
493 t0 = time()
494 tomo_stack = np.asarray(nxentry.reduced_data.data.tomo_fields[i])
495 logger.debug(f'... done in {time()-t0:.2f} seconds')
496 if len(tomo_stack.shape) != 3 or any(True for dim in tomo_stack.shape if not dim):
497 raise ValueError(f'Unable to load tomography stack {i+1} for reconstruction')
498 tomo_stack = np.swapaxes(tomo_stack, 0, 1)
499 assert(len(thetas) == tomo_stack.shape[1])
500 assert(0 <= lower_row < upper_row < tomo_stack.shape[0])
501 center_offsets = [lower_center_offset-lower_row*center_slope,
502 upper_center_offset+(tomo_stack.shape[0]-1-upper_row)*center_slope]
503 t0 = time()
504 logger.debug(f'Running _reconstruct_one_tomo_stack on {self.num_core} cores ...')
505 tomo_recon_stack = self._reconstruct_one_tomo_stack(tomo_stack, thetas,
506 center_offsets=center_offsets, num_core=self.num_core, algorithm='gridrec')
507 logger.debug(f'... done in {time()-t0:.2f} seconds')
508 logger.info(f'Reconstruction of stack {i+1} took {time()-t0:.2f} seconds')
509
510 # Combine stacks
511 tomo_recon_stacks[i] = tomo_recon_stack
512
513 # Resize the reconstructed tomography data
514 # reconstructed data order in each stack: row/z,x,y
515 if self.test_mode:
516 x_bounds = self.test_config.get('x_bounds')
517 y_bounds = self.test_config.get('y_bounds')
518 z_bounds = None
519 elif self.galaxy_flag:
520 if x_bounds is not None and not is_int_pair(x_bounds, ge=0,
521 lt=tomo_recon_stacks[0].shape[1]):
522 raise ValueError(f'Invalid parameter x_bounds ({x_bounds})')
523 if y_bounds is not None and not is_int_pair(y_bounds, ge=0,
524 lt=tomo_recon_stacks[0].shape[1]):
525 raise ValueError(f'Invalid parameter y_bounds ({y_bounds})')
526 z_bounds = None
527 else:
528 x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_stacks)
529 if x_bounds is None:
530 x_range = (0, tomo_recon_stacks[0].shape[1])
531 x_slice = int(x_range[1]/2)
532 else:
533 x_range = (min(x_bounds), max(x_bounds))
534 x_slice = int((x_bounds[0]+x_bounds[1])/2)
535 if y_bounds is None:
536 y_range = (0, tomo_recon_stacks[0].shape[2])
537 y_slice = int(y_range[1]/2)
538 else:
539 y_range = (min(y_bounds), max(y_bounds))
540 y_slice = int((y_bounds[0]+y_bounds[1])/2)
541 if z_bounds is None:
542 z_range = (0, tomo_recon_stacks[0].shape[0])
543 z_slice = int(z_range[1]/2)
544 else:
545 z_range = (min(z_bounds), max(z_bounds))
546 z_slice = int((z_bounds[0]+z_bounds[1])/2)
547
548 # Plot a few reconstructed image slices
549 if num_tomo_stacks == 1:
550 basetitle = 'recon'
551 else:
552 basetitle = f'recon stack {i+1}'
553 for i, stack in enumerate(tomo_recon_stacks):
554 title = f'{basetitle} {res_title} xslice{x_slice}'
555 quick_imshow(stack[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]],
556 title=title, path=path, save_fig=self.save_figs, save_only=self.save_only,
557 block=self.block)
558 title = f'{basetitle} {res_title} yslice{y_slice}'
559 quick_imshow(stack[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice],
560 title=title, path=path, save_fig=self.save_figs, save_only=self.save_only,
561 block=self.block)
562 title = f'{basetitle} {res_title} zslice{z_slice}'
563 quick_imshow(stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]],
564 title=title, path=path, save_fig=self.save_figs, save_only=self.save_only,
565 block=self.block)
566
567 # Save test data to file
568 # reconstructed data order in each stack: row/z,x,y
569 if self.test_mode:
570 for i, stack in enumerate(tomo_recon_stacks):
571 np.savetxt(f'{self.output_folder}/recon_stack_{i+1}.txt',
572 stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e')
573
574 # Add image reconstruction to reconstructed data NXprocess
575 # reconstructed data order in each stack: row/z,x,y
576 nxprocess.data = NXdata()
577 nxprocess.attrs['default'] = 'data'
578 for k, v in center_info.items():
579 nxprocess[k] = v
580 if x_bounds is not None:
581 nxprocess.x_bounds = x_bounds
582 if y_bounds is not None:
583 nxprocess.y_bounds = y_bounds
584 if z_bounds is not None:
585 nxprocess.z_bounds = z_bounds
586 nxprocess.data['reconstructed_data'] = np.asarray([stack[z_range[0]:z_range[1],
587 x_range[0]:x_range[1],y_range[0]:y_range[1]] for stack in tomo_recon_stacks])
588 nxprocess.data.attrs['signal'] = 'reconstructed_data'
589
590 # Create a copy of the input Nexus object and remove reduced data
591 exclude_items = [f'{nxentry._name}/reduced_data/data', f'{nxentry._name}/data/reduced_data']
592 nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items)
593
594 # Add the reconstructed data NXprocess to the new Nexus object
595 nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']]
596 nxentry_copy.reconstructed_data = nxprocess
597 if 'data' not in nxentry_copy:
598 nxentry_copy.data = NXdata()
599 nxentry_copy.attrs['default'] = 'data'
600 nxentry_copy.data.makelink(nxprocess.data.reconstructed_data, name='reconstructed_data')
601 nxentry_copy.data.attrs['signal'] = 'reconstructed_data'
602
603 return(nxroot_copy)
604
605 def combine_data(self, nxroot, x_bounds=None, y_bounds=None):
606 """Combine the reconstructed tomography stacks.
607 """
608 logger.info('Combine the reconstructed tomography stacks')
609
610 if not isinstance(nxroot, NXroot):
611 raise ValueError(f'Invalid parameter nxroot ({nxroot})')
612 nxentry = nxroot[nxroot.attrs['default']]
613 if not isinstance(nxentry, NXentry):
614 raise ValueError(f'Invalid nxentry ({nxentry})')
615
616 # Create plot galaxy path directory and path if needed
617 if self.galaxy_flag:
618 if not os_path.exists('tomo_combine_plots'):
619 mkdir('tomo_combine_plots')
620 path = 'tomo_combine_plots'
621 else:
622 path = self.output_folder
623
624 # Check if reconstructed image data is available
625 if ('reconstructed_data' not in nxentry or 'reconstructed_data' not in nxentry.data):
626 raise KeyError(f'Unable to find valid reconstructed image data in {nxentry}.')
627
628 # Create an NXprocess to store combined image reconstruction (meta)data
629 nxprocess = NXprocess()
630
631 # Get the reconstructed data
632 # reconstructed data order: stack,row(z),x,y
633 # Note: Nexus cannot follow a link if the data it points to is too big,
634 # so get the data from the actual place, not from nxentry.data
635 num_tomo_stacks = nxentry.reconstructed_data.data.reconstructed_data.shape[0]
636 if num_tomo_stacks == 1:
637 logger.info('Only one stack available: leaving combine_data')
638 return(None)
639
640 # Combine the reconstructed stacks
641 # (load one stack at a time to reduce risk of hitting Nexus data access limit)
642 t0 = time()
643 logger.debug(f'Combining the reconstructed stacks ...')
644 tomo_recon_combined = np.asarray(nxentry.reconstructed_data.data.reconstructed_data[0])
645 if num_tomo_stacks > 2:
646 tomo_recon_combined = np.concatenate([tomo_recon_combined]+
647 [nxentry.reconstructed_data.data.reconstructed_data[i]
648 for i in range(1, num_tomo_stacks-1)])
649 if num_tomo_stacks > 1:
650 tomo_recon_combined = np.concatenate([tomo_recon_combined]+
651 [nxentry.reconstructed_data.data.reconstructed_data[num_tomo_stacks-1]])
652 logger.debug(f'... done in {time()-t0:.2f} seconds')
653 logger.info(f'Combining the reconstructed stacks took {time()-t0:.2f} seconds')
654
655 # Resize the combined tomography data stacks
656 # combined data order: row/z,x,y
657 if self.test_mode:
658 x_bounds = None
659 y_bounds = None
660 z_bounds = self.test_config.get('z_bounds')
661 elif self.galaxy_flag:
662 if x_bounds is not None and not is_int_pair(x_bounds, ge=0,
663 lt=tomo_recon_stacks[0].shape[1]):
664 raise ValueError(f'Invalid parameter x_bounds ({x_bounds})')
665 if y_bounds is not None and not is_int_pair(y_bounds, ge=0,
666 lt=tomo_recon_stacks[0].shape[1]):
667 raise ValueError(f'Invalid parameter y_bounds ({y_bounds})')
668 z_bounds = None
669 else:
670 x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_combined,
671 z_only=True)
672 if x_bounds is None:
673 x_range = (0, tomo_recon_combined.shape[1])
674 x_slice = int(x_range[1]/2)
675 else:
676 x_range = x_bounds
677 x_slice = int((x_bounds[0]+x_bounds[1])/2)
678 if y_bounds is None:
679 y_range = (0, tomo_recon_combined.shape[2])
680 y_slice = int(y_range[1]/2)
681 else:
682 y_range = y_bounds
683 y_slice = int((y_bounds[0]+y_bounds[1])/2)
684 if z_bounds is None:
685 z_range = (0, tomo_recon_combined.shape[0])
686 z_slice = int(z_range[1]/2)
687 else:
688 z_range = z_bounds
689 z_slice = int((z_bounds[0]+z_bounds[1])/2)
690
691 # Plot a few combined image slices
692 quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]],
693 title=f'recon combined xslice{x_slice}', path=path,
694 save_fig=self.save_figs, save_only=self.save_only, block=self.block)
695 quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice],
696 title=f'recon combined yslice{y_slice}', path=path,
697 save_fig=self.save_figs, save_only=self.save_only, block=self.block)
698 quick_imshow(tomo_recon_combined[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]],
699 title=f'recon combined zslice{z_slice}', path=path,
700 save_fig=self.save_figs, save_only=self.save_only, block=self.block)
701
702 # Save test data to file
703 # combined data order: row/z,x,y
704 if self.test_mode:
705 np.savetxt(f'{self.output_folder}/recon_combined.txt', tomo_recon_combined[
706 z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e')
707
708 # Add image reconstruction to reconstructed data NXprocess
709 # combined data order: row/z,x,y
710 nxprocess.data = NXdata()
711 nxprocess.attrs['default'] = 'data'
712 if x_bounds is not None:
713 nxprocess.x_bounds = x_bounds
714 if y_bounds is not None:
715 nxprocess.y_bounds = y_bounds
716 if z_bounds is not None:
717 nxprocess.z_bounds = z_bounds
718 nxprocess.data['combined_data'] = tomo_recon_combined
719 nxprocess.data.attrs['signal'] = 'combined_data'
720
721 # Create a copy of the input Nexus object and remove reconstructed data
722 exclude_items = [f'{nxentry._name}/reconstructed_data/data',
723 f'{nxentry._name}/data/reconstructed_data']
724 nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items)
725
726 # Add the combined data NXprocess to the new Nexus object
727 nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']]
728 nxentry_copy.combined_data = nxprocess
729 if 'data' not in nxentry_copy:
730 nxentry_copy.data = NXdata()
731 nxentry_copy.attrs['default'] = 'data'
732 nxentry_copy.data.makelink(nxprocess.data.combined_data, name='combined_data')
733 nxentry_copy.data.attrs['signal'] = 'combined_data'
734
735 return(nxroot_copy)
736
737 def _gen_dark(self, nxentry, reduced_data):
738 """Generate dark field.
739 """
740 # Get the dark field images
741 image_key = nxentry.instrument.detector.get('image_key', None)
742 if image_key and 'data' in nxentry.instrument.detector:
743 field_indices = [index for index, key in enumerate(image_key) if key == 2]
744 tdf_stack = nxentry.instrument.detector.data[field_indices,:,:]
745 # RV the default NXtomo form does not accomodate bright or dark field stacks
746 else:
747 dark_field_scans = nxentry.spec_scans.dark_field
748 dark_field = FlatField.construct_from_nxcollection(dark_field_scans)
749 prefix = str(nxentry.instrument.detector.local_name)
750 tdf_stack = dark_field.get_detector_data(prefix)
751 if isinstance(tdf_stack, list):
752 assert(len(tdf_stack) == 1) # TODO
753 tdf_stack = tdf_stack[0]
754
755 # Take median
756 if tdf_stack.ndim == 2:
757 tdf = tdf_stack
758 elif tdf_stack.ndim == 3:
759 tdf = np.median(tdf_stack, axis=0)
760 del tdf_stack
761 else:
762 raise ValueError(f'Invalid tdf_stack shape ({tdf_stack.shape})')
763
764 # Remove dark field intensities above the cutoff
765 #RV tdf_cutoff = None
766 tdf_cutoff = tdf.min()+2*(np.median(tdf)-tdf.min())
767 logger.debug(f'tdf_cutoff = {tdf_cutoff}')
768 if tdf_cutoff is not None:
769 if not is_num(tdf_cutoff, ge=0):
770 logger.warning(f'Ignoring illegal value of tdf_cutoff {tdf_cutoff}')
771 else:
772 tdf[tdf > tdf_cutoff] = np.nan
773 logger.debug(f'tdf_cutoff = {tdf_cutoff}')
774
775 # Remove nans
776 tdf_mean = np.nanmean(tdf)
777 logger.debug(f'tdf_mean = {tdf_mean}')
778 np.nan_to_num(tdf, copy=False, nan=tdf_mean, posinf=tdf_mean, neginf=0.)
779
780 # Plot dark field
781 if self.galaxy_flag:
782 quick_imshow(tdf, title='dark field', path='tomo_reduce_plots', save_fig=self.save_figs,
783 save_only=self.save_only)
784 elif not self.test_mode:
785 quick_imshow(tdf, title='dark field', path=self.output_folder, save_fig=self.save_figs,
786 save_only=self.save_only)
787 clear_imshow('dark field')
788 # quick_imshow(tdf, title='dark field', block=True)
789
790 # Add dark field to reduced data NXprocess
791 reduced_data.data = NXdata()
792 reduced_data.data['dark_field'] = tdf
793
794 return(reduced_data)
795
796 def _gen_bright(self, nxentry, reduced_data):
797 """Generate bright field.
798 """
799 # Get the bright field images
800 image_key = nxentry.instrument.detector.get('image_key', None)
801 if image_key and 'data' in nxentry.instrument.detector:
802 field_indices = [index for index, key in enumerate(image_key) if key == 1]
803 tbf_stack = nxentry.instrument.detector.data[field_indices,:,:]
804 # RV the default NXtomo form does not accomodate bright or dark field stacks
805 else:
806 bright_field_scans = nxentry.spec_scans.bright_field
807 bright_field = FlatField.construct_from_nxcollection(bright_field_scans)
808 prefix = str(nxentry.instrument.detector.local_name)
809 tbf_stack = bright_field.get_detector_data(prefix)
810 if isinstance(tbf_stack, list):
811 assert(len(tbf_stack) == 1) # TODO
812 tbf_stack = tbf_stack[0]
813
814 # Take median if more than one image
815 """Median or mean: It may be best to try the median because of some image
816 artifacts that arise due to crinkles in the upstream kapton tape windows
817 causing some phase contrast images to appear on the detector.
818 One thing that also may be useful in a future implementation is to do a
819 brightfield adjustment on EACH frame of the tomo based on a ROI in the
820 corner of the frame where there is no sample but there is the direct X-ray
821 beam because there is frame to frame fluctuations from the incoming beam.
822 We don’t typically account for them but potentially could.
823 """
824 if tbf_stack.ndim == 2:
825 tbf = tbf_stack
826 elif tbf_stack.ndim == 3:
827 tbf = np.median(tbf_stack, axis=0)
828 del tbf_stack
829 else:
830 raise ValueError(f'Invalid tbf_stack shape ({tbf_stacks.shape})')
831
832 # Subtract dark field
833 if 'data' in reduced_data and 'dark_field' in reduced_data.data:
834 tbf -= reduced_data.data.dark_field
835 else:
836 logger.warning('Dark field unavailable')
837
838 # Set any non-positive values to one
839 # (avoid negative bright field values for spikes in dark field)
840 tbf[tbf < 1] = 1
841
842 # Plot bright field
843 if self.galaxy_flag:
844 quick_imshow(tbf, title='bright field', path='tomo_reduce_plots',
845 save_fig=self.save_figs, save_only=self.save_only)
846 elif not self.test_mode:
847 quick_imshow(tbf, title='bright field', path=self.output_folder,
848 save_fig=self.save_figs, save_only=self.save_only)
849 clear_imshow('bright field')
850 # quick_imshow(tbf, title='bright field', block=True)
851
852 # Add bright field to reduced data NXprocess
853 if 'data' not in reduced_data:
854 reduced_data.data = NXdata()
855 reduced_data.data['bright_field'] = tbf
856
857 return(reduced_data)
858
859 def _set_detector_bounds(self, nxentry, reduced_data, img_x_bounds=None):
860 """Set vertical detector bounds for each image stack.
861 Right now the range is the same for each set in the image stack.
862 """
863 if self.test_mode:
864 return(tuple(self.test_config['img_x_bounds']))
865
866 # Get the first tomography image and the reference heights
867 image_key = nxentry.instrument.detector.get('image_key', None)
868 if image_key and 'data' in nxentry.instrument.detector:
869 field_indices = [index for index, key in enumerate(image_key) if key == 0]
870 first_image = np.asarray(nxentry.instrument.detector.data[field_indices[0],:,:])
871 theta = float(nxentry.sample.rotation_angle[field_indices[0]])
872 z_translation_all = nxentry.sample.z_translation[field_indices]
873 vertical_shifts = sorted(list(set(z_translation_all)))
874 num_tomo_stacks = len(vertical_shifts)
875 else:
876 tomo_field_scans = nxentry.spec_scans.tomo_fields
877 tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans)
878 vertical_shifts = tomo_fields.get_vertical_shifts()
879 if not isinstance(vertical_shifts, list):
880 vertical_shifts = [vertical_shifts]
881 prefix = str(nxentry.instrument.detector.local_name)
882 t0 = time()
883 first_image = tomo_fields.get_detector_data(prefix, tomo_fields.scan_numbers[0], 0)
884 logger.debug(f'Getting first image took {time()-t0:.2f} seconds')
885 num_tomo_stacks = len(tomo_fields.scan_numbers)
886 theta = tomo_fields.theta_range['start']
887
888 # Select image bounds
889 title = f'tomography image at theta={round(theta, 2)+0}'
890 if img_x_bounds is not None:
891 if is_int_pair(img_x_bounds) and img_x_bounds[0] == -1 and img_x_bounds[1] == -1:
892 img_x_bounds = None
893 elif not is_index_range(img_x_bounds, ge=0, le=first_image.shape[0]):
894 raise ValueError(f'Invalid parameter img_x_bounds ({img_x_bounds})')
895 if nxentry.instrument.source.attrs['station'] in ('id1a3', 'id3a'):
896 pixel_size = nxentry.instrument.detector.x_pixel_size
897 # Try to get a fit from the bright field
898 tbf = np.asarray(reduced_data.data.bright_field)
899 tbf_shape = tbf.shape
900 x_sum = np.sum(tbf, 1)
901 x_sum_min = x_sum.min()
902 x_sum_max = x_sum.max()
903 fit = Fit.fit_data(x_sum, 'rectangle', x=np.array(range(len(x_sum))), form='atan',
904 guess=True)
905 parameters = fit.best_values
906 x_low_fit = parameters.get('center1', None)
907 x_upp_fit = parameters.get('center2', None)
908 sig_low = parameters.get('sigma1', None)
909 sig_upp = parameters.get('sigma2', None)
910 have_fit = fit.success and x_low_fit is not None and x_upp_fit is not None and \
911 sig_low is not None and sig_upp is not None and \
912 0 <= x_low_fit < x_upp_fit <= x_sum.size and \
913 (sig_low+sig_upp)/(x_upp_fit-x_low_fit) < 0.1
914 if have_fit:
915 # Set a 5% margin on each side
916 margin = 0.05*(x_upp_fit-x_low_fit)
917 x_low_fit = max(0, x_low_fit-margin)
918 x_upp_fit = min(tbf_shape[0], x_upp_fit+margin)
919 if num_tomo_stacks == 1:
920 if have_fit:
921 # Set the default range to enclose the full fitted window
922 x_low = int(x_low_fit)
923 x_upp = int(x_upp_fit)
924 else:
925 # Center a default range of 1 mm (RV: can we get this from the slits?)
926 num_x_min = int((1.0-0.5*pixel_size)/pixel_size)
927 x_low = int(0.5*(tbf_shape[0]-num_x_min))
928 x_upp = x_low+num_x_min
929 else:
930 # Get the default range from the reference heights
931 delta_z = vertical_shifts[1]-vertical_shifts[0]
932 for i in range(2, num_tomo_stacks):
933 delta_z = min(delta_z, vertical_shifts[i]-vertical_shifts[i-1])
934 logger.debug(f'delta_z = {delta_z}')
935 num_x_min = int((delta_z-0.5*pixel_size)/pixel_size)
936 logger.debug(f'num_x_min = {num_x_min}')
937 if num_x_min > tbf_shape[0]:
938 logger.warning('Image bounds and pixel size prevent seamless stacking')
939 if have_fit:
940 # Center the default range relative to the fitted window
941 x_low = int(0.5*(x_low_fit+x_upp_fit-num_x_min))
942 x_upp = x_low+num_x_min
943 else:
944 # Center the default range
945 x_low = int(0.5*(tbf_shape[0]-num_x_min))
946 x_upp = x_low+num_x_min
947 if self.galaxy_flag:
948 img_x_bounds = (x_low, x_upp)
949 else:
950 tmp = np.copy(tbf)
951 tmp_max = tmp.max()
952 tmp[x_low,:] = tmp_max
953 tmp[x_upp-1,:] = tmp_max
954 quick_imshow(tmp, title='bright field')
955 tmp = np.copy(first_image)
956 tmp_max = tmp.max()
957 tmp[x_low,:] = tmp_max
958 tmp[x_upp-1,:] = tmp_max
959 quick_imshow(tmp, title=title)
960 del tmp
961 quick_plot((range(x_sum.size), x_sum),
962 ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'),
963 ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'),
964 title='sum over theta and y')
965 print(f'lower bound = {x_low} (inclusive)')
966 print(f'upper bound = {x_upp} (exclusive)]')
967 accept = input_yesno('Accept these bounds (y/n)?', 'y')
968 clear_imshow('bright field')
969 clear_imshow(title)
970 clear_plot('sum over theta and y')
971 if accept:
972 img_x_bounds = (x_low, x_upp)
973 else:
974 while True:
975 mask, img_x_bounds = draw_mask_1d(x_sum, title='select x data range',
976 legend='sum over theta and y')
977 if len(img_x_bounds) == 1:
978 break
979 else:
980 print(f'Choose a single connected data range')
981 img_x_bounds = tuple(img_x_bounds[0])
982 if (num_tomo_stacks > 1 and img_x_bounds[1]-img_x_bounds[0]+1 <
983 int((delta_z-0.5*pixel_size)/pixel_size)):
984 logger.warning('Image bounds and pixel size prevent seamless stacking')
985 else:
986 if num_tomo_stacks > 1:
987 raise NotImplementedError('Selecting image bounds for multiple stacks on FMB')
988 # For FMB: use the first tomography image to select range
989 # RV: revisit if they do tomography with multiple stacks
990 x_sum = np.sum(first_image, 1)
991 x_sum_min = x_sum.min()
992 x_sum_max = x_sum.max()
993 if self.galaxy_flag:
994 if img_x_bounds is None:
995 img_x_bounds = (0, first_image.shape[0])
996 else:
997 quick_imshow(first_image, title=title)
998 print('Select vertical data reduction range from first tomography image')
999 img_x_bounds = select_image_bounds(first_image, 0, title=title)
1000 clear_imshow(title)
1001 if img_x_bounds is None:
1002 raise ValueError('Unable to select image bounds')
1003
1004 # Plot results
1005 if self.galaxy_flag:
1006 path = 'tomo_reduce_plots'
1007 else:
1008 path = self.output_folder
1009 x_low = img_x_bounds[0]
1010 x_upp = img_x_bounds[1]
1011 tmp = np.copy(first_image)
1012 tmp_max = tmp.max()
1013 tmp[x_low,:] = tmp_max
1014 tmp[x_upp-1,:] = tmp_max
1015 quick_imshow(tmp, title=title, path=path, save_fig=self.save_figs, save_only=self.save_only,
1016 block=self.block)
1017 del tmp
1018 quick_plot((range(x_sum.size), x_sum),
1019 ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'),
1020 ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'),
1021 title='sum over theta and y', path=path, save_fig=self.save_figs,
1022 save_only=self.save_only, block=self.block)
1023
1024 return(img_x_bounds)
1025
1026 def _set_zoom_or_skip(self):
1027 """Set zoom and/or theta skip to reduce memory the requirement for the analysis.
1028 """
1029 # if input_yesno('\nDo you want to zoom in to reduce memory requirement (y/n)?', 'n'):
1030 # zoom_perc = input_int(' Enter zoom percentage', ge=1, le=100)
1031 # else:
1032 # zoom_perc = None
1033 zoom_perc = None
1034 # if input_yesno('Do you want to skip thetas to reduce memory requirement (y/n)?', 'n'):
1035 # num_theta_skip = input_int(' Enter the number skip theta interval', ge=0,
1036 # lt=num_theta)
1037 # else:
1038 # num_theta_skip = None
1039 num_theta_skip = None
1040 logger.debug(f'zoom_perc = {zoom_perc}')
1041 logger.debug(f'num_theta_skip = {num_theta_skip}')
1042
1043 return(zoom_perc, num_theta_skip)
1044
1045 def _gen_tomo(self, nxentry, reduced_data):
1046 """Generate tomography fields.
1047 """
1048 # Get full bright field
1049 tbf = np.asarray(reduced_data.data.bright_field)
1050 tbf_shape = tbf.shape
1051
1052 # Get image bounds
1053 img_x_bounds = tuple(reduced_data.get('img_x_bounds', (0, tbf_shape[0])))
1054 img_y_bounds = tuple(reduced_data.get('img_y_bounds', (0, tbf_shape[1])))
1055
1056 # Get resized dark field
1057 # if 'dark_field' in data:
1058 # tbf = np.asarray(reduced_data.data.dark_field[
1059 # img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]])
1060 # else:
1061 # logger.warning('Dark field unavailable')
1062 # tdf = None
1063 tdf = None
1064
1065 # Resize bright field
1066 if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]):
1067 tbf = tbf[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]]
1068
1069 # Get the tomography images
1070 image_key = nxentry.instrument.detector.get('image_key', None)
1071 if image_key and 'data' in nxentry.instrument.detector:
1072 field_indices_all = [index for index, key in enumerate(image_key) if key == 0]
1073 z_translation_all = nxentry.sample.z_translation[field_indices_all]
1074 z_translation_levels = sorted(list(set(z_translation_all)))
1075 num_tomo_stacks = len(z_translation_levels)
1076 tomo_stacks = num_tomo_stacks*[np.array([])]
1077 horizontal_shifts = []
1078 vertical_shifts = []
1079 thetas = None
1080 tomo_stacks = []
1081 for i, z_translation in enumerate(z_translation_levels):
1082 field_indices = [field_indices_all[index]
1083 for index, z in enumerate(z_translation_all) if z == z_translation]
1084 horizontal_shift = list(set(nxentry.sample.x_translation[field_indices]))
1085 assert(len(horizontal_shift) == 1)
1086 horizontal_shifts += horizontal_shift
1087 vertical_shift = list(set(nxentry.sample.z_translation[field_indices]))
1088 assert(len(vertical_shift) == 1)
1089 vertical_shifts += vertical_shift
1090 sequence_numbers = nxentry.instrument.detector.sequence_number[field_indices]
1091 if thetas is None:
1092 thetas = np.asarray(nxentry.sample.rotation_angle[field_indices]) \
1093 [sequence_numbers]
1094 else:
1095 assert(all(thetas[i] == nxentry.sample.rotation_angle[field_indices[index]]
1096 for i, index in enumerate(sequence_numbers)))
1097 assert(list(set(sequence_numbers)) == [i for i in range(len(sequence_numbers))])
1098 if list(sequence_numbers) == [i for i in range(len(sequence_numbers))]:
1099 tomo_stack = np.asarray(nxentry.instrument.detector.data[field_indices])
1100 else:
1101 raise ValueError('Unable to load the tomography images')
1102 tomo_stacks.append(tomo_stack)
1103 else:
1104 tomo_field_scans = nxentry.spec_scans.tomo_fields
1105 tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans)
1106 horizontal_shifts = tomo_fields.get_horizontal_shifts()
1107 vertical_shifts = tomo_fields.get_vertical_shifts()
1108 prefix = str(nxentry.instrument.detector.local_name)
1109 t0 = time()
1110 tomo_stacks = tomo_fields.get_detector_data(prefix)
1111 logger.debug(f'Getting tomography images took {time()-t0:.2f} seconds')
1112 logger.debug(f'Getting all images took {time()-t0:.2f} seconds')
1113 thetas = np.linspace(tomo_fields.theta_range['start'], tomo_fields.theta_range['end'],
1114 tomo_fields.theta_range['num'])
1115 if not isinstance(tomo_stacks, list):
1116 horizontal_shifts = [horizontal_shifts]
1117 vertical_shifts = [vertical_shifts]
1118 tomo_stacks = [tomo_stacks]
1119
1120 reduced_tomo_stacks = []
1121 if self.galaxy_flag:
1122 path = 'tomo_reduce_plots'
1123 else:
1124 path = self.output_folder
1125 for i, tomo_stack in enumerate(tomo_stacks):
1126 # Resize the tomography images
1127 # Right now the range is the same for each set in the image stack.
1128 if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]):
1129 t0 = time()
1130 tomo_stack = tomo_stack[:,img_x_bounds[0]:img_x_bounds[1],
1131 img_y_bounds[0]:img_y_bounds[1]].astype('float64')
1132 logger.debug(f'Resizing tomography images took {time()-t0:.2f} seconds')
1133
1134 # Subtract dark field
1135 if tdf is not None:
1136 t0 = time()
1137 with set_numexpr_threads(self.num_core):
1138 ne.evaluate('tomo_stack-tdf', out=tomo_stack)
1139 logger.debug(f'Subtracting dark field took {time()-t0:.2f} seconds')
1140
1141 # Normalize
1142 t0 = time()
1143 with set_numexpr_threads(self.num_core):
1144 ne.evaluate('tomo_stack/tbf', out=tomo_stack, truediv=True)
1145 logger.debug(f'Normalizing took {time()-t0:.2f} seconds')
1146
1147 # Remove non-positive values and linearize data
1148 t0 = time()
1149 cutoff = 1.e-6
1150 with set_numexpr_threads(self.num_core):
1151 ne.evaluate('where(tomo_stack<cutoff, cutoff, tomo_stack)', out=tomo_stack)
1152 with set_numexpr_threads(self.num_core):
1153 ne.evaluate('-log(tomo_stack)', out=tomo_stack)
1154 logger.debug('Removing non-positive values and linearizing data took '+
1155 f'{time()-t0:.2f} seconds')
1156
1157 # Get rid of nans/infs that may be introduced by normalization
1158 t0 = time()
1159 np.where(np.isfinite(tomo_stack), tomo_stack, 0.)
1160 logger.debug(f'Remove nans/infs took {time()-t0:.2f} seconds')
1161
1162 # Downsize tomography stack to smaller size
1163 # TODO use theta_skip as well
1164 tomo_stack = tomo_stack.astype('float32')
1165 if not self.test_mode:
1166 if len(tomo_stacks) == 1:
1167 title = f'red fullres theta {round(thetas[0], 2)+0}'
1168 else:
1169 title = f'red stack {i+1} fullres theta {round(thetas[0], 2)+0}'
1170 quick_imshow(tomo_stack[0,:,:], title=title, path=path, save_fig=self.save_figs,
1171 save_only=self.save_only, block=self.block)
1172 # if not self.block:
1173 # clear_imshow(title)
1174 if False and zoom_perc != 100:
1175 t0 = time()
1176 logger.debug(f'Zooming in ...')
1177 tomo_zoom_list = []
1178 for j in range(tomo_stack.shape[0]):
1179 tomo_zoom = spi.zoom(tomo_stack[j,:,:], 0.01*zoom_perc)
1180 tomo_zoom_list.append(tomo_zoom)
1181 tomo_stack = np.stack([tomo_zoom for tomo_zoom in tomo_zoom_list])
1182 logger.debug(f'... done in {time()-t0:.2f} seconds')
1183 logger.info(f'Zooming in took {time()-t0:.2f} seconds')
1184 del tomo_zoom_list
1185 if not self.test_mode:
1186 title = f'red stack {zoom_perc}p theta {round(thetas[0], 2)+0}'
1187 quick_imshow(tomo_stack[0,:,:], title=title, path=path, save_fig=self.save_figs,
1188 save_only=self.save_only, block=self.block)
1189 # if not self.block:
1190 # clear_imshow(title)
1191
1192 # Save test data to file
1193 if self.test_mode:
1194 # row_index = int(tomo_stack.shape[0]/2)
1195 # np.savetxt(f'{self.output_folder}/red_stack_{i+1}.txt', tomo_stack[row_index,:,:],
1196 # fmt='%.6e')
1197 row_index = int(tomo_stack.shape[1]/2)
1198 np.savetxt(f'{self.output_folder}/red_stack_{i+1}.txt', tomo_stack[:,row_index,:],
1199 fmt='%.6e')
1200
1201 # Combine resized stacks
1202 reduced_tomo_stacks.append(tomo_stack)
1203
1204 # Add tomo field info to reduced data NXprocess
1205 reduced_data['rotation_angle'] = thetas
1206 reduced_data['x_translation'] = np.asarray(horizontal_shifts)
1207 reduced_data['z_translation'] = np.asarray(vertical_shifts)
1208 reduced_data.data['tomo_fields'] = np.asarray(reduced_tomo_stacks)
1209
1210 if tdf is not None:
1211 del tdf
1212 del tbf
1213
1214 return(reduced_data)
1215
1216 def _find_center_one_plane(self, sinogram, row, thetas, eff_pixel_size, cross_sectional_dim,
1217 path=None, tol=0.1, num_core=1):
1218 """Find center for a single tomography plane.
1219 """
1220 # Try automatic center finding routines for initial value
1221 # sinogram index order: theta,column
1222 # need column,theta for iradon, so take transpose
1223 sinogram = np.asarray(sinogram)
1224 sinogram_T = sinogram.T
1225 center = sinogram.shape[1]/2
1226
1227 # Try using Nghia Vo’s method
1228 t0 = time()
1229 if num_core > num_core_tomopy_limit:
1230 logger.debug(f'Running find_center_vo on {num_core_tomopy_limit} cores ...')
1231 tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core_tomopy_limit)
1232 else:
1233 logger.debug(f'Running find_center_vo on {num_core} cores ...')
1234 tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core)
1235 logger.debug(f'... done in {time()-t0:.2f} seconds')
1236 logger.info(f'Finding the center using Nghia Vo’s method took {time()-t0:.2f} seconds')
1237 center_offset_vo = tomo_center-center
1238 logger.info(f'Center at row {row} using Nghia Vo’s method = {center_offset_vo:.2f}')
1239 t0 = time()
1240 logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...')
1241 recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas,
1242 eff_pixel_size, cross_sectional_dim, False, num_core)
1243 logger.debug(f'... done in {time()-t0:.2f} seconds')
1244 logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds')
1245
1246 title = f'edges row{row} center offset{center_offset_vo:.2f} Vo'
1247 self._plot_edges_one_plane(recon_plane, title, path=path)
1248
1249 # Try using phase correlation method
1250 # if input_yesno('Try finding center using phase correlation (y/n)?', 'n'):
1251 # t0 = time()
1252 # logger.debug(f'Running find_center_pc ...')
1253 # tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=0.1, rotc_guess=tomo_center)
1254 # error = 1.
1255 # while error > tol:
1256 # prev = tomo_center
1257 # tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=tol,
1258 # rotc_guess=tomo_center)
1259 # error = np.abs(tomo_center-prev)
1260 # logger.debug(f'... done in {time()-t0:.2f} seconds')
1261 # logger.info('Finding the center using the phase correlation method took '+
1262 # f'{time()-t0:.2f} seconds')
1263 # center_offset = tomo_center-center
1264 # print(f'Center at row {row} using phase correlation = {center_offset:.2f}')
1265 # t0 = time()
1266 # logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...')
1267 # recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas,
1268 # eff_pixel_size, cross_sectional_dim, False, num_core)
1269 # logger.debug(f'... done in {time()-t0:.2f} seconds')
1270 # logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds')
1271 #
1272 # title = f'edges row{row} center_offset{center_offset:.2f} PC'
1273 # self._plot_edges_one_plane(recon_plane, title, path=path)
1274
1275 # Select center location
1276 # if input_yesno('Accept a center location (y) or continue search (n)?', 'y'):
1277 if True:
1278 # center_offset = input_num(' Enter chosen center offset', ge=-center, le=center,
1279 # default=center_offset_vo)
1280 center_offset = center_offset_vo
1281 del sinogram_T
1282 del recon_plane
1283 return float(center_offset)
1284
1285 # perform center finding search
1286 while True:
1287 center_offset_low = input_int('\nEnter lower bound for center offset', ge=-center,
1288 le=center)
1289 center_offset_upp = input_int('Enter upper bound for center offset',
1290 ge=center_offset_low, le=center)
1291 if center_offset_upp == center_offset_low:
1292 center_offset_step = 1
1293 else:
1294 center_offset_step = input_int('Enter step size for center offset search', ge=1,
1295 le=center_offset_upp-center_offset_low)
1296 num_center_offset = 1+int((center_offset_upp-center_offset_low)/center_offset_step)
1297 center_offsets = np.linspace(center_offset_low, center_offset_upp, num_center_offset)
1298 for center_offset in center_offsets:
1299 if center_offset == center_offset_vo:
1300 continue
1301 t0 = time()
1302 logger.debug(f'Running _reconstruct_one_plane on {num_core} cores ...')
1303 recon_plane = self._reconstruct_one_plane(sinogram_T, center_offset+center, thetas,
1304 eff_pixel_size, cross_sectional_dim, False, num_core)
1305 logger.debug(f'... done in {time()-t0:.2f} seconds')
1306 logger.info(f'Reconstructing center_offset {center_offset} took '+
1307 f'{time()-t0:.2f} seconds')
1308 title = f'edges row{row} center_offset{center_offset:.2f}'
1309 self._plot_edges_one_plane(recon_plane, title, path=path)
1310 if input_int('\nContinue (0) or end the search (1)', ge=0, le=1):
1311 break
1312
1313 del sinogram_T
1314 del recon_plane
1315 center_offset = input_num(' Enter chosen center offset', ge=-center, le=center)
1316 return float(center_offset)
1317
1318 def _reconstruct_one_plane(self, tomo_plane_T, center, thetas, eff_pixel_size,
1319 cross_sectional_dim, plot_sinogram=True, num_core=1):
1320 """Invert the sinogram for a single tomography plane.
1321 """
1322 # tomo_plane_T index order: column,theta
1323 assert(0 <= center < tomo_plane_T.shape[0])
1324 center_offset = center-tomo_plane_T.shape[0]/2
1325 two_offset = 2*int(np.round(center_offset))
1326 two_offset_abs = np.abs(two_offset)
1327 max_rad = int(0.55*(cross_sectional_dim/eff_pixel_size)) # 10% slack to avoid edge effects
1328 if max_rad > 0.5*tomo_plane_T.shape[0]:
1329 max_rad = 0.5*tomo_plane_T.shape[0]
1330 dist_from_edge = max(1, int(np.floor((tomo_plane_T.shape[0]-two_offset_abs)/2.)-max_rad))
1331 if two_offset >= 0:
1332 logger.debug(f'sinogram range = [{two_offset+dist_from_edge}, {-dist_from_edge}]')
1333 sinogram = tomo_plane_T[two_offset+dist_from_edge:-dist_from_edge,:]
1334 else:
1335 logger.debug(f'sinogram range = [{dist_from_edge}, {two_offset-dist_from_edge}]')
1336 sinogram = tomo_plane_T[dist_from_edge:two_offset-dist_from_edge,:]
1337 if not self.galaxy_flag and plot_sinogram:
1338 quick_imshow(sinogram.T, f'sinogram center offset{center_offset:.2f}', aspect='auto',
1339 path=self.output_folder, save_fig=self.save_figs, save_only=self.save_only,
1340 block=self.block)
1341
1342 # Inverting sinogram
1343 t0 = time()
1344 recon_sinogram = iradon(sinogram, theta=thetas, circle=True)
1345 logger.debug(f'Inverting sinogram took {time()-t0:.2f} seconds')
1346 del sinogram
1347
1348 # Performing Gaussian filtering and removing ring artifacts
1349 recon_parameters = None#self.config.get('recon_parameters')
1350 if recon_parameters is None:
1351 sigma = 1.0
1352 ring_width = 15
1353 else:
1354 sigma = recon_parameters.get('gaussian_sigma', 1.0)
1355 if not is_num(sigma, ge=0.0):
1356 logger.warning(f'Invalid gaussian_sigma ({sigma}) in _reconstruct_one_plane, '+
1357 'set to a default value of 1.0')
1358 sigma = 1.0
1359 ring_width = recon_parameters.get('ring_width', 15)
1360 if not is_int(ring_width, ge=0):
1361 logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+
1362 'set to a default value of 15')
1363 ring_width = 15
1364 t0 = time()
1365 recon_sinogram = spi.gaussian_filter(recon_sinogram, sigma, mode='nearest')
1366 recon_clean = np.expand_dims(recon_sinogram, axis=0)
1367 del recon_sinogram
1368 recon_clean = tomopy.misc.corr.remove_ring(recon_clean, rwidth=ring_width, ncore=num_core)
1369 logger.debug(f'Filtering and removing ring artifacts took {time()-t0:.2f} seconds')
1370
1371 return recon_clean
1372
1373 def _plot_edges_one_plane(self, recon_plane, title, path=None):
1374 vis_parameters = None#self.config.get('vis_parameters')
1375 if vis_parameters is None:
1376 weight = 0.1
1377 else:
1378 weight = vis_parameters.get('denoise_weight', 0.1)
1379 if not is_num(weight, ge=0.0):
1380 logger.warning(f'Invalid weight ({weight}) in _plot_edges_one_plane, '+
1381 'set to a default value of 0.1')
1382 weight = 0.1
1383 edges = denoise_tv_chambolle(recon_plane, weight=weight)
1384 vmax = np.max(edges[0,:,:])
1385 vmin = -vmax
1386 if path is None:
1387 path = self.output_folder
1388 quick_imshow(edges[0,:,:], f'{title} coolwarm', path=path, cmap='coolwarm',
1389 save_fig=self.save_figs, save_only=self.save_only, block=self.block)
1390 quick_imshow(edges[0,:,:], f'{title} gray', path=path, cmap='gray', vmin=vmin, vmax=vmax,
1391 save_fig=self.save_figs, save_only=self.save_only, block=self.block)
1392 del edges
1393
1394 def _reconstruct_one_tomo_stack(self, tomo_stack, thetas, center_offsets=[], num_core=1,
1395 algorithm='gridrec'):
1396 """Reconstruct a single tomography stack.
1397 """
1398 # tomo_stack order: row,theta,column
1399 # input thetas must be in degrees
1400 # centers_offset: tomography axis shift in pixels relative to column center
1401 # RV should we remove stripes?
1402 # https://tomopy.readthedocs.io/en/latest/api/tomopy.prep.stripe.html
1403 # RV should we remove rings?
1404 # https://tomopy.readthedocs.io/en/latest/api/tomopy.misc.corr.html
1405 # RV: Add an option to do (extra) secondary iterations later or to do some sort of convergence test?
1406 if not len(center_offsets):
1407 centers = np.zeros((tomo_stack.shape[0]))
1408 elif len(center_offsets) == 2:
1409 centers = np.linspace(center_offsets[0], center_offsets[1], tomo_stack.shape[0])
1410 else:
1411 if center_offsets.size != tomo_stack.shape[0]:
1412 raise ValueError('center_offsets dimension mismatch in reconstruct_one_tomo_stack')
1413 centers = center_offsets
1414 centers += tomo_stack.shape[2]/2
1415
1416 # Get reconstruction parameters
1417 recon_parameters = None#self.config.get('recon_parameters')
1418 if recon_parameters is None:
1419 sigma = 2.0
1420 secondary_iters = 0
1421 ring_width = 15
1422 else:
1423 sigma = recon_parameters.get('stripe_fw_sigma', 2.0)
1424 if not is_num(sigma, ge=0):
1425 logger.warning(f'Invalid stripe_fw_sigma ({sigma}) in '+
1426 '_reconstruct_one_tomo_stack, set to a default value of 2.0')
1427 ring_width = 15
1428 secondary_iters = recon_parameters.get('secondary_iters', 0)
1429 if not is_int(secondary_iters, ge=0):
1430 logger.warning(f'Invalid secondary_iters ({secondary_iters}) in '+
1431 '_reconstruct_one_tomo_stack, set to a default value of 0 (skip them)')
1432 ring_width = 0
1433 ring_width = recon_parameters.get('ring_width', 15)
1434 if not is_int(ring_width, ge=0):
1435 logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+
1436 'set to a default value of 15')
1437 ring_width = 15
1438
1439 # Remove horizontal stripe
1440 t0 = time()
1441 if num_core > num_core_tomopy_limit:
1442 logger.debug('Running remove_stripe_fw on {num_core_tomopy_limit} cores ...')
1443 tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma,
1444 ncore=num_core_tomopy_limit)
1445 else:
1446 logger.debug(f'Running remove_stripe_fw on {num_core} cores ...')
1447 tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma,
1448 ncore=num_core)
1449 logger.debug(f'... tomopy.prep.stripe.remove_stripe_fw took {time()-t0:.2f} seconds')
1450
1451 # Perform initial image reconstruction
1452 logger.debug('Performing initial image reconstruction')
1453 t0 = time()
1454 logger.debug(f'Running recon on {num_core} cores ...')
1455 tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers,
1456 sinogram_order=True, algorithm=algorithm, ncore=num_core)
1457 logger.debug(f'... done in {time()-t0:.2f} seconds')
1458 logger.info(f'Performing initial image reconstruction took {time()-t0:.2f} seconds')
1459
1460 # Run optional secondary iterations
1461 if secondary_iters > 0:
1462 logger.debug(f'Running {secondary_iters} secondary iterations')
1463 #options = {'method':'SIRT_CUDA', 'proj_type':'cuda', 'num_iter':secondary_iters}
1464 #RV: doesn't work for me:
1465 #"Error: CUDA error 803: system has unsupported display driver/cuda driver combination."
1466 #options = {'method':'SIRT', 'proj_type':'linear', 'MinConstraint': 0, 'num_iter':secondary_iters}
1467 #SIRT did not finish while running overnight
1468 #options = {'method':'SART', 'proj_type':'linear', 'num_iter':secondary_iters}
1469 options = {'method':'SART', 'proj_type':'linear', 'MinConstraint': 0,
1470 'num_iter':secondary_iters}
1471 t0 = time()
1472 logger.debug(f'Running recon on {num_core} cores ...')
1473 tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers,
1474 init_recon=tomo_recon_stack, options=options, sinogram_order=True,
1475 algorithm=tomopy.astra, ncore=num_core)
1476 logger.debug(f'... done in {time()-t0:.2f} seconds')
1477 logger.info(f'Performing secondary iterations took {time()-t0:.2f} seconds')
1478
1479 # Remove ring artifacts
1480 t0 = time()
1481 tomopy.misc.corr.remove_ring(tomo_recon_stack, rwidth=ring_width, out=tomo_recon_stack,
1482 ncore=num_core)
1483 logger.debug(f'Removing ring artifacts took {time()-t0:.2f} seconds')
1484
1485 return tomo_recon_stack
1486
1487 def _resize_reconstructed_data(self, data, z_only=False):
1488 """Resize the reconstructed tomography data.
1489 """
1490 # Data order: row(z),x,y or stack,row(z),x,y
1491 if isinstance(data, list):
1492 for stack in data:
1493 assert(stack.ndim == 3)
1494 num_tomo_stacks = len(data)
1495 tomo_recon_stacks = data
1496 else:
1497 assert(data.ndim == 3)
1498 num_tomo_stacks = 1
1499 tomo_recon_stacks = [data]
1500
1501 if z_only:
1502 x_bounds = None
1503 y_bounds = None
1504 else:
1505 # Selecting x bounds (in yz-plane)
1506 tomosum = 0
1507 [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,2))
1508 for i in range(num_tomo_stacks)]
1509 select_x_bounds = input_yesno('\nDo you want to change the image x-bounds (y/n)?', 'y')
1510 if not select_x_bounds:
1511 x_bounds = None
1512 else:
1513 accept = False
1514 index_ranges = None
1515 while not accept:
1516 mask, x_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges,
1517 title='select x data range', legend='recon stack sum yz')
1518 while len(x_bounds) != 1:
1519 print('Please select exactly one continuous range')
1520 mask, x_bounds = draw_mask_1d(tomosum, title='select x data range',
1521 legend='recon stack sum yz')
1522 x_bounds = x_bounds[0]
1523 # quick_plot(tomosum, vlines=x_bounds, title='recon stack sum yz')
1524 # print(f'x_bounds = {x_bounds} (lower bound inclusive, upper bound '+
1525 # 'exclusive)')
1526 # accept = input_yesno('Accept these bounds (y/n)?', 'y')
1527 accept = True
1528 logger.debug(f'x_bounds = {x_bounds}')
1529
1530 # Selecting y bounds (in xz-plane)
1531 tomosum = 0
1532 [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,1))
1533 for i in range(num_tomo_stacks)]
1534 select_y_bounds = input_yesno('\nDo you want to change the image y-bounds (y/n)?', 'y')
1535 if not select_y_bounds:
1536 y_bounds = None
1537 else:
1538 accept = False
1539 index_ranges = None
1540 while not accept:
1541 mask, y_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges,
1542 title='select x data range', legend='recon stack sum xz')
1543 while len(y_bounds) != 1:
1544 print('Please select exactly one continuous range')
1545 mask, y_bounds = draw_mask_1d(tomosum, title='select x data range',
1546 legend='recon stack sum xz')
1547 y_bounds = y_bounds[0]
1548 # quick_plot(tomosum, vlines=y_bounds, title='recon stack sum xz')
1549 # print(f'y_bounds = {y_bounds} (lower bound inclusive, upper bound '+
1550 # 'exclusive)')
1551 # accept = input_yesno('Accept these bounds (y/n)?', 'y')
1552 accept = True
1553 logger.debug(f'y_bounds = {y_bounds}')
1554
1555 # Selecting z bounds (in xy-plane) (only valid for a single image stack)
1556 if num_tomo_stacks != 1:
1557 z_bounds = None
1558 else:
1559 tomosum = 0
1560 [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(1,2))
1561 for i in range(num_tomo_stacks)]
1562 select_z_bounds = input_yesno('Do you want to change the image z-bounds (y/n)?', 'n')
1563 if not select_z_bounds:
1564 z_bounds = None
1565 else:
1566 accept = False
1567 index_ranges = None
1568 while not accept:
1569 mask, z_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges,
1570 title='select x data range', legend='recon stack sum xy')
1571 while len(z_bounds) != 1:
1572 print('Please select exactly one continuous range')
1573 mask, z_bounds = draw_mask_1d(tomosum, title='select x data range',
1574 legend='recon stack sum xy')
1575 z_bounds = z_bounds[0]
1576 # quick_plot(tomosum, vlines=z_bounds, title='recon stack sum xy')
1577 # print(f'z_bounds = {z_bounds} (lower bound inclusive, upper bound '+
1578 # 'exclusive)')
1579 # accept = input_yesno('Accept these bounds (y/n)?', 'y')
1580 accept = True
1581 logger.debug(f'z_bounds = {z_bounds}')
1582
1583 return(x_bounds, y_bounds, z_bounds)
1584
1585
1586 def run_tomo(input_file:str, output_file:str, modes:list[str], center_file=None, num_core=-1,
1587 output_folder='.', save_figs='no', test_mode=False) -> None:
1588
1589 if test_mode:
1590 logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s'
1591 level = logging.getLevelName('INFO')
1592 logging.basicConfig(filename=f'{output_folder}/tomo.log', filemode='w',
1593 format=logging_format, level=level, force=True)
1594 logger.info(f'input_file = {input_file}')
1595 logger.info(f'center_file = {center_file}')
1596 logger.info(f'output_file = {output_file}')
1597 logger.debug(f'modes= {modes}')
1598 logger.debug(f'num_core= {num_core}')
1599 logger.info(f'output_folder = {output_folder}')
1600 logger.info(f'save_figs = {save_figs}')
1601 logger.info(f'test_mode = {test_mode}')
1602
1603 # Check for correction modes
1604 legal_modes = ['reduce_data', 'find_center', 'reconstruct_data', 'combine_data', 'all']
1605 if modes is None:
1606 modes = ['all']
1607 if not all(True if mode in legal_modes else False for mode in modes):
1608 raise ValueError(f'Invalid parameter modes ({modes})')
1609
1610 # Instantiate Tomo object
1611 tomo = Tomo(num_core=num_core, output_folder=output_folder, save_figs=save_figs,
1612 test_mode=test_mode)
1613
1614 # Read input file
1615 data = tomo.read(input_file)
1616
1617 # Generate reduced tomography images
1618 if 'reduce_data' in modes or 'all' in modes:
1619 data = tomo.gen_reduced_data(data)
1620
1621 # Find rotation axis centers for the tomography stacks.
1622 center_data = None
1623 if 'find_center' in modes or 'all' in modes:
1624 center_data = tomo.find_centers(data)
1625
1626 # Reconstruct tomography stacks
1627 if 'reconstruct_data' in modes or 'all' in modes:
1628 if center_data is None:
1629 # Read input file
1630 center_data = tomo.read(center_file)
1631 data = tomo.reconstruct_data(data, center_data)
1632 center_data = None
1633
1634 # Combine reconstructed tomography stacks
1635 if 'combine_data' in modes or 'all' in modes:
1636 data = tomo.combine_data(data)
1637
1638 # Write output file
1639 if data is not None and not test_mode:
1640 if center_data is None:
1641 data = tomo.write(data, output_file)
1642 else:
1643 data = tomo.write(center_data, output_file)
1644
1645 logger.info(f'Completed modes: {modes}')