comparison workflow/run_tomo.py @ 4:9aa288729b9a draft

planemo upload for repository https://github.com/rolfverberg/galaxytools commit f8c4bdb31c20c468045ad5e6eb255a293244bc6c-dirty
author rv43
date Mon, 20 Mar 2023 18:44:23 +0000
parents
children 543dba81eb15
comparison
equal deleted inserted replaced
3:fc38431f257f 4:9aa288729b9a
1 #!/usr/bin/env python3
2
3 import logging
4 logger = logging.getLogger(__name__)
5
6 import numpy as np
7 try:
8 import numexpr as ne
9 except:
10 pass
11 try:
12 import scipy.ndimage as spi
13 except:
14 pass
15
16 from multiprocessing import cpu_count
17 from nexusformat.nexus import *
18 from os import mkdir
19 from os import path as os_path
20 try:
21 from skimage.transform import iradon
22 except:
23 pass
24 try:
25 from skimage.restoration import denoise_tv_chambolle
26 except:
27 pass
28 from time import time
29 try:
30 import tomopy
31 except:
32 pass
33 from yaml import safe_load, safe_dump
34
35 try:
36 from msnctools.fit import Fit
37 except:
38 from fit import Fit
39 try:
40 from msnctools.general import illegal_value, is_int, is_int_pair, is_num, is_index_range, \
41 input_int, input_num, input_yesno, input_menu, draw_mask_1d, select_image_bounds, \
42 select_one_image_bound, clear_imshow, quick_imshow, clear_plot, quick_plot
43 except:
44 from general import illegal_value, is_int, is_int_pair, is_num, is_index_range, \
45 input_int, input_num, input_yesno, input_menu, draw_mask_1d, select_image_bounds, \
46 select_one_image_bound, clear_imshow, quick_imshow, clear_plot, quick_plot
47
48 try:
49 from workflow.models import import_scanparser, FlatField, TomoField, TomoWorkflow
50 from workflow.__version__ import __version__
51 except:
52 pass
53
54 num_core_tomopy_limit = 24
55
56 def nxcopy(nxobject:NXobject, exclude_nxpaths:list[str]=[], nxpath_prefix:str='') -> NXobject:
57 '''Function that returns a copy of a nexus object, optionally exluding certain child items.
58
59 :param nxobject: the original nexus object to return a "copy" of
60 :type nxobject: nexusformat.nexus.NXobject
61 :param exlude_nxpaths: a list of paths to child nexus objects that
62 should be exluded from the returned "copy", defaults to `[]`
63 :type exclude_nxpaths: list[str], optional
64 :param nxpath_prefix: For use in recursive calls from inside this
65 function only!
66 :type nxpath_prefix: str
67 :return: a copy of `nxobject` with some children optionally exluded.
68 :rtype: NXobject
69 '''
70
71 nxobject_copy = nxobject.__class__()
72 if not len(nxpath_prefix):
73 if 'default' in nxobject.attrs:
74 nxobject_copy.attrs['default'] = nxobject.attrs['default']
75 else:
76 for k, v in nxobject.attrs.items():
77 nxobject_copy.attrs[k] = v
78
79 for k, v in nxobject.items():
80 nxpath = os_path.join(nxpath_prefix, k)
81
82 if nxpath in exclude_nxpaths:
83 continue
84
85 if isinstance(v, NXgroup):
86 nxobject_copy[k] = nxcopy(v, exclude_nxpaths=exclude_nxpaths,
87 nxpath_prefix=os_path.join(nxpath_prefix, k))
88 else:
89 nxobject_copy[k] = v
90
91 return(nxobject_copy)
92
93 class set_numexpr_threads:
94
95 def __init__(self, num_core):
96 if num_core is None or num_core < 1 or num_core > cpu_count():
97 self.num_core = cpu_count()
98 else:
99 self.num_core = num_core
100
101 def __enter__(self):
102 self.num_core_org = ne.set_num_threads(self.num_core)
103
104 def __exit__(self, exc_type, exc_value, traceback):
105 ne.set_num_threads(self.num_core_org)
106
107 class Tomo:
108 """Processing tomography data with misalignment.
109 """
110 def __init__(self, galaxy_flag=False, num_core=-1, output_folder='.', save_figs=None,
111 test_mode=False):
112 """Initialize with optional config input file or dictionary
113 """
114 if not isinstance(galaxy_flag, bool):
115 raise ValueError(f'Invalid parameter galaxy_flag ({galaxy_flag})')
116 self.galaxy_flag = galaxy_flag
117 self.num_core = num_core
118 if self.galaxy_flag:
119 if output_folder != '.':
120 logger.warning('Ignoring output_folder in galaxy mode')
121 self.output_folder = '.'
122 if test_mode != False:
123 logger.warning('Ignoring test_mode in galaxy mode')
124 self.test_mode = False
125 if save_figs is not None:
126 logger.warning('Ignoring save_figs in galaxy mode')
127 save_figs = 'only'
128 else:
129 self.output_folder = os_path.abspath(output_folder)
130 if not os_path.isdir(output_folder):
131 mkdir(os_path.abspath(output_folder))
132 if not isinstance(test_mode, bool):
133 raise ValueError(f'Invalid parameter test_mode ({test_mode})')
134 self.test_mode = test_mode
135 if save_figs is None:
136 save_figs = 'no'
137 self.test_config = {}
138 if self.test_mode:
139 if save_figs != 'only':
140 logger.warning('Ignoring save_figs in test mode')
141 save_figs = 'only'
142 if save_figs == 'only':
143 self.save_only = True
144 self.save_figs = True
145 elif save_figs == 'yes':
146 self.save_only = False
147 self.save_figs = True
148 elif save_figs == 'no':
149 self.save_only = False
150 self.save_figs = False
151 else:
152 raise ValueError(f'Invalid parameter save_figs ({save_figs})')
153 if self.save_only:
154 self.block = False
155 else:
156 self.block = True
157 if self.num_core == -1:
158 self.num_core = cpu_count()
159 if not is_int(self.num_core, gt=0, log=False):
160 raise ValueError(f'Invalid parameter num_core ({num_core})')
161 if self.num_core > cpu_count():
162 logger.warning(f'num_core = {self.num_core} is larger than the number of available '
163 f'processors and reduced to {cpu_count()}')
164 self.num_core= cpu_count()
165
166 def read(self, filename):
167 extension = os_path.splitext(filename)[1]
168 if extension == '.yml' or extension == '.yaml':
169 with open(filename, 'r') as f:
170 config = safe_load(f)
171 # if len(config) > 1:
172 # raise ValueError(f'Multiple root entries in {filename} not yet implemented')
173 # if len(list(config.values())[0]) > 1:
174 # raise ValueError(f'Multiple sample maps in {filename} not yet implemented')
175 return(config)
176 elif extension == '.nxs':
177 with NXFile(filename, mode='r') as nxfile:
178 nxroot = nxfile.readfile()
179 return(nxroot)
180 else:
181 raise ValueError(f'Invalid filename extension ({extension})')
182
183 def write(self, data, filename):
184 extension = os_path.splitext(filename)[1]
185 if extension == '.yml' or extension == '.yaml':
186 with open(filename, 'w') as f:
187 safe_dump(data, f)
188 elif extension == '.nxs':
189 data.save(filename, mode='w')
190 elif extension == '.nc':
191 data.to_netcdf(os_path=filename)
192 else:
193 raise ValueError(f'Invalid filename extension ({extension})')
194
195 def gen_reduced_data(self, data, img_x_bounds=None):
196 """Generate the reduced tomography images.
197 """
198 logger.info('Generate the reduced tomography images')
199
200 # Create plot galaxy path directory if needed
201 if self.galaxy_flag and not os_path.exists('tomo_reduce_plots'):
202 mkdir('tomo_reduce_plots')
203
204 if isinstance(data, dict):
205 # Create Nexus format object from input dictionary
206 wf = TomoWorkflow(**data)
207 if len(wf.sample_maps) > 1:
208 raise ValueError(f'Multiple sample maps not yet implemented')
209 # print(f'\nwf:\n{wf}\n')
210 nxroot = NXroot()
211 t0 = time()
212 for sample_map in wf.sample_maps:
213 logger.info(f'Start constructing the {sample_map.title} map.')
214 import_scanparser(sample_map.station)
215 sample_map.construct_nxentry(nxroot, include_raw_data=False)
216 logger.info(f'Constructed all sample maps in {time()-t0:.2f} seconds.')
217 nxentry = nxroot[nxroot.attrs['default']]
218 # Get test mode configuration info
219 if self.test_mode:
220 self.test_config = data['sample_maps'][0]['test_mode']
221 elif isinstance(data, NXroot):
222 nxentry = data[data.attrs['default']]
223 else:
224 raise ValueError(f'Invalid parameter data ({data})')
225
226 # Create an NXprocess to store data reduction (meta)data
227 reduced_data = NXprocess()
228
229 # Generate dark field
230 if 'dark_field' in nxentry['spec_scans']:
231 reduced_data = self._gen_dark(nxentry, reduced_data)
232
233 # Generate bright field
234 reduced_data = self._gen_bright(nxentry, reduced_data)
235
236 # Set vertical detector bounds for image stack
237 img_x_bounds = self._set_detector_bounds(nxentry, reduced_data, img_x_bounds=img_x_bounds)
238 logger.info(f'img_x_bounds = {img_x_bounds}')
239 reduced_data['img_x_bounds'] = img_x_bounds
240
241 # Set zoom and/or theta skip to reduce memory the requirement
242 zoom_perc, num_theta_skip = self._set_zoom_or_skip()
243 if zoom_perc is not None:
244 reduced_data.attrs['zoom_perc'] = zoom_perc
245 if num_theta_skip is not None:
246 reduced_data.attrs['num_theta_skip'] = num_theta_skip
247
248 # Generate reduced tomography fields
249 reduced_data = self._gen_tomo(nxentry, reduced_data)
250
251 # Create a copy of the input Nexus object and remove raw and any existing reduced data
252 if isinstance(data, NXroot):
253 exclude_items = [f'{nxentry._name}/reduced_data/data',
254 f'{nxentry._name}/instrument/detector/data',
255 f'{nxentry._name}/instrument/detector/image_key',
256 f'{nxentry._name}/instrument/detector/sequence_number',
257 f'{nxentry._name}/sample/rotation_angle',
258 f'{nxentry._name}/sample/x_translation',
259 f'{nxentry._name}/sample/z_translation',
260 f'{nxentry._name}/data/data',
261 f'{nxentry._name}/data/image_key',
262 f'{nxentry._name}/data/rotation_angle',
263 f'{nxentry._name}/data/x_translation',
264 f'{nxentry._name}/data/z_translation']
265 nxroot = nxcopy(data, exclude_nxpaths=exclude_items)
266 nxentry = nxroot[nxroot.attrs['default']]
267
268 # Add the reduced data NXprocess
269 nxentry.reduced_data = reduced_data
270
271 if 'data' not in nxentry:
272 nxentry.data = NXdata()
273 nxentry.attrs['default'] = 'data'
274 nxentry.data.makelink(nxentry.reduced_data.data.tomo_fields, name='reduced_data')
275 nxentry.data.makelink(nxentry.reduced_data.rotation_angle, name='rotation_angle')
276 nxentry.data.attrs['signal'] = 'reduced_data'
277
278 return(nxroot)
279
280 def find_centers(self, nxroot, center_rows=None, center_stack_index=None):
281 """Find the calibrated center axis info
282 """
283 logger.info('Find the calibrated center axis info')
284
285 if not isinstance(nxroot, NXroot):
286 raise ValueError(f'Invalid parameter nxroot ({nxroot})')
287 nxentry = nxroot[nxroot.attrs['default']]
288 if not isinstance(nxentry, NXentry):
289 raise ValueError(f'Invalid nxentry ({nxentry})')
290 if self.galaxy_flag:
291 if center_rows is not None:
292 center_rows = tuple(center_rows)
293 if not is_int_pair(center_rows):
294 raise ValueError(f'Invalid parameter center_rows ({center_rows})')
295 elif center_rows is not None:
296 logger.warning(f'Ignoring parameter center_rows ({center_rows})')
297 center_rows = None
298 if self.galaxy_flag:
299 if center_stack_index is not None and not is_int(center_stack_index, ge=0):
300 raise ValueError(f'Invalid parameter center_stack_index ({center_stack_index})')
301 elif center_stack_index is not None:
302 logger.warning(f'Ignoring parameter center_stack_index ({center_stack_index})')
303 center_stack_index = None
304
305 # Create plot galaxy path directory and path if needed
306 if self.galaxy_flag:
307 if not os_path.exists('tomo_find_centers_plots'):
308 mkdir('tomo_find_centers_plots')
309 path = 'tomo_find_centers_plots'
310 else:
311 path = self.output_folder
312
313 # Check if reduced data is available
314 if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data):
315 raise KeyError(f'Unable to find valid reduced data in {nxentry}.')
316
317 # Select the image stack to calibrate the center axis
318 # reduced data axes order: stack,theta,row,column
319 # Note: Nexus cannot follow a link if the data it points to is too big,
320 # so get the data from the actual place, not from nxentry.data
321 tomo_fields_shape = nxentry.reduced_data.data.tomo_fields.shape
322 if len(tomo_fields_shape) != 4 or any(True for dim in tomo_fields_shape if not dim):
323 raise KeyError('Unable to load the required reduced tomography stack')
324 num_tomo_stacks = tomo_fields_shape[0]
325 if num_tomo_stacks == 1:
326 center_stack_index = 0
327 default = 'n'
328 else:
329 if self.test_mode:
330 center_stack_index = self.test_config['center_stack_index']-1 # make offset 0
331 elif self.galaxy_flag:
332 if center_stack_index is None:
333 center_stack_index = int(num_tomo_stacks/2)
334 if center_stack_index >= num_tomo_stacks:
335 raise ValueError(f'Invalid parameter center_stack_index ({center_stack_index})')
336 else:
337 center_stack_index = input_int('\nEnter tomography stack index to calibrate the '
338 'center axis', ge=1, le=num_tomo_stacks, default=int(1+num_tomo_stacks/2))
339 center_stack_index -= 1
340 default = 'y'
341
342 # Get thetas (in degrees)
343 thetas = np.asarray(nxentry.reduced_data.rotation_angle)
344
345 # Get effective pixel_size
346 if 'zoom_perc' in nxentry.reduced_data:
347 eff_pixel_size = 100.*(nxentry.instrument.detector.x_pixel_size/
348 nxentry.reduced_data.attrs['zoom_perc'])
349 else:
350 eff_pixel_size = nxentry.instrument.detector.x_pixel_size
351
352 # Get cross sectional diameter
353 cross_sectional_dim = tomo_fields_shape[3]*eff_pixel_size
354 logger.debug(f'cross_sectional_dim = {cross_sectional_dim}')
355
356 # Determine center offset at sample row boundaries
357 logger.info('Determine center offset at sample row boundaries')
358
359 # Lower row center
360 if self.test_mode:
361 lower_row = self.test_config['lower_row']
362 elif self.galaxy_flag:
363 if center_rows is None:
364 lower_row = 0
365 else:
366 lower_row = min(center_rows)
367 if not 0 <= lower_row < tomo_fields_shape[2]-1:
368 raise ValueError(f'Invalid parameter center_rows ({center_rows})')
369 else:
370 lower_row = select_one_image_bound(
371 nxentry.reduced_data.data.tomo_fields[center_stack_index,0,:,:], 0, bound=0,
372 title=f'theta={round(thetas[0], 2)+0}',
373 bound_name='row index to find lower center', default=default, raise_error=True)
374 logger.debug('Finding center...')
375 t0 = time()
376 lower_center_offset = self._find_center_one_plane(
377 #np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index,:,lower_row,:]),
378 nxentry.reduced_data.data.tomo_fields[center_stack_index,:,lower_row,:],
379 lower_row, thetas, eff_pixel_size, cross_sectional_dim, path=path,
380 num_core=self.num_core)
381 logger.debug(f'... done in {time()-t0:.2f} seconds')
382 logger.debug(f'lower_row = {lower_row:.2f}')
383 logger.debug(f'lower_center_offset = {lower_center_offset:.2f}')
384
385 # Upper row center
386 if self.test_mode:
387 upper_row = self.test_config['upper_row']
388 elif self.galaxy_flag:
389 if center_rows is None:
390 upper_row = tomo_fields_shape[2]-1
391 else:
392 upper_row = max(center_rows)
393 if not lower_row < upper_row < tomo_fields_shape[2]:
394 raise ValueError(f'Invalid parameter center_rows ({center_rows})')
395 else:
396 upper_row = select_one_image_bound(
397 nxentry.reduced_data.data.tomo_fields[center_stack_index,0,:,:], 0,
398 bound=tomo_fields_shape[2]-1, title=f'theta={round(thetas[0], 2)+0}',
399 bound_name='row index to find upper center', default=default, raise_error=True)
400 logger.debug('Finding center...')
401 t0 = time()
402 upper_center_offset = self._find_center_one_plane(
403 #np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index,:,upper_row,:]),
404 nxentry.reduced_data.data.tomo_fields[center_stack_index,:,upper_row,:],
405 upper_row, thetas, eff_pixel_size, cross_sectional_dim, path=path,
406 num_core=self.num_core)
407 logger.debug(f'... done in {time()-t0:.2f} seconds')
408 logger.debug(f'upper_row = {upper_row:.2f}')
409 logger.debug(f'upper_center_offset = {upper_center_offset:.2f}')
410
411 center_config = {'lower_row': lower_row, 'lower_center_offset': lower_center_offset,
412 'upper_row': upper_row, 'upper_center_offset': upper_center_offset}
413 if num_tomo_stacks > 1:
414 center_config['center_stack_index'] = center_stack_index+1 # save as offset 1
415
416 # Save test data to file
417 if self.test_mode:
418 with open(f'{self.output_folder}/center_config.yaml', 'w') as f:
419 safe_dump(center_config, f)
420
421 return(center_config)
422
423 def reconstruct_data(self, nxroot, center_info, x_bounds=None, y_bounds=None):
424 """Reconstruct the tomography data.
425 """
426 logger.info('Reconstruct the tomography data')
427
428 if not isinstance(nxroot, NXroot):
429 raise ValueError(f'Invalid parameter nxroot ({nxroot})')
430 nxentry = nxroot[nxroot.attrs['default']]
431 if not isinstance(nxentry, NXentry):
432 raise ValueError(f'Invalid nxentry ({nxentry})')
433 if not isinstance(center_info, dict):
434 raise ValueError(f'Invalid parameter center_info ({center_info})')
435
436 # Create plot galaxy path directory and path if needed
437 if self.galaxy_flag:
438 if not os_path.exists('tomo_reconstruct_plots'):
439 mkdir('tomo_reconstruct_plots')
440 path = 'tomo_reconstruct_plots'
441 else:
442 path = self.output_folder
443
444 # Check if reduced data is available
445 if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data):
446 raise KeyError(f'Unable to find valid reduced data in {nxentry}.')
447
448 # Create an NXprocess to store image reconstruction (meta)data
449 nxprocess = NXprocess()
450
451 # Get rotation axis rows and centers
452 lower_row = center_info.get('lower_row')
453 lower_center_offset = center_info.get('lower_center_offset')
454 upper_row = center_info.get('upper_row')
455 upper_center_offset = center_info.get('upper_center_offset')
456 if (lower_row is None or lower_center_offset is None or upper_row is None or
457 upper_center_offset is None):
458 raise KeyError(f'Unable to find valid calibrated center axis info in {center_info}.')
459 center_slope = (upper_center_offset-lower_center_offset)/(upper_row-lower_row)
460
461 # Get thetas (in degrees)
462 thetas = np.asarray(nxentry.reduced_data.rotation_angle)
463
464 # Reconstruct tomography data
465 # reduced data axes order: stack,theta,row,column
466 # reconstructed data order in each stack: row/z,x,y
467 # Note: Nexus cannot follow a link if the data it points to is too big,
468 # so get the data from the actual place, not from nxentry.data
469 if 'zoom_perc' in nxentry.reduced_data:
470 res_title = f'{nxentry.reduced_data.attrs["zoom_perc"]}p'
471 else:
472 res_title = 'fullres'
473 load_error = False
474 num_tomo_stacks = nxentry.reduced_data.data.tomo_fields.shape[0]
475 tomo_recon_stacks = num_tomo_stacks*[np.array([])]
476 for i in range(num_tomo_stacks):
477 # Convert reduced data stack from theta,row,column to row,theta,column
478 logger.debug(f'Reading reduced data stack {i+1}...')
479 t0 = time()
480 tomo_stack = np.asarray(nxentry.reduced_data.data.tomo_fields[i])
481 logger.debug(f'... done in {time()-t0:.2f} seconds')
482 if len(tomo_stack.shape) != 3 or any(True for dim in tomo_stack.shape if not dim):
483 raise ValueError(f'Unable to load tomography stack {i+1} for reconstruction')
484 tomo_stack = np.swapaxes(tomo_stack, 0, 1)
485 assert(len(thetas) == tomo_stack.shape[1])
486 assert(0 <= lower_row < upper_row < tomo_stack.shape[0])
487 center_offsets = [lower_center_offset-lower_row*center_slope,
488 upper_center_offset+(tomo_stack.shape[0]-1-upper_row)*center_slope]
489 t0 = time()
490 logger.debug(f'Running _reconstruct_one_tomo_stack on {self.num_core} cores ...')
491 tomo_recon_stack = self._reconstruct_one_tomo_stack(tomo_stack, thetas,
492 center_offsets=center_offsets, num_core=self.num_core, algorithm='gridrec')
493 logger.debug(f'... done in {time()-t0:.2f} seconds')
494 logger.info(f'Reconstruction of stack {i+1} took {time()-t0:.2f} seconds')
495
496 # Combine stacks
497 tomo_recon_stacks[i] = tomo_recon_stack
498
499 # Resize the reconstructed tomography data
500 # reconstructed data order in each stack: row/z,x,y
501 if self.test_mode:
502 x_bounds = self.test_config.get('x_bounds')
503 y_bounds = self.test_config.get('y_bounds')
504 z_bounds = None
505 elif self.galaxy_flag:
506 if x_bounds is not None and not is_int_pair(x_bounds, ge=0,
507 lt=tomo_recon_stacks[0].shape[1]):
508 raise ValueError(f'Invalid parameter x_bounds ({x_bounds})')
509 if y_bounds is not None and not is_int_pair(y_bounds, ge=0,
510 lt=tomo_recon_stacks[0].shape[1]):
511 raise ValueError(f'Invalid parameter y_bounds ({y_bounds})')
512 z_bounds = None
513 else:
514 x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_stacks)
515 if x_bounds is None:
516 x_range = (0, tomo_recon_stacks[0].shape[1])
517 x_slice = int(x_range[1]/2)
518 else:
519 x_range = (min(x_bounds), max(x_bounds))
520 x_slice = int((x_bounds[0]+x_bounds[1])/2)
521 if y_bounds is None:
522 y_range = (0, tomo_recon_stacks[0].shape[2])
523 y_slice = int(y_range[1]/2)
524 else:
525 y_range = (min(y_bounds), max(y_bounds))
526 y_slice = int((y_bounds[0]+y_bounds[1])/2)
527 if z_bounds is None:
528 z_range = (0, tomo_recon_stacks[0].shape[0])
529 z_slice = int(z_range[1]/2)
530 else:
531 z_range = (min(z_bounds), max(z_bounds))
532 z_slice = int((z_bounds[0]+z_bounds[1])/2)
533
534 # Plot a few reconstructed image slices
535 if num_tomo_stacks == 1:
536 basetitle = 'recon'
537 else:
538 basetitle = f'recon stack {i+1}'
539 for i, stack in enumerate(tomo_recon_stacks):
540 title = f'{basetitle} {res_title} xslice{x_slice}'
541 quick_imshow(stack[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]],
542 title=title, path=path, save_fig=self.save_figs, save_only=self.save_only,
543 block=self.block)
544 title = f'{basetitle} {res_title} yslice{y_slice}'
545 quick_imshow(stack[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice],
546 title=title, path=path, save_fig=self.save_figs, save_only=self.save_only,
547 block=self.block)
548 title = f'{basetitle} {res_title} zslice{z_slice}'
549 quick_imshow(stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]],
550 title=title, path=path, save_fig=self.save_figs, save_only=self.save_only,
551 block=self.block)
552
553 # Save test data to file
554 # reconstructed data order in each stack: row/z,x,y
555 if self.test_mode:
556 for i, stack in enumerate(tomo_recon_stacks):
557 np.savetxt(f'{self.output_folder}/recon_stack_{i+1}.txt',
558 stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e')
559
560 # Add image reconstruction to reconstructed data NXprocess
561 # reconstructed data order in each stack: row/z,x,y
562 nxprocess.data = NXdata()
563 nxprocess.attrs['default'] = 'data'
564 for k, v in center_info.items():
565 nxprocess[k] = v
566 if x_bounds is not None:
567 nxprocess.x_bounds = x_bounds
568 if y_bounds is not None:
569 nxprocess.y_bounds = y_bounds
570 if z_bounds is not None:
571 nxprocess.z_bounds = z_bounds
572 nxprocess.data['reconstructed_data'] = np.asarray([stack[z_range[0]:z_range[1],
573 x_range[0]:x_range[1],y_range[0]:y_range[1]] for stack in tomo_recon_stacks])
574 nxprocess.data.attrs['signal'] = 'reconstructed_data'
575
576 # Create a copy of the input Nexus object and remove reduced data
577 exclude_items = [f'{nxentry._name}/reduced_data/data', f'{nxentry._name}/data/reduced_data']
578 nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items)
579
580 # Add the reconstructed data NXprocess to the new Nexus object
581 nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']]
582 nxentry_copy.reconstructed_data = nxprocess
583 if 'data' not in nxentry_copy:
584 nxentry_copy.data = NXdata()
585 nxentry_copy.attrs['default'] = 'data'
586 nxentry_copy.data.makelink(nxprocess.data.reconstructed_data, name='reconstructed_data')
587 nxentry_copy.data.attrs['signal'] = 'reconstructed_data'
588
589 return(nxroot_copy)
590
591 def combine_data(self, nxroot, x_bounds=None, y_bounds=None):
592 """Combine the reconstructed tomography stacks.
593 """
594 logger.info('Combine the reconstructed tomography stacks')
595
596 if not isinstance(nxroot, NXroot):
597 raise ValueError(f'Invalid parameter nxroot ({nxroot})')
598 nxentry = nxroot[nxroot.attrs['default']]
599 if not isinstance(nxentry, NXentry):
600 raise ValueError(f'Invalid nxentry ({nxentry})')
601
602 # Create plot galaxy path directory and path if needed
603 if self.galaxy_flag:
604 if not os_path.exists('tomo_combine_plots'):
605 mkdir('tomo_combine_plots')
606 path = 'tomo_combine_plots'
607 else:
608 path = self.output_folder
609
610 # Check if reconstructed image data is available
611 if ('reconstructed_data' not in nxentry or 'reconstructed_data' not in nxentry.data):
612 raise KeyError(f'Unable to find valid reconstructed image data in {nxentry}.')
613
614 # Create an NXprocess to store combined image reconstruction (meta)data
615 nxprocess = NXprocess()
616
617 # Get the reconstructed data
618 # reconstructed data order: stack,row(z),x,y
619 # Note: Nexus cannot follow a link if the data it points to is too big,
620 # so get the data from the actual place, not from nxentry.data
621 num_tomo_stacks = nxentry.reconstructed_data.data.reconstructed_data.shape[0]
622 if num_tomo_stacks == 1:
623 logger.info('Only one stack available: leaving combine_data')
624 return(None)
625
626 # Combine the reconstructed stacks
627 # (load one stack at a time to reduce risk of hitting Nexus data access limit)
628 t0 = time()
629 logger.debug(f'Combining the reconstructed stacks ...')
630 tomo_recon_combined = np.asarray(nxentry.reconstructed_data.data.reconstructed_data[0])
631 if num_tomo_stacks > 2:
632 tomo_recon_combined = np.concatenate([tomo_recon_combined]+
633 [nxentry.reconstructed_data.data.reconstructed_data[i]
634 for i in range(1, num_tomo_stacks-1)])
635 if num_tomo_stacks > 1:
636 tomo_recon_combined = np.concatenate([tomo_recon_combined]+
637 [nxentry.reconstructed_data.data.reconstructed_data[num_tomo_stacks-1]])
638 logger.debug(f'... done in {time()-t0:.2f} seconds')
639 logger.info(f'Combining the reconstructed stacks took {time()-t0:.2f} seconds')
640
641 # Resize the combined tomography data stacks
642 # combined data order: row/z,x,y
643 if self.test_mode:
644 x_bounds = None
645 y_bounds = None
646 z_bounds = self.test_config.get('z_bounds')
647 elif self.galaxy_flag:
648 if x_bounds is not None and not is_int_pair(x_bounds, ge=0,
649 lt=tomo_recon_stacks[0].shape[1]):
650 raise ValueError(f'Invalid parameter x_bounds ({x_bounds})')
651 if y_bounds is not None and not is_int_pair(y_bounds, ge=0,
652 lt=tomo_recon_stacks[0].shape[1]):
653 raise ValueError(f'Invalid parameter y_bounds ({y_bounds})')
654 z_bounds = None
655 else:
656 x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_combined,
657 z_only=True)
658 if x_bounds is None:
659 x_range = (0, tomo_recon_combined.shape[1])
660 x_slice = int(x_range[1]/2)
661 else:
662 x_range = x_bounds
663 x_slice = int((x_bounds[0]+x_bounds[1])/2)
664 if y_bounds is None:
665 y_range = (0, tomo_recon_combined.shape[2])
666 y_slice = int(y_range[1]/2)
667 else:
668 y_range = y_bounds
669 y_slice = int((y_bounds[0]+y_bounds[1])/2)
670 if z_bounds is None:
671 z_range = (0, tomo_recon_combined.shape[0])
672 z_slice = int(z_range[1]/2)
673 else:
674 z_range = z_bounds
675 z_slice = int((z_bounds[0]+z_bounds[1])/2)
676
677 # Plot a few combined image slices
678 quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]],
679 title=f'recon combined xslice{x_slice}', path=path,
680 save_fig=self.save_figs, save_only=self.save_only, block=self.block)
681 quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice],
682 title=f'recon combined yslice{y_slice}', path=path,
683 save_fig=self.save_figs, save_only=self.save_only, block=self.block)
684 quick_imshow(tomo_recon_combined[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]],
685 title=f'recon combined zslice{z_slice}', path=path,
686 save_fig=self.save_figs, save_only=self.save_only, block=self.block)
687
688 # Save test data to file
689 # combined data order: row/z,x,y
690 if self.test_mode:
691 np.savetxt(f'{self.output_folder}/recon_combined.txt', tomo_recon_combined[
692 z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e')
693
694 # Add image reconstruction to reconstructed data NXprocess
695 # combined data order: row/z,x,y
696 nxprocess.data = NXdata()
697 nxprocess.attrs['default'] = 'data'
698 if x_bounds is not None:
699 nxprocess.x_bounds = x_bounds
700 if y_bounds is not None:
701 nxprocess.y_bounds = y_bounds
702 if z_bounds is not None:
703 nxprocess.z_bounds = z_bounds
704 nxprocess.data['combined_data'] = tomo_recon_combined
705 nxprocess.data.attrs['signal'] = 'combined_data'
706
707 # Create a copy of the input Nexus object and remove reconstructed data
708 exclude_items = [f'{nxentry._name}/reconstructed_data/data',
709 f'{nxentry._name}/data/reconstructed_data']
710 nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items)
711
712 # Add the combined data NXprocess to the new Nexus object
713 nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']]
714 nxentry_copy.combined_data = nxprocess
715 if 'data' not in nxentry_copy:
716 nxentry_copy.data = NXdata()
717 nxentry_copy.attrs['default'] = 'data'
718 nxentry_copy.data.makelink(nxprocess.data.combined_data, name='combined_data')
719 nxentry_copy.data.attrs['signal'] = 'combined_data'
720
721 return(nxroot_copy)
722
723 def _gen_dark(self, nxentry, reduced_data):
724 """Generate dark field.
725 """
726 # Get the dark field images
727 image_key = nxentry.instrument.detector.get('image_key', None)
728 if image_key and 'data' in nxentry.instrument.detector:
729 field_indices = [index for index, key in enumerate(image_key) if key == 2]
730 tdf_stack = nxentry.instrument.detector.data[field_indices,:,:]
731 # RV the default NXtomo form does not accomodate bright or dark field stacks
732 else:
733 dark_field_scans = nxentry.spec_scans.dark_field
734 dark_field = FlatField.construct_from_nxcollection(dark_field_scans)
735 prefix = str(nxentry.instrument.detector.local_name)
736 tdf_stack = dark_field.get_detector_data(prefix)
737 if isinstance(tdf_stack, list):
738 assert(len(tdf_stack) == 1) # TODO
739 tdf_stack = tdf_stack[0]
740
741 # Take median
742 if tdf_stack.ndim == 2:
743 tdf = tdf_stack
744 elif tdf_stack.ndim == 3:
745 tdf = np.median(tdf_stack, axis=0)
746 del tdf_stack
747 else:
748 raise ValueError(f'Invalid tdf_stack shape ({tdf_stack.shape})')
749
750 # Remove dark field intensities above the cutoff
751 #RV tdf_cutoff = None
752 tdf_cutoff = tdf.min()+2*(np.median(tdf)-tdf.min())
753 logger.debug(f'tdf_cutoff = {tdf_cutoff}')
754 if tdf_cutoff is not None:
755 if not is_num(tdf_cutoff, ge=0):
756 logger.warning(f'Ignoring illegal value of tdf_cutoff {tdf_cutoff}')
757 else:
758 tdf[tdf > tdf_cutoff] = np.nan
759 logger.debug(f'tdf_cutoff = {tdf_cutoff}')
760
761 # Remove nans
762 tdf_mean = np.nanmean(tdf)
763 logger.debug(f'tdf_mean = {tdf_mean}')
764 np.nan_to_num(tdf, copy=False, nan=tdf_mean, posinf=tdf_mean, neginf=0.)
765
766 # Plot dark field
767 if self.galaxy_flag:
768 quick_imshow(tdf, title='dark field', path='tomo_reduce_plots', save_fig=self.save_figs,
769 save_only=self.save_only)
770 elif not self.test_mode:
771 quick_imshow(tdf, title='dark field', path=self.output_folder, save_fig=self.save_figs,
772 save_only=self.save_only)
773 clear_imshow('dark field')
774 # quick_imshow(tdf, title='dark field', block=True)
775
776 # Add dark field to reduced data NXprocess
777 reduced_data.data = NXdata()
778 reduced_data.data['dark_field'] = tdf
779
780 return(reduced_data)
781
782 def _gen_bright(self, nxentry, reduced_data):
783 """Generate bright field.
784 """
785 # Get the bright field images
786 image_key = nxentry.instrument.detector.get('image_key', None)
787 if image_key and 'data' in nxentry.instrument.detector:
788 field_indices = [index for index, key in enumerate(image_key) if key == 1]
789 tbf_stack = nxentry.instrument.detector.data[field_indices,:,:]
790 # RV the default NXtomo form does not accomodate bright or dark field stacks
791 else:
792 bright_field_scans = nxentry.spec_scans.bright_field
793 bright_field = FlatField.construct_from_nxcollection(bright_field_scans)
794 prefix = str(nxentry.instrument.detector.local_name)
795 tbf_stack = bright_field.get_detector_data(prefix)
796 if isinstance(tbf_stack, list):
797 assert(len(tbf_stack) == 1) # TODO
798 tbf_stack = tbf_stack[0]
799
800 # Take median if more than one image
801 """Median or mean: It may be best to try the median because of some image
802 artifacts that arise due to crinkles in the upstream kapton tape windows
803 causing some phase contrast images to appear on the detector.
804 One thing that also may be useful in a future implementation is to do a
805 brightfield adjustment on EACH frame of the tomo based on a ROI in the
806 corner of the frame where there is no sample but there is the direct X-ray
807 beam because there is frame to frame fluctuations from the incoming beam.
808 We don’t typically account for them but potentially could.
809 """
810 if tbf_stack.ndim == 2:
811 tbf = tbf_stack
812 elif tbf_stack.ndim == 3:
813 tbf = np.median(tbf_stack, axis=0)
814 del tbf_stack
815 else:
816 raise ValueError(f'Invalid tbf_stack shape ({tbf_stacks.shape})')
817
818 # Subtract dark field
819 if 'data' in reduced_data and 'dark_field' in reduced_data.data:
820 tbf -= reduced_data.data.dark_field
821 else:
822 logger.warning('Dark field unavailable')
823
824 # Set any non-positive values to one
825 # (avoid negative bright field values for spikes in dark field)
826 tbf[tbf < 1] = 1
827
828 # Plot bright field
829 if self.galaxy_flag:
830 quick_imshow(tbf, title='bright field', path='tomo_reduce_plots',
831 save_fig=self.save_figs, save_only=self.save_only)
832 elif not self.test_mode:
833 quick_imshow(tbf, title='bright field', path=self.output_folder,
834 save_fig=self.save_figs, save_only=self.save_only)
835 clear_imshow('bright field')
836 # quick_imshow(tbf, title='bright field', block=True)
837
838 # Add bright field to reduced data NXprocess
839 if 'data' not in reduced_data:
840 reduced_data.data = NXdata()
841 reduced_data.data['bright_field'] = tbf
842
843 return(reduced_data)
844
845 def _set_detector_bounds(self, nxentry, reduced_data, img_x_bounds=None):
846 """Set vertical detector bounds for each image stack.
847 Right now the range is the same for each set in the image stack.
848 """
849 if self.test_mode:
850 return(tuple(self.test_config['img_x_bounds']))
851
852 # Get the first tomography image and the reference heights
853 image_key = nxentry.instrument.detector.get('image_key', None)
854 if image_key and 'data' in nxentry.instrument.detector:
855 field_indices = [index for index, key in enumerate(image_key) if key == 0]
856 first_image = np.asarray(nxentry.instrument.detector.data[field_indices[0],:,:])
857 theta = float(nxentry.sample.rotation_angle[field_indices[0]])
858 z_translation_all = nxentry.sample.z_translation[field_indices]
859 vertical_shifts = sorted(list(set(z_translation_all)))
860 num_tomo_stacks = len(vertical_shifts)
861 else:
862 tomo_field_scans = nxentry.spec_scans.tomo_fields
863 tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans)
864 vertical_shifts = tomo_fields.get_vertical_shifts()
865 if not isinstance(vertical_shifts, list):
866 vertical_shifts = [vertical_shifts]
867 prefix = str(nxentry.instrument.detector.local_name)
868 t0 = time()
869 first_image = tomo_fields.get_detector_data(prefix, tomo_fields.scan_numbers[0], 0)
870 logger.debug(f'Getting first image took {time()-t0:.2f} seconds')
871 num_tomo_stacks = len(tomo_fields.scan_numbers)
872 theta = tomo_fields.theta_range['start']
873
874 # Select image bounds
875 title = f'tomography image at theta={round(theta, 2)+0}'
876 if (img_x_bounds is not None and not is_index_range(img_x_bounds, ge=0,
877 le=first_image.shape[0])):
878 raise ValueError(f'Invalid parameter img_x_bounds ({img_x_bounds})')
879 if nxentry.instrument.source.attrs['station'] in ('id1a3', 'id3a'):
880 pixel_size = nxentry.instrument.detector.x_pixel_size
881 # Try to get a fit from the bright field
882 tbf = np.asarray(reduced_data.data.bright_field)
883 tbf_shape = tbf.shape
884 x_sum = np.sum(tbf, 1)
885 x_sum_min = x_sum.min()
886 x_sum_max = x_sum.max()
887 fit = Fit.fit_data(x_sum, 'rectangle', x=np.array(range(len(x_sum))), form='atan',
888 guess=True)
889 parameters = fit.best_values
890 x_low_fit = parameters.get('center1', None)
891 x_upp_fit = parameters.get('center2', None)
892 sig_low = parameters.get('sigma1', None)
893 sig_upp = parameters.get('sigma2', None)
894 have_fit = fit.success and x_low_fit is not None and x_upp_fit is not None and \
895 sig_low is not None and sig_upp is not None and \
896 0 <= x_low_fit < x_upp_fit <= x_sum.size and \
897 (sig_low+sig_upp)/(x_upp_fit-x_low_fit) < 0.1
898 if have_fit:
899 # Set a 5% margin on each side
900 margin = 0.05*(x_upp_fit-x_low_fit)
901 x_low_fit = max(0, x_low_fit-margin)
902 x_upp_fit = min(tbf_shape[0], x_upp_fit+margin)
903 if num_tomo_stacks == 1:
904 if have_fit:
905 # Set the default range to enclose the full fitted window
906 x_low = int(x_low_fit)
907 x_upp = int(x_upp_fit)
908 else:
909 # Center a default range of 1 mm (RV: can we get this from the slits?)
910 num_x_min = int((1.0-0.5*pixel_size)/pixel_size)
911 x_low = int(0.5*(tbf_shape[0]-num_x_min))
912 x_upp = x_low+num_x_min
913 else:
914 # Get the default range from the reference heights
915 delta_z = vertical_shifts[1]-vertical_shifts[0]
916 for i in range(2, num_tomo_stacks):
917 delta_z = min(delta_z, vertical_shifts[i]-vertical_shifts[i-1])
918 logger.debug(f'delta_z = {delta_z}')
919 num_x_min = int((delta_z-0.5*pixel_size)/pixel_size)
920 logger.debug(f'num_x_min = {num_x_min}')
921 if num_x_min > tbf_shape[0]:
922 logger.warning('Image bounds and pixel size prevent seamless stacking')
923 if have_fit:
924 # Center the default range relative to the fitted window
925 x_low = int(0.5*(x_low_fit+x_upp_fit-num_x_min))
926 x_upp = x_low+num_x_min
927 else:
928 # Center the default range
929 x_low = int(0.5*(tbf_shape[0]-num_x_min))
930 x_upp = x_low+num_x_min
931 if self.galaxy_flag:
932 img_x_bounds = (x_low, x_upp)
933 else:
934 tmp = np.copy(tbf)
935 tmp_max = tmp.max()
936 tmp[x_low,:] = tmp_max
937 tmp[x_upp-1,:] = tmp_max
938 quick_imshow(tmp, title='bright field')
939 tmp = np.copy(first_image)
940 tmp_max = tmp.max()
941 tmp[x_low,:] = tmp_max
942 tmp[x_upp-1,:] = tmp_max
943 quick_imshow(tmp, title=title)
944 del tmp
945 quick_plot((range(x_sum.size), x_sum),
946 ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'),
947 ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'),
948 title='sum over theta and y')
949 print(f'lower bound = {x_low} (inclusive)')
950 print(f'upper bound = {x_upp} (exclusive)]')
951 accept = input_yesno('Accept these bounds (y/n)?', 'y')
952 clear_imshow('bright field')
953 clear_imshow(title)
954 clear_plot('sum over theta and y')
955 if accept:
956 img_x_bounds = (x_low, x_upp)
957 else:
958 while True:
959 mask, img_x_bounds = draw_mask_1d(x_sum, title='select x data range',
960 legend='sum over theta and y')
961 if len(img_x_bounds) == 1:
962 break
963 else:
964 print(f'Choose a single connected data range')
965 img_x_bounds = tuple(img_x_bounds[0])
966 if (num_tomo_stacks > 1 and img_x_bounds[1]-img_x_bounds[0]+1 <
967 int((delta_z-0.5*pixel_size)/pixel_size)):
968 logger.warning('Image bounds and pixel size prevent seamless stacking')
969 else:
970 if num_tomo_stacks > 1:
971 raise NotImplementedError('Selecting image bounds for multiple stacks on FMB')
972 # For FMB: use the first tomography image to select range
973 # RV: revisit if they do tomography with multiple stacks
974 x_sum = np.sum(first_image, 1)
975 x_sum_min = x_sum.min()
976 x_sum_max = x_sum.max()
977 if self.galaxy_flag:
978 if img_x_bounds is None:
979 img_x_bounds = (0, first_image.shape[0])
980 else:
981 quick_imshow(first_image, title=title)
982 print('Select vertical data reduction range from first tomography image')
983 img_x_bounds = select_image_bounds(first_image, 0, title=title)
984 clear_imshow(title)
985 if img_x_bounds is None:
986 raise ValueError('Unable to select image bounds')
987
988 # Plot results
989 if self.galaxy_flag:
990 path = 'tomo_reduce_plots'
991 else:
992 path = self.output_folder
993 x_low = img_x_bounds[0]
994 x_upp = img_x_bounds[1]
995 tmp = np.copy(first_image)
996 tmp_max = tmp.max()
997 tmp[x_low,:] = tmp_max
998 tmp[x_upp-1,:] = tmp_max
999 quick_imshow(tmp, title=title, path=path, save_fig=self.save_figs, save_only=self.save_only,
1000 block=self.block)
1001 del tmp
1002 quick_plot((range(x_sum.size), x_sum),
1003 ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'),
1004 ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'),
1005 title='sum over theta and y', path=path, save_fig=self.save_figs,
1006 save_only=self.save_only, block=self.block)
1007
1008 return(img_x_bounds)
1009
1010 def _set_zoom_or_skip(self):
1011 """Set zoom and/or theta skip to reduce memory the requirement for the analysis.
1012 """
1013 # if input_yesno('\nDo you want to zoom in to reduce memory requirement (y/n)?', 'n'):
1014 # zoom_perc = input_int(' Enter zoom percentage', ge=1, le=100)
1015 # else:
1016 # zoom_perc = None
1017 zoom_perc = None
1018 # if input_yesno('Do you want to skip thetas to reduce memory requirement (y/n)?', 'n'):
1019 # num_theta_skip = input_int(' Enter the number skip theta interval', ge=0,
1020 # lt=num_theta)
1021 # else:
1022 # num_theta_skip = None
1023 num_theta_skip = None
1024 logger.debug(f'zoom_perc = {zoom_perc}')
1025 logger.debug(f'num_theta_skip = {num_theta_skip}')
1026
1027 return(zoom_perc, num_theta_skip)
1028
1029 def _gen_tomo(self, nxentry, reduced_data):
1030 """Generate tomography fields.
1031 """
1032 # Get full bright field
1033 tbf = np.asarray(reduced_data.data.bright_field)
1034 tbf_shape = tbf.shape
1035
1036 # Get image bounds
1037 img_x_bounds = tuple(reduced_data.get('img_x_bounds', (0, tbf_shape[0])))
1038 img_y_bounds = tuple(reduced_data.get('img_y_bounds', (0, tbf_shape[1])))
1039
1040 # Get resized dark field
1041 # if 'dark_field' in data:
1042 # tbf = np.asarray(reduced_data.data.dark_field[
1043 # img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]])
1044 # else:
1045 # logger.warning('Dark field unavailable')
1046 # tdf = None
1047 tdf = None
1048
1049 # Resize bright field
1050 if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]):
1051 tbf = tbf[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]]
1052
1053 # Get the tomography images
1054 image_key = nxentry.instrument.detector.get('image_key', None)
1055 if image_key and 'data' in nxentry.instrument.detector:
1056 field_indices_all = [index for index, key in enumerate(image_key) if key == 0]
1057 z_translation_all = nxentry.sample.z_translation[field_indices_all]
1058 z_translation_levels = sorted(list(set(z_translation_all)))
1059 num_tomo_stacks = len(z_translation_levels)
1060 tomo_stacks = num_tomo_stacks*[np.array([])]
1061 horizontal_shifts = []
1062 vertical_shifts = []
1063 thetas = None
1064 tomo_stacks = []
1065 for i, z_translation in enumerate(z_translation_levels):
1066 field_indices = [field_indices_all[index]
1067 for index, z in enumerate(z_translation_all) if z == z_translation]
1068 horizontal_shift = list(set(nxentry.sample.x_translation[field_indices]))
1069 assert(len(horizontal_shift) == 1)
1070 horizontal_shifts += horizontal_shift
1071 vertical_shift = list(set(nxentry.sample.z_translation[field_indices]))
1072 assert(len(vertical_shift) == 1)
1073 vertical_shifts += vertical_shift
1074 sequence_numbers = nxentry.instrument.detector.sequence_number[field_indices]
1075 if thetas is None:
1076 thetas = np.asarray(nxentry.sample.rotation_angle[field_indices]) \
1077 [sequence_numbers]
1078 else:
1079 assert(all(thetas[i] == nxentry.sample.rotation_angle[field_indices[index]]
1080 for i, index in enumerate(sequence_numbers)))
1081 assert(list(set(sequence_numbers)) == [i for i in range(len(sequence_numbers))])
1082 if list(sequence_numbers) == [i for i in range(len(sequence_numbers))]:
1083 tomo_stack = np.asarray(nxentry.instrument.detector.data[field_indices])
1084 else:
1085 raise ValueError('Unable to load the tomography images')
1086 tomo_stacks.append(tomo_stack)
1087 else:
1088 tomo_field_scans = nxentry.spec_scans.tomo_fields
1089 tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans)
1090 horizontal_shifts = tomo_fields.get_horizontal_shifts()
1091 vertical_shifts = tomo_fields.get_vertical_shifts()
1092 prefix = str(nxentry.instrument.detector.local_name)
1093 t0 = time()
1094 tomo_stacks = tomo_fields.get_detector_data(prefix)
1095 logger.debug(f'Getting tomography images took {time()-t0:.2f} seconds')
1096 logger.debug(f'Getting all images took {time()-t0:.2f} seconds')
1097 thetas = np.linspace(tomo_fields.theta_range['start'], tomo_fields.theta_range['end'],
1098 tomo_fields.theta_range['num'])
1099 if not isinstance(tomo_stacks, list):
1100 horizontal_shifts = [horizontal_shifts]
1101 vertical_shifts = [vertical_shifts]
1102 tomo_stacks = [tomo_stacks]
1103
1104 reduced_tomo_stacks = []
1105 if self.galaxy_flag:
1106 path = 'tomo_reduce_plots'
1107 else:
1108 path = self.output_folder
1109 for i, tomo_stack in enumerate(tomo_stacks):
1110 # Resize the tomography images
1111 # Right now the range is the same for each set in the image stack.
1112 if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]):
1113 t0 = time()
1114 tomo_stack = tomo_stack[:,img_x_bounds[0]:img_x_bounds[1],
1115 img_y_bounds[0]:img_y_bounds[1]].astype('float64')
1116 logger.debug(f'Resizing tomography images took {time()-t0:.2f} seconds')
1117
1118 # Subtract dark field
1119 if tdf is not None:
1120 t0 = time()
1121 with set_numexpr_threads(self.num_core):
1122 ne.evaluate('tomo_stack-tdf', out=tomo_stack)
1123 logger.debug(f'Subtracting dark field took {time()-t0:.2f} seconds')
1124
1125 # Normalize
1126 t0 = time()
1127 with set_numexpr_threads(self.num_core):
1128 ne.evaluate('tomo_stack/tbf', out=tomo_stack, truediv=True)
1129 logger.debug(f'Normalizing took {time()-t0:.2f} seconds')
1130
1131 # Remove non-positive values and linearize data
1132 t0 = time()
1133 cutoff = 1.e-6
1134 with set_numexpr_threads(self.num_core):
1135 ne.evaluate('where(tomo_stack<cutoff, cutoff, tomo_stack)', out=tomo_stack)
1136 with set_numexpr_threads(self.num_core):
1137 ne.evaluate('-log(tomo_stack)', out=tomo_stack)
1138 logger.debug('Removing non-positive values and linearizing data took '+
1139 f'{time()-t0:.2f} seconds')
1140
1141 # Get rid of nans/infs that may be introduced by normalization
1142 t0 = time()
1143 np.where(np.isfinite(tomo_stack), tomo_stack, 0.)
1144 logger.debug(f'Remove nans/infs took {time()-t0:.2f} seconds')
1145
1146 # Downsize tomography stack to smaller size
1147 # TODO use theta_skip as well
1148 tomo_stack = tomo_stack.astype('float32')
1149 if not self.test_mode:
1150 if len(tomo_stacks) == 1:
1151 title = f'red fullres theta {round(thetas[0], 2)+0}'
1152 else:
1153 title = f'red stack {i+1} fullres theta {round(thetas[0], 2)+0}'
1154 quick_imshow(tomo_stack[0,:,:], title=title, path=path, save_fig=self.save_figs,
1155 save_only=self.save_only, block=self.block)
1156 # if not self.block:
1157 # clear_imshow(title)
1158 if False and zoom_perc != 100:
1159 t0 = time()
1160 logger.debug(f'Zooming in ...')
1161 tomo_zoom_list = []
1162 for j in range(tomo_stack.shape[0]):
1163 tomo_zoom = spi.zoom(tomo_stack[j,:,:], 0.01*zoom_perc)
1164 tomo_zoom_list.append(tomo_zoom)
1165 tomo_stack = np.stack([tomo_zoom for tomo_zoom in tomo_zoom_list])
1166 logger.debug(f'... done in {time()-t0:.2f} seconds')
1167 logger.info(f'Zooming in took {time()-t0:.2f} seconds')
1168 del tomo_zoom_list
1169 if not self.test_mode:
1170 title = f'red stack {zoom_perc}p theta {round(thetas[0], 2)+0}'
1171 quick_imshow(tomo_stack[0,:,:], title=title, path=path, save_fig=self.save_figs,
1172 save_only=self.save_only, block=self.block)
1173 # if not self.block:
1174 # clear_imshow(title)
1175
1176 # Save test data to file
1177 if self.test_mode:
1178 # row_index = int(tomo_stack.shape[0]/2)
1179 # np.savetxt(f'{self.output_folder}/red_stack_{i+1}.txt', tomo_stack[row_index,:,:],
1180 # fmt='%.6e')
1181 row_index = int(tomo_stack.shape[1]/2)
1182 np.savetxt(f'{self.output_folder}/red_stack_{i+1}.txt', tomo_stack[:,row_index,:],
1183 fmt='%.6e')
1184
1185 # Combine resized stacks
1186 reduced_tomo_stacks.append(tomo_stack)
1187
1188 # Add tomo field info to reduced data NXprocess
1189 reduced_data['rotation_angle'] = thetas
1190 reduced_data['x_translation'] = np.asarray(horizontal_shifts)
1191 reduced_data['z_translation'] = np.asarray(vertical_shifts)
1192 reduced_data.data['tomo_fields'] = np.asarray(reduced_tomo_stacks)
1193
1194 if tdf is not None:
1195 del tdf
1196 del tbf
1197
1198 return(reduced_data)
1199
1200 def _find_center_one_plane(self, sinogram, row, thetas, eff_pixel_size, cross_sectional_dim,
1201 path=None, tol=0.1, num_core=1):
1202 """Find center for a single tomography plane.
1203 """
1204 # Try automatic center finding routines for initial value
1205 # sinogram index order: theta,column
1206 # need column,theta for iradon, so take transpose
1207 sinogram = np.asarray(sinogram)
1208 sinogram_T = sinogram.T
1209 center = sinogram.shape[1]/2
1210
1211 # Try using Nghia Vo’s method
1212 t0 = time()
1213 if num_core > num_core_tomopy_limit:
1214 logger.debug(f'Running find_center_vo on {num_core_tomopy_limit} cores ...')
1215 tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core_tomopy_limit)
1216 else:
1217 logger.debug(f'Running find_center_vo on {num_core} cores ...')
1218 tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core)
1219 logger.debug(f'... done in {time()-t0:.2f} seconds')
1220 logger.info(f'Finding the center using Nghia Vo’s method took {time()-t0:.2f} seconds')
1221 center_offset_vo = tomo_center-center
1222 logger.info(f'Center at row {row} using Nghia Vo’s method = {center_offset_vo:.2f}')
1223 t0 = time()
1224 logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...')
1225 recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas,
1226 eff_pixel_size, cross_sectional_dim, False, num_core)
1227 logger.debug(f'... done in {time()-t0:.2f} seconds')
1228 logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds')
1229
1230 title = f'edges row{row} center offset{center_offset_vo:.2f} Vo'
1231 self._plot_edges_one_plane(recon_plane, title, path=path)
1232
1233 # Try using phase correlation method
1234 # if input_yesno('Try finding center using phase correlation (y/n)?', 'n'):
1235 # t0 = time()
1236 # logger.debug(f'Running find_center_pc ...')
1237 # tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=0.1, rotc_guess=tomo_center)
1238 # error = 1.
1239 # while error > tol:
1240 # prev = tomo_center
1241 # tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=tol,
1242 # rotc_guess=tomo_center)
1243 # error = np.abs(tomo_center-prev)
1244 # logger.debug(f'... done in {time()-t0:.2f} seconds')
1245 # logger.info('Finding the center using the phase correlation method took '+
1246 # f'{time()-t0:.2f} seconds')
1247 # center_offset = tomo_center-center
1248 # print(f'Center at row {row} using phase correlation = {center_offset:.2f}')
1249 # t0 = time()
1250 # logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...')
1251 # recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas,
1252 # eff_pixel_size, cross_sectional_dim, False, num_core)
1253 # logger.debug(f'... done in {time()-t0:.2f} seconds')
1254 # logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds')
1255 #
1256 # title = f'edges row{row} center_offset{center_offset:.2f} PC'
1257 # self._plot_edges_one_plane(recon_plane, title, path=path)
1258
1259 # Select center location
1260 # if input_yesno('Accept a center location (y) or continue search (n)?', 'y'):
1261 if True:
1262 # center_offset = input_num(' Enter chosen center offset', ge=-center, le=center,
1263 # default=center_offset_vo)
1264 center_offset = center_offset_vo
1265 del sinogram_T
1266 del recon_plane
1267 return float(center_offset)
1268
1269 # perform center finding search
1270 while True:
1271 center_offset_low = input_int('\nEnter lower bound for center offset', ge=-center,
1272 le=center)
1273 center_offset_upp = input_int('Enter upper bound for center offset',
1274 ge=center_offset_low, le=center)
1275 if center_offset_upp == center_offset_low:
1276 center_offset_step = 1
1277 else:
1278 center_offset_step = input_int('Enter step size for center offset search', ge=1,
1279 le=center_offset_upp-center_offset_low)
1280 num_center_offset = 1+int((center_offset_upp-center_offset_low)/center_offset_step)
1281 center_offsets = np.linspace(center_offset_low, center_offset_upp, num_center_offset)
1282 for center_offset in center_offsets:
1283 if center_offset == center_offset_vo:
1284 continue
1285 t0 = time()
1286 logger.debug(f'Running _reconstruct_one_plane on {num_core} cores ...')
1287 recon_plane = self._reconstruct_one_plane(sinogram_T, center_offset+center, thetas,
1288 eff_pixel_size, cross_sectional_dim, False, num_core)
1289 logger.debug(f'... done in {time()-t0:.2f} seconds')
1290 logger.info(f'Reconstructing center_offset {center_offset} took '+
1291 f'{time()-t0:.2f} seconds')
1292 title = f'edges row{row} center_offset{center_offset:.2f}'
1293 self._plot_edges_one_plane(recon_plane, title, path=path)
1294 if input_int('\nContinue (0) or end the search (1)', ge=0, le=1):
1295 break
1296
1297 del sinogram_T
1298 del recon_plane
1299 center_offset = input_num(' Enter chosen center offset', ge=-center, le=center)
1300 return float(center_offset)
1301
1302 def _reconstruct_one_plane(self, tomo_plane_T, center, thetas, eff_pixel_size,
1303 cross_sectional_dim, plot_sinogram=True, num_core=1):
1304 """Invert the sinogram for a single tomography plane.
1305 """
1306 # tomo_plane_T index order: column,theta
1307 assert(0 <= center < tomo_plane_T.shape[0])
1308 center_offset = center-tomo_plane_T.shape[0]/2
1309 two_offset = 2*int(np.round(center_offset))
1310 two_offset_abs = np.abs(two_offset)
1311 max_rad = int(0.55*(cross_sectional_dim/eff_pixel_size)) # 10% slack to avoid edge effects
1312 if max_rad > 0.5*tomo_plane_T.shape[0]:
1313 max_rad = 0.5*tomo_plane_T.shape[0]
1314 dist_from_edge = max(1, int(np.floor((tomo_plane_T.shape[0]-two_offset_abs)/2.)-max_rad))
1315 if two_offset >= 0:
1316 logger.debug(f'sinogram range = [{two_offset+dist_from_edge}, {-dist_from_edge}]')
1317 sinogram = tomo_plane_T[two_offset+dist_from_edge:-dist_from_edge,:]
1318 else:
1319 logger.debug(f'sinogram range = [{dist_from_edge}, {two_offset-dist_from_edge}]')
1320 sinogram = tomo_plane_T[dist_from_edge:two_offset-dist_from_edge,:]
1321 if not self.galaxy_flag and plot_sinogram:
1322 quick_imshow(sinogram.T, f'sinogram center offset{center_offset:.2f}', aspect='auto',
1323 path=self.output_folder, save_fig=self.save_figs, save_only=self.save_only,
1324 block=self.block)
1325
1326 # Inverting sinogram
1327 t0 = time()
1328 recon_sinogram = iradon(sinogram, theta=thetas, circle=True)
1329 logger.debug(f'Inverting sinogram took {time()-t0:.2f} seconds')
1330 del sinogram
1331
1332 # Performing Gaussian filtering and removing ring artifacts
1333 recon_parameters = None#self.config.get('recon_parameters')
1334 if recon_parameters is None:
1335 sigma = 1.0
1336 ring_width = 15
1337 else:
1338 sigma = recon_parameters.get('gaussian_sigma', 1.0)
1339 if not is_num(sigma, ge=0.0):
1340 logger.warning(f'Invalid gaussian_sigma ({sigma}) in _reconstruct_one_plane, '+
1341 'set to a default value of 1.0')
1342 sigma = 1.0
1343 ring_width = recon_parameters.get('ring_width', 15)
1344 if not is_int(ring_width, ge=0):
1345 logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+
1346 'set to a default value of 15')
1347 ring_width = 15
1348 t0 = time()
1349 recon_sinogram = spi.gaussian_filter(recon_sinogram, sigma, mode='nearest')
1350 recon_clean = np.expand_dims(recon_sinogram, axis=0)
1351 del recon_sinogram
1352 recon_clean = tomopy.misc.corr.remove_ring(recon_clean, rwidth=ring_width, ncore=num_core)
1353 logger.debug(f'Filtering and removing ring artifacts took {time()-t0:.2f} seconds')
1354
1355 return recon_clean
1356
1357 def _plot_edges_one_plane(self, recon_plane, title, path=None):
1358 vis_parameters = None#self.config.get('vis_parameters')
1359 if vis_parameters is None:
1360 weight = 0.1
1361 else:
1362 weight = vis_parameters.get('denoise_weight', 0.1)
1363 if not is_num(weight, ge=0.0):
1364 logger.warning(f'Invalid weight ({weight}) in _plot_edges_one_plane, '+
1365 'set to a default value of 0.1')
1366 weight = 0.1
1367 edges = denoise_tv_chambolle(recon_plane, weight=weight)
1368 vmax = np.max(edges[0,:,:])
1369 vmin = -vmax
1370 if path is None:
1371 path = self.output_folder
1372 quick_imshow(edges[0,:,:], f'{title} coolwarm', path=path, cmap='coolwarm',
1373 save_fig=self.save_figs, save_only=self.save_only, block=self.block)
1374 quick_imshow(edges[0,:,:], f'{title} gray', path=path, cmap='gray', vmin=vmin, vmax=vmax,
1375 save_fig=self.save_figs, save_only=self.save_only, block=self.block)
1376 del edges
1377
1378 def _reconstruct_one_tomo_stack(self, tomo_stack, thetas, center_offsets=[], num_core=1,
1379 algorithm='gridrec'):
1380 """Reconstruct a single tomography stack.
1381 """
1382 # tomo_stack order: row,theta,column
1383 # input thetas must be in degrees
1384 # centers_offset: tomography axis shift in pixels relative to column center
1385 # RV should we remove stripes?
1386 # https://tomopy.readthedocs.io/en/latest/api/tomopy.prep.stripe.html
1387 # RV should we remove rings?
1388 # https://tomopy.readthedocs.io/en/latest/api/tomopy.misc.corr.html
1389 # RV: Add an option to do (extra) secondary iterations later or to do some sort of convergence test?
1390 if not len(center_offsets):
1391 centers = np.zeros((tomo_stack.shape[0]))
1392 elif len(center_offsets) == 2:
1393 centers = np.linspace(center_offsets[0], center_offsets[1], tomo_stack.shape[0])
1394 else:
1395 if center_offsets.size != tomo_stack.shape[0]:
1396 raise ValueError('center_offsets dimension mismatch in reconstruct_one_tomo_stack')
1397 centers = center_offsets
1398 centers += tomo_stack.shape[2]/2
1399
1400 # Get reconstruction parameters
1401 recon_parameters = None#self.config.get('recon_parameters')
1402 if recon_parameters is None:
1403 sigma = 2.0
1404 secondary_iters = 0
1405 ring_width = 15
1406 else:
1407 sigma = recon_parameters.get('stripe_fw_sigma', 2.0)
1408 if not is_num(sigma, ge=0):
1409 logger.warning(f'Invalid stripe_fw_sigma ({sigma}) in '+
1410 '_reconstruct_one_tomo_stack, set to a default value of 2.0')
1411 ring_width = 15
1412 secondary_iters = recon_parameters.get('secondary_iters', 0)
1413 if not is_int(secondary_iters, ge=0):
1414 logger.warning(f'Invalid secondary_iters ({secondary_iters}) in '+
1415 '_reconstruct_one_tomo_stack, set to a default value of 0 (skip them)')
1416 ring_width = 0
1417 ring_width = recon_parameters.get('ring_width', 15)
1418 if not is_int(ring_width, ge=0):
1419 logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+
1420 'set to a default value of 15')
1421 ring_width = 15
1422
1423 # Remove horizontal stripe
1424 t0 = time()
1425 if num_core > num_core_tomopy_limit:
1426 logger.debug('Running remove_stripe_fw on {num_core_tomopy_limit} cores ...')
1427 tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma,
1428 ncore=num_core_tomopy_limit)
1429 else:
1430 logger.debug(f'Running remove_stripe_fw on {num_core} cores ...')
1431 tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma,
1432 ncore=num_core)
1433 logger.debug(f'... tomopy.prep.stripe.remove_stripe_fw took {time()-t0:.2f} seconds')
1434
1435 # Perform initial image reconstruction
1436 logger.debug('Performing initial image reconstruction')
1437 t0 = time()
1438 logger.debug(f'Running recon on {num_core} cores ...')
1439 tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers,
1440 sinogram_order=True, algorithm=algorithm, ncore=num_core)
1441 logger.debug(f'... done in {time()-t0:.2f} seconds')
1442 logger.info(f'Performing initial image reconstruction took {time()-t0:.2f} seconds')
1443
1444 # Run optional secondary iterations
1445 if secondary_iters > 0:
1446 logger.debug(f'Running {secondary_iters} secondary iterations')
1447 #options = {'method':'SIRT_CUDA', 'proj_type':'cuda', 'num_iter':secondary_iters}
1448 #RV: doesn't work for me:
1449 #"Error: CUDA error 803: system has unsupported display driver/cuda driver combination."
1450 #options = {'method':'SIRT', 'proj_type':'linear', 'MinConstraint': 0, 'num_iter':secondary_iters}
1451 #SIRT did not finish while running overnight
1452 #options = {'method':'SART', 'proj_type':'linear', 'num_iter':secondary_iters}
1453 options = {'method':'SART', 'proj_type':'linear', 'MinConstraint': 0,
1454 'num_iter':secondary_iters}
1455 t0 = time()
1456 logger.debug(f'Running recon on {num_core} cores ...')
1457 tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers,
1458 init_recon=tomo_recon_stack, options=options, sinogram_order=True,
1459 algorithm=tomopy.astra, ncore=num_core)
1460 logger.debug(f'... done in {time()-t0:.2f} seconds')
1461 logger.info(f'Performing secondary iterations took {time()-t0:.2f} seconds')
1462
1463 # Remove ring artifacts
1464 t0 = time()
1465 tomopy.misc.corr.remove_ring(tomo_recon_stack, rwidth=ring_width, out=tomo_recon_stack,
1466 ncore=num_core)
1467 logger.debug(f'Removing ring artifacts took {time()-t0:.2f} seconds')
1468
1469 return tomo_recon_stack
1470
1471 def _resize_reconstructed_data(self, data, z_only=False):
1472 """Resize the reconstructed tomography data.
1473 """
1474 # Data order: row(z),x,y or stack,row(z),x,y
1475 if isinstance(data, list):
1476 for stack in data:
1477 assert(stack.ndim == 3)
1478 num_tomo_stacks = len(data)
1479 tomo_recon_stacks = data
1480 else:
1481 assert(data.ndim == 3)
1482 num_tomo_stacks = 1
1483 tomo_recon_stacks = [data]
1484
1485 if z_only:
1486 x_bounds = None
1487 y_bounds = None
1488 else:
1489 # Selecting x bounds (in yz-plane)
1490 tomosum = 0
1491 [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,2))
1492 for i in range(num_tomo_stacks)]
1493 select_x_bounds = input_yesno('\nDo you want to change the image x-bounds (y/n)?', 'y')
1494 if not select_x_bounds:
1495 x_bounds = None
1496 else:
1497 accept = False
1498 index_ranges = None
1499 while not accept:
1500 mask, x_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges,
1501 title='select x data range', legend='recon stack sum yz')
1502 while len(x_bounds) != 1:
1503 print('Please select exactly one continuous range')
1504 mask, x_bounds = draw_mask_1d(tomosum, title='select x data range',
1505 legend='recon stack sum yz')
1506 x_bounds = x_bounds[0]
1507 # quick_plot(tomosum, vlines=x_bounds, title='recon stack sum yz')
1508 # print(f'x_bounds = {x_bounds} (lower bound inclusive, upper bound '+
1509 # 'exclusive)')
1510 # accept = input_yesno('Accept these bounds (y/n)?', 'y')
1511 accept = True
1512 logger.debug(f'x_bounds = {x_bounds}')
1513
1514 # Selecting y bounds (in xz-plane)
1515 tomosum = 0
1516 [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,1))
1517 for i in range(num_tomo_stacks)]
1518 select_y_bounds = input_yesno('\nDo you want to change the image y-bounds (y/n)?', 'y')
1519 if not select_y_bounds:
1520 y_bounds = None
1521 else:
1522 accept = False
1523 index_ranges = None
1524 while not accept:
1525 mask, y_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges,
1526 title='select x data range', legend='recon stack sum xz')
1527 while len(y_bounds) != 1:
1528 print('Please select exactly one continuous range')
1529 mask, y_bounds = draw_mask_1d(tomosum, title='select x data range',
1530 legend='recon stack sum xz')
1531 y_bounds = y_bounds[0]
1532 # quick_plot(tomosum, vlines=y_bounds, title='recon stack sum xz')
1533 # print(f'y_bounds = {y_bounds} (lower bound inclusive, upper bound '+
1534 # 'exclusive)')
1535 # accept = input_yesno('Accept these bounds (y/n)?', 'y')
1536 accept = True
1537 logger.debug(f'y_bounds = {y_bounds}')
1538
1539 # Selecting z bounds (in xy-plane) (only valid for a single image stack)
1540 if num_tomo_stacks != 1:
1541 z_bounds = None
1542 else:
1543 tomosum = 0
1544 [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(1,2))
1545 for i in range(num_tomo_stacks)]
1546 select_z_bounds = input_yesno('Do you want to change the image z-bounds (y/n)?', 'n')
1547 if not select_z_bounds:
1548 z_bounds = None
1549 else:
1550 accept = False
1551 index_ranges = None
1552 while not accept:
1553 mask, z_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges,
1554 title='select x data range', legend='recon stack sum xy')
1555 while len(z_bounds) != 1:
1556 print('Please select exactly one continuous range')
1557 mask, z_bounds = draw_mask_1d(tomosum, title='select x data range',
1558 legend='recon stack sum xy')
1559 z_bounds = z_bounds[0]
1560 # quick_plot(tomosum, vlines=z_bounds, title='recon stack sum xy')
1561 # print(f'z_bounds = {z_bounds} (lower bound inclusive, upper bound '+
1562 # 'exclusive)')
1563 # accept = input_yesno('Accept these bounds (y/n)?', 'y')
1564 accept = True
1565 logger.debug(f'z_bounds = {z_bounds}')
1566
1567 return(x_bounds, y_bounds, z_bounds)
1568
1569
1570 def run_tomo(input_file:str, output_file:str, modes:list[str], center_file=None, num_core=-1,
1571 output_folder='.', save_figs='no', test_mode=False) -> None:
1572
1573 if test_mode:
1574 logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s'
1575 level = logging.getLevelName('INFO')
1576 logging.basicConfig(filename=f'{output_folder}/tomo.log', filemode='w',
1577 format=logging_format, level=level, force=True)
1578 logger.info(f'input_file = {input_file}')
1579 logger.info(f'center_file = {center_file}')
1580 logger.info(f'output_file = {output_file}')
1581 logger.debug(f'modes= {modes}')
1582 logger.debug(f'num_core= {num_core}')
1583 logger.info(f'output_folder = {output_folder}')
1584 logger.info(f'save_figs = {save_figs}')
1585 logger.info(f'test_mode = {test_mode}')
1586
1587 # Check for correction modes
1588 legal_modes = ['reduce_data', 'find_center', 'reconstruct_data', 'combine_data', 'all']
1589 if modes is None:
1590 modes = ['all']
1591 if not all(True if mode in legal_modes else False for mode in modes):
1592 raise ValueError(f'Invalid parameter modes ({modes})')
1593
1594 # Instantiate Tomo object
1595 tomo = Tomo(num_core=num_core, output_folder=output_folder, save_figs=save_figs,
1596 test_mode=test_mode)
1597
1598 # Read input file
1599 data = tomo.read(input_file)
1600
1601 # Generate reduced tomography images
1602 if 'reduce_data' in modes or 'all' in modes:
1603 data = tomo.gen_reduced_data(data)
1604
1605 # Find rotation axis centers for the tomography stacks.
1606 center_data = None
1607 if 'find_center' in modes or 'all' in modes:
1608 center_data = tomo.find_centers(data)
1609
1610 # Reconstruct tomography stacks
1611 if 'reconstruct_data' in modes or 'all' in modes:
1612 if center_data is None:
1613 # Read input file
1614 center_data = tomo.read(center_file)
1615 data = tomo.reconstruct_data(data, center_data)
1616 center_data = None
1617
1618 # Combine reconstructed tomography stacks
1619 if 'combine_data' in modes or 'all' in modes:
1620 data = tomo.combine_data(data)
1621
1622 # Write output file
1623 if data is not None and not test_mode:
1624 if center_data is None:
1625 data = tomo.write(data, output_file)
1626 else:
1627 data = tomo.write(center_data, output_file)
1628
1629 logger.info(f'Completed modes: {modes}')