Mercurial > repos > rv43 > test_tomo_reconstruct
comparison workflow/run_tomo.py @ 0:98e23dff1de2 draft default tip
planemo upload for repository https://github.com/rolfverberg/galaxytools commit f8c4bdb31c20c468045ad5e6eb255a293244bc6c-dirty
author | rv43 |
---|---|
date | Tue, 21 Mar 2023 16:22:42 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:98e23dff1de2 |
---|---|
1 #!/usr/bin/env python3 | |
2 | |
3 import logging | |
4 logger = logging.getLogger(__name__) | |
5 | |
6 import numpy as np | |
7 try: | |
8 import numexpr as ne | |
9 except: | |
10 pass | |
11 try: | |
12 import scipy.ndimage as spi | |
13 except: | |
14 pass | |
15 | |
16 from multiprocessing import cpu_count | |
17 from nexusformat.nexus import * | |
18 from os import mkdir | |
19 from os import path as os_path | |
20 try: | |
21 from skimage.transform import iradon | |
22 except: | |
23 pass | |
24 try: | |
25 from skimage.restoration import denoise_tv_chambolle | |
26 except: | |
27 pass | |
28 from time import time | |
29 try: | |
30 import tomopy | |
31 except: | |
32 pass | |
33 from yaml import safe_load, safe_dump | |
34 | |
35 try: | |
36 from msnctools.fit import Fit | |
37 except: | |
38 from fit import Fit | |
39 try: | |
40 from msnctools.general import illegal_value, is_int, is_int_pair, is_num, is_index_range, \ | |
41 input_int, input_num, input_yesno, input_menu, draw_mask_1d, select_image_bounds, \ | |
42 select_one_image_bound, clear_imshow, quick_imshow, clear_plot, quick_plot | |
43 except: | |
44 from general import illegal_value, is_int, is_int_pair, is_num, is_index_range, \ | |
45 input_int, input_num, input_yesno, input_menu, draw_mask_1d, select_image_bounds, \ | |
46 select_one_image_bound, clear_imshow, quick_imshow, clear_plot, quick_plot | |
47 | |
48 try: | |
49 from workflow.models import import_scanparser, FlatField, TomoField, TomoWorkflow | |
50 from workflow.__version__ import __version__ | |
51 except: | |
52 pass | |
53 | |
54 num_core_tomopy_limit = 24 | |
55 | |
56 def nxcopy(nxobject:NXobject, exclude_nxpaths:list[str]=[], nxpath_prefix:str='') -> NXobject: | |
57 '''Function that returns a copy of a nexus object, optionally exluding certain child items. | |
58 | |
59 :param nxobject: the original nexus object to return a "copy" of | |
60 :type nxobject: nexusformat.nexus.NXobject | |
61 :param exlude_nxpaths: a list of paths to child nexus objects that | |
62 should be exluded from the returned "copy", defaults to `[]` | |
63 :type exclude_nxpaths: list[str], optional | |
64 :param nxpath_prefix: For use in recursive calls from inside this | |
65 function only! | |
66 :type nxpath_prefix: str | |
67 :return: a copy of `nxobject` with some children optionally exluded. | |
68 :rtype: NXobject | |
69 ''' | |
70 | |
71 nxobject_copy = nxobject.__class__() | |
72 if not len(nxpath_prefix): | |
73 if 'default' in nxobject.attrs: | |
74 nxobject_copy.attrs['default'] = nxobject.attrs['default'] | |
75 else: | |
76 for k, v in nxobject.attrs.items(): | |
77 nxobject_copy.attrs[k] = v | |
78 | |
79 for k, v in nxobject.items(): | |
80 nxpath = os_path.join(nxpath_prefix, k) | |
81 | |
82 if nxpath in exclude_nxpaths: | |
83 continue | |
84 | |
85 if isinstance(v, NXgroup): | |
86 nxobject_copy[k] = nxcopy(v, exclude_nxpaths=exclude_nxpaths, | |
87 nxpath_prefix=os_path.join(nxpath_prefix, k)) | |
88 else: | |
89 nxobject_copy[k] = v | |
90 | |
91 return(nxobject_copy) | |
92 | |
93 class set_numexpr_threads: | |
94 | |
95 def __init__(self, num_core): | |
96 if num_core is None or num_core < 1 or num_core > cpu_count(): | |
97 self.num_core = cpu_count() | |
98 else: | |
99 self.num_core = num_core | |
100 | |
101 def __enter__(self): | |
102 self.num_core_org = ne.set_num_threads(self.num_core) | |
103 | |
104 def __exit__(self, exc_type, exc_value, traceback): | |
105 ne.set_num_threads(self.num_core_org) | |
106 | |
107 class Tomo: | |
108 """Processing tomography data with misalignment. | |
109 """ | |
110 def __init__(self, galaxy_flag=False, num_core=-1, output_folder='.', save_figs=None, | |
111 test_mode=False): | |
112 """Initialize with optional config input file or dictionary | |
113 """ | |
114 if not isinstance(galaxy_flag, bool): | |
115 raise ValueError(f'Invalid parameter galaxy_flag ({galaxy_flag})') | |
116 self.galaxy_flag = galaxy_flag | |
117 self.num_core = num_core | |
118 if self.galaxy_flag: | |
119 if output_folder != '.': | |
120 logger.warning('Ignoring output_folder in galaxy mode') | |
121 self.output_folder = '.' | |
122 if test_mode != False: | |
123 logger.warning('Ignoring test_mode in galaxy mode') | |
124 self.test_mode = False | |
125 if save_figs is not None: | |
126 logger.warning('Ignoring save_figs in galaxy mode') | |
127 save_figs = 'only' | |
128 else: | |
129 self.output_folder = os_path.abspath(output_folder) | |
130 if not os_path.isdir(output_folder): | |
131 mkdir(os_path.abspath(output_folder)) | |
132 if not isinstance(test_mode, bool): | |
133 raise ValueError(f'Invalid parameter test_mode ({test_mode})') | |
134 self.test_mode = test_mode | |
135 if save_figs is None: | |
136 save_figs = 'no' | |
137 self.test_config = {} | |
138 if self.test_mode: | |
139 if save_figs != 'only': | |
140 logger.warning('Ignoring save_figs in test mode') | |
141 save_figs = 'only' | |
142 if save_figs == 'only': | |
143 self.save_only = True | |
144 self.save_figs = True | |
145 elif save_figs == 'yes': | |
146 self.save_only = False | |
147 self.save_figs = True | |
148 elif save_figs == 'no': | |
149 self.save_only = False | |
150 self.save_figs = False | |
151 else: | |
152 raise ValueError(f'Invalid parameter save_figs ({save_figs})') | |
153 if self.save_only: | |
154 self.block = False | |
155 else: | |
156 self.block = True | |
157 if self.num_core == -1: | |
158 self.num_core = cpu_count() | |
159 if not is_int(self.num_core, gt=0, log=False): | |
160 raise ValueError(f'Invalid parameter num_core ({num_core})') | |
161 if self.num_core > cpu_count(): | |
162 logger.warning(f'num_core = {self.num_core} is larger than the number of available ' | |
163 f'processors and reduced to {cpu_count()}') | |
164 self.num_core= cpu_count() | |
165 | |
166 def read(self, filename): | |
167 logger.info(f'looking for {filename}') | |
168 if self.galaxy_flag: | |
169 try: | |
170 with open(filename, 'r') as f: | |
171 config = safe_load(f) | |
172 return(config) | |
173 except: | |
174 try: | |
175 with NXFile(filename, mode='r') as nxfile: | |
176 nxroot = nxfile.readfile() | |
177 return(nxroot) | |
178 except: | |
179 raise ValueError(f'Unable to open ({filename})') | |
180 else: | |
181 extension = os_path.splitext(filename)[1] | |
182 if extension == '.yml' or extension == '.yaml': | |
183 with open(filename, 'r') as f: | |
184 config = safe_load(f) | |
185 # if len(config) > 1: | |
186 # raise ValueError(f'Multiple root entries in {filename} not yet implemented') | |
187 # if len(list(config.values())[0]) > 1: | |
188 # raise ValueError(f'Multiple sample maps in {filename} not yet implemented') | |
189 return(config) | |
190 elif extension == '.nxs': | |
191 with NXFile(filename, mode='r') as nxfile: | |
192 nxroot = nxfile.readfile() | |
193 return(nxroot) | |
194 else: | |
195 raise ValueError(f'Invalid filename extension ({extension})') | |
196 | |
197 def write(self, data, filename): | |
198 extension = os_path.splitext(filename)[1] | |
199 if extension == '.yml' or extension == '.yaml': | |
200 with open(filename, 'w') as f: | |
201 safe_dump(data, f) | |
202 elif extension == '.nxs' or extension == '.nex': | |
203 data.save(filename, mode='w') | |
204 elif extension == '.nc': | |
205 data.to_netcdf(os_path=filename) | |
206 else: | |
207 raise ValueError(f'Invalid filename extension ({extension})') | |
208 | |
209 def gen_reduced_data(self, data, img_x_bounds=None): | |
210 """Generate the reduced tomography images. | |
211 """ | |
212 logger.info('Generate the reduced tomography images') | |
213 | |
214 # Create plot galaxy path directory if needed | |
215 if self.galaxy_flag and not os_path.exists('tomo_reduce_plots'): | |
216 mkdir('tomo_reduce_plots') | |
217 | |
218 if isinstance(data, dict): | |
219 # Create Nexus format object from input dictionary | |
220 wf = TomoWorkflow(**data) | |
221 if len(wf.sample_maps) > 1: | |
222 raise ValueError(f'Multiple sample maps not yet implemented') | |
223 # print(f'\nwf:\n{wf}\n') | |
224 nxroot = NXroot() | |
225 t0 = time() | |
226 for sample_map in wf.sample_maps: | |
227 logger.info(f'Start constructing the {sample_map.title} map.') | |
228 import_scanparser(sample_map.station) | |
229 sample_map.construct_nxentry(nxroot, include_raw_data=False) | |
230 logger.info(f'Constructed all sample maps in {time()-t0:.2f} seconds.') | |
231 nxentry = nxroot[nxroot.attrs['default']] | |
232 # Get test mode configuration info | |
233 if self.test_mode: | |
234 self.test_config = data['sample_maps'][0]['test_mode'] | |
235 elif isinstance(data, NXroot): | |
236 nxentry = data[data.attrs['default']] | |
237 else: | |
238 raise ValueError(f'Invalid parameter data ({data})') | |
239 | |
240 # Create an NXprocess to store data reduction (meta)data | |
241 reduced_data = NXprocess() | |
242 | |
243 # Generate dark field | |
244 if 'dark_field' in nxentry['spec_scans']: | |
245 reduced_data = self._gen_dark(nxentry, reduced_data) | |
246 | |
247 # Generate bright field | |
248 reduced_data = self._gen_bright(nxentry, reduced_data) | |
249 | |
250 # Set vertical detector bounds for image stack | |
251 img_x_bounds = self._set_detector_bounds(nxentry, reduced_data, img_x_bounds=img_x_bounds) | |
252 logger.info(f'img_x_bounds = {img_x_bounds}') | |
253 reduced_data['img_x_bounds'] = img_x_bounds | |
254 | |
255 # Set zoom and/or theta skip to reduce memory the requirement | |
256 zoom_perc, num_theta_skip = self._set_zoom_or_skip() | |
257 if zoom_perc is not None: | |
258 reduced_data.attrs['zoom_perc'] = zoom_perc | |
259 if num_theta_skip is not None: | |
260 reduced_data.attrs['num_theta_skip'] = num_theta_skip | |
261 | |
262 # Generate reduced tomography fields | |
263 reduced_data = self._gen_tomo(nxentry, reduced_data) | |
264 | |
265 # Create a copy of the input Nexus object and remove raw and any existing reduced data | |
266 if isinstance(data, NXroot): | |
267 exclude_items = [f'{nxentry._name}/reduced_data/data', | |
268 f'{nxentry._name}/instrument/detector/data', | |
269 f'{nxentry._name}/instrument/detector/image_key', | |
270 f'{nxentry._name}/instrument/detector/sequence_number', | |
271 f'{nxentry._name}/sample/rotation_angle', | |
272 f'{nxentry._name}/sample/x_translation', | |
273 f'{nxentry._name}/sample/z_translation', | |
274 f'{nxentry._name}/data/data', | |
275 f'{nxentry._name}/data/image_key', | |
276 f'{nxentry._name}/data/rotation_angle', | |
277 f'{nxentry._name}/data/x_translation', | |
278 f'{nxentry._name}/data/z_translation'] | |
279 nxroot = nxcopy(data, exclude_nxpaths=exclude_items) | |
280 nxentry = nxroot[nxroot.attrs['default']] | |
281 | |
282 # Add the reduced data NXprocess | |
283 nxentry.reduced_data = reduced_data | |
284 | |
285 if 'data' not in nxentry: | |
286 nxentry.data = NXdata() | |
287 nxentry.attrs['default'] = 'data' | |
288 nxentry.data.makelink(nxentry.reduced_data.data.tomo_fields, name='reduced_data') | |
289 nxentry.data.makelink(nxentry.reduced_data.rotation_angle, name='rotation_angle') | |
290 nxentry.data.attrs['signal'] = 'reduced_data' | |
291 | |
292 return(nxroot) | |
293 | |
294 def find_centers(self, nxroot, center_rows=None, center_stack_index=None): | |
295 """Find the calibrated center axis info | |
296 """ | |
297 logger.info('Find the calibrated center axis info') | |
298 | |
299 if not isinstance(nxroot, NXroot): | |
300 raise ValueError(f'Invalid parameter nxroot ({nxroot})') | |
301 nxentry = nxroot[nxroot.attrs['default']] | |
302 if not isinstance(nxentry, NXentry): | |
303 raise ValueError(f'Invalid nxentry ({nxentry})') | |
304 if self.galaxy_flag: | |
305 if center_rows is not None: | |
306 center_rows = tuple(center_rows) | |
307 if not is_int_pair(center_rows): | |
308 raise ValueError(f'Invalid parameter center_rows ({center_rows})') | |
309 elif center_rows is not None: | |
310 logger.warning(f'Ignoring parameter center_rows ({center_rows})') | |
311 center_rows = None | |
312 if self.galaxy_flag: | |
313 if center_stack_index is not None and not is_int(center_stack_index, ge=0): | |
314 raise ValueError(f'Invalid parameter center_stack_index ({center_stack_index})') | |
315 elif center_stack_index is not None: | |
316 logger.warning(f'Ignoring parameter center_stack_index ({center_stack_index})') | |
317 center_stack_index = None | |
318 | |
319 # Create plot galaxy path directory and path if needed | |
320 if self.galaxy_flag: | |
321 if not os_path.exists('tomo_find_centers_plots'): | |
322 mkdir('tomo_find_centers_plots') | |
323 path = 'tomo_find_centers_plots' | |
324 else: | |
325 path = self.output_folder | |
326 | |
327 # Check if reduced data is available | |
328 if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data): | |
329 raise KeyError(f'Unable to find valid reduced data in {nxentry}.') | |
330 | |
331 # Select the image stack to calibrate the center axis | |
332 # reduced data axes order: stack,theta,row,column | |
333 # Note: Nexus cannot follow a link if the data it points to is too big, | |
334 # so get the data from the actual place, not from nxentry.data | |
335 tomo_fields_shape = nxentry.reduced_data.data.tomo_fields.shape | |
336 if len(tomo_fields_shape) != 4 or any(True for dim in tomo_fields_shape if not dim): | |
337 raise KeyError('Unable to load the required reduced tomography stack') | |
338 num_tomo_stacks = tomo_fields_shape[0] | |
339 if num_tomo_stacks == 1: | |
340 center_stack_index = 0 | |
341 default = 'n' | |
342 else: | |
343 if self.test_mode: | |
344 center_stack_index = self.test_config['center_stack_index']-1 # make offset 0 | |
345 elif self.galaxy_flag: | |
346 if center_stack_index is None: | |
347 center_stack_index = int(num_tomo_stacks/2) | |
348 if center_stack_index >= num_tomo_stacks: | |
349 raise ValueError(f'Invalid parameter center_stack_index ({center_stack_index})') | |
350 else: | |
351 center_stack_index = input_int('\nEnter tomography stack index to calibrate the ' | |
352 'center axis', ge=1, le=num_tomo_stacks, default=int(1+num_tomo_stacks/2)) | |
353 center_stack_index -= 1 | |
354 default = 'y' | |
355 | |
356 # Get thetas (in degrees) | |
357 thetas = np.asarray(nxentry.reduced_data.rotation_angle) | |
358 | |
359 # Get effective pixel_size | |
360 if 'zoom_perc' in nxentry.reduced_data: | |
361 eff_pixel_size = 100.*(nxentry.instrument.detector.x_pixel_size/ | |
362 nxentry.reduced_data.attrs['zoom_perc']) | |
363 else: | |
364 eff_pixel_size = nxentry.instrument.detector.x_pixel_size | |
365 | |
366 # Get cross sectional diameter | |
367 cross_sectional_dim = tomo_fields_shape[3]*eff_pixel_size | |
368 logger.debug(f'cross_sectional_dim = {cross_sectional_dim}') | |
369 | |
370 # Determine center offset at sample row boundaries | |
371 logger.info('Determine center offset at sample row boundaries') | |
372 | |
373 # Lower row center | |
374 if self.test_mode: | |
375 lower_row = self.test_config['lower_row'] | |
376 elif self.galaxy_flag: | |
377 if center_rows is None or center_rows[0] == -1: | |
378 lower_row = 0 | |
379 else: | |
380 lower_row = min(center_rows) | |
381 if not 0 <= lower_row < tomo_fields_shape[2]-1: | |
382 raise ValueError(f'Invalid parameter center_rows ({center_rows})') | |
383 else: | |
384 lower_row = select_one_image_bound( | |
385 nxentry.reduced_data.data.tomo_fields[center_stack_index,0,:,:], 0, bound=0, | |
386 title=f'theta={round(thetas[0], 2)+0}', | |
387 bound_name='row index to find lower center', default=default, raise_error=True) | |
388 logger.debug('Finding center...') | |
389 t0 = time() | |
390 lower_center_offset = self._find_center_one_plane( | |
391 #np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index,:,lower_row,:]), | |
392 nxentry.reduced_data.data.tomo_fields[center_stack_index,:,lower_row,:], | |
393 lower_row, thetas, eff_pixel_size, cross_sectional_dim, path=path, | |
394 num_core=self.num_core) | |
395 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
396 logger.debug(f'lower_row = {lower_row:.2f}') | |
397 logger.debug(f'lower_center_offset = {lower_center_offset:.2f}') | |
398 | |
399 # Upper row center | |
400 if self.test_mode: | |
401 upper_row = self.test_config['upper_row'] | |
402 elif self.galaxy_flag: | |
403 if center_rows is None or center_rows[1] == -1: | |
404 upper_row = tomo_fields_shape[2]-1 | |
405 else: | |
406 upper_row = max(center_rows) | |
407 if not lower_row < upper_row < tomo_fields_shape[2]: | |
408 raise ValueError(f'Invalid parameter center_rows ({center_rows})') | |
409 else: | |
410 upper_row = select_one_image_bound( | |
411 nxentry.reduced_data.data.tomo_fields[center_stack_index,0,:,:], 0, | |
412 bound=tomo_fields_shape[2]-1, title=f'theta={round(thetas[0], 2)+0}', | |
413 bound_name='row index to find upper center', default=default, raise_error=True) | |
414 logger.debug('Finding center...') | |
415 t0 = time() | |
416 upper_center_offset = self._find_center_one_plane( | |
417 #np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index,:,upper_row,:]), | |
418 nxentry.reduced_data.data.tomo_fields[center_stack_index,:,upper_row,:], | |
419 upper_row, thetas, eff_pixel_size, cross_sectional_dim, path=path, | |
420 num_core=self.num_core) | |
421 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
422 logger.debug(f'upper_row = {upper_row:.2f}') | |
423 logger.debug(f'upper_center_offset = {upper_center_offset:.2f}') | |
424 | |
425 center_config = {'lower_row': lower_row, 'lower_center_offset': lower_center_offset, | |
426 'upper_row': upper_row, 'upper_center_offset': upper_center_offset} | |
427 if num_tomo_stacks > 1: | |
428 center_config['center_stack_index'] = center_stack_index+1 # save as offset 1 | |
429 | |
430 # Save test data to file | |
431 if self.test_mode: | |
432 with open(f'{self.output_folder}/center_config.yaml', 'w') as f: | |
433 safe_dump(center_config, f) | |
434 | |
435 return(center_config) | |
436 | |
437 def reconstruct_data(self, nxroot, center_info, x_bounds=None, y_bounds=None): | |
438 """Reconstruct the tomography data. | |
439 """ | |
440 logger.info('Reconstruct the tomography data') | |
441 | |
442 if not isinstance(nxroot, NXroot): | |
443 raise ValueError(f'Invalid parameter nxroot ({nxroot})') | |
444 nxentry = nxroot[nxroot.attrs['default']] | |
445 if not isinstance(nxentry, NXentry): | |
446 raise ValueError(f'Invalid nxentry ({nxentry})') | |
447 if not isinstance(center_info, dict): | |
448 raise ValueError(f'Invalid parameter center_info ({center_info})') | |
449 | |
450 # Create plot galaxy path directory and path if needed | |
451 if self.galaxy_flag: | |
452 if not os_path.exists('tomo_reconstruct_plots'): | |
453 mkdir('tomo_reconstruct_plots') | |
454 path = 'tomo_reconstruct_plots' | |
455 else: | |
456 path = self.output_folder | |
457 | |
458 # Check if reduced data is available | |
459 if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data): | |
460 raise KeyError(f'Unable to find valid reduced data in {nxentry}.') | |
461 | |
462 # Create an NXprocess to store image reconstruction (meta)data | |
463 nxprocess = NXprocess() | |
464 | |
465 # Get rotation axis rows and centers | |
466 lower_row = center_info.get('lower_row') | |
467 lower_center_offset = center_info.get('lower_center_offset') | |
468 upper_row = center_info.get('upper_row') | |
469 upper_center_offset = center_info.get('upper_center_offset') | |
470 if (lower_row is None or lower_center_offset is None or upper_row is None or | |
471 upper_center_offset is None): | |
472 raise KeyError(f'Unable to find valid calibrated center axis info in {center_info}.') | |
473 center_slope = (upper_center_offset-lower_center_offset)/(upper_row-lower_row) | |
474 | |
475 # Get thetas (in degrees) | |
476 thetas = np.asarray(nxentry.reduced_data.rotation_angle) | |
477 | |
478 # Reconstruct tomography data | |
479 # reduced data axes order: stack,theta,row,column | |
480 # reconstructed data order in each stack: row/z,x,y | |
481 # Note: Nexus cannot follow a link if the data it points to is too big, | |
482 # so get the data from the actual place, not from nxentry.data | |
483 if 'zoom_perc' in nxentry.reduced_data: | |
484 res_title = f'{nxentry.reduced_data.attrs["zoom_perc"]}p' | |
485 else: | |
486 res_title = 'fullres' | |
487 load_error = False | |
488 num_tomo_stacks = nxentry.reduced_data.data.tomo_fields.shape[0] | |
489 tomo_recon_stacks = num_tomo_stacks*[np.array([])] | |
490 for i in range(num_tomo_stacks): | |
491 # Convert reduced data stack from theta,row,column to row,theta,column | |
492 logger.debug(f'Reading reduced data stack {i+1}...') | |
493 t0 = time() | |
494 tomo_stack = np.asarray(nxentry.reduced_data.data.tomo_fields[i]) | |
495 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
496 if len(tomo_stack.shape) != 3 or any(True for dim in tomo_stack.shape if not dim): | |
497 raise ValueError(f'Unable to load tomography stack {i+1} for reconstruction') | |
498 tomo_stack = np.swapaxes(tomo_stack, 0, 1) | |
499 assert(len(thetas) == tomo_stack.shape[1]) | |
500 assert(0 <= lower_row < upper_row < tomo_stack.shape[0]) | |
501 center_offsets = [lower_center_offset-lower_row*center_slope, | |
502 upper_center_offset+(tomo_stack.shape[0]-1-upper_row)*center_slope] | |
503 t0 = time() | |
504 logger.debug(f'Running _reconstruct_one_tomo_stack on {self.num_core} cores ...') | |
505 tomo_recon_stack = self._reconstruct_one_tomo_stack(tomo_stack, thetas, | |
506 center_offsets=center_offsets, num_core=self.num_core, algorithm='gridrec') | |
507 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
508 logger.info(f'Reconstruction of stack {i+1} took {time()-t0:.2f} seconds') | |
509 | |
510 # Combine stacks | |
511 tomo_recon_stacks[i] = tomo_recon_stack | |
512 | |
513 # Resize the reconstructed tomography data | |
514 # reconstructed data order in each stack: row/z,x,y | |
515 if self.test_mode: | |
516 x_bounds = self.test_config.get('x_bounds') | |
517 y_bounds = self.test_config.get('y_bounds') | |
518 z_bounds = None | |
519 elif self.galaxy_flag: | |
520 if x_bounds is not None and not is_int_pair(x_bounds, ge=0, | |
521 lt=tomo_recon_stacks[0].shape[1]): | |
522 raise ValueError(f'Invalid parameter x_bounds ({x_bounds})') | |
523 if y_bounds is not None and not is_int_pair(y_bounds, ge=0, | |
524 lt=tomo_recon_stacks[0].shape[1]): | |
525 raise ValueError(f'Invalid parameter y_bounds ({y_bounds})') | |
526 z_bounds = None | |
527 else: | |
528 x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_stacks) | |
529 if x_bounds is None: | |
530 x_range = (0, tomo_recon_stacks[0].shape[1]) | |
531 x_slice = int(x_range[1]/2) | |
532 else: | |
533 x_range = (min(x_bounds), max(x_bounds)) | |
534 x_slice = int((x_bounds[0]+x_bounds[1])/2) | |
535 if y_bounds is None: | |
536 y_range = (0, tomo_recon_stacks[0].shape[2]) | |
537 y_slice = int(y_range[1]/2) | |
538 else: | |
539 y_range = (min(y_bounds), max(y_bounds)) | |
540 y_slice = int((y_bounds[0]+y_bounds[1])/2) | |
541 if z_bounds is None: | |
542 z_range = (0, tomo_recon_stacks[0].shape[0]) | |
543 z_slice = int(z_range[1]/2) | |
544 else: | |
545 z_range = (min(z_bounds), max(z_bounds)) | |
546 z_slice = int((z_bounds[0]+z_bounds[1])/2) | |
547 | |
548 # Plot a few reconstructed image slices | |
549 if num_tomo_stacks == 1: | |
550 basetitle = 'recon' | |
551 else: | |
552 basetitle = f'recon stack {i+1}' | |
553 for i, stack in enumerate(tomo_recon_stacks): | |
554 title = f'{basetitle} {res_title} xslice{x_slice}' | |
555 quick_imshow(stack[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]], | |
556 title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, | |
557 block=self.block) | |
558 title = f'{basetitle} {res_title} yslice{y_slice}' | |
559 quick_imshow(stack[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice], | |
560 title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, | |
561 block=self.block) | |
562 title = f'{basetitle} {res_title} zslice{z_slice}' | |
563 quick_imshow(stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], | |
564 title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, | |
565 block=self.block) | |
566 | |
567 # Save test data to file | |
568 # reconstructed data order in each stack: row/z,x,y | |
569 if self.test_mode: | |
570 for i, stack in enumerate(tomo_recon_stacks): | |
571 np.savetxt(f'{self.output_folder}/recon_stack_{i+1}.txt', | |
572 stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e') | |
573 | |
574 # Add image reconstruction to reconstructed data NXprocess | |
575 # reconstructed data order in each stack: row/z,x,y | |
576 nxprocess.data = NXdata() | |
577 nxprocess.attrs['default'] = 'data' | |
578 for k, v in center_info.items(): | |
579 nxprocess[k] = v | |
580 if x_bounds is not None: | |
581 nxprocess.x_bounds = x_bounds | |
582 if y_bounds is not None: | |
583 nxprocess.y_bounds = y_bounds | |
584 if z_bounds is not None: | |
585 nxprocess.z_bounds = z_bounds | |
586 nxprocess.data['reconstructed_data'] = np.asarray([stack[z_range[0]:z_range[1], | |
587 x_range[0]:x_range[1],y_range[0]:y_range[1]] for stack in tomo_recon_stacks]) | |
588 nxprocess.data.attrs['signal'] = 'reconstructed_data' | |
589 | |
590 # Create a copy of the input Nexus object and remove reduced data | |
591 exclude_items = [f'{nxentry._name}/reduced_data/data', f'{nxentry._name}/data/reduced_data'] | |
592 nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items) | |
593 | |
594 # Add the reconstructed data NXprocess to the new Nexus object | |
595 nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']] | |
596 nxentry_copy.reconstructed_data = nxprocess | |
597 if 'data' not in nxentry_copy: | |
598 nxentry_copy.data = NXdata() | |
599 nxentry_copy.attrs['default'] = 'data' | |
600 nxentry_copy.data.makelink(nxprocess.data.reconstructed_data, name='reconstructed_data') | |
601 nxentry_copy.data.attrs['signal'] = 'reconstructed_data' | |
602 | |
603 return(nxroot_copy) | |
604 | |
605 def combine_data(self, nxroot, x_bounds=None, y_bounds=None): | |
606 """Combine the reconstructed tomography stacks. | |
607 """ | |
608 logger.info('Combine the reconstructed tomography stacks') | |
609 | |
610 if not isinstance(nxroot, NXroot): | |
611 raise ValueError(f'Invalid parameter nxroot ({nxroot})') | |
612 nxentry = nxroot[nxroot.attrs['default']] | |
613 if not isinstance(nxentry, NXentry): | |
614 raise ValueError(f'Invalid nxentry ({nxentry})') | |
615 | |
616 # Create plot galaxy path directory and path if needed | |
617 if self.galaxy_flag: | |
618 if not os_path.exists('tomo_combine_plots'): | |
619 mkdir('tomo_combine_plots') | |
620 path = 'tomo_combine_plots' | |
621 else: | |
622 path = self.output_folder | |
623 | |
624 # Check if reconstructed image data is available | |
625 if ('reconstructed_data' not in nxentry or 'reconstructed_data' not in nxentry.data): | |
626 raise KeyError(f'Unable to find valid reconstructed image data in {nxentry}.') | |
627 | |
628 # Create an NXprocess to store combined image reconstruction (meta)data | |
629 nxprocess = NXprocess() | |
630 | |
631 # Get the reconstructed data | |
632 # reconstructed data order: stack,row(z),x,y | |
633 # Note: Nexus cannot follow a link if the data it points to is too big, | |
634 # so get the data from the actual place, not from nxentry.data | |
635 num_tomo_stacks = nxentry.reconstructed_data.data.reconstructed_data.shape[0] | |
636 if num_tomo_stacks == 1: | |
637 logger.info('Only one stack available: leaving combine_data') | |
638 return(None) | |
639 | |
640 # Combine the reconstructed stacks | |
641 # (load one stack at a time to reduce risk of hitting Nexus data access limit) | |
642 t0 = time() | |
643 logger.debug(f'Combining the reconstructed stacks ...') | |
644 tomo_recon_combined = np.asarray(nxentry.reconstructed_data.data.reconstructed_data[0]) | |
645 if num_tomo_stacks > 2: | |
646 tomo_recon_combined = np.concatenate([tomo_recon_combined]+ | |
647 [nxentry.reconstructed_data.data.reconstructed_data[i] | |
648 for i in range(1, num_tomo_stacks-1)]) | |
649 if num_tomo_stacks > 1: | |
650 tomo_recon_combined = np.concatenate([tomo_recon_combined]+ | |
651 [nxentry.reconstructed_data.data.reconstructed_data[num_tomo_stacks-1]]) | |
652 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
653 logger.info(f'Combining the reconstructed stacks took {time()-t0:.2f} seconds') | |
654 | |
655 # Resize the combined tomography data stacks | |
656 # combined data order: row/z,x,y | |
657 if self.test_mode: | |
658 x_bounds = None | |
659 y_bounds = None | |
660 z_bounds = self.test_config.get('z_bounds') | |
661 elif self.galaxy_flag: | |
662 if x_bounds is not None and not is_int_pair(x_bounds, ge=0, | |
663 lt=tomo_recon_stacks[0].shape[1]): | |
664 raise ValueError(f'Invalid parameter x_bounds ({x_bounds})') | |
665 if y_bounds is not None and not is_int_pair(y_bounds, ge=0, | |
666 lt=tomo_recon_stacks[0].shape[1]): | |
667 raise ValueError(f'Invalid parameter y_bounds ({y_bounds})') | |
668 z_bounds = None | |
669 else: | |
670 x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_combined, | |
671 z_only=True) | |
672 if x_bounds is None: | |
673 x_range = (0, tomo_recon_combined.shape[1]) | |
674 x_slice = int(x_range[1]/2) | |
675 else: | |
676 x_range = x_bounds | |
677 x_slice = int((x_bounds[0]+x_bounds[1])/2) | |
678 if y_bounds is None: | |
679 y_range = (0, tomo_recon_combined.shape[2]) | |
680 y_slice = int(y_range[1]/2) | |
681 else: | |
682 y_range = y_bounds | |
683 y_slice = int((y_bounds[0]+y_bounds[1])/2) | |
684 if z_bounds is None: | |
685 z_range = (0, tomo_recon_combined.shape[0]) | |
686 z_slice = int(z_range[1]/2) | |
687 else: | |
688 z_range = z_bounds | |
689 z_slice = int((z_bounds[0]+z_bounds[1])/2) | |
690 | |
691 # Plot a few combined image slices | |
692 quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]], | |
693 title=f'recon combined xslice{x_slice}', path=path, | |
694 save_fig=self.save_figs, save_only=self.save_only, block=self.block) | |
695 quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice], | |
696 title=f'recon combined yslice{y_slice}', path=path, | |
697 save_fig=self.save_figs, save_only=self.save_only, block=self.block) | |
698 quick_imshow(tomo_recon_combined[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], | |
699 title=f'recon combined zslice{z_slice}', path=path, | |
700 save_fig=self.save_figs, save_only=self.save_only, block=self.block) | |
701 | |
702 # Save test data to file | |
703 # combined data order: row/z,x,y | |
704 if self.test_mode: | |
705 np.savetxt(f'{self.output_folder}/recon_combined.txt', tomo_recon_combined[ | |
706 z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e') | |
707 | |
708 # Add image reconstruction to reconstructed data NXprocess | |
709 # combined data order: row/z,x,y | |
710 nxprocess.data = NXdata() | |
711 nxprocess.attrs['default'] = 'data' | |
712 if x_bounds is not None: | |
713 nxprocess.x_bounds = x_bounds | |
714 if y_bounds is not None: | |
715 nxprocess.y_bounds = y_bounds | |
716 if z_bounds is not None: | |
717 nxprocess.z_bounds = z_bounds | |
718 nxprocess.data['combined_data'] = tomo_recon_combined | |
719 nxprocess.data.attrs['signal'] = 'combined_data' | |
720 | |
721 # Create a copy of the input Nexus object and remove reconstructed data | |
722 exclude_items = [f'{nxentry._name}/reconstructed_data/data', | |
723 f'{nxentry._name}/data/reconstructed_data'] | |
724 nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items) | |
725 | |
726 # Add the combined data NXprocess to the new Nexus object | |
727 nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']] | |
728 nxentry_copy.combined_data = nxprocess | |
729 if 'data' not in nxentry_copy: | |
730 nxentry_copy.data = NXdata() | |
731 nxentry_copy.attrs['default'] = 'data' | |
732 nxentry_copy.data.makelink(nxprocess.data.combined_data, name='combined_data') | |
733 nxentry_copy.data.attrs['signal'] = 'combined_data' | |
734 | |
735 return(nxroot_copy) | |
736 | |
737 def _gen_dark(self, nxentry, reduced_data): | |
738 """Generate dark field. | |
739 """ | |
740 # Get the dark field images | |
741 image_key = nxentry.instrument.detector.get('image_key', None) | |
742 if image_key and 'data' in nxentry.instrument.detector: | |
743 field_indices = [index for index, key in enumerate(image_key) if key == 2] | |
744 tdf_stack = nxentry.instrument.detector.data[field_indices,:,:] | |
745 # RV the default NXtomo form does not accomodate bright or dark field stacks | |
746 else: | |
747 dark_field_scans = nxentry.spec_scans.dark_field | |
748 dark_field = FlatField.construct_from_nxcollection(dark_field_scans) | |
749 prefix = str(nxentry.instrument.detector.local_name) | |
750 tdf_stack = dark_field.get_detector_data(prefix) | |
751 if isinstance(tdf_stack, list): | |
752 assert(len(tdf_stack) == 1) # TODO | |
753 tdf_stack = tdf_stack[0] | |
754 | |
755 # Take median | |
756 if tdf_stack.ndim == 2: | |
757 tdf = tdf_stack | |
758 elif tdf_stack.ndim == 3: | |
759 tdf = np.median(tdf_stack, axis=0) | |
760 del tdf_stack | |
761 else: | |
762 raise ValueError(f'Invalid tdf_stack shape ({tdf_stack.shape})') | |
763 | |
764 # Remove dark field intensities above the cutoff | |
765 #RV tdf_cutoff = None | |
766 tdf_cutoff = tdf.min()+2*(np.median(tdf)-tdf.min()) | |
767 logger.debug(f'tdf_cutoff = {tdf_cutoff}') | |
768 if tdf_cutoff is not None: | |
769 if not is_num(tdf_cutoff, ge=0): | |
770 logger.warning(f'Ignoring illegal value of tdf_cutoff {tdf_cutoff}') | |
771 else: | |
772 tdf[tdf > tdf_cutoff] = np.nan | |
773 logger.debug(f'tdf_cutoff = {tdf_cutoff}') | |
774 | |
775 # Remove nans | |
776 tdf_mean = np.nanmean(tdf) | |
777 logger.debug(f'tdf_mean = {tdf_mean}') | |
778 np.nan_to_num(tdf, copy=False, nan=tdf_mean, posinf=tdf_mean, neginf=0.) | |
779 | |
780 # Plot dark field | |
781 if self.galaxy_flag: | |
782 quick_imshow(tdf, title='dark field', path='tomo_reduce_plots', save_fig=self.save_figs, | |
783 save_only=self.save_only) | |
784 elif not self.test_mode: | |
785 quick_imshow(tdf, title='dark field', path=self.output_folder, save_fig=self.save_figs, | |
786 save_only=self.save_only) | |
787 clear_imshow('dark field') | |
788 # quick_imshow(tdf, title='dark field', block=True) | |
789 | |
790 # Add dark field to reduced data NXprocess | |
791 reduced_data.data = NXdata() | |
792 reduced_data.data['dark_field'] = tdf | |
793 | |
794 return(reduced_data) | |
795 | |
796 def _gen_bright(self, nxentry, reduced_data): | |
797 """Generate bright field. | |
798 """ | |
799 # Get the bright field images | |
800 image_key = nxentry.instrument.detector.get('image_key', None) | |
801 if image_key and 'data' in nxentry.instrument.detector: | |
802 field_indices = [index for index, key in enumerate(image_key) if key == 1] | |
803 tbf_stack = nxentry.instrument.detector.data[field_indices,:,:] | |
804 # RV the default NXtomo form does not accomodate bright or dark field stacks | |
805 else: | |
806 bright_field_scans = nxentry.spec_scans.bright_field | |
807 bright_field = FlatField.construct_from_nxcollection(bright_field_scans) | |
808 prefix = str(nxentry.instrument.detector.local_name) | |
809 tbf_stack = bright_field.get_detector_data(prefix) | |
810 if isinstance(tbf_stack, list): | |
811 assert(len(tbf_stack) == 1) # TODO | |
812 tbf_stack = tbf_stack[0] | |
813 | |
814 # Take median if more than one image | |
815 """Median or mean: It may be best to try the median because of some image | |
816 artifacts that arise due to crinkles in the upstream kapton tape windows | |
817 causing some phase contrast images to appear on the detector. | |
818 One thing that also may be useful in a future implementation is to do a | |
819 brightfield adjustment on EACH frame of the tomo based on a ROI in the | |
820 corner of the frame where there is no sample but there is the direct X-ray | |
821 beam because there is frame to frame fluctuations from the incoming beam. | |
822 We don’t typically account for them but potentially could. | |
823 """ | |
824 if tbf_stack.ndim == 2: | |
825 tbf = tbf_stack | |
826 elif tbf_stack.ndim == 3: | |
827 tbf = np.median(tbf_stack, axis=0) | |
828 del tbf_stack | |
829 else: | |
830 raise ValueError(f'Invalid tbf_stack shape ({tbf_stacks.shape})') | |
831 | |
832 # Subtract dark field | |
833 if 'data' in reduced_data and 'dark_field' in reduced_data.data: | |
834 tbf -= reduced_data.data.dark_field | |
835 else: | |
836 logger.warning('Dark field unavailable') | |
837 | |
838 # Set any non-positive values to one | |
839 # (avoid negative bright field values for spikes in dark field) | |
840 tbf[tbf < 1] = 1 | |
841 | |
842 # Plot bright field | |
843 if self.galaxy_flag: | |
844 quick_imshow(tbf, title='bright field', path='tomo_reduce_plots', | |
845 save_fig=self.save_figs, save_only=self.save_only) | |
846 elif not self.test_mode: | |
847 quick_imshow(tbf, title='bright field', path=self.output_folder, | |
848 save_fig=self.save_figs, save_only=self.save_only) | |
849 clear_imshow('bright field') | |
850 # quick_imshow(tbf, title='bright field', block=True) | |
851 | |
852 # Add bright field to reduced data NXprocess | |
853 if 'data' not in reduced_data: | |
854 reduced_data.data = NXdata() | |
855 reduced_data.data['bright_field'] = tbf | |
856 | |
857 return(reduced_data) | |
858 | |
859 def _set_detector_bounds(self, nxentry, reduced_data, img_x_bounds=None): | |
860 """Set vertical detector bounds for each image stack. | |
861 Right now the range is the same for each set in the image stack. | |
862 """ | |
863 if self.test_mode: | |
864 return(tuple(self.test_config['img_x_bounds'])) | |
865 | |
866 # Get the first tomography image and the reference heights | |
867 image_key = nxentry.instrument.detector.get('image_key', None) | |
868 if image_key and 'data' in nxentry.instrument.detector: | |
869 field_indices = [index for index, key in enumerate(image_key) if key == 0] | |
870 first_image = np.asarray(nxentry.instrument.detector.data[field_indices[0],:,:]) | |
871 theta = float(nxentry.sample.rotation_angle[field_indices[0]]) | |
872 z_translation_all = nxentry.sample.z_translation[field_indices] | |
873 vertical_shifts = sorted(list(set(z_translation_all))) | |
874 num_tomo_stacks = len(vertical_shifts) | |
875 else: | |
876 tomo_field_scans = nxentry.spec_scans.tomo_fields | |
877 tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans) | |
878 vertical_shifts = tomo_fields.get_vertical_shifts() | |
879 if not isinstance(vertical_shifts, list): | |
880 vertical_shifts = [vertical_shifts] | |
881 prefix = str(nxentry.instrument.detector.local_name) | |
882 t0 = time() | |
883 first_image = tomo_fields.get_detector_data(prefix, tomo_fields.scan_numbers[0], 0) | |
884 logger.debug(f'Getting first image took {time()-t0:.2f} seconds') | |
885 num_tomo_stacks = len(tomo_fields.scan_numbers) | |
886 theta = tomo_fields.theta_range['start'] | |
887 | |
888 # Select image bounds | |
889 title = f'tomography image at theta={round(theta, 2)+0}' | |
890 if img_x_bounds is not None: | |
891 if is_int_pair(img_x_bounds) and img_x_bounds[0] == -1 and img_x_bounds[1] == -1: | |
892 img_x_bounds = None | |
893 elif not is_index_range(img_x_bounds, ge=0, le=first_image.shape[0]): | |
894 raise ValueError(f'Invalid parameter img_x_bounds ({img_x_bounds})') | |
895 if nxentry.instrument.source.attrs['station'] in ('id1a3', 'id3a'): | |
896 pixel_size = nxentry.instrument.detector.x_pixel_size | |
897 # Try to get a fit from the bright field | |
898 tbf = np.asarray(reduced_data.data.bright_field) | |
899 tbf_shape = tbf.shape | |
900 x_sum = np.sum(tbf, 1) | |
901 x_sum_min = x_sum.min() | |
902 x_sum_max = x_sum.max() | |
903 fit = Fit.fit_data(x_sum, 'rectangle', x=np.array(range(len(x_sum))), form='atan', | |
904 guess=True) | |
905 parameters = fit.best_values | |
906 x_low_fit = parameters.get('center1', None) | |
907 x_upp_fit = parameters.get('center2', None) | |
908 sig_low = parameters.get('sigma1', None) | |
909 sig_upp = parameters.get('sigma2', None) | |
910 have_fit = fit.success and x_low_fit is not None and x_upp_fit is not None and \ | |
911 sig_low is not None and sig_upp is not None and \ | |
912 0 <= x_low_fit < x_upp_fit <= x_sum.size and \ | |
913 (sig_low+sig_upp)/(x_upp_fit-x_low_fit) < 0.1 | |
914 if have_fit: | |
915 # Set a 5% margin on each side | |
916 margin = 0.05*(x_upp_fit-x_low_fit) | |
917 x_low_fit = max(0, x_low_fit-margin) | |
918 x_upp_fit = min(tbf_shape[0], x_upp_fit+margin) | |
919 if num_tomo_stacks == 1: | |
920 if have_fit: | |
921 # Set the default range to enclose the full fitted window | |
922 x_low = int(x_low_fit) | |
923 x_upp = int(x_upp_fit) | |
924 else: | |
925 # Center a default range of 1 mm (RV: can we get this from the slits?) | |
926 num_x_min = int((1.0-0.5*pixel_size)/pixel_size) | |
927 x_low = int(0.5*(tbf_shape[0]-num_x_min)) | |
928 x_upp = x_low+num_x_min | |
929 else: | |
930 # Get the default range from the reference heights | |
931 delta_z = vertical_shifts[1]-vertical_shifts[0] | |
932 for i in range(2, num_tomo_stacks): | |
933 delta_z = min(delta_z, vertical_shifts[i]-vertical_shifts[i-1]) | |
934 logger.debug(f'delta_z = {delta_z}') | |
935 num_x_min = int((delta_z-0.5*pixel_size)/pixel_size) | |
936 logger.debug(f'num_x_min = {num_x_min}') | |
937 if num_x_min > tbf_shape[0]: | |
938 logger.warning('Image bounds and pixel size prevent seamless stacking') | |
939 if have_fit: | |
940 # Center the default range relative to the fitted window | |
941 x_low = int(0.5*(x_low_fit+x_upp_fit-num_x_min)) | |
942 x_upp = x_low+num_x_min | |
943 else: | |
944 # Center the default range | |
945 x_low = int(0.5*(tbf_shape[0]-num_x_min)) | |
946 x_upp = x_low+num_x_min | |
947 if self.galaxy_flag: | |
948 img_x_bounds = (x_low, x_upp) | |
949 else: | |
950 tmp = np.copy(tbf) | |
951 tmp_max = tmp.max() | |
952 tmp[x_low,:] = tmp_max | |
953 tmp[x_upp-1,:] = tmp_max | |
954 quick_imshow(tmp, title='bright field') | |
955 tmp = np.copy(first_image) | |
956 tmp_max = tmp.max() | |
957 tmp[x_low,:] = tmp_max | |
958 tmp[x_upp-1,:] = tmp_max | |
959 quick_imshow(tmp, title=title) | |
960 del tmp | |
961 quick_plot((range(x_sum.size), x_sum), | |
962 ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'), | |
963 ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'), | |
964 title='sum over theta and y') | |
965 print(f'lower bound = {x_low} (inclusive)') | |
966 print(f'upper bound = {x_upp} (exclusive)]') | |
967 accept = input_yesno('Accept these bounds (y/n)?', 'y') | |
968 clear_imshow('bright field') | |
969 clear_imshow(title) | |
970 clear_plot('sum over theta and y') | |
971 if accept: | |
972 img_x_bounds = (x_low, x_upp) | |
973 else: | |
974 while True: | |
975 mask, img_x_bounds = draw_mask_1d(x_sum, title='select x data range', | |
976 legend='sum over theta and y') | |
977 if len(img_x_bounds) == 1: | |
978 break | |
979 else: | |
980 print(f'Choose a single connected data range') | |
981 img_x_bounds = tuple(img_x_bounds[0]) | |
982 if (num_tomo_stacks > 1 and img_x_bounds[1]-img_x_bounds[0]+1 < | |
983 int((delta_z-0.5*pixel_size)/pixel_size)): | |
984 logger.warning('Image bounds and pixel size prevent seamless stacking') | |
985 else: | |
986 if num_tomo_stacks > 1: | |
987 raise NotImplementedError('Selecting image bounds for multiple stacks on FMB') | |
988 # For FMB: use the first tomography image to select range | |
989 # RV: revisit if they do tomography with multiple stacks | |
990 x_sum = np.sum(first_image, 1) | |
991 x_sum_min = x_sum.min() | |
992 x_sum_max = x_sum.max() | |
993 if self.galaxy_flag: | |
994 if img_x_bounds is None: | |
995 img_x_bounds = (0, first_image.shape[0]) | |
996 else: | |
997 quick_imshow(first_image, title=title) | |
998 print('Select vertical data reduction range from first tomography image') | |
999 img_x_bounds = select_image_bounds(first_image, 0, title=title) | |
1000 clear_imshow(title) | |
1001 if img_x_bounds is None: | |
1002 raise ValueError('Unable to select image bounds') | |
1003 | |
1004 # Plot results | |
1005 if self.galaxy_flag: | |
1006 path = 'tomo_reduce_plots' | |
1007 else: | |
1008 path = self.output_folder | |
1009 x_low = img_x_bounds[0] | |
1010 x_upp = img_x_bounds[1] | |
1011 tmp = np.copy(first_image) | |
1012 tmp_max = tmp.max() | |
1013 tmp[x_low,:] = tmp_max | |
1014 tmp[x_upp-1,:] = tmp_max | |
1015 quick_imshow(tmp, title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, | |
1016 block=self.block) | |
1017 del tmp | |
1018 quick_plot((range(x_sum.size), x_sum), | |
1019 ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'), | |
1020 ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'), | |
1021 title='sum over theta and y', path=path, save_fig=self.save_figs, | |
1022 save_only=self.save_only, block=self.block) | |
1023 | |
1024 return(img_x_bounds) | |
1025 | |
1026 def _set_zoom_or_skip(self): | |
1027 """Set zoom and/or theta skip to reduce memory the requirement for the analysis. | |
1028 """ | |
1029 # if input_yesno('\nDo you want to zoom in to reduce memory requirement (y/n)?', 'n'): | |
1030 # zoom_perc = input_int(' Enter zoom percentage', ge=1, le=100) | |
1031 # else: | |
1032 # zoom_perc = None | |
1033 zoom_perc = None | |
1034 # if input_yesno('Do you want to skip thetas to reduce memory requirement (y/n)?', 'n'): | |
1035 # num_theta_skip = input_int(' Enter the number skip theta interval', ge=0, | |
1036 # lt=num_theta) | |
1037 # else: | |
1038 # num_theta_skip = None | |
1039 num_theta_skip = None | |
1040 logger.debug(f'zoom_perc = {zoom_perc}') | |
1041 logger.debug(f'num_theta_skip = {num_theta_skip}') | |
1042 | |
1043 return(zoom_perc, num_theta_skip) | |
1044 | |
1045 def _gen_tomo(self, nxentry, reduced_data): | |
1046 """Generate tomography fields. | |
1047 """ | |
1048 # Get full bright field | |
1049 tbf = np.asarray(reduced_data.data.bright_field) | |
1050 tbf_shape = tbf.shape | |
1051 | |
1052 # Get image bounds | |
1053 img_x_bounds = tuple(reduced_data.get('img_x_bounds', (0, tbf_shape[0]))) | |
1054 img_y_bounds = tuple(reduced_data.get('img_y_bounds', (0, tbf_shape[1]))) | |
1055 | |
1056 # Get resized dark field | |
1057 # if 'dark_field' in data: | |
1058 # tbf = np.asarray(reduced_data.data.dark_field[ | |
1059 # img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]]) | |
1060 # else: | |
1061 # logger.warning('Dark field unavailable') | |
1062 # tdf = None | |
1063 tdf = None | |
1064 | |
1065 # Resize bright field | |
1066 if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]): | |
1067 tbf = tbf[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]] | |
1068 | |
1069 # Get the tomography images | |
1070 image_key = nxentry.instrument.detector.get('image_key', None) | |
1071 if image_key and 'data' in nxentry.instrument.detector: | |
1072 field_indices_all = [index for index, key in enumerate(image_key) if key == 0] | |
1073 z_translation_all = nxentry.sample.z_translation[field_indices_all] | |
1074 z_translation_levels = sorted(list(set(z_translation_all))) | |
1075 num_tomo_stacks = len(z_translation_levels) | |
1076 tomo_stacks = num_tomo_stacks*[np.array([])] | |
1077 horizontal_shifts = [] | |
1078 vertical_shifts = [] | |
1079 thetas = None | |
1080 tomo_stacks = [] | |
1081 for i, z_translation in enumerate(z_translation_levels): | |
1082 field_indices = [field_indices_all[index] | |
1083 for index, z in enumerate(z_translation_all) if z == z_translation] | |
1084 horizontal_shift = list(set(nxentry.sample.x_translation[field_indices])) | |
1085 assert(len(horizontal_shift) == 1) | |
1086 horizontal_shifts += horizontal_shift | |
1087 vertical_shift = list(set(nxentry.sample.z_translation[field_indices])) | |
1088 assert(len(vertical_shift) == 1) | |
1089 vertical_shifts += vertical_shift | |
1090 sequence_numbers = nxentry.instrument.detector.sequence_number[field_indices] | |
1091 if thetas is None: | |
1092 thetas = np.asarray(nxentry.sample.rotation_angle[field_indices]) \ | |
1093 [sequence_numbers] | |
1094 else: | |
1095 assert(all(thetas[i] == nxentry.sample.rotation_angle[field_indices[index]] | |
1096 for i, index in enumerate(sequence_numbers))) | |
1097 assert(list(set(sequence_numbers)) == [i for i in range(len(sequence_numbers))]) | |
1098 if list(sequence_numbers) == [i for i in range(len(sequence_numbers))]: | |
1099 tomo_stack = np.asarray(nxentry.instrument.detector.data[field_indices]) | |
1100 else: | |
1101 raise ValueError('Unable to load the tomography images') | |
1102 tomo_stacks.append(tomo_stack) | |
1103 else: | |
1104 tomo_field_scans = nxentry.spec_scans.tomo_fields | |
1105 tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans) | |
1106 horizontal_shifts = tomo_fields.get_horizontal_shifts() | |
1107 vertical_shifts = tomo_fields.get_vertical_shifts() | |
1108 prefix = str(nxentry.instrument.detector.local_name) | |
1109 t0 = time() | |
1110 tomo_stacks = tomo_fields.get_detector_data(prefix) | |
1111 logger.debug(f'Getting tomography images took {time()-t0:.2f} seconds') | |
1112 logger.debug(f'Getting all images took {time()-t0:.2f} seconds') | |
1113 thetas = np.linspace(tomo_fields.theta_range['start'], tomo_fields.theta_range['end'], | |
1114 tomo_fields.theta_range['num']) | |
1115 if not isinstance(tomo_stacks, list): | |
1116 horizontal_shifts = [horizontal_shifts] | |
1117 vertical_shifts = [vertical_shifts] | |
1118 tomo_stacks = [tomo_stacks] | |
1119 | |
1120 reduced_tomo_stacks = [] | |
1121 if self.galaxy_flag: | |
1122 path = 'tomo_reduce_plots' | |
1123 else: | |
1124 path = self.output_folder | |
1125 for i, tomo_stack in enumerate(tomo_stacks): | |
1126 # Resize the tomography images | |
1127 # Right now the range is the same for each set in the image stack. | |
1128 if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]): | |
1129 t0 = time() | |
1130 tomo_stack = tomo_stack[:,img_x_bounds[0]:img_x_bounds[1], | |
1131 img_y_bounds[0]:img_y_bounds[1]].astype('float64') | |
1132 logger.debug(f'Resizing tomography images took {time()-t0:.2f} seconds') | |
1133 | |
1134 # Subtract dark field | |
1135 if tdf is not None: | |
1136 t0 = time() | |
1137 with set_numexpr_threads(self.num_core): | |
1138 ne.evaluate('tomo_stack-tdf', out=tomo_stack) | |
1139 logger.debug(f'Subtracting dark field took {time()-t0:.2f} seconds') | |
1140 | |
1141 # Normalize | |
1142 t0 = time() | |
1143 with set_numexpr_threads(self.num_core): | |
1144 ne.evaluate('tomo_stack/tbf', out=tomo_stack, truediv=True) | |
1145 logger.debug(f'Normalizing took {time()-t0:.2f} seconds') | |
1146 | |
1147 # Remove non-positive values and linearize data | |
1148 t0 = time() | |
1149 cutoff = 1.e-6 | |
1150 with set_numexpr_threads(self.num_core): | |
1151 ne.evaluate('where(tomo_stack<cutoff, cutoff, tomo_stack)', out=tomo_stack) | |
1152 with set_numexpr_threads(self.num_core): | |
1153 ne.evaluate('-log(tomo_stack)', out=tomo_stack) | |
1154 logger.debug('Removing non-positive values and linearizing data took '+ | |
1155 f'{time()-t0:.2f} seconds') | |
1156 | |
1157 # Get rid of nans/infs that may be introduced by normalization | |
1158 t0 = time() | |
1159 np.where(np.isfinite(tomo_stack), tomo_stack, 0.) | |
1160 logger.debug(f'Remove nans/infs took {time()-t0:.2f} seconds') | |
1161 | |
1162 # Downsize tomography stack to smaller size | |
1163 # TODO use theta_skip as well | |
1164 tomo_stack = tomo_stack.astype('float32') | |
1165 if not self.test_mode: | |
1166 if len(tomo_stacks) == 1: | |
1167 title = f'red fullres theta {round(thetas[0], 2)+0}' | |
1168 else: | |
1169 title = f'red stack {i+1} fullres theta {round(thetas[0], 2)+0}' | |
1170 quick_imshow(tomo_stack[0,:,:], title=title, path=path, save_fig=self.save_figs, | |
1171 save_only=self.save_only, block=self.block) | |
1172 # if not self.block: | |
1173 # clear_imshow(title) | |
1174 if False and zoom_perc != 100: | |
1175 t0 = time() | |
1176 logger.debug(f'Zooming in ...') | |
1177 tomo_zoom_list = [] | |
1178 for j in range(tomo_stack.shape[0]): | |
1179 tomo_zoom = spi.zoom(tomo_stack[j,:,:], 0.01*zoom_perc) | |
1180 tomo_zoom_list.append(tomo_zoom) | |
1181 tomo_stack = np.stack([tomo_zoom for tomo_zoom in tomo_zoom_list]) | |
1182 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
1183 logger.info(f'Zooming in took {time()-t0:.2f} seconds') | |
1184 del tomo_zoom_list | |
1185 if not self.test_mode: | |
1186 title = f'red stack {zoom_perc}p theta {round(thetas[0], 2)+0}' | |
1187 quick_imshow(tomo_stack[0,:,:], title=title, path=path, save_fig=self.save_figs, | |
1188 save_only=self.save_only, block=self.block) | |
1189 # if not self.block: | |
1190 # clear_imshow(title) | |
1191 | |
1192 # Save test data to file | |
1193 if self.test_mode: | |
1194 # row_index = int(tomo_stack.shape[0]/2) | |
1195 # np.savetxt(f'{self.output_folder}/red_stack_{i+1}.txt', tomo_stack[row_index,:,:], | |
1196 # fmt='%.6e') | |
1197 row_index = int(tomo_stack.shape[1]/2) | |
1198 np.savetxt(f'{self.output_folder}/red_stack_{i+1}.txt', tomo_stack[:,row_index,:], | |
1199 fmt='%.6e') | |
1200 | |
1201 # Combine resized stacks | |
1202 reduced_tomo_stacks.append(tomo_stack) | |
1203 | |
1204 # Add tomo field info to reduced data NXprocess | |
1205 reduced_data['rotation_angle'] = thetas | |
1206 reduced_data['x_translation'] = np.asarray(horizontal_shifts) | |
1207 reduced_data['z_translation'] = np.asarray(vertical_shifts) | |
1208 reduced_data.data['tomo_fields'] = np.asarray(reduced_tomo_stacks) | |
1209 | |
1210 if tdf is not None: | |
1211 del tdf | |
1212 del tbf | |
1213 | |
1214 return(reduced_data) | |
1215 | |
1216 def _find_center_one_plane(self, sinogram, row, thetas, eff_pixel_size, cross_sectional_dim, | |
1217 path=None, tol=0.1, num_core=1): | |
1218 """Find center for a single tomography plane. | |
1219 """ | |
1220 # Try automatic center finding routines for initial value | |
1221 # sinogram index order: theta,column | |
1222 # need column,theta for iradon, so take transpose | |
1223 sinogram = np.asarray(sinogram) | |
1224 sinogram_T = sinogram.T | |
1225 center = sinogram.shape[1]/2 | |
1226 | |
1227 # Try using Nghia Vo’s method | |
1228 t0 = time() | |
1229 if num_core > num_core_tomopy_limit: | |
1230 logger.debug(f'Running find_center_vo on {num_core_tomopy_limit} cores ...') | |
1231 tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core_tomopy_limit) | |
1232 else: | |
1233 logger.debug(f'Running find_center_vo on {num_core} cores ...') | |
1234 tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core) | |
1235 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
1236 logger.info(f'Finding the center using Nghia Vo’s method took {time()-t0:.2f} seconds') | |
1237 center_offset_vo = tomo_center-center | |
1238 logger.info(f'Center at row {row} using Nghia Vo’s method = {center_offset_vo:.2f}') | |
1239 t0 = time() | |
1240 logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...') | |
1241 recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas, | |
1242 eff_pixel_size, cross_sectional_dim, False, num_core) | |
1243 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
1244 logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds') | |
1245 | |
1246 title = f'edges row{row} center offset{center_offset_vo:.2f} Vo' | |
1247 self._plot_edges_one_plane(recon_plane, title, path=path) | |
1248 | |
1249 # Try using phase correlation method | |
1250 # if input_yesno('Try finding center using phase correlation (y/n)?', 'n'): | |
1251 # t0 = time() | |
1252 # logger.debug(f'Running find_center_pc ...') | |
1253 # tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=0.1, rotc_guess=tomo_center) | |
1254 # error = 1. | |
1255 # while error > tol: | |
1256 # prev = tomo_center | |
1257 # tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=tol, | |
1258 # rotc_guess=tomo_center) | |
1259 # error = np.abs(tomo_center-prev) | |
1260 # logger.debug(f'... done in {time()-t0:.2f} seconds') | |
1261 # logger.info('Finding the center using the phase correlation method took '+ | |
1262 # f'{time()-t0:.2f} seconds') | |
1263 # center_offset = tomo_center-center | |
1264 # print(f'Center at row {row} using phase correlation = {center_offset:.2f}') | |
1265 # t0 = time() | |
1266 # logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...') | |
1267 # recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas, | |
1268 # eff_pixel_size, cross_sectional_dim, False, num_core) | |
1269 # logger.debug(f'... done in {time()-t0:.2f} seconds') | |
1270 # logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds') | |
1271 # | |
1272 # title = f'edges row{row} center_offset{center_offset:.2f} PC' | |
1273 # self._plot_edges_one_plane(recon_plane, title, path=path) | |
1274 | |
1275 # Select center location | |
1276 # if input_yesno('Accept a center location (y) or continue search (n)?', 'y'): | |
1277 if True: | |
1278 # center_offset = input_num(' Enter chosen center offset', ge=-center, le=center, | |
1279 # default=center_offset_vo) | |
1280 center_offset = center_offset_vo | |
1281 del sinogram_T | |
1282 del recon_plane | |
1283 return float(center_offset) | |
1284 | |
1285 # perform center finding search | |
1286 while True: | |
1287 center_offset_low = input_int('\nEnter lower bound for center offset', ge=-center, | |
1288 le=center) | |
1289 center_offset_upp = input_int('Enter upper bound for center offset', | |
1290 ge=center_offset_low, le=center) | |
1291 if center_offset_upp == center_offset_low: | |
1292 center_offset_step = 1 | |
1293 else: | |
1294 center_offset_step = input_int('Enter step size for center offset search', ge=1, | |
1295 le=center_offset_upp-center_offset_low) | |
1296 num_center_offset = 1+int((center_offset_upp-center_offset_low)/center_offset_step) | |
1297 center_offsets = np.linspace(center_offset_low, center_offset_upp, num_center_offset) | |
1298 for center_offset in center_offsets: | |
1299 if center_offset == center_offset_vo: | |
1300 continue | |
1301 t0 = time() | |
1302 logger.debug(f'Running _reconstruct_one_plane on {num_core} cores ...') | |
1303 recon_plane = self._reconstruct_one_plane(sinogram_T, center_offset+center, thetas, | |
1304 eff_pixel_size, cross_sectional_dim, False, num_core) | |
1305 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
1306 logger.info(f'Reconstructing center_offset {center_offset} took '+ | |
1307 f'{time()-t0:.2f} seconds') | |
1308 title = f'edges row{row} center_offset{center_offset:.2f}' | |
1309 self._plot_edges_one_plane(recon_plane, title, path=path) | |
1310 if input_int('\nContinue (0) or end the search (1)', ge=0, le=1): | |
1311 break | |
1312 | |
1313 del sinogram_T | |
1314 del recon_plane | |
1315 center_offset = input_num(' Enter chosen center offset', ge=-center, le=center) | |
1316 return float(center_offset) | |
1317 | |
1318 def _reconstruct_one_plane(self, tomo_plane_T, center, thetas, eff_pixel_size, | |
1319 cross_sectional_dim, plot_sinogram=True, num_core=1): | |
1320 """Invert the sinogram for a single tomography plane. | |
1321 """ | |
1322 # tomo_plane_T index order: column,theta | |
1323 assert(0 <= center < tomo_plane_T.shape[0]) | |
1324 center_offset = center-tomo_plane_T.shape[0]/2 | |
1325 two_offset = 2*int(np.round(center_offset)) | |
1326 two_offset_abs = np.abs(two_offset) | |
1327 max_rad = int(0.55*(cross_sectional_dim/eff_pixel_size)) # 10% slack to avoid edge effects | |
1328 if max_rad > 0.5*tomo_plane_T.shape[0]: | |
1329 max_rad = 0.5*tomo_plane_T.shape[0] | |
1330 dist_from_edge = max(1, int(np.floor((tomo_plane_T.shape[0]-two_offset_abs)/2.)-max_rad)) | |
1331 if two_offset >= 0: | |
1332 logger.debug(f'sinogram range = [{two_offset+dist_from_edge}, {-dist_from_edge}]') | |
1333 sinogram = tomo_plane_T[two_offset+dist_from_edge:-dist_from_edge,:] | |
1334 else: | |
1335 logger.debug(f'sinogram range = [{dist_from_edge}, {two_offset-dist_from_edge}]') | |
1336 sinogram = tomo_plane_T[dist_from_edge:two_offset-dist_from_edge,:] | |
1337 if not self.galaxy_flag and plot_sinogram: | |
1338 quick_imshow(sinogram.T, f'sinogram center offset{center_offset:.2f}', aspect='auto', | |
1339 path=self.output_folder, save_fig=self.save_figs, save_only=self.save_only, | |
1340 block=self.block) | |
1341 | |
1342 # Inverting sinogram | |
1343 t0 = time() | |
1344 recon_sinogram = iradon(sinogram, theta=thetas, circle=True) | |
1345 logger.debug(f'Inverting sinogram took {time()-t0:.2f} seconds') | |
1346 del sinogram | |
1347 | |
1348 # Performing Gaussian filtering and removing ring artifacts | |
1349 recon_parameters = None#self.config.get('recon_parameters') | |
1350 if recon_parameters is None: | |
1351 sigma = 1.0 | |
1352 ring_width = 15 | |
1353 else: | |
1354 sigma = recon_parameters.get('gaussian_sigma', 1.0) | |
1355 if not is_num(sigma, ge=0.0): | |
1356 logger.warning(f'Invalid gaussian_sigma ({sigma}) in _reconstruct_one_plane, '+ | |
1357 'set to a default value of 1.0') | |
1358 sigma = 1.0 | |
1359 ring_width = recon_parameters.get('ring_width', 15) | |
1360 if not is_int(ring_width, ge=0): | |
1361 logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+ | |
1362 'set to a default value of 15') | |
1363 ring_width = 15 | |
1364 t0 = time() | |
1365 recon_sinogram = spi.gaussian_filter(recon_sinogram, sigma, mode='nearest') | |
1366 recon_clean = np.expand_dims(recon_sinogram, axis=0) | |
1367 del recon_sinogram | |
1368 recon_clean = tomopy.misc.corr.remove_ring(recon_clean, rwidth=ring_width, ncore=num_core) | |
1369 logger.debug(f'Filtering and removing ring artifacts took {time()-t0:.2f} seconds') | |
1370 | |
1371 return recon_clean | |
1372 | |
1373 def _plot_edges_one_plane(self, recon_plane, title, path=None): | |
1374 vis_parameters = None#self.config.get('vis_parameters') | |
1375 if vis_parameters is None: | |
1376 weight = 0.1 | |
1377 else: | |
1378 weight = vis_parameters.get('denoise_weight', 0.1) | |
1379 if not is_num(weight, ge=0.0): | |
1380 logger.warning(f'Invalid weight ({weight}) in _plot_edges_one_plane, '+ | |
1381 'set to a default value of 0.1') | |
1382 weight = 0.1 | |
1383 edges = denoise_tv_chambolle(recon_plane, weight=weight) | |
1384 vmax = np.max(edges[0,:,:]) | |
1385 vmin = -vmax | |
1386 if path is None: | |
1387 path = self.output_folder | |
1388 quick_imshow(edges[0,:,:], f'{title} coolwarm', path=path, cmap='coolwarm', | |
1389 save_fig=self.save_figs, save_only=self.save_only, block=self.block) | |
1390 quick_imshow(edges[0,:,:], f'{title} gray', path=path, cmap='gray', vmin=vmin, vmax=vmax, | |
1391 save_fig=self.save_figs, save_only=self.save_only, block=self.block) | |
1392 del edges | |
1393 | |
1394 def _reconstruct_one_tomo_stack(self, tomo_stack, thetas, center_offsets=[], num_core=1, | |
1395 algorithm='gridrec'): | |
1396 """Reconstruct a single tomography stack. | |
1397 """ | |
1398 # tomo_stack order: row,theta,column | |
1399 # input thetas must be in degrees | |
1400 # centers_offset: tomography axis shift in pixels relative to column center | |
1401 # RV should we remove stripes? | |
1402 # https://tomopy.readthedocs.io/en/latest/api/tomopy.prep.stripe.html | |
1403 # RV should we remove rings? | |
1404 # https://tomopy.readthedocs.io/en/latest/api/tomopy.misc.corr.html | |
1405 # RV: Add an option to do (extra) secondary iterations later or to do some sort of convergence test? | |
1406 if not len(center_offsets): | |
1407 centers = np.zeros((tomo_stack.shape[0])) | |
1408 elif len(center_offsets) == 2: | |
1409 centers = np.linspace(center_offsets[0], center_offsets[1], tomo_stack.shape[0]) | |
1410 else: | |
1411 if center_offsets.size != tomo_stack.shape[0]: | |
1412 raise ValueError('center_offsets dimension mismatch in reconstruct_one_tomo_stack') | |
1413 centers = center_offsets | |
1414 centers += tomo_stack.shape[2]/2 | |
1415 | |
1416 # Get reconstruction parameters | |
1417 recon_parameters = None#self.config.get('recon_parameters') | |
1418 if recon_parameters is None: | |
1419 sigma = 2.0 | |
1420 secondary_iters = 0 | |
1421 ring_width = 15 | |
1422 else: | |
1423 sigma = recon_parameters.get('stripe_fw_sigma', 2.0) | |
1424 if not is_num(sigma, ge=0): | |
1425 logger.warning(f'Invalid stripe_fw_sigma ({sigma}) in '+ | |
1426 '_reconstruct_one_tomo_stack, set to a default value of 2.0') | |
1427 ring_width = 15 | |
1428 secondary_iters = recon_parameters.get('secondary_iters', 0) | |
1429 if not is_int(secondary_iters, ge=0): | |
1430 logger.warning(f'Invalid secondary_iters ({secondary_iters}) in '+ | |
1431 '_reconstruct_one_tomo_stack, set to a default value of 0 (skip them)') | |
1432 ring_width = 0 | |
1433 ring_width = recon_parameters.get('ring_width', 15) | |
1434 if not is_int(ring_width, ge=0): | |
1435 logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+ | |
1436 'set to a default value of 15') | |
1437 ring_width = 15 | |
1438 | |
1439 # Remove horizontal stripe | |
1440 t0 = time() | |
1441 if num_core > num_core_tomopy_limit: | |
1442 logger.debug('Running remove_stripe_fw on {num_core_tomopy_limit} cores ...') | |
1443 tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma, | |
1444 ncore=num_core_tomopy_limit) | |
1445 else: | |
1446 logger.debug(f'Running remove_stripe_fw on {num_core} cores ...') | |
1447 tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma, | |
1448 ncore=num_core) | |
1449 logger.debug(f'... tomopy.prep.stripe.remove_stripe_fw took {time()-t0:.2f} seconds') | |
1450 | |
1451 # Perform initial image reconstruction | |
1452 logger.debug('Performing initial image reconstruction') | |
1453 t0 = time() | |
1454 logger.debug(f'Running recon on {num_core} cores ...') | |
1455 tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers, | |
1456 sinogram_order=True, algorithm=algorithm, ncore=num_core) | |
1457 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
1458 logger.info(f'Performing initial image reconstruction took {time()-t0:.2f} seconds') | |
1459 | |
1460 # Run optional secondary iterations | |
1461 if secondary_iters > 0: | |
1462 logger.debug(f'Running {secondary_iters} secondary iterations') | |
1463 #options = {'method':'SIRT_CUDA', 'proj_type':'cuda', 'num_iter':secondary_iters} | |
1464 #RV: doesn't work for me: | |
1465 #"Error: CUDA error 803: system has unsupported display driver/cuda driver combination." | |
1466 #options = {'method':'SIRT', 'proj_type':'linear', 'MinConstraint': 0, 'num_iter':secondary_iters} | |
1467 #SIRT did not finish while running overnight | |
1468 #options = {'method':'SART', 'proj_type':'linear', 'num_iter':secondary_iters} | |
1469 options = {'method':'SART', 'proj_type':'linear', 'MinConstraint': 0, | |
1470 'num_iter':secondary_iters} | |
1471 t0 = time() | |
1472 logger.debug(f'Running recon on {num_core} cores ...') | |
1473 tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers, | |
1474 init_recon=tomo_recon_stack, options=options, sinogram_order=True, | |
1475 algorithm=tomopy.astra, ncore=num_core) | |
1476 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
1477 logger.info(f'Performing secondary iterations took {time()-t0:.2f} seconds') | |
1478 | |
1479 # Remove ring artifacts | |
1480 t0 = time() | |
1481 tomopy.misc.corr.remove_ring(tomo_recon_stack, rwidth=ring_width, out=tomo_recon_stack, | |
1482 ncore=num_core) | |
1483 logger.debug(f'Removing ring artifacts took {time()-t0:.2f} seconds') | |
1484 | |
1485 return tomo_recon_stack | |
1486 | |
1487 def _resize_reconstructed_data(self, data, z_only=False): | |
1488 """Resize the reconstructed tomography data. | |
1489 """ | |
1490 # Data order: row(z),x,y or stack,row(z),x,y | |
1491 if isinstance(data, list): | |
1492 for stack in data: | |
1493 assert(stack.ndim == 3) | |
1494 num_tomo_stacks = len(data) | |
1495 tomo_recon_stacks = data | |
1496 else: | |
1497 assert(data.ndim == 3) | |
1498 num_tomo_stacks = 1 | |
1499 tomo_recon_stacks = [data] | |
1500 | |
1501 if z_only: | |
1502 x_bounds = None | |
1503 y_bounds = None | |
1504 else: | |
1505 # Selecting x bounds (in yz-plane) | |
1506 tomosum = 0 | |
1507 [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,2)) | |
1508 for i in range(num_tomo_stacks)] | |
1509 select_x_bounds = input_yesno('\nDo you want to change the image x-bounds (y/n)?', 'y') | |
1510 if not select_x_bounds: | |
1511 x_bounds = None | |
1512 else: | |
1513 accept = False | |
1514 index_ranges = None | |
1515 while not accept: | |
1516 mask, x_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, | |
1517 title='select x data range', legend='recon stack sum yz') | |
1518 while len(x_bounds) != 1: | |
1519 print('Please select exactly one continuous range') | |
1520 mask, x_bounds = draw_mask_1d(tomosum, title='select x data range', | |
1521 legend='recon stack sum yz') | |
1522 x_bounds = x_bounds[0] | |
1523 # quick_plot(tomosum, vlines=x_bounds, title='recon stack sum yz') | |
1524 # print(f'x_bounds = {x_bounds} (lower bound inclusive, upper bound '+ | |
1525 # 'exclusive)') | |
1526 # accept = input_yesno('Accept these bounds (y/n)?', 'y') | |
1527 accept = True | |
1528 logger.debug(f'x_bounds = {x_bounds}') | |
1529 | |
1530 # Selecting y bounds (in xz-plane) | |
1531 tomosum = 0 | |
1532 [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,1)) | |
1533 for i in range(num_tomo_stacks)] | |
1534 select_y_bounds = input_yesno('\nDo you want to change the image y-bounds (y/n)?', 'y') | |
1535 if not select_y_bounds: | |
1536 y_bounds = None | |
1537 else: | |
1538 accept = False | |
1539 index_ranges = None | |
1540 while not accept: | |
1541 mask, y_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, | |
1542 title='select x data range', legend='recon stack sum xz') | |
1543 while len(y_bounds) != 1: | |
1544 print('Please select exactly one continuous range') | |
1545 mask, y_bounds = draw_mask_1d(tomosum, title='select x data range', | |
1546 legend='recon stack sum xz') | |
1547 y_bounds = y_bounds[0] | |
1548 # quick_plot(tomosum, vlines=y_bounds, title='recon stack sum xz') | |
1549 # print(f'y_bounds = {y_bounds} (lower bound inclusive, upper bound '+ | |
1550 # 'exclusive)') | |
1551 # accept = input_yesno('Accept these bounds (y/n)?', 'y') | |
1552 accept = True | |
1553 logger.debug(f'y_bounds = {y_bounds}') | |
1554 | |
1555 # Selecting z bounds (in xy-plane) (only valid for a single image stack) | |
1556 if num_tomo_stacks != 1: | |
1557 z_bounds = None | |
1558 else: | |
1559 tomosum = 0 | |
1560 [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(1,2)) | |
1561 for i in range(num_tomo_stacks)] | |
1562 select_z_bounds = input_yesno('Do you want to change the image z-bounds (y/n)?', 'n') | |
1563 if not select_z_bounds: | |
1564 z_bounds = None | |
1565 else: | |
1566 accept = False | |
1567 index_ranges = None | |
1568 while not accept: | |
1569 mask, z_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, | |
1570 title='select x data range', legend='recon stack sum xy') | |
1571 while len(z_bounds) != 1: | |
1572 print('Please select exactly one continuous range') | |
1573 mask, z_bounds = draw_mask_1d(tomosum, title='select x data range', | |
1574 legend='recon stack sum xy') | |
1575 z_bounds = z_bounds[0] | |
1576 # quick_plot(tomosum, vlines=z_bounds, title='recon stack sum xy') | |
1577 # print(f'z_bounds = {z_bounds} (lower bound inclusive, upper bound '+ | |
1578 # 'exclusive)') | |
1579 # accept = input_yesno('Accept these bounds (y/n)?', 'y') | |
1580 accept = True | |
1581 logger.debug(f'z_bounds = {z_bounds}') | |
1582 | |
1583 return(x_bounds, y_bounds, z_bounds) | |
1584 | |
1585 | |
1586 def run_tomo(input_file:str, output_file:str, modes:list[str], center_file=None, num_core=-1, | |
1587 output_folder='.', save_figs='no', test_mode=False) -> None: | |
1588 | |
1589 if test_mode: | |
1590 logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s' | |
1591 level = logging.getLevelName('INFO') | |
1592 logging.basicConfig(filename=f'{output_folder}/tomo.log', filemode='w', | |
1593 format=logging_format, level=level, force=True) | |
1594 logger.info(f'input_file = {input_file}') | |
1595 logger.info(f'center_file = {center_file}') | |
1596 logger.info(f'output_file = {output_file}') | |
1597 logger.debug(f'modes= {modes}') | |
1598 logger.debug(f'num_core= {num_core}') | |
1599 logger.info(f'output_folder = {output_folder}') | |
1600 logger.info(f'save_figs = {save_figs}') | |
1601 logger.info(f'test_mode = {test_mode}') | |
1602 | |
1603 # Check for correction modes | |
1604 legal_modes = ['reduce_data', 'find_center', 'reconstruct_data', 'combine_data', 'all'] | |
1605 if modes is None: | |
1606 modes = ['all'] | |
1607 if not all(True if mode in legal_modes else False for mode in modes): | |
1608 raise ValueError(f'Invalid parameter modes ({modes})') | |
1609 | |
1610 # Instantiate Tomo object | |
1611 tomo = Tomo(num_core=num_core, output_folder=output_folder, save_figs=save_figs, | |
1612 test_mode=test_mode) | |
1613 | |
1614 # Read input file | |
1615 data = tomo.read(input_file) | |
1616 | |
1617 # Generate reduced tomography images | |
1618 if 'reduce_data' in modes or 'all' in modes: | |
1619 data = tomo.gen_reduced_data(data) | |
1620 | |
1621 # Find rotation axis centers for the tomography stacks. | |
1622 center_data = None | |
1623 if 'find_center' in modes or 'all' in modes: | |
1624 center_data = tomo.find_centers(data) | |
1625 | |
1626 # Reconstruct tomography stacks | |
1627 if 'reconstruct_data' in modes or 'all' in modes: | |
1628 if center_data is None: | |
1629 # Read input file | |
1630 center_data = tomo.read(center_file) | |
1631 data = tomo.reconstruct_data(data, center_data) | |
1632 center_data = None | |
1633 | |
1634 # Combine reconstructed tomography stacks | |
1635 if 'combine_data' in modes or 'all' in modes: | |
1636 data = tomo.combine_data(data) | |
1637 | |
1638 # Write output file | |
1639 if data is not None and not test_mode: | |
1640 if center_data is None: | |
1641 data = tomo.write(data, output_file) | |
1642 else: | |
1643 data = tomo.write(center_data, output_file) | |
1644 | |
1645 logger.info(f'Completed modes: {modes}') |