Mercurial > repos > rv43 > tomo
comparison workflow/run_tomo.py @ 69:fba792d5f83b draft
planemo upload for repository https://github.com/rolfverberg/galaxytools commit ab9f412c362a4ab986d00e21d5185cfcf82485c1
author | rv43 |
---|---|
date | Fri, 10 Mar 2023 16:02:04 +0000 |
parents | |
children | 1cf15b61cd83 |
comparison
equal
deleted
inserted
replaced
68:ba5866d0251d | 69:fba792d5f83b |
---|---|
1 #!/usr/bin/env python3 | |
2 | |
3 import logging | |
4 logger = logging.getLogger(__name__) | |
5 | |
6 import numpy as np | |
7 try: | |
8 import numexpr as ne | |
9 except: | |
10 pass | |
11 try: | |
12 import scipy.ndimage as spi | |
13 except: | |
14 pass | |
15 | |
16 from multiprocessing import cpu_count | |
17 from nexusformat.nexus import * | |
18 from os import mkdir | |
19 from os import path as os_path | |
20 try: | |
21 from skimage.transform import iradon | |
22 except: | |
23 pass | |
24 try: | |
25 from skimage.restoration import denoise_tv_chambolle | |
26 except: | |
27 pass | |
28 from time import time | |
29 try: | |
30 import tomopy | |
31 except: | |
32 pass | |
33 from yaml import safe_load, safe_dump | |
34 | |
35 from msnctools.fit import Fit | |
36 from msnctools.general import illegal_value, is_int, is_int_pair, is_num, is_index_range, \ | |
37 input_int, input_num, input_yesno, input_menu, draw_mask_1d, select_image_bounds, \ | |
38 select_one_image_bound, clear_imshow, quick_imshow, clear_plot, quick_plot | |
39 | |
40 from workflow.models import import_scanparser, FlatField, TomoField, TomoWorkflow | |
41 from workflow.__version__ import __version__ | |
42 | |
43 num_core_tomopy_limit = 24 | |
44 | |
45 def nxcopy(nxobject:NXobject, exclude_nxpaths:list[str]=[], nxpath_prefix:str='') -> NXobject: | |
46 '''Function that returns a copy of a nexus object, optionally exluding certain child items. | |
47 | |
48 :param nxobject: the original nexus object to return a "copy" of | |
49 :type nxobject: nexusformat.nexus.NXobject | |
50 :param exlude_nxpaths: a list of paths to child nexus objects that | |
51 should be exluded from the returned "copy", defaults to `[]` | |
52 :type exclude_nxpaths: list[str], optional | |
53 :param nxpath_prefix: For use in recursive calls from inside this | |
54 function only! | |
55 :type nxpath_prefix: str | |
56 :return: a copy of `nxobject` with some children optionally exluded. | |
57 :rtype: NXobject | |
58 ''' | |
59 | |
60 nxobject_copy = nxobject.__class__() | |
61 if not len(nxpath_prefix): | |
62 if 'default' in nxobject.attrs: | |
63 nxobject_copy.attrs['default'] = nxobject.attrs['default'] | |
64 else: | |
65 for k, v in nxobject.attrs.items(): | |
66 nxobject_copy.attrs[k] = v | |
67 | |
68 for k, v in nxobject.items(): | |
69 nxpath = os_path.join(nxpath_prefix, k) | |
70 | |
71 if nxpath in exclude_nxpaths: | |
72 continue | |
73 | |
74 if isinstance(v, NXgroup): | |
75 nxobject_copy[k] = nxcopy(v, exclude_nxpaths=exclude_nxpaths, | |
76 nxpath_prefix=os_path.join(nxpath_prefix, k)) | |
77 else: | |
78 nxobject_copy[k] = v | |
79 | |
80 return(nxobject_copy) | |
81 | |
82 class set_numexpr_threads: | |
83 | |
84 def __init__(self, num_core): | |
85 if num_core is None or num_core < 1 or num_core > cpu_count(): | |
86 self.num_core = cpu_count() | |
87 else: | |
88 self.num_core = num_core | |
89 | |
90 def __enter__(self): | |
91 self.num_core_org = ne.set_num_threads(self.num_core) | |
92 | |
93 def __exit__(self, exc_type, exc_value, traceback): | |
94 ne.set_num_threads(self.num_core_org) | |
95 | |
96 class Tomo: | |
97 """Processing tomography data with misalignment. | |
98 """ | |
99 def __init__(self, galaxy_flag=False, num_core=-1, output_folder='.', save_figs=None, | |
100 test_mode=False): | |
101 """Initialize with optional config input file or dictionary | |
102 """ | |
103 if not isinstance(galaxy_flag, bool): | |
104 raise ValueError(f'Invalid parameter galaxy_flag ({galaxy_flag})') | |
105 self.galaxy_flag = galaxy_flag | |
106 self.num_core = num_core | |
107 if self.galaxy_flag: | |
108 if output_folder != '.': | |
109 logger.warning('Ignoring output_folder in galaxy mode') | |
110 self.output_folder = '.' | |
111 if test_mode != False: | |
112 logger.warning('Ignoring test_mode in galaxy mode') | |
113 self.test_mode = False | |
114 if save_figs is not None: | |
115 logger.warning('Ignoring save_figs in galaxy mode') | |
116 save_figs = 'only' | |
117 else: | |
118 self.output_folder = os_path.abspath(output_folder) | |
119 if not os_path.isdir(output_folder): | |
120 mkdir(os_path.abspath(output_folder)) | |
121 if not isinstance(test_mode, bool): | |
122 raise ValueError(f'Invalid parameter test_mode ({test_mode})') | |
123 self.test_mode = test_mode | |
124 if save_figs is None: | |
125 save_figs = 'no' | |
126 self.test_config = {} | |
127 if self.test_mode: | |
128 if save_figs != 'only': | |
129 logger.warning('Ignoring save_figs in test mode') | |
130 save_figs = 'only' | |
131 if save_figs == 'only': | |
132 self.save_only = True | |
133 self.save_figs = True | |
134 elif save_figs == 'yes': | |
135 self.save_only = False | |
136 self.save_figs = True | |
137 elif save_figs == 'no': | |
138 self.save_only = False | |
139 self.save_figs = False | |
140 else: | |
141 raise ValueError(f'Invalid parameter save_figs ({save_figs})') | |
142 if self.save_only: | |
143 self.block = False | |
144 else: | |
145 self.block = True | |
146 if self.num_core == -1: | |
147 self.num_core = cpu_count() | |
148 if not is_int(self.num_core, gt=0, log=False): | |
149 raise ValueError(f'Invalid parameter num_core ({num_core})') | |
150 if self.num_core > cpu_count(): | |
151 logger.warning(f'num_core = {self.num_core} is larger than the number of available ' | |
152 f'processors and reduced to {cpu_count()}') | |
153 self.num_core= cpu_count() | |
154 | |
155 def read(self, filename): | |
156 extension = os_path.splitext(filename)[1] | |
157 if extension == '.yml' or extension == '.yaml': | |
158 with open(filename, 'r') as f: | |
159 config = safe_load(f) | |
160 # if len(config) > 1: | |
161 # raise ValueError(f'Multiple root entries in {filename} not yet implemented') | |
162 # if len(list(config.values())[0]) > 1: | |
163 # raise ValueError(f'Multiple sample maps in {filename} not yet implemented') | |
164 return(config) | |
165 elif extension == '.nxs': | |
166 with NXFile(filename, mode='r') as nxfile: | |
167 nxroot = nxfile.readfile() | |
168 return(nxroot) | |
169 else: | |
170 raise ValueError(f'Invalid filename extension ({extension})') | |
171 | |
172 def write(self, data, filename): | |
173 extension = os_path.splitext(filename)[1] | |
174 if extension == '.yml' or extension == '.yaml': | |
175 with open(filename, 'w') as f: | |
176 safe_dump(data, f) | |
177 elif extension == '.nxs': | |
178 data.save(filename, mode='w') | |
179 elif extension == '.nc': | |
180 data.to_netcdf(os_path=filename) | |
181 else: | |
182 raise ValueError(f'Invalid filename extension ({extension})') | |
183 | |
184 def gen_reduced_data(self, data, img_x_bounds=None): | |
185 """Generate the reduced tomography images. | |
186 """ | |
187 logger.info('Generate the reduced tomography images') | |
188 | |
189 # Create plot galaxy path directory if needed | |
190 if self.galaxy_flag and not os_path.exists('tomo_reduce_plots'): | |
191 mkdir('tomo_reduce_plots') | |
192 | |
193 if isinstance(data, dict): | |
194 # Create Nexus format object from input dictionary | |
195 wf = TomoWorkflow(**data) | |
196 if len(wf.sample_maps) > 1: | |
197 raise ValueError(f'Multiple sample maps not yet implemented') | |
198 # print(f'\nwf:\n{wf}\n') | |
199 nxroot = NXroot() | |
200 t0 = time() | |
201 for sample_map in wf.sample_maps: | |
202 logger.info(f'Start constructing the {sample_map.title} map.') | |
203 import_scanparser(sample_map.station) | |
204 sample_map.construct_nxentry(nxroot, include_raw_data=False) | |
205 logger.info(f'Constructed all sample maps in {time()-t0:.2f} seconds.') | |
206 nxentry = nxroot[nxroot.attrs['default']] | |
207 # Get test mode configuration info | |
208 if self.test_mode: | |
209 self.test_config = data['sample_maps'][0]['test_mode'] | |
210 elif isinstance(data, NXroot): | |
211 nxentry = data[data.attrs['default']] | |
212 else: | |
213 raise ValueError(f'Invalid parameter data ({data})') | |
214 | |
215 # Create an NXprocess to store data reduction (meta)data | |
216 reduced_data = NXprocess() | |
217 | |
218 # Generate dark field | |
219 if 'dark_field' in nxentry['spec_scans']: | |
220 reduced_data = self._gen_dark(nxentry, reduced_data) | |
221 | |
222 # Generate bright field | |
223 reduced_data = self._gen_bright(nxentry, reduced_data) | |
224 | |
225 # Set vertical detector bounds for image stack | |
226 img_x_bounds = self._set_detector_bounds(nxentry, reduced_data, img_x_bounds=img_x_bounds) | |
227 logger.info(f'img_x_bounds = {img_x_bounds}') | |
228 reduced_data['img_x_bounds'] = img_x_bounds | |
229 | |
230 # Set zoom and/or theta skip to reduce memory the requirement | |
231 zoom_perc, num_theta_skip = self._set_zoom_or_skip() | |
232 if zoom_perc is not None: | |
233 reduced_data.attrs['zoom_perc'] = zoom_perc | |
234 if num_theta_skip is not None: | |
235 reduced_data.attrs['num_theta_skip'] = num_theta_skip | |
236 | |
237 # Generate reduced tomography fields | |
238 reduced_data = self._gen_tomo(nxentry, reduced_data) | |
239 | |
240 # Create a copy of the input Nexus object and remove raw and any existing reduced data | |
241 if isinstance(data, NXroot): | |
242 exclude_items = [f'{nxentry._name}/reduced_data/data', | |
243 f'{nxentry._name}/instrument/detector/data', | |
244 f'{nxentry._name}/instrument/detector/image_key', | |
245 f'{nxentry._name}/instrument/detector/sequence_number', | |
246 f'{nxentry._name}/sample/rotation_angle', | |
247 f'{nxentry._name}/sample/x_translation', | |
248 f'{nxentry._name}/sample/z_translation', | |
249 f'{nxentry._name}/data/data', | |
250 f'{nxentry._name}/data/image_key', | |
251 f'{nxentry._name}/data/rotation_angle', | |
252 f'{nxentry._name}/data/x_translation', | |
253 f'{nxentry._name}/data/z_translation'] | |
254 nxroot = nxcopy(data, exclude_nxpaths=exclude_items) | |
255 nxentry = nxroot[nxroot.attrs['default']] | |
256 | |
257 # Add the reduced data NXprocess | |
258 nxentry.reduced_data = reduced_data | |
259 | |
260 if 'data' not in nxentry: | |
261 nxentry.data = NXdata() | |
262 nxentry.attrs['default'] = 'data' | |
263 nxentry.data.makelink(nxentry.reduced_data.data.tomo_fields, name='reduced_data') | |
264 nxentry.data.makelink(nxentry.reduced_data.rotation_angle, name='rotation_angle') | |
265 nxentry.data.attrs['signal'] = 'reduced_data' | |
266 | |
267 return(nxroot) | |
268 | |
269 def find_centers(self, nxroot, center_rows=None): | |
270 """Find the calibrated center axis info | |
271 """ | |
272 logger.info('Find the calibrated center axis info') | |
273 | |
274 if not isinstance(nxroot, NXroot): | |
275 raise ValueError(f'Invalid parameter nxroot ({nxroot})') | |
276 nxentry = nxroot[nxroot.attrs['default']] | |
277 if not isinstance(nxentry, NXentry): | |
278 raise ValueError(f'Invalid nxentry ({nxentry})') | |
279 if self.galaxy_flag: | |
280 if center_rows is None: | |
281 raise ValueError(f'Missing parameter center_rows ({center_rows})') | |
282 if not is_int_pair(center_rows): | |
283 raise ValueError(f'Invalid parameter center_rows ({center_rows})') | |
284 elif center_rows is not None: | |
285 logging.warning(f'Ignoring parameter center_rows ({center_rows})') | |
286 center_rows = None | |
287 | |
288 # Create plot galaxy path directory and path if needed | |
289 if self.galaxy_flag: | |
290 if not os_path.exists('tomo_find_centers_plots'): | |
291 mkdir('tomo_find_centers_plots') | |
292 path = 'tomo_find_centers_plots' | |
293 else: | |
294 path = self.output_folder | |
295 | |
296 # Check if reduced data is available | |
297 if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data): | |
298 raise KeyError(f'Unable to find valid reduced data in {nxentry}.') | |
299 | |
300 # Select the image stack to calibrate the center axis | |
301 # reduced data axes order: stack,row,theta,column | |
302 # Note: Nexus cannot follow a link if the data it points to is too big, | |
303 # so get the data from the actual place, not from nxentry.data | |
304 num_tomo_stacks = nxentry.reduced_data.data.tomo_fields.shape[0] | |
305 if num_tomo_stacks == 1: | |
306 center_stack_index = 0 | |
307 center_stack = np.asarray(nxentry.reduced_data.data.tomo_fields[0]) | |
308 if not center_stack.size: | |
309 raise KeyError('Unable to load the required reduced tomography stack') | |
310 default = 'n' | |
311 else: | |
312 if self.test_mode: | |
313 center_stack_index = self.test_config['center_stack_index']-1 # make offset 0 | |
314 else: | |
315 center_stack_index = input_int('\nEnter tomography stack index to calibrate the ' | |
316 'center axis', ge=0, le=num_tomo_stacks-1, default=int(num_tomo_stacks/2)) | |
317 center_stack = \ | |
318 np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index]) | |
319 if not center_stack.size: | |
320 raise KeyError('Unable to load the required reduced tomography stack') | |
321 default = 'y' | |
322 | |
323 # Get thetas (in degrees) | |
324 thetas = np.asarray(nxentry.reduced_data.rotation_angle) | |
325 | |
326 # Get effective pixel_size | |
327 if 'zoom_perc' in nxentry.reduced_data: | |
328 eff_pixel_size = 100.*(nxentry.instrument.detector.x_pixel_size/ | |
329 nxentry.reduced_data.attrs['zoom_perc']) | |
330 else: | |
331 eff_pixel_size = nxentry.instrument.detector.x_pixel_size | |
332 | |
333 # Get cross sectional diameter | |
334 cross_sectional_dim = center_stack.shape[2]*eff_pixel_size | |
335 logger.debug(f'cross_sectional_dim = {cross_sectional_dim}') | |
336 | |
337 # Determine center offset at sample row boundaries | |
338 logger.info('Determine center offset at sample row boundaries') | |
339 | |
340 # Lower row center | |
341 # center_stack order: row,theta,column | |
342 if self.test_mode: | |
343 lower_row = self.test_config['lower_row'] | |
344 elif self.galaxy_flag: | |
345 lower_row = min(center_rows) | |
346 if not 0 <= lower_row < center_stack.shape[0]-1: | |
347 raise ValueError(f'Invalid parameter center_rows ({center_rows})') | |
348 else: | |
349 lower_row = select_one_image_bound(center_stack[:,0,:], 0, bound=0, | |
350 title=f'theta={round(thetas[0], 2)+0}', | |
351 bound_name='row index to find lower center', default=default) | |
352 lower_center_offset = self._find_center_one_plane(center_stack[lower_row,:,:], lower_row, | |
353 thetas, eff_pixel_size, cross_sectional_dim, path=path, num_core=self.num_core) | |
354 logger.debug(f'lower_row = {lower_row:.2f}') | |
355 logger.debug(f'lower_center_offset = {lower_center_offset:.2f}') | |
356 | |
357 # Upper row center | |
358 if self.test_mode: | |
359 upper_row = self.test_config['upper_row'] | |
360 elif self.galaxy_flag: | |
361 upper_row = max(center_rows) | |
362 if not lower_row < upper_row < center_stack.shape[0]: | |
363 raise ValueError(f'Invalid parameter center_rows ({center_rows})') | |
364 else: | |
365 upper_row = select_one_image_bound(center_stack[:,0,:], 0, | |
366 bound=center_stack.shape[0]-1, title=f'theta={round(thetas[0], 2)+0}', | |
367 bound_name='row index to find upper center', default=default) | |
368 upper_center_offset = self._find_center_one_plane(center_stack[upper_row,:,:], upper_row, | |
369 thetas, eff_pixel_size, cross_sectional_dim, path=path, num_core=self.num_core) | |
370 logger.debug(f'upper_row = {upper_row:.2f}') | |
371 logger.debug(f'upper_center_offset = {upper_center_offset:.2f}') | |
372 del center_stack | |
373 | |
374 center_config = {'lower_row': lower_row, 'lower_center_offset': lower_center_offset, | |
375 'upper_row': upper_row, 'upper_center_offset': upper_center_offset} | |
376 if num_tomo_stacks > 1: | |
377 center_config['center_stack_index'] = center_stack_index+1 # save as offset 1 | |
378 | |
379 # Save test data to file | |
380 if self.test_mode: | |
381 with open(f'{self.output_folder}/center_config.yaml', 'w') as f: | |
382 safe_dump(center_config, f) | |
383 | |
384 return(center_config) | |
385 | |
386 def reconstruct_data(self, nxroot, center_info, x_bounds=None, y_bounds=None): | |
387 """Reconstruct the tomography data. | |
388 """ | |
389 logger.info('Reconstruct the tomography data') | |
390 | |
391 if not isinstance(nxroot, NXroot): | |
392 raise ValueError(f'Invalid parameter nxroot ({nxroot})') | |
393 nxentry = nxroot[nxroot.attrs['default']] | |
394 if not isinstance(nxentry, NXentry): | |
395 raise ValueError(f'Invalid nxentry ({nxentry})') | |
396 if not isinstance(center_info, dict): | |
397 raise ValueError(f'Invalid parameter center_info ({center_info})') | |
398 | |
399 # Create plot galaxy path directory and path if needed | |
400 if self.galaxy_flag: | |
401 if not os_path.exists('tomo_reconstruct_plots'): | |
402 mkdir('tomo_reconstruct_plots') | |
403 path = 'tomo_reconstruct_plots' | |
404 else: | |
405 path = self.output_folder | |
406 | |
407 # Check if reduced data is available | |
408 if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data): | |
409 raise KeyError(f'Unable to find valid reduced data in {nxentry}.') | |
410 | |
411 # Create an NXprocess to store image reconstruction (meta)data | |
412 # if 'reconstructed_data' in nxentry: | |
413 # logger.warning(f'Existing reconstructed data in {nxentry} will be overwritten.') | |
414 # del nxentry['reconstructed_data'] | |
415 # if 'data' in nxentry and 'reconstructed_data' in nxentry.data: | |
416 # del nxentry.data['reconstructed_data'] | |
417 nxprocess = NXprocess() | |
418 | |
419 # Get rotation axis rows and centers | |
420 lower_row = center_info.get('lower_row') | |
421 lower_center_offset = center_info.get('lower_center_offset') | |
422 upper_row = center_info.get('upper_row') | |
423 upper_center_offset = center_info.get('upper_center_offset') | |
424 if (lower_row is None or lower_center_offset is None or upper_row is None or | |
425 upper_center_offset is None): | |
426 raise KeyError(f'Unable to find valid calibrated center axis info in {center_info}.') | |
427 center_slope = (upper_center_offset-lower_center_offset)/(upper_row-lower_row) | |
428 | |
429 # Get thetas (in degrees) | |
430 thetas = np.asarray(nxentry.reduced_data.rotation_angle) | |
431 | |
432 # Reconstruct tomography data | |
433 # reduced data axes order: stack,row,theta,column | |
434 # reconstructed data order in each stack: row/z,x,y | |
435 # Note: Nexus cannot follow a link if the data it points to is too big, | |
436 # so get the data from the actual place, not from nxentry.data | |
437 if 'zoom_perc' in nxentry.reduced_data: | |
438 res_title = f'{nxentry.reduced_data.attrs["zoom_perc"]}p' | |
439 else: | |
440 res_title = 'fullres' | |
441 load_error = False | |
442 num_tomo_stacks = nxentry.reduced_data.data.tomo_fields.shape[0] | |
443 tomo_recon_stacks = num_tomo_stacks*[np.array([])] | |
444 for i in range(num_tomo_stacks): | |
445 tomo_stack = np.asarray(nxentry.reduced_data.data.tomo_fields[i]) | |
446 if not tomo_stack.size: | |
447 raise KeyError(f'Unable to load tomography stack {i} for reconstruction') | |
448 assert(0 <= lower_row < upper_row < tomo_stack.shape[0]) | |
449 center_offsets = [lower_center_offset-lower_row*center_slope, | |
450 upper_center_offset+(tomo_stack.shape[0]-1-upper_row)*center_slope] | |
451 t0 = time() | |
452 logger.debug(f'Running _reconstruct_one_tomo_stack on {self.num_core} cores ...') | |
453 tomo_recon_stack = self._reconstruct_one_tomo_stack(tomo_stack, thetas, | |
454 center_offsets=center_offsets, num_core=self.num_core, algorithm='gridrec') | |
455 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
456 logger.info(f'Reconstruction of stack {i} took {time()-t0:.2f} seconds') | |
457 | |
458 # Combine stacks | |
459 tomo_recon_stacks[i] = tomo_recon_stack | |
460 | |
461 # Resize the reconstructed tomography data | |
462 # reconstructed data order in each stack: row/z,x,y | |
463 if self.test_mode: | |
464 x_bounds = self.test_config.get('x_bounds') | |
465 y_bounds = self.test_config.get('y_bounds') | |
466 z_bounds = None | |
467 elif self.galaxy_flag: | |
468 if x_bounds is not None and not is_int_pair(x_bounds, ge=0, | |
469 lt=tomo_recon_stacks[0].shape[1]): | |
470 raise ValueError(f'Invalid parameter x_bounds ({x_bounds})') | |
471 if y_bounds is not None and not is_int_pair(y_bounds, ge=0, | |
472 lt=tomo_recon_stacks[0].shape[1]): | |
473 raise ValueError(f'Invalid parameter y_bounds ({y_bounds})') | |
474 z_bounds = None | |
475 else: | |
476 x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_stacks) | |
477 if x_bounds is None: | |
478 x_range = (0, tomo_recon_stacks[0].shape[1]) | |
479 x_slice = int(x_range[1]/2) | |
480 else: | |
481 x_range = (min(x_bounds), max(x_bounds)) | |
482 x_slice = int((x_bounds[0]+x_bounds[1])/2) | |
483 if y_bounds is None: | |
484 y_range = (0, tomo_recon_stacks[0].shape[2]) | |
485 y_slice = int(y_range[1]/2) | |
486 else: | |
487 y_range = (min(y_bounds), max(y_bounds)) | |
488 y_slice = int((y_bounds[0]+y_bounds[1])/2) | |
489 if z_bounds is None: | |
490 z_range = (0, tomo_recon_stacks[0].shape[0]) | |
491 z_slice = int(z_range[1]/2) | |
492 else: | |
493 z_range = (min(z_bounds), max(z_bounds)) | |
494 z_slice = int((z_bounds[0]+z_bounds[1])/2) | |
495 | |
496 # Plot a few reconstructed image slices | |
497 if num_tomo_stacks == 1: | |
498 basetitle = 'recon' | |
499 else: | |
500 basetitle = f'recon stack {i}' | |
501 for i, stack in enumerate(tomo_recon_stacks): | |
502 title = f'{basetitle} {res_title} xslice{x_slice}' | |
503 quick_imshow(stack[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]], | |
504 title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, | |
505 block=self.block) | |
506 title = f'{basetitle} {res_title} yslice{y_slice}' | |
507 quick_imshow(stack[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice], | |
508 title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, | |
509 block=self.block) | |
510 title = f'{basetitle} {res_title} zslice{z_slice}' | |
511 quick_imshow(stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], | |
512 title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, | |
513 block=self.block) | |
514 | |
515 # Save test data to file | |
516 # reconstructed data order in each stack: row/z,x,y | |
517 if self.test_mode: | |
518 for i, stack in enumerate(tomo_recon_stacks): | |
519 np.savetxt(f'{self.output_folder}/recon_stack_{i+1}.txt', | |
520 stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e') | |
521 | |
522 # Add image reconstruction to reconstructed data NXprocess | |
523 # reconstructed data order in each stack: row/z,x,y | |
524 nxprocess.data = NXdata() | |
525 nxprocess.attrs['default'] = 'data' | |
526 for k, v in center_info.items(): | |
527 nxprocess[k] = v | |
528 if x_bounds is not None: | |
529 nxprocess.x_bounds = x_bounds | |
530 if y_bounds is not None: | |
531 nxprocess.y_bounds = y_bounds | |
532 if z_bounds is not None: | |
533 nxprocess.z_bounds = z_bounds | |
534 nxprocess.data['reconstructed_data'] = np.asarray([stack[z_range[0]:z_range[1], | |
535 x_range[0]:x_range[1],y_range[0]:y_range[1]] for stack in tomo_recon_stacks]) | |
536 nxprocess.data.attrs['signal'] = 'reconstructed_data' | |
537 | |
538 # Create a copy of the input Nexus object and remove reduced data | |
539 exclude_items = [f'{nxentry._name}/reduced_data/data', f'{nxentry._name}/data/reduced_data'] | |
540 nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items) | |
541 | |
542 # Add the reconstructed data NXprocess to the new Nexus object | |
543 nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']] | |
544 nxentry_copy.reconstructed_data = nxprocess | |
545 if 'data' not in nxentry_copy: | |
546 nxentry_copy.data = NXdata() | |
547 nxentry_copy.attrs['default'] = 'data' | |
548 nxentry_copy.data.makelink(nxprocess.data.reconstructed_data, name='reconstructed_data') | |
549 nxentry_copy.data.attrs['signal'] = 'reconstructed_data' | |
550 | |
551 return(nxroot_copy) | |
552 | |
553 def combine_data(self, nxroot): | |
554 """Combine the reconstructed tomography stacks. | |
555 """ | |
556 logger.info('Combine the reconstructed tomography stacks') | |
557 | |
558 if not isinstance(nxroot, NXroot): | |
559 raise ValueError(f'Invalid parameter nxroot ({nxroot})') | |
560 nxentry = nxroot[nxroot.attrs['default']] | |
561 if not isinstance(nxentry, NXentry): | |
562 raise ValueError(f'Invalid nxentry ({nxentry})') | |
563 | |
564 # Create plot galaxy path directory and path if needed | |
565 if self.galaxy_flag: | |
566 if not os_path.exists('tomo_combine_plots'): | |
567 mkdir('tomo_combine_plots') | |
568 path = 'tomo_combine_plots' | |
569 else: | |
570 path = self.output_folder | |
571 | |
572 # Check if reconstructed image data is available | |
573 if ('reconstructed_data' not in nxentry or 'reconstructed_data' not in nxentry.data): | |
574 raise KeyError(f'Unable to find valid reconstructed image data in {nxentry}.') | |
575 | |
576 # Create an NXprocess to store combined image reconstruction (meta)data | |
577 # if 'combined_data' in nxentry: | |
578 # logger.warning(f'Existing combined data in {nxentry} will be overwritten.') | |
579 # del nxentry['combined_data'] | |
580 # if 'data' in nxentry 'combined_data' in nxentry.data: | |
581 # del nxentry.data['combined_data'] | |
582 nxprocess = NXprocess() | |
583 | |
584 # Get the reconstructed data | |
585 # reconstructed data order: stack,row(z),x,y | |
586 # Note: Nexus cannot follow a link if the data it points to is too big, | |
587 # so get the data from the actual place, not from nxentry.data | |
588 tomo_recon_stacks = np.asarray(nxentry.reconstructed_data.data.reconstructed_data) | |
589 num_tomo_stacks = tomo_recon_stacks.shape[0] | |
590 if num_tomo_stacks == 1: | |
591 return(nxroot) | |
592 t0 = time() | |
593 logger.debug(f'Combining the reconstructed stacks ...') | |
594 tomo_recon_combined = tomo_recon_stacks[0,:,:,:] | |
595 if num_tomo_stacks > 2: | |
596 tomo_recon_combined = np.concatenate([tomo_recon_combined]+ | |
597 [tomo_recon_stacks[i,:,:,:] for i in range(1, num_tomo_stacks-1)]) | |
598 if num_tomo_stacks > 1: | |
599 tomo_recon_combined = np.concatenate([tomo_recon_combined]+ | |
600 [tomo_recon_stacks[num_tomo_stacks-1,:,:,:]]) | |
601 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
602 logger.info(f'Combining the reconstructed stacks took {time()-t0:.2f} seconds') | |
603 | |
604 # Resize the combined tomography data set | |
605 # combined data order: row/z,x,y | |
606 if self.test_mode: | |
607 x_bounds = None | |
608 y_bounds = None | |
609 z_bounds = self.test_config.get('z_bounds') | |
610 elif self.galaxy_flag: | |
611 exit('TODO') | |
612 if x_bounds is not None and not is_int_pair(x_bounds, ge=0, | |
613 lt=tomo_recon_stacks[0].shape[1]): | |
614 raise ValueError(f'Invalid parameter x_bounds ({x_bounds})') | |
615 if y_bounds is not None and not is_int_pair(y_bounds, ge=0, | |
616 lt=tomo_recon_stacks[0].shape[1]): | |
617 raise ValueError(f'Invalid parameter y_bounds ({y_bounds})') | |
618 z_bounds = None | |
619 else: | |
620 x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_combined, | |
621 z_only=True) | |
622 if x_bounds is None: | |
623 x_range = (0, tomo_recon_combined.shape[1]) | |
624 x_slice = int(x_range[1]/2) | |
625 else: | |
626 x_range = x_bounds | |
627 x_slice = int((x_bounds[0]+x_bounds[1])/2) | |
628 if y_bounds is None: | |
629 y_range = (0, tomo_recon_combined.shape[2]) | |
630 y_slice = int(y_range[1]/2) | |
631 else: | |
632 y_range = y_bounds | |
633 y_slice = int((y_bounds[0]+y_bounds[1])/2) | |
634 if z_bounds is None: | |
635 z_range = (0, tomo_recon_combined.shape[0]) | |
636 z_slice = int(z_range[1]/2) | |
637 else: | |
638 z_range = z_bounds | |
639 z_slice = int((z_bounds[0]+z_bounds[1])/2) | |
640 | |
641 # Plot a few combined image slices | |
642 quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]], | |
643 title=f'recon combined xslice{x_slice}', path=path, | |
644 save_fig=self.save_figs, save_only=self.save_only, block=self.block) | |
645 quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice], | |
646 title=f'recon combined yslice{y_slice}', path=path, | |
647 save_fig=self.save_figs, save_only=self.save_only, block=self.block) | |
648 quick_imshow(tomo_recon_combined[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], | |
649 title=f'recon combined zslice{z_slice}', path=path, | |
650 save_fig=self.save_figs, save_only=self.save_only, block=self.block) | |
651 | |
652 # Save test data to file | |
653 # combined data order: row/z,x,y | |
654 if self.test_mode: | |
655 np.savetxt(f'{self.output_folder}/recon_combined.txt', tomo_recon_combined[ | |
656 z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e') | |
657 | |
658 # Add image reconstruction to reconstructed data NXprocess | |
659 # combined data order: row/z,x,y | |
660 nxprocess.data = NXdata() | |
661 nxprocess.attrs['default'] = 'data' | |
662 if x_bounds is not None: | |
663 nxprocess.x_bounds = x_bounds | |
664 if y_bounds is not None: | |
665 nxprocess.y_bounds = y_bounds | |
666 if z_bounds is not None: | |
667 nxprocess.z_bounds = z_bounds | |
668 nxprocess.data['combined_data'] = tomo_recon_combined | |
669 nxprocess.data.attrs['signal'] = 'combined_data' | |
670 | |
671 # Create a copy of the input Nexus object and remove reconstructed data | |
672 exclude_items = [f'{nxentry._name}/reconstructed_data/data', | |
673 f'{nxentry._name}/data/reconstructed_data'] | |
674 nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items) | |
675 | |
676 # Add the combined data NXprocess to the new Nexus object | |
677 nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']] | |
678 nxentry_copy.combined_data = nxprocess | |
679 if 'data' not in nxentry_copy: | |
680 nxentry_copy.data = NXdata() | |
681 nxentry_copy.attrs['default'] = 'data' | |
682 nxentry_copy.data.makelink(nxprocess.data.combined_data, name='combined_data') | |
683 nxentry_copy.data.attrs['signal'] = 'combined_data' | |
684 | |
685 return(nxroot_copy) | |
686 | |
687 def _gen_dark(self, nxentry, reduced_data): | |
688 """Generate dark field. | |
689 """ | |
690 # Get the dark field images | |
691 image_key = nxentry.instrument.detector.get('image_key', None) | |
692 if image_key and 'data' in nxentry.instrument.detector: | |
693 field_indices = [index for index, key in enumerate(image_key) if key == 2] | |
694 tdf_stack = nxentry.instrument.detector.data[field_indices,:,:] | |
695 # RV the default NXtomo form does not accomodate bright or dark field stacks | |
696 else: | |
697 dark_field_scans = nxentry.spec_scans.dark_field | |
698 dark_field = FlatField.construct_from_nxcollection(dark_field_scans) | |
699 prefix = str(nxentry.instrument.detector.local_name) | |
700 tdf_stack = dark_field.get_detector_data(prefix) | |
701 if isinstance(tdf_stack, list): | |
702 exit('TODO') | |
703 | |
704 # Take median | |
705 if tdf_stack.ndim == 2: | |
706 tdf = tdf_stack | |
707 elif tdf_stack.ndim == 3: | |
708 tdf = np.median(tdf_stack, axis=0) | |
709 del tdf_stack | |
710 else: | |
711 raise ValueError(f'Invalid tdf_stack shape ({tdf_stack.shape})') | |
712 | |
713 # Remove dark field intensities above the cutoff | |
714 #RV tdf_cutoff = None | |
715 tdf_cutoff = tdf.min()+2*(np.median(tdf)-tdf.min()) | |
716 logger.debug(f'tdf_cutoff = {tdf_cutoff}') | |
717 if tdf_cutoff is not None: | |
718 if not is_num(tdf_cutoff, ge=0): | |
719 logger.warning(f'Ignoring illegal value of tdf_cutoff {tdf_cutoff}') | |
720 else: | |
721 tdf[tdf > tdf_cutoff] = np.nan | |
722 logger.debug(f'tdf_cutoff = {tdf_cutoff}') | |
723 | |
724 # Remove nans | |
725 tdf_mean = np.nanmean(tdf) | |
726 logger.debug(f'tdf_mean = {tdf_mean}') | |
727 np.nan_to_num(tdf, copy=False, nan=tdf_mean, posinf=tdf_mean, neginf=0.) | |
728 | |
729 # Plot dark field | |
730 if self.galaxy_flag: | |
731 quick_imshow(tdf, title='dark field', path='tomo_reduce_plots', save_fig=self.save_figs, | |
732 save_only=self.save_only) | |
733 elif not self.test_mode: | |
734 quick_imshow(tdf, title='dark field', path=self.output_folder, save_fig=self.save_figs, | |
735 save_only=self.save_only) | |
736 clear_imshow('dark field') | |
737 # quick_imshow(tdf, title='dark field', block=True) | |
738 | |
739 # Add dark field to reduced data NXprocess | |
740 reduced_data.data = NXdata() | |
741 reduced_data.data['dark_field'] = tdf | |
742 | |
743 return(reduced_data) | |
744 | |
745 def _gen_bright(self, nxentry, reduced_data): | |
746 """Generate bright field. | |
747 """ | |
748 # Get the bright field images | |
749 image_key = nxentry.instrument.detector.get('image_key', None) | |
750 if image_key and 'data' in nxentry.instrument.detector: | |
751 field_indices = [index for index, key in enumerate(image_key) if key == 1] | |
752 tbf_stack = nxentry.instrument.detector.data[field_indices,:,:] | |
753 # RV the default NXtomo form does not accomodate bright or dark field stacks | |
754 else: | |
755 bright_field_scans = nxentry.spec_scans.bright_field | |
756 bright_field = FlatField.construct_from_nxcollection(bright_field_scans) | |
757 prefix = str(nxentry.instrument.detector.local_name) | |
758 tbf_stack = bright_field.get_detector_data(prefix) | |
759 if isinstance(tbf_stack, list): | |
760 exit('TODO') | |
761 | |
762 # Take median if more than one image | |
763 """Median or mean: It may be best to try the median because of some image | |
764 artifacts that arise due to crinkles in the upstream kapton tape windows | |
765 causing some phase contrast images to appear on the detector. | |
766 One thing that also may be useful in a future implementation is to do a | |
767 brightfield adjustment on EACH frame of the tomo based on a ROI in the | |
768 corner of the frame where there is no sample but there is the direct X-ray | |
769 beam because there is frame to frame fluctuations from the incoming beam. | |
770 We don’t typically account for them but potentially could. | |
771 """ | |
772 if tbf_stack.ndim == 2: | |
773 tbf = tbf_stack | |
774 elif tbf_stack.ndim == 3: | |
775 tbf = np.median(tbf_stack, axis=0) | |
776 del tbf_stack | |
777 else: | |
778 raise ValueError(f'Invalid tbf_stack shape ({tbf_stacks.shape})') | |
779 | |
780 # Subtract dark field | |
781 if 'data' in reduced_data and 'dark_field' in reduced_data.data: | |
782 tbf -= reduced_data.data.dark_field | |
783 else: | |
784 logger.warning('Dark field unavailable') | |
785 | |
786 # Set any non-positive values to one | |
787 # (avoid negative bright field values for spikes in dark field) | |
788 tbf[tbf < 1] = 1 | |
789 | |
790 # Plot bright field | |
791 if self.galaxy_flag: | |
792 quick_imshow(tbf, title='bright field', path='tomo_reduce_plots', | |
793 save_fig=self.save_figs, save_only=self.save_only) | |
794 elif not self.test_mode: | |
795 quick_imshow(tbf, title='bright field', path=self.output_folder, | |
796 save_fig=self.save_figs, save_only=self.save_only) | |
797 clear_imshow('bright field') | |
798 # quick_imshow(tbf, title='bright field', block=True) | |
799 | |
800 # Add bright field to reduced data NXprocess | |
801 if 'data' not in reduced_data: | |
802 reduced_data.data = NXdata() | |
803 reduced_data.data['bright_field'] = tbf | |
804 | |
805 return(reduced_data) | |
806 | |
807 def _set_detector_bounds(self, nxentry, reduced_data, img_x_bounds=None): | |
808 """Set vertical detector bounds for each image stack. | |
809 Right now the range is the same for each set in the image stack. | |
810 """ | |
811 if self.test_mode: | |
812 return(tuple(self.test_config['img_x_bounds'])) | |
813 | |
814 # Get the first tomography image and the reference heights | |
815 image_key = nxentry.instrument.detector.get('image_key', None) | |
816 if image_key and 'data' in nxentry.instrument.detector: | |
817 field_indices = [index for index, key in enumerate(image_key) if key == 0] | |
818 first_image = np.asarray(nxentry.instrument.detector.data[field_indices[0],:,:]) | |
819 theta = float(nxentry.sample.rotation_angle[field_indices[0]]) | |
820 z_translation_all = nxentry.sample.z_translation[field_indices] | |
821 z_translation_levels = sorted(list(set(z_translation_all))) | |
822 num_tomo_stacks = len(z_translation_levels) | |
823 else: | |
824 tomo_field_scans = nxentry.spec_scans.tomo_fields | |
825 tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans) | |
826 vertical_shifts = tomo_fields.get_vertical_shifts() | |
827 if not isinstance(vertical_shifts, list): | |
828 vertical_shifts = [vertical_shifts] | |
829 prefix = str(nxentry.instrument.detector.local_name) | |
830 t0 = time() | |
831 first_image = tomo_fields.get_detector_data(prefix, tomo_fields.scan_numbers[0], 0) | |
832 logger.debug(f'Getting first image took {time()-t0:.2f} seconds') | |
833 num_tomo_stacks = len(tomo_fields.scan_numbers) | |
834 theta = tomo_fields.theta_range['start'] | |
835 | |
836 # Select image bounds | |
837 title = f'tomography image at theta={round(theta, 2)+0}' | |
838 if (img_x_bounds is not None and not is_index_range(img_x_bounds, ge=0, | |
839 le=first_image.shape[0])): | |
840 raise ValueError(f'Invalid parameter img_x_bounds ({img_x_bounds})') | |
841 if nxentry.instrument.source.attrs['station'] in ('id1a3', 'id3a'): | |
842 pixel_size = nxentry.instrument.detector.x_pixel_size | |
843 # Try to get a fit from the bright field | |
844 tbf = np.asarray(reduced_data.data.bright_field) | |
845 tbf_shape = tbf.shape | |
846 x_sum = np.sum(tbf, 1) | |
847 x_sum_min = x_sum.min() | |
848 x_sum_max = x_sum.max() | |
849 fit = Fit.fit_data(x_sum, 'rectangle', x=np.array(range(len(x_sum))), form='atan', | |
850 guess=True) | |
851 parameters = fit.best_values | |
852 x_low_fit = parameters.get('center1', None) | |
853 x_upp_fit = parameters.get('center2', None) | |
854 sig_low = parameters.get('sigma1', None) | |
855 sig_upp = parameters.get('sigma2', None) | |
856 have_fit = fit.success and x_low_fit is not None and x_upp_fit is not None and \ | |
857 sig_low is not None and sig_upp is not None and \ | |
858 0 <= x_low_fit < x_upp_fit <= x_sum.size and \ | |
859 (sig_low+sig_upp)/(x_upp_fit-x_low_fit) < 0.1 | |
860 if have_fit: | |
861 # Set a 5% margin on each side | |
862 margin = 0.05*(x_upp_fit-x_low_fit) | |
863 x_low_fit = max(0, x_low_fit-margin) | |
864 x_upp_fit = min(tbf_shape[0], x_upp_fit+margin) | |
865 if num_tomo_stacks == 1: | |
866 if have_fit: | |
867 # Set the default range to enclose the full fitted window | |
868 x_low = int(x_low_fit) | |
869 x_upp = int(x_upp_fit) | |
870 else: | |
871 # Center a default range of 1 mm (RV: can we get this from the slits?) | |
872 num_x_min = int((1.0-0.5*pixel_size)/pixel_size) | |
873 x_low = int(0.5*(tbf_shape[0]-num_x_min)) | |
874 x_upp = x_low+num_x_min | |
875 else: | |
876 # Get the default range from the reference heights | |
877 delta_z = vertical_shifts[1]-vertical_shifts[0] | |
878 for i in range(2, num_tomo_stacks): | |
879 delta_z = min(delta_z, vertical_shifts[i]-vertical_shifts[i-1]) | |
880 logger.debug(f'delta_z = {delta_z}') | |
881 num_x_min = int((delta_z-0.5*pixel_size)/pixel_size) | |
882 logger.debug(f'num_x_min = {num_x_min}') | |
883 if num_x_min > tbf_shape[0]: | |
884 logger.warning('Image bounds and pixel size prevent seamless stacking') | |
885 if have_fit: | |
886 # Center the default range relative to the fitted window | |
887 x_low = int(0.5*(x_low_fit+x_upp_fit-num_x_min)) | |
888 x_upp = x_low+num_x_min | |
889 else: | |
890 # Center the default range | |
891 x_low = int(0.5*(tbf_shape[0]-num_x_min)) | |
892 x_upp = x_low+num_x_min | |
893 if self.galaxy_flag: | |
894 img_x_bounds = (x_low, x_upp) | |
895 else: | |
896 tmp = np.copy(tbf) | |
897 tmp_max = tmp.max() | |
898 tmp[x_low,:] = tmp_max | |
899 tmp[x_upp-1,:] = tmp_max | |
900 quick_imshow(tmp, title='bright field') | |
901 tmp = np.copy(first_image) | |
902 tmp_max = tmp.max() | |
903 tmp[x_low,:] = tmp_max | |
904 tmp[x_upp-1,:] = tmp_max | |
905 quick_imshow(tmp, title=title) | |
906 del tmp | |
907 quick_plot((range(x_sum.size), x_sum), | |
908 ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'), | |
909 ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'), | |
910 title='sum over theta and y') | |
911 print(f'lower bound = {x_low} (inclusive)') | |
912 print(f'upper bound = {x_upp} (exclusive)]') | |
913 accept = input_yesno('Accept these bounds (y/n)?', 'y') | |
914 clear_imshow('bright field') | |
915 clear_imshow(title) | |
916 clear_plot('sum over theta and y') | |
917 if accept: | |
918 img_x_bounds = (x_low, x_upp) | |
919 else: | |
920 while True: | |
921 mask, img_x_bounds = draw_mask_1d(x_sum, title='select x data range', | |
922 legend='sum over theta and y') | |
923 if len(img_x_bounds) == 1: | |
924 break | |
925 else: | |
926 print(f'Choose a single connected data range') | |
927 img_x_bounds = tuple(img_x_bounds[0]) | |
928 if (num_tomo_stacks > 1 and img_x_bounds[1]-img_x_bounds[0]+1 < | |
929 int((delta_z-0.5*pixel_size)/pixel_size)): | |
930 logger.warning('Image bounds and pixel size prevent seamless stacking') | |
931 else: | |
932 if num_tomo_stacks > 1: | |
933 raise NotImplementedError('Selecting image bounds for multiple stacks on FMB') | |
934 # For FMB: use the first tomography image to select range | |
935 # RV: revisit if they do tomography with multiple stacks | |
936 x_sum = np.sum(first_image, 1) | |
937 x_sum_min = x_sum.min() | |
938 x_sum_max = x_sum.max() | |
939 if self.galaxy_flag: | |
940 if img_x_bounds is None: | |
941 img_x_bounds = (0, first_image.shape[0]) | |
942 else: | |
943 quick_imshow(first_image, title=title) | |
944 print('Select vertical data reduction range from first tomography image') | |
945 img_x_bounds = select_image_bounds(first_image, 0, title=title) | |
946 clear_imshow(title) | |
947 if img_x_bounds is None: | |
948 raise ValueError('Unable to select image bounds') | |
949 | |
950 # Plot results | |
951 if self.galaxy_flag: | |
952 path = 'tomo_reduce_plots' | |
953 else: | |
954 path = self.output_folder | |
955 x_low = img_x_bounds[0] | |
956 x_upp = img_x_bounds[1] | |
957 tmp = np.copy(first_image) | |
958 tmp_max = tmp.max() | |
959 tmp[x_low,:] = tmp_max | |
960 tmp[x_upp-1,:] = tmp_max | |
961 quick_imshow(tmp, title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, | |
962 block=self.block) | |
963 del tmp | |
964 quick_plot((range(x_sum.size), x_sum), | |
965 ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'), | |
966 ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'), | |
967 title='sum over theta and y', path=path, save_fig=self.save_figs, | |
968 save_only=self.save_only, block=self.block) | |
969 | |
970 return(img_x_bounds) | |
971 | |
972 def _set_zoom_or_skip(self): | |
973 """Set zoom and/or theta skip to reduce memory the requirement for the analysis. | |
974 """ | |
975 # if input_yesno('\nDo you want to zoom in to reduce memory requirement (y/n)?', 'n'): | |
976 # zoom_perc = input_int(' Enter zoom percentage', ge=1, le=100) | |
977 # else: | |
978 # zoom_perc = None | |
979 zoom_perc = None | |
980 # if input_yesno('Do you want to skip thetas to reduce memory requirement (y/n)?', 'n'): | |
981 # num_theta_skip = input_int(' Enter the number skip theta interval', ge=0, | |
982 # lt=num_theta) | |
983 # else: | |
984 # num_theta_skip = None | |
985 num_theta_skip = None | |
986 logger.debug(f'zoom_perc = {zoom_perc}') | |
987 logger.debug(f'num_theta_skip = {num_theta_skip}') | |
988 | |
989 return(zoom_perc, num_theta_skip) | |
990 | |
991 def _gen_tomo(self, nxentry, reduced_data): | |
992 """Generate tomography fields. | |
993 """ | |
994 # Get full bright field | |
995 tbf = np.asarray(reduced_data.data.bright_field) | |
996 tbf_shape = tbf.shape | |
997 | |
998 # Get image bounds | |
999 img_x_bounds = tuple(reduced_data.get('img_x_bounds', (0, tbf_shape[0]))) | |
1000 img_y_bounds = tuple(reduced_data.get('img_y_bounds', (0, tbf_shape[1]))) | |
1001 | |
1002 # Get resized dark field | |
1003 # if 'dark_field' in data: | |
1004 # tbf = np.asarray(reduced_data.data.dark_field[ | |
1005 # img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]]) | |
1006 # else: | |
1007 # logger.warning('Dark field unavailable') | |
1008 # tdf = None | |
1009 tdf = None | |
1010 | |
1011 # Resize bright field | |
1012 if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]): | |
1013 tbf = tbf[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]] | |
1014 | |
1015 # Get the tomography images | |
1016 image_key = nxentry.instrument.detector.get('image_key', None) | |
1017 if image_key and 'data' in nxentry.instrument.detector: | |
1018 field_indices_all = [index for index, key in enumerate(image_key) if key == 0] | |
1019 z_translation_all = nxentry.sample.z_translation[field_indices_all] | |
1020 z_translation_levels = sorted(list(set(z_translation_all))) | |
1021 num_tomo_stacks = len(z_translation_levels) | |
1022 tomo_stacks = num_tomo_stacks*[np.array([])] | |
1023 horizontal_shifts = [] | |
1024 vertical_shifts = [] | |
1025 thetas = None | |
1026 tomo_stacks = [] | |
1027 for i, z_translation in enumerate(z_translation_levels): | |
1028 field_indices = [field_indices_all[index] | |
1029 for index, z in enumerate(z_translation_all) if z == z_translation] | |
1030 horizontal_shift = list(set(nxentry.sample.x_translation[field_indices])) | |
1031 assert(len(horizontal_shift) == 1) | |
1032 horizontal_shifts += horizontal_shift | |
1033 vertical_shift = list(set(nxentry.sample.z_translation[field_indices])) | |
1034 assert(len(vertical_shift) == 1) | |
1035 vertical_shifts += vertical_shift | |
1036 sequence_numbers = nxentry.instrument.detector.sequence_number[field_indices] | |
1037 if thetas is None: | |
1038 thetas = np.asarray(nxentry.sample.rotation_angle[field_indices]) \ | |
1039 [sequence_numbers] | |
1040 else: | |
1041 assert(all(thetas[i] == nxentry.sample.rotation_angle[field_indices[index]] | |
1042 for i, index in enumerate(sequence_numbers))) | |
1043 assert(list(set(sequence_numbers)) == [i for i in range(len(sequence_numbers))]) | |
1044 if list(sequence_numbers) == [i for i in range(len(sequence_numbers))]: | |
1045 tomo_stack = np.asarray(nxentry.instrument.detector.data[field_indices]) | |
1046 else: | |
1047 raise ValueError('Unable to load the tomography images') | |
1048 tomo_stacks.append(tomo_stack) | |
1049 else: | |
1050 tomo_field_scans = nxentry.spec_scans.tomo_fields | |
1051 tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans) | |
1052 horizontal_shifts = tomo_fields.get_horizontal_shifts() | |
1053 vertical_shifts = tomo_fields.get_vertical_shifts() | |
1054 prefix = str(nxentry.instrument.detector.local_name) | |
1055 t0 = time() | |
1056 tomo_stacks = tomo_fields.get_detector_data(prefix) | |
1057 logger.debug(f'Getting tomography images took {time()-t0:.2f} seconds') | |
1058 logger.debug(f'Getting all images took {time()-t0:.2f} seconds') | |
1059 thetas = np.linspace(tomo_fields.theta_range['start'], tomo_fields.theta_range['end'], | |
1060 tomo_fields.theta_range['num']) | |
1061 if not isinstance(tomo_stacks, list): | |
1062 horizontal_shifts = [horizontal_shifts] | |
1063 vertical_shifts = [vertical_shifts] | |
1064 tomo_stacks = [tomo_stacks] | |
1065 | |
1066 reduced_tomo_stacks = [] | |
1067 if self.galaxy_flag: | |
1068 path = 'tomo_reduce_plots' | |
1069 else: | |
1070 path = self.output_folder | |
1071 for i, tomo_stack in enumerate(tomo_stacks): | |
1072 # Resize the tomography images | |
1073 # Right now the range is the same for each set in the image stack. | |
1074 if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]): | |
1075 t0 = time() | |
1076 tomo_stack = tomo_stack[:,img_x_bounds[0]:img_x_bounds[1], | |
1077 img_y_bounds[0]:img_y_bounds[1]].astype('float64') | |
1078 logger.debug(f'Resizing tomography images took {time()-t0:.2f} seconds') | |
1079 | |
1080 # Subtract dark field | |
1081 if tdf is not None: | |
1082 t0 = time() | |
1083 with set_numexpr_threads(self.num_core): | |
1084 ne.evaluate('tomo_stack-tdf', out=tomo_stack) | |
1085 logger.debug(f'Subtracting dark field took {time()-t0:.2f} seconds') | |
1086 | |
1087 # Normalize | |
1088 t0 = time() | |
1089 with set_numexpr_threads(self.num_core): | |
1090 ne.evaluate('tomo_stack/tbf', out=tomo_stack, truediv=True) | |
1091 logger.debug(f'Normalizing took {time()-t0:.2f} seconds') | |
1092 | |
1093 # Remove non-positive values and linearize data | |
1094 t0 = time() | |
1095 cutoff = 1.e-6 | |
1096 with set_numexpr_threads(self.num_core): | |
1097 ne.evaluate('where(tomo_stack<cutoff, cutoff, tomo_stack)', out=tomo_stack) | |
1098 with set_numexpr_threads(self.num_core): | |
1099 ne.evaluate('-log(tomo_stack)', out=tomo_stack) | |
1100 logger.debug('Removing non-positive values and linearizing data took '+ | |
1101 f'{time()-t0:.2f} seconds') | |
1102 | |
1103 # Get rid of nans/infs that may be introduced by normalization | |
1104 t0 = time() | |
1105 np.where(np.isfinite(tomo_stack), tomo_stack, 0.) | |
1106 logger.debug(f'Remove nans/infs took {time()-t0:.2f} seconds') | |
1107 | |
1108 # Downsize tomography stack to smaller size | |
1109 # TODO use theta_skip as well | |
1110 tomo_stack = tomo_stack.astype('float32') | |
1111 if not self.test_mode: | |
1112 if len(tomo_stacks) == 1: | |
1113 title = f'red fullres theta {round(thetas[0], 2)+0}' | |
1114 else: | |
1115 title = f'red stack {i} fullres theta {round(thetas[0], 2)+0}' | |
1116 quick_imshow(tomo_stack[0,:,:], title=title, path=path, save_fig=self.save_figs, | |
1117 save_only=self.save_only, block=self.block) | |
1118 # if not self.block: | |
1119 # clear_imshow(title) | |
1120 if False and zoom_perc != 100: | |
1121 t0 = time() | |
1122 logger.debug(f'Zooming in ...') | |
1123 tomo_zoom_list = [] | |
1124 for j in range(tomo_stack.shape[0]): | |
1125 tomo_zoom = spi.zoom(tomo_stack[j,:,:], 0.01*zoom_perc) | |
1126 tomo_zoom_list.append(tomo_zoom) | |
1127 tomo_stack = np.stack([tomo_zoom for tomo_zoom in tomo_zoom_list]) | |
1128 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
1129 logger.info(f'Zooming in took {time()-t0:.2f} seconds') | |
1130 del tomo_zoom_list | |
1131 if not self.test_mode: | |
1132 title = f'red stack {zoom_perc}p theta {round(thetas[0], 2)+0}' | |
1133 quick_imshow(tomo_stack[0,:,:], title=title, path=path, save_fig=self.save_figs, | |
1134 save_only=self.save_only, block=self.block) | |
1135 # if not self.block: | |
1136 # clear_imshow(title) | |
1137 | |
1138 # Convert tomography stack from theta,row,column to row,theta,column | |
1139 t0 = time() | |
1140 tomo_stack = np.swapaxes(tomo_stack, 0, 1) | |
1141 logger.debug(f'Converting coordinate order took {time()-t0:.2f} seconds') | |
1142 | |
1143 # Save test data to file | |
1144 if self.test_mode: | |
1145 row_index = int(tomo_stack.shape[0]/2) | |
1146 np.savetxt(f'{self.output_folder}/red_stack_{i+1}.txt', tomo_stack[row_index,:,:], | |
1147 fmt='%.6e') | |
1148 | |
1149 # Combine resized stacks | |
1150 reduced_tomo_stacks.append(tomo_stack) | |
1151 | |
1152 # Add tomo field info to reduced data NXprocess | |
1153 reduced_data['rotation_angle'] = thetas | |
1154 reduced_data['x_translation'] = np.asarray(horizontal_shifts) | |
1155 reduced_data['z_translation'] = np.asarray(vertical_shifts) | |
1156 reduced_data.data['tomo_fields'] = np.asarray(reduced_tomo_stacks) | |
1157 | |
1158 if tdf is not None: | |
1159 del tdf | |
1160 del tbf | |
1161 | |
1162 return(reduced_data) | |
1163 | |
1164 def _find_center_one_plane(self, sinogram, row, thetas, eff_pixel_size, cross_sectional_dim, | |
1165 path=None, tol=0.1, num_core=1): | |
1166 """Find center for a single tomography plane. | |
1167 """ | |
1168 # Try automatic center finding routines for initial value | |
1169 # sinogram index order: theta,column | |
1170 # need column,theta for iradon, so take transpose | |
1171 sinogram_T = sinogram.T | |
1172 center = sinogram.shape[1]/2 | |
1173 | |
1174 # Try using Nghia Vo’s method | |
1175 t0 = time() | |
1176 if num_core > num_core_tomopy_limit: | |
1177 logger.debug(f'Running find_center_vo on {num_core_tomopy_limit} cores ...') | |
1178 tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core_tomopy_limit) | |
1179 else: | |
1180 logger.debug(f'Running find_center_vo on {num_core} cores ...') | |
1181 tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core) | |
1182 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
1183 logger.info(f'Finding the center using Nghia Vo’s method took {time()-t0:.2f} seconds') | |
1184 center_offset_vo = tomo_center-center | |
1185 logger.info(f'Center at row {row} using Nghia Vo’s method = {center_offset_vo:.2f}') | |
1186 t0 = time() | |
1187 logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...') | |
1188 recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas, | |
1189 eff_pixel_size, cross_sectional_dim, False, num_core) | |
1190 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
1191 logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds') | |
1192 | |
1193 title = f'edges row{row} center offset{center_offset_vo:.2f} Vo' | |
1194 self._plot_edges_one_plane(recon_plane, title, path=path) | |
1195 | |
1196 # Try using phase correlation method | |
1197 # if input_yesno('Try finding center using phase correlation (y/n)?', 'n'): | |
1198 # t0 = time() | |
1199 # logger.debug(f'Running find_center_pc ...') | |
1200 # tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=0.1, rotc_guess=tomo_center) | |
1201 # error = 1. | |
1202 # while error > tol: | |
1203 # prev = tomo_center | |
1204 # tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=tol, | |
1205 # rotc_guess=tomo_center) | |
1206 # error = np.abs(tomo_center-prev) | |
1207 # logger.debug(f'... done in {time()-t0:.2f} seconds') | |
1208 # logger.info('Finding the center using the phase correlation method took '+ | |
1209 # f'{time()-t0:.2f} seconds') | |
1210 # center_offset = tomo_center-center | |
1211 # print(f'Center at row {row} using phase correlation = {center_offset:.2f}') | |
1212 # t0 = time() | |
1213 # logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...') | |
1214 # recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas, | |
1215 # eff_pixel_size, cross_sectional_dim, False, num_core) | |
1216 # logger.debug(f'... done in {time()-t0:.2f} seconds') | |
1217 # logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds') | |
1218 # | |
1219 # title = f'edges row{row} center_offset{center_offset:.2f} PC' | |
1220 # self._plot_edges_one_plane(recon_plane, title, path=path) | |
1221 | |
1222 # Select center location | |
1223 # if input_yesno('Accept a center location (y) or continue search (n)?', 'y'): | |
1224 if True: | |
1225 # center_offset = input_num(' Enter chosen center offset', ge=-center, le=center, | |
1226 # default=center_offset_vo) | |
1227 center_offset = center_offset_vo | |
1228 del sinogram_T | |
1229 del recon_plane | |
1230 return float(center_offset) | |
1231 | |
1232 # perform center finding search | |
1233 while True: | |
1234 center_offset_low = input_int('\nEnter lower bound for center offset', ge=-center, | |
1235 le=center) | |
1236 center_offset_upp = input_int('Enter upper bound for center offset', | |
1237 ge=center_offset_low, le=center) | |
1238 if center_offset_upp == center_offset_low: | |
1239 center_offset_step = 1 | |
1240 else: | |
1241 center_offset_step = input_int('Enter step size for center offset search', ge=1, | |
1242 le=center_offset_upp-center_offset_low) | |
1243 num_center_offset = 1+int((center_offset_upp-center_offset_low)/center_offset_step) | |
1244 center_offsets = np.linspace(center_offset_low, center_offset_upp, num_center_offset) | |
1245 for center_offset in center_offsets: | |
1246 if center_offset == center_offset_vo: | |
1247 continue | |
1248 t0 = time() | |
1249 logger.debug(f'Running _reconstruct_one_plane on {num_core} cores ...') | |
1250 recon_plane = self._reconstruct_one_plane(sinogram_T, center_offset+center, thetas, | |
1251 eff_pixel_size, cross_sectional_dim, False, num_core) | |
1252 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
1253 logger.info(f'Reconstructing center_offset {center_offset} took '+ | |
1254 f'{time()-t0:.2f} seconds') | |
1255 title = f'edges row{row} center_offset{center_offset:.2f}' | |
1256 self._plot_edges_one_plane(recon_plane, title, path=path) | |
1257 if input_int('\nContinue (0) or end the search (1)', ge=0, le=1): | |
1258 break | |
1259 | |
1260 del sinogram_T | |
1261 del recon_plane | |
1262 center_offset = input_num(' Enter chosen center offset', ge=-center, le=center) | |
1263 return float(center_offset) | |
1264 | |
1265 def _reconstruct_one_plane(self, tomo_plane_T, center, thetas, eff_pixel_size, | |
1266 cross_sectional_dim, plot_sinogram=True, num_core=1): | |
1267 """Invert the sinogram for a single tomography plane. | |
1268 """ | |
1269 # tomo_plane_T index order: column,theta | |
1270 assert(0 <= center < tomo_plane_T.shape[0]) | |
1271 center_offset = center-tomo_plane_T.shape[0]/2 | |
1272 two_offset = 2*int(np.round(center_offset)) | |
1273 two_offset_abs = np.abs(two_offset) | |
1274 max_rad = int(0.55*(cross_sectional_dim/eff_pixel_size)) # 10% slack to avoid edge effects | |
1275 if max_rad > 0.5*tomo_plane_T.shape[0]: | |
1276 max_rad = 0.5*tomo_plane_T.shape[0] | |
1277 dist_from_edge = max(1, int(np.floor((tomo_plane_T.shape[0]-two_offset_abs)/2.)-max_rad)) | |
1278 if two_offset >= 0: | |
1279 logger.debug(f'sinogram range = [{two_offset+dist_from_edge}, {-dist_from_edge}]') | |
1280 sinogram = tomo_plane_T[two_offset+dist_from_edge:-dist_from_edge,:] | |
1281 else: | |
1282 logger.debug(f'sinogram range = [{dist_from_edge}, {two_offset-dist_from_edge}]') | |
1283 sinogram = tomo_plane_T[dist_from_edge:two_offset-dist_from_edge,:] | |
1284 if not self.galaxy_flag and plot_sinogram: | |
1285 quick_imshow(sinogram.T, f'sinogram center offset{center_offset:.2f}', aspect='auto', | |
1286 path=self.output_folder, save_fig=self.save_figs, save_only=self.save_only, | |
1287 block=self.block) | |
1288 | |
1289 # Inverting sinogram | |
1290 t0 = time() | |
1291 recon_sinogram = iradon(sinogram, theta=thetas, circle=True) | |
1292 logger.debug(f'Inverting sinogram took {time()-t0:.2f} seconds') | |
1293 del sinogram | |
1294 | |
1295 # Performing Gaussian filtering and removing ring artifacts | |
1296 recon_parameters = None#self.config.get('recon_parameters') | |
1297 if recon_parameters is None: | |
1298 sigma = 1.0 | |
1299 ring_width = 15 | |
1300 else: | |
1301 sigma = recon_parameters.get('gaussian_sigma', 1.0) | |
1302 if not is_num(sigma, ge=0.0): | |
1303 logger.warning(f'Invalid gaussian_sigma ({sigma}) in _reconstruct_one_plane, '+ | |
1304 'set to a default value of 1.0') | |
1305 sigma = 1.0 | |
1306 ring_width = recon_parameters.get('ring_width', 15) | |
1307 if not is_int(ring_width, ge=0): | |
1308 logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+ | |
1309 'set to a default value of 15') | |
1310 ring_width = 15 | |
1311 t0 = time() | |
1312 recon_sinogram = spi.gaussian_filter(recon_sinogram, sigma, mode='nearest') | |
1313 recon_clean = np.expand_dims(recon_sinogram, axis=0) | |
1314 del recon_sinogram | |
1315 recon_clean = tomopy.misc.corr.remove_ring(recon_clean, rwidth=ring_width, ncore=num_core) | |
1316 logger.debug(f'Filtering and removing ring artifacts took {time()-t0:.2f} seconds') | |
1317 | |
1318 return recon_clean | |
1319 | |
1320 def _plot_edges_one_plane(self, recon_plane, title, path=None): | |
1321 vis_parameters = None#self.config.get('vis_parameters') | |
1322 if vis_parameters is None: | |
1323 weight = 0.1 | |
1324 else: | |
1325 weight = vis_parameters.get('denoise_weight', 0.1) | |
1326 if not is_num(weight, ge=0.0): | |
1327 logger.warning(f'Invalid weight ({weight}) in _plot_edges_one_plane, '+ | |
1328 'set to a default value of 0.1') | |
1329 weight = 0.1 | |
1330 edges = denoise_tv_chambolle(recon_plane, weight=weight) | |
1331 vmax = np.max(edges[0,:,:]) | |
1332 vmin = -vmax | |
1333 if path is None: | |
1334 path = self.output_folder | |
1335 quick_imshow(edges[0,:,:], f'{title} coolwarm', path=path, cmap='coolwarm', | |
1336 save_fig=self.save_figs, save_only=self.save_only, block=self.block) | |
1337 quick_imshow(edges[0,:,:], f'{title} gray', path=path, cmap='gray', vmin=vmin, vmax=vmax, | |
1338 save_fig=self.save_figs, save_only=self.save_only, block=self.block) | |
1339 del edges | |
1340 | |
1341 def _reconstruct_one_tomo_stack(self, tomo_stack, thetas, center_offsets=[], num_core=1, | |
1342 algorithm='gridrec'): | |
1343 """Reconstruct a single tomography stack. | |
1344 """ | |
1345 # tomo_stack order: row,theta,column | |
1346 # input thetas must be in degrees | |
1347 # centers_offset: tomography axis shift in pixels relative to column center | |
1348 # RV should we remove stripes? | |
1349 # https://tomopy.readthedocs.io/en/latest/api/tomopy.prep.stripe.html | |
1350 # RV should we remove rings? | |
1351 # https://tomopy.readthedocs.io/en/latest/api/tomopy.misc.corr.html | |
1352 # RV: Add an option to do (extra) secondary iterations later or to do some sort of convergence test? | |
1353 if not len(center_offsets): | |
1354 centers = np.zeros((tomo_stack.shape[0])) | |
1355 elif len(center_offsets) == 2: | |
1356 centers = np.linspace(center_offsets[0], center_offsets[1], tomo_stack.shape[0]) | |
1357 else: | |
1358 if center_offsets.size != tomo_stack.shape[0]: | |
1359 raise ValueError('center_offsets dimension mismatch in reconstruct_one_tomo_stack') | |
1360 centers = center_offsets | |
1361 centers += tomo_stack.shape[2]/2 | |
1362 | |
1363 # Get reconstruction parameters | |
1364 recon_parameters = None#self.config.get('recon_parameters') | |
1365 if recon_parameters is None: | |
1366 sigma = 2.0 | |
1367 secondary_iters = 0 | |
1368 ring_width = 15 | |
1369 else: | |
1370 sigma = recon_parameters.get('stripe_fw_sigma', 2.0) | |
1371 if not is_num(sigma, ge=0): | |
1372 logger.warning(f'Invalid stripe_fw_sigma ({sigma}) in '+ | |
1373 '_reconstruct_one_tomo_stack, set to a default value of 2.0') | |
1374 ring_width = 15 | |
1375 secondary_iters = recon_parameters.get('secondary_iters', 0) | |
1376 if not is_int(secondary_iters, ge=0): | |
1377 logger.warning(f'Invalid secondary_iters ({secondary_iters}) in '+ | |
1378 '_reconstruct_one_tomo_stack, set to a default value of 0 (skip them)') | |
1379 ring_width = 0 | |
1380 ring_width = recon_parameters.get('ring_width', 15) | |
1381 if not is_int(ring_width, ge=0): | |
1382 logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+ | |
1383 'set to a default value of 15') | |
1384 ring_width = 15 | |
1385 | |
1386 # Remove horizontal stripe | |
1387 t0 = time() | |
1388 if num_core > num_core_tomopy_limit: | |
1389 logger.debug('Running remove_stripe_fw on {num_core_tomopy_limit} cores ...') | |
1390 tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma, | |
1391 ncore=num_core_tomopy_limit) | |
1392 else: | |
1393 logger.debug(f'Running remove_stripe_fw on {num_core} cores ...') | |
1394 tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma, | |
1395 ncore=num_core) | |
1396 logger.debug(f'... tomopy.prep.stripe.remove_stripe_fw took {time()-t0:.2f} seconds') | |
1397 | |
1398 # Perform initial image reconstruction | |
1399 logger.debug('Performing initial image reconstruction') | |
1400 t0 = time() | |
1401 logger.debug(f'Running recon on {num_core} cores ...') | |
1402 tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers, | |
1403 sinogram_order=True, algorithm=algorithm, ncore=num_core) | |
1404 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
1405 logger.info(f'Performing initial image reconstruction took {time()-t0:.2f} seconds') | |
1406 | |
1407 # Run optional secondary iterations | |
1408 if secondary_iters > 0: | |
1409 logger.debug(f'Running {secondary_iters} secondary iterations') | |
1410 #options = {'method':'SIRT_CUDA', 'proj_type':'cuda', 'num_iter':secondary_iters} | |
1411 #RV: doesn't work for me: | |
1412 #"Error: CUDA error 803: system has unsupported display driver/cuda driver combination." | |
1413 #options = {'method':'SIRT', 'proj_type':'linear', 'MinConstraint': 0, 'num_iter':secondary_iters} | |
1414 #SIRT did not finish while running overnight | |
1415 #options = {'method':'SART', 'proj_type':'linear', 'num_iter':secondary_iters} | |
1416 options = {'method':'SART', 'proj_type':'linear', 'MinConstraint': 0, | |
1417 'num_iter':secondary_iters} | |
1418 t0 = time() | |
1419 logger.debug(f'Running recon on {num_core} cores ...') | |
1420 tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers, | |
1421 init_recon=tomo_recon_stack, options=options, sinogram_order=True, | |
1422 algorithm=tomopy.astra, ncore=num_core) | |
1423 logger.debug(f'... done in {time()-t0:.2f} seconds') | |
1424 logger.info(f'Performing secondary iterations took {time()-t0:.2f} seconds') | |
1425 | |
1426 # Remove ring artifacts | |
1427 t0 = time() | |
1428 tomopy.misc.corr.remove_ring(tomo_recon_stack, rwidth=ring_width, out=tomo_recon_stack, | |
1429 ncore=num_core) | |
1430 logger.debug(f'Removing ring artifacts took {time()-t0:.2f} seconds') | |
1431 | |
1432 return tomo_recon_stack | |
1433 | |
1434 def _resize_reconstructed_data(self, data, z_only=False): | |
1435 """Resize the reconstructed tomography data. | |
1436 """ | |
1437 # Data order: row(z),x,y or stack,row(z),x,y | |
1438 if isinstance(data, list): | |
1439 for stack in data: | |
1440 assert(stack.ndim == 3) | |
1441 num_tomo_stacks = len(data) | |
1442 tomo_recon_stacks = data | |
1443 else: | |
1444 assert(data.ndim == 3) | |
1445 num_tomo_stacks = 1 | |
1446 tomo_recon_stacks = [data] | |
1447 | |
1448 if z_only: | |
1449 x_bounds = None | |
1450 y_bounds = None | |
1451 else: | |
1452 # Selecting x bounds (in yz-plane) | |
1453 tomosum = 0 | |
1454 [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,2)) | |
1455 for i in range(num_tomo_stacks)] | |
1456 select_x_bounds = input_yesno('\nDo you want to change the image x-bounds (y/n)?', 'y') | |
1457 if not select_x_bounds: | |
1458 x_bounds = None | |
1459 else: | |
1460 accept = False | |
1461 index_ranges = None | |
1462 while not accept: | |
1463 mask, x_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, | |
1464 title='select x data range', legend='recon stack sum yz') | |
1465 while len(x_bounds) != 1: | |
1466 print('Please select exactly one continuous range') | |
1467 mask, x_bounds = draw_mask_1d(tomosum, title='select x data range', | |
1468 legend='recon stack sum yz') | |
1469 x_bounds = x_bounds[0] | |
1470 # quick_plot(tomosum, vlines=x_bounds, title='recon stack sum yz') | |
1471 # print(f'x_bounds = {x_bounds} (lower bound inclusive, upper bound '+ | |
1472 # 'exclusive)') | |
1473 # accept = input_yesno('Accept these bounds (y/n)?', 'y') | |
1474 accept = True | |
1475 logger.debug(f'x_bounds = {x_bounds}') | |
1476 | |
1477 # Selecting y bounds (in xz-plane) | |
1478 tomosum = 0 | |
1479 [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,1)) | |
1480 for i in range(num_tomo_stacks)] | |
1481 select_y_bounds = input_yesno('\nDo you want to change the image y-bounds (y/n)?', 'y') | |
1482 if not select_y_bounds: | |
1483 y_bounds = None | |
1484 else: | |
1485 accept = False | |
1486 index_ranges = None | |
1487 while not accept: | |
1488 mask, y_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, | |
1489 title='select x data range', legend='recon stack sum xz') | |
1490 while len(y_bounds) != 1: | |
1491 print('Please select exactly one continuous range') | |
1492 mask, y_bounds = draw_mask_1d(tomosum, title='select x data range', | |
1493 legend='recon stack sum xz') | |
1494 y_bounds = y_bounds[0] | |
1495 # quick_plot(tomosum, vlines=y_bounds, title='recon stack sum xz') | |
1496 # print(f'y_bounds = {y_bounds} (lower bound inclusive, upper bound '+ | |
1497 # 'exclusive)') | |
1498 # accept = input_yesno('Accept these bounds (y/n)?', 'y') | |
1499 accept = True | |
1500 logger.debug(f'y_bounds = {y_bounds}') | |
1501 | |
1502 # Selecting z bounds (in xy-plane) (only valid for a single image set) | |
1503 if num_tomo_stacks != 1: | |
1504 z_bounds = None | |
1505 else: | |
1506 tomosum = 0 | |
1507 [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(1,2)) | |
1508 for i in range(num_tomo_stacks)] | |
1509 select_z_bounds = input_yesno('Do you want to change the image z-bounds (y/n)?', 'n') | |
1510 if not select_z_bounds: | |
1511 z_bounds = None | |
1512 else: | |
1513 accept = False | |
1514 index_ranges = None | |
1515 while not accept: | |
1516 mask, z_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, | |
1517 title='select x data range', legend='recon stack sum xy') | |
1518 while len(z_bounds) != 1: | |
1519 print('Please select exactly one continuous range') | |
1520 mask, z_bounds = draw_mask_1d(tomosum, title='select x data range', | |
1521 legend='recon stack sum xy') | |
1522 z_bounds = z_bounds[0] | |
1523 # quick_plot(tomosum, vlines=z_bounds, title='recon stack sum xy') | |
1524 # print(f'z_bounds = {z_bounds} (lower bound inclusive, upper bound '+ | |
1525 # 'exclusive)') | |
1526 # accept = input_yesno('Accept these bounds (y/n)?', 'y') | |
1527 accept = True | |
1528 logger.debug(f'z_bounds = {z_bounds}') | |
1529 | |
1530 return(x_bounds, y_bounds, z_bounds) | |
1531 | |
1532 | |
1533 def run_tomo(input_file:str, output_file:str, modes:list[str], center_file=None, num_core=-1, | |
1534 output_folder='.', save_figs='no', test_mode=False) -> None: | |
1535 | |
1536 if test_mode: | |
1537 logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s' | |
1538 level = logging.getLevelName('INFO') | |
1539 logging.basicConfig(filename=f'{output_folder}/tomo.log', filemode='w', | |
1540 format=logging_format, level=level, force=True) | |
1541 logger.info(f'input_file = {input_file}') | |
1542 logger.info(f'center_file = {center_file}') | |
1543 logger.info(f'output_file = {output_file}') | |
1544 logger.debug(f'modes= {modes}') | |
1545 logger.debug(f'num_core= {num_core}') | |
1546 logger.info(f'output_folder = {output_folder}') | |
1547 logger.info(f'save_figs = {save_figs}') | |
1548 logger.info(f'test_mode = {test_mode}') | |
1549 | |
1550 # Check for correction modes | |
1551 if modes is None: | |
1552 modes = ['all'] | |
1553 logger.debug(f'modes {type(modes)} = {modes}') | |
1554 | |
1555 # Instantiate Tomo object | |
1556 tomo = Tomo(num_core=num_core, output_folder=output_folder, save_figs=save_figs, | |
1557 test_mode=test_mode) | |
1558 | |
1559 # Read input file | |
1560 data = tomo.read(input_file) | |
1561 | |
1562 # Generate reduced tomography images | |
1563 if 'reduce_data' in modes or 'all' in modes: | |
1564 data = tomo.gen_reduced_data(data) | |
1565 | |
1566 # Find rotation axis centers for the tomography stacks. | |
1567 center_data = None | |
1568 if 'find_center' in modes or 'all' in modes: | |
1569 center_data = tomo.find_centers(data) | |
1570 | |
1571 # Reconstruct tomography stacks | |
1572 if 'reconstruct_data' in modes or 'all' in modes: | |
1573 if center_data is None: | |
1574 # Read input file | |
1575 center_data = tomo.read(center_file) | |
1576 data = tomo.reconstruct_data(data, center_data) | |
1577 center_data = None | |
1578 | |
1579 # Combine reconstructed tomography stacks | |
1580 if 'combine_data' in modes or 'all' in modes: | |
1581 data = tomo.combine_data(data) | |
1582 | |
1583 # Write output file | |
1584 if not test_mode: | |
1585 if center_data is None: | |
1586 data = tomo.write(data, output_file) | |
1587 else: | |
1588 data = tomo.write(center_data, output_file) | |
1589 |