diff tomo.py @ 3:f9c52762c32c draft

"planemo upload for repository https://github.com/rolfverberg/galaxytools commit 7dce44d576e4149f31bdc2ee4dce0bb6962badb6"
author rv43
date Tue, 05 Apr 2022 18:23:54 +0000
parents b8977c98800b
children 845270a96464
line wrap: on
line diff
--- a/tomo.py	Thu Mar 31 20:48:17 2022 +0000
+++ b/tomo.py	Tue Apr 05 18:23:54 2022 +0000
@@ -34,13 +34,16 @@
 
     def __init__(self, num_core):
         cpu_count = mp.cpu_count()
+        logging.debug(f'start: num_core={num_core} cpu_count={cpu_count}')
         if num_core is None or num_core < 1 or num_core > cpu_count:
             self.num_core = cpu_count
         else:
             self.num_core = num_core
+        logging.debug(f'self.num_core={self.num_core}')
 
     def __enter__(self):
         self.num_core_org = ne.set_num_threads(self.num_core)
+        logging.debug(f'self.num_core={self.num_core}')
 
     def __exit__(self, exc_type, exc_value, traceback):
         ne.set_num_threads(self.num_core_org)
@@ -647,12 +650,12 @@
         logging.debug(f'tdf_cutoff = {tdf_cutoff}')
         logging.debug(f'tdf_mean = {tdf_mean}')
         np.nan_to_num(self.tdf, copy=False, nan=tdf_mean, posinf=tdf_mean, neginf=0.)
-        if not self.test_mode and not self.galaxy_flag:
+        if self.galaxy_flag:
+            msnc.quickImshow(self.tdf, title='dark field', name=dark_field_pngname,
+                    save_fig=True, save_only=True)
+        elif not self.test_mode:
             msnc.quickImshow(self.tdf, title='dark field', path=self.output_folder,
                     save_fig=self.save_plots, save_only=self.save_plots_only)
-        elif self.galaxy_flag:
-            msnc.quickImshow(self.tdf, title='dark field', name=dark_field_pngname,
-                    save_fig=True, save_only=True)
 
     def _genBright(self, tbf_files, bright_field_pngname):
         """Generate bright field.
@@ -681,12 +684,12 @@
             self.tbf -= self.tdf
         else:
             logging.warning('Dark field unavailable')
-        if not self.test_mode and not self.galaxy_flag:
+        if self.galaxy_flag:
+            msnc.quickImshow(self.tbf, title='bright field', name=bright_field_pngname,
+                    save_fig=True, save_only=True)
+        elif not self.test_mode:
             msnc.quickImshow(self.tbf, title='bright field', path=self.output_folder,
                     save_fig=self.save_plots, save_only=self.save_plots_only)
-        elif self.galaxy_flag:
-            msnc.quickImshow(self.tbf, title='bright field', name=bright_field_pngname,
-                    save_fig=True, save_only=True)
 
     def _setDetectorBounds(self, tomo_stack_files, tomo_field_pngname, detectorbounds_pngname):
         """Set vertical detector bounds for image stack.
@@ -743,14 +746,16 @@
             tomo_stack = msnc.loadImageStack(tomo_stack_files[0], self.config['data_filetype'],
                 stacks[0]['img_offset'], 1)
             x_sum = np.sum(tomo_stack[0,:,:], 1)
+            x_sum_min = x_sum.min()
+            x_sum_max = x_sum.max()
             title = f'tomography image at theta={self.config["theta_range"]["start"]}'
             msnc.quickImshow(tomo_stack[0,:,:], title=title, name=tomo_field_pngname,
-                    save_fig=True, save_only=True)
+                    save_fig=True, save_only=True, show_grid=True)
             msnc.quickPlot((range(x_sum.size), x_sum),
-                    ([img_x_bounds[0], img_x_bounds[0]], [x_sum.min(), x_sum.max()], 'r-'),
-                    ([img_x_bounds[1]-1, img_x_bounds[1]-1], [x_sum.min(), x_sum.max()], 'r-'),
+                    ([img_x_bounds[0], img_x_bounds[0]], [x_sum_min, x_sum_max], 'r-'),
+                    ([img_x_bounds[1]-1, img_x_bounds[1]-1], [x_sum_min, x_sum_max], 'r-'),
                     title='sum over theta and y', name=detectorbounds_pngname,
-                    save_fig=True, save_only=True)
+                    save_fig=True, save_only=True, show_grid=True)
             
             # Update config and save to file
             if preprocess is None:
@@ -758,6 +763,7 @@
             else:
                 preprocess['img_x_bounds'] = img_x_bounds
             self.cf.saveFile(self.config_out)
+            del x_sum
             return
 
         # For one tomography stack only: load the first image
@@ -777,9 +783,17 @@
             x_sum = np.sum(tomo_stack[0,:,:], 1)
             use_bounds = 'no'
             if img_x_bounds[0] is not None and img_x_bounds[1] is not None:
+                tmp = np.copy(tomo_stack[0,:,:])
+                tmp_max = tmp.max()
+                tmp[img_x_bounds[0],:] = tmp_max
+                tmp[img_x_bounds[1]-1,:] = tmp_max
+                msnc.quickImshow(tmp, title=title)
+                del tmp
+                x_sum_min = x_sum.min()
+                x_sum_max = x_sum.max()
                 msnc.quickPlot((range(x_sum.size), x_sum),
-                        ([img_x_bounds[0], img_x_bounds[0]], [x_sum.min(), x_sum.max()], 'r-'),
-                        ([img_x_bounds[1]-1, img_x_bounds[1]-1], [x_sum.min(), x_sum.max()], 'r-'),
+                        ([img_x_bounds[0], img_x_bounds[0]], [x_sum_min, x_sum_max], 'r-'),
+                        ([img_x_bounds[1]-1, img_x_bounds[1]-1], [x_sum_min, x_sum_max], 'r-'),
                         title='sum over theta and y')
                 print(f'lower bound = {img_x_bounds[0]} (inclusive)\n'+
                         f'upper bound = {img_x_bounds[1]} (exclusive)]')
@@ -799,11 +813,20 @@
                         save_fig=self.save_plots, save_only=True)
         else:
             x_sum = np.sum(self.tbf, 1)
+            x_sum_min = x_sum.min()
+            x_sum_max = x_sum.max()
             use_bounds = 'no'
             if img_x_bounds[0] is not None and img_x_bounds[1] is not None:
+                tmp = np.copy(self.tbf)
+                tmp_max = tmp.max()
+                tmp[img_x_bounds[0],:] = tmp_max
+                tmp[img_x_bounds[1]-1,:] = tmp_max
+                title = 'Bright field'
+                msnc.quickImshow(tmp, title=title)
+                del tmp
                 msnc.quickPlot((range(x_sum.size), x_sum),
-                        ([img_x_bounds[0], img_x_bounds[0]], [x_sum.min(), x_sum.max()], 'r-'),
-                        ([img_x_bounds[1]-1, img_x_bounds[1]-1], [x_sum.min(), x_sum.max()], 'r-'),
+                        ([img_x_bounds[0], img_x_bounds[0]], [x_sum_min, x_sum_max], 'r-'),
+                        ([img_x_bounds[1]-1, img_x_bounds[1]-1], [x_sum_min, x_sum_max], 'r-'),
                         title='sum over theta and y')
                 print(f'lower bound = {img_x_bounds[0]} (inclusive)\n'+
                         f'upper bound = {img_x_bounds[1]} (exclusive)]')
@@ -820,9 +843,16 @@
                     x_upp = int(x_upp+(x_upp-x_low)/10)
                     if x_upp >= x_sum.size:
                         x_upp = x_sum.size
+                    tmp = np.copy(self.tbf)
+                    tmp_max = tmp.max()
+                    tmp[x_low,:] = tmp_max
+                    tmp[x_upp-1,:] = tmp_max
+                    title = 'Bright field'
+                    msnc.quickImshow(tmp, title=title)
+                    del tmp
                     msnc.quickPlot((range(x_sum.size), x_sum),
-                            ([x_low, x_low], [x_sum.min(), x_sum.max()], 'r-'),
-                            ([x_upp, x_upp], [x_sum.min(), x_sum.max()], 'r-'),
+                            ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'),
+                            ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'),
                             title='sum over theta and y')
                     print(f'lower bound = {x_low} (inclusive)\nupper bound = {x_upp} (exclusive)]')
                     use_fit =  pyip.inputYesNo('Accept these bounds ([y]/n)?: ', blank=True)
@@ -837,6 +867,7 @@
                         x_sum[img_x_bounds[0]:img_x_bounds[1]],
                         title='sum over theta and y', path=self.output_folder,
                         save_fig=self.save_plots, save_only=True)
+            del x_sum
         logging.debug(f'img_x_bounds: {img_x_bounds}')
 
         if self.save_plots_only:
@@ -1069,13 +1100,10 @@
             stack.pop('reconstructed', 'reconstructed not found')
             find_center = self.config.get('find_center')
             if find_center:
+                find_center.pop('completed', 'completed not found')
                 if self.test_mode:
-                    find_center.pop('completed', 'completed not found')
                     find_center.pop('lower_center_offset', 'lower_center_offset not found')
                     find_center.pop('upper_center_offset', 'upper_center_offset not found')
-                else:
-                    if find_center.get('center_stack_index', -1) == index:
-                        self.config.pop('find_center')
             self.cf.saveFile(self.config_out)
 
         if self.tdf.size:
@@ -1083,7 +1111,7 @@
         del tbf
 
     def _reconstructOnePlane(self, tomo_plane_T, center, thetas_deg, eff_pixel_size,
-            cross_sectional_dim, plot_sinogram=True):
+            cross_sectional_dim, plot_sinogram=True, num_core=1):
         """Invert the sinogram for a single tomography plane.
         """
         # tomo_plane_T index order: column,theta
@@ -1117,24 +1145,28 @@
         recon_sinogram = spi.gaussian_filter(recon_sinogram, 0.5)
         recon_clean = np.expand_dims(recon_sinogram, axis=0)
         del recon_sinogram
-        recon_clean = tomopy.misc.corr.remove_ring(recon_clean, rwidth=17)
+        recon_clean = tomopy.misc.corr.remove_ring(recon_clean, rwidth=17, ncore=num_core)
         logging.debug(f'filtering and removing ring artifact took {time()-t0:.2f} seconds!')
         return recon_clean
 
-    def _plotEdgesOnePlane(self, recon_plane, base_name, weight=0.001):
+    def _plotEdgesOnePlane(self, recon_plane, title, name=None, weight=0.001):
         # RV parameters for the denoise, gaussian, and ring removal will be different for different feature sizes
         edges = denoise_tv_chambolle(recon_plane, weight = weight)
         vmax = np.max(edges[0,:,:])
         vmin = -vmax
-        msnc.quickImshow(edges[0,:,:], f'{base_name} coolwarm', path=self.output_folder,
-                save_fig=self.save_plots, save_only=self.save_plots_only, cmap='coolwarm')
-        msnc.quickImshow(edges[0,:,:], f'{base_name} gray', path=self.output_folder,
-                save_fig=self.save_plots, save_only=self.save_plots_only, cmap='gray',
-                vmin=vmin, vmax=vmax)
+        if self.galaxy_flag:
+            msnc.quickImshow(edges[0,:,:], title, name=name, save_fig=True, save_only=True,
+                    cmap='gray', vmin=vmin, vmax=vmax)
+        else:
+            msnc.quickImshow(edges[0,:,:], f'{title} coolwarm', path=self.output_folder,
+                    save_fig=self.save_plots, save_only=self.save_plots_only, cmap='coolwarm')
+            msnc.quickImshow(edges[0,:,:], f'{title} gray', path=self.output_folder,
+                    save_fig=self.save_plots, save_only=self.save_plots_only, cmap='gray',
+                    vmin=vmin, vmax=vmax)
         del edges
 
     def _findCenterOnePlane(self, sinogram, row, thetas_deg, eff_pixel_size, cross_sectional_dim,
-            tol=0.1):
+            tol=0.1, num_core=1, recon_pngname=None):
         """Find center for a single tomography plane.
         """
         # sinogram index order: theta,column
@@ -1143,20 +1175,31 @@
         center = sinogram.shape[1]/2
 
         # try automatic center finding routines for initial value
-        tomo_center = tomopy.find_center_vo(sinogram)
+        tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core)
         center_offset_vo = tomo_center-center
-        if not self.test_mode:
+        if self.test_mode or self.galaxy_flag:
+            logging.info(f'Center at row {row} using Nghia Vo’s method = {center_offset_vo:.2f}')
+            if self.test_mode:
+                del sinogram_T
+                return float(center_offset_vo)
+        else:
             print(f'Center at row {row} using Nghia Vo’s method = {center_offset_vo:.2f}')
+            if recon_pngname:
+                logging.warning('Ignoring recon_pngname in _findCenterOnePlane (only for Galaxy)')
         recon_plane = self._reconstructOnePlane(sinogram_T, tomo_center, thetas_deg,
-                eff_pixel_size, cross_sectional_dim, False)
-        if not self.test_mode:
-            base_name=f'edges row{row} center_offset_vo{center_offset_vo:.2f}'
-            self._plotEdgesOnePlane(recon_plane, base_name)
-        use_phase_corr = 'no'
-        if not self.test_mode:
-            use_phase_corr = pyip.inputYesNo('Try finding center using phase correlation '+
-                    '(y/[n])? ', blank=True)
-        if use_phase_corr == 'yes':
+                eff_pixel_size, cross_sectional_dim, False, num_core)
+        if self.galaxy_flag:
+            assert(isinstance(recon_pngname, str))
+            title = os.path.basename(recon_pngname)
+            self._plotEdgesOnePlane(recon_plane, title, name=recon_pngname)
+            del sinogram_T
+            del recon_plane
+            return float(center_offset_vo)
+        else:
+            title = f'edges row{row} center_offset_vo{center_offset_vo:.2f}'
+            self._plotEdgesOnePlane(recon_plane, title)
+        if (pyip.inputYesNo('Try finding center using phase correlation '+
+                '(y/[n])? ', blank=True) == 'yes'):
             tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=0.1,
                     rotc_guess=tomo_center)
             error = 1.
@@ -1168,24 +1211,18 @@
             center_offset = tomo_center-center
             print(f'Center at row {row} using phase correlation = {center_offset:.2f}')
             recon_plane = self._reconstructOnePlane(sinogram_T, tomo_center, thetas_deg,
-                    eff_pixel_size, cross_sectional_dim, False)
-            base_name=f'edges row{row} center_offset{center_offset:.2f}'
-            self._plotEdgesOnePlane(recon_plane, base_name)
-        accept_center = 'yes'
-        if not self.test_mode:
-            accept_center = pyip.inputYesNo('Accept a center location ([y]) or continue '+
-                    'search (n)? ', blank=True)
-        if accept_center != 'no':
+                    eff_pixel_size, cross_sectional_dim, False, num_core)
+            title = f'edges row{row} center_offset{center_offset:.2f}'
+            self._plotEdgesOnePlane(recon_plane, title)
+        if (pyip.inputYesNo('Accept a center location ([y]) or continue '+
+                'search (n)? ', blank=True) != 'no'):
             del sinogram_T
             del recon_plane
-            if self.test_mode:
+            center_offset = pyip.inputNum(
+                    f'    Enter chosen center offset [{-int(center)}, {int(center)}] '+
+                    f'([{center_offset_vo}])): ', blank=True)
+            if center_offset == '':
                 center_offset = center_offset_vo
-            else:
-                center_offset = pyip.inputNum(
-                        f'    Enter chosen center offset [{-int(center)}, {int(center)}] '+
-                        f'([{center_offset_vo}])): ', blank=True)
-                if center_offset == '':
-                    center_offset = center_offset_vo
             return float(center_offset)
 
         while True:
@@ -1204,9 +1241,9 @@
                         center_offset_step):
                 logging.info(f'center_offset = {center_offset}')
                 recon_plane = self._reconstructOnePlane(sinogram_T, center_offset+center,
-                        thetas_deg, eff_pixel_size, cross_sectional_dim, False)
-                base_name=f'edges row{row} center_offset{center_offset}'
-                self._plotEdgesOnePlane(recon_plane, base_name)
+                        thetas_deg, eff_pixel_size, cross_sectional_dim, False, num_core)
+                title = f'edges row{row} center_offset{center_offset}'
+                self._plotEdgesOnePlane(recon_plane, title)
             if pyip.inputInt('\nContinue (0) or end the search (1): ', min=0, max=1):
                 break
 
@@ -1266,7 +1303,8 @@
                     init_recon=tomo_recon_stack, options=options, sinogram_order=True,
                     algorithm=tomopy.astra, ncore=num_core)
         if True:
-            tomopy.misc.corr.remove_ring(tomo_recon_stack, rwidth=rwidth, out=tomo_recon_stack)
+            tomopy.misc.corr.remove_ring(tomo_recon_stack, rwidth=rwidth, out=tomo_recon_stack,
+                    ncore=num_core)
         return tomo_recon_stack
 
     def findImageFiles(self):
@@ -1359,6 +1397,22 @@
 
         return dark_files, bright_files, tomo_stack_files
 
+    def loadTomoStacks(self, input_name):
+        """Load tomography stacks (only for Galaxy).
+        """
+        assert(self.galaxy_flag)
+        t0 = time()
+        logging.info(f'Loading preprocessed tomography stack from {input_name} ...')
+        stack_info = self.config['stack_info']
+        stacks = stack_info.get('stacks')
+        assert(len(self.tomo_stacks) == stack_info['num'])
+        with np.load(input_name) as f:
+            for i,stack in enumerate(stacks):
+                self.tomo_stacks[i] = f[f'set_{stack["index"]}']
+                logging.info(f'loaded stack {i}: index = {stack["index"]}, shape = '+
+                        f'{self.tomo_stacks[i].shape}')
+        logging.info(f'... done in {time()-t0:.2f} seconds!')
+
     def genTomoStacks(self, tdf_files=None, tbf_files=None, tomo_stack_files=None,
             dark_field_pngname=None, bright_field_pngname=None, tomo_field_pngname=None,
             detectorbounds_pngname=None, output_name=None):
@@ -1459,13 +1513,34 @@
             stack['ref_height'] = stack['ref_height']+pixel_size
         self.cf.saveFile(self.config_out)
 
-    def findCenters(self):
+    def findCenters(self, row_bounds=None, center_rows=None, recon_low_pngname=None,
+            recon_upp_pngname=None, num_core=None):
         """Find rotation axis centers for the tomography stacks.
         """
+        if num_core is None:
+            num_core = self.num_core
         logging.debug('Find centers for tomography stacks')
         stacks = self.config['stack_info']['stacks']
         available_stacks = [stack['index'] for stack in stacks if stack.get('preprocessed', False)]
         logging.debug('Available stacks: {available_stacks}')
+        if self.galaxy_flag:
+            assert(isinstance(row_bounds, list) and len(row_bounds) == 2)
+            assert(isinstance(center_rows, list) and len(center_rows) == 2)
+            assert(isinstance(recon_low_pngname, str))
+            assert(isinstance(recon_upp_pngname, str))
+        else:
+            if row_bounds:
+                logging.warning('Ignoring row_bounds in findCenters (only for Galaxy)')
+                row_bounds = None
+            if center_rows:
+                logging.warning('Ignoring center_rows in findCenters (only for Galaxy)')
+                center_rows = None
+            if recon_low_pngname:
+                logging.warning('Ignoring recon_low_pngname in findCenters (only for Galaxy)')
+                recon_low_pngname = None
+            if recon_upp_pngname:
+                logging.warning('Ignoring recon_upp_pngname in findCenters (only for Galaxy)')
+                recon_upp_pngname = None
 
         # Check for valid available center stack index
         find_center = self.config.get('find_center')
@@ -1481,8 +1556,8 @@
                 else:
                     print('\nFound calibration center offset info for stack '+
                             f'{center_stack_index}')
-                    if (pyip.inputYesNo('Do you want to use this again (y/n)? ') == 'yes' and
-                            find_center.get('completed', False) == True):
+                    if (pyip.inputYesNo('Do you want to use this again ([y]/n)? ',
+                            blank=True) != 'no' and find_center.get('completed', False) == True):
                         return
 
         # Load the required preprocessed stack
@@ -1503,6 +1578,9 @@
                 stacks[0]['preprocessed'] = False
                 raise OSError('Unable to load the required preprocessed tomography stack')
             assert(stacks[0].get('preprocessed', False) == True)
+        elif self.galaxy_flag:
+            logging.error('CHECK/FIX FOR GALAXY')
+            #center_stack_index = stacks[int(num_tomo_stacks/2)]['index']
         else:
             while True:
                 if not center_stack_index:
@@ -1528,6 +1606,10 @@
             find_center = self.config['find_center']
         else:
             find_center['center_stack_index'] = center_stack_index
+        if not self.galaxy_flag:
+            row_bounds = find_center.get('row_bounds', None)
+            center_rows = [find_center.get('lower_row', None),
+                    find_center.get('upper_row', None)]
 
         # Set thetas (in degrees)
         theta_range = self.config['theta_range']
@@ -1545,26 +1627,54 @@
             raise ValueError('Detector pixel size unavailable')
         eff_pixel_size = 100.*pixel_size/zoom_perc
         logging.debug(f'eff_pixel_size = {eff_pixel_size}')
-        tomo_ref_heights = [stack['ref_height'] for stack in stacks]
         if num_tomo_stacks == 1:
-            n1 = 0
-            height = center_stack.shape[0]*eff_pixel_size
-            if not self.test_mode and pyip.inputYesNo('\nDo you want to reconstruct the full '+
-                    f'sample height ({height:.3f} mm) (y/n)? ') == 'no':
-                height = pyip.inputNum('\nEnter the desired reconstructed sample height '+
-                        f'in mm [0, {height:.3f}]: ', min=0, max=height)
-            n1 = int(0.5*(center_stack.shape[0]-height/eff_pixel_size))
+             accept = 'yes'
+             if not self.test_mode and not self.galaxy_flag:
+                 accept = 'no'
+                 print('\nSelect bounds for image reconstruction')
+                 if msnc.is_index_range(row_bounds, 0, center_stack.shape[0]):
+                     a_tmp = np.copy(center_stack[:,0,:])
+                     a_tmp_max = a_tmp.max()
+                     a_tmp[row_bounds[0],:] = a_tmp_max
+                     a_tmp[row_bounds[1]-1,:] = a_tmp_max
+                     print(f'lower bound = {row_bounds[0]} (inclusive)')
+                     print(f'upper bound = {row_bounds[1]} (exclusive)')
+                     msnc.quickImshow(a_tmp, title=f'center stack theta={theta_start}',
+                         aspect='auto')
+                     del a_tmp
+                     accept = pyip.inputYesNo('Accept these bounds ([y]/n)?: ', blank=True)
+             if accept == 'no':
+                 (n1, n2) = msnc.selectImageBounds(center_stack[:,0,:], 0,
+                         title=f'center stack theta={theta_start}')
+             else:
+                 n1 = row_bounds[0]
+                 n2 = row_bounds[1]
         else:
+            logging.error('CHECK/FIX FOR GALAXY')
+            tomo_ref_heights = [stack['ref_height'] for stack in stacks]
             n1 = int((1.+(tomo_ref_heights[0]+center_stack.shape[0]*eff_pixel_size-
                 tomo_ref_heights[1])/eff_pixel_size)/2)
-        n2 = center_stack.shape[0]-n1
+            n2 = center_stack.shape[0]-n1
         logging.info(f'n1 = {n1}, n2 = {n2} (n2-n1) = {(n2-n1)*eff_pixel_size:.3f} mm')
         if not center_stack.size:
             RuntimeError('Center stack not loaded')
-        if not self.test_mode:
-            msnc.quickImshow(center_stack[:,0,:], title=f'center stack theta={theta_start}',
-                    path=self.output_folder, save_fig=self.save_plots,
-                    save_only=self.save_plots_only)
+        if not self.test_mode and not self.galaxy_flag:
+            tmp = center_stack[:,0,:]
+            tmp_max = tmp.max()
+            tmp[n1,:] = tmp_max
+            tmp[n2-1,:] = tmp_max
+            if msnc.is_index_range(center_rows, 0, tmp.shape[0]):
+                tmp[center_rows[0],:] = tmp_max
+                tmp[center_rows[1]-1,:] = tmp_max
+            extent = [0, tmp.shape[1], tmp.shape[0], 0]
+            msnc.quickImshow(tmp, title=f'center stack theta={theta_start}',
+                    path=self.output_folder, extent=extent, save_fig=self.save_plots,
+                    save_only=self.save_plots_only, aspect='auto')
+            del tmp
+            #extent = [0, center_stack.shape[2], n2, n1]
+            #msnc.quickImshow(center_stack[n1:n2,0,:], title=f'center stack theta={theta_start}',
+            #        path=self.output_folder, extent=extent, save_fig=self.save_plots,
+            #        save_only=self.save_plots_only, show_grid=True, aspect='auto')
 
         # Get cross sectional diameter in mm
         cross_sectional_dim = center_stack.shape[2]*eff_pixel_size
@@ -1576,9 +1686,9 @@
         # Lower row center
         use_row = 'no'
         use_center = 'no'
-        row = find_center.get('lower_row')
+        row = center_rows[0]
         if msnc.is_int(row, n1, n2-2):
-            if self.test_mode:
+            if self.test_mode or self.galaxy_flag:
                 assert(row is not None)
                 use_row = 'yes'
             else:
@@ -1605,8 +1715,8 @@
                     msnc.clearFig(f'theta={theta_start}')
             # center_stack order: row,theta,column
             center_offset = self._findCenterOnePlane(center_stack[row,:,:], row, thetas_deg,
-                    eff_pixel_size, cross_sectional_dim)
-        logging.info(f'Lower center offset = {center_offset}')
+                    eff_pixel_size, cross_sectional_dim, num_core=num_core,
+                    recon_pngname=recon_low_pngname)
 
         # Update config and save to file
         find_center['row_bounds'] = [n1, n2]
@@ -1618,9 +1728,9 @@
         # Upper row center
         use_row = 'no'
         use_center = 'no'
-        row = find_center.get('upper_row')
+        row = center_rows[1]
         if msnc.is_int(row, lower_row+1, n2-1):
-            if self.test_mode:
+            if self.test_mode or self.galaxy_flag:
                 assert(row is not None)
                 use_row = 'yes'
             else:
@@ -1647,7 +1757,8 @@
                     msnc.clearFig(f'theta={theta_start}')
             # center_stack order: row,theta,column
             center_offset = self._findCenterOnePlane(center_stack[row,:,:], row, thetas_deg,
-                    eff_pixel_size, cross_sectional_dim)
+                    eff_pixel_size, cross_sectional_dim, num_core=num_core,
+                    recon_pngname=recon_upp_pngname)
         logging.info(f'upper_center_offset = {center_offset}')
         del center_stack
 
@@ -1731,14 +1842,22 @@
         # Update config file
         self.config = msnc.update('config.txt', 'check_centers', True, 'find_centers')
 
-    def reconstructTomoStacks(self):
+    def reconstructTomoStacks(self, output_name=None, num_core=None):
         """Reconstruct tomography stacks.
         """
+        if num_core is None:
+            num_core = self.num_core
         logging.debug('Reconstruct tomography stacks')
         stacks = self.config['stack_info']['stacks']
         assert(len(self.tomo_stacks) == self.config['stack_info']['num'])
         assert(len(self.tomo_stacks) == len(stacks))
         assert(len(self.tomo_recon_stacks) == len(stacks))
+        if self.galaxy_flag:
+            assert(isinstance(output_name, str))
+        else:
+            if output_name:
+                logging.warning('Ignoring output_name in reconstructTomoStacks '+
+                    '(only used in Galaxy)')
 
         # Get rotation axis rows and centers
         find_center = self.config['find_center']
@@ -1784,56 +1903,65 @@
             # reconstructed stack order for each one in stack : row/z,x,y
             # preprocessed stack order for each one in stack: row,theta,column
             index = stack['index']
-            available = False
-            if stack.get('reconstructed', False):
-                self.tomo_recon_stacks[i], available = self._loadTomo('recon stack', index)
-            if self.tomo_recon_stacks[i].size or available:
-                if self.tomo_stacks[i].size:
-                    self.tomo_stacks[i] = np.array([])
-                assert(stack.get('preprocessed', False) == True)
-                assert(stack.get('reconstructed', False) == True)
-                continue
-            else:
-                stack['reconstructed'] = False
-                if not self.tomo_stacks[i].size:
-                    self.tomo_stacks[i], available = self._loadTomo('red stack', index,
-                            required=True)
-                if not self.tomo_stacks[i].size:
-                    logging.error(f'Unable to load tomography stack {index} for reconstruction')
-                    stack[i]['preprocessed'] = False
-                    load_error = True
+            if not self.galaxy_flag:
+                available = False
+                if stack.get('reconstructed', False):
+                    self.tomo_recon_stacks[i], available = self._loadTomo('recon stack', index)
+                if self.tomo_recon_stacks[i].size or available:
+                    if self.tomo_stacks[i].size:
+                        self.tomo_stacks[i] = np.array([])
+                    assert(stack.get('preprocessed', False) == True)
+                    assert(stack.get('reconstructed', False) == True)
                     continue
-                assert(0 <= lower_row < upper_row < self.tomo_stacks[i].shape[0])
-                center_offsets = [lower_center_offset-lower_row*center_slope,
-                        upper_center_offset+(self.tomo_stacks[i].shape[0]-1-upper_row)*center_slope]
-                t0 = time()
-                self.tomo_recon_stacks[i]= self._reconstructOneTomoStack(self.tomo_stacks[i],
-                        thetas, center_offsets=center_offsets, sigma=0.1, num_core=self.num_core,
-                        algorithm='gridrec', run_secondary_sirt=True, secondary_iter=25)
-                logging.info(f'Reconstruction of stack {index} took {time()-t0:.2f} seconds!')
-                if not self.test_mode:
-                    row_slice = int(self.tomo_stacks[i].shape[0]/2) 
-                    title = f'{basetitle} {index} slice{row_slice}'
-                    msnc.quickImshow(self.tomo_recon_stacks[i][row_slice,:,:], title=title,
-                            path=self.output_folder, save_fig=self.save_plots,
-                            save_only=self.save_plots_only)
-                    msnc.quickPlot(self.tomo_recon_stacks[i]
-                            [row_slice,int(self.tomo_recon_stacks[i].shape[2]/2),:],
-                            title=f'{title} cut{int(self.tomo_recon_stacks[i].shape[2]/2)}',
-                            path=self.output_folder, save_fig=self.save_plots,
-                            save_only=self.save_plots_only)
-                    self._saveTomo('recon stack', self.tomo_recon_stacks[i], index)
-#                else:
-#                    np.savetxt(self.output_folder+f'recon_stack_{index}.txt',
-#                            self.tomo_recon_stacks[i][row_slice,:,:], fmt='%.6e')
-                self.tomo_stacks[i] = np.array([])
+            stack['reconstructed'] = False
+            if not self.tomo_stacks[i].size:
+                self.tomo_stacks[i], available = self._loadTomo('red stack', index,
+                        required=True)
+            if not self.tomo_stacks[i].size:
+                logging.error(f'Unable to load tomography stack {index} for reconstruction')
+                stack[i]['preprocessed'] = False
+                load_error = True
+                continue
+            assert(0 <= lower_row < upper_row < self.tomo_stacks[i].shape[0])
+            center_offsets = [lower_center_offset-lower_row*center_slope,
+                    upper_center_offset+(self.tomo_stacks[i].shape[0]-1-upper_row)*center_slope]
+            t0 = time()
+            self.tomo_recon_stacks[i]= self._reconstructOneTomoStack(self.tomo_stacks[i],
+                    thetas, center_offsets=center_offsets, sigma=0.1, num_core=num_core,
+                    algorithm='gridrec', run_secondary_sirt=True, secondary_iter=25)
+            logging.info(f'Reconstruction of stack {index} took {time()-t0:.2f} seconds!')
+            if not self.test_mode and not self.galaxy_flag:
+                row_slice = int(self.tomo_stacks[i].shape[0]/2) 
+                title = f'{basetitle} {index} slice{row_slice}'
+                msnc.quickImshow(self.tomo_recon_stacks[i][row_slice,:,:], title=title,
+                        path=self.output_folder, save_fig=self.save_plots,
+                        save_only=self.save_plots_only)
+                msnc.quickPlot(self.tomo_recon_stacks[i]
+                        [row_slice,int(self.tomo_recon_stacks[i].shape[2]/2),:],
+                        title=f'{title} cut{int(self.tomo_recon_stacks[i].shape[2]/2)}',
+                        path=self.output_folder, save_fig=self.save_plots,
+                        save_only=self.save_plots_only)
+                self._saveTomo('recon stack', self.tomo_recon_stacks[i], index)
+#            else:
+#                np.savetxt(self.output_folder+f'recon_stack_{index}.txt',
+#                        self.tomo_recon_stacks[i][row_slice,:,:], fmt='%.6e')
+            self.tomo_stacks[i] = np.array([])
 
-                # Update config and save to file
-                stack['reconstructed'] = True
-                combine_stacks = self.config.get('combine_stacks')
-                if combine_stacks and index in combine_stacks.get('stacks', []):
-                    combine_stacks['stacks'].remove(index)
-                self.cf.saveFile(self.config_out)
+            # Update config and save to file
+            stack['reconstructed'] = True
+            combine_stacks = self.config.get('combine_stacks')
+            if combine_stacks and index in combine_stacks.get('stacks', []):
+                combine_stacks['stacks'].remove(index)
+            self.cf.saveFile(self.config_out)
+
+        # Save reconstructed tomography stack to file
+        if self.galaxy_flag:
+            t0 = time()
+            logging.info(f'Saving reconstructed tomography stack to {output_name} ...')
+            save_stacks = {f'set_{stack["index"]}':tomo_stack
+                    for stack,tomo_stack in zip(stacks,self.tomo_recon_stacks)}
+            np.savez(output_name, **save_stacks)
+            logging.info(f'... done in {time()-t0:.2f} seconds!')
 
     def combineTomoStacks(self):
         """Combine the reconstructed tomography stacks.
@@ -2095,6 +2223,7 @@
             default=False,
             help='Test mode flag')
     parser.add_argument('--num_core',
+            type=int,
             default=-1,
             help='Number of cores')
     args = parser.parse_args()