changeset 0:98e23dff1de2 draft default tip

planemo upload for repository https://github.com/rolfverberg/galaxytools commit f8c4bdb31c20c468045ad5e6eb255a293244bc6c-dirty
author rv43
date Tue, 21 Mar 2023 16:22:42 +0000
parents
children
files fit.py general.py tomo_macros.xml tomo_reconstruct.py tomo_reconstruct.xml workflow/__main__.py workflow/__version__.py workflow/link_to_galaxy.py workflow/models.py workflow/run_tomo.py
diffstat 10 files changed, 7816 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/fit.py	Tue Mar 21 16:22:42 2023 +0000
@@ -0,0 +1,2576 @@
+#!/usr/bin/env python3
+
+# -*- coding: utf-8 -*-
+"""
+Created on Mon Dec  6 15:36:22 2021
+
+@author: rv43
+"""
+
+import logging
+
+from asteval import Interpreter, get_ast_names
+from copy import deepcopy
+from lmfit import Model, Parameters
+from lmfit.model import ModelResult
+from lmfit.models import ConstantModel, LinearModel, QuadraticModel, PolynomialModel,\
+        ExponentialModel, StepModel, RectangleModel, ExpressionModel, GaussianModel,\
+        LorentzianModel
+import numpy as np
+from os import cpu_count, getpid, listdir, mkdir, path
+from re import compile, sub
+from shutil import rmtree
+try:
+    from sympy import diff, simplify
+except:
+    pass
+try:
+    from joblib import Parallel, delayed
+    have_joblib = True
+except:
+    have_joblib = False
+try:
+    import xarray as xr
+    have_xarray = True
+except:
+    have_xarray = False
+
+try:
+    from .general import illegal_value, is_int, is_dict_series, is_index, index_nearest, \
+            almost_equal, quick_plot #, eval_expr
+except:
+    try:
+        from sys import path as syspath
+        syspath.append(f'/nfs/chess/user/rv43/msnctools/msnctools')
+        from general import illegal_value, is_int, is_dict_series, is_index, index_nearest, \
+                almost_equal, quick_plot #, eval_expr
+    except:
+        from general import illegal_value, is_int, is_dict_series, is_index, index_nearest, \
+                almost_equal, quick_plot #, eval_expr
+
+from sys import float_info
+float_min = float_info.min
+float_max = float_info.max
+
+# sigma = fwhm_factor*fwhm
+fwhm_factor = {
+    'gaussian': f'fwhm/(2*sqrt(2*log(2)))',
+    'lorentzian': f'0.5*fwhm',
+    'splitlorentzian': f'0.5*fwhm', # sigma = sigma_r
+    'voight': f'0.2776*fwhm', # sigma = gamma
+    'pseudovoight': f'0.5*fwhm'} # fraction = 0.5
+
+# amplitude = height_factor*height*fwhm
+height_factor = {
+    'gaussian': f'height*fwhm*0.5*sqrt(pi/log(2))',
+    'lorentzian': f'height*fwhm*0.5*pi',
+    'splitlorentzian': f'height*fwhm*0.5*pi', # sigma = sigma_r
+    'voight': f'3.334*height*fwhm', # sigma = gamma
+    'pseudovoight': f'1.268*height*fwhm'} # fraction = 0.5
+
+class Fit:
+    """Wrapper class for lmfit
+    """
+    def __init__(self, y, x=None, models=None, normalize=True, **kwargs):
+        if not isinstance(normalize, bool):
+            raise ValueError(f'Invalid parameter normalize ({normalize})')
+        self._mask = None
+        self._model = None
+        self._norm = None
+        self._normalized = False
+        self._parameters = Parameters()
+        self._parameter_bounds = None
+        self._parameter_norms = {}
+        self._linear_parameters = []
+        self._nonlinear_parameters = []
+        self._result = None
+        self._try_linear_fit = True
+        self._y = None
+        self._y_norm = None
+        self._y_range = None
+        if 'try_linear_fit' in kwargs:
+            try_linear_fit = kwargs.pop('try_linear_fit')
+            if not isinstance(try_linear_fit, bool):
+                illegal_value(try_linear_fit, 'try_linear_fit', 'Fit.fit', raise_error=True)
+            self._try_linear_fit = try_linear_fit
+        if y is not None:
+            if isinstance(y, (tuple, list, np.ndarray)):
+                self._x = np.asarray(x)
+            elif have_xarray and isinstance(y, xr.DataArray):
+                if x is not None:
+                    logging.warning('Ignoring superfluous input x ({x}) in Fit.__init__')
+                if y.ndim != 1:
+                    illegal_value(y.ndim, 'DataArray dimensions', 'Fit:__init__', raise_error=True)
+                self._x = np.asarray(y[y.dims[0]])
+            else:
+                illegal_value(y, 'y', 'Fit:__init__', raise_error=True)
+            self._y = y
+            if self._x.ndim != 1:
+                raise ValueError(f'Invalid dimension for input x ({self._x.ndim})')
+            if self._x.size != self._y.size:
+                raise ValueError(f'Inconsistent x and y dimensions ({self._x.size} vs '+
+                        f'{self._y.size})')
+            if 'mask' in kwargs:
+                self._mask = kwargs.pop('mask')
+            if self._mask is None:
+                y_min = float(self._y.min())
+                self._y_range = float(self._y.max())-y_min
+                if normalize and self._y_range > 0.0:
+                    self._norm = (y_min, self._y_range)
+            else:
+                self._mask = np.asarray(self._mask).astype(bool)
+                if self._x.size != self._mask.size:
+                    raise ValueError(f'Inconsistent x and mask dimensions ({self._x.size} vs '+
+                            f'{self._mask.size})')
+                y_masked = np.asarray(self._y)[~self._mask]
+                y_min = float(y_masked.min())
+                self._y_range = float(y_masked.max())-y_min
+                if normalize and self._y_range > 0.0:
+                    if normalize and self._y_range > 0.0:
+                        self._norm = (y_min, self._y_range)
+        if models is not None:
+            if callable(models) or isinstance(models, str):
+                kwargs = self.add_model(models, **kwargs)
+            elif isinstance(models, (tuple, list)):
+                for model in models:
+                    kwargs = self.add_model(model, **kwargs)
+            self.fit(**kwargs)
+
+    @classmethod
+    def fit_data(cls, y, models, x=None, normalize=True, **kwargs):
+        return(cls(y, x=x, models=models, normalize=normalize, **kwargs))
+
+    @property
+    def best_errors(self):
+        if self._result is None:
+            return(None)
+        return({name:self._result.params[name].stderr for name in sorted(self._result.params)
+                if name != 'tmp_normalization_offset_c'})
+
+    @property
+    def best_fit(self):
+        if self._result is None:
+            return(None)
+        return(self._result.best_fit)
+
+    @property
+    def best_parameters(self):
+        if self._result is None:
+            return(None)
+        parameters = {}
+        for name in sorted(self._result.params):
+            if name != 'tmp_normalization_offset_c':
+                par = self._result.params[name]
+                parameters[name] = {'value': par.value, 'error': par.stderr,
+                        'init_value': par.init_value, 'min': par.min, 'max': par.max,
+                        'vary': par.vary, 'expr': par.expr}
+        return(parameters)
+
+    @property
+    def best_results(self):
+        """Convert the input data array to a data set and add the fit results.
+        """
+        if self._result is None:
+            return(None)
+        if isinstance(self._y, xr.DataArray):
+            best_results = self._y.to_dataset()
+            dims = self._y.dims
+            fit_name = f'{self._y.name}_fit'
+        else:
+            coords = {'x': (['x'], self._x)}
+            dims = ('x')
+            best_results = xr.Dataset(coords=coords)
+            best_results['y'] = (dims, self._y)
+            fit_name = 'y_fit'
+        best_results[fit_name] = (dims, self.best_fit)
+        if self._mask is not None:
+            best_results['mask'] = self._mask
+        best_results.coords['par_names'] = ('peak', [name for name in self.best_values.keys()])
+        best_results['best_values'] = (['par_names'], [v for v in self.best_values.values()])
+        best_results['best_errors'] = (['par_names'], [v for v in self.best_errors.values()])
+        best_results.attrs['components'] = self.components
+        return(best_results)
+
+    @property
+    def best_values(self):
+        if self._result is None:
+            return(None)
+        return({name:self._result.params[name].value for name in sorted(self._result.params)
+                if name != 'tmp_normalization_offset_c'})
+
+    @property
+    def chisqr(self):
+        if self._result is None:
+            return(None)
+        return(self._result.chisqr)
+
+    @property
+    def components(self):
+        components = {}
+        if self._result is None:
+            logging.warning('Unable to collect components in Fit.components')
+            return(components)
+        for component in self._result.components:
+            if 'tmp_normalization_offset_c' in component.param_names:
+                continue
+            parameters = {}
+            for name in component.param_names:
+                par = self._parameters[name]
+                parameters[name] = {'free': par.vary, 'value': self._result.params[name].value}
+                if par.expr is not None:
+                    parameters[name]['expr'] = par.expr
+            expr = None
+            if isinstance(component, ExpressionModel):
+                name = component._name
+                if name[-1] == '_':
+                    name = name[:-1]
+                expr = component.expr
+            else:
+                prefix = component.prefix
+                if len(prefix):
+                    if prefix[-1] == '_':
+                        prefix = prefix[:-1]
+                    name = f'{prefix} ({component._name})'
+                else:
+                    name = f'{component._name}'
+            if expr is None:
+                components[name] = {'parameters': parameters}
+            else:
+                components[name] = {'expr': expr, 'parameters': parameters}
+        return(components)
+
+    @property
+    def covar(self):
+        if self._result is None:
+            return(None)
+        return(self._result.covar)
+
+    @property
+    def init_parameters(self):
+        if self._result is None or self._result.init_params is None:
+            return(None)
+        parameters = {}
+        for name in sorted(self._result.init_params):
+            if name != 'tmp_normalization_offset_c':
+                par = self._result.init_params[name]
+                parameters[name] = {'value': par.value, 'min': par.min, 'max': par.max,
+                        'vary': par.vary, 'expr': par.expr}
+        return(parameters)
+
+    @property
+    def init_values(self):
+        if self._result is None or self._result.init_params is None:
+            return(None)
+        return({name:self._result.init_params[name].value for name in
+                sorted(self._result.init_params) if name != 'tmp_normalization_offset_c'})
+
+    @property
+    def normalization_offset(self):
+        if self._result is None:
+            return(None)
+        if self._norm is None:
+            return(0.0)
+        else:
+            if self._result.init_params is not None:
+                normalization_offset = self._result.init_params['tmp_normalization_offset_c']
+            else:
+                normalization_offset = self._result.params['tmp_normalization_offset_c']
+            return(normalization_offset)
+
+    @property
+    def num_func_eval(self):
+        if self._result is None:
+            return(None)
+        return(self._result.nfev)
+
+    @property
+    def parameters(self):
+        return({name:{'min': par.min, 'max': par.max, 'vary': par.vary, 'expr': par.expr}
+                for name, par in self._parameters.items() if name != 'tmp_normalization_offset_c'})
+
+    @property
+    def redchi(self):
+        if self._result is None:
+            return(None)
+        return(self._result.redchi)
+
+    @property
+    def residual(self):
+        if self._result is None:
+            return(None)
+        return(self._result.residual)
+
+    @property
+    def success(self):
+        if self._result is None:
+            return(None)
+        if not self._result.success:
+#            print(f'ier = {self._result.ier}')
+#            print(f'lmdif_message = {self._result.lmdif_message}')
+#            print(f'message = {self._result.message}')
+#            print(f'nfev = {self._result.nfev}')
+#            print(f'redchi = {self._result.redchi}')
+#            print(f'success = {self._result.success}')
+            if self._result.ier == 0 or self._result.ier == 5:
+                logging.warning(f'ier = {self._result.ier}: {self._result.message}')
+            else:
+                logging.warning(f'ier = {self._result.ier}: {self._result.message}')
+                return(True)
+#            self.print_fit_report()
+#            self.plot()
+        return(self._result.success)
+
+    @property
+    def var_names(self):
+        """Intended to be used with covar
+        """
+        if self._result is None:
+            return(None)
+        return(getattr(self._result, 'var_names', None))
+
+    @property
+    def x(self):
+        return(self._x)
+
+    @property
+    def y(self):
+        return(self._y)
+
+    def print_fit_report(self, result=None, show_correl=False):
+        if result is None:
+            result = self._result
+        if result is not None:
+            print(result.fit_report(show_correl=show_correl))
+
+    def add_parameter(self, **parameter):
+        if not isinstance(parameter, dict):
+            raise ValueError(f'Invalid parameter ({parameter})')
+        if parameter.get('expr') is not None:
+            raise KeyError(f'Illegal "expr" key in parameter {parameter}')
+        name = parameter['name']
+        if not isinstance(name, str):
+            raise ValueError(f'Illegal "name" value ({name}) in parameter {parameter}')
+        if parameter.get('norm') is None:
+            self._parameter_norms[name] = False
+        else:
+            norm = parameter.pop('norm')
+            if self._norm is None:
+                logging.warning(f'Ignoring norm in parameter {name} in '+
+                            f'Fit.add_parameter (normalization is turned off)')
+                self._parameter_norms[name] = False
+            else:
+                if not isinstance(norm, bool):
+                    raise ValueError(f'Illegal "norm" value ({norm}) in parameter {parameter}')
+                self._parameter_norms[name] = norm
+        vary = parameter.get('vary')
+        if vary is not None:
+            if not isinstance(vary, bool):
+                raise ValueError(f'Illegal "vary" value ({vary}) in parameter {parameter}')
+            if not vary:
+                if 'min' in parameter:
+                    logging.warning(f'Ignoring min in parameter {name} in '+
+                            f'Fit.add_parameter (vary = {vary})')
+                    parameter.pop('min')
+                if 'max' in parameter:
+                    logging.warning(f'Ignoring max in parameter {name} in '+
+                            f'Fit.add_parameter (vary = {vary})')
+                    parameter.pop('max')
+        if self._norm is not None and name not in self._parameter_norms:
+            raise ValueError(f'Missing parameter normalization type for paremeter {name}')
+        self._parameters.add(**parameter)
+
+    def add_model(self, model, prefix=None, parameters=None, parameter_norms=None, **kwargs):
+        # Create the new model
+#        print(f'at start add_model:\nself._parameters:\n{self._parameters}')
+#        print(f'at start add_model: kwargs = {kwargs}')
+#        print(f'parameters = {parameters}')
+#        print(f'parameter_norms = {parameter_norms}')
+#        if len(self._parameters.keys()):
+#            print('\nAt start adding model:')
+#            self._parameters.pretty_print()
+#            print(f'parameter_norms:\n{self._parameter_norms}')
+        if prefix is not None and not isinstance(prefix, str):
+            logging.warning('Ignoring illegal prefix: {model} {type(model)}')
+            prefix = None
+        if prefix is None:
+            pprefix = ''
+        else:
+            pprefix = prefix
+        if parameters is not None:
+            if isinstance(parameters, dict):
+                parameters = (parameters, )
+            elif not is_dict_series(parameters):
+                illegal_value(parameters, 'parameters', 'Fit.add_model', raise_error=True)
+            parameters = deepcopy(parameters)
+        if parameter_norms is not None:
+            if isinstance(parameter_norms, dict):
+                parameter_norms = (parameter_norms, )
+            if not is_dict_series(parameter_norms):
+                illegal_value(parameter_norms, 'parameter_norms', 'Fit.add_model', raise_error=True)
+        new_parameter_norms = {}
+        if callable(model):
+            # Linear fit not yet implemented for callable models
+            self._try_linear_fit = False
+            if parameter_norms is None:
+                if parameters is None:
+                    raise ValueError('Either "parameters" or "parameter_norms" is required in '+
+                            f'{model}')
+                for par in parameters:
+                    name = par['name']
+                    if not isinstance(name, str):
+                        raise ValueError(f'Illegal "name" value ({name}) in input parameters')
+                    if par.get('norm') is not None:
+                        norm = par.pop('norm')
+                        if not isinstance(norm, bool):
+                            raise ValueError(f'Illegal "norm" value ({norm}) in input parameters')
+                        new_parameter_norms[f'{pprefix}{name}'] = norm
+            else:
+                for par in parameter_norms:
+                    name = par['name']
+                    if not isinstance(name, str):
+                        raise ValueError(f'Illegal "name" value ({name}) in input parameters')
+                    norm = par.get('norm')
+                    if norm is None or not isinstance(norm, bool):
+                        raise ValueError(f'Illegal "norm" value ({norm}) in input parameters')
+                    new_parameter_norms[f'{pprefix}{name}'] = norm
+            if parameters is not None:
+                for par in parameters:
+                    if par.get('expr') is not None:
+                        raise KeyError(f'Illegal "expr" key ({par.get("expr")}) in parameter '+
+                                f'{name} for a callable model {model}')
+                    name = par['name']
+                    if not isinstance(name, str):
+                        raise ValueError(f'Illegal "name" value ({name}) in input parameters')
+# RV FIX callable model will need partial deriv functions for any linear pars to get the linearized matrix, so for now skip linear solution option
+            newmodel = Model(model, prefix=prefix)
+        elif isinstance(model, str):
+            if model == 'constant':      # Par: c
+                newmodel = ConstantModel(prefix=prefix)
+                new_parameter_norms[f'{pprefix}c'] = True
+                self._linear_parameters.append(f'{pprefix}c')
+            elif model == 'linear':      # Par: slope, intercept
+                newmodel = LinearModel(prefix=prefix)
+                new_parameter_norms[f'{pprefix}slope'] = True
+                new_parameter_norms[f'{pprefix}intercept'] = True
+                self._linear_parameters.append(f'{pprefix}slope')
+                self._linear_parameters.append(f'{pprefix}intercept')
+            elif model == 'quadratic':   # Par: a, b, c
+                newmodel = QuadraticModel(prefix=prefix)
+                new_parameter_norms[f'{pprefix}a'] = True
+                new_parameter_norms[f'{pprefix}b'] = True
+                new_parameter_norms[f'{pprefix}c'] = True
+                self._linear_parameters.append(f'{pprefix}a')
+                self._linear_parameters.append(f'{pprefix}b')
+                self._linear_parameters.append(f'{pprefix}c')
+            elif model == 'gaussian':    # Par: amplitude, center, sigma (fwhm, height)
+                newmodel = GaussianModel(prefix=prefix)
+                new_parameter_norms[f'{pprefix}amplitude'] = True
+                new_parameter_norms[f'{pprefix}center'] = False
+                new_parameter_norms[f'{pprefix}sigma'] = False
+                self._linear_parameters.append(f'{pprefix}amplitude')
+                self._nonlinear_parameters.append(f'{pprefix}center')
+                self._nonlinear_parameters.append(f'{pprefix}sigma')
+                # parameter norms for height and fwhm are needed to get correct errors
+                new_parameter_norms[f'{pprefix}height'] = True
+                new_parameter_norms[f'{pprefix}fwhm'] = False
+            elif model == 'lorentzian':    # Par: amplitude, center, sigma (fwhm, height)
+                newmodel = LorentzianModel(prefix=prefix)
+                new_parameter_norms[f'{pprefix}amplitude'] = True
+                new_parameter_norms[f'{pprefix}center'] = False
+                new_parameter_norms[f'{pprefix}sigma'] = False
+                self._linear_parameters.append(f'{pprefix}amplitude')
+                self._nonlinear_parameters.append(f'{pprefix}center')
+                self._nonlinear_parameters.append(f'{pprefix}sigma')
+                # parameter norms for height and fwhm are needed to get correct errors
+                new_parameter_norms[f'{pprefix}height'] = True
+                new_parameter_norms[f'{pprefix}fwhm'] = False
+            elif model == 'exponential': # Par: amplitude, decay
+                newmodel = ExponentialModel(prefix=prefix)
+                new_parameter_norms[f'{pprefix}amplitude'] = True
+                new_parameter_norms[f'{pprefix}decay'] = False
+                self._linear_parameters.append(f'{pprefix}amplitude')
+                self._nonlinear_parameters.append(f'{pprefix}decay')
+            elif model == 'step':        # Par: amplitude, center, sigma
+                form = kwargs.get('form')
+                if form is not None:
+                    kwargs.pop('form')
+                if form is None or form not in ('linear', 'atan', 'arctan', 'erf', 'logistic'):
+                    raise ValueError(f'Invalid parameter form for build-in step model ({form})')
+                newmodel = StepModel(prefix=prefix, form=form)
+                new_parameter_norms[f'{pprefix}amplitude'] = True
+                new_parameter_norms[f'{pprefix}center'] = False
+                new_parameter_norms[f'{pprefix}sigma'] = False
+                self._linear_parameters.append(f'{pprefix}amplitude')
+                self._nonlinear_parameters.append(f'{pprefix}center')
+                self._nonlinear_parameters.append(f'{pprefix}sigma')
+            elif model == 'rectangle':   # Par: amplitude, center1, center2, sigma1, sigma2
+                form = kwargs.get('form')
+                if form is not None:
+                    kwargs.pop('form')
+                if form is None or form not in ('linear', 'atan', 'arctan', 'erf', 'logistic'):
+                    raise ValueError('Invalid parameter form for build-in rectangle model '+
+                            f'({form})')
+                newmodel = RectangleModel(prefix=prefix, form=form)
+                new_parameter_norms[f'{pprefix}amplitude'] = True
+                new_parameter_norms[f'{pprefix}center1'] = False
+                new_parameter_norms[f'{pprefix}center2'] = False
+                new_parameter_norms[f'{pprefix}sigma1'] = False
+                new_parameter_norms[f'{pprefix}sigma2'] = False
+                self._linear_parameters.append(f'{pprefix}amplitude')
+                self._nonlinear_parameters.append(f'{pprefix}center1')
+                self._nonlinear_parameters.append(f'{pprefix}center2')
+                self._nonlinear_parameters.append(f'{pprefix}sigma1')
+                self._nonlinear_parameters.append(f'{pprefix}sigma2')
+            elif model == 'expression':  # Par: by expression
+                expr = kwargs['expr']
+                if not isinstance(expr, str):
+                    raise ValueError(f'Illegal "expr" value ({expr}) in {model}')
+                kwargs.pop('expr')
+                if parameter_norms is not None:
+                        logging.warning('Ignoring parameter_norms (normalization determined from '+
+                                'linearity)}')
+                if parameters is not None:
+                    for par in parameters:
+                        if par.get('expr') is not None:
+                            raise KeyError(f'Illegal "expr" key ({par.get("expr")}) in parameter '+
+                                    f'({par}) for an expression model')
+                        if par.get('norm') is not None:
+                            logging.warning(f'Ignoring "norm" key in parameter ({par}) '+
+                                '(normalization determined from linearity)}')
+                            par.pop('norm')
+                        name = par['name']
+                        if not isinstance(name, str):
+                            raise ValueError(f'Illegal "name" value ({name}) in input parameters')
+                ast = Interpreter()
+                expr_parameters = [name for name in get_ast_names(ast.parse(expr))
+                        if name != 'x' and name not in self._parameters
+                        and name not in ast.symtable]
+#                print(f'\nexpr_parameters: {expr_parameters}')
+#                print(f'expr = {expr}')
+                if prefix is None:
+                    newmodel = ExpressionModel(expr=expr)
+                else:
+                    for name in expr_parameters:
+                        expr = sub(rf'\b{name}\b', f'{prefix}{name}', expr)
+                    expr_parameters = [f'{prefix}{name}' for name in expr_parameters]
+#                    print(f'\nexpr_parameters: {expr_parameters}')
+#                    print(f'expr = {expr}')
+                    newmodel = ExpressionModel(expr=expr, name=name)
+#                print(f'\nnewmodel = {newmodel.__dict__}')
+#                print(f'params_names = {newmodel._param_names}')
+#                print(f'params_names = {newmodel.param_names}')
+                # Remove already existing names
+                for name in newmodel.param_names.copy():
+                    if name not in expr_parameters:
+                        newmodel._func_allargs.remove(name)
+                        newmodel._param_names.remove(name)
+#                print(f'params_names = {newmodel._param_names}')
+#                print(f'params_names = {newmodel.param_names}')
+            else:
+                raise ValueError(f'Unknown build-in fit model ({model})')
+        else:
+            illegal_value(model, 'model', 'Fit.add_model', raise_error=True)
+
+        # Add the new model to the current one
+#        print('\nBefore adding model:')
+#        print(f'\nnewmodel = {newmodel.__dict__}')
+#        if len(self._parameters):
+#            self._parameters.pretty_print()
+        if self._model is None:
+            self._model = newmodel
+        else:
+            self._model += newmodel
+        new_parameters = newmodel.make_params()
+        self._parameters += new_parameters
+#        print('\nAfter adding model:')
+#        print(f'\nnewmodel = {newmodel.__dict__}')
+#        print(f'\nnew_parameters = {new_parameters}')
+#        self._parameters.pretty_print()
+
+        # Check linearity of expression model paremeters
+        if isinstance(newmodel, ExpressionModel):
+            for name in newmodel.param_names:
+                if not diff(newmodel.expr, name, name):
+                    if name not in self._linear_parameters:
+                        self._linear_parameters.append(name)
+                        new_parameter_norms[name] = True
+#                        print(f'\nADDING {name} TO LINEAR')
+                else:
+                    if name not in self._nonlinear_parameters:
+                        self._nonlinear_parameters.append(name)
+                        new_parameter_norms[name] = False
+#                        print(f'\nADDING {name} TO NONLINEAR')
+#        print(f'new_parameter_norms:\n{new_parameter_norms}')
+
+        # Scale the default initial model parameters
+        if self._norm is not None:
+            for name, norm in new_parameter_norms.copy().items():
+                par = self._parameters.get(name)
+                if par is None:
+                    new_parameter_norms.pop(name)
+                    continue
+                if par.expr is None and norm:
+                    value = par.value*self._norm[1]
+                    _min = par.min
+                    _max = par.max
+                    if not np.isinf(_min) and abs(_min) != float_min:
+                        _min *= self._norm[1]
+                    if not np.isinf(_max) and abs(_max) != float_min:
+                        _max *= self._norm[1]
+                    par.set(value=value, min=_min, max=_max)
+#        print('\nAfter norm defaults:')
+#        self._parameters.pretty_print()
+#        print(f'parameters:\n{parameters}')
+#        print(f'all_parameters:\n{list(self.parameters)}')
+#        print(f'new_parameter_norms:\n{new_parameter_norms}')
+#        print(f'parameter_norms:\n{self._parameter_norms}')
+
+        # Initialize the model parameters from parameters
+        if prefix is None:
+            prefix = ""
+        if parameters is not None:
+            for parameter in parameters:
+                name = parameter['name']
+                if not isinstance(name, str):
+                    raise ValueError(f'Illegal "name" value ({name}) in input parameters')
+                if name not in new_parameters:
+                    name = prefix+name
+                    parameter['name'] = name
+                if name not in new_parameters:
+                    logging.warning(f'Ignoring superfluous parameter info for {name}')
+                    continue
+                if name in self._parameters:
+                    parameter.pop('name')
+                    if 'norm' in parameter:
+                        if not isinstance(parameter['norm'], bool):
+                            illegal_value(parameter['norm'], 'norm', 'Fit.add_model',
+                                    raise_error=True)
+                        new_parameter_norms[name] = parameter['norm']
+                        parameter.pop('norm')
+                    if parameter.get('expr') is not None:
+                        if 'value' in parameter:
+                            logging.warning(f'Ignoring value in parameter {name} '+
+                                    f'(set by expression: {parameter["expr"]})')
+                            parameter.pop('value')
+                        if 'vary' in parameter:
+                            logging.warning(f'Ignoring vary in parameter {name} '+
+                                    f'(set by expression: {parameter["expr"]})')
+                            parameter.pop('vary')
+                        if 'min' in parameter:
+                            logging.warning(f'Ignoring min in parameter {name} '+
+                                    f'(set by expression: {parameter["expr"]})')
+                            parameter.pop('min')
+                        if 'max' in parameter:
+                            logging.warning(f'Ignoring max in parameter {name} '+
+                                    f'(set by expression: {parameter["expr"]})')
+                            parameter.pop('max')
+                    if 'vary' in parameter:
+                        if not isinstance(parameter['vary'], bool):
+                            illegal_value(parameter['vary'], 'vary', 'Fit.add_model',
+                                    raise_error=True)
+                        if not parameter['vary']:
+                            if 'min' in parameter:
+                                logging.warning(f'Ignoring min in parameter {name} in '+
+                                        f'Fit.add_model (vary = {parameter["vary"]})')
+                                parameter.pop('min')
+                            if 'max' in parameter:
+                                logging.warning(f'Ignoring max in parameter {name} in '+
+                                        f'Fit.add_model (vary = {parameter["vary"]})')
+                                parameter.pop('max')
+                    self._parameters[name].set(**parameter)
+                    parameter['name'] = name
+                else:
+                    illegal_value(parameter, 'parameter name', 'Fit.model', raise_error=True)
+        self._parameter_norms = {**self._parameter_norms, **new_parameter_norms}
+#        print('\nAfter parameter init:')
+#        self._parameters.pretty_print()
+#        print(f'parameters:\n{parameters}')
+#        print(f'new_parameter_norms:\n{new_parameter_norms}')
+#        print(f'parameter_norms:\n{self._parameter_norms}')
+#        print(f'kwargs:\n{kwargs}')
+
+        # Initialize the model parameters from kwargs
+        for name, value in {**kwargs}.items():
+            full_name = f'{pprefix}{name}'
+            if full_name in new_parameter_norms and isinstance(value, (int, float)):
+                kwargs.pop(name)
+                if self._parameters[full_name].expr is None:
+                    self._parameters[full_name].set(value=value)
+                else:
+                    logging.warning(f'Ignoring parameter {name} in Fit.fit (set by expression: '+
+                            f'{self._parameters[full_name].expr})')
+#        print('\nAfter kwargs init:')
+#        self._parameters.pretty_print()
+#        print(f'parameter_norms:\n{self._parameter_norms}')
+#        print(f'kwargs:\n{kwargs}')
+
+        # Check parameter norms (also need it for expressions to renormalize the errors)
+        if self._norm is not None and (callable(model) or model == 'expression'):
+            missing_norm = False
+            for name in new_parameters.valuesdict():
+                if name not in self._parameter_norms:
+                    print(f'new_parameters:\n{new_parameters.valuesdict()}')
+                    print(f'self._parameter_norms:\n{self._parameter_norms}')
+                    logging.error(f'Missing parameter normalization type for {name} in {model}')
+                    missing_norm = True
+            if missing_norm:
+                raise ValueError
+
+#        print(f'at end add_model:\nself._parameters:\n{list(self.parameters)}')
+#        print(f'at end add_model: kwargs = {kwargs}')
+#        print(f'\nat end add_model: newmodel:\n{newmodel.__dict__}\n')
+        return(kwargs)
+
+    def fit(self, interactive=False, guess=False, **kwargs):
+        # Check inputs
+        if self._model is None:
+            logging.error('Undefined fit model')
+            return
+        if not isinstance(interactive, bool):
+            illegal_value(interactive, 'interactive', 'Fit.fit', raise_error=True)
+        if not isinstance(guess, bool):
+            illegal_value(guess, 'guess', 'Fit.fit', raise_error=True)
+        if 'try_linear_fit' in kwargs:
+            try_linear_fit = kwargs.pop('try_linear_fit')
+            if not isinstance(try_linear_fit, bool):
+                illegal_value(try_linear_fit, 'try_linear_fit', 'Fit.fit', raise_error=True)
+            if not self._try_linear_fit:
+                logging.warning('Ignore superfluous keyword argument "try_linear_fit" (not '+
+                        'yet supported for callable models)')
+            else:
+                self._try_linear_fit = try_linear_fit
+#        if self._result is None:
+#            if 'parameters' in kwargs:
+#                raise ValueError('Invalid parameter parameters ({kwargs["parameters"]})')
+#        else:
+        if self._result is not None:
+            if guess:
+                logging.warning('Ignoring input parameter guess in Fit.fit during refitting')
+                guess = False
+
+        # Check for circular expressions
+        # FIX TODO
+#        for name1, par1 in self._parameters.items():
+#            if par1.expr is not None:
+
+        # Apply mask if supplied:
+        if 'mask' in kwargs:
+            self._mask = kwargs.pop('mask')
+        if self._mask is not None:
+            self._mask = np.asarray(self._mask).astype(bool)
+            if self._x.size != self._mask.size:
+                raise ValueError(f'Inconsistent x and mask dimensions ({self._x.size} vs '+
+                        f'{self._mask.size})')
+
+        # Estimate initial parameters with build-in lmfit guess method (only for a single model)
+#        print(f'\nat start fit: kwargs = {kwargs}')
+#RV        print('\nAt start of fit:')
+#RV        self._parameters.pretty_print()
+#        print(f'parameter_norms:\n{self._parameter_norms}')
+        if guess:
+            if self._mask is None:
+                self._parameters = self._model.guess(self._y, x=self._x)
+            else:
+                self._parameters = self._model.guess(np.asarray(self._y)[~self._mask],
+                        x=self._x[~self._mask])
+#            print('\nAfter guess:')
+#            self._parameters.pretty_print()
+
+        # Add constant offset for a normalized model
+        if self._result is None and self._norm is not None and self._norm[0]:
+            self.add_model('constant', prefix='tmp_normalization_offset_', parameters={'name': 'c',
+                    'value': -self._norm[0], 'vary': False, 'norm': True})
+                    #'value': -self._norm[0]/self._norm[1], 'vary': False, 'norm': False})
+
+        # Adjust existing parameters for refit:
+        if 'parameters' in kwargs:
+            parameters = kwargs.pop('parameters')
+            if isinstance(parameters, dict):
+                parameters = (parameters, )
+            elif not is_dict_series(parameters):
+                illegal_value(parameters, 'parameters', 'Fit.fit', raise_error=True)
+            for par in parameters:
+                name = par['name']
+                if name not in self._parameters:
+                    raise ValueError(f'Unable to match {name} parameter {par} to an existing one')
+                if self._parameters[name].expr is not None:
+                    raise ValueError(f'Unable to modify {name} parameter {par} (currently an '+
+                            'expression)')
+                if par.get('expr') is not None:
+                    raise KeyError(f'Illegal "expr" key in {name} parameter {par}')
+                self._parameters[name].set(vary=par.get('vary'))
+                self._parameters[name].set(min=par.get('min'))
+                self._parameters[name].set(max=par.get('max'))
+                self._parameters[name].set(value=par.get('value'))
+#RV            print('\nAfter adjust:')
+#RV            self._parameters.pretty_print()
+
+        # Apply parameter updates through keyword arguments
+#        print(f'kwargs = {kwargs}')
+#        print(f'parameter_norms = {self._parameter_norms}')
+        for name in set(self._parameters) & set(kwargs):
+            value = kwargs.pop(name)
+            if self._parameters[name].expr is None:
+                self._parameters[name].set(value=value)
+            else:
+                logging.warning(f'Ignoring parameter {name} in Fit.fit (set by expression: '+
+                        f'{self._parameters[name].expr})')
+
+        # Check for uninitialized parameters
+        for name, par in self._parameters.items():
+            if par.expr is None:
+                value = par.value
+                if value is None or np.isinf(value) or np.isnan(value):
+                    if interactive:
+                        value = input_num(f'Enter an initial value for {name}', default=1.0)
+                    else:
+                        value = 1.0
+                    if self._norm is None or name not in self._parameter_norms:
+                        self._parameters[name].set(value=value)
+                    elif self._parameter_norms[name]:
+                        self._parameters[name].set(value=value*self._norm[1])
+
+        # Check if model is linear
+        try:
+            linear_model = self._check_linearity_model()
+        except:
+            linear_model = False
+#        print(f'\n\n--------> linear_model = {linear_model}\n')
+        if kwargs.get('check_only_linearity') is not None:
+            return(linear_model)
+
+        # Normalize the data and initial parameters
+#RV        print('\nBefore normalization:')
+#RV        self._parameters.pretty_print()
+#        print(f'parameter_norms:\n{self._parameter_norms}')
+        self._normalize()
+#        print(f'norm = {self._norm}') 
+#RV        print('\nAfter normalization:')
+#RV        self._parameters.pretty_print()
+#        self.print_fit_report()
+#        print(f'parameter_norms:\n{self._parameter_norms}')
+
+        if linear_model:
+            # Perform a linear fit by direct matrix solution with numpy
+            try:
+                if self._mask is None:
+                    self._fit_linear_model(self._x, self._y_norm)
+                else:
+                    self._fit_linear_model(self._x[~self._mask],
+                            np.asarray(self._y_norm)[~self._mask])
+            except:
+                linear_model = False
+        if not linear_model:
+            # Perform a non-linear fit with lmfit
+            # Prevent initial values from sitting at boundaries
+            self._parameter_bounds = {name:{'min': par.min, 'max': par.max} for name, par in
+                    self._parameters.items() if par.vary}
+            for par in self._parameters.values():
+                if par.vary:
+                    par.set(value=self._reset_par_at_boundary(par, par.value))
+#            print('\nAfter checking boundaries:')
+#            self._parameters.pretty_print()
+
+            # Perform the fit
+#            fit_kws = None
+#            if 'Dfun' in kwargs:
+#                fit_kws = {'Dfun': kwargs.pop('Dfun')}
+#            self._result = self._model.fit(self._y_norm, self._parameters, x=self._x,
+#                    fit_kws=fit_kws, **kwargs)
+            if self._mask is None:
+                self._result = self._model.fit(self._y_norm, self._parameters, x=self._x, **kwargs)
+            else:
+                self._result = self._model.fit(np.asarray(self._y_norm)[~self._mask],
+                        self._parameters, x=self._x[~self._mask], **kwargs)
+#RV        print('\nAfter fit:')
+#        print(f'\nself._result ({self._result}):\n\t{self._result.__dict__}')
+#RV        self._parameters.pretty_print()
+#        self.print_fit_report()
+
+        # Set internal parameter values to fit results upon success
+        if self.success:
+            for name, par in self._parameters.items():
+                if par.expr is None and par.vary:
+                    par.set(value=self._result.params[name].value)
+#            print('\nAfter update parameter values:')
+#            self._parameters.pretty_print()
+
+        # Renormalize the data and results
+        self._renormalize()
+#RV        print('\nAfter renormalization:')
+#RV        self._parameters.pretty_print()
+#        self.print_fit_report()
+
+    def plot(self, y=None, y_title=None, result=None, skip_init=False, plot_comp_legends=False,
+            plot_residual=False, plot_masked_data=True, **kwargs):
+        if result is None:
+            result = self._result
+        if result is None:
+            return
+        plots = []
+        legend = []
+        if self._mask is None:
+            mask = np.zeros(self._x.size).astype(bool)
+            plot_masked_data = False
+        else:
+            mask = self._mask
+        if y is not None:
+            if not isinstance(y, (tuple, list, np.ndarray)):
+                illegal_value(y, 'y', 'Fit.plot')
+            if len(y) != len(self._x):
+                logging.warning('Ignoring parameter y in Fit.plot (wrong dimension)')
+                y = None
+        if y is not None:
+            if y_title is None or not isinstance(y_title, str):
+                y_title = 'data'
+            plots += [(self._x, y, '.')]
+            legend += [y_title]
+        if self._y is not None:
+            plots += [(self._x, np.asarray(self._y), 'b.')]
+            legend += ['data']
+            if plot_masked_data:
+                plots += [(self._x[mask], np.asarray(self._y)[mask], 'bx')]
+                legend += ['masked data']
+        if isinstance(plot_residual, bool) and plot_residual:
+            plots += [(self._x[~mask], result.residual, 'k-')]
+            legend += ['residual']
+        plots += [(self._x[~mask], result.best_fit, 'k-')]
+        legend += ['best fit']
+        if not skip_init and hasattr(result, 'init_fit'):
+            plots += [(self._x[~mask], result.init_fit, 'g-')]
+            legend += ['init']
+        components = result.eval_components(x=self._x[~mask])
+        num_components = len(components)
+        if 'tmp_normalization_offset_' in components:
+            num_components -= 1
+        if num_components > 1:
+            eval_index = 0
+            for modelname, y in components.items():
+                if modelname == 'tmp_normalization_offset_':
+                    continue
+                if modelname == '_eval':
+                    modelname = f'eval{eval_index}'
+                if len(modelname) > 20:
+                    modelname = f'{modelname[0:16]} ...'
+                if isinstance(y, (int, float)):
+                    y *= np.ones(self._x[~mask].size)
+                plots += [(self._x[~mask], y, '--')]
+                if plot_comp_legends:
+                    if modelname[-1] == '_':
+                        legend.append(modelname[:-1])
+                    else:
+                        legend.append(modelname)
+        title = kwargs.get('title')
+        if title is not None:
+            kwargs.pop('title')
+        quick_plot(tuple(plots), legend=legend, title=title, block=True, **kwargs)
+
+    @staticmethod
+    def guess_init_peak(x, y, *args, center_guess=None, use_max_for_center=True):
+        """ Return a guess for the initial height, center and fwhm for a peak
+        """
+#        print(f'\n\nargs = {args}')
+#        print(f'center_guess = {center_guess}')
+#        quick_plot(x, y, vlines=center_guess, block=True)
+        center_guesses = None
+        x = np.asarray(x)
+        y = np.asarray(y)
+        if len(x) != len(y):
+            logging.error(f'Invalid x and y lengths ({len(x)}, {len(y)}), skip initial guess')
+            return(None, None, None)
+        if isinstance(center_guess, (int, float)):
+            if len(args):
+                logging.warning('Ignoring additional arguments for single center_guess value')
+            center_guesses = [center_guess]
+        elif isinstance(center_guess, (tuple, list, np.ndarray)):
+            if len(center_guess) == 1:
+                logging.warning('Ignoring additional arguments for single center_guess value')
+                if not isinstance(center_guess[0], (int, float)):
+                    raise ValueError(f'Invalid parameter center_guess ({type(center_guess[0])})')
+                center_guess = center_guess[0]
+            else:
+                if len(args) != 1:
+                    raise ValueError(f'Invalid number of arguments ({len(args)})')
+                n = args[0]
+                if not is_index(n, 0, len(center_guess)):
+                    raise ValueError('Invalid argument')
+                center_guesses = center_guess
+                center_guess = center_guesses[n]
+        elif center_guess is not None:
+            raise ValueError(f'Invalid center_guess type ({type(center_guess)})')
+#        print(f'x = {x}')
+#        print(f'y = {y}')
+#        print(f'center_guess = {center_guess}')
+
+        # Sort the inputs
+        index = np.argsort(x)
+        x = x[index]
+        y = y[index]
+        miny = y.min()
+#        print(f'miny = {miny}')
+#        print(f'x_range = {x[0]} {x[-1]} {len(x)}')
+#        print(f'y_range = {y[0]} {y[-1]} {len(y)}')
+#        quick_plot(x, y, vlines=center_guess, block=True)
+
+#        xx = x
+#        yy = y
+        # Set range for current peak
+#        print(f'n = {n}')
+#        print(f'center_guesses = {center_guesses}')
+        if center_guesses is not None:
+            if len(center_guesses) > 1:
+                index = np.argsort(center_guesses)
+                n = list(index).index(n)
+#                print(f'n = {n}')
+#                print(f'index = {index}')
+                center_guesses = np.asarray(center_guesses)[index]
+#                print(f'center_guesses = {center_guesses}')
+            if n == 0:
+               low = 0
+               upp = index_nearest(x, (center_guesses[0]+center_guesses[1])/2)
+            elif n == len(center_guesses)-1:
+               low = index_nearest(x, (center_guesses[n-1]+center_guesses[n])/2)
+               upp = len(x)
+            else:
+               low = index_nearest(x, (center_guesses[n-1]+center_guesses[n])/2)
+               upp = index_nearest(x, (center_guesses[n]+center_guesses[n+1])/2)
+#            print(f'low = {low}')
+#            print(f'upp = {upp}')
+            x = x[low:upp]
+            y = y[low:upp]
+#            quick_plot(x, y, vlines=(x[0], center_guess, x[-1]), block=True)
+
+        # Estimate FHHM
+        maxy = y.max()
+#        print(f'x_range = {x[0]} {x[-1]} {len(x)}')
+#        print(f'y_range = {y[0]} {y[-1]} {len(y)} {miny} {maxy}')
+#        print(f'center_guess = {center_guess}')
+        if center_guess is None:
+            center_index = np.argmax(y)
+            center = x[center_index]
+            height = maxy-miny
+        else:
+            if use_max_for_center:
+                center_index = np.argmax(y)
+                center = x[center_index]
+                if center_index < 0.1*len(x) or center_index > 0.9*len(x):
+                    center_index = index_nearest(x, center_guess)
+                    center = center_guess
+            else:
+                center_index = index_nearest(x, center_guess)
+                center = center_guess
+            height = y[center_index]-miny
+#        print(f'center_index = {center_index}')
+#        print(f'center = {center}')
+#        print(f'height = {height}')
+        half_height = miny+0.5*height
+#        print(f'half_height = {half_height}')
+        fwhm_index1 = 0
+        for i in range(center_index, fwhm_index1, -1):
+            if y[i] < half_height:
+                fwhm_index1 = i
+                break
+#        print(f'fwhm_index1 = {fwhm_index1} {x[fwhm_index1]}')
+        fwhm_index2 = len(x)-1
+        for i in range(center_index, fwhm_index2):
+            if y[i] < half_height:
+                fwhm_index2 = i
+                break
+#        print(f'fwhm_index2 = {fwhm_index2} {x[fwhm_index2]}')
+#        quick_plot((x,y,'o'), vlines=(x[fwhm_index1], center, x[fwhm_index2]), block=True)
+        if fwhm_index1 == 0 and fwhm_index2 < len(x)-1:
+            fwhm = 2*(x[fwhm_index2]-center)
+        elif fwhm_index1 > 0 and fwhm_index2 == len(x)-1:
+            fwhm = 2*(center-x[fwhm_index1])
+        else:
+            fwhm = x[fwhm_index2]-x[fwhm_index1]
+#        print(f'fwhm_index1 = {fwhm_index1} {x[fwhm_index1]}')
+#        print(f'fwhm_index2 = {fwhm_index2} {x[fwhm_index2]}')
+#        print(f'fwhm = {fwhm}')
+
+        # Return height, center and FWHM
+#        quick_plot((x,y,'o'), (xx,yy), vlines=(x[fwhm_index1], center, x[fwhm_index2]), block=True)
+        return(height, center, fwhm)
+
+    def _check_linearity_model(self):
+        """Identify the linearity of all model parameters and check if the model is linear or not
+        """
+        if not self._try_linear_fit:
+            logging.info('Skip linearity check (not yet supported for callable models)')
+            return(False)
+        free_parameters = [name for name, par in self._parameters.items() if par.vary]
+        for component in self._model.components:
+            if 'tmp_normalization_offset_c' in component.param_names:
+                continue
+            if isinstance(component, ExpressionModel):
+                for name in free_parameters:
+                    if diff(component.expr, name, name):
+#                        print(f'\t\t{component.expr} is non-linear in {name}')
+                        self._nonlinear_parameters.append(name)
+                        if name in self._linear_parameters:
+                            self._linear_parameters.remove(name)
+            else:
+                model_parameters = component.param_names.copy()
+                for basename, hint in component.param_hints.items():
+                    name = f'{component.prefix}{basename}'
+                    if hint.get('expr') is not None:
+                        model_parameters.remove(name)
+                for name in model_parameters:
+                    expr = self._parameters[name].expr
+                    if expr is not None:
+                        for nname in free_parameters:
+                            if name in self._nonlinear_parameters:
+                                if diff(expr, nname):
+#                                    print(f'\t\t{component} is non-linear in {nname} (through {name} = "{expr}")')
+                                    self._nonlinear_parameters.append(nname)
+                                    if nname in self._linear_parameters:
+                                        self._linear_parameters.remove(nname)
+                            else:
+                                assert(name in self._linear_parameters)
+#                                print(f'\n\nexpr ({type(expr)}) = {expr}\nnname ({type(nname)}) = {nname}\n\n')
+                                if diff(expr, nname, nname):
+#                                    print(f'\t\t{component} is non-linear in {nname} (through {name} = "{expr}")')
+                                    self._nonlinear_parameters.append(nname)
+                                    if nname in self._linear_parameters:
+                                        self._linear_parameters.remove(nname)
+#        print(f'\nfree parameters:\n\t{free_parameters}')
+#        print(f'linear parameters:\n\t{self._linear_parameters}')
+#        print(f'nonlinear parameters:\n\t{self._nonlinear_parameters}\n')
+        if any(True for name in self._nonlinear_parameters if self._parameters[name].vary):
+            return(False)
+        return(True)
+
+    def _fit_linear_model(self, x, y):
+        """Perform a linear fit by direct matrix solution with numpy
+        """
+        # Construct the matrix and the free parameter vector
+#        print(f'\nparameters:')
+#        self._parameters.pretty_print()
+#        print(f'\nparameter_norms:\n\t{self._parameter_norms}')
+#        print(f'\nlinear_parameters:\n\t{self._linear_parameters}')
+#        print(f'nonlinear_parameters:\n\t{self._nonlinear_parameters}')
+        free_parameters = [name for name, par in self._parameters.items() if par.vary]
+#        print(f'free parameters:\n\t{free_parameters}\n')
+        expr_parameters = {name:par.expr for name, par in self._parameters.items()
+                if par.expr is not None}
+        model_parameters = []
+        for component in self._model.components:
+            if 'tmp_normalization_offset_c' in component.param_names:
+                continue
+            model_parameters += component.param_names
+            for basename, hint in component.param_hints.items():
+                name = f'{component.prefix}{basename}'
+                if hint.get('expr') is not None:
+                    expr_parameters.pop(name)
+                    model_parameters.remove(name)
+#        print(f'expr parameters:\n{expr_parameters}')
+#        print(f'model parameters:\n\t{model_parameters}\n')
+        norm = 1.0
+        if self._normalized:
+            norm = self._norm[1]
+#        print(f'\n\nself._normalized = {self._normalized}\nnorm = {norm}\nself._norm = {self._norm}\n')
+        # Add expression parameters to asteval
+        ast = Interpreter()
+#        print(f'Adding to asteval sym table:')
+        for name, expr in expr_parameters.items():
+#            print(f'\tadding {name} {expr}')
+            ast.symtable[name] = expr
+        # Add constant parameters to asteval
+        # (renormalize to use correctly in evaluation of expression models)
+        for name, par in self._parameters.items():
+            if par.expr is None and not par.vary:
+                if self._parameter_norms[name]:
+#                    print(f'\tadding {name} {par.value*norm}')
+                    ast.symtable[name] = par.value*norm
+                else:
+#                    print(f'\tadding {name} {par.value}')
+                    ast.symtable[name] = par.value
+        A = np.zeros((len(x), len(free_parameters)), dtype='float64')
+        y_const = np.zeros(len(x), dtype='float64')
+        have_expression_model = False
+        for component in self._model.components:
+            if isinstance(component, ConstantModel):
+                name = component.param_names[0]
+#                print(f'\nConstant model: {name} {self._parameters[name]}\n')
+                if name in free_parameters:
+#                    print(f'\t\t{name} is a free constant set matrix column {free_parameters.index(name)} to 1.0')
+                    A[:,free_parameters.index(name)] = 1.0
+                else:
+                    if self._parameter_norms[name]:
+                        delta_y_const = self._parameters[name]*np.ones(len(x))
+                    else:
+                        delta_y_const = (self._parameters[name]*norm)*np.ones(len(x))
+                    y_const += delta_y_const
+#                    print(f'\ndelta_y_const ({type(delta_y_const)}):\n{delta_y_const}\n')
+            elif isinstance(component, ExpressionModel):
+                have_expression_model = True
+                const_expr = component.expr
+#                print(f'\nExpression model:\nconst_expr: {const_expr}\n')
+                for name in free_parameters:
+                    dexpr_dname = diff(component.expr, name)
+                    if dexpr_dname:
+                        const_expr = f'{const_expr}-({str(dexpr_dname)})*{name}'
+#                        print(f'\tconst_expr: {const_expr}')
+                        if not self._parameter_norms[name]:
+                            dexpr_dname = f'({dexpr_dname})/{norm}'
+#                        print(f'\t{component.expr} is linear in {name}\n\t\tadd "{str(dexpr_dname)}" to matrix as column {free_parameters.index(name)}')
+                        fx = [(lambda _: ast.eval(str(dexpr_dname)))(ast(f'x={v}')) for v in x]
+#                        print(f'\tfx:\n{fx}')
+                        if len(ast.error):
+                            raise ValueError(f'Unable to evaluate {dexpr_dname}')
+                        A[:,free_parameters.index(name)] += fx
+#                        if self._parameter_norms[name]:
+#                            print(f'\t\t{component.expr} is linear in {name} add "{str(dexpr_dname)}" to matrix as column {free_parameters.index(name)}')
+#                            A[:,free_parameters.index(name)] += fx
+#                        else:
+#                            print(f'\t\t{component.expr} is linear in {name} add "({str(dexpr_dname)})/{norm}" to matrix as column {free_parameters.index(name)}')
+#                            A[:,free_parameters.index(name)] += np.asarray(fx)/norm
+                # FIX: find another solution if expr not supported by simplify
+                const_expr = str(simplify(f'({const_expr})/{norm}'))
+#                print(f'\nconst_expr: {const_expr}')
+                delta_y_const = [(lambda _: ast.eval(const_expr))(ast(f'x = {v}')) for v in x]
+                y_const += delta_y_const
+#                print(f'\ndelta_y_const ({type(delta_y_const)}):\n{delta_y_const}\n')
+                if len(ast.error):
+                    raise ValueError(f'Unable to evaluate {const_expr}')
+            else:
+                free_model_parameters = [name for name in component.param_names
+                        if name in free_parameters or name in expr_parameters]
+#                print(f'\nBuild-in model ({component}):\nfree_model_parameters: {free_model_parameters}\n')
+                if not len(free_model_parameters):
+                    y_const += component.eval(params=self._parameters, x=x)
+                elif isinstance(component, LinearModel):
+                    if f'{component.prefix}slope' in free_model_parameters:
+                        A[:,free_parameters.index(f'{component.prefix}slope')] = x
+                    else:
+                        y_const += self._parameters[f'{component.prefix}slope'].value*x
+                    if f'{component.prefix}intercept' in free_model_parameters:
+                        A[:,free_parameters.index(f'{component.prefix}intercept')] = 1.0
+                    else:
+                        y_const += self._parameters[f'{component.prefix}intercept'].value* \
+                                np.ones(len(x))
+                elif isinstance(component, QuadraticModel):
+                    if f'{component.prefix}a' in free_model_parameters:
+                        A[:,free_parameters.index(f'{component.prefix}a')] = x**2
+                    else:
+                        y_const += self._parameters[f'{component.prefix}a'].value*x**2
+                    if f'{component.prefix}b' in free_model_parameters:
+                        A[:,free_parameters.index(f'{component.prefix}b')] = x
+                    else:
+                        y_const += self._parameters[f'{component.prefix}b'].value*x
+                    if f'{component.prefix}c' in free_model_parameters:
+                        A[:,free_parameters.index(f'{component.prefix}c')] = 1.0
+                    else:
+                        y_const += self._parameters[f'{component.prefix}c'].value*np.ones(len(x))
+                else:
+                    # At this point each build-in model must be strictly proportional to each linear
+                    #   model parameter. Without this assumption, the model equation is needed
+                    #   For the current build-in lmfit models, this can only ever be the amplitude
+                    assert(len(free_model_parameters) == 1)
+                    name = f'{component.prefix}amplitude'
+                    assert(free_model_parameters[0] == name)
+                    assert(self._parameter_norms[name])
+                    expr = self._parameters[name].expr
+                    if expr is None:
+#                        print(f'\t{component} is linear in {name} add to matrix as column {free_parameters.index(name)}')
+                        parameters = deepcopy(self._parameters)
+                        parameters[name].set(value=1.0)
+                        index = free_parameters.index(name)
+                        A[:,free_parameters.index(name)] += component.eval(params=parameters, x=x)
+                    else:
+                        const_expr = expr
+#                        print(f'\tconst_expr: {const_expr}')
+                        parameters = deepcopy(self._parameters)
+                        parameters[name].set(value=1.0)
+                        dcomp_dname = component.eval(params=parameters, x=x)
+#                        print(f'\tdcomp_dname ({type(dcomp_dname)}):\n{dcomp_dname}')
+                        for nname in free_parameters:
+                            dexpr_dnname =  diff(expr, nname)
+                            if dexpr_dnname:
+                                assert(self._parameter_norms[name])
+#                                print(f'\t\td({expr})/d{nname} = {dexpr_dnname}')
+#                                print(f'\t\t{component} is linear in {nname} (through {name} = "{expr}", add to matrix as column {free_parameters.index(nname)})')
+                                fx = np.asarray(dexpr_dnname*dcomp_dname, dtype='float64')
+#                                print(f'\t\tfx ({type(fx)}): {fx}')
+#                                print(f'free_parameters.index({nname}): {free_parameters.index(nname)}')
+                                if self._parameter_norms[nname]:
+                                    A[:,free_parameters.index(nname)] += fx
+                                else:
+                                    A[:,free_parameters.index(nname)] += fx/norm
+                                const_expr = f'{const_expr}-({dexpr_dnname})*{nname}'
+#                                print(f'\t\tconst_expr: {const_expr}')
+                        const_expr = str(simplify(f'({const_expr})/{norm}'))
+#                        print(f'\tconst_expr: {const_expr}')
+                        fx = [(lambda _: ast.eval(const_expr))(ast(f'x = {v}')) for v in x]
+#                        print(f'\tfx: {fx}')
+                        delta_y_const = np.multiply(fx, dcomp_dname)
+                        y_const += delta_y_const
+#                        print(f'\ndelta_y_const ({type(delta_y_const)}):\n{delta_y_const}\n')
+#            print(A)
+#            print(y_const)
+        solution, residual, rank, s = np.linalg.lstsq(A, y-y_const, rcond=None)
+#        print(f'\nsolution ({type(solution)} {solution.shape}):\n\t{solution}')
+#        print(f'\nresidual ({type(residual)} {residual.shape}):\n\t{residual}')
+#        print(f'\nrank ({type(rank)} {rank.shape}):\n\t{rank}')
+#        print(f'\ns ({type(s)} {s.shape}):\n\t{s}\n')
+
+        # Assemble result (compensate for normalization in expression models)
+        for name, value in zip(free_parameters, solution):
+            self._parameters[name].set(value=value)
+        if self._normalized and (have_expression_model or len(expr_parameters)):
+            for name, norm in self._parameter_norms.items():
+                par = self._parameters[name]
+                if par.expr is None and norm:
+                    self._parameters[name].set(value=par.value*self._norm[1])
+#        self._parameters.pretty_print()
+#        print(f'\nself._parameter_norms:\n\t{self._parameter_norms}')
+        self._result = ModelResult(self._model, deepcopy(self._parameters))
+        self._result.best_fit = self._model.eval(params=self._parameters, x=x)
+        if self._normalized and (have_expression_model or len(expr_parameters)):
+            if 'tmp_normalization_offset_c' in self._parameters:
+                offset = self._parameters['tmp_normalization_offset_c']
+            else:
+                offset = 0.0
+            self._result.best_fit = (self._result.best_fit-offset-self._norm[0])/self._norm[1]
+            if self._normalized:
+                for name, norm in self._parameter_norms.items():
+                    par = self._parameters[name]
+                    if par.expr is None and norm:
+                        value = par.value/self._norm[1]
+                        self._parameters[name].set(value=value)
+                        self._result.params[name].set(value=value)
+#        self._parameters.pretty_print()
+        self._result.residual = self._result.best_fit-y
+        self._result.components = self._model.components
+        self._result.init_params = None
+#        quick_plot((x, y, '.'), (x, y_const, 'g'), (x, self._result.best_fit, 'k'), (x, self._result.residual, 'r'), block=True)
+
+    def _normalize(self):
+        """Normalize the data and initial parameters
+        """
+        if self._normalized:
+            return
+        if self._norm is None:
+            if self._y is not None and self._y_norm is None:
+                self._y_norm = np.asarray(self._y)
+        else:
+            if self._y is not None and self._y_norm is None:
+                self._y_norm = (np.asarray(self._y)-self._norm[0])/self._norm[1]
+            self._y_range = 1.0
+            for name, norm in self._parameter_norms.items():
+                par = self._parameters[name]
+                if par.expr is None and norm:
+                    value = par.value/self._norm[1]
+                    _min = par.min
+                    _max = par.max
+                    if not np.isinf(_min) and abs(_min) != float_min:
+                        _min /= self._norm[1]
+                    if not np.isinf(_max) and abs(_max) != float_min:
+                        _max /= self._norm[1]
+                    par.set(value=value, min=_min, max=_max)
+            self._normalized = True
+
+    def _renormalize(self):
+        """Renormalize the data and results
+        """
+        if self._norm is None or not self._normalized:
+            return
+        self._normalized = False
+        for name, norm in self._parameter_norms.items():
+            par = self._parameters[name]
+            if par.expr is None and norm:
+                value = par.value*self._norm[1]
+                _min = par.min
+                _max = par.max
+                if not np.isinf(_min) and abs(_min) != float_min:
+                    _min *= self._norm[1]
+                if not np.isinf(_max) and abs(_max) != float_min:
+                    _max *= self._norm[1]
+                par.set(value=value, min=_min, max=_max)
+        if self._result is None:
+            return
+        self._result.best_fit = self._result.best_fit*self._norm[1]+self._norm[0]
+        for name, par in self._result.params.items():
+            if self._parameter_norms.get(name, False):
+                if par.stderr is not None:
+                    par.stderr *= self._norm[1]
+                if par.expr is None:
+                    _min = par.min
+                    _max = par.max
+                    value = par.value*self._norm[1]
+                    if par.init_value is not None:
+                        par.init_value *= self._norm[1]
+                    if not np.isinf(_min) and abs(_min) != float_min:
+                        _min *= self._norm[1]
+                    if not np.isinf(_max) and abs(_max) != float_min:
+                        _max *= self._norm[1]
+                    par.set(value=value, min=_min, max=_max)
+        if hasattr(self._result, 'init_fit'):
+            self._result.init_fit = self._result.init_fit*self._norm[1]+self._norm[0]
+        if hasattr(self._result, 'init_values'):
+            init_values = {}
+            for name, value in self._result.init_values.items():
+                if name not in self._parameter_norms or self._parameters[name].expr is not None:
+                    init_values[name] = value
+                elif self._parameter_norms[name]:
+                    init_values[name] = value*self._norm[1]
+            self._result.init_values = init_values
+            for name, par in self._result.init_params.items():
+                if par.expr is None and self._parameter_norms.get(name, False):
+                    value = par.value
+                    _min = par.min
+                    _max = par.max
+                    value *= self._norm[1]
+                    if not np.isinf(_min) and abs(_min) != float_min:
+                        _min *= self._norm[1]
+                    if not np.isinf(_max) and abs(_max) != float_min:
+                        _max *= self._norm[1]
+                    par.set(value=value, min=_min, max=_max)
+                par.init_value = par.value
+        # Don't renormalize chisqr, it has no useful meaning in physical units
+        #self._result.chisqr *= self._norm[1]*self._norm[1]
+        if self._result.covar is not None:
+            for i, name in enumerate(self._result.var_names):
+                if self._parameter_norms.get(name, False):
+                    for j in range(len(self._result.var_names)):
+                        if self._result.covar[i,j] is not None:
+                            self._result.covar[i,j] *= self._norm[1]
+                        if self._result.covar[j,i] is not None:
+                            self._result.covar[j,i] *= self._norm[1]
+        # Don't renormalize redchi, it has no useful meaning in physical units
+        #self._result.redchi *= self._norm[1]*self._norm[1]
+        if self._result.residual is not None:
+            self._result.residual *= self._norm[1]
+
+    def _reset_par_at_boundary(self, par, value):
+        assert(par.vary)
+        name = par.name
+        _min = self._parameter_bounds[name]['min']
+        _max = self._parameter_bounds[name]['max']
+        if np.isinf(_min):
+            if not np.isinf(_max):
+                if self._parameter_norms.get(name, False):
+                    upp = _max-0.1*self._y_range
+                elif _max == 0.0:
+                    upp = _max-0.1
+                else:
+                    upp = _max-0.1*abs(_max)
+                if value >= upp:
+                    return(upp)
+        else:
+            if np.isinf(_max):
+                if self._parameter_norms.get(name, False):
+                    low = _min+0.1*self._y_range
+                elif _min == 0.0:
+                    low = _min+0.1
+                else:
+                    low = _min+0.1*abs(_min)
+                if value <= low:
+                    return(low)
+            else:
+                low = 0.9*_min+0.1*_max
+                upp = 0.1*_min+0.9*_max
+                if value <= low:
+                    return(low)
+                elif value >= upp:
+                    return(upp)
+        return(value)
+
+
+class FitMultipeak(Fit):
+    """Fit data with multiple peaks
+    """
+    def __init__(self, y, x=None, normalize=True):
+        super().__init__(y, x=x, normalize=normalize)
+        self._fwhm_max = None
+        self._sigma_max = None
+
+    @classmethod
+    def fit_multipeak(cls, y, centers, x=None, normalize=True, peak_models='gaussian',
+            center_exprs=None, fit_type=None, background_order=None, background_exp=False,
+            fwhm_max=None, plot_components=False):
+        """Make sure that centers and fwhm_max are in the correct units and consistent with expr
+           for a uniform fit (fit_type == 'uniform')
+        """
+        fit = cls(y, x=x, normalize=normalize)
+        success = fit.fit(centers, fit_type=fit_type, peak_models=peak_models, fwhm_max=fwhm_max,
+                center_exprs=center_exprs, background_order=background_order,
+                background_exp=background_exp, plot_components=plot_components)
+        if success:
+            return(fit.best_fit, fit.residual, fit.best_values, fit.best_errors, fit.redchi, \
+                    fit.success)
+        else:
+            return(np.array([]), np.array([]), {}, {}, float_max, False)
+
+    def fit(self, centers, fit_type=None, peak_models=None, center_exprs=None, fwhm_max=None,
+                background_order=None, background_exp=False, plot_components=False,
+                param_constraint=False):
+        self._fwhm_max = fwhm_max
+        # Create the multipeak model
+        self._create_model(centers, fit_type, peak_models, center_exprs, background_order,
+                background_exp, param_constraint)
+
+        # RV: Obsolete Normalize the data and results
+#        print('\nBefore fit before normalization in FitMultipeak:')
+#        self._parameters.pretty_print()
+#        self._normalize()
+#        print('\nBefore fit after normalization in FitMultipeak:')
+#        self._parameters.pretty_print()
+
+        # Perform the fit
+        try:
+            if param_constraint:
+                super().fit(fit_kws={'xtol': 1.e-5, 'ftol': 1.e-5, 'gtol': 1.e-5})
+            else:
+                super().fit()
+        except:
+            return(False)
+
+        # Check for valid fit parameter results
+        fit_failure = self._check_validity()
+        success = True
+        if fit_failure:
+            if param_constraint:
+                logging.warning('  -> Should not happen with param_constraint set, fail the fit')
+                success = False
+            else:
+                logging.info('  -> Retry fitting with constraints')
+                self.fit(centers, fit_type, peak_models, center_exprs, fwhm_max=fwhm_max,
+                        background_order=background_order, background_exp=background_exp,
+                        plot_components=plot_components, param_constraint=True)
+        else:
+            # RV: Obsolete Renormalize the data and results
+#            print('\nAfter fit before renormalization in FitMultipeak:')
+#            self._parameters.pretty_print()
+#            self.print_fit_report()
+#            self._renormalize()
+#            print('\nAfter fit after renormalization in FitMultipeak:')
+#            self._parameters.pretty_print()
+#            self.print_fit_report()
+
+            # Print report and plot components if requested
+            if plot_components:
+                self.print_fit_report()
+                self.plot()
+
+        return(success)
+
+    def _create_model(self, centers, fit_type=None, peak_models=None, center_exprs=None,
+                background_order=None, background_exp=False, param_constraint=False):
+        """Create the multipeak model
+        """
+        if isinstance(centers, (int, float)):
+            centers = [centers]
+        num_peaks = len(centers)
+        if peak_models is None:
+            peak_models = num_peaks*['gaussian']
+        elif isinstance(peak_models, str):
+            peak_models = num_peaks*[peak_models]
+        if len(peak_models) != num_peaks:
+            raise ValueError(f'Inconsistent number of peaks in peak_models ({len(peak_models)} vs '+
+                    f'{num_peaks})')
+        if num_peaks == 1:
+            if fit_type is not None:
+                logging.debug('Ignoring fit_type input for fitting one peak')
+            fit_type = None
+            if center_exprs is not None:
+                logging.debug('Ignoring center_exprs input for fitting one peak')
+                center_exprs = None
+        else:
+            if fit_type == 'uniform':
+                if center_exprs is None:
+                    center_exprs = [f'scale_factor*{cen}' for cen in centers]
+                if len(center_exprs) != num_peaks:
+                    raise ValueError(f'Inconsistent number of peaks in center_exprs '+
+                            f'({len(center_exprs)} vs {num_peaks})')
+            elif fit_type == 'unconstrained' or fit_type is None:
+                if center_exprs is not None:
+                    logging.warning('Ignoring center_exprs input for unconstrained fit')
+                    center_exprs = None
+            else:
+                raise ValueError(f'Invalid fit_type in fit_multigaussian {fit_type}')
+        self._sigma_max = None
+        if param_constraint:
+            min_value = float_min
+            if self._fwhm_max is not None:
+                self._sigma_max = np.zeros(num_peaks)
+        else:
+            min_value = None
+
+        # Reset the fit
+        self._model = None
+        self._parameters = Parameters()
+        self._result = None
+
+        # Add background model
+        if background_order is not None:
+            if background_order == 0:
+                self.add_model('constant', prefix='background', parameters=
+                        {'name': 'c', 'value': float_min, 'min': min_value})
+            elif background_order == 1:
+                self.add_model('linear', prefix='background', slope=0.0, intercept=0.0)
+            elif background_order == 2:
+                self.add_model('quadratic', prefix='background', a=0.0, b=0.0, c=0.0)
+            else:
+                raise ValueError(f'Invalid parameter background_order ({background_order})')
+        if background_exp:
+            self.add_model('exponential', prefix='background', parameters=(
+                        {'name': 'amplitude', 'value': float_min, 'min': min_value},
+                        {'name': 'decay', 'value': float_min, 'min': min_value}))
+
+        # Add peaks and guess initial fit parameters
+        ast = Interpreter()
+        if num_peaks == 1:
+            height_init, cen_init, fwhm_init = self.guess_init_peak(self._x, self._y)
+            if self._fwhm_max is not None and fwhm_init > self._fwhm_max:
+                fwhm_init = self._fwhm_max
+            ast(f'fwhm = {fwhm_init}')
+            ast(f'height = {height_init}')
+            sig_init = ast(fwhm_factor[peak_models[0]])
+            amp_init = ast(height_factor[peak_models[0]])
+            sig_max = None
+            if self._sigma_max is not None:
+                ast(f'fwhm = {self._fwhm_max}')
+                sig_max = ast(fwhm_factor[peak_models[0]])
+                self._sigma_max[0] = sig_max
+            self.add_model(peak_models[0], parameters=(
+                    {'name': 'amplitude', 'value': amp_init, 'min': min_value},
+                    {'name': 'center', 'value': cen_init, 'min': min_value},
+                    {'name': 'sigma', 'value': sig_init, 'min': min_value, 'max': sig_max}))
+        else:
+            if fit_type == 'uniform':
+                self.add_parameter(name='scale_factor', value=1.0)
+            for i in range(num_peaks):
+                height_init, cen_init, fwhm_init = self.guess_init_peak(self._x, self._y, i,
+                        center_guess=centers)
+                if self._fwhm_max is not None and fwhm_init > self._fwhm_max:
+                    fwhm_init = self._fwhm_max
+                ast(f'fwhm = {fwhm_init}')
+                ast(f'height = {height_init}')
+                sig_init = ast(fwhm_factor[peak_models[i]])
+                amp_init = ast(height_factor[peak_models[i]])
+                sig_max = None
+                if self._sigma_max is not None:
+                    ast(f'fwhm = {self._fwhm_max}')
+                    sig_max = ast(fwhm_factor[peak_models[i]])
+                    self._sigma_max[i] = sig_max
+                if fit_type == 'uniform':
+                    self.add_model(peak_models[i], prefix=f'peak{i+1}_', parameters=(
+                            {'name': 'amplitude', 'value': amp_init, 'min': min_value},
+                            {'name': 'center', 'expr': center_exprs[i]},
+                            {'name': 'sigma', 'value': sig_init, 'min': min_value,
+                            'max': sig_max}))
+                else:
+                    self.add_model('gaussian', prefix=f'peak{i+1}_', parameters=(
+                            {'name': 'amplitude', 'value': amp_init, 'min': min_value},
+                            {'name': 'center', 'value': cen_init, 'min': min_value},
+                            {'name': 'sigma', 'value': sig_init, 'min': min_value,
+                            'max': sig_max}))
+
+    def _check_validity(self):
+        """Check for valid fit parameter results
+        """
+        fit_failure = False
+        index = compile(r'\d+')
+        for name, par in self.best_parameters.items():
+            if 'background' in name:
+#                if ((name == 'backgroundc' and par['value'] <= 0.0) or
+#                        (name.endswith('amplitude') and par['value'] <= 0.0) or
+                 if ((name.endswith('amplitude') and par['value'] <= 0.0) or
+                        (name.endswith('decay') and par['value'] <= 0.0)):
+                    logging.info(f'Invalid fit result for {name} ({par["value"]})')
+                    fit_failure = True
+            elif (((name.endswith('amplitude') or name.endswith('height')) and
+                    par['value'] <= 0.0) or
+                    ((name.endswith('sigma') or name.endswith('fwhm')) and par['value'] <= 0.0) or
+                    (name.endswith('center') and par['value'] <= 0.0) or
+                    (name == 'scale_factor' and par['value'] <= 0.0)):
+                logging.info(f'Invalid fit result for {name} ({par["value"]})')
+                fit_failure = True
+            if name.endswith('sigma') and self._sigma_max is not None:
+                if name == 'sigma':
+                    sigma_max = self._sigma_max[0]
+                else:
+                    sigma_max = self._sigma_max[int(index.search(name).group())-1]
+                if par['value'] > sigma_max:
+                    logging.info(f'Invalid fit result for {name} ({par["value"]})')
+                    fit_failure = True
+                elif par['value'] == sigma_max:
+                    logging.warning(f'Edge result on for {name} ({par["value"]})')
+            if name.endswith('fwhm') and self._fwhm_max is not None:
+                if par['value'] > self._fwhm_max:
+                    logging.info(f'Invalid fit result for {name} ({par["value"]})')
+                    fit_failure = True
+                elif par['value'] == self._fwhm_max:
+                    logging.warning(f'Edge result on for {name} ({par["value"]})')
+        return(fit_failure)
+
+
+class FitMap(Fit):
+    """Fit a map of data
+    """
+    def __init__(self, ymap, x=None, models=None, normalize=True, transpose=None, **kwargs):
+        super().__init__(None)
+        self._best_errors = None
+        self._best_fit = None
+        self._best_parameters = None
+        self._best_values = None
+        self._inv_transpose = None
+        self._max_nfev = None
+        self._memfolder = None
+        self._new_parameters = None
+        self._out_of_bounds = None
+        self._plot = False
+        self._print_report = False
+        self._redchi = None
+        self._redchi_cutoff = 0.1
+        self._skip_init = True
+        self._success = None
+        self._transpose = None
+        self._try_no_bounds = True
+
+        # At this point the fastest index should always be the signal dimension so that the slowest 
+        #    ndim-1 dimensions are the map dimensions
+        if isinstance(ymap, (tuple, list, np.ndarray)):
+            self._x = np.asarray(x)
+        elif have_xarray and isinstance(ymap, xr.DataArray):
+            if x is not None:
+                logging.warning('Ignoring superfluous input x ({x}) in Fit.__init__')
+            self._x = np.asarray(ymap[ymap.dims[-1]])
+        else:
+            illegal_value(ymap, 'ymap', 'FitMap:__init__', raise_error=True)
+        self._ymap = ymap
+
+        # Verify the input parameters
+        if self._x.ndim != 1:
+            raise ValueError(f'Invalid dimension for input x {self._x.ndim}')
+        if self._ymap.ndim < 2:
+            raise ValueError('Invalid number of dimension of the input dataset '+
+                    f'{self._ymap.ndim}')
+        if self._x.size != self._ymap.shape[-1]:
+            raise ValueError(f'Inconsistent x and y dimensions ({self._x.size} vs '+
+                    f'{self._ymap.shape[-1]})')
+        if not isinstance(normalize, bool):
+            logging.warning(f'Invalid value for normalize ({normalize}) in Fit.__init__: '+
+                    'setting normalize to True')
+            normalize = True
+        if isinstance(transpose, bool) and not transpose:
+            transpose = None
+        if transpose is not None and self._ymap.ndim < 3:
+            logging.warning(f'Transpose meaningless for {self._ymap.ndim-1}D data maps: ignoring '+
+                    'transpose')
+        if transpose is not None:
+            if self._ymap.ndim == 3 and isinstance(transpose, bool) and transpose:
+                self._transpose = (1, 0)
+            elif not isinstance(transpose, (tuple, list)):
+                logging.warning(f'Invalid data type for transpose ({transpose}, '+
+                        f'{type(transpose)}) in Fit.__init__: setting transpose to False')
+            elif len(transpose) != self._ymap.ndim-1:
+                logging.warning(f'Invalid dimension for transpose ({transpose}, must be equal to '+
+                        f'{self._ymap.ndim-1}) in Fit.__init__: setting transpose to False')
+            elif any(i not in transpose for i in range(len(transpose))):
+                logging.warning(f'Invalid index in transpose ({transpose}) '+
+                        f'in Fit.__init__: setting transpose to False')
+            elif not all(i==transpose[i] for i in range(self._ymap.ndim-1)):
+                self._transpose = transpose
+            if self._transpose is not None:
+                self._inv_transpose = tuple(self._transpose.index(i)
+                        for i in range(len(self._transpose)))
+
+        # Flatten the map (transpose if requested)
+        # Store the flattened map in self._ymap_norm, whether normalized or not
+        if self._transpose is not None:
+            self._ymap_norm = np.transpose(np.asarray(self._ymap), list(self._transpose)+
+                    [len(self._transpose)])
+        else:
+            self._ymap_norm = np.asarray(self._ymap)
+        self._map_dim = int(self._ymap_norm.size/self._x.size)
+        self._map_shape = self._ymap_norm.shape[:-1]
+        self._ymap_norm = np.reshape(self._ymap_norm, (self._map_dim, self._x.size))
+
+        # Check if a mask is provided
+        if 'mask' in kwargs:
+            self._mask = kwargs.pop('mask')
+        if self._mask is None:
+            ymap_min = float(self._ymap_norm.min())
+            ymap_max = float(self._ymap_norm.max())
+        else:
+            self._mask = np.asarray(self._mask).astype(bool)
+            if self._x.size != self._mask.size:
+                raise ValueError(f'Inconsistent mask dimension ({self._x.size} vs '+
+                        f'{self._mask.size})')
+            ymap_masked = np.asarray(self._ymap_norm)[:,~self._mask]
+            ymap_min = float(ymap_masked.min())
+            ymap_max = float(ymap_masked.max())
+
+        # Normalize the data
+        self._y_range = ymap_max-ymap_min
+        if normalize and self._y_range > 0.0:
+            self._norm = (ymap_min, self._y_range)
+            self._ymap_norm = (self._ymap_norm-self._norm[0])/self._norm[1]
+        else:
+            self._redchi_cutoff *= self._y_range**2
+        if models is not None:
+            if callable(models) or isinstance(models, str):
+                kwargs = self.add_model(models, **kwargs)
+            elif isinstance(models, (tuple, list)):
+                for model in models:
+                    kwargs = self.add_model(model, **kwargs)
+            self.fit(**kwargs)
+
+    @classmethod
+    def fit_map(cls, ymap, models, x=None, normalize=True, **kwargs):
+        return(cls(ymap, x=x, models=models, normalize=normalize, **kwargs))
+
+    @property
+    def best_errors(self):
+        return(self._best_errors)
+
+    @property
+    def best_fit(self):
+        return(self._best_fit)
+
+    @property
+    def best_results(self):
+        """Convert the input data array to a data set and add the fit results.
+        """
+        if self.best_values is None or self.best_errors is None or self.best_fit is None:
+            return(None)
+        if not have_xarray:
+            logging.warning('Unable to load xarray module')
+            return(None)
+        best_values = self.best_values
+        best_errors = self.best_errors
+        if isinstance(self._ymap, xr.DataArray):
+            best_results = self._ymap.to_dataset()
+            dims = self._ymap.dims
+            fit_name = f'{self._ymap.name}_fit'
+        else:
+            coords = {f'dim{n}_index':([f'dim{n}_index'], range(self._ymap.shape[n]))
+                    for n in range(self._ymap.ndim-1)}
+            coords['x'] = (['x'], self._x)
+            dims = list(coords.keys())
+            best_results = xr.Dataset(coords=coords)
+            best_results['y'] = (dims, self._ymap)
+            fit_name = 'y_fit'
+        best_results[fit_name] = (dims, self.best_fit)
+        if self._mask is not None:
+            best_results['mask'] = self._mask
+        for n in range(best_values.shape[0]):
+            best_results[f'{self._best_parameters[n]}_values'] = (dims[:-1], best_values[n])
+            best_results[f'{self._best_parameters[n]}_errors'] = (dims[:-1], best_errors[n])
+        best_results.attrs['components'] = self.components
+        return(best_results)
+
+    @property
+    def best_values(self):
+        return(self._best_values)
+
+    @property
+    def chisqr(self):
+        logging.warning('property chisqr not defined for fit.FitMap')
+        return(None)
+
+    @property
+    def components(self):
+        components = {}
+        if self._result is None:
+            logging.warning('Unable to collect components in FitMap.components')
+            return(components)
+        for component in self._result.components:
+            if 'tmp_normalization_offset_c' in component.param_names:
+                continue
+            parameters = {}
+            for name in component.param_names:
+                if self._parameters[name].vary:
+                    parameters[name] = {'free': True}
+                elif self._parameters[name].expr is not None:
+                    parameters[name] = {'free': False, 'expr': self._parameters[name].expr}
+                else:
+                    parameters[name] = {'free': False, 'value': self.init_parameters[name]['value']}
+            expr = None
+            if isinstance(component, ExpressionModel):
+                name = component._name
+                if name[-1] == '_':
+                    name = name[:-1]
+                expr = component.expr
+            else:
+                prefix = component.prefix
+                if len(prefix):
+                    if prefix[-1] == '_':
+                        prefix = prefix[:-1]
+                    name = f'{prefix} ({component._name})'
+                else:
+                    name = f'{component._name}'
+            if expr is None:
+                components[name] = {'parameters': parameters}
+            else:
+                components[name] = {'expr': expr, 'parameters': parameters}
+        return(components)
+
+    @property
+    def covar(self):
+        logging.warning('property covar not defined for fit.FitMap')
+        return(None)
+
+    @property
+    def max_nfev(self):
+        return(self._max_nfev)
+
+    @property
+    def num_func_eval(self):
+        logging.warning('property num_func_eval not defined for fit.FitMap')
+        return(None)
+
+    @property
+    def out_of_bounds(self):
+        return(self._out_of_bounds)
+
+    @property
+    def redchi(self):
+        return(self._redchi)
+
+    @property
+    def residual(self):
+        if self.best_fit is None:
+            return(None)
+        if self._mask is None:
+            return(np.asarray(self._ymap)-self.best_fit)
+        else:
+            ymap_flat = np.reshape(np.asarray(self._ymap), (self._map_dim, self._x.size))
+            ymap_flat_masked = ymap_flat[:,~self._mask]
+            ymap_masked = np.reshape(ymap_flat_masked,
+                    list(self._map_shape)+[ymap_flat_masked.shape[-1]])
+            return(ymap_masked-self.best_fit)
+
+    @property
+    def success(self):
+        return(self._success)
+
+    @property
+    def var_names(self):
+        logging.warning('property var_names not defined for fit.FitMap')
+        return(None)
+
+    @property
+    def y(self):
+        logging.warning('property y not defined for fit.FitMap')
+        return(None)
+
+    @property
+    def ymap(self):
+        return(self._ymap)
+
+    def best_parameters(self, dims=None):
+        if dims is None:
+            return(self._best_parameters)
+        if not isinstance(dims, (list, tuple)) or len(dims) != len(self._map_shape):
+            illegal_value(dims, 'dims', 'FitMap.best_parameters', raise_error=True)
+        if self.best_values is None or self.best_errors is None:
+            logging.warning(f'Unable to obtain best parameter values for dims = {dims} in '+
+                    'FitMap.best_parameters')
+            return({})
+        # Create current parameters
+        parameters = deepcopy(self._parameters)
+        for n, name in enumerate(self._best_parameters):
+            if self._parameters[name].vary:
+                parameters[name].set(value=self.best_values[n][dims])
+            parameters[name].stderr = self.best_errors[n][dims]
+        parameters_dict = {}
+        for name in sorted(parameters):
+            if name != 'tmp_normalization_offset_c':
+                par = parameters[name]
+                parameters_dict[name] = {'value': par.value, 'error': par.stderr,
+                        'init_value': self.init_parameters[name]['value'], 'min': par.min,
+                        'max': par.max, 'vary': par.vary, 'expr': par.expr}
+        return(parameters_dict)
+
+    def freemem(self):
+        if self._memfolder is None:
+            return
+        try:
+            rmtree(self._memfolder)
+            self._memfolder = None
+        except:
+            logging.warning('Could not clean-up automatically.')
+
+    def plot(self, dims, y_title=None, plot_residual=False, plot_comp_legends=False,
+            plot_masked_data=True):
+        if not isinstance(dims, (list, tuple)) or len(dims) != len(self._map_shape):
+            illegal_value(dims, 'dims', 'FitMap.plot', raise_error=True)
+        if self._result is None or self.best_fit is None or self.best_values is None:
+            logging.warning(f'Unable to plot fit for dims = {dims} in FitMap.plot')
+            return
+        if y_title is None or not isinstance(y_title, str):
+            y_title = 'data'
+        if self._mask is None:
+            mask = np.zeros(self._x.size).astype(bool)
+            plot_masked_data = False
+        else:
+            mask = self._mask
+        plots = [(self._x, np.asarray(self._ymap[dims]), 'b.')]
+        legend = [y_title]
+        if plot_masked_data:
+            plots += [(self._x[mask], np.asarray(self._ymap)[(*dims,mask)], 'bx')]
+            legend += ['masked data']
+        plots += [(self._x[~mask], self.best_fit[dims], 'k-')]
+        legend += ['best fit']
+        if plot_residual:
+            plots += [(self._x[~mask], self.residual[dims], 'k--')]
+            legend += ['residual']
+        # Create current parameters
+        parameters = deepcopy(self._parameters)
+        for name in self._best_parameters:
+            if self._parameters[name].vary:
+                parameters[name].set(value=
+                        self.best_values[self._best_parameters.index(name)][dims])
+        for component in self._result.components:
+            if 'tmp_normalization_offset_c' in component.param_names:
+                continue
+            if isinstance(component, ExpressionModel):
+                prefix = component._name
+                if prefix[-1] == '_':
+                    prefix = prefix[:-1]
+                modelname = f'{prefix}: {component.expr}'
+            else:
+                prefix = component.prefix
+                if len(prefix):
+                    if prefix[-1] == '_':
+                        prefix = prefix[:-1]
+                    modelname = f'{prefix} ({component._name})'
+                else:
+                    modelname = f'{component._name}'
+            if len(modelname) > 20:
+                modelname = f'{modelname[0:16]} ...'
+            y = component.eval(params=parameters, x=self._x[~mask])
+            if isinstance(y, (int, float)):
+                y *= np.ones(self._x[~mask].size)
+            plots += [(self._x[~mask], y, '--')]
+            if plot_comp_legends:
+                legend.append(modelname)
+        quick_plot(tuple(plots), legend=legend, title=str(dims), block=True)
+
+    def fit(self, **kwargs):
+#        t0 = time()
+        # Check input parameters
+        if self._model is None:
+            logging.error('Undefined fit model')
+        if 'num_proc' in kwargs:
+            num_proc = kwargs.pop('num_proc')
+            if not is_int(num_proc, ge=1):
+                illegal_value(num_proc, 'num_proc', 'FitMap.fit', raise_error=True)
+        else:
+            num_proc = cpu_count()
+        if num_proc > 1 and not have_joblib:
+            logging.warning(f'Missing joblib in the conda environment, running FitMap serially')
+            num_proc = 1
+        if num_proc > cpu_count():
+            logging.warning(f'The requested number of processors ({num_proc}) exceeds the maximum '+
+                    f'number of processors, num_proc reduced to ({cpu_count()})')
+            num_proc = cpu_count()
+        if 'try_no_bounds' in kwargs:
+            self._try_no_bounds = kwargs.pop('try_no_bounds')
+            if not isinstance(self._try_no_bounds, bool):
+                illegal_value(self._try_no_bounds, 'try_no_bounds', 'FitMap.fit', raise_error=True)
+        if 'redchi_cutoff' in kwargs:
+            self._redchi_cutoff = kwargs.pop('redchi_cutoff')
+            if not is_num(self._redchi_cutoff, gt=0):
+                illegal_value(self._redchi_cutoff, 'redchi_cutoff', 'FitMap.fit', raise_error=True)
+        if 'print_report' in kwargs:
+            self._print_report = kwargs.pop('print_report')
+            if not isinstance(self._print_report, bool):
+                illegal_value(self._print_report, 'print_report', 'FitMap.fit', raise_error=True)
+        if 'plot' in kwargs:
+            self._plot = kwargs.pop('plot')
+            if not isinstance(self._plot, bool):
+                illegal_value(self._plot, 'plot', 'FitMap.fit', raise_error=True)
+        if 'skip_init' in kwargs:
+            self._skip_init = kwargs.pop('skip_init')
+            if not isinstance(self._skip_init, bool):
+                illegal_value(self._skip_init, 'skip_init', 'FitMap.fit', raise_error=True)
+
+        # Apply mask if supplied:
+        if 'mask' in kwargs:
+            self._mask = kwargs.pop('mask')
+        if self._mask is not None:
+            self._mask = np.asarray(self._mask).astype(bool)
+            if self._x.size != self._mask.size:
+                raise ValueError(f'Inconsistent x and mask dimensions ({self._x.size} vs '+
+                        f'{self._mask.size})')
+
+        # Add constant offset for a normalized single component model
+        if self._result is None and self._norm is not None and self._norm[0]:
+            self.add_model('constant', prefix='tmp_normalization_offset_', parameters={'name': 'c',
+                    'value': -self._norm[0], 'vary': False, 'norm': True})
+                    #'value': -self._norm[0]/self._norm[1], 'vary': False, 'norm': False})
+
+        # Adjust existing parameters for refit:
+        if 'parameters' in kwargs:
+#            print('\nIn FitMap before adjusting existing parameters for refit:')
+#            self._parameters.pretty_print()
+#            if self._result is None:
+#                raise ValueError('Invalid parameter parameters ({parameters})')
+#            if self._best_values is None:
+#                raise ValueError('Valid self._best_values required for refitting in FitMap.fit')
+            parameters = kwargs.pop('parameters')
+#            print(f'\nparameters:\n{parameters}')
+            if isinstance(parameters, dict):
+                parameters = (parameters, )
+            elif not is_dict_series(parameters):
+                illegal_value(parameters, 'parameters', 'Fit.fit', raise_error=True)
+            for par in parameters:
+                name = par['name']
+                if name not in self._parameters:
+                    raise ValueError(f'Unable to match {name} parameter {par} to an existing one')
+                if self._parameters[name].expr is not None:
+                    raise ValueError(f'Unable to modify {name} parameter {par} (currently an '+
+                            'expression)')
+                value =  par.get('value')
+                vary =  par.get('vary')
+                if par.get('expr') is not None:
+                    raise KeyError(f'Illegal "expr" key in {name} parameter {par}')
+                self._parameters[name].set(value=value, vary=vary, min=par.get('min'),
+                        max=par.get('max'))
+                # Overwrite existing best values for fixed parameters when a value is specified
+#                print(f'best values befored resetting:\n{self._best_values}')
+                if isinstance(value, (int, float)) and vary is False:
+                    for i, nname in enumerate(self._best_parameters):
+                        if nname == name:
+                            self._best_values[i] = value
+#                print(f'best values after resetting (value={value}, vary={vary}):\n{self._best_values}')
+#RV            print('\nIn FitMap after adjusting existing parameters for refit:')
+#RV            self._parameters.pretty_print()
+
+        # Check for uninitialized parameters
+        for name, par in self._parameters.items():
+            if par.expr is None:
+                value = par.value
+                if value is None or np.isinf(value) or np.isnan(value):
+                    value = 1.0
+                    if self._norm is None or name not in self._parameter_norms:
+                        self._parameters[name].set(value=value)
+                    elif self._parameter_norms[name]:
+                        self._parameters[name].set(value=value*self._norm[1])
+
+        # Create the best parameter list, consisting of all varying parameters plus the expression
+        #   parameters in order to collect their errors
+        if self._result is None:
+            # Initial fit
+            assert(self._best_parameters is None)
+            self._best_parameters = [name for name, par in self._parameters.items()
+                    if par.vary or par.expr is not None]
+            num_new_parameters = 0
+        else:
+            # Refit
+            assert(len(self._best_parameters))
+            self._new_parameters = [name for name, par in self._parameters.items()
+                if name != 'tmp_normalization_offset_c' and name not in self._best_parameters and
+                    (par.vary or par.expr is not None)]
+            num_new_parameters = len(self._new_parameters)
+        num_best_parameters = len(self._best_parameters)
+
+        # Flatten and normalize the best values of the previous fit, remove the remaining results
+        #   of the previous fit
+        if self._result is not None:
+#            print('\nBefore flatten and normalize:')
+#            print(f'self._best_values:\n{self._best_values}')
+            self._out_of_bounds = None
+            self._max_nfev = None
+            self._redchi = None
+            self._success = None
+            self._best_fit = None
+            self._best_errors = None
+            assert(self._best_values is not None)
+            assert(self._best_values.shape[0] == num_best_parameters)
+            assert(self._best_values.shape[1:] == self._map_shape)
+            if self._transpose is not None:
+                self._best_values = np.transpose(self._best_values,
+                        [0]+[i+1 for i in self._transpose])
+            self._best_values = [np.reshape(self._best_values[i], self._map_dim)
+                for i in range(num_best_parameters)]
+            if self._norm is not None:
+                for i, name in enumerate(self._best_parameters):
+                    if self._parameter_norms.get(name, False):
+                        self._best_values[i] /= self._norm[1]
+#RV            print('\nAfter flatten and normalize:')
+#RV            print(f'self._best_values:\n{self._best_values}')
+
+        # Normalize the initial parameters (and best values for a refit)
+#        print('\nIn FitMap before normalize:')
+#        self._parameters.pretty_print()
+#        print(f'\nparameter_norms:\n{self._parameter_norms}\n')
+        self._normalize()
+#        print('\nIn FitMap after normalize:')
+#        self._parameters.pretty_print()
+#        print(f'\nparameter_norms:\n{self._parameter_norms}\n')
+
+        # Prevent initial values from sitting at boundaries
+        self._parameter_bounds = {name:{'min': par.min, 'max': par.max}
+                for name, par in self._parameters.items() if par.vary}
+        for name, par in self._parameters.items():
+            if par.vary:
+                par.set(value=self._reset_par_at_boundary(par, par.value))
+#        print('\nAfter checking boundaries:')
+#        self._parameters.pretty_print()
+
+        # Set parameter bounds to unbound (only use bounds when fit fails)
+        if self._try_no_bounds:
+            for name in self._parameter_bounds.keys():
+                self._parameters[name].set(min=-np.inf, max=np.inf)
+
+        # Allocate memory to store fit results
+        if self._mask is None:
+            x_size = self._x.size
+        else:
+            x_size = self._x[~self._mask].size
+        if num_proc == 1:
+            self._out_of_bounds_flat = np.zeros(self._map_dim, dtype=bool)
+            self._max_nfev_flat = np.zeros(self._map_dim, dtype=bool)
+            self._redchi_flat = np.zeros(self._map_dim, dtype=np.float64)
+            self._success_flat = np.zeros(self._map_dim, dtype=bool)
+            self._best_fit_flat = np.zeros((self._map_dim, x_size),
+                    dtype=self._ymap_norm.dtype)
+            self._best_errors_flat = [np.zeros(self._map_dim, dtype=np.float64)
+                    for _ in range(num_best_parameters+num_new_parameters)]
+            if self._result is None:
+                self._best_values_flat = [np.zeros(self._map_dim, dtype=np.float64)
+                        for _ in range(num_best_parameters)]
+            else:
+                self._best_values_flat = self._best_values
+                self._best_values_flat += [np.zeros(self._map_dim, dtype=np.float64)
+                        for _ in range(num_new_parameters)]
+        else:
+            self._memfolder = './joblib_memmap'
+            try:
+                mkdir(self._memfolder)
+            except FileExistsError:
+                pass
+            filename_memmap = path.join(self._memfolder, 'out_of_bounds_memmap')
+            self._out_of_bounds_flat = np.memmap(filename_memmap, dtype=bool,
+                    shape=(self._map_dim), mode='w+')
+            filename_memmap = path.join(self._memfolder, 'max_nfev_memmap')
+            self._max_nfev_flat = np.memmap(filename_memmap, dtype=bool,
+                    shape=(self._map_dim), mode='w+')
+            filename_memmap = path.join(self._memfolder, 'redchi_memmap')
+            self._redchi_flat = np.memmap(filename_memmap, dtype=np.float64,
+                    shape=(self._map_dim), mode='w+')
+            filename_memmap = path.join(self._memfolder, 'success_memmap')
+            self._success_flat = np.memmap(filename_memmap, dtype=bool,
+                    shape=(self._map_dim), mode='w+')
+            filename_memmap = path.join(self._memfolder, 'best_fit_memmap')
+            self._best_fit_flat = np.memmap(filename_memmap, dtype=self._ymap_norm.dtype,
+                    shape=(self._map_dim, x_size), mode='w+')
+            self._best_errors_flat = []
+            for i in range(num_best_parameters+num_new_parameters):
+                filename_memmap = path.join(self._memfolder, f'best_errors_memmap_{i}')
+                self._best_errors_flat.append(np.memmap(filename_memmap, dtype=np.float64,
+                        shape=self._map_dim, mode='w+'))
+            self._best_values_flat = []
+            for i in range(num_best_parameters):
+                filename_memmap = path.join(self._memfolder, f'best_values_memmap_{i}')
+                self._best_values_flat.append(np.memmap(filename_memmap, dtype=np.float64,
+                        shape=self._map_dim, mode='w+'))
+                if self._result is not None:
+                    self._best_values_flat[i][:] = self._best_values[i][:]
+            for i in range(num_new_parameters):
+                filename_memmap = path.join(self._memfolder,
+                        f'best_values_memmap_{i+num_best_parameters}')
+                self._best_values_flat.append(np.memmap(filename_memmap, dtype=np.float64,
+                        shape=self._map_dim, mode='w+'))
+
+        # Update the best parameter list
+        if num_new_parameters:
+            self._best_parameters += self._new_parameters
+
+        # Perform the first fit to get model component info and initial parameters
+        current_best_values = {}
+#        print(f'0 before:\n{current_best_values}')
+#        t1 = time()
+        self._result = self._fit(0, current_best_values, return_result=True, **kwargs)
+#        t2 = time()
+#        print(f'0 after:\n{current_best_values}')
+#        print('\nAfter the first fit:')
+#        self._parameters.pretty_print()
+#        print(self._result.fit_report(show_correl=False))
+
+        # Remove all irrelevant content from self._result
+        for attr in ('_abort', 'aborted', 'aic', 'best_fit', 'best_values', 'bic', 'calc_covar',
+                'call_kws', 'chisqr', 'ci_out', 'col_deriv', 'covar', 'data', 'errorbars',
+                'flatchain', 'ier', 'init_vals', 'init_fit', 'iter_cb', 'jacfcn', 'kws',
+                'last_internal_values', 'lmdif_message', 'message', 'method', 'nan_policy',
+                'ndata', 'nfev', 'nfree', 'params', 'redchi', 'reduce_fcn', 'residual', 'result',
+                'scale_covar', 'show_candidates', 'calc_covar', 'success', 'userargs', 'userfcn',
+                'userkws', 'values', 'var_names', 'weights', 'user_options'):
+            try:
+                delattr(self._result, attr)
+            except AttributeError:
+#                logging.warning(f'Unknown attribute {attr} in fit.FtMap._cleanup_result')
+                pass
+
+#        t3 = time()
+        if num_proc == 1:
+            # Perform the remaining fits serially
+            for n in range(1, self._map_dim):
+#                print(f'{n} before:\n{current_best_values}')
+                self._fit(n, current_best_values, **kwargs)
+#                print(f'{n} after:\n{current_best_values}')
+        else:
+            # Perform the remaining fits in parallel
+            num_fit = self._map_dim-1
+#            print(f'num_fit = {num_fit}')
+            if num_proc > num_fit:
+                logging.warning(f'The requested number of processors ({num_proc}) exceeds the '+
+                        f'number of fits, num_proc reduced to ({num_fit})')
+                num_proc = num_fit
+                num_fit_per_proc = 1
+            else:
+                num_fit_per_proc = round((num_fit)/num_proc)
+                if num_proc*num_fit_per_proc < num_fit:
+                    num_fit_per_proc +=1
+#            print(f'num_fit_per_proc = {num_fit_per_proc}')
+            num_fit_batch = min(num_fit_per_proc, 40)
+#            print(f'num_fit_batch = {num_fit_batch}')
+            with Parallel(n_jobs=num_proc) as parallel:
+                parallel(delayed(self._fit_parallel)(current_best_values, num_fit_batch,
+                        n_start, **kwargs) for n_start in range(1, self._map_dim, num_fit_batch))
+#        t4 = time()
+
+        # Renormalize the initial parameters for external use
+        if self._norm is not None and self._normalized:
+            init_values = {}
+            for name, value in self._result.init_values.items():
+                if name not in self._parameter_norms or self._parameters[name].expr is not None:
+                    init_values[name] = value
+                elif self._parameter_norms[name]:
+                    init_values[name] = value*self._norm[1]
+            self._result.init_values = init_values
+            for name, par in self._result.init_params.items():
+                if par.expr is None and self._parameter_norms.get(name, False):
+                    _min = par.min
+                    _max = par.max
+                    value = par.value*self._norm[1]
+                    if not np.isinf(_min) and abs(_min) != float_min:
+                        _min *= self._norm[1]
+                    if not np.isinf(_max) and abs(_max) != float_min:
+                        _max *= self._norm[1]
+                    par.set(value=value, min=_min, max=_max)
+                par.init_value = par.value
+
+        # Remap the best results
+#        t5 = time()
+        self._out_of_bounds = np.copy(np.reshape(self._out_of_bounds_flat, self._map_shape))
+        self._max_nfev = np.copy(np.reshape(self._max_nfev_flat, self._map_shape))
+        self._redchi = np.copy(np.reshape(self._redchi_flat, self._map_shape))
+        self._success = np.copy(np.reshape(self._success_flat, self._map_shape))
+        self._best_fit = np.copy(np.reshape(self._best_fit_flat,
+               list(self._map_shape)+[x_size]))
+        self._best_values = np.asarray([np.reshape(par, list(self._map_shape))
+                for par in self._best_values_flat])
+        self._best_errors = np.asarray([np.reshape(par, list(self._map_shape))
+                for par in self._best_errors_flat])
+        if self._inv_transpose is not None:
+            self._out_of_bounds = np.transpose(self._out_of_bounds, self._inv_transpose)
+            self._max_nfev = np.transpose(self._max_nfev, self._inv_transpose)
+            self._redchi = np.transpose(self._redchi, self._inv_transpose)
+            self._success = np.transpose(self._success, self._inv_transpose)
+            self._best_fit = np.transpose(self._best_fit,
+                    list(self._inv_transpose)+[len(self._inv_transpose)])
+            self._best_values = np.transpose(self._best_values,
+                    [0]+[i+1 for i in self._inv_transpose])
+            self._best_errors = np.transpose(self._best_errors,
+                    [0]+[i+1 for i in self._inv_transpose])
+        del self._out_of_bounds_flat
+        del self._max_nfev_flat
+        del self._redchi_flat
+        del self._success_flat
+        del self._best_fit_flat
+        del self._best_values_flat
+        del self._best_errors_flat
+#        t6 = time()
+
+        # Restore parameter bounds and renormalize the parameters
+        for name, par in self._parameter_bounds.items():
+            self._parameters[name].set(min=par['min'], max=par['max'])
+        self._normalized = False
+        if self._norm is not None:
+            for name, norm in self._parameter_norms.items():
+                par = self._parameters[name]
+                if par.expr is None and norm:
+                    value = par.value*self._norm[1]
+                    _min = par.min
+                    _max = par.max
+                    if not np.isinf(_min) and abs(_min) != float_min:
+                        _min *= self._norm[1]
+                    if not np.isinf(_max) and abs(_max) != float_min:
+                        _max *= self._norm[1]
+                    par.set(value=value, min=_min, max=_max)
+#        t7 = time()
+#        print(f'total run time in fit: {t7-t0:.2f} seconds')
+#        print(f'run time first fit: {t2-t1:.2f} seconds')
+#        print(f'run time remaining fits: {t4-t3:.2f} seconds')
+#        print(f'run time remapping results: {t6-t5:.2f} seconds')
+
+#        print('\n\nAt end fit:')
+#        self._parameters.pretty_print()
+#        print(f'self._best_values:\n{self._best_values}\n\n')
+
+        # Free the shared memory
+        self.freemem()
+
+    def _fit_parallel(self, current_best_values, num, n_start, **kwargs):
+        num = min(num, self._map_dim-n_start)
+        for n in range(num):
+#            print(f'{n_start+n} before:\n{current_best_values}')
+            self._fit(n_start+n, current_best_values, **kwargs)
+#            print(f'{n_start+n} after:\n{current_best_values}')
+
+    def _fit(self, n, current_best_values, return_result=False, **kwargs):
+#RV        print(f'\n\nstart FitMap._fit {n}\n')
+#RV        print(f'current_best_values = {current_best_values}')
+#RV        print(f'self._best_parameters = {self._best_parameters}')
+#RV        print(f'self._new_parameters = {self._new_parameters}\n\n')
+#        self._parameters.pretty_print()
+        # Set parameters to current best values, but prevent them from sitting at boundaries
+        if self._new_parameters is None:
+            # Initial fit
+            for name, value in current_best_values.items():
+                par = self._parameters[name]
+                par.set(value=self._reset_par_at_boundary(par, value))
+        else:
+            # Refit
+            for i, name in enumerate(self._best_parameters):
+                par = self._parameters[name]
+                if name in self._new_parameters:
+                    if name in current_best_values:
+                        par.set(value=self._reset_par_at_boundary(par, current_best_values[name]))
+                elif par.expr is None:
+                    par.set(value=self._best_values[i][n])
+#RV        print(f'\nbefore fit {n}')
+#RV        self._parameters.pretty_print()
+        if self._mask is None:
+            result = self._model.fit(self._ymap_norm[n], self._parameters, x=self._x, **kwargs)
+        else:
+            result = self._model.fit(self._ymap_norm[n][~self._mask], self._parameters,
+                    x=self._x[~self._mask], **kwargs)
+#        print(f'\nafter fit {n}')
+#        self._parameters.pretty_print()
+#        print(result.fit_report(show_correl=False))
+        out_of_bounds = False
+        for name, par in self._parameter_bounds.items():
+            value = result.params[name].value
+            if not np.isinf(par['min']) and value < par['min']:
+                out_of_bounds = True
+                break
+            if not np.isinf(par['max']) and value > par['max']:
+                out_of_bounds = True
+                break
+        self._out_of_bounds_flat[n] = out_of_bounds
+        if self._try_no_bounds and out_of_bounds:
+            # Rerun fit with parameter bounds in place
+            for name, par in self._parameter_bounds.items():
+                self._parameters[name].set(min=par['min'], max=par['max'])
+            # Set parameters to current best values, but prevent them from sitting at boundaries
+            if self._new_parameters is None:
+                # Initial fit
+                for name, value in current_best_values.items():
+                    par = self._parameters[name]
+                    par.set(value=self._reset_par_at_boundary(par, value))
+            else:
+                # Refit
+                for i, name in enumerate(self._best_parameters):
+                    par = self._parameters[name]
+                    if name in self._new_parameters:
+                        if name in current_best_values:
+                            par.set(value=self._reset_par_at_boundary(par,
+                                    current_best_values[name]))
+                    elif par.expr is None:
+                        par.set(value=self._best_values[i][n])
+#            print('\nbefore fit')
+#            self._parameters.pretty_print()
+#            print(result.fit_report(show_correl=False))
+            if self._mask is None:
+                result = self._model.fit(self._ymap_norm[n], self._parameters, x=self._x, **kwargs)
+            else:
+                result = self._model.fit(self._ymap_norm[n][~self._mask], self._parameters,
+                    x=self._x[~self._mask], **kwargs)
+#            print(f'\nafter fit {n}')
+#            self._parameters.pretty_print()
+#            print(result.fit_report(show_correl=False))
+            out_of_bounds = False
+            for name, par in self._parameter_bounds.items():
+                value = result.params[name].value
+                if not np.isinf(par['min']) and value < par['min']:
+                    out_of_bounds = True
+                    break
+                if not np.isinf(par['max']) and value > par['max']:
+                    out_of_bounds = True
+                    break
+#            print(f'{n} redchi < redchi_cutoff = {result.redchi < self._redchi_cutoff} success = {result.success} out_of_bounds = {out_of_bounds}')
+            # Reset parameters back to unbound
+            for name in self._parameter_bounds.keys():
+                self._parameters[name].set(min=-np.inf, max=np.inf)
+        assert(not out_of_bounds)
+        if result.redchi >= self._redchi_cutoff:
+            result.success = False
+        if result.nfev == result.max_nfev:
+#            print(f'Maximum number of function evaluations reached for n = {n}')
+#            logging.warning(f'Maximum number of function evaluations reached for n = {n}')
+            if result.redchi < self._redchi_cutoff:
+                result.success = True
+            self._max_nfev_flat[n] = True
+        if result.success:
+            assert(all(True for par in current_best_values if par in result.params.values()))
+            for par in result.params.values():
+                if par.vary:
+                    current_best_values[par.name] = par.value
+        else:
+            logging.warning(f'Fit for n = {n} failed: {result.lmdif_message}')
+        # Renormalize the data and results
+        self._renormalize(n, result)
+        if self._print_report:
+            print(result.fit_report(show_correl=False))
+        if self._plot:
+            dims = np.unravel_index(n, self._map_shape)
+            if self._inv_transpose is not None:
+                dims= tuple(dims[self._inv_transpose[i]] for i in range(len(dims)))
+            super().plot(result=result, y=np.asarray(self._ymap[dims]), plot_comp_legends=True,
+                skip_init=self._skip_init, title=str(dims))
+#RV        print(f'\n\nend FitMap._fit {n}\n')
+#RV        print(f'current_best_values = {current_best_values}')
+#        self._parameters.pretty_print()
+#        print(result.fit_report(show_correl=False))
+#RV        print(f'\nself._best_values_flat:\n{self._best_values_flat}\n\n')
+        if return_result:
+            return(result)
+
+    def _renormalize(self, n, result):
+        self._redchi_flat[n] = np.float64(result.redchi)
+        self._success_flat[n] = result.success
+        if self._norm is None or not self._normalized:
+            self._best_fit_flat[n] = result.best_fit
+            for i, name in enumerate(self._best_parameters):
+                self._best_values_flat[i][n] = np.float64(result.params[name].value)
+                self._best_errors_flat[i][n] = np.float64(result.params[name].stderr)
+        else:
+            pars = set(self._parameter_norms) & set(self._best_parameters)
+            for name, par in result.params.items():
+                if name in pars and self._parameter_norms[name]:
+                    if par.stderr is not None:
+                        par.stderr *= self._norm[1]
+                    if par.expr is None:
+                        par.value *= self._norm[1]
+                        if self._print_report:
+                            if par.init_value is not None:
+                                par.init_value *= self._norm[1]
+                            if not np.isinf(par.min) and abs(par.min) != float_min:
+                                par.min *= self._norm[1]
+                            if not np.isinf(par.max) and abs(par.max) != float_min:
+                                par.max *= self._norm[1]
+            self._best_fit_flat[n] = result.best_fit*self._norm[1]+self._norm[0]
+            for i, name in enumerate(self._best_parameters):
+                self._best_values_flat[i][n] = np.float64(result.params[name].value)
+                self._best_errors_flat[i][n] = np.float64(result.params[name].stderr)
+            if self._plot:
+                if not self._skip_init:
+                    result.init_fit = result.init_fit*self._norm[1]+self._norm[0]
+                result.best_fit = np.copy(self._best_fit_flat[n])
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/general.py	Tue Mar 21 16:22:42 2023 +0000
@@ -0,0 +1,1965 @@
+#!/usr/bin/env python3
+
+#FIX write a function that returns a list of peak indices for a given plot
+#FIX use raise_error concept on more functions to optionally raise an error
+
+# -*- coding: utf-8 -*-
+"""
+Created on Mon Dec  6 15:36:22 2021
+
+@author: rv43
+"""
+
+import logging
+logger=logging.getLogger(__name__)
+
+import os
+import sys
+import re
+try:
+    from yaml import safe_load, safe_dump
+except:
+    pass
+try:
+    import h5py
+except:
+    pass
+import numpy as np
+try:
+    import matplotlib.pyplot as plt
+    import matplotlib.lines as mlines
+    from matplotlib import transforms
+    from matplotlib.widgets import Button
+except:
+    pass
+
+from ast import literal_eval
+try:
+    from asteval import Interpreter, get_ast_names
+except:
+    pass
+from copy import deepcopy
+try:
+    from sympy import diff, simplify
+except:
+    pass
+from time import time
+
+
+def depth_list(L): return(isinstance(L, list) and max(map(depth_list, L))+1)
+def depth_tuple(T): return(isinstance(T, tuple) and max(map(depth_tuple, T))+1)
+def unwrap_tuple(T):
+    if depth_tuple(T) > 1 and len(T) == 1:
+        T = unwrap_tuple(*T)
+    return(T)
+
+def illegal_value(value, name, location=None, raise_error=False, log=True):
+    if not isinstance(location, str):
+        location = ''
+    else:
+        location = f'in {location} '
+    if isinstance(name, str):
+        error_msg = f'Illegal value for {name} {location}({value}, {type(value)})'
+    else:
+        error_msg = f'Illegal value {location}({value}, {type(value)})'
+    if log:
+        logger.error(error_msg)
+    if raise_error:
+        raise ValueError(error_msg)
+
+def illegal_combination(value1, name1, value2, name2, location=None, raise_error=False,
+        log=True):
+    if not isinstance(location, str):
+        location = ''
+    else:
+        location = f'in {location} '
+    if isinstance(name1, str):
+        error_msg = f'Illegal combination for {name1} and {name2} {location}'+ \
+                f'({value1}, {type(value1)} and {value2}, {type(value2)})'
+    else:
+        error_msg = f'Illegal combination {location}'+ \
+                f'({value1}, {type(value1)} and {value2}, {type(value2)})'
+    if log:
+        logger.error(error_msg)
+    if raise_error:
+        raise ValueError(error_msg)
+
+def test_ge_gt_le_lt(ge, gt, le, lt, func, location=None, raise_error=False, log=True):
+    """Check individual and mutual validity of ge, gt, le, lt qualifiers
+       func: is_int or is_num to test for int or numbers
+       Return: True upon success or False when mutually exlusive
+    """
+    if ge is None and gt is None and le is None and lt is None:
+        return(True)
+    if ge is not None:
+        if not func(ge):
+            illegal_value(ge, 'ge', location, raise_error, log) 
+            return(False)
+        if gt is not None:
+            illegal_combination(ge, 'ge', gt, 'gt', location, raise_error, log) 
+            return(False)
+    elif gt is not None and not func(gt):
+        illegal_value(gt, 'gt', location, raise_error, log) 
+        return(False)
+    if le is not None:
+        if not func(le):
+            illegal_value(le, 'le', location, raise_error, log) 
+            return(False)
+        if lt is not None:
+            illegal_combination(le, 'le', lt, 'lt', location, raise_error, log) 
+            return(False)
+    elif lt is not None and not func(lt):
+        illegal_value(lt, 'lt', location, raise_error, log) 
+        return(False)
+    if ge is not None:
+        if le is not None and ge > le:
+            illegal_combination(ge, 'ge', le, 'le', location, raise_error, log) 
+            return(False)
+        elif lt is not None and ge >= lt:
+            illegal_combination(ge, 'ge', lt, 'lt', location, raise_error, log) 
+            return(False)
+    elif gt is not None:
+        if le is not None and gt >= le:
+            illegal_combination(gt, 'gt', le, 'le', location, raise_error, log) 
+            return(False)
+        elif lt is not None and gt >= lt:
+            illegal_combination(gt, 'gt', lt, 'lt', location, raise_error, log) 
+            return(False)
+    return(True)
+
+def range_string_ge_gt_le_lt(ge=None, gt=None, le=None, lt=None):
+    """Return a range string representation matching the ge, gt, le, lt qualifiers
+       Does not validate the inputs, do that as needed before calling
+    """
+    range_string = ''
+    if ge is not None:
+        if le is None and lt is None:
+            range_string += f'>= {ge}'
+        else:
+            range_string += f'[{ge}, '
+    elif gt is not None:
+        if le is None and lt is None:
+            range_string += f'> {gt}'
+        else:
+            range_string += f'({gt}, '
+    if le is not None:
+        if ge is None and gt is None:
+            range_string += f'<= {le}'
+        else:
+            range_string += f'{le}]'
+    elif lt is not None:
+        if ge is None and gt is None:
+            range_string += f'< {lt}'
+        else:
+            range_string += f'{lt})'
+    return(range_string)
+
+def is_int(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
+    """Value is an integer in range ge <= v <= le or gt < v < lt or some combination.
+       Return: True if yes or False is no
+    """
+    return(_is_int_or_num(v, 'int', ge, gt, le, lt, raise_error, log))
+
+def is_num(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
+    """Value is a number in range ge <= v <= le or gt < v < lt or some combination.
+       Return: True if yes or False is no
+    """
+    return(_is_int_or_num(v, 'num', ge, gt, le, lt, raise_error, log))
+
+def _is_int_or_num(v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False,
+        log=True):
+    if type_str == 'int':
+        if not isinstance(v, int):
+            illegal_value(v, 'v', '_is_int_or_num', raise_error, log) 
+            return(False)
+        if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, '_is_int_or_num', raise_error, log):
+            return(False)
+    elif type_str == 'num':
+        if not isinstance(v, (int, float)):
+            illegal_value(v, 'v', '_is_int_or_num', raise_error, log) 
+            return(False)
+        if not test_ge_gt_le_lt(ge, gt, le, lt, is_num, '_is_int_or_num', raise_error, log):
+            return(False)
+    else:
+        illegal_value(type_str, 'type_str', '_is_int_or_num', raise_error, log) 
+        return(False)
+    if ge is None and gt is None and le is None and lt is None:
+        return(True)
+    error = False
+    if ge is not None and v < ge:
+        error = True
+        error_msg = f'Value {v} out of range: {v} !>= {ge}'
+    if not error and gt is not None and v <= gt:
+        error = True
+        error_msg = f'Value {v} out of range: {v} !> {gt}'
+    if not error and le is not None and v > le:
+        error = True
+        error_msg = f'Value {v} out of range: {v} !<= {le}'
+    if not error and lt is not None and v >= lt:
+        error = True
+        error_msg = f'Value {v} out of range: {v} !< {lt}'
+    if error:
+        if log:
+            logger.error(error_msg)
+        if raise_error:
+            raise ValueError(error_msg)
+        return(False)
+    return(True)
+
+def is_int_pair(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
+    """Value is an integer pair, each in range ge <= v[i] <= le or gt < v[i] < lt or
+           ge[i] <= v[i] <= le[i] or gt[i] < v[i] < lt[i] or some combination.
+       Return: True if yes or False is no
+    """
+    return(_is_int_or_num_pair(v, 'int', ge, gt, le, lt, raise_error, log))
+
+def is_num_pair(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
+    """Value is a number pair, each in range ge <= v[i] <= le or gt < v[i] < lt or
+           ge[i] <= v[i] <= le[i] or gt[i] < v[i] < lt[i] or some combination.
+       Return: True if yes or False is no
+    """
+    return(_is_int_or_num_pair(v, 'num', ge, gt, le, lt, raise_error, log))
+
+def _is_int_or_num_pair(v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False,
+        log=True):
+    if type_str == 'int':
+        if not (isinstance(v, (tuple, list)) and len(v) == 2 and isinstance(v[0], int) and
+                isinstance(v[1], int)):
+            illegal_value(v, 'v', '_is_int_or_num_pair', raise_error, log) 
+            return(False)
+        func = is_int
+    elif type_str == 'num':
+        if not (isinstance(v, (tuple, list)) and len(v) == 2 and isinstance(v[0], (int, float)) and
+                isinstance(v[1], (int, float))):
+            illegal_value(v, 'v', '_is_int_or_num_pair', raise_error, log) 
+            return(False)
+        func = is_num
+    else:
+        illegal_value(type_str, 'type_str', '_is_int_or_num_pair', raise_error, log)
+        return(False)
+    if ge is None and gt is None and le is None and lt is None:
+        return(True)
+    if ge is None or func(ge, log=True):
+        ge = 2*[ge]
+    elif not _is_int_or_num_pair(ge, type_str, raise_error=raise_error, log=log):
+        return(False)
+    if gt is None or func(gt, log=True):
+        gt = 2*[gt]
+    elif not _is_int_or_num_pair(gt, type_str, raise_error=raise_error, log=log):
+        return(False)
+    if le is None or func(le, log=True):
+        le = 2*[le]
+    elif not _is_int_or_num_pair(le, type_str, raise_error=raise_error, log=log):
+        return(False)
+    if lt is None or func(lt, log=True):
+        lt = 2*[lt]
+    elif not _is_int_or_num_pair(lt, type_str, raise_error=raise_error, log=log):
+        return(False)
+    if (not func(v[0], ge[0], gt[0], le[0], lt[0], raise_error, log) or
+            not func(v[1], ge[1], gt[1], le[1], lt[1], raise_error, log)):
+        return(False)
+    return(True)
+
+def is_int_series(l, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
+    """Value is a tuple or list of integers, each in range ge <= l[i] <= le or
+       gt < l[i] < lt or some combination.
+    """
+    if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, 'is_int_series', raise_error, log):
+        return(False)
+    if not isinstance(l, (tuple, list)):
+        illegal_value(l, 'l', 'is_int_series', raise_error, log)
+        return(False)
+    if any(True if not is_int(v, ge, gt, le, lt, raise_error, log) else False for v in l):
+        return(False)
+    return(True)
+
+def is_num_series(l, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
+    """Value is a tuple or list of numbers, each in range ge <= l[i] <= le or
+       gt < l[i] < lt or some combination.
+    """
+    if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, 'is_int_series', raise_error, log):
+        return(False)
+    if not isinstance(l, (tuple, list)):
+        illegal_value(l, 'l', 'is_num_series', raise_error, log)
+        return(False)
+    if any(True if not is_num(v, ge, gt, le, lt, raise_error, log) else False for v in l):
+        return(False)
+    return(True)
+
+def is_str_series(l, raise_error=False, log=True):
+    """Value is a tuple or list of strings.
+    """
+    if (not isinstance(l, (tuple, list)) or
+            any(True if not isinstance(s, str) else False for s in l)):
+        illegal_value(l, 'l', 'is_str_series', raise_error, log)
+        return(False)
+    return(True)
+
+def is_dict_series(l, raise_error=False, log=True):
+    """Value is a tuple or list of dictionaries.
+    """
+    if (not isinstance(l, (tuple, list)) or
+            any(True if not isinstance(d, dict) else False for d in l)):
+        illegal_value(l, 'l', 'is_dict_series', raise_error, log)
+        return(False)
+    return(True)
+
+def is_dict_nums(l, raise_error=False, log=True):
+    """Value is a dictionary with single number values
+    """
+    if (not isinstance(l, dict) or
+            any(True if not is_num(v, log=False) else False for v in l.values())):
+        illegal_value(l, 'l', 'is_dict_nums', raise_error, log)
+        return(False)
+    return(True)
+
+def is_dict_strings(l, raise_error=False, log=True):
+    """Value is a dictionary with single string values
+    """
+    if (not isinstance(l, dict) or
+            any(True if not isinstance(v, str) else False for v in l.values())):
+        illegal_value(l, 'l', 'is_dict_strings', raise_error, log)
+        return(False)
+    return(True)
+
+def is_index(v, ge=0, lt=None, raise_error=False, log=True):
+    """Value is an array index in range ge <= v < lt.
+       NOTE lt IS NOT included!
+    """
+    if isinstance(lt, int):
+        if lt <= ge:
+            illegal_combination(ge, 'ge', lt, 'lt', 'is_index', raise_error, log) 
+            return(False)
+    return(is_int(v, ge=ge, lt=lt, raise_error=raise_error, log=log))
+
+def is_index_range(v, ge=0, le=None, lt=None, raise_error=False, log=True):
+    """Value is an array index range in range ge <= v[0] <= v[1] <= le or ge <= v[0] <= v[1] < lt.
+       NOTE le IS included!
+    """
+    if not is_int_pair(v, raise_error=raise_error, log=log):
+        return(False)
+    if not test_ge_gt_le_lt(ge, None, le, lt, is_int, 'is_index_range', raise_error, log):
+        return(False)
+    if not ge <= v[0] <= v[1] or (le is not None and v[1] > le) or (lt is not None and v[1] >= lt):
+        if le is not None:
+            error_msg = f'Value {v} out of range: !({ge} <= {v[0]} <= {v[1]} <= {le})'
+        else:
+            error_msg = f'Value {v} out of range: !({ge} <= {v[0]} <= {v[1]} < {lt})'
+        if log:
+            logger.error(error_msg)
+        if raise_error:
+            raise ValueError(error_msg)
+        return(False)
+    return(True)
+
+def index_nearest(a, value):
+    a = np.asarray(a)
+    if a.ndim > 1:
+        raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})')
+    # Round up for .5
+    value *= 1.0+sys.float_info.epsilon
+    return((int)(np.argmin(np.abs(a-value))))
+
+def index_nearest_low(a, value):
+    a = np.asarray(a)
+    if a.ndim > 1:
+        raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})')
+    index = int(np.argmin(np.abs(a-value)))
+    if value < a[index] and index > 0:
+        index -= 1
+    return(index)
+
+def index_nearest_upp(a, value):
+    a = np.asarray(a)
+    if a.ndim > 1:
+        raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})')
+    index = int(np.argmin(np.abs(a-value)))
+    if value > a[index] and index < a.size-1:
+        index += 1
+    return(index)
+
+def round_to_n(x, n=1):
+    if x == 0.0:
+        return(0)
+    else:
+        return(type(x)(round(x, n-1-int(np.floor(np.log10(abs(x)))))))
+
+def round_up_to_n(x, n=1):
+    xr = round_to_n(x, n)
+    if abs(x/xr) > 1.0:
+        xr += np.sign(x)*10**(np.floor(np.log10(abs(x)))+1-n)
+    return(type(x)(xr))
+
+def trunc_to_n(x, n=1):
+    xr = round_to_n(x, n)
+    if abs(xr/x) > 1.0:
+        xr -= np.sign(x)*10**(np.floor(np.log10(abs(x)))+1-n)
+    return(type(x)(xr))
+
+def almost_equal(a, b, sig_figs):
+    if is_num(a) and is_num(b):
+        return(abs(round_to_n(a-b, sig_figs)) < pow(10, -sig_figs+1))
+    else:
+        raise ValueError(f'Invalid value for a or b in almost_equal (a: {a}, {type(a)}, '+
+                f'b: {b}, {type(b)})')
+        return(False)
+
+def string_to_list(s, split_on_dash=True, remove_duplicates=True, sort=True):
+    """Return a list of numbers by splitting/expanding a string on any combination of
+       commas, whitespaces, or dashes (when split_on_dash=True)
+       e.g: '1, 3, 5-8, 12 ' -> [1, 3, 5, 6, 7, 8, 12]
+    """
+    if not isinstance(s, str):
+        illegal_value(s, location='string_to_list') 
+        return(None)
+    if not len(s):
+        return([])
+    try:
+        ll = [x for x in re.split('\s+,\s+|\s+,|,\s+|\s+|,', s.strip())]
+    except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
+        return(None)
+    if split_on_dash:
+        try:
+            l = []
+            for l1 in ll:
+                l2 = [literal_eval(x) for x in re.split('\s+-\s+|\s+-|-\s+|\s+|-', l1)]
+                if len(l2) == 1:
+                    l += l2
+                elif len(l2) == 2 and l2[1] > l2[0]:
+                    l += [i for i in range(l2[0], l2[1]+1)]
+                else:
+                    raise ValueError
+        except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
+            return(None)
+    else:
+        l = [literal_eval(x) for x in ll]
+    if remove_duplicates:
+        l = list(dict.fromkeys(l))
+    if sort:
+        l = sorted(l)
+    return(l)
+
+def get_trailing_int(string):
+    indexRegex = re.compile(r'\d+$')
+    mo = indexRegex.search(string)
+    if mo is None:
+        return(None)
+    else:
+        return(int(mo.group()))
+
+def input_int(s=None, ge=None, gt=None, le=None, lt=None, default=None, inset=None,
+        raise_error=False, log=True):
+    return(_input_int_or_num('int', s, ge, gt, le, lt, default, inset, raise_error, log))
+
+def input_num(s=None, ge=None, gt=None, le=None, lt=None, default=None, raise_error=False,
+        log=True):
+    return(_input_int_or_num('num', s, ge, gt, le, lt, default, None, raise_error,log))
+
+def _input_int_or_num(type_str, s=None, ge=None, gt=None, le=None, lt=None, default=None,
+        inset=None, raise_error=False, log=True):
+    if type_str == 'int':
+        if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, '_input_int_or_num', raise_error, log):
+            return(None)
+    elif type_str == 'num':
+        if not test_ge_gt_le_lt(ge, gt, le, lt, is_num, '_input_int_or_num', raise_error, log):
+            return(None)
+    else:
+        illegal_value(type_str, 'type_str', '_input_int_or_num', raise_error, log)
+        return(None)
+    if default is not None:
+        if not _is_int_or_num(default, type_str, raise_error=raise_error, log=log):
+            return(None)
+        if ge is not None and default < ge:
+            illegal_combination(ge, 'ge', default, 'default', '_input_int_or_num', raise_error,
+                log)
+            return(None)
+        if gt is not None and default <= gt:
+            illegal_combination(gt, 'gt', default, 'default', '_input_int_or_num', raise_error,
+                log)
+            return(None)
+        if le is not None and default > le:
+            illegal_combination(le, 'le', default, 'default', '_input_int_or_num', raise_error,
+                log)
+            return(None)
+        if lt is not None and default >= lt:
+            illegal_combination(lt, 'lt', default, 'default', '_input_int_or_num', raise_error,
+                log)
+            return(None)
+        default_string = f' [{default}]'
+    else:
+        default_string = ''
+    if inset is not None:
+        if (not isinstance(inset, (tuple, list)) or any(True if not isinstance(i, int) else
+                False for i in inset)):
+            illegal_value(inset, 'inset', '_input_int_or_num', raise_error, log) 
+            return(None)
+    v_range = f'{range_string_ge_gt_le_lt(ge, gt, le, lt)}'
+    if len(v_range):
+        v_range = f' {v_range}'
+    if s is None:
+        if type_str == 'int':
+            print(f'Enter an integer{v_range}{default_string}: ')
+        else:
+            print(f'Enter a number{v_range}{default_string}: ')
+    else:
+        print(f'{s}{v_range}{default_string}: ')
+    try:
+        i = input()
+        if isinstance(i, str) and not len(i):
+            v = default
+            print(f'{v}')
+        else:
+            v = literal_eval(i)
+        if inset and v not in inset:
+           raise ValueError(f'{v} not part of the set {inset}')
+    except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
+        v = None
+    except:
+        if log:
+            logger.error('Unexpected error')
+        if raise_error:
+            raise ValueError('Unexpected error')
+    if not _is_int_or_num(v, type_str, ge, gt, le, lt):
+        v = _input_int_or_num(type_str, s, ge, gt, le, lt, default, inset, raise_error, log)
+    return(v)
+
+def input_int_list(s=None, ge=None, le=None, split_on_dash=True, remove_duplicates=True,
+        sort=True, raise_error=False, log=True):
+    """Prompt the user to input a list of interger and split the entered string on any combination
+       of commas, whitespaces, or dashes (when split_on_dash is True)
+       e.g: '1 3,5-8 , 12 ' -> [1, 3, 5, 6, 7, 8, 12]
+       remove_duplicates: removes duplicates if True (may also change the order)
+       sort: sort in ascending order if True
+       return None upon an illegal input
+    """
+    return(_input_int_or_num_list('int', s, ge, le, split_on_dash, remove_duplicates, sort,
+        raise_error, log))
+
+def input_num_list(s=None, ge=None, le=None, remove_duplicates=True, sort=True, raise_error=False,
+        log=True):
+    """Prompt the user to input a list of numbers and split the entered string on any combination
+       of commas or whitespaces
+       e.g: '1.0, 3, 5.8, 12 ' -> [1.0, 3.0, 5.8, 12.0]
+       remove_duplicates: removes duplicates if True (may also change the order)
+       sort: sort in ascending order if True
+       return None upon an illegal input
+    """
+    return(_input_int_or_num_list('num', s, ge, le, False, remove_duplicates, sort, raise_error,
+        log))
+
+def _input_int_or_num_list(type_str, s=None, ge=None, le=None, split_on_dash=True,
+        remove_duplicates=True, sort=True, raise_error=False, log=True):
+    #FIX do we want a limit on max dimension?
+    if type_str == 'int':
+        if not test_ge_gt_le_lt(ge, None, le, None, is_int, 'input_int_or_num_list', raise_error,
+                log):
+            return(None)
+    elif type_str == 'num':
+        if not test_ge_gt_le_lt(ge, None, le, None, is_num, 'input_int_or_num_list', raise_error,
+                log):
+            return(None)
+    else:
+        illegal_value(type_str, 'type_str', '_input_int_or_num_list')
+        return(None)
+    v_range = f'{range_string_ge_gt_le_lt(ge=ge, le=le)}'
+    if len(v_range):
+        v_range = f' (each value in {v_range})'
+    if s is None:
+        print(f'Enter a series of integers{v_range}: ')
+    else:
+        print(f'{s}{v_range}: ')
+    try:
+        l = string_to_list(input(), split_on_dash, remove_duplicates, sort)
+    except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
+        l = None
+    except:
+        print('Unexpected error')
+        raise
+    if (not isinstance(l, list) or
+            any(True if not _is_int_or_num(v, type_str, ge=ge, le=le) else False for v in l)):
+        if split_on_dash:
+            print('Invalid input: enter a valid set of dash/comma/whitespace separated integers '+
+                    'e.g. 1 3,5-8 , 12')
+        else:
+            print('Invalid input: enter a valid set of comma/whitespace separated integers '+
+                    'e.g. 1 3,5 8 , 12')
+        l = _input_int_or_num_list(type_str, s, ge, le, split_on_dash, remove_duplicates, sort,
+            raise_error, log)
+    return(l)
+
+def input_yesno(s=None, default=None):
+    if default is not None:
+        if not isinstance(default, str):
+            illegal_value(default, 'default', 'input_yesno') 
+            return(None)
+        if default.lower() in 'yes':
+            default = 'y'
+        elif default.lower() in 'no':
+            default = 'n'
+        else:
+            illegal_value(default, 'default', 'input_yesno') 
+            return(None)
+        default_string = f' [{default}]'
+    else:
+        default_string = ''
+    if s is None:
+        print(f'Enter yes or no{default_string}: ')
+    else:
+        print(f'{s}{default_string}: ')
+    i = input()
+    if isinstance(i, str) and not len(i):
+        i = default
+        print(f'{i}')
+    if i is not None and i.lower() in 'yes':
+        v = True
+    elif i is not None and i.lower() in 'no':
+        v = False
+    else:
+        print('Invalid input, enter yes or no')
+        v = input_yesno(s, default)
+    return(v)
+
+def input_menu(items, default=None, header=None):
+    if not isinstance(items, (tuple, list)) or any(True if not isinstance(i, str) else False
+            for i in items):
+        illegal_value(items, 'items', 'input_menu') 
+        return(None)
+    if default is not None:
+        if not (isinstance(default, str) and default in items):
+            logger.error(f'Invalid value for default ({default}), must be in {items}') 
+            return(None)
+        default_string = f' [{items.index(default)+1}]'
+    else:
+        default_string = ''
+    if header is None:
+        print(f'Choose one of the following items (1, {len(items)}){default_string}:')
+    else:
+        print(f'{header} (1, {len(items)}){default_string}:')
+    for i, choice in enumerate(items):
+        print(f'  {i+1}: {choice}')
+    try:
+        choice  = input()
+        if isinstance(choice, str) and not len(choice):
+            choice = items.index(default)
+            print(f'{choice+1}')
+        else:
+            choice = literal_eval(choice)
+            if isinstance(choice, int) and 1 <= choice <= len(items):
+                choice -= 1
+            else:
+                raise ValueError
+    except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
+        choice = None
+    except:
+        print('Unexpected error')
+        raise
+    if choice is None:
+        print(f'Invalid choice, enter a number between 1 and {len(items)}')
+        choice = input_menu(items, default)
+    return(choice)
+
+def assert_no_duplicates_in_list_of_dicts(l: list, raise_error=False) -> list:
+    if not isinstance(l, list):
+        illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error)
+        return(None)
+    if any(True if not isinstance(d, dict) else False for d in l):
+        illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error)
+        return(None)
+    if len(l) != len([dict(t) for t in {tuple(sorted(d.items())) for d in l}]):
+        if raise_error:
+            raise ValueError(f'Duplicate items found in {l}')
+        else:
+            logger.error(f'Duplicate items found in {l}')
+        return(None)
+    else:
+        return(l)
+
+def assert_no_duplicate_key_in_list_of_dicts(l: list, key: str, raise_error=False) -> list:
+    if not isinstance(key, str):
+        illegal_value(key, 'key', 'assert_no_duplicate_key_in_list_of_dicts', raise_error)
+        return(None)
+    if not isinstance(l, list):
+        illegal_value(l, 'l', 'assert_no_duplicate_key_in_list_of_dicts', raise_error)
+        return(None)
+    if any(True if not isinstance(d, dict) else False for d in l):
+        illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error)
+        return(None)
+    keys = [d.get(key, None) for d in l]
+    if None in keys or len(set(keys)) != len(l):
+        if raise_error:
+            raise ValueError(f'Duplicate or missing key ({key}) found in {l}')
+        else:
+            logger.error(f'Duplicate or missing key ({key}) found in {l}')
+        return(None)
+    else:
+        return(l)
+
+def assert_no_duplicate_attr_in_list_of_objs(l: list, attr: str, raise_error=False) -> list:
+    if not isinstance(attr, str):
+        illegal_value(attr, 'attr', 'assert_no_duplicate_attr_in_list_of_objs', raise_error)
+        return(None)
+    if not isinstance(l, list):
+        illegal_value(l, 'l', 'assert_no_duplicate_key_in_list_of_objs', raise_error)
+        return(None)
+    attrs = [getattr(obj, attr, None) for obj in l]
+    if None in attrs or len(set(attrs)) != len(l):
+        if raise_error:
+            raise ValueError(f'Duplicate or missing attr ({attr}) found in {l}')
+        else:
+            logger.error(f'Duplicate or missing attr ({attr}) found in {l}')
+        return(None)
+    else:
+        return(l)
+
+def file_exists_and_readable(path):
+    if not os.path.isfile(path):
+        raise ValueError(f'{path} is not a valid file')
+    elif not os.access(path, os.R_OK):
+        raise ValueError(f'{path} is not accessible for reading')
+    else:
+        return(path)
+
+def create_mask(x, bounds=None, exclude_bounds=False, current_mask=None):
+    # bounds is a pair of number in the same units a x
+    if not isinstance(x, (tuple, list, np.ndarray)) or not len(x):
+        logger.warning(f'Invalid input array ({x}, {type(x)})')
+        return(None)
+    if bounds is not None and not is_num_pair(bounds):
+        logger.warning(f'Invalid bounds parameter ({bounds} {type(bounds)}, input ignored')
+        bounds = None
+    if bounds is not None:
+        if exclude_bounds:
+            mask = np.logical_or(x < min(bounds), x > max(bounds))
+        else:
+            mask = np.logical_and(x > min(bounds), x < max(bounds))
+    else:
+        mask = np.ones(len(x), dtype=bool)
+    if current_mask is not None:
+        if not isinstance(current_mask, (tuple, list, np.ndarray)) or len(current_mask) != len(x):
+            logger.warning(f'Invalid current_mask ({current_mask}, {type(current_mask)}), '+
+                    'input ignored')
+        else:
+            mask = np.logical_or(mask, current_mask)
+    if not True in mask:
+        logger.warning('Entire data array is masked')
+    return(mask)
+
+def eval_expr(name, expr, expr_variables, user_variables=None, max_depth=10, raise_error=False,
+        log=True, **kwargs):
+    """Evaluate an expression of expressions
+    """
+    if not isinstance(name, str):
+        illegal_value(name, 'name', 'eval_expr', raise_error, log)
+        return(None)
+    if not isinstance(expr, str):
+        illegal_value(expr, 'expr', 'eval_expr', raise_error, log)
+        return(None)
+    if not is_dict_strings(expr_variables, log=False):
+        illegal_value(expr_variables, 'expr_variables', 'eval_expr', raise_error, log)
+        return(None)
+    if user_variables is not None and not is_dict_nums(user_variables, log=False):
+        illegal_value(user_variables, 'user_variables', 'eval_expr', raise_error, log)
+        return(None)
+    if not is_int(max_depth, gt=1, log=False):
+        illegal_value(max_depth, 'max_depth', 'eval_expr', raise_error, log)
+        return(None)
+    if not isinstance(raise_error, bool):
+        illegal_value(raise_error, 'raise_error', 'eval_expr', raise_error, log)
+        return(None)
+    if not isinstance(log, bool):
+        illegal_value(log, 'log', 'eval_expr', raise_error, log)
+        return(None)
+#    print(f'\nEvaluate the full expression for {expr}')
+    if 'chain' in kwargs:
+        chain = kwargs.pop('chain')
+        if not is_str_series(chain):
+            illegal_value(chain, 'chain', 'eval_expr', raise_error, log)
+            return(None)
+    else:
+        chain = []
+    if len(chain) > max_depth:
+        error_msg = 'Exceeded maximum depth ({max_depth}) in eval_expr'
+        if log:
+            logger.error(error_msg)
+        if raise_error:
+            raise ValueError(error_msg)
+        return(None)
+    if name not in chain:
+        chain.append(name)
+#    print(f'start: chain = {chain}')
+    if 'ast' in kwargs:
+        ast = kwargs.pop('ast')
+    else:
+        ast = Interpreter()
+    if user_variables is not None:
+        ast.symtable.update(user_variables)
+    chain_vars = [var for var in get_ast_names(ast.parse(expr))
+            if var in expr_variables and var not in ast.symtable]
+#    print(f'chain_vars: {chain_vars}')
+    save_chain = chain.copy()
+    for var in chain_vars:
+#        print(f'\n\tname = {name}, var = {var}:\n\t\t{expr_variables[var]}')
+#        print(f'\tchain = {chain}')
+        if var in chain:
+            error_msg = f'Circular variable {var} in eval_expr'
+            if log:
+               logger.error(error_msg)
+            if raise_error:
+               raise ValueError(error_msg)
+            return(None)
+#        print(f'\tknown symbols:\n\t\t{ast.user_defined_symbols()}\n')
+        if var in ast.user_defined_symbols():
+            val = ast.symtable[var]
+        else:
+            #val = eval_expr(var, expr_variables[var], expr_variables, user_variables=user_variables,
+            val = eval_expr(var, expr_variables[var], expr_variables, max_depth=max_depth,
+                    raise_error=raise_error, log=log, chain=chain, ast=ast)
+            if val is None:
+                return(None)
+            ast.symtable[var] = val
+#        print(f'\tval = {val}')
+#        print(f'\t{var} = {ast.symtable[var]}')
+        chain = save_chain.copy()
+#        print(f'\treset loop for {var}: chain = {chain}')
+    val = ast.eval(expr)
+#    print(f'return val for {expr} = {val}\n')
+    return(val)
+
+def full_gradient(expr, x, expr_name=None, expr_variables=None, valid_variables=None, max_depth=10,
+        raise_error=False, log=True, **kwargs):
+    """Compute the full gradient dexpr/dx
+    """
+    if not isinstance(x, str):
+        illegal_value(x, 'x', 'full_gradient', raise_error, log)
+        return(None)
+    if expr_name is not None and not isinstance(expr_name, str):
+        illegal_value(expr_name, 'expr_name', 'eval_expr', raise_error, log)
+        return(None)
+    if expr_variables is not None and not is_dict_strings(expr_variables, log=False):
+        illegal_value(expr_variables, 'expr_variables', 'full_gradient', raise_error, log)
+        return(None)
+    if valid_variables is not None and not is_str_series(valid_variables, log=False):
+        illegal_value(valid_variables, 'valid_variables', 'full_gradient', raise_error, log)
+    if not is_int(max_depth, gt=1, log=False):
+        illegal_value(max_depth, 'max_depth', 'eval_expr', raise_error, log)
+        return(None)
+    if not isinstance(raise_error, bool):
+        illegal_value(raise_error, 'raise_error', 'eval_expr', raise_error, log)
+        return(None)
+    if not isinstance(log, bool):
+        illegal_value(log, 'log', 'eval_expr', raise_error, log)
+        return(None)
+#    print(f'\nGet full gradient of {expr_name} = {expr} with respect to {x}')
+    if expr_name is not None and expr_name == x:
+        return(1.0)
+    if 'chain' in kwargs:
+        chain = kwargs.pop('chain')
+        if not is_str_series(chain):
+            illegal_value(chain, 'chain', 'eval_expr', raise_error, log)
+            return(None)
+    else:
+        chain = []
+    if len(chain) > max_depth:
+        error_msg = 'Exceeded maximum depth ({max_depth}) in eval_expr'
+        if log:
+            logger.error(error_msg)
+        if raise_error:
+            raise ValueError(error_msg)
+        return(None)
+    if expr_name is not None and expr_name not in chain:
+        chain.append(expr_name)
+#    print(f'start ({x}): chain = {chain}')
+    ast = Interpreter()
+    if expr_variables is None:
+        chain_vars = []
+    else:
+        chain_vars = [var for var in get_ast_names(ast.parse(f'{expr}'))
+                if var in expr_variables and var != x and var not in ast.symtable]
+#    print(f'chain_vars: {chain_vars}')
+    if valid_variables is not None:
+        unknown_vars = [var for var in chain_vars if var not in valid_variables]
+        if len(unknown_vars):
+            error_msg = f'Unknown variable {unknown_vars} in {expr}'
+            if log:
+               logger.error(error_msg)
+            if raise_error:
+               raise ValueError(error_msg)
+            return(None)
+    dexpr_dx = diff(expr, x)
+#    print(f'direct gradient: d({expr})/d({x}) = {dexpr_dx} ({type(dexpr_dx)})')
+    save_chain = chain.copy()
+    for var in chain_vars:
+#        print(f'\n\texpr_name = {expr_name}, var = {var}:\n\t\t{expr}')
+#        print(f'\tchain = {chain}')
+        if var in chain:
+            error_msg = f'Circular variable {var} in full_gradient'
+            if log:
+               logger.error(error_msg)
+            if raise_error:
+               raise ValueError(error_msg)
+            return(None)
+        dexpr_dvar = diff(expr, var)
+#        print(f'\td({expr})/d({var}) = {dexpr_dvar}')
+        if dexpr_dvar:
+            dvar_dx = full_gradient(expr_variables[var], x, expr_name=var,
+                    expr_variables=expr_variables, valid_variables=valid_variables,
+                    max_depth=max_depth, raise_error=raise_error, log=log, chain=chain)
+#            print(f'\t\td({var})/d({x}) = {dvar_dx}')
+            if dvar_dx:
+                dexpr_dx = f'{dexpr_dx}+({dexpr_dvar})*({dvar_dx})'
+#            print(f'\t\t2: chain = {chain}')
+        chain = save_chain.copy()
+#        print(f'\treset loop for {var}: chain = {chain}')
+#    print(f'full gradient: d({expr})/d({x}) = {dexpr_dx} ({type(dexpr_dx)})')
+#    print(f'reset end: chain = {chain}\n\n')
+    return(simplify(dexpr_dx))
+
+def bounds_from_mask(mask, return_include_bounds:bool=True):
+    bounds = []
+    for i, m in enumerate(mask):
+        if m == return_include_bounds:
+            if len(bounds) == 0 or type(bounds[-1]) == tuple:
+                bounds.append(i)
+        else:
+            if len(bounds) > 0 and isinstance(bounds[-1], int):
+                bounds[-1] = (bounds[-1], i-1)
+    if len(bounds) > 0 and isinstance(bounds[-1], int):
+        bounds[-1] = (bounds[-1], mask.size-1)
+    return(bounds)
+
+def draw_mask_1d(ydata, xdata=None, current_index_ranges=None, current_mask=None,
+        select_mask=True, num_index_ranges_max=None, title=None, legend=None, test_mode=False):
+    #FIX make color blind friendly
+    def draw_selections(ax, current_include, current_exclude, selected_index_ranges):
+        ax.clear()
+        ax.set_title(title)
+        ax.legend([legend])
+        ax.plot(xdata, ydata, 'k')
+        for (low, upp) in current_include:
+            xlow = 0.5*(xdata[max(0, low-1)]+xdata[low])
+            xupp = 0.5*(xdata[upp]+xdata[min(num_data-1, upp+1)])
+            ax.axvspan(xlow, xupp, facecolor='green', alpha=0.5)
+        for (low, upp) in current_exclude:
+            xlow = 0.5*(xdata[max(0, low-1)]+xdata[low])
+            xupp = 0.5*(xdata[upp]+xdata[min(num_data-1, upp+1)])
+            ax.axvspan(xlow, xupp, facecolor='red', alpha=0.5)
+        for (low, upp) in selected_index_ranges:
+            xlow = 0.5*(xdata[max(0, low-1)]+xdata[low])
+            xupp = 0.5*(xdata[upp]+xdata[min(num_data-1, upp+1)])
+            ax.axvspan(xlow, xupp, facecolor=selection_color, alpha=0.5)
+        ax.get_figure().canvas.draw()
+
+    def onclick(event):
+        if event.inaxes in [fig.axes[0]]:
+            selected_index_ranges.append(index_nearest_upp(xdata, event.xdata))
+
+    def onrelease(event):
+        if len(selected_index_ranges) > 0:
+            if isinstance(selected_index_ranges[-1], int):
+                if event.inaxes in [fig.axes[0]]:
+                    event.xdata = index_nearest_low(xdata, event.xdata)
+                    if selected_index_ranges[-1] <= event.xdata:
+                        selected_index_ranges[-1] = (selected_index_ranges[-1], event.xdata)
+                    else:
+                        selected_index_ranges[-1] = (event.xdata, selected_index_ranges[-1])
+                    draw_selections(event.inaxes, current_include, current_exclude, selected_index_ranges)
+                else:
+                    selected_index_ranges.pop(-1)
+
+    def confirm_selection(event):
+        plt.close()
+    
+    def clear_last_selection(event):
+        if len(selected_index_ranges):
+            selected_index_ranges.pop(-1)
+        else:
+            while len(current_include):
+                current_include.pop()
+            while len(current_exclude):
+                current_exclude.pop()
+            selected_mask.fill(False)
+        draw_selections(ax, current_include, current_exclude, selected_index_ranges)
+
+    def update_mask(mask, selected_index_ranges, unselected_index_ranges):
+        for (low, upp) in selected_index_ranges:
+            selected_mask = np.logical_and(xdata >= xdata[low], xdata <= xdata[upp])
+            mask = np.logical_or(mask, selected_mask)
+        for (low, upp) in unselected_index_ranges:
+            unselected_mask = np.logical_and(xdata >= xdata[low], xdata <= xdata[upp])
+            mask[unselected_mask] = False
+        return(mask)
+
+    def update_index_ranges(mask):
+        # Update the currently included index ranges (where mask is True)
+        current_include = []
+        for i, m in enumerate(mask):
+            if m == True:
+                if len(current_include) == 0 or type(current_include[-1]) == tuple:
+                    current_include.append(i)
+            else:
+                if len(current_include) > 0 and isinstance(current_include[-1], int):
+                    current_include[-1] = (current_include[-1], i-1)
+        if len(current_include) > 0 and isinstance(current_include[-1], int):
+            current_include[-1] = (current_include[-1], num_data-1)
+        return(current_include)
+
+    # Check inputs
+    ydata = np.asarray(ydata)
+    if ydata.ndim > 1:
+        logger.warning(f'Invalid ydata dimension ({ydata.ndim})')
+        return(None, None)
+    num_data = ydata.size
+    if xdata is None:
+        xdata = np.arange(num_data)
+    else:
+        xdata = np.asarray(xdata, dtype=np.float64)
+        if xdata.ndim > 1 or xdata.size != num_data:
+            logger.warning(f'Invalid xdata shape ({xdata.shape})')
+            return(None, None)
+        if not np.all(xdata[:-1] < xdata[1:]):
+            logger.warning('Invalid xdata: must be monotonically increasing')
+            return(None, None)
+    if current_index_ranges is not None:
+        if not isinstance(current_index_ranges, (tuple, list)):
+            logger.warning('Invalid current_index_ranges parameter ({current_index_ranges}, '+
+                    f'{type(current_index_ranges)})')
+            return(None, None)
+    if not isinstance(select_mask, bool):
+        logger.warning('Invalid select_mask parameter ({select_mask}, {type(select_mask)})')
+        return(None, None)
+    if num_index_ranges_max is not None:
+        logger.warning('num_index_ranges_max input not yet implemented in draw_mask_1d')
+    if title is None:
+        title = 'select ranges of data'
+    elif not isinstance(title, str):
+        illegal(title, 'title')
+        title = ''
+    if legend is None and not isinstance(title, str):
+        illegal(legend, 'legend')
+        legend = None
+
+    if select_mask:
+        title = f'Click and drag to {title} you wish to include'
+        selection_color = 'green'
+    else:
+        title = f'Click and drag to {title} you wish to exclude'
+        selection_color = 'red'
+
+    # Set initial selected mask and the selected/unselected index ranges as needed
+    selected_index_ranges = []
+    unselected_index_ranges = []
+    selected_mask = np.full(xdata.shape, False, dtype=bool)
+    if current_index_ranges is None:
+        if current_mask is None:
+            if not select_mask:
+                selected_index_ranges = [(0, num_data-1)]
+                selected_mask = np.full(xdata.shape, True, dtype=bool)
+        else:
+            selected_mask = np.copy(np.asarray(current_mask, dtype=bool))
+    if current_index_ranges is not None and len(current_index_ranges):
+        current_index_ranges = sorted([(low, upp) for (low, upp) in current_index_ranges])
+        for (low, upp) in current_index_ranges:
+            if low > upp or low >= num_data or upp < 0:
+                continue
+            if low < 0:
+                low = 0
+            if upp >= num_data:
+                upp = num_data-1
+            selected_index_ranges.append((low, upp))
+        selected_mask = update_mask(selected_mask, selected_index_ranges, unselected_index_ranges)
+    if current_index_ranges is not None and current_mask is not None:
+        selected_mask = np.logical_and(current_mask, selected_mask)
+    if current_mask is not None:
+        selected_index_ranges = update_index_ranges(selected_mask)
+
+    # Set up range selections for display
+    current_include = selected_index_ranges
+    current_exclude = []
+    selected_index_ranges = []
+    if not len(current_include):
+        if select_mask:
+            current_exclude = [(0, num_data-1)]
+        else:
+            current_include = [(0, num_data-1)]
+    else:
+        if current_include[0][0] > 0:
+            current_exclude.append((0, current_include[0][0]-1))
+        for i in range(1, len(current_include)):
+            current_exclude.append((current_include[i-1][1]+1, current_include[i][0]-1))
+        if current_include[-1][1] < num_data-1:
+            current_exclude.append((current_include[-1][1]+1, num_data-1))
+
+    if not test_mode:
+
+        # Set up matplotlib figure
+        plt.close('all')
+        fig, ax = plt.subplots()
+        plt.subplots_adjust(bottom=0.2)
+        draw_selections(ax, current_include, current_exclude, selected_index_ranges)
+
+        # Set up event handling for click-and-drag range selection
+        cid_click = fig.canvas.mpl_connect('button_press_event', onclick)
+        cid_release = fig.canvas.mpl_connect('button_release_event', onrelease)
+
+        # Set up confirm / clear range selection buttons
+        confirm_b = Button(plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm')
+        clear_b = Button(plt.axes([0.59, 0.05, 0.15, 0.075]), 'Clear')
+        cid_confirm = confirm_b.on_clicked(confirm_selection)
+        cid_clear = clear_b.on_clicked(clear_last_selection)
+
+        # Show figure
+        plt.show(block=True)
+
+        # Disconnect callbacks when figure is closed
+        fig.canvas.mpl_disconnect(cid_click)
+        fig.canvas.mpl_disconnect(cid_release)
+        confirm_b.disconnect(cid_confirm)
+        clear_b.disconnect(cid_clear)
+
+    # Swap selection depending on select_mask
+    if not select_mask:
+        selected_index_ranges, unselected_index_ranges = unselected_index_ranges, \
+                selected_index_ranges
+
+    # Update the mask with the currently selected/unselected x-ranges
+    selected_mask = update_mask(selected_mask, selected_index_ranges, unselected_index_ranges)
+
+    # Update the currently included index ranges (where mask is True)
+    current_include = update_index_ranges(selected_mask)
+    
+    return(selected_mask, current_include)
+
+def select_peaks(ydata:np.ndarray, x_values:np.ndarray=None, x_mask:np.ndarray=None,
+                 peak_x_values:np.ndarray=np.array([]), peak_x_indices:np.ndarray=np.array([]),
+                 return_peak_x_values:bool=False, return_peak_x_indices:bool=False,
+                 return_peak_input_indices:bool=False, return_sorted:bool=False,
+                 title:str=None, xlabel:str=None, ylabel:str=None) -> list :
+
+    # Check arguments
+    if (len(peak_x_values) > 0 or return_peak_x_values) and not len(x_values) > 0:
+        raise RuntimeError('Cannot use peak_x_values or return_peak_x_values without x_values')
+    if not ((len(peak_x_values) > 0) ^ (len(peak_x_indices) > 0)):
+        raise RuntimeError('Use exactly one of peak_x_values or peak_x_indices')
+    return_format_iter = iter((return_peak_x_values, return_peak_x_indices, return_peak_input_indices))
+    if not (any(return_format_iter) and not any(return_format_iter)):
+        raise RuntimeError('Exactly one of return_peak_x_values, return_peak_x_indices, or '+
+                'return_peak_input_indices must be True')
+
+    EXCLUDE_PEAK_PROPERTIES = {'color': 'black', 'linestyle': '--','linewidth': 1,
+                               'marker': 10, 'markersize': 5, 'fillstyle': 'none'}
+    INCLUDE_PEAK_PROPERTIES = {'color': 'green', 'linestyle': '-', 'linewidth': 2,
+                               'marker': 10, 'markersize': 10, 'fillstyle': 'full'}
+    MASKED_PEAK_PROPERTIES = {'color': 'gray', 'linestyle': ':', 'linewidth': 1}
+
+    # Setup reference data & plot
+    x_indices = np.arange(len(ydata))
+    if x_values is None:
+        x_values = x_indices
+    if x_mask is None:
+        x_mask = np.full(x_values.shape, True, dtype=bool)
+    fig, ax = plt.subplots()
+    handles = ax.plot(x_values, ydata, label='Reference data')
+    handles.append(mlines.Line2D([], [], label='Excluded / unselected HKL', **EXCLUDE_PEAK_PROPERTIES))
+    handles.append(mlines.Line2D([], [], label='Included / selected HKL', **INCLUDE_PEAK_PROPERTIES))
+    handles.append(mlines.Line2D([], [], label='HKL in masked region (unselectable)', **MASKED_PEAK_PROPERTIES))
+    ax.legend(handles=handles, loc='upper right')
+    ax.set(title=title, xlabel=xlabel, ylabel=ylabel)
+
+
+    # Plot vertical line at each peak
+    value_to_index = lambda x_value: int(np.argmin(abs(x_values - x_value)))
+    if len(peak_x_indices) > 0:
+        peak_x_values = x_values[peak_x_indices]
+    else:
+        peak_x_indices = np.array(list(map(value_to_index, peak_x_values)))
+    peak_vlines = []
+    for loc in peak_x_values:
+        nearest_index = value_to_index(loc)
+        if nearest_index in x_indices[x_mask]:
+            peak_vline = ax.axvline(loc, **EXCLUDE_PEAK_PROPERTIES)
+            peak_vline.set_picker(5)
+        else:
+            peak_vline = ax.axvline(loc, **MASKED_PEAK_PROPERTIES)
+        peak_vlines.append(peak_vline)
+
+    # Indicate masked regions by gray-ing out the axes facecolor
+    mask_exclude_bounds = bounds_from_mask(x_mask, return_include_bounds=False)
+    for (low, upp) in mask_exclude_bounds:
+        xlow = x_values[low]
+        xupp = x_values[upp]
+        ax.axvspan(xlow, xupp, facecolor='gray', alpha=0.5)
+
+    # Setup peak picking
+    selected_peak_input_indices = []
+    def onpick(event):
+        try:
+            peak_index = peak_vlines.index(event.artist)
+        except:
+            pass
+        else:
+            peak_vline = event.artist
+            if peak_index in selected_peak_input_indices:
+                peak_vline.set(**EXCLUDE_PEAK_PROPERTIES)
+                selected_peak_input_indices.remove(peak_index)
+            else:
+                peak_vline.set(**INCLUDE_PEAK_PROPERTIES)
+                selected_peak_input_indices.append(peak_index)
+            plt.draw()
+    cid_pick_peak = fig.canvas.mpl_connect('pick_event', onpick)
+
+    # Setup "Confirm" button
+    def confirm_selection(event):
+        plt.close()
+    plt.subplots_adjust(bottom=0.2)
+    confirm_b = Button(plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm')
+    cid_confirm = confirm_b.on_clicked(confirm_selection)
+
+    # Show figure for user interaction
+    plt.show()
+
+    # Disconnect callbacks when figure is closed
+    fig.canvas.mpl_disconnect(cid_pick_peak)
+    confirm_b.disconnect(cid_confirm)
+
+    if return_peak_input_indices:
+        selected_peaks = np.array(selected_peak_input_indices)
+    if return_peak_x_values:
+        selected_peaks = peak_x_values[selected_peak_input_indices]
+    if return_peak_x_indices:
+        selected_peaks = peak_x_indices[selected_peak_input_indices]
+
+    if return_sorted:
+        selected_peaks.sort()
+
+    return(selected_peaks)
+
+def find_image_files(path, filetype, name=None):
+    if isinstance(name, str):
+        name = f'{name.strip()} '
+    else:
+        name = ''
+    # Find available index range
+    if filetype == 'tif':
+        if not isinstance(path, str) or not os.path.isdir(path):
+            illegal_value(path, 'path', 'find_image_files')
+            return(-1, 0, [])
+        indexRegex = re.compile(r'\d+')
+        # At this point only tiffs
+        files = sorted([f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) and
+                f.endswith('.tif') and indexRegex.search(f)])
+        num_img = len(files)
+        if num_img < 1:
+            logger.warning(f'No available {name}files')
+            return(-1, 0, [])
+        first_index = indexRegex.search(files[0]).group()
+        last_index = indexRegex.search(files[-1]).group()
+        if first_index is None or last_index is None:
+            logger.error(f'Unable to find correctly indexed {name}images')
+            return(-1, 0, [])
+        first_index = int(first_index)
+        last_index = int(last_index)
+        if num_img != last_index-first_index+1:
+            logger.error(f'Non-consecutive set of indices for {name}images')
+            return(-1, 0, [])
+        paths = [os.path.join(path, f) for f in files]
+    elif filetype == 'h5':
+        if not isinstance(path, str) or not os.path.isfile(path):
+            illegal_value(path, 'path', 'find_image_files')
+            return(-1, 0, [])
+        # At this point only h5 in alamo2 detector style
+        first_index = 0
+        with h5py.File(path, 'r') as f:
+            num_img = f['entry/instrument/detector/data'].shape[0]
+            last_index = num_img-1
+        paths = [path]
+    else:
+        illegal_value(filetype, 'filetype', 'find_image_files')
+        return(-1, 0, [])
+    logger.info(f'Number of available {name}images: {num_img}')
+    logger.info(f'Index range of available {name}images: [{first_index}, '+
+            f'{last_index}]')
+
+    return(first_index, num_img, paths)
+
+def select_image_range(first_index, offset, num_available, num_img=None, name=None,
+        num_required=None):
+    if isinstance(name, str):
+        name = f'{name.strip()} '
+    else:
+        name = ''
+    # Check existing values
+    if not is_int(num_available, gt=0):
+        logger.warning(f'No available {name}images')
+        return(0, 0, 0)
+    if num_img is not None and not is_int(num_img, ge=0):
+        illegal_value(num_img, 'num_img', 'select_image_range')
+        return(0, 0, 0)
+    if is_int(first_index, ge=0) and is_int(offset, ge=0):
+        if num_required is None:
+            if input_yesno(f'\nCurrent {name}first image index/offset = {first_index}/{offset},'+
+                    'use these values (y/n)?', 'y'):
+                if num_img is not None:
+                    if input_yesno(f'Current number of {name}images = {num_img}, '+
+                            'use this value (y/n)? ', 'y'):
+                        return(first_index, offset, num_img)
+                else:
+                    if input_yesno(f'Number of available {name}images = {num_available}, '+
+                            'use all (y/n)? ', 'y'):
+                        return(first_index, offset, num_available)
+        else:
+            if input_yesno(f'\nCurrent {name}first image offset = {offset}, '+
+                    f'use this values (y/n)?', 'y'):
+                return(first_index, offset, num_required)
+
+    # Check range against requirements
+    if num_required is None:
+        if num_available == 1:
+            return(first_index, 0, 1)
+    else:
+        if not is_int(num_required, ge=1):
+            illegal_value(num_required, 'num_required', 'select_image_range')
+            return(0, 0, 0)
+        if num_available < num_required:
+            logger.error(f'Unable to find the required {name}images ({num_available} out of '+
+                    f'{num_required})')
+            return(0, 0, 0)
+
+    # Select index range
+    print(f'\nThe number of available {name}images is {num_available}')
+    if num_required is None:
+        last_index = first_index+num_available
+        use_all = f'Use all ([{first_index}, {last_index}])'
+        pick_offset = 'Pick the first image index offset and the number of images'
+        pick_bounds = 'Pick the first and last image index'
+        choice = input_menu([use_all, pick_offset, pick_bounds], default=pick_offset)
+        if not choice:
+            offset = 0
+            num_img = num_available
+        elif choice == 1:
+            offset = input_int('Enter the first index offset', ge=0, le=last_index-first_index)
+            if first_index+offset == last_index:
+                num_img = 1
+            else:
+                num_img = input_int('Enter the number of images', ge=1, le=num_available-offset)
+        else:
+            offset = input_int('Enter the first index', ge=first_index, le=last_index)
+            num_img = 1-offset+input_int('Enter the last index', ge=offset, le=last_index)
+            offset -= first_index
+    else:
+        use_all = f'Use ([{first_index}, {first_index+num_required-1}])'
+        pick_offset = 'Pick the first index offset'
+        choice = input_menu([use_all, pick_offset], pick_offset)
+        offset = 0
+        if choice == 1:
+            offset = input_int('Enter the first index offset', ge=0, le=num_available-num_required)
+        num_img = num_required
+
+    return(first_index, offset, num_img)
+
+def load_image(f, img_x_bounds=None, img_y_bounds=None):
+    """Load a single image from file.
+    """
+    if not os.path.isfile(f):
+        logger.error(f'Unable to load {f}')
+        return(None)
+    img_read = plt.imread(f)
+    if not img_x_bounds:
+        img_x_bounds = (0, img_read.shape[0])
+    else:
+        if (not isinstance(img_x_bounds, (tuple, list)) or len(img_x_bounds) != 2 or 
+                not (0 <= img_x_bounds[0] < img_x_bounds[1] <= img_read.shape[0])):
+            logger.error(f'inconsistent row dimension in {f}')
+            return(None)
+    if not img_y_bounds:
+        img_y_bounds = (0, img_read.shape[1])
+    else:
+        if (not isinstance(img_y_bounds, list) or len(img_y_bounds) != 2 or 
+                not (0 <= img_y_bounds[0] < img_y_bounds[1] <= img_read.shape[1])):
+            logger.error(f'inconsistent column dimension in {f}')
+            return(None)
+    return(img_read[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]])
+
+def load_image_stack(files, filetype, img_offset, num_img, num_img_skip=0,
+        img_x_bounds=None, img_y_bounds=None):
+    """Load a set of images and return them as a stack.
+    """
+    logger.debug(f'img_offset = {img_offset}')
+    logger.debug(f'num_img = {num_img}')
+    logger.debug(f'num_img_skip = {num_img_skip}')
+    logger.debug(f'\nfiles:\n{files}\n')
+    img_stack = np.array([])
+    if filetype == 'tif':
+        img_read_stack = []
+        i = 1
+        t0 = time()
+        for f in files[img_offset:img_offset+num_img:num_img_skip+1]:
+            if not i%20:
+                logger.info(f'    loading {i}/{num_img}: {f}')
+            else:
+                logger.debug(f'    loading {i}/{num_img}: {f}')
+            img_read = load_image(f, img_x_bounds, img_y_bounds)
+            img_read_stack.append(img_read)
+            i += num_img_skip+1
+        img_stack = np.stack([img_read for img_read in img_read_stack])
+        logger.info(f'... done in {time()-t0:.2f} seconds!')
+        logger.debug(f'img_stack shape = {np.shape(img_stack)}')
+        del img_read_stack, img_read
+    elif filetype == 'h5':
+        if not isinstance(files[0], str) and not os.path.isfile(files[0]):
+            illegal_value(files[0], 'files[0]', 'load_image_stack')
+            return(img_stack)
+        t0 = time()
+        logger.info(f'Loading {files[0]}')
+        with h5py.File(files[0], 'r') as f:
+            shape = f['entry/instrument/detector/data'].shape
+            if len(shape) != 3:
+                logger.error(f'inconsistent dimensions in {files[0]}')
+            if not img_x_bounds:
+                img_x_bounds = (0, shape[1])
+            else:
+                if (not isinstance(img_x_bounds, (tuple, list)) or len(img_x_bounds) != 2 or 
+                        not (0 <= img_x_bounds[0] < img_x_bounds[1] <= shape[1])):
+                    logger.error(f'inconsistent row dimension in {files[0]} {img_x_bounds} '+
+                            f'{shape[1]}')
+            if not img_y_bounds:
+                img_y_bounds = (0, shape[2])
+            else:
+                if (not isinstance(img_y_bounds, list) or len(img_y_bounds) != 2 or 
+                        not (0 <= img_y_bounds[0] < img_y_bounds[1] <= shape[2])):
+                    logger.error(f'inconsistent column dimension in {files[0]}')
+            img_stack = f.get('entry/instrument/detector/data')[
+                    img_offset:img_offset+num_img:num_img_skip+1,
+                    img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]]
+        logger.info(f'... done in {time()-t0:.2f} seconds!')
+    else:
+        illegal_value(filetype, 'filetype', 'load_image_stack')
+    return(img_stack)
+
+def combine_tiffs_in_h5(files, num_img, h5_filename):
+    img_stack = load_image_stack(files, 'tif', 0, num_img)
+    with h5py.File(h5_filename, 'w') as f:
+        f.create_dataset('entry/instrument/detector/data', data=img_stack)
+    del img_stack
+    return([h5_filename])
+
+def clear_imshow(title=None):
+    plt.ioff()
+    if title is None:
+        title = 'quick imshow'
+    elif not isinstance(title, str):
+        illegal_value(title, 'title', 'clear_imshow')
+        return
+    plt.close(fig=title)
+
+def clear_plot(title=None):
+    plt.ioff()
+    if title is None:
+        title = 'quick plot'
+    elif not isinstance(title, str):
+        illegal_value(title, 'title', 'clear_plot')
+        return
+    plt.close(fig=title)
+
+def quick_imshow(a, title=None, path=None, name=None, save_fig=False, save_only=False,
+            clear=True, extent=None, show_grid=False, grid_color='w', grid_linewidth=1,
+            block=False, **kwargs):
+    if title is not None and not isinstance(title, str):
+        illegal_value(title, 'title', 'quick_imshow')
+        return
+    if path is not None and not isinstance(path, str):
+        illegal_value(path, 'path', 'quick_imshow')
+        return
+    if not isinstance(save_fig, bool):
+        illegal_value(save_fig, 'save_fig', 'quick_imshow')
+        return
+    if not isinstance(save_only, bool):
+        illegal_value(save_only, 'save_only', 'quick_imshow')
+        return
+    if not isinstance(clear, bool):
+        illegal_value(clear, 'clear', 'quick_imshow')
+        return
+    if not isinstance(block, bool):
+        illegal_value(block, 'block', 'quick_imshow')
+        return
+    if not title:
+        title='quick imshow'
+#    else:
+#        title = re.sub(r"\s+", '_', title)
+    if name is None:
+        ttitle = re.sub(r"\s+", '_', title)
+        if path is None:
+            path = f'{ttitle}.png'
+        else:
+            path = f'{path}/{ttitle}.png'
+    else:
+        if path is None:
+            path = name
+        else:
+            path = f'{path}/{name}'
+    if 'cmap' in kwargs and a.ndim == 3 and (a.shape[2] == 3 or a.shape[2] == 4):
+        use_cmap = True
+        if a.shape[2] == 4 and a[:,:,-1].min() != a[:,:,-1].max():
+            use_cmap = False
+        if any(True if a[i,j,0] != a[i,j,1] and a[i,j,0] != a[i,j,2] else False
+                for i in range(a.shape[0]) for j in range(a.shape[1])):
+            use_cmap = False
+        if use_cmap:
+            a = a[:,:,0]
+        else:
+            logger.warning('Image incompatible with cmap option, ignore cmap')
+            kwargs.pop('cmap')
+    if extent is None:
+        extent = (0, a.shape[1], a.shape[0], 0)
+    if clear:
+        try:
+            plt.close(fig=title)
+        except:
+            pass
+    if not save_only:
+        if block:
+            plt.ioff()
+        else:
+            plt.ion()
+    plt.figure(title)
+    plt.imshow(a, extent=extent, **kwargs)
+    if show_grid:
+        ax = plt.gca()
+        ax.grid(color=grid_color, linewidth=grid_linewidth)
+#    if title != 'quick imshow':
+#        plt.title = title
+    if save_only:
+        plt.savefig(path)
+        plt.close(fig=title)
+    else:
+        if save_fig:
+            plt.savefig(path)
+        if block:
+            plt.show(block=block)
+
+def quick_plot(*args, xerr=None, yerr=None, vlines=None, title=None, xlim=None, ylim=None,
+        xlabel=None, ylabel=None, legend=None, path=None, name=None, show_grid=False,
+        save_fig=False, save_only=False, clear=True, block=False, **kwargs):
+    if title is not None and not isinstance(title, str):
+        illegal_value(title, 'title', 'quick_plot')
+        title = None
+    if xlim is not None and not isinstance(xlim, (tuple, list)) and len(xlim) != 2:
+        illegal_value(xlim, 'xlim', 'quick_plot')
+        xlim = None
+    if ylim is not None and not isinstance(ylim, (tuple, list)) and len(ylim) != 2:
+        illegal_value(ylim, 'ylim', 'quick_plot')
+        ylim = None
+    if xlabel is not None and not isinstance(xlabel, str):
+        illegal_value(xlabel, 'xlabel', 'quick_plot')
+        xlabel = None
+    if ylabel is not None and not isinstance(ylabel, str):
+        illegal_value(ylabel, 'ylabel', 'quick_plot')
+        ylabel = None
+    if legend is not None and not isinstance(legend, (tuple, list)):
+        illegal_value(legend, 'legend', 'quick_plot')
+        legend = None
+    if path is not None and not isinstance(path, str):
+        illegal_value(path, 'path', 'quick_plot')
+        return
+    if not isinstance(show_grid, bool):
+        illegal_value(show_grid, 'show_grid', 'quick_plot')
+        return
+    if not isinstance(save_fig, bool):
+        illegal_value(save_fig, 'save_fig', 'quick_plot')
+        return
+    if not isinstance(save_only, bool):
+        illegal_value(save_only, 'save_only', 'quick_plot')
+        return
+    if not isinstance(clear, bool):
+        illegal_value(clear, 'clear', 'quick_plot')
+        return
+    if not isinstance(block, bool):
+        illegal_value(block, 'block', 'quick_plot')
+        return
+    if title is None:
+        title = 'quick plot'
+#    else:
+#        title = re.sub(r"\s+", '_', title)
+    if name is None:
+        ttitle = re.sub(r"\s+", '_', title)
+        if path is None:
+            path = f'{ttitle}.png'
+        else:
+            path = f'{path}/{ttitle}.png'
+    else:
+        if path is None:
+            path = name
+        else:
+            path = f'{path}/{name}'
+    if clear:
+        try:
+            plt.close(fig=title)
+        except:
+            pass
+    args = unwrap_tuple(args)
+    if depth_tuple(args) > 1 and (xerr is not None or yerr is not None):
+        logger.warning('Error bars ignored form multiple curves')
+    if not save_only:
+        if block:
+            plt.ioff()
+        else:
+            plt.ion()
+    plt.figure(title)
+    if depth_tuple(args) > 1:
+       for y in args:
+           plt.plot(*y, **kwargs)
+    else:
+        if xerr is None and yerr is None:
+            plt.plot(*args, **kwargs)
+        else:
+            plt.errorbar(*args, xerr=xerr, yerr=yerr, **kwargs)
+    if vlines is not None:
+        if isinstance(vlines, (int, float)):
+            vlines = [vlines]
+        for v in vlines:
+            plt.axvline(v, color='r', linestyle='--', **kwargs)
+#    if vlines is not None:
+#        for s in tuple(([x, x], list(plt.gca().get_ylim())) for x in vlines):
+#            plt.plot(*s, color='red', **kwargs)
+    if xlim is not None:
+        plt.xlim(xlim)
+    if ylim is not None:
+        plt.ylim(ylim)
+    if xlabel is not None:
+        plt.xlabel(xlabel)
+    if ylabel is not None:
+        plt.ylabel(ylabel)
+    if show_grid:
+        ax = plt.gca()
+        ax.grid(color='k')#, linewidth=1)
+    if legend is not None:
+        plt.legend(legend)
+    if save_only:
+        plt.savefig(path)
+        plt.close(fig=title)
+    else:
+        if save_fig:
+            plt.savefig(path)
+        if block:
+            plt.show(block=block)
+
+def select_array_bounds(a, x_low=None, x_upp=None, num_x_min=None, ask_bounds=False,
+        title='select array bounds'):
+    """Interactively select the lower and upper data bounds for a numpy array.
+    """
+    if isinstance(a, (tuple, list)):
+        a = np.array(a)
+    if not isinstance(a, np.ndarray) or a.ndim != 1:
+        illegal_value(a.ndim, 'array type or dimension', 'select_array_bounds')
+        return(None)
+    len_a = len(a)
+    if num_x_min is None:
+        num_x_min = 1
+    else:
+        if num_x_min < 2 or num_x_min > len_a:
+            logger.warning('Invalid value for num_x_min in select_array_bounds, input ignored')
+            num_x_min = 1
+
+    # Ask to use current bounds
+    if ask_bounds and (x_low is not None or x_upp is not None):
+        if x_low is None:
+            x_low = 0
+        if not is_int(x_low, ge=0, le=len_a-num_x_min):
+            illegal_value(x_low, 'x_low', 'select_array_bounds')
+            return(None)
+        if x_upp is None:
+            x_upp = len_a
+        if not is_int(x_upp, ge=x_low+num_x_min, le=len_a):
+            illegal_value(x_upp, 'x_upp', 'select_array_bounds')
+            return(None)
+        quick_plot((range(len_a), a), vlines=(x_low,x_upp), title=title)
+        if not input_yesno(f'\nCurrent array bounds: [{x_low}, {x_upp}] '+
+                    'use these values (y/n)?', 'y'):
+            x_low = None
+            x_upp = None
+        else:
+            clear_plot(title)
+            return(x_low, x_upp)
+
+    if x_low is None:
+        x_min = 0
+        x_max = len_a
+        x_low_max = len_a-num_x_min
+        while True:
+            quick_plot(range(x_min, x_max), a[x_min:x_max], title=title)
+            zoom_flag = input_yesno('Set lower data bound (y) or zoom in (n)?', 'y')
+            if zoom_flag:
+                x_low = input_int('    Set lower data bound', ge=0, le=x_low_max)
+                break
+            else:
+                x_min = input_int('    Set lower zoom index', ge=0, le=x_low_max)
+                x_max = input_int('    Set upper zoom index', ge=x_min+1, le=x_low_max+1)
+    else:
+        if not is_int(x_low, ge=0, le=len_a-num_x_min):
+            illegal_value(x_low, 'x_low', 'select_array_bounds')
+            return(None)
+    if x_upp is None:
+        x_min = x_low+num_x_min
+        x_max = len_a
+        x_upp_min = x_min
+        while True:
+            quick_plot(range(x_min, x_max), a[x_min:x_max], title=title)
+            zoom_flag = input_yesno('Set upper data bound (y) or zoom in (n)?', 'y')
+            if zoom_flag:
+                x_upp = input_int('    Set upper data bound', ge=x_upp_min, le=len_a)
+                break
+            else:
+                x_min = input_int('    Set upper zoom index', ge=x_upp_min, le=len_a-1)
+                x_max = input_int('    Set upper zoom index', ge=x_min+1, le=len_a)
+    else:
+        if not is_int(x_upp, ge=x_low+num_x_min, le=len_a):
+            illegal_value(x_upp, 'x_upp', 'select_array_bounds')
+            return(None)
+    print(f'lower bound = {x_low} (inclusive)\nupper bound = {x_upp} (exclusive)]')
+    quick_plot((range(len_a), a), vlines=(x_low,x_upp), title=title)
+    if not input_yesno('Accept these bounds (y/n)?', 'y'):
+        x_low, x_upp = select_array_bounds(a, None, None, num_x_min, title=title)
+    clear_plot(title)
+    return(x_low, x_upp)
+
+def select_image_bounds(a, axis, low=None, upp=None, num_min=None, title='select array bounds',
+        raise_error=False):
+    """Interactively select the lower and upper data bounds for a 2D numpy array.
+    """
+    a = np.asarray(a)
+    if a.ndim != 2:
+        illegal_value(a.ndim, 'array dimension', location='select_image_bounds',
+                raise_error=raise_error)
+        return(None)
+    if axis < 0 or axis >= a.ndim:
+        illegal_value(axis, 'axis', location='select_image_bounds', raise_error=raise_error)
+        return(None)
+    low_save = low
+    upp_save = upp
+    num_min_save = num_min
+    if num_min is None:
+        num_min = 1
+    else:
+        if num_min < 2 or num_min > a.shape[axis]:
+            logger.warning('Invalid input for num_min in select_image_bounds, input ignored')
+            num_min = 1
+    if low is None:
+        min_ = 0
+        max_ = a.shape[axis]
+        low_max = a.shape[axis]-num_min
+        while True:
+            if axis:
+                quick_imshow(a[:,min_:max_], title=title, aspect='auto',
+                        extent=[min_,max_,a.shape[0],0])
+            else:
+                quick_imshow(a[min_:max_,:], title=title, aspect='auto',
+                        extent=[0,a.shape[1], max_,min_])
+            zoom_flag = input_yesno('Set lower data bound (y) or zoom in (n)?', 'y')
+            if zoom_flag:
+                low = input_int('    Set lower data bound', ge=0, le=low_max)
+                break
+            else:
+                min_ = input_int('    Set lower zoom index', ge=0, le=low_max)
+                max_ = input_int('    Set upper zoom index', ge=min_+1, le=low_max+1)
+    else:
+        if not is_int(low, ge=0, le=a.shape[axis]-num_min):
+            illegal_value(low, 'low', location='select_image_bounds', raise_error=raise_error)
+            return(None)
+    if upp is None:
+        min_ = low+num_min
+        max_ = a.shape[axis]
+        upp_min = min_
+        while True:
+            if axis:
+                quick_imshow(a[:,min_:max_], title=title, aspect='auto',
+                        extent=[min_,max_,a.shape[0],0])
+            else:
+                quick_imshow(a[min_:max_,:], title=title, aspect='auto',
+                        extent=[0,a.shape[1], max_,min_])
+            zoom_flag = input_yesno('Set upper data bound (y) or zoom in (n)?', 'y')
+            if zoom_flag:
+                upp = input_int('    Set upper data bound', ge=upp_min, le=a.shape[axis])
+                break
+            else:
+                min_ = input_int('    Set upper zoom index', ge=upp_min, le=a.shape[axis]-1)
+                max_ = input_int('    Set upper zoom index', ge=min_+1, le=a.shape[axis])
+    else:
+        if not is_int(upp, ge=low+num_min, le=a.shape[axis]):
+            illegal_value(upp, 'upp', location='select_image_bounds', raise_error=raise_error)
+            return(None)
+    bounds = (low, upp)
+    a_tmp = np.copy(a)
+    a_tmp_max = a.max()
+    if axis:
+        a_tmp[:,bounds[0]] = a_tmp_max
+        a_tmp[:,bounds[1]-1] = a_tmp_max
+    else:
+        a_tmp[bounds[0],:] = a_tmp_max
+        a_tmp[bounds[1]-1,:] = a_tmp_max
+    print(f'lower bound = {low} (inclusive)\nupper bound = {upp} (exclusive)')
+    quick_imshow(a_tmp, title=title, aspect='auto')
+    del a_tmp
+    if not input_yesno('Accept these bounds (y/n)?', 'y'):
+        bounds = select_image_bounds(a, axis, low=low_save, upp=upp_save, num_min=num_min_save,
+            title=title)
+    return(bounds)
+
+def select_one_image_bound(a, axis, bound=None, bound_name=None, title='select array bounds',
+        default='y', raise_error=False):
+    """Interactively select a data boundary for a 2D numpy array.
+    """
+    a = np.asarray(a)
+    if a.ndim != 2:
+        illegal_value(a.ndim, 'array dimension', location='select_one_image_bound',
+                raise_error=raise_error)
+        return(None)
+    if axis < 0 or axis >= a.ndim:
+        illegal_value(axis, 'axis', location='select_one_image_bound', raise_error=raise_error)
+        return(None)
+    if bound_name is None:
+        bound_name = 'data bound'
+    if bound is None:
+        min_ = 0
+        max_ = a.shape[axis]
+        bound_max = a.shape[axis]-1
+        while True:
+            if axis:
+                quick_imshow(a[:,min_:max_], title=title, aspect='auto',
+                        extent=[min_,max_,a.shape[0],0])
+            else:
+                quick_imshow(a[min_:max_,:], title=title, aspect='auto',
+                        extent=[0,a.shape[1], max_,min_])
+            zoom_flag = input_yesno(f'Set {bound_name} (y) or zoom in (n)?', 'y')
+            if zoom_flag:
+                bound = input_int(f'    Set {bound_name}', ge=0, le=bound_max)
+                clear_imshow(title)
+                break
+            else:
+                min_ = input_int('    Set lower zoom index', ge=0, le=bound_max)
+                max_ = input_int('    Set upper zoom index', ge=min_+1, le=bound_max+1)
+
+    elif not is_int(bound, ge=0, le=a.shape[axis]-1):
+        illegal_value(bound, 'bound', location='select_one_image_bound', raise_error=raise_error)
+        return(None)
+    else:
+        print(f'Current {bound_name} = {bound}')
+    a_tmp = np.copy(a)
+    a_tmp_max = a.max()
+    if axis:
+        a_tmp[:,bound] = a_tmp_max
+    else:
+        a_tmp[bound,:] = a_tmp_max
+    quick_imshow(a_tmp, title=title, aspect='auto')
+    del a_tmp
+    if not input_yesno(f'Accept this {bound_name} (y/n)?', default):
+        bound = select_one_image_bound(a, axis, bound_name=bound_name, title=title)
+    clear_imshow(title)
+    return(bound)
+
+
+class Config:
+    """Base class for processing a config file or dictionary.
+    """
+    def __init__(self, config_file=None, config_dict=None):
+        self.config = {}
+        self.load_flag = False
+        self.suffix = None
+
+        # Load config file 
+        if config_file is not None and config_dict is not None:
+            logger.warning('Ignoring config_dict (both config_file and config_dict are specified)')
+        if config_file is not None:
+           self.load_file(config_file)
+        elif config_dict is not None:
+           self.load_dict(config_dict)
+
+    def load_file(self, config_file):
+        """Load a config file.
+        """
+        if self.load_flag:
+            logger.warning('Overwriting any previously loaded config file')
+        self.config = {}
+
+        # Ensure config file exists
+        if not os.path.isfile(config_file):
+            logger.error(f'Unable to load {config_file}')
+            return
+
+        # Load config file (for now for Galaxy, allow .dat extension)
+        self.suffix = os.path.splitext(config_file)[1]
+        if self.suffix == '.yml' or self.suffix == '.yaml' or self.suffix == '.dat':
+            with open(config_file, 'r') as f:
+                self.config = safe_load(f)
+        elif self.suffix == '.txt':
+            with open(config_file, 'r') as f:
+                lines = f.read().splitlines()
+            self.config = {item[0].strip():literal_eval(item[1].strip()) for item in
+                    [line.split('#')[0].split('=') for line in lines if '=' in line.split('#')[0]]}
+        else:
+            illegal_value(self.suffix, 'config file extension', 'Config.load_file')
+
+        # Make sure config file was correctly loaded
+        if isinstance(self.config, dict):
+            self.load_flag = True
+        else:
+            logger.error(f'Unable to load dictionary from config file: {config_file}')
+            self.config = {}
+
+    def load_dict(self, config_dict):
+        """Takes a dictionary and places it into self.config.
+        """
+        if self.load_flag:
+            logger.warning('Overwriting the previously loaded config file')
+
+        if isinstance(config_dict, dict):
+            self.config = config_dict
+            self.load_flag = True
+        else:
+            illegal_value(config_dict, 'dictionary config object', 'Config.load_dict')
+            self.config = {}
+
+    def save_file(self, config_file):
+        """Save the config file (as a yaml file only right now).
+        """
+        suffix = os.path.splitext(config_file)[1]
+        if suffix != '.yml' and suffix != '.yaml':
+            illegal_value(suffix, 'config file extension', 'Config.save_file')
+
+        # Check if config file exists
+        if os.path.isfile(config_file):
+            logger.info(f'Updating {config_file}')
+        else:
+            logger.info(f'Saving {config_file}')
+
+        # Save config file
+        with open(config_file, 'w') as f:
+            safe_dump(self.config, f)
+
+    def validate(self, pars_required, pars_missing=None):
+        """Returns False if any required keys are missing.
+        """
+        if not self.load_flag:
+            logger.error('Load a config file prior to calling Config.validate')
+
+        def validate_nested_pars(config, par):
+            par_levels = par.split(':')
+            first_level_par = par_levels[0]
+            try:
+                first_level_par = int(first_level_par)
+            except:
+                pass
+            try:
+                next_level_config = config[first_level_par]
+                if len(par_levels) > 1:
+                    next_level_par = ':'.join(par_levels[1:])
+                    return(validate_nested_pars(next_level_config, next_level_par))
+                else:
+                    return(True)
+            except:
+                return(False)
+
+        pars_missing = [p for p in pars_required if not validate_nested_pars(self.config, p)]
+        if len(pars_missing) > 0:
+            logger.error(f'Missing item(s) in configuration: {", ".join(pars_missing)}')
+            return(False)
+        else:
+            return(True)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tomo_macros.xml	Tue Mar 21 16:22:42 2023 +0000
@@ -0,0 +1,30 @@
+<macros>
+    <xml name="requirements">
+        <requirements>
+            <requirement type="package" version="1.0.3">lmfit</requirement>
+            <requirement type="package" version="3.5.2">matplotlib</requirement>
+            <requirement type="package" version="1.0.0">nexusformat</requirement>
+            <requirement type="package" version="1.12.2">tomopy</requirement>
+        </requirements>
+    </xml>
+    <xml name="citations">
+        <citations>
+            <citation type="bibtex">
+@misc{github_files,
+  author = {Verberg, Rolf},
+  year = {2022},
+  title = {Tomo Reconstruction},
+}</citation>
+        </citations>
+    </xml>
+    <!--
+    <xml name="common_inputs">
+        <param name="config" type='data' format='yaml' optional='false' label="Input config"/>
+        <param name="config" type='data' format='tomo.config.yaml' optional='true' label="Input config"/>
+    </xml>
+    -->
+    <xml name="common_outputs">
+        <data name="log" format="txt" label="Log"/>
+    </xml>
+</macros>
+ 
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tomo_reconstruct.py	Tue Mar 21 16:22:42 2023 +0000
@@ -0,0 +1,113 @@
+#!/usr/bin/env python3
+
+import logging
+
+import argparse
+import pathlib
+import sys
+#import tracemalloc
+
+from workflow.run_tomo import Tomo
+
+#from memory_profiler import profile
+#@profile
+def __main__():
+    # Parse command line arguments
+    parser = argparse.ArgumentParser(
+            description='Perform a tomography reconstruction')
+    parser.add_argument('-i', '--input_file',
+            required=True,
+            type=pathlib.Path,
+            help='''Full or relative path to the input file (in Nexus format).''')
+    parser.add_argument('-c', '--center_file',
+            required=True,
+            type=pathlib.Path,
+            help='''Full or relative path to the center info file (in yaml format).''')
+    parser.add_argument('-o', '--output_file',
+            required=False,
+            type=pathlib.Path,
+            help='''Full or relative path to the output file (in Nexus format).''')
+    parser.add_argument('--galaxy_flag',
+            action='store_true',
+            help='''Use this flag to run the scripts as a galaxy tool.''')
+    parser.add_argument('-l', '--log',
+#            type=argparse.FileType('w'),
+            default=sys.stdout,
+            help='Logging stream or filename')
+    parser.add_argument('--log_level',
+            choices=logging._nameToLevel.keys(),
+            default='INFO',
+            help='''Specify a preferred logging level.''')
+    parser.add_argument('--x_bounds',
+            required=False,
+            nargs=2,
+            type=int,
+            help='''Boundaries of reconstructed images in x-direction.''')
+    parser.add_argument('--y_bounds',
+            required=False,
+            nargs=2,
+            type=int,
+            help='''Boundaries of reconstructed images in y-direction.''')
+    args = parser.parse_args()
+
+    # Set log configuration
+    # When logging to file, the stdout log level defaults to WARNING
+    logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s'
+    level = logging.getLevelName(args.log_level)
+    if args.log is sys.stdout:
+        logging.basicConfig(format=logging_format, level=level, force=True,
+                handlers=[logging.StreamHandler()])
+    else:
+        if isinstance(args.log, str):
+            logging.basicConfig(filename=f'{args.log}', filemode='w',
+                    format=logging_format, level=level, force=True)
+        elif isinstance(args.log, io.TextIOWrapper):
+            logging.basicConfig(filemode='w', format=logging_format, level=level,
+                    stream=args.log, force=True)
+        else:
+            raise(ValueError(f'Invalid argument --log: {args.log}'))
+        stream_handler = logging.StreamHandler()
+        logging.getLogger().addHandler(stream_handler)
+        stream_handler.setLevel(logging.WARNING)
+        stream_handler.setFormatter(logging.Formatter(logging_format))
+
+    # Starting memory monitoring
+#    tracemalloc.start()
+
+    # Log command line arguments
+    logging.info(f'input_file = {args.input_file}')
+    logging.info(f'center_file = {args.center_file}')
+    logging.info(f'output_file = {args.output_file}')
+    logging.info(f'galaxy_flag = {args.galaxy_flag}')
+    logging.debug(f'log = {args.log}')
+    logging.debug(f'is log stdout? {args.log is sys.stdout}')
+    logging.debug(f'log_level = {args.log_level}')
+    logging.info(f'x_bounds = {args.x_bounds}')
+    logging.info(f'y_bounds = {args.y_bounds}')
+
+    # Instantiate Tomo object
+    tomo = Tomo(galaxy_flag=args.galaxy_flag)
+
+    # Read input file
+    data = tomo.read(args.input_file)
+
+    # Read center data
+    center_data = tomo.read(args.center_file)
+
+    # Find the calibrated center axis info
+    data = tomo.reconstruct_data(data, center_data, x_bounds=args.x_bounds, y_bounds=args.y_bounds)
+
+    # Write output file
+    data = tomo.write(data, args.output_file)
+
+    # Displaying memory usage
+#    logging.info(f'Memory usage: {tracemalloc.get_traced_memory()}')
+ 
+    # stopping memory monitoring
+#    tracemalloc.stop()
+
+    logging.info('Completed tomography reconstruction')
+
+
+if __name__ == "__main__":
+    __main__()
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tomo_reconstruct.xml	Tue Mar 21 16:22:42 2023 +0000
@@ -0,0 +1,34 @@
+<tool id="tomo_reconstruct" name="Tomo Reconstruction" version="0.3.0" python_template_version="3.9">
+    <description>Perform a tomography reconstruction</description>
+    <macros>
+        <import>tomo_macros.xml</import>
+    </macros>
+    <expand macro="requirements"/>
+    <command detect_errors="exit_code">
+        <![CDATA[
+            mkdir tomo_reconstruct_plots;
+            $__tool_directory__/tomo_reconstruct.py
+            --input_file "$input_file"
+            --center_file "$center_file"
+            --output_file "output.nex"
+            --galaxy_flag
+            -l "$log"
+        ]]>
+    </command>
+    <inputs>
+        <expand macro="common_inputs"/>
+        <param name="input_file" type="data" format="nex" optional="false" label="Reduced tomography data"/>
+        <param name="center_file" type="data" format="yaml" optional="false" label="Center axis input file"/>
+    </inputs>
+    <outputs>
+        <expand macro="common_outputs"/>
+        <collection name="tomo_reconstruct_plots" type="list" label="Data recontructed images">
+            <discover_datasets pattern="__name_and_ext__" directory="tomo_reconstruct_plots"/>
+        </collection>
+        <data name="output_file" format="nex" label="Reconstructed tomography data" from_work_dir="output.nex"/>
+    </outputs>
+    <help><![CDATA[
+        Reconstruct tomography images.
+    ]]></help>
+    <expand macro="citations"/>
+</tool>
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/workflow/__main__.py	Tue Mar 21 16:22:42 2023 +0000
@@ -0,0 +1,236 @@
+#!/usr/bin/env python3
+
+import logging
+logging.getLogger(__name__)
+
+import argparse
+import pathlib
+import sys
+
+from .models import TomoWorkflow as Workflow
+try:
+    from deepdiff import DeepDiff
+except:
+    pass
+
+parser = argparse.ArgumentParser(description='''Operate on representations of
+        Tomo data workflows saved to files.''')
+parser.add_argument('-l', '--log',
+#        type=argparse.FileType('w'),
+        default=sys.stdout,
+        help='Logging stream or filename')
+parser.add_argument('--log_level',
+        choices=logging._nameToLevel.keys(),
+        default='INFO',
+        help='''Specify a preferred logging level.''')
+subparsers = parser.add_subparsers(title='subcommands', required=True)#, dest='command')
+
+
+# CONSTRUCT
+def construct(args:list) -> None:
+    if args.template_file is not None:
+        wf = Workflow.construct_from_file(args.template_file)
+        wf.cli()
+    else:
+        wf = Workflow.construct_from_cli()
+    wf.write_to_file(args.output_file, force_overwrite=args.force_overwrite)
+
+construct_parser = subparsers.add_parser('construct', help='''Construct a valid Tomo
+        workflow representation on the command line and save it to a file. Optionally use
+        an existing file as a template and/or preform the reconstruction or transfer to Galaxy.''')
+construct_parser.set_defaults(func=construct)
+construct_parser.add_argument('-t', '--template_file',
+        type=pathlib.Path,
+        required=False,
+        help='''Full or relative template file path for the constructed workflow.''')
+construct_parser.add_argument('-f', '--force_overwrite',
+        action='store_true',
+        help='''Use this flag to overwrite the output file if it already exists.''')
+construct_parser.add_argument('-o', '--output_file',
+        type=pathlib.Path,
+        help='''Full or relative file path to which the constructed workflow will be written.''')
+
+
+# VALIDATE
+def validate(args:list) -> bool:
+    try:
+        wf = Workflow.construct_from_file(args.input_file)
+        logger.info(f'Success: {args.input_file} represents a valid Tomo workflow configuration.')
+        return(True)
+    except BaseException as e:
+        logger.error(f'{e.__class__.__name__}: {str(e)}')
+        logger.info(f'''Failure: {args.input_file} does not represent a valid Tomo workflow
+                configuration.''')
+        return(False)
+
+validate_parser = subparsers.add_parser('validate',
+        help='''Validate a file as a representation of a Tomo workflow (this is most useful
+                after a .yaml file has been manually edited).''')
+validate_parser.set_defaults(func=validate)
+validate_parser.add_argument('input_file',
+        type=pathlib.Path,
+        help='''Full or relative file path to validate as a Tomo workflow.''')
+
+
+# CONVERT
+def convert(args:list) -> None:
+    wf = Workflow.construct_from_file(args.input_file)
+    wf.write_to_file(args.output_file, force_overwrite=args.force_overwrite)
+
+convert_parser = subparsers.add_parser('convert', help='''Convert one Tomo workflow
+        representation to another. File format of both input and output files will be
+        automatically determined from the files' extensions.''')
+convert_parser.set_defaults(func=convert)
+convert_parser.add_argument('-f', '--force_overwrite',
+        action='store_true',
+        help='''Use this flag to overwrite the output file if it already exists.''')
+convert_parser.add_argument('-i', '--input_file',
+        type=pathlib.Path,
+        required=True,
+        help='''Full or relative input file path to be converted.''')
+convert_parser.add_argument('-o', '--output_file',
+        type=pathlib.Path,
+        required=True,
+        help='''Full or relative file path to which the converted input will be written.''')
+
+
+# DIFF / COMPARE
+def diff(args:list) -> bool:
+    raise ValueError('diff not tested')
+#    wf1 = Workflow.construct_from_file(args.file1).dict_for_yaml()
+#    wf2 = Workflow.construct_from_file(args.file2).dict_for_yaml()
+#    diff = DeepDiff(wf1,wf2,
+#                    ignore_order_func=lambda level:'independent_dimensions' not in level.path(),
+#                    report_repetition=True,
+#                    ignore_string_type_changes=True,
+#                    ignore_numeric_type_changes=True)
+    diff_report = diff.pretty()
+    if len(diff_report) > 0:
+        logger.info(f'The configurations in {args.file1} and {args.file2} are not identical.')
+        print(diff_report)
+        return(True)
+    else:
+        logger.info(f'The configurations in {args.file1} and {args.file2} are identical.')
+        return(False)
+
+diff_parser = subparsers.add_parser('diff', aliases=['compare'], help='''Print a comparison of 
+        two Tomo workflow representations stored in files. The files may have different formats.''')
+diff_parser.set_defaults(func=diff)
+diff_parser.add_argument('file1',
+        type=pathlib.Path,
+        help='''Full or relative path to the first file for comparison.''')
+diff_parser.add_argument('file2',
+        type=pathlib.Path,
+        help='''Full or relative path to the second file for comparison.''')
+
+
+# LINK TO GALAXY
+def link_to_galaxy(args:list) -> None:
+    from .link_to_galaxy import link_to_galaxy
+    link_to_galaxy(args.input_file, galaxy=args.galaxy, user=args.user,
+            password=args.password, api_key=args.api_key)
+
+link_parser = subparsers.add_parser('link_to_galaxy', help='''Construct a Galaxy history and link
+        to an existing Tomo workflow representations in a NeXus file.''')
+link_parser.set_defaults(func=link_to_galaxy)
+link_parser.add_argument('-i', '--input_file',
+        type=pathlib.Path,
+        required=True,
+        help='''Full or relative input file path to the existing Tomo workflow representations as 
+                a NeXus file.''')
+link_parser.add_argument('-g', '--galaxy',
+        required=True,
+        help='Target Galaxy instance URL/IP address')
+link_parser.add_argument('-u', '--user',
+        default=None,
+        help='Galaxy user email address')
+link_parser.add_argument('-p', '--password',
+        default=None,
+        help='Password for the Galaxy user')
+link_parser.add_argument('-a', '--api_key',
+        default=None,
+        help='Galaxy admin user API key (required if not defined in the tools list file)')
+
+
+# RUN THE RECONSTRUCTION
+def run_tomo(args:list) -> None:
+    from .run_tomo import run_tomo
+    run_tomo(args.input_file, args.output_file, args.modes, center_file=args.center_file,
+            num_core=args.num_core, output_folder=args.output_folder, save_figs=args.save_figs)
+
+tomo_parser = subparsers.add_parser('run_tomo', help='''Construct and add reconstructed tomography
+        data to an existing Tomo workflow representations in a NeXus file.''')
+tomo_parser.set_defaults(func=run_tomo)
+tomo_parser.add_argument('-i', '--input_file',
+        required=True,
+        type=pathlib.Path,
+        help='''Full or relative input file path containing raw and/or reduced data.''')
+tomo_parser.add_argument('-o', '--output_file',
+        required=True,
+        type=pathlib.Path,
+        help='''Full or relative input file path containing raw and/or reduced data.''')
+tomo_parser.add_argument('-c', '--center_file',
+        type=pathlib.Path,
+        help='''Full or relative input file path containing the rotation axis centers info.''')
+#tomo_parser.add_argument('-f', '--force_overwrite',
+#        action='store_true',
+#        help='''Use this flag to overwrite any existing reduced data.''')
+tomo_parser.add_argument('-n', '--num_core',
+        type=int,
+        default=-1,
+        help='''Specify the number of processors to use.''')
+tomo_parser.add_argument('--output_folder',
+        type=pathlib.Path,
+        default='.',
+        help='Full or relative path to an output folder')
+tomo_parser.add_argument('-s', '--save_figs',
+        choices=['yes', 'no', 'only'],
+        default='no',
+        help='''Specify weather to display ('yes' or 'no'), save ('yes'), or only save ('only').''')
+tomo_parser.add_argument('--reduce_data',
+        dest='modes',
+        const='reduce_data',
+        action='append_const',
+        help='''Use this flag to create and add reduced data to the input file.''')
+tomo_parser.add_argument('--find_center',
+        dest='modes',
+        const='find_center',
+        action='append_const',
+        help='''Use this flag to find and add the calibrated center axis info to the input file.''')
+tomo_parser.add_argument('--reconstruct_data',
+        dest='modes',
+        const='reconstruct_data',
+        action='append_const',
+        help='''Use this flag to create and add reconstructed data data to the input file.''')
+tomo_parser.add_argument('--combine_data',
+        dest='modes',
+        const='combine_data',
+        action='append_const',
+        help='''Use this flag to combine reconstructed data data and add to the input file.''')
+
+
+if __name__ == '__main__':
+    args = parser.parse_args(sys.argv[1:])
+
+    # Set log configuration
+    # When logging to file, the stdout log level defaults to WARNING
+    logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s'
+    level = logging.getLevelName(args.log_level)
+    if args.log is sys.stdout:
+        logging.basicConfig(format=logging_format, level=level, force=True,
+                handlers=[logging.StreamHandler()])
+    else:
+        if isinstance(args.log, str):
+            logging.basicConfig(filename=f'{args.log}', filemode='w',
+                    format=logging_format, level=level, force=True)
+        elif isinstance(args.log, io.TextIOWrapper):
+            logging.basicConfig(filemode='w', format=logging_format, level=level,
+                    stream=args.log, force=True)
+        else:
+            raise ValueError(f'Invalid argument --log: {args.log}')
+        stream_handler = logging.StreamHandler()
+        logging.getLogger().addHandler(stream_handler)
+        stream_handler.setLevel(logging.WARNING)
+        stream_handler.setFormatter(logging.Formatter(logging_format))
+
+    args.func(args)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/workflow/__version__.py	Tue Mar 21 16:22:42 2023 +0000
@@ -0,0 +1,1 @@
+__version__='2022.3.0'
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/workflow/link_to_galaxy.py	Tue Mar 21 16:22:42 2023 +0000
@@ -0,0 +1,120 @@
+#!/usr/bin/env python3
+
+import logging
+logger = logging.getLogger(__name__)
+
+from bioblend.galaxy import GalaxyInstance
+from nexusformat.nexus import *
+from os import path
+from yaml import safe_load
+
+from .models import import_scanparser, TomoWorkflow
+
+def get_folder_id(gi, path):
+    library_id = None
+    folder_id = None
+    folder_names = path[1:] if len(path) > 1 else []
+    new_folders = folder_names
+    libs = gi.libraries.get_libraries(name=path[0])
+    if libs:
+        for lib in libs:
+            library_id = lib['id']
+            folders = gi.libraries.get_folders(library_id, folder_id=None, name=None)
+            for i, folder in enumerate(folders):
+                fid = folder['id']
+                details = gi.libraries.show_folder(library_id, fid)
+                library_path = details['library_path']
+                if library_path == folder_names:
+                    return (library_id, fid, [])
+                elif len(library_path) < len(folder_names):
+                    if library_path == folder_names[:len(library_path)]:
+                        nf = folder_names[len(library_path):]
+                        if len(nf) < len(new_folders):
+                            folder_id = fid
+                            new_folders = nf
+    return (library_id, folder_id, new_folders)
+
+def link_to_galaxy(filename:str, galaxy=None, user=None, password=None, api_key=None) -> None:
+    # Read input file
+    extension = path.splitext(filename)[1]
+# RV yaml input not incorporated yet, since Galaxy can't use pyspec right now
+#    if extension == '.yml' or extension == '.yaml':
+#        with open(filename, 'r') as f:
+#            data = safe_load(f)
+#    elif extension == '.nxs':
+    if extension == '.nxs':
+        with NXFile(filename, mode='r') as nxfile:
+            data = nxfile.readfile()
+    else:
+        raise ValueError(f'Invalid filename extension ({extension})')
+    if isinstance(data, dict):
+        # Create Nexus format object from input dictionary
+        wf = TomoWorkflow(**data)
+        if len(wf.sample_maps) > 1:
+            raise ValueError(f'Multiple sample maps not yet implemented')
+        nxroot = NXroot()
+        for sample_map in wf.sample_maps:
+            import_scanparser(sample_map.station)
+# RV raw data must be included, since Galaxy can't use pyspec right now
+#            sample_map.construct_nxentry(nxroot, include_raw_data=False)
+            sample_map.construct_nxentry(nxroot, include_raw_data=True)
+        nxentry = nxroot[nxroot.attrs['default']]
+    elif isinstance(data, NXroot):
+        nxentry = data[data.attrs['default']]
+    else:
+        raise ValueError(f'Invalid input file data ({data})')
+
+    # Get a Galaxy instance
+    if user is not None and password is not None :
+        gi = GalaxyInstance(url=galaxy, email=user, password=password)
+    elif api_key is not None:
+        gi = GalaxyInstance(url=galaxy, key=api_key)
+    else:
+        exit('Please specify either a valid Galaxy username/password or an API key.')
+
+    cycle = nxentry.instrument.source.attrs['cycle']
+    btr = nxentry.instrument.source.attrs['btr']
+    sample = nxentry.sample.name
+
+    # Create a Galaxy work library/folder
+    # Combine the cycle, BTR and sample name as the base library name
+    lib_path = [p.strip() for p in f'{cycle}/{btr}/{sample}'.split('/')]
+    (library_id, folder_id, folder_names) = get_folder_id(gi, lib_path)
+    if not library_id:
+        library = gi.libraries.create_library(lib_path[0], description=None, synopsis=None)
+        library_id = library['id']
+#        if user:
+#            gi.libraries.set_library_permissions(library_id, access_ids=user,
+#                    manage_ids=user, modify_ids=user)
+        logger.info(f'Created Library:\n{library}')
+    if len(folder_names):
+        folder = gi.libraries.create_folder(library_id, folder_names[0], description=None,
+                base_folder_id=folder_id)[0]
+        folder_id = folder['id']
+        logger.info(f'Created Folder:\n{folder}')
+        folder_names.pop(0)
+        while len(folder_names):
+            folder = gi.folders.create_folder(folder['id'], folder_names[0],
+                    description=None)
+            folder_id = folder['id']
+            logger.info(f'Created Folder:\n{folder}')
+            folder_names.pop(0)
+
+    # Create a sym link for the Nexus file
+    dataset = gi.libraries.upload_from_galaxy_filesystem(library_id, path.abspath(filename),
+            folder_id=folder_id, file_type='auto', dbkey='?', link_data_only='link_to_files',
+            roles='', preserve_dirs=False, tag_using_filenames=False, tags=None)[0]
+
+    # Make a history for the data
+    history_name = f'tomo {btr} {sample}'
+    history = gi.histories.create_history(name=history_name)
+    logger.info(f'Created history:\n{history}')
+    history_id = history['id']
+    gi.histories.copy_dataset(history_id, dataset['id'], source='library')
+
+# TODO add option to either 
+#     get a URL to share the history
+#     or to share with specific users
+# This might require using: 
+#     https://bioblend.readthedocs.io/en/latest/api_docs/galaxy/docs.html#using-bioblend-for-raw-api-calls
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/workflow/models.py	Tue Mar 21 16:22:42 2023 +0000
@@ -0,0 +1,1096 @@
+#!/usr/bin/env python3
+
+import logging
+logger = logging.getLogger(__name__)
+
+import logging
+
+import numpy as np
+import os
+import yaml
+
+from functools import cache
+from pathlib import PosixPath
+from pydantic import BaseModel as PydanticBaseModel
+from pydantic import validator, ValidationError, conint, confloat, constr, conlist, FilePath, \
+        PrivateAttr
+from nexusformat.nexus import *
+from time import time
+from typing import Optional, Literal
+from typing_extensions import TypedDict
+try:
+    from pyspec.file.spec import FileSpec
+except:
+    pass
+
+try:
+    from msnctools.general import is_int, is_num, input_int, input_int_list, input_num, \
+            input_yesno, input_menu, index_nearest, string_to_list, file_exists_and_readable
+except:
+    from general import is_int, is_num, input_int, input_int_list, input_num, \
+            input_yesno, input_menu, index_nearest, string_to_list, file_exists_and_readable
+
+
+def import_scanparser(station):
+    if station in ('id1a3', 'id3a'):
+        try:
+            from msnctools.scanparsers import SMBRotationScanParser
+            globals()['ScanParser'] = SMBRotationScanParser
+        except:
+            try:
+                from scanparsers import SMBRotationScanParser
+                globals()['ScanParser'] = SMBRotationScanParser
+            except:
+                pass
+    elif station in ('id3b'):
+        try:
+            from msnctools.scanparsers import FMBRotationScanParser
+            globals()['ScanParser'] = FMBRotationScanParser
+        except:
+            try:
+                from scanparsers import FMBRotationScanParser
+                globals()['ScanParser'] = FMBRotationScanParser
+            except:
+                pass
+    else:
+        raise RuntimeError(f'Invalid station: {station}')
+
+@cache
+def get_available_scan_numbers(spec_file:str):
+    scans = FileSpec(spec_file).scans
+    scan_numbers = list(scans.keys())
+    for scan_number in scan_numbers.copy():
+        try:
+            parser = ScanParser(spec_file, scan_number)
+            try:
+                scan_type = parser.scan_type
+            except:
+                scan_type = None
+        except:
+            scan_numbers.remove(scan_number)
+    return(scan_numbers)
+
+@cache
+def get_scanparser(spec_file:str, scan_number:int):
+    if scan_number not in get_available_scan_numbers(spec_file):
+        return(None)
+    else:
+        return(ScanParser(spec_file, scan_number))
+
+
+class BaseModel(PydanticBaseModel):
+    class Config:
+        validate_assignment = True
+        arbitrary_types_allowed = True
+
+    @classmethod
+    def construct_from_cli(cls):
+        obj = cls.construct()
+        obj.cli()
+        return(obj)
+
+    @classmethod
+    def construct_from_yaml(cls, filename):
+        try:
+            with open(filename, 'r') as infile:
+                indict = yaml.load(infile, Loader=yaml.CLoader)
+        except:
+            raise ValueError(f'Could not load a dictionary from {filename}')
+        else:
+            obj = cls(**indict)
+            return(obj)
+
+    @classmethod
+    def construct_from_file(cls, filename):
+        file_exists_and_readable(filename)
+        filename = os.path.abspath(filename)
+        fileformat = os.path.splitext(filename)[1]
+        yaml_extensions = ('.yaml','.yml')
+        nexus_extensions = ('.nxs','.nx5','.h5','.hdf5')
+        t0 = time()
+        if fileformat.lower() in yaml_extensions:
+            obj = cls.construct_from_yaml(filename)
+            logger.info(f'Constructed a model from {filename} in {time()-t0:.2f} seconds.')
+            return(obj)
+        elif fileformat.lower() in nexus_extensions:
+            obj = cls.construct_from_nexus(filename)
+            logger.info(f'Constructed a model from {filename} in {time()-t0:.2f} seconds.')
+            return(obj)
+        else:
+            logger.error(f'Unsupported file extension for constructing a model: {fileformat}')
+            raise TypeError(f'Unrecognized file extension: {fileformat}')
+
+    def dict_for_yaml(self, exclude_fields=[]):
+        yaml_dict = {}
+        for field_name in self.__fields__:
+            if field_name in exclude_fields:
+                continue
+            else:
+                field_value = getattr(self, field_name, None)
+                if field_value is not None:
+                    if isinstance(field_value, BaseModel):
+                        yaml_dict[field_name] = field_value.dict_for_yaml()
+                    elif isinstance(field_value,list) and all(isinstance(item,BaseModel)
+                            for item in field_value):
+                        yaml_dict[field_name] = [item.dict_for_yaml() for item in field_value]
+                    elif isinstance(field_value, PosixPath):
+                        yaml_dict[field_name] = str(field_value)
+                    else:
+                        yaml_dict[field_name] = field_value
+                else:
+                    continue
+        return(yaml_dict)
+
+    def write_to_yaml(self, filename=None):
+        yaml_dict = self.dict_for_yaml()
+        if filename is None:
+            logger.info('Printing yaml representation here:\n'+
+                    f'{yaml.dump(yaml_dict, sort_keys=False)}')
+        else:
+            try:
+                with open(filename, 'w') as outfile:
+                    yaml.dump(yaml_dict, outfile, sort_keys=False)
+                logger.info(f'Successfully wrote this model to {filename}')
+            except:
+                logger.error(f'Unknown error -- could not write to {filename} in yaml format.')
+                logger.info('Printing yaml representation here:\n'+
+                        f'{yaml.dump(yaml_dict, sort_keys=False)}')
+
+    def write_to_file(self, filename, force_overwrite=False):
+        file_writeable, fileformat = self.output_file_valid(filename,
+                force_overwrite=force_overwrite)
+        if fileformat == 'yaml':
+            if file_writeable:
+                self.write_to_yaml(filename=filename)
+            else:
+                self.write_to_yaml()
+        elif fileformat == 'nexus':
+            if file_writeable:
+                self.write_to_nexus(filename=filename)
+
+    def output_file_valid(self, filename, force_overwrite=False):
+        filename = os.path.abspath(filename)
+        fileformat = os.path.splitext(filename)[1]
+        yaml_extensions = ('.yaml','.yml')
+        nexus_extensions = ('.nxs','.nx5','.h5','.hdf5')
+        if fileformat.lower() not in (*yaml_extensions, *nexus_extensions):
+            return(False, None) # Only yaml and NeXus files allowed for output now.
+        elif fileformat.lower() in yaml_extensions:
+            fileformat = 'yaml'
+        elif fileformat.lower() in nexus_extensions:
+            fileformat = 'nexus'
+        if os.path.isfile(filename):
+            if os.access(filename, os.W_OK):
+                if not force_overwrite:
+                    logger.error(f'{filename} will not be overwritten.')
+                    return(False, fileformat)
+            else:
+                logger.error(f'Cannot access {filename} for writing.')
+                return(False, fileformat)
+        if os.path.isdir(os.path.dirname(filename)):
+            if os.access(os.path.dirname(filename), os.W_OK):
+                return(True, fileformat)
+            else:
+                logger.error(f'Cannot access {os.path.dirname(filename)} for writing.')
+                return(False, fileformat)
+        else:
+            try:
+                os.makedirs(os.path.dirname(filename))
+                return(True, fileformat)
+            except:
+                logger.error(f'Cannot create {os.path.dirname(filename)} for output.')
+                return(False, fileformat)
+
+    def set_single_attr_cli(self, attr_name, attr_desc='unknown attribute', list_flag=False,
+            **cli_kwargs):
+        if cli_kwargs.get('chain_attr_desc', False):
+            cli_kwargs['attr_desc'] = attr_desc
+        try:
+            attr = getattr(self, attr_name, None)
+            if attr is None:
+                attr = self.__fields__[attr_name].type_.construct()
+            if cli_kwargs.get('chain_attr_desc', False):
+                cli_kwargs['attr_desc'] = attr_desc
+            input_accepted = False
+            while not input_accepted:
+                try:
+                    attr.cli(**cli_kwargs)
+                except ValidationError as e:
+                    print(e)
+                    print(f'Removing {attr_desc} configuration')
+                    attr = self.__fields__[attr_name].type_.construct()
+                    continue
+                except KeyboardInterrupt as e:
+                    raise e
+                except BaseException as e:
+                    print(f'{type(e).__name__}: {e}')
+                    print(f'Removing {attr_desc} configuration')
+                    attr = self.__fields__[attr_name].type_.construct()
+                    continue
+                try:
+                    setattr(self, attr_name, attr)
+                except ValidationError as e:
+                    print(e)
+                except KeyboardInterrupt as e:
+                    raise e
+                except BaseException as e:
+                    print(f'{type(e).__name__}: {e}')
+                else:
+                    input_accepted = True
+        except:
+            input_accepted = False
+            while not input_accepted:
+                attr = getattr(self, attr_name, None)
+                if attr is None:
+                    input_value = input(f'Type and enter a value for {attr_desc}: ')
+                else:
+                    input_value = input(f'Type and enter a new value for {attr_desc} or press '+
+                            f'enter to keep the current one ({attr}): ')
+                if list_flag:
+                    input_value = string_to_list(input_value, remove_duplicates=False, sort=False)
+                if len(input_value) == 0:
+                    input_value = getattr(self, attr_name, None)
+                try:
+                    setattr(self, attr_name, input_value)
+                except ValidationError as e:
+                    print(e)
+                except KeyboardInterrupt as e:
+                    raise e
+                except BaseException as e:
+                    print(f'Unexpected {type(e).__name__}: {e}')
+                else:
+                    input_accepted = True
+
+    def set_list_attr_cli(self, attr_name, attr_desc='unknown attribute', **cli_kwargs):
+        if cli_kwargs.get('chain_attr_desc', False):
+            cli_kwargs['attr_desc'] = attr_desc
+        attr = getattr(self, attr_name, None)
+        if attr is not None:
+            # Check existing items
+            for item in attr:
+                item_accepted = False
+                while not item_accepted:
+                    item.cli(**cli_kwargs)
+                    try:
+                        setattr(self, attr_name, attr)
+                    except ValidationError as e:
+                        print(e)
+                    except KeyboardInterrupt as e:
+                        raise e
+                    except BaseException as e:
+                        print(f'{type(e).__name__}: {e}')
+                    else:
+                        item_accepted = True
+        else:
+            # Initialize list for new attribute & starting item
+            attr = []
+            item = self.__fields__[attr_name].type_.construct()
+        # Append (optional) additional items
+        append = input_yesno(f'Add a {attr_desc} configuration? (y/n)', 'n')
+        while append:
+            attr.append(item.__class__.construct_from_cli())
+            try:
+                setattr(self, attr_name, attr)
+            except ValidationError as e:
+                print(e)
+                print(f'Removing last {attr_desc} configuration from the list')
+                attr.pop()
+            except KeyboardInterrupt as e:
+                raise e
+            except BaseException as e:
+                print(f'{type(e).__name__}: {e}')
+                print(f'Removing last {attr_desc} configuration from the list')
+                attr.pop()
+            else:
+                append = input_yesno(f'Add another {attr_desc} configuration? (y/n)', 'n')
+
+
+class Detector(BaseModel):
+    prefix: constr(strip_whitespace=True, min_length=1)
+    rows: conint(gt=0)
+    columns: conint(gt=0)
+    pixel_size: conlist(item_type=confloat(gt=0), min_items=1, max_items=2)
+    lens_magnification: confloat(gt=0) = 1.0
+
+    @property
+    def get_pixel_size(self):
+        return(list(np.asarray(self.pixel_size)/self.lens_magnification))
+
+    def construct_from_yaml(self, filename):
+        try:
+            with open(filename, 'r') as infile:
+                indict = yaml.load(infile, Loader=yaml.CLoader)
+            detector = indict['detector']
+            self.prefix = detector['id']
+            pixels = detector['pixels']
+            self.rows = pixels['rows']
+            self.columns = pixels['columns']
+            self.pixel_size = pixels['size']
+            self.lens_magnification = indict['lens_magnification']
+        except:
+            logging.warning(f'Could not load a dictionary from {filename}')
+            return(False)
+        else:
+            return(True)
+
+    def cli(self):
+        print('\n -- Configure the detector -- ')
+        self.set_single_attr_cli('prefix', 'detector ID')
+        self.set_single_attr_cli('rows', 'number of pixel rows')
+        self.set_single_attr_cli('columns', 'number of pixel columns')
+        self.set_single_attr_cli('pixel_size', 'pixel size in mm (enter either a single value for '+
+                'square pixels or a pair of values for the size in the respective row and column '+
+                'directions)', list_flag=True)
+        self.set_single_attr_cli('lens_magnification', 'lens magnification')
+
+    def construct_nxdetector(self):
+        nxdetector = NXdetector()
+        nxdetector.local_name = self.prefix
+        pixel_size = self.get_pixel_size
+        if len(pixel_size) == 1:
+            nxdetector.x_pixel_size = pixel_size[0]
+            nxdetector.y_pixel_size = pixel_size[0]
+        else:
+            nxdetector.x_pixel_size = pixel_size[0]
+            nxdetector.y_pixel_size = pixel_size[1]
+        nxdetector.x_pixel_size.attrs['units'] = 'mm'
+        nxdetector.y_pixel_size.attrs['units'] = 'mm'
+        return(nxdetector)
+
+
+class ScanInfo(TypedDict):
+    scan_number: int
+    starting_image_offset: conint(ge=0)
+    num_image: conint(gt=0)
+    ref_x: float
+    ref_z: float
+
+class SpecScans(BaseModel):
+    spec_file: FilePath
+    scan_numbers: conlist(item_type=conint(gt=0), min_items=1)
+    stack_info: conlist(item_type=ScanInfo, min_items=1) = []
+
+    @validator('spec_file')
+    def validate_spec_file(cls, spec_file):
+        try:
+            spec_file = os.path.abspath(spec_file)
+            sspec_file = FileSpec(spec_file)
+        except:
+            raise ValueError(f'Invalid SPEC file {spec_file}')
+        else:
+            return(spec_file)
+
+    @validator('scan_numbers')
+    def validate_scan_numbers(cls, scan_numbers, values):
+        spec_file = values.get('spec_file')
+        if spec_file is not None:
+            spec_scans = FileSpec(spec_file)
+            for scan_number in scan_numbers:
+                scan = spec_scans.get_scan_by_number(scan_number)
+                if scan is None:
+                    raise ValueError(f'There is no scan number {scan_number} in {spec_file}')
+        return(scan_numbers)
+
+    @validator('stack_info')
+    def validate_stack_info(cls, stack_info, values):
+        scan_numbers = values.get('scan_numbers')
+        assert(len(scan_numbers) == len(stack_info))
+        for scan_info in stack_info:
+            assert(scan_info['scan_number'] in scan_numbers)
+            is_int(scan_info['starting_image_offset'], ge=0, lt=scan_info['num_image'],
+                    raise_error=True)
+        return(stack_info)
+
+    @classmethod
+    def construct_from_nxcollection(cls, nxcollection:NXcollection):
+        config = {}
+        config['spec_file'] = nxcollection.attrs['spec_file']
+        scan_numbers = []
+        stack_info = []
+        for nxsubentry_name, nxsubentry in nxcollection.items():
+            scan_number = int(nxsubentry_name.split('_')[-1])
+            scan_numbers.append(scan_number)
+            stack_info.append({'scan_number': scan_number,
+                    'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number),
+                    'num_image': len(nxsubentry.sample.rotation_angle),
+                    'ref_x': float(nxsubentry.sample.x_translation),
+                    'ref_z': float(nxsubentry.sample.z_translation)})
+        config['scan_numbers'] = sorted(scan_numbers)
+        config['stack_info'] = stack_info
+        return(cls(**config))
+
+    @property
+    def available_scan_numbers(self):
+        return(get_available_scan_numbers(self.spec_file))
+
+    def set_from_nxcollection(self, nxcollection:NXcollection):
+        self.spec_file = nxcollection.attrs['spec_file']
+        scan_numbers = []
+        stack_info = []
+        for nxsubentry_name, nxsubentry in nxcollection.items():
+            scan_number = int(nxsubentry_name.split('_')[-1])
+            scan_numbers.append(scan_number)
+            stack_info.append({'scan_number': scan_number,
+                    'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number),
+                    'num_image': len(nxsubentry.sample.rotation_angle),
+                    'ref_x': float(nxsubentry.sample.x_translation),
+                    'ref_z': float(nxsubentry.sample.z_translation)})
+        self.scan_numbers = sorted(scan_numbers)
+        self.stack_info = stack_info
+
+    def get_scan_index(self, scan_number):
+        scan_index = [scan_index for scan_index, scan_info in enumerate(self.stack_info)
+                if scan_info['scan_number'] == scan_number]
+        if len(scan_index) > 1:
+            raise ValueError('Duplicate scan_numbers in image stack')
+        elif len(scan_index) == 1:
+            return(scan_index[0])
+        else:
+            return(None)
+
+    def get_scanparser(self, scan_number):
+        return(get_scanparser(self.spec_file, scan_number))
+
+    def get_detector_data(self, detector_prefix, scan_number=None, scan_step_index=None):
+        image_stacks = []
+        if scan_number is None:
+            scan_numbers = self.scan_numbers
+        else:
+            scan_numbers = [scan_number]
+        for scan_number in scan_numbers:
+            parser = self.get_scanparser(scan_number)
+            scan_info = self.stack_info[self.get_scan_index(scan_number)]
+            image_offset = scan_info['starting_image_offset']
+            if scan_step_index is None:
+                num_image = scan_info['num_image']
+                image_stacks.append(parser.get_detector_data(detector_prefix,
+                        (image_offset, image_offset+num_image)))
+            else:
+                image_stacks.append(parser.get_detector_data(detector_prefix,
+                        image_offset+scan_step_index))
+        if scan_number is not None and scan_step_index is not None:
+            # Return a single image for a specific scan_number and scan_step_index request
+            return(image_stacks[0])
+        else:
+            # Return a list otherwise
+            return(image_stacks)
+        return(image_stacks)
+
+    def scan_numbers_cli(self, attr_desc, **kwargs):
+        available_scan_numbers = self.available_scan_numbers
+        station = kwargs.get('station')
+        if (station is not None and station in ('id1a3', 'id3a') and
+                'scan_type' in kwargs):
+            scan_type = kwargs['scan_type']
+            if scan_type == 'ts1':
+                available_scan_numbers = []
+                for scan_number in self.available_scan_numbers:
+                    parser = self.get_scanparser(scan_number)
+                    try:
+                        if parser.scan_type == scan_type:
+                            available_scan_numbers.append(scan_number)
+                    except:
+                        pass
+            elif scan_type == 'df1':
+                tomo_scan_numbers = kwargs['tomo_scan_numbers']
+                available_scan_numbers = []
+                for scan_number in tomo_scan_numbers:
+                    parser = self.get_scanparser(scan_number-2)
+                    assert(parser.scan_type == scan_type)
+                    available_scan_numbers.append(scan_number-2)
+            elif scan_type == 'bf1':
+                tomo_scan_numbers = kwargs['tomo_scan_numbers']
+                available_scan_numbers = []
+                for scan_number in tomo_scan_numbers:
+                    parser = self.get_scanparser(scan_number-1)
+                    assert(parser.scan_type == scan_type)
+                    available_scan_numbers.append(scan_number-1)
+        if len(available_scan_numbers) == 1:
+            input_mode = 1
+        else:
+            if hasattr(self, 'scan_numbers'):
+                print(f'Currently selected {attr_desc}scan numbers are: {self.scan_numbers}')
+                menu_options = [f'Select a subset of the available {attr_desc}scan numbers',
+                        f'Use all available {attr_desc}scan numbers in {self.spec_file}',
+                        f'Keep the currently selected {attr_desc}scan numbers']
+            else:
+                menu_options = [f'Select a subset of the available {attr_desc}scan numbers',
+                        f'Use all available {attr_desc}scan numbers in {self.spec_file}']
+            print(f'Available scan numbers in {self.spec_file} are: '+
+                    f'{available_scan_numbers}')
+            input_mode = input_menu(menu_options, header='Choose one of the following options '+
+                    'for selecting scan numbers')
+        if input_mode == 0:
+            accept_scan_numbers = False
+            while not accept_scan_numbers:
+                try:
+                    self.scan_numbers = \
+                            input_int_list(f'Enter a series of {attr_desc}scan numbers')
+                except ValidationError as e:
+                    print(e)
+                except KeyboardInterrupt as e:
+                    raise e
+                except BaseException as e:
+                    print(f'Unexpected {type(e).__name__}: {e}')
+                else:
+                    accept_scan_numbers = True
+        elif input_mode == 1:
+            self.scan_numbers = available_scan_numbers
+        elif input_mode == 2:
+            pass
+
+    def cli(self, **cli_kwargs):
+        if cli_kwargs.get('attr_desc') is not None:
+            attr_desc = f'{cli_kwargs["attr_desc"]} '
+        else:
+            attr_desc = ''
+        print(f'\n -- Configure which scans to use from a single {attr_desc}SPEC file')
+        self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path')
+        self.scan_numbers_cli(attr_desc)
+
+    def construct_nxcollection(self, image_key, thetas, detector):
+        nxcollection = NXcollection()
+        nxcollection.attrs['spec_file'] = str(self.spec_file)
+        parser = self.get_scanparser(self.scan_numbers[0])
+        nxcollection.attrs['date'] = parser.spec_scan.file_date
+        for scan_number in self.scan_numbers:
+            # Get scan info
+            scan_info = self.stack_info[self.get_scan_index(scan_number)]
+            # Add an NXsubentry to the NXcollection for each scan
+            entry_name = f'scan_{scan_number}'
+            nxsubentry = NXsubentry()
+            nxcollection[entry_name] = nxsubentry
+            parser = self.get_scanparser(scan_number)
+            nxsubentry.start_time = parser.spec_scan.date
+            nxsubentry.spec_command = parser.spec_command
+            # Add an NXdata for independent dimensions to the scan's NXsubentry
+            num_image = scan_info['num_image']
+            if thetas is None:
+                thetas = num_image*[0.0]
+            else:
+                assert(num_image == len(thetas))
+#            nxsubentry.independent_dimensions = NXdata()
+#            nxsubentry.independent_dimensions.rotation_angle = thetas
+#            nxsubentry.independent_dimensions.rotation_angle.units = 'degrees'
+            # Add an NXinstrument to the scan's NXsubentry
+            nxsubentry.instrument = NXinstrument()
+            # Add an NXdetector to the NXinstrument to the scan's NXsubentry
+            nxsubentry.instrument.detector = detector.construct_nxdetector()
+            nxsubentry.instrument.detector.frame_start_number = scan_info['starting_image_offset']
+            nxsubentry.instrument.detector.image_key = image_key
+            # Add an NXsample to the scan's NXsubentry
+            nxsubentry.sample = NXsample()
+            nxsubentry.sample.rotation_angle = thetas
+            nxsubentry.sample.rotation_angle.units = 'degrees'
+            nxsubentry.sample.x_translation = scan_info['ref_x']
+            nxsubentry.sample.x_translation.units = 'mm'
+            nxsubentry.sample.z_translation = scan_info['ref_z']
+            nxsubentry.sample.z_translation.units = 'mm'
+        return(nxcollection)
+
+
+class FlatField(SpecScans):
+
+    def image_range_cli(self, attr_desc, detector_prefix):
+        stack_info = self.stack_info
+        for scan_number in self.scan_numbers:
+            # Parse the available image range
+            parser = self.get_scanparser(scan_number)
+            image_offset = parser.starting_image_offset
+            num_image = parser.get_num_image(detector_prefix.upper())
+            scan_index = self.get_scan_index(scan_number)
+
+            # Select the image set
+            last_image_index = image_offset+num_image
+            print(f'Available good image set index range: [{image_offset}, {last_image_index})')
+            image_set_approved = False
+            if scan_index is not None:
+                scan_info = stack_info[scan_index]
+                print(f'Current starting image offset and number of images: '+
+                        f'{scan_info["starting_image_offset"]} and {scan_info["num_image"]}')
+                image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y')
+            if not image_set_approved:
+                print(f'Default starting image offset and number of images: '+
+                        f'{image_offset} and {num_image}')
+                image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y')
+                if image_set_approved:
+                    offset = image_offset
+                    num = last_image_index-offset
+                while not image_set_approved:
+                    offset = input_int(f'Enter the starting image offset', ge=image_offset,
+                            lt=last_image_index)#, default=image_offset)
+                    num = input_int(f'Enter the number of images', ge=1,
+                            le=last_image_index-offset)#, default=last_image_index-offset)
+                    print(f'Current starting image offset and number of images: {offset} and {num}')
+                    image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y')
+                if scan_index is not None:
+                    scan_info['starting_image_offset'] = offset
+                    scan_info['num_image'] = num
+                    scan_info['ref_x'] = parser.horizontal_shift
+                    scan_info['ref_z'] = parser.vertical_shift
+                else:
+                    stack_info.append({'scan_number': scan_number, 'starting_image_offset': offset,
+                            'num_image': num, 'ref_x': parser.horizontal_shift,
+                            'ref_z': parser.vertical_shift})
+        self.stack_info = stack_info
+
+    def cli(self, **cli_kwargs):
+        if cli_kwargs.get('attr_desc') is not None:
+            attr_desc = f'{cli_kwargs["attr_desc"]} '
+        else:
+            attr_desc = ''
+        station = cli_kwargs.get('station')
+        detector = cli_kwargs.get('detector')
+        print(f'\n -- Configure the location of the {attr_desc}scan data -- ')
+        if station in ('id1a3', 'id3a'):
+            self.spec_file = cli_kwargs['spec_file']
+            tomo_scan_numbers = cli_kwargs['tomo_scan_numbers']
+            scan_type = cli_kwargs['scan_type']
+            self.scan_numbers_cli(attr_desc, station=station, tomo_scan_numbers=tomo_scan_numbers,
+                    scan_type=scan_type)
+        else:
+            self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path')
+            self.scan_numbers_cli(attr_desc)
+        self.image_range_cli(attr_desc, detector.prefix)
+
+
+class TomoField(SpecScans):
+    theta_range: dict = {}
+
+    @validator('theta_range')
+    def validate_theta_range(cls, theta_range):
+        if len(theta_range) != 3 and len(theta_range) != 4:
+            raise ValueError(f'Invalid theta range {theta_range}')
+        is_num(theta_range['start'], raise_error=True)
+        is_num(theta_range['end'], raise_error=True)
+        is_int(theta_range['num'], gt=1, raise_error=True)
+        if theta_range['end'] <= theta_range['start']:
+            raise ValueError(f'Invalid theta range {theta_range}')
+        if 'start_index' in theta_range:
+            is_int(theta_range['start_index'], ge=0, raise_error=True)
+        return(theta_range)
+
+    @classmethod
+    def construct_from_nxcollection(cls, nxcollection:NXcollection):
+        #RV Can I derive this from the same classfunction for SpecScans by adding theta_range
+        config = {}
+        config['spec_file'] = nxcollection.attrs['spec_file']
+        scan_numbers = []
+        stack_info = []
+        for nxsubentry_name, nxsubentry in nxcollection.items():
+            scan_number = int(nxsubentry_name.split('_')[-1])
+            scan_numbers.append(scan_number)
+            stack_info.append({'scan_number': scan_number,
+                    'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number),
+                    'num_image': len(nxsubentry.sample.rotation_angle),
+                    'ref_x': float(nxsubentry.sample.x_translation),
+                    'ref_z': float(nxsubentry.sample.z_translation)})
+        config['scan_numbers'] = sorted(scan_numbers)
+        config['stack_info'] = stack_info
+        for name in nxcollection.entries:
+            if 'scan_' in name:
+                thetas = np.asarray(nxcollection[name].sample.rotation_angle)
+                config['theta_range'] = {'start': thetas[0], 'end': thetas[-1], 'num': thetas.size}
+                break
+        return(cls(**config))
+
+    def get_horizontal_shifts(self, scan_number=None):
+        horizontal_shifts = []
+        if scan_number is None:
+            scan_numbers = self.scan_numbers
+        else:
+            scan_numbers = [scan_number]
+        for scan_number in scan_numbers:
+            parser = self.get_scanparser(scan_number)
+            horizontal_shifts.append(parser.horizontal_shift)
+        if len(horizontal_shifts) == 1:
+            return(horizontal_shifts[0])
+        else:
+            return(horizontal_shifts)
+
+    def get_vertical_shifts(self, scan_number=None):
+        vertical_shifts = []
+        if scan_number is None:
+            scan_numbers = self.scan_numbers
+        else:
+            scan_numbers = [scan_number]
+        for scan_number in scan_numbers:
+            parser = self.get_scanparser(scan_number)
+            vertical_shifts.append(parser.vertical_shift)
+        if len(vertical_shifts) == 1:
+            return(vertical_shifts[0])
+        else:
+            return(vertical_shifts)
+
+    def theta_range_cli(self, scan_number, attr_desc, station):
+        # Parse the available theta range
+        parser = self.get_scanparser(scan_number)
+        theta_vals = parser.theta_vals
+        spec_theta_start = theta_vals.get('start')
+        spec_theta_end = theta_vals.get('end')
+        spec_num_theta = theta_vals.get('num')
+
+        # Check for consistency of theta ranges between scans
+        if scan_number != self.scan_numbers[0]:
+            parser = self.get_scanparser(self.scan_numbers[0])
+            if (parser.theta_vals.get('start') != spec_theta_start or
+                    parser.theta_vals.get('end') != spec_theta_end or
+                    parser.theta_vals.get('num') != spec_num_theta):
+                raise ValueError(f'Incompatible theta ranges between {attr_desc}scans:'+
+                        f'\n\tScan {scan_number}: {theta_vals}'+
+                        f'\n\tScan {self.scan_numbers[0]}: {parser.theta_vals}')
+            return
+
+        # Select the theta range for the tomo reconstruction from the first scan
+        theta_range_approved = False
+        thetas = np.linspace(spec_theta_start, spec_theta_end, spec_num_theta)
+        delta_theta = thetas[1]-thetas[0]
+        print(f'Theta range obtained from SPEC data: [{spec_theta_start}, {spec_theta_end}]')
+        print(f'Theta step size = {delta_theta}')
+        print(f'Number of theta values: {spec_num_theta}')
+        default_start = None
+        default_end = None
+        if station in ('id1a3', 'id3a'):
+            theta_range_approved = input_yesno(f'Accept this theta range (y/n)?', 'y')
+            if theta_range_approved:
+                self.theta_range = {'start': float(spec_theta_start), 'end': float(spec_theta_end),
+                        'num': int(spec_num_theta), 'start_index': 0}
+                return
+        elif station in ('id3b'):
+            if spec_theta_start <= 0.0 and spec_theta_end >= 180.0:
+                default_start = 0
+                default_end = 180
+            elif spec_theta_end-spec_theta_start == 180:
+                default_start = spec_theta_start
+                default_end = spec_theta_end
+        while not theta_range_approved:
+            theta_start = input_num(f'Enter the first theta (included)', ge=spec_theta_start,
+                    lt=spec_theta_end, default=default_start)
+            theta_index_start = index_nearest(thetas, theta_start)
+            theta_start = thetas[theta_index_start]
+            theta_end = input_num(f'Enter the last theta (excluded)',
+                    ge=theta_start+delta_theta, le=spec_theta_end, default=default_end)
+            theta_index_end = index_nearest(thetas, theta_end)
+            theta_end = thetas[theta_index_end]
+            num_theta = theta_index_end-theta_index_start
+            print(f'Selected theta range: [{theta_start}, {theta_start+delta_theta}, ..., '+
+                    f'{theta_end})')
+            print(f'Number of theta values: {num_theta}')
+            theta_range_approved = input_yesno(f'Accept this theta range (y/n)?', 'y')
+        self.theta_range = {'start': float(theta_start), 'end': float(theta_end),
+                'num': int(num_theta), 'start_index': int(theta_index_start)}
+
+    def image_range_cli(self, attr_desc, detector_prefix):
+        stack_info = self.stack_info
+        for scan_number in self.scan_numbers:
+            # Parse the available image range
+            parser = self.get_scanparser(scan_number)
+            image_offset = parser.starting_image_offset
+            num_image = parser.get_num_image(detector_prefix.upper())
+            scan_index = self.get_scan_index(scan_number)
+
+            # Select the image set matching the theta range
+            num_theta = self.theta_range['num']
+            theta_index_start = self.theta_range['start_index']
+            if num_theta > num_image-theta_index_start:
+                raise ValueError(f'Available {attr_desc}image indices incompatible with thetas:'+
+                        f'\n\tNumber of thetas and offset = {num_theta} and {theta_index_start}'+
+                        f'\n\tNumber of available images {num_image}')
+            if scan_index is not None:
+                scan_info = stack_info[scan_index]
+                scan_info['starting_image_offset'] = image_offset+theta_index_start
+                scan_info['num_image'] = num_theta
+                scan_info['ref_x'] = parser.horizontal_shift
+                scan_info['ref_z'] = parser.vertical_shift
+            else:
+                stack_info.append({'scan_number': scan_number,
+                        'starting_image_offset': image_offset+theta_index_start,
+                        'num_image': num_theta, 'ref_x': parser.horizontal_shift,
+                        'ref_z': parser.vertical_shift})
+        self.stack_info = stack_info
+
+    def cli(self, **cli_kwargs):
+        if cli_kwargs.get('attr_desc') is not None:
+            attr_desc = f'{cli_kwargs["attr_desc"]} '
+        else:
+            attr_desc = ''
+        cycle = cli_kwargs.get('cycle')
+        btr = cli_kwargs.get('btr')
+        station = cli_kwargs.get('station')
+        detector = cli_kwargs.get('detector')
+        sample_name = cli_kwargs.get('sample_name')
+        print(f'\n -- Configure the location of the {attr_desc}scan data -- ')
+        if station in ('id1a3', 'id3a'):
+            basedir = f'/nfs/chess/{station}/{cycle}/{btr}'
+            runs = [d for d in os.listdir(basedir) if os.path.isdir(os.path.join(basedir, d))]
+#RV            index = 15-1
+#RV            index = 7-1
+            if sample_name is not None and sample_name in runs:
+                index = runs.index(sample_name)
+            else:
+                index = input_menu(runs, header='Choose a sample directory')
+            self.spec_file = f'{basedir}/{runs[index]}/spec.log'
+            self.scan_numbers_cli(attr_desc, station=station, scan_type='ts1')
+        else:
+            self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path')
+            self.scan_numbers_cli(attr_desc)
+        for scan_number in self.scan_numbers:
+            self.theta_range_cli(scan_number, attr_desc, station)
+        self.image_range_cli(attr_desc, detector.prefix)
+
+
+class Sample(BaseModel):
+    name: constr(min_length=1)
+    description: Optional[str]
+    rotation_angles: Optional[list]
+    x_translations: Optional[list]
+    z_translations: Optional[list]
+
+    @classmethod
+    def construct_from_nxsample(cls, nxsample:NXsample):
+        config = {}
+        config['name'] = nxsample.name.nxdata
+        if 'description' in nxsample:
+            config['description'] = nxsample.description.nxdata
+        if 'rotation_angle' in nxsample:
+            config['rotation_angle'] = nxsample.rotation_angle.nxdata
+        if 'x_translation' in nxsample:
+            config['x_translation'] = nxsample.x_translation.nxdata
+        if 'z_translation' in nxsample:
+            config['z_translation'] = nxsample.z_translation.nxdata
+        return(cls(**config))
+
+    def cli(self):
+        print('\n -- Configure the sample metadata -- ')
+#RV        self.name = 'sobhani-3249-A'
+#RV        self.name = 'tenstom_1304r-1'
+        self.set_single_attr_cli('name', 'the sample name')
+#RV        self.description = 'test sample'
+        self.set_single_attr_cli('description', 'a description of the sample (optional)')
+
+
+class MapConfig(BaseModel):
+    cycle: constr(strip_whitespace=True, min_length=1)
+    btr: constr(strip_whitespace=True, min_length=1)
+    title: constr(strip_whitespace=True, min_length=1)
+    station: Literal['id1a3', 'id3a', 'id3b'] = None
+    sample: Sample
+    detector: Detector = Detector.construct()
+    tomo_fields: TomoField
+    dark_field: Optional[FlatField]
+    bright_field: FlatField
+    _thetas: list[float] = PrivateAttr()
+    _field_types = ({'name': 'dark_field', 'image_key': 2}, {'name': 'bright_field',
+            'image_key': 1}, {'name': 'tomo_fields', 'image_key': 0})
+
+    @classmethod
+    def construct_from_nxentry(cls, nxentry:NXentry):
+        config = {}
+        config['cycle'] = nxentry.instrument.source.attrs['cycle']
+        config['btr'] = nxentry.instrument.source.attrs['btr']
+        config['title'] = nxentry.nxname
+        config['station'] = nxentry.instrument.source.attrs['station']
+        config['sample'] = Sample.construct_from_nxsample(nxentry['sample'])
+        for nxobject_name, nxobject in nxentry.spec_scans.items():
+            if isinstance(nxobject, NXcollection):
+                config[nxobject_name] = SpecScans.construct_from_nxcollection(nxobject)
+        return(cls(**config))
+
+#FIX cache?
+    @property
+    def thetas(self):
+        try:
+            return(self._thetas)
+        except:
+            theta_range = self.tomo_fields.theta_range
+            self._thetas = list(np.linspace(theta_range['start'], theta_range['end'],
+                    theta_range['num']))
+            return(self._thetas)
+
+    def cli(self):
+        print('\n -- Configure a map from a set of SPEC scans (dark, bright, and tomo), '+
+                'and / or detector data -- ')
+#RV        self.cycle = '2021-3'
+#RV        self.cycle = '2022-2'
+#RV        self.cycle = '2023-1'
+        self.set_single_attr_cli('cycle', 'beam cycle')
+#RV        self.btr = 'z-3234-A'
+#RV        self.btr = 'sobhani-3249-A'
+#RV        self.btr = 'przybyla-3606-a'
+        self.set_single_attr_cli('btr', 'BTR')
+#RV        self.title = 'z-3234-A'
+#RV        self.title = 'tomo7C'
+#RV        self.title = 'cmc-test-dwell-1'
+        self.set_single_attr_cli('title', 'title for the map entry')
+#RV        self.station = 'id3a'
+#RV        self.station = 'id3b'
+#RV        self.station = 'id1a3'
+        self.set_single_attr_cli('station', 'name of the station at which scans were collected '+
+                '(currently choose from: id1a3, id3a, id3b)')
+        import_scanparser(self.station)
+        self.set_single_attr_cli('sample')
+        use_detector_config = False
+        if hasattr(self.detector, 'prefix') and len(self.detector.prefix):
+            use_detector_config = input_yesno(f'Current detector settings:\n{self.detector}\n'+
+                    f'Keep these settings? (y/n)')
+        if not use_detector_config:
+            menu_options = ['not listed', 'andor2', 'manta', 'retiga']
+            input_mode = input_menu(menu_options, header='Choose one of the following detector '+
+                    'configuration options')
+            if input_mode:
+                detector_config_file = f'{menu_options[input_mode]}.yaml'
+                have_detector_config = self.detector.construct_from_yaml(detector_config_file)
+            else:
+                have_detector_config = False
+            if not have_detector_config:
+                self.set_single_attr_cli('detector', 'detector')
+        self.set_single_attr_cli('tomo_fields', 'Tomo field', chain_attr_desc=True,
+                cycle=self.cycle, btr=self.btr, station=self.station, detector=self.detector,
+                sample_name=self.sample.name)
+        if self.station in ('id1a3', 'id3a'):
+            have_dark_field = True
+            tomo_spec_file = self.tomo_fields.spec_file
+        else:
+            have_dark_field = input_yesno(f'Are Dark field images available? (y/n)')
+            tomo_spec_file = None
+        if have_dark_field:
+            self.set_single_attr_cli('dark_field', 'Dark field', chain_attr_desc=True,
+                    station=self.station, detector=self.detector, spec_file=tomo_spec_file,
+                    tomo_scan_numbers=self.tomo_fields.scan_numbers, scan_type='df1')
+        self.set_single_attr_cli('bright_field', 'Bright field', chain_attr_desc=True,
+                station=self.station, detector=self.detector, spec_file=tomo_spec_file,
+                    tomo_scan_numbers=self.tomo_fields.scan_numbers, scan_type='bf1')
+
+    def construct_nxentry(self, nxroot, include_raw_data=True):
+        # Construct base NXentry
+        nxentry = NXentry()
+
+        # Add an NXentry to the NXroot
+        nxroot[self.title] = nxentry
+        nxroot.attrs['default'] = self.title
+        nxentry.definition = 'NXtomo'
+#        nxentry.attrs['default'] = 'data'
+
+        # Add an NXinstrument to the NXentry
+        nxinstrument = NXinstrument()
+        nxentry.instrument = nxinstrument
+
+        # Add an NXsource to the NXinstrument
+        nxsource = NXsource()
+        nxinstrument.source = nxsource
+        nxsource.type = 'Synchrotron X-ray Source'
+        nxsource.name = 'CHESS'
+        nxsource.probe = 'x-ray'
+ 
+        # Tag the NXsource with the runinfo (as an attribute)
+        nxsource.attrs['cycle'] = self.cycle
+        nxsource.attrs['btr'] = self.btr
+        nxsource.attrs['station'] = self.station
+
+        # Add an NXdetector to the NXinstrument (don't fill in data fields yet)
+        nxinstrument.detector = self.detector.construct_nxdetector()
+
+        # Add an NXsample to NXentry (don't fill in data fields yet)
+        nxsample = NXsample()
+        nxentry.sample = nxsample
+        nxsample.name = self.sample.name
+        nxsample.description = self.sample.description
+
+        # Add an NXcollection to the base NXentry to hold metadata about the spec scans in the map
+        # Also obtain the data fields in NXsample and NXdetector
+        nxspec_scans = NXcollection()
+        nxentry.spec_scans = nxspec_scans
+        image_keys = []
+        sequence_numbers = []
+        image_stacks = []
+        rotation_angles = []
+        x_translations = []
+        z_translations = []
+        for field_type in self._field_types:
+            field_name = field_type['name']
+            field = getattr(self, field_name)
+            if field is None:
+                continue
+            image_key = field_type['image_key']
+            if field_type['name'] == 'tomo_fields':
+                thetas = self.thetas
+            else:
+                thetas = None
+            # Add the scans in a single spec file
+            nxspec_scans[field_name] = field.construct_nxcollection(image_key, thetas,
+                    self.detector)
+            if include_raw_data:
+                image_stacks += field.get_detector_data(self.detector.prefix)
+                for scan_number in field.scan_numbers:
+                    parser = field.get_scanparser(scan_number)
+                    scan_info = field.stack_info[field.get_scan_index(scan_number)]
+                    num_image = scan_info['num_image']
+                    image_keys += num_image*[image_key]
+                    sequence_numbers += [i for i in range(num_image)]
+                    if thetas is None:
+                        rotation_angles += scan_info['num_image']*[0.0]
+                    else:
+                        assert(num_image == len(thetas))
+                        rotation_angles += thetas
+                    x_translations += scan_info['num_image']*[scan_info['ref_x']]
+                    z_translations += scan_info['num_image']*[scan_info['ref_z']]
+
+        if include_raw_data:
+            # Add image data to NXdetector
+            nxinstrument.detector.image_key = image_keys
+            nxinstrument.detector.sequence_number = sequence_numbers
+            nxinstrument.detector.data = np.concatenate([image for image in image_stacks])
+
+            # Add image data to NXsample
+            nxsample.rotation_angle = rotation_angles
+            nxsample.rotation_angle.attrs['units'] = 'degrees'
+            nxsample.x_translation = x_translations
+            nxsample.x_translation.attrs['units'] = 'mm'
+            nxsample.z_translation = z_translations
+            nxsample.z_translation.attrs['units'] = 'mm'
+
+            # Add an NXdata to NXentry
+            nxdata = NXdata()
+            nxentry.data = nxdata
+            nxdata.makelink(nxentry.instrument.detector.data, name='data')
+            nxdata.makelink(nxentry.instrument.detector.image_key)
+            nxdata.makelink(nxentry.sample.rotation_angle)
+            nxdata.makelink(nxentry.sample.x_translation)
+            nxdata.makelink(nxentry.sample.z_translation)
+#            nxdata.attrs['axes'] = ['field', 'row', 'column']
+#            nxdata.attrs['field_indices'] = 0
+#            nxdata.attrs['row_indices'] = 1
+#            nxdata.attrs['column_indices'] = 2
+
+
+class TomoWorkflow(BaseModel):
+    sample_maps: conlist(item_type=MapConfig, min_items=1) = [MapConfig.construct()]
+
+    @classmethod
+    def construct_from_nexus(cls, filename):
+        nxroot = nxload(filename)
+        sample_maps = []
+        config = {'sample_maps': sample_maps}
+        for nxentry_name, nxentry in nxroot.items():
+            sample_maps.append(MapConfig.construct_from_nxentry(nxentry))
+        return(cls(**config))
+
+    def cli(self):
+        print('\n -- Configure a map -- ')
+        self.set_list_attr_cli('sample_maps', 'sample map')
+
+    def construct_nxfile(self, filename, mode='w-'):
+        nxroot = NXroot()
+        t0 = time()
+        for sample_map in self.sample_maps:
+            logger.info(f'Start constructing the {sample_map.title} map.')
+            import_scanparser(sample_map.station)
+            sample_map.construct_nxentry(nxroot)
+        logger.info(f'Constructed all sample maps in {time()-t0:.2f} seconds.')
+        logger.info(f'Start saving all sample maps to {filename}.')
+        nxroot.save(filename, mode=mode)
+
+    def write_to_nexus(self, filename):
+        t0 = time()
+        self.construct_nxfile(filename, mode='w')
+        logger.info(f'Saved all sample maps to {filename} in {time()-t0:.2f} seconds.')
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/workflow/run_tomo.py	Tue Mar 21 16:22:42 2023 +0000
@@ -0,0 +1,1645 @@
+#!/usr/bin/env python3
+
+import logging
+logger = logging.getLogger(__name__)
+
+import numpy as np
+try:
+    import numexpr as ne
+except:
+    pass
+try:
+    import scipy.ndimage as spi
+except:
+    pass
+
+from multiprocessing import cpu_count
+from nexusformat.nexus import *
+from os import mkdir
+from os import path as os_path
+try:
+    from skimage.transform import iradon
+except:
+    pass
+try:
+    from skimage.restoration import denoise_tv_chambolle
+except:
+    pass
+from time import time
+try:
+    import tomopy
+except:
+    pass
+from yaml import safe_load, safe_dump
+
+try:
+    from msnctools.fit import Fit
+except:
+    from fit import Fit
+try:
+    from msnctools.general import illegal_value, is_int, is_int_pair, is_num, is_index_range, \
+            input_int, input_num, input_yesno, input_menu, draw_mask_1d, select_image_bounds, \
+            select_one_image_bound, clear_imshow, quick_imshow, clear_plot, quick_plot
+except:
+    from general import illegal_value, is_int, is_int_pair, is_num, is_index_range, \
+            input_int, input_num, input_yesno, input_menu, draw_mask_1d, select_image_bounds, \
+            select_one_image_bound, clear_imshow, quick_imshow, clear_plot, quick_plot
+
+try:
+    from workflow.models import import_scanparser, FlatField, TomoField, TomoWorkflow
+    from workflow.__version__ import __version__
+except:
+    pass
+
+num_core_tomopy_limit = 24
+
+def nxcopy(nxobject:NXobject, exclude_nxpaths:list[str]=[], nxpath_prefix:str='') -> NXobject:
+    '''Function that returns a copy of a nexus object, optionally exluding certain child items.
+
+    :param nxobject: the original nexus object to return a "copy" of
+    :type nxobject: nexusformat.nexus.NXobject
+    :param exlude_nxpaths: a list of paths to child nexus objects that
+        should be exluded from the returned "copy", defaults to `[]`
+    :type exclude_nxpaths: list[str], optional
+    :param nxpath_prefix: For use in recursive calls from inside this
+        function only!
+    :type nxpath_prefix: str
+    :return: a copy of `nxobject` with some children optionally exluded.
+    :rtype: NXobject
+    '''
+
+    nxobject_copy = nxobject.__class__()
+    if not len(nxpath_prefix):
+        if 'default' in nxobject.attrs:
+            nxobject_copy.attrs['default'] = nxobject.attrs['default']
+    else:
+        for k, v in nxobject.attrs.items():
+            nxobject_copy.attrs[k] = v
+
+    for k, v in nxobject.items():
+        nxpath = os_path.join(nxpath_prefix, k)
+
+        if nxpath in exclude_nxpaths:
+            continue
+
+        if isinstance(v, NXgroup):
+            nxobject_copy[k] = nxcopy(v, exclude_nxpaths=exclude_nxpaths,
+                    nxpath_prefix=os_path.join(nxpath_prefix, k))
+        else:
+            nxobject_copy[k] = v
+
+    return(nxobject_copy)
+
+class set_numexpr_threads:
+
+    def __init__(self, num_core):
+        if num_core is None or num_core < 1 or num_core > cpu_count():
+            self.num_core = cpu_count()
+        else:
+            self.num_core = num_core
+
+    def __enter__(self):
+        self.num_core_org = ne.set_num_threads(self.num_core)
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        ne.set_num_threads(self.num_core_org)
+
+class Tomo:
+    """Processing tomography data with misalignment.
+    """
+    def __init__(self, galaxy_flag=False, num_core=-1, output_folder='.', save_figs=None,
+            test_mode=False):
+        """Initialize with optional config input file or dictionary
+        """
+        if not isinstance(galaxy_flag, bool):
+            raise ValueError(f'Invalid parameter galaxy_flag ({galaxy_flag})')
+        self.galaxy_flag = galaxy_flag
+        self.num_core = num_core
+        if self.galaxy_flag:
+            if output_folder != '.':
+                logger.warning('Ignoring output_folder in galaxy mode')
+            self.output_folder = '.'
+            if test_mode != False:
+                logger.warning('Ignoring test_mode in galaxy mode')
+            self.test_mode = False
+            if save_figs is not None:
+                logger.warning('Ignoring save_figs in galaxy mode')
+            save_figs = 'only'
+        else:
+            self.output_folder = os_path.abspath(output_folder)
+            if not os_path.isdir(output_folder):
+                mkdir(os_path.abspath(output_folder))
+            if not isinstance(test_mode, bool):
+                raise ValueError(f'Invalid parameter test_mode ({test_mode})')
+            self.test_mode = test_mode
+            if save_figs is None:
+                save_figs = 'no'
+        self.test_config = {}
+        if self.test_mode:
+            if save_figs != 'only':
+                logger.warning('Ignoring save_figs in test mode')
+            save_figs = 'only'
+        if save_figs == 'only':
+            self.save_only = True
+            self.save_figs = True
+        elif save_figs == 'yes':
+            self.save_only = False
+            self.save_figs = True
+        elif save_figs == 'no':
+            self.save_only = False
+            self.save_figs = False
+        else:
+            raise ValueError(f'Invalid parameter save_figs ({save_figs})')
+        if self.save_only:
+            self.block = False
+        else:
+            self.block = True
+        if self.num_core == -1:
+            self.num_core = cpu_count()
+        if not is_int(self.num_core, gt=0, log=False):
+            raise ValueError(f'Invalid parameter num_core ({num_core})')
+        if self.num_core > cpu_count():
+            logger.warning(f'num_core = {self.num_core} is larger than the number of available '
+                    f'processors and reduced to {cpu_count()}')
+            self.num_core= cpu_count()
+
+    def read(self, filename):
+        logger.info(f'looking for {filename}')
+        if self.galaxy_flag:
+            try:
+                with open(filename, 'r') as f:
+                    config = safe_load(f)
+                return(config)
+            except:
+                try:
+                    with NXFile(filename, mode='r') as nxfile:
+                        nxroot = nxfile.readfile()
+                    return(nxroot)
+                except:
+                    raise ValueError(f'Unable to open ({filename})')
+        else:
+            extension = os_path.splitext(filename)[1]
+            if extension == '.yml' or extension == '.yaml':
+                with open(filename, 'r') as f:
+                    config = safe_load(f)
+#                if len(config) > 1:
+#                    raise ValueError(f'Multiple root entries in {filename} not yet implemented')
+#                if len(list(config.values())[0]) > 1:
+#                    raise ValueError(f'Multiple sample maps in {filename} not yet implemented')
+                return(config)
+            elif extension == '.nxs':
+                with NXFile(filename, mode='r') as nxfile:
+                    nxroot = nxfile.readfile()
+                return(nxroot)
+            else:
+                raise ValueError(f'Invalid filename extension ({extension})')
+
+    def write(self, data, filename):
+        extension = os_path.splitext(filename)[1]
+        if extension == '.yml' or extension == '.yaml':
+            with open(filename, 'w') as f:
+                safe_dump(data, f)
+        elif extension == '.nxs' or extension == '.nex':
+            data.save(filename, mode='w')
+        elif extension == '.nc':
+            data.to_netcdf(os_path=filename)
+        else:
+            raise ValueError(f'Invalid filename extension ({extension})')
+
+    def gen_reduced_data(self, data, img_x_bounds=None):
+        """Generate the reduced tomography images.
+        """
+        logger.info('Generate the reduced tomography images')
+
+        # Create plot galaxy path directory if needed
+        if self.galaxy_flag and not os_path.exists('tomo_reduce_plots'):
+            mkdir('tomo_reduce_plots')
+
+        if isinstance(data, dict):
+            # Create Nexus format object from input dictionary
+            wf = TomoWorkflow(**data)
+            if len(wf.sample_maps) > 1:
+                raise ValueError(f'Multiple sample maps not yet implemented')
+#            print(f'\nwf:\n{wf}\n')
+            nxroot = NXroot()
+            t0 = time()
+            for sample_map in wf.sample_maps:
+                logger.info(f'Start constructing the {sample_map.title} map.')
+                import_scanparser(sample_map.station)
+                sample_map.construct_nxentry(nxroot, include_raw_data=False)
+            logger.info(f'Constructed all sample maps in {time()-t0:.2f} seconds.')
+            nxentry = nxroot[nxroot.attrs['default']]
+            # Get test mode configuration info
+            if self.test_mode:
+                self.test_config = data['sample_maps'][0]['test_mode']
+        elif isinstance(data, NXroot):
+            nxentry = data[data.attrs['default']]
+        else:
+            raise ValueError(f'Invalid parameter data ({data})')
+
+        # Create an NXprocess to store data reduction (meta)data
+        reduced_data = NXprocess()
+
+        # Generate dark field
+        if 'dark_field' in nxentry['spec_scans']:
+            reduced_data = self._gen_dark(nxentry, reduced_data)
+
+        # Generate bright field
+        reduced_data = self._gen_bright(nxentry, reduced_data)
+
+        # Set vertical detector bounds for image stack
+        img_x_bounds = self._set_detector_bounds(nxentry, reduced_data, img_x_bounds=img_x_bounds)
+        logger.info(f'img_x_bounds = {img_x_bounds}')
+        reduced_data['img_x_bounds'] = img_x_bounds
+
+        # Set zoom and/or theta skip to reduce memory the requirement
+        zoom_perc, num_theta_skip = self._set_zoom_or_skip()
+        if zoom_perc is not None:
+            reduced_data.attrs['zoom_perc'] = zoom_perc
+        if num_theta_skip is not None:
+            reduced_data.attrs['num_theta_skip'] = num_theta_skip
+
+        # Generate reduced tomography fields
+        reduced_data = self._gen_tomo(nxentry, reduced_data)
+
+        # Create a copy of the input Nexus object and remove raw and any existing reduced data
+        if isinstance(data, NXroot):
+            exclude_items = [f'{nxentry._name}/reduced_data/data',
+                    f'{nxentry._name}/instrument/detector/data',
+                    f'{nxentry._name}/instrument/detector/image_key',
+                    f'{nxentry._name}/instrument/detector/sequence_number',
+                    f'{nxentry._name}/sample/rotation_angle',
+                    f'{nxentry._name}/sample/x_translation',
+                    f'{nxentry._name}/sample/z_translation',
+                    f'{nxentry._name}/data/data',
+                    f'{nxentry._name}/data/image_key',
+                    f'{nxentry._name}/data/rotation_angle',
+                    f'{nxentry._name}/data/x_translation',
+                    f'{nxentry._name}/data/z_translation']
+            nxroot = nxcopy(data, exclude_nxpaths=exclude_items)
+            nxentry = nxroot[nxroot.attrs['default']]
+
+        # Add the reduced data NXprocess
+        nxentry.reduced_data = reduced_data
+
+        if 'data' not in nxentry:
+            nxentry.data = NXdata()
+        nxentry.attrs['default'] = 'data'
+        nxentry.data.makelink(nxentry.reduced_data.data.tomo_fields, name='reduced_data')
+        nxentry.data.makelink(nxentry.reduced_data.rotation_angle, name='rotation_angle')
+        nxentry.data.attrs['signal'] = 'reduced_data'
+ 
+        return(nxroot)
+
+    def find_centers(self, nxroot, center_rows=None, center_stack_index=None):
+        """Find the calibrated center axis info
+        """
+        logger.info('Find the calibrated center axis info')
+
+        if not isinstance(nxroot, NXroot):
+            raise ValueError(f'Invalid parameter nxroot ({nxroot})')
+        nxentry = nxroot[nxroot.attrs['default']]
+        if not isinstance(nxentry, NXentry):
+            raise ValueError(f'Invalid nxentry ({nxentry})')
+        if self.galaxy_flag:
+            if center_rows is not None:
+                center_rows = tuple(center_rows)
+                if not is_int_pair(center_rows):
+                    raise ValueError(f'Invalid parameter center_rows ({center_rows})')
+        elif center_rows is not None:
+            logger.warning(f'Ignoring parameter center_rows ({center_rows})')
+            center_rows = None
+        if self.galaxy_flag:
+            if center_stack_index is not None and not is_int(center_stack_index, ge=0):
+                raise ValueError(f'Invalid parameter center_stack_index ({center_stack_index})')
+        elif center_stack_index is not None:
+            logger.warning(f'Ignoring parameter center_stack_index ({center_stack_index})')
+            center_stack_index = None
+
+        # Create plot galaxy path directory and path if needed
+        if self.galaxy_flag:
+            if not os_path.exists('tomo_find_centers_plots'):
+                mkdir('tomo_find_centers_plots')
+            path = 'tomo_find_centers_plots'
+        else:
+            path = self.output_folder
+
+        # Check if reduced data is available
+        if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data):
+            raise KeyError(f'Unable to find valid reduced data in {nxentry}.')
+
+        # Select the image stack to calibrate the center axis
+        #   reduced data axes order: stack,theta,row,column
+        #   Note: Nexus cannot follow a link if the data it points to is too big,
+        #         so get the data from the actual place, not from nxentry.data
+        tomo_fields_shape = nxentry.reduced_data.data.tomo_fields.shape
+        if len(tomo_fields_shape) != 4 or any(True for dim in tomo_fields_shape if not dim):
+            raise KeyError('Unable to load the required reduced tomography stack')
+        num_tomo_stacks = tomo_fields_shape[0]
+        if num_tomo_stacks == 1:
+            center_stack_index = 0
+            default = 'n'
+        else:
+            if self.test_mode:
+                center_stack_index = self.test_config['center_stack_index']-1 # make offset 0
+            elif self.galaxy_flag:
+                if center_stack_index is None:
+                    center_stack_index = int(num_tomo_stacks/2)
+                if center_stack_index >= num_tomo_stacks:
+                    raise ValueError(f'Invalid parameter center_stack_index ({center_stack_index})')
+            else:
+                center_stack_index = input_int('\nEnter tomography stack index to calibrate the '
+                        'center axis', ge=1, le=num_tomo_stacks, default=int(1+num_tomo_stacks/2))
+                center_stack_index -= 1
+            default = 'y'
+
+        # Get thetas (in degrees)
+        thetas = np.asarray(nxentry.reduced_data.rotation_angle)
+
+        # Get effective pixel_size
+        if 'zoom_perc' in nxentry.reduced_data:
+            eff_pixel_size = 100.*(nxentry.instrument.detector.x_pixel_size/
+                nxentry.reduced_data.attrs['zoom_perc'])
+        else:
+            eff_pixel_size = nxentry.instrument.detector.x_pixel_size
+
+        # Get cross sectional diameter
+        cross_sectional_dim = tomo_fields_shape[3]*eff_pixel_size
+        logger.debug(f'cross_sectional_dim = {cross_sectional_dim}')
+
+        # Determine center offset at sample row boundaries
+        logger.info('Determine center offset at sample row boundaries')
+
+        # Lower row center
+        if self.test_mode:
+            lower_row = self.test_config['lower_row']
+        elif self.galaxy_flag:
+            if center_rows is None or center_rows[0] == -1:
+                lower_row = 0
+            else:
+                lower_row = min(center_rows)
+                if not 0 <= lower_row < tomo_fields_shape[2]-1:
+                    raise ValueError(f'Invalid parameter center_rows ({center_rows})')
+        else:
+            lower_row = select_one_image_bound(
+                    nxentry.reduced_data.data.tomo_fields[center_stack_index,0,:,:], 0, bound=0,
+                    title=f'theta={round(thetas[0], 2)+0}',
+                    bound_name='row index to find lower center', default=default, raise_error=True)
+        logger.debug('Finding center...')
+        t0 = time()
+        lower_center_offset = self._find_center_one_plane(
+                #np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index,:,lower_row,:]),
+                nxentry.reduced_data.data.tomo_fields[center_stack_index,:,lower_row,:],
+                lower_row, thetas, eff_pixel_size, cross_sectional_dim, path=path,
+                num_core=self.num_core)
+        logger.debug(f'... done in {time()-t0:.2f} seconds')
+        logger.debug(f'lower_row = {lower_row:.2f}')
+        logger.debug(f'lower_center_offset = {lower_center_offset:.2f}')
+
+        # Upper row center
+        if self.test_mode:
+            upper_row = self.test_config['upper_row']
+        elif self.galaxy_flag:
+            if center_rows is None or center_rows[1] == -1:
+                upper_row = tomo_fields_shape[2]-1
+            else:
+                upper_row = max(center_rows)
+                if not lower_row < upper_row < tomo_fields_shape[2]:
+                    raise ValueError(f'Invalid parameter center_rows ({center_rows})')
+        else:
+            upper_row = select_one_image_bound(
+                    nxentry.reduced_data.data.tomo_fields[center_stack_index,0,:,:], 0,
+                    bound=tomo_fields_shape[2]-1, title=f'theta={round(thetas[0], 2)+0}',
+                    bound_name='row index to find upper center', default=default, raise_error=True)
+        logger.debug('Finding center...')
+        t0 = time()
+        upper_center_offset = self._find_center_one_plane(
+                #np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index,:,upper_row,:]),
+                nxentry.reduced_data.data.tomo_fields[center_stack_index,:,upper_row,:],
+                upper_row, thetas, eff_pixel_size, cross_sectional_dim, path=path,
+                num_core=self.num_core)
+        logger.debug(f'... done in {time()-t0:.2f} seconds')
+        logger.debug(f'upper_row = {upper_row:.2f}')
+        logger.debug(f'upper_center_offset = {upper_center_offset:.2f}')
+
+        center_config = {'lower_row': lower_row, 'lower_center_offset': lower_center_offset,
+                'upper_row': upper_row, 'upper_center_offset': upper_center_offset}
+        if num_tomo_stacks > 1:
+            center_config['center_stack_index'] = center_stack_index+1 # save as offset 1
+
+        # Save test data to file
+        if self.test_mode:
+            with open(f'{self.output_folder}/center_config.yaml', 'w') as f:
+                safe_dump(center_config, f)
+
+        return(center_config)
+
+    def reconstruct_data(self, nxroot, center_info, x_bounds=None, y_bounds=None):
+        """Reconstruct the tomography data.
+        """
+        logger.info('Reconstruct the tomography data')
+
+        if not isinstance(nxroot, NXroot):
+            raise ValueError(f'Invalid parameter nxroot ({nxroot})')
+        nxentry = nxroot[nxroot.attrs['default']]
+        if not isinstance(nxentry, NXentry):
+            raise ValueError(f'Invalid nxentry ({nxentry})')
+        if not isinstance(center_info, dict):
+            raise ValueError(f'Invalid parameter center_info ({center_info})')
+
+        # Create plot galaxy path directory and path if needed
+        if self.galaxy_flag:
+            if not os_path.exists('tomo_reconstruct_plots'):
+                mkdir('tomo_reconstruct_plots')
+            path = 'tomo_reconstruct_plots'
+        else:
+            path = self.output_folder
+
+        # Check if reduced data is available
+        if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data):
+            raise KeyError(f'Unable to find valid reduced data in {nxentry}.')
+
+        # Create an NXprocess to store image reconstruction (meta)data
+        nxprocess = NXprocess()
+
+        # Get rotation axis rows and centers
+        lower_row = center_info.get('lower_row')
+        lower_center_offset = center_info.get('lower_center_offset')
+        upper_row = center_info.get('upper_row')
+        upper_center_offset = center_info.get('upper_center_offset')
+        if (lower_row is None or lower_center_offset is None or upper_row is None or
+                upper_center_offset is None):
+            raise KeyError(f'Unable to find valid calibrated center axis info in {center_info}.')
+        center_slope = (upper_center_offset-lower_center_offset)/(upper_row-lower_row)
+
+        # Get thetas (in degrees)
+        thetas = np.asarray(nxentry.reduced_data.rotation_angle)
+
+        # Reconstruct tomography data
+        #   reduced data axes order: stack,theta,row,column
+        #   reconstructed data order in each stack: row/z,x,y
+        #   Note: Nexus cannot follow a link if the data it points to is too big,
+        #         so get the data from the actual place, not from nxentry.data
+        if 'zoom_perc' in nxentry.reduced_data:
+            res_title = f'{nxentry.reduced_data.attrs["zoom_perc"]}p'
+        else:
+            res_title = 'fullres'
+        load_error = False
+        num_tomo_stacks = nxentry.reduced_data.data.tomo_fields.shape[0]
+        tomo_recon_stacks = num_tomo_stacks*[np.array([])]
+        for i in range(num_tomo_stacks):
+            # Convert reduced data stack from theta,row,column to row,theta,column
+            logger.debug(f'Reading reduced data stack {i+1}...')
+            t0 = time()
+            tomo_stack = np.asarray(nxentry.reduced_data.data.tomo_fields[i])
+            logger.debug(f'... done in {time()-t0:.2f} seconds')
+            if len(tomo_stack.shape) != 3 or any(True for dim in tomo_stack.shape if not dim):
+                raise ValueError(f'Unable to load tomography stack {i+1} for reconstruction')
+            tomo_stack = np.swapaxes(tomo_stack, 0, 1)
+            assert(len(thetas) == tomo_stack.shape[1])
+            assert(0 <= lower_row < upper_row < tomo_stack.shape[0])
+            center_offsets = [lower_center_offset-lower_row*center_slope,
+                    upper_center_offset+(tomo_stack.shape[0]-1-upper_row)*center_slope]
+            t0 = time()
+            logger.debug(f'Running _reconstruct_one_tomo_stack on {self.num_core} cores ...')
+            tomo_recon_stack = self._reconstruct_one_tomo_stack(tomo_stack, thetas,
+                    center_offsets=center_offsets, num_core=self.num_core, algorithm='gridrec')
+            logger.debug(f'... done in {time()-t0:.2f} seconds')
+            logger.info(f'Reconstruction of stack {i+1} took {time()-t0:.2f} seconds')
+
+            # Combine stacks
+            tomo_recon_stacks[i] = tomo_recon_stack
+
+        # Resize the reconstructed tomography data
+        #   reconstructed data order in each stack: row/z,x,y
+        if self.test_mode:
+            x_bounds = self.test_config.get('x_bounds')
+            y_bounds = self.test_config.get('y_bounds')
+            z_bounds = None
+        elif self.galaxy_flag:
+            if x_bounds is not None and not is_int_pair(x_bounds, ge=0,
+                    lt=tomo_recon_stacks[0].shape[1]):
+                raise ValueError(f'Invalid parameter x_bounds ({x_bounds})')
+            if y_bounds is not None and not is_int_pair(y_bounds, ge=0,
+                    lt=tomo_recon_stacks[0].shape[1]):
+                raise ValueError(f'Invalid parameter y_bounds ({y_bounds})')
+            z_bounds = None
+        else:
+            x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_stacks)
+        if x_bounds is None:
+            x_range = (0, tomo_recon_stacks[0].shape[1])
+            x_slice = int(x_range[1]/2)
+        else:
+            x_range = (min(x_bounds), max(x_bounds))
+            x_slice = int((x_bounds[0]+x_bounds[1])/2)
+        if y_bounds is None:
+            y_range = (0, tomo_recon_stacks[0].shape[2])
+            y_slice = int(y_range[1]/2)
+        else:
+            y_range = (min(y_bounds), max(y_bounds))
+            y_slice = int((y_bounds[0]+y_bounds[1])/2)
+        if z_bounds is None:
+            z_range = (0, tomo_recon_stacks[0].shape[0])
+            z_slice = int(z_range[1]/2)
+        else:
+            z_range = (min(z_bounds), max(z_bounds))
+            z_slice = int((z_bounds[0]+z_bounds[1])/2)
+
+        # Plot a few reconstructed image slices
+        if num_tomo_stacks == 1:
+            basetitle = 'recon'
+        else:
+            basetitle = f'recon stack {i+1}'
+        for i, stack in enumerate(tomo_recon_stacks):
+            title = f'{basetitle} {res_title} xslice{x_slice}'
+            quick_imshow(stack[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]],
+                    title=title, path=path, save_fig=self.save_figs, save_only=self.save_only,
+                    block=self.block)
+            title = f'{basetitle} {res_title} yslice{y_slice}'
+            quick_imshow(stack[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice],
+                    title=title, path=path, save_fig=self.save_figs, save_only=self.save_only,
+                    block=self.block)
+            title = f'{basetitle} {res_title} zslice{z_slice}'
+            quick_imshow(stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]],
+                    title=title, path=path, save_fig=self.save_figs, save_only=self.save_only,
+                    block=self.block)
+
+        # Save test data to file
+        #   reconstructed data order in each stack: row/z,x,y
+        if self.test_mode:
+            for i, stack in enumerate(tomo_recon_stacks):
+                np.savetxt(f'{self.output_folder}/recon_stack_{i+1}.txt',
+                        stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e')
+
+        # Add image reconstruction to reconstructed data NXprocess
+        #   reconstructed data order in each stack: row/z,x,y
+        nxprocess.data = NXdata()
+        nxprocess.attrs['default'] = 'data'
+        for k, v in center_info.items():
+            nxprocess[k] = v
+        if x_bounds is not None:
+            nxprocess.x_bounds = x_bounds
+        if y_bounds is not None:
+            nxprocess.y_bounds = y_bounds
+        if z_bounds is not None:
+            nxprocess.z_bounds = z_bounds
+        nxprocess.data['reconstructed_data'] = np.asarray([stack[z_range[0]:z_range[1],
+                x_range[0]:x_range[1],y_range[0]:y_range[1]] for stack in tomo_recon_stacks])
+        nxprocess.data.attrs['signal'] = 'reconstructed_data'
+
+        # Create a copy of the input Nexus object and remove reduced data
+        exclude_items = [f'{nxentry._name}/reduced_data/data', f'{nxentry._name}/data/reduced_data']
+        nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items)
+
+        # Add the reconstructed data NXprocess to the new Nexus object
+        nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']]
+        nxentry_copy.reconstructed_data = nxprocess
+        if 'data' not in nxentry_copy:
+            nxentry_copy.data = NXdata()
+        nxentry_copy.attrs['default'] = 'data'
+        nxentry_copy.data.makelink(nxprocess.data.reconstructed_data, name='reconstructed_data')
+        nxentry_copy.data.attrs['signal'] = 'reconstructed_data'
+
+        return(nxroot_copy)
+
+    def combine_data(self, nxroot, x_bounds=None, y_bounds=None):
+        """Combine the reconstructed tomography stacks.
+        """
+        logger.info('Combine the reconstructed tomography stacks')
+
+        if not isinstance(nxroot, NXroot):
+            raise ValueError(f'Invalid parameter nxroot ({nxroot})')
+        nxentry = nxroot[nxroot.attrs['default']]
+        if not isinstance(nxentry, NXentry):
+            raise ValueError(f'Invalid nxentry ({nxentry})')
+
+        # Create plot galaxy path directory and path if needed
+        if self.galaxy_flag:
+            if not os_path.exists('tomo_combine_plots'):
+                mkdir('tomo_combine_plots')
+            path = 'tomo_combine_plots'
+        else:
+            path = self.output_folder
+
+        # Check if reconstructed image data is available
+        if ('reconstructed_data' not in nxentry or 'reconstructed_data' not in nxentry.data):
+            raise KeyError(f'Unable to find valid reconstructed image data in {nxentry}.')
+
+        # Create an NXprocess to store combined image reconstruction (meta)data
+        nxprocess = NXprocess()
+
+        # Get the reconstructed data
+        #   reconstructed data order: stack,row(z),x,y
+        #   Note: Nexus cannot follow a link if the data it points to is too big,
+        #         so get the data from the actual place, not from nxentry.data
+        num_tomo_stacks = nxentry.reconstructed_data.data.reconstructed_data.shape[0]
+        if num_tomo_stacks == 1:
+            logger.info('Only one stack available: leaving combine_data')
+            return(None)
+
+        # Combine the reconstructed stacks
+        # (load one stack at a time to reduce risk of hitting Nexus data access limit)
+        t0 = time()
+        logger.debug(f'Combining the reconstructed stacks ...')
+        tomo_recon_combined = np.asarray(nxentry.reconstructed_data.data.reconstructed_data[0])
+        if num_tomo_stacks > 2:
+            tomo_recon_combined = np.concatenate([tomo_recon_combined]+
+                    [nxentry.reconstructed_data.data.reconstructed_data[i]
+                    for i in range(1, num_tomo_stacks-1)])
+        if num_tomo_stacks > 1:
+            tomo_recon_combined = np.concatenate([tomo_recon_combined]+
+                    [nxentry.reconstructed_data.data.reconstructed_data[num_tomo_stacks-1]])
+        logger.debug(f'... done in {time()-t0:.2f} seconds')
+        logger.info(f'Combining the reconstructed stacks took {time()-t0:.2f} seconds')
+
+        # Resize the combined tomography data stacks
+        #   combined data order: row/z,x,y
+        if self.test_mode:
+            x_bounds = None
+            y_bounds = None
+            z_bounds = self.test_config.get('z_bounds')
+        elif self.galaxy_flag:
+            if x_bounds is not None and not is_int_pair(x_bounds, ge=0,
+                    lt=tomo_recon_stacks[0].shape[1]):
+                raise ValueError(f'Invalid parameter x_bounds ({x_bounds})')
+            if y_bounds is not None and not is_int_pair(y_bounds, ge=0,
+                    lt=tomo_recon_stacks[0].shape[1]):
+                raise ValueError(f'Invalid parameter y_bounds ({y_bounds})')
+            z_bounds = None
+        else:
+            x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_combined,
+                    z_only=True)
+        if x_bounds is None:
+            x_range = (0, tomo_recon_combined.shape[1])
+            x_slice = int(x_range[1]/2)
+        else:
+            x_range = x_bounds
+            x_slice = int((x_bounds[0]+x_bounds[1])/2)
+        if y_bounds is None:
+            y_range = (0, tomo_recon_combined.shape[2])
+            y_slice = int(y_range[1]/2)
+        else:
+            y_range = y_bounds
+            y_slice = int((y_bounds[0]+y_bounds[1])/2)
+        if z_bounds is None:
+            z_range = (0, tomo_recon_combined.shape[0])
+            z_slice = int(z_range[1]/2)
+        else:
+            z_range = z_bounds
+            z_slice = int((z_bounds[0]+z_bounds[1])/2)
+
+        # Plot a few combined image slices
+        quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]],
+                title=f'recon combined xslice{x_slice}', path=path,
+                save_fig=self.save_figs, save_only=self.save_only, block=self.block)
+        quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice],
+                title=f'recon combined yslice{y_slice}', path=path,
+                save_fig=self.save_figs, save_only=self.save_only, block=self.block)
+        quick_imshow(tomo_recon_combined[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]],
+                title=f'recon combined zslice{z_slice}', path=path,
+                save_fig=self.save_figs, save_only=self.save_only, block=self.block)
+
+        # Save test data to file
+        #   combined data order: row/z,x,y
+        if self.test_mode:
+            np.savetxt(f'{self.output_folder}/recon_combined.txt', tomo_recon_combined[
+                    z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e')
+
+        # Add image reconstruction to reconstructed data NXprocess
+        #   combined data order: row/z,x,y
+        nxprocess.data = NXdata()
+        nxprocess.attrs['default'] = 'data'
+        if x_bounds is not None:
+            nxprocess.x_bounds = x_bounds
+        if y_bounds is not None:
+            nxprocess.y_bounds = y_bounds
+        if z_bounds is not None:
+            nxprocess.z_bounds = z_bounds
+        nxprocess.data['combined_data'] = tomo_recon_combined
+        nxprocess.data.attrs['signal'] = 'combined_data'
+
+        # Create a copy of the input Nexus object and remove reconstructed data
+        exclude_items = [f'{nxentry._name}/reconstructed_data/data',
+                f'{nxentry._name}/data/reconstructed_data']
+        nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items)
+
+        # Add the combined data NXprocess to the new Nexus object
+        nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']]
+        nxentry_copy.combined_data = nxprocess
+        if 'data' not in nxentry_copy:
+            nxentry_copy.data = NXdata()
+        nxentry_copy.attrs['default'] = 'data'
+        nxentry_copy.data.makelink(nxprocess.data.combined_data, name='combined_data')
+        nxentry_copy.data.attrs['signal'] = 'combined_data'
+
+        return(nxroot_copy)
+
+    def _gen_dark(self, nxentry, reduced_data):
+        """Generate dark field.
+        """
+        # Get the dark field images
+        image_key = nxentry.instrument.detector.get('image_key', None)
+        if image_key and 'data' in nxentry.instrument.detector:
+            field_indices = [index for index, key in enumerate(image_key) if key == 2]
+            tdf_stack = nxentry.instrument.detector.data[field_indices,:,:]
+            # RV the default NXtomo form does not accomodate bright or dark field stacks
+        else:
+            dark_field_scans = nxentry.spec_scans.dark_field
+            dark_field = FlatField.construct_from_nxcollection(dark_field_scans)
+            prefix = str(nxentry.instrument.detector.local_name)
+            tdf_stack = dark_field.get_detector_data(prefix)
+            if isinstance(tdf_stack, list):
+                assert(len(tdf_stack) == 1) # TODO
+                tdf_stack = tdf_stack[0]
+
+        # Take median
+        if tdf_stack.ndim == 2:
+            tdf = tdf_stack
+        elif tdf_stack.ndim == 3:
+            tdf = np.median(tdf_stack, axis=0)
+            del tdf_stack
+        else:
+           raise ValueError(f'Invalid tdf_stack shape ({tdf_stack.shape})')
+
+        # Remove dark field intensities above the cutoff
+#RV        tdf_cutoff = None
+        tdf_cutoff = tdf.min()+2*(np.median(tdf)-tdf.min())
+        logger.debug(f'tdf_cutoff = {tdf_cutoff}')
+        if tdf_cutoff is not None:
+            if not is_num(tdf_cutoff, ge=0):
+                logger.warning(f'Ignoring illegal value of tdf_cutoff {tdf_cutoff}')
+            else:
+                tdf[tdf > tdf_cutoff] = np.nan
+                logger.debug(f'tdf_cutoff = {tdf_cutoff}')
+
+        # Remove nans
+        tdf_mean = np.nanmean(tdf)
+        logger.debug(f'tdf_mean = {tdf_mean}')
+        np.nan_to_num(tdf, copy=False, nan=tdf_mean, posinf=tdf_mean, neginf=0.)
+
+        # Plot dark field
+        if self.galaxy_flag:
+            quick_imshow(tdf, title='dark field', path='tomo_reduce_plots', save_fig=self.save_figs,
+                    save_only=self.save_only)
+        elif not self.test_mode:
+            quick_imshow(tdf, title='dark field', path=self.output_folder, save_fig=self.save_figs,
+                    save_only=self.save_only)
+            clear_imshow('dark field')
+#        quick_imshow(tdf, title='dark field', block=True)
+
+        # Add dark field to reduced data NXprocess
+        reduced_data.data = NXdata()
+        reduced_data.data['dark_field'] = tdf
+
+        return(reduced_data)
+
+    def _gen_bright(self, nxentry, reduced_data):
+        """Generate bright field.
+        """
+        # Get the bright field images
+        image_key = nxentry.instrument.detector.get('image_key', None)
+        if image_key and 'data' in nxentry.instrument.detector:
+            field_indices = [index for index, key in enumerate(image_key) if key == 1]
+            tbf_stack = nxentry.instrument.detector.data[field_indices,:,:]
+            # RV the default NXtomo form does not accomodate bright or dark field stacks
+        else:
+            bright_field_scans = nxentry.spec_scans.bright_field
+            bright_field = FlatField.construct_from_nxcollection(bright_field_scans)
+            prefix = str(nxentry.instrument.detector.local_name)
+            tbf_stack = bright_field.get_detector_data(prefix)
+            if isinstance(tbf_stack, list):
+                assert(len(tbf_stack) == 1) # TODO
+                tbf_stack = tbf_stack[0]
+
+        # Take median if more than one image
+        """Median or mean: It may be best to try the median because of some image 
+           artifacts that arise due to crinkles in the upstream kapton tape windows 
+           causing some phase contrast images to appear on the detector.
+           One thing that also may be useful in a future implementation is to do a 
+           brightfield adjustment on EACH frame of the tomo based on a ROI in the 
+           corner of the frame where there is no sample but there is the direct X-ray 
+           beam because there is frame to frame fluctuations from the incoming beam. 
+           We don’t typically account for them but potentially could.
+        """
+        if tbf_stack.ndim == 2:
+            tbf = tbf_stack
+        elif tbf_stack.ndim == 3:
+            tbf = np.median(tbf_stack, axis=0)
+            del tbf_stack
+        else:
+           raise ValueError(f'Invalid tbf_stack shape ({tbf_stacks.shape})')
+
+        # Subtract dark field
+        if 'data' in reduced_data and 'dark_field' in reduced_data.data:
+            tbf -= reduced_data.data.dark_field
+        else:
+            logger.warning('Dark field unavailable')
+
+        # Set any non-positive values to one
+        # (avoid negative bright field values for spikes in dark field)
+        tbf[tbf < 1] = 1
+
+        # Plot bright field
+        if self.galaxy_flag:
+            quick_imshow(tbf, title='bright field', path='tomo_reduce_plots',
+                    save_fig=self.save_figs, save_only=self.save_only)
+        elif not self.test_mode:
+            quick_imshow(tbf, title='bright field', path=self.output_folder,
+                    save_fig=self.save_figs, save_only=self.save_only)
+            clear_imshow('bright field')
+#        quick_imshow(tbf, title='bright field', block=True)
+
+        # Add bright field to reduced data NXprocess
+        if 'data' not in reduced_data: 
+            reduced_data.data = NXdata()
+        reduced_data.data['bright_field'] = tbf
+
+        return(reduced_data)
+
+    def _set_detector_bounds(self, nxentry, reduced_data, img_x_bounds=None):
+        """Set vertical detector bounds for each image stack.
+        Right now the range is the same for each set in the image stack.
+        """
+        if self.test_mode:
+            return(tuple(self.test_config['img_x_bounds']))
+
+        # Get the first tomography image and the reference heights
+        image_key = nxentry.instrument.detector.get('image_key', None)
+        if image_key and 'data' in nxentry.instrument.detector:
+            field_indices = [index for index, key in enumerate(image_key) if key == 0]
+            first_image = np.asarray(nxentry.instrument.detector.data[field_indices[0],:,:])
+            theta = float(nxentry.sample.rotation_angle[field_indices[0]])
+            z_translation_all = nxentry.sample.z_translation[field_indices]
+            vertical_shifts = sorted(list(set(z_translation_all)))
+            num_tomo_stacks = len(vertical_shifts)
+        else:
+            tomo_field_scans = nxentry.spec_scans.tomo_fields
+            tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans)
+            vertical_shifts = tomo_fields.get_vertical_shifts()
+            if not isinstance(vertical_shifts, list):
+               vertical_shifts = [vertical_shifts]
+            prefix = str(nxentry.instrument.detector.local_name)
+            t0 = time()
+            first_image = tomo_fields.get_detector_data(prefix, tomo_fields.scan_numbers[0], 0)
+            logger.debug(f'Getting first image took {time()-t0:.2f} seconds')
+            num_tomo_stacks = len(tomo_fields.scan_numbers)
+            theta = tomo_fields.theta_range['start']
+
+        # Select image bounds
+        title = f'tomography image at theta={round(theta, 2)+0}'
+        if img_x_bounds is not None:
+            if is_int_pair(img_x_bounds) and img_x_bounds[0] == -1 and img_x_bounds[1] == -1:
+                img_x_bounds = None
+            elif not is_index_range(img_x_bounds, ge=0, le=first_image.shape[0]):
+                raise ValueError(f'Invalid parameter img_x_bounds ({img_x_bounds})')
+        if nxentry.instrument.source.attrs['station'] in ('id1a3', 'id3a'):
+            pixel_size = nxentry.instrument.detector.x_pixel_size
+            # Try to get a fit from the bright field
+            tbf = np.asarray(reduced_data.data.bright_field)
+            tbf_shape = tbf.shape
+            x_sum = np.sum(tbf, 1)
+            x_sum_min = x_sum.min()
+            x_sum_max = x_sum.max()
+            fit = Fit.fit_data(x_sum, 'rectangle', x=np.array(range(len(x_sum))), form='atan',
+                    guess=True)
+            parameters = fit.best_values
+            x_low_fit = parameters.get('center1', None)
+            x_upp_fit = parameters.get('center2', None)
+            sig_low = parameters.get('sigma1', None)
+            sig_upp = parameters.get('sigma2', None)
+            have_fit = fit.success and x_low_fit is not None and x_upp_fit is not None and \
+                    sig_low is not None and sig_upp is not None and \
+                    0 <= x_low_fit < x_upp_fit <= x_sum.size and \
+                    (sig_low+sig_upp)/(x_upp_fit-x_low_fit) < 0.1
+            if have_fit:
+                # Set a 5% margin on each side
+                margin = 0.05*(x_upp_fit-x_low_fit)
+                x_low_fit = max(0, x_low_fit-margin)
+                x_upp_fit = min(tbf_shape[0], x_upp_fit+margin)
+            if num_tomo_stacks == 1:
+                if have_fit:
+                    # Set the default range to enclose the full fitted window
+                    x_low = int(x_low_fit)
+                    x_upp = int(x_upp_fit)
+                else:
+                    # Center a default range of 1 mm (RV: can we get this from the slits?)
+                    num_x_min = int((1.0-0.5*pixel_size)/pixel_size)
+                    x_low = int(0.5*(tbf_shape[0]-num_x_min))
+                    x_upp = x_low+num_x_min
+            else:
+                # Get the default range from the reference heights
+                delta_z = vertical_shifts[1]-vertical_shifts[0]
+                for i in range(2, num_tomo_stacks):
+                    delta_z = min(delta_z, vertical_shifts[i]-vertical_shifts[i-1])
+                logger.debug(f'delta_z = {delta_z}')
+                num_x_min = int((delta_z-0.5*pixel_size)/pixel_size)
+                logger.debug(f'num_x_min = {num_x_min}')
+                if num_x_min > tbf_shape[0]:
+                    logger.warning('Image bounds and pixel size prevent seamless stacking')
+                if have_fit:
+                    # Center the default range relative to the fitted window
+                    x_low = int(0.5*(x_low_fit+x_upp_fit-num_x_min))
+                    x_upp = x_low+num_x_min
+                else:
+                    # Center the default range
+                    x_low = int(0.5*(tbf_shape[0]-num_x_min))
+                    x_upp = x_low+num_x_min
+            if self.galaxy_flag:
+                img_x_bounds = (x_low, x_upp)
+            else:
+                tmp = np.copy(tbf)
+                tmp_max = tmp.max()
+                tmp[x_low,:] = tmp_max
+                tmp[x_upp-1,:] = tmp_max
+                quick_imshow(tmp, title='bright field')
+                tmp = np.copy(first_image)
+                tmp_max = tmp.max()
+                tmp[x_low,:] = tmp_max
+                tmp[x_upp-1,:] = tmp_max
+                quick_imshow(tmp, title=title)
+                del tmp
+                quick_plot((range(x_sum.size), x_sum),
+                        ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'),
+                        ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'),
+                        title='sum over theta and y')
+                print(f'lower bound = {x_low} (inclusive)')
+                print(f'upper bound = {x_upp} (exclusive)]')
+                accept =  input_yesno('Accept these bounds (y/n)?', 'y')
+                clear_imshow('bright field')
+                clear_imshow(title)
+                clear_plot('sum over theta and y')
+                if accept:
+                    img_x_bounds = (x_low, x_upp)
+                else:
+                    while True:
+                        mask, img_x_bounds = draw_mask_1d(x_sum, title='select x data range',
+                                legend='sum over theta and y')
+                        if len(img_x_bounds) == 1:
+                            break
+                        else:
+                            print(f'Choose a single connected data range')
+                    img_x_bounds = tuple(img_x_bounds[0])
+            if (num_tomo_stacks > 1 and img_x_bounds[1]-img_x_bounds[0]+1 < 
+                    int((delta_z-0.5*pixel_size)/pixel_size)):
+                logger.warning('Image bounds and pixel size prevent seamless stacking')
+        else:
+            if num_tomo_stacks > 1:
+                raise NotImplementedError('Selecting image bounds for multiple stacks on FMB')
+            # For FMB: use the first tomography image to select range
+            # RV: revisit if they do tomography with multiple stacks
+            x_sum = np.sum(first_image, 1)
+            x_sum_min = x_sum.min()
+            x_sum_max = x_sum.max()
+            if self.galaxy_flag:
+                if img_x_bounds is None:
+                    img_x_bounds = (0, first_image.shape[0])
+            else:
+                quick_imshow(first_image, title=title)
+                print('Select vertical data reduction range from first tomography image')
+                img_x_bounds = select_image_bounds(first_image, 0, title=title)
+                clear_imshow(title)
+                if img_x_bounds is None:
+                    raise ValueError('Unable to select image bounds')
+
+        # Plot results
+        if self.galaxy_flag:
+            path = 'tomo_reduce_plots'
+        else:
+            path = self.output_folder
+        x_low = img_x_bounds[0]
+        x_upp = img_x_bounds[1]
+        tmp = np.copy(first_image)
+        tmp_max = tmp.max()
+        tmp[x_low,:] = tmp_max
+        tmp[x_upp-1,:] = tmp_max
+        quick_imshow(tmp, title=title, path=path, save_fig=self.save_figs, save_only=self.save_only,
+                block=self.block)
+        del tmp
+        quick_plot((range(x_sum.size), x_sum),
+                ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'),
+                ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'),
+                title='sum over theta and y', path=path, save_fig=self.save_figs,
+                save_only=self.save_only, block=self.block)
+
+        return(img_x_bounds)
+
+    def _set_zoom_or_skip(self):
+        """Set zoom and/or theta skip to reduce memory the requirement for the analysis.
+        """
+#        if input_yesno('\nDo you want to zoom in to reduce memory requirement (y/n)?', 'n'):
+#            zoom_perc = input_int('    Enter zoom percentage', ge=1, le=100)
+#        else:
+#            zoom_perc = None
+        zoom_perc = None
+#        if input_yesno('Do you want to skip thetas to reduce memory requirement (y/n)?', 'n'):
+#            num_theta_skip = input_int('    Enter the number skip theta interval', ge=0,
+#                    lt=num_theta)
+#        else:
+#            num_theta_skip = None
+        num_theta_skip = None
+        logger.debug(f'zoom_perc = {zoom_perc}')
+        logger.debug(f'num_theta_skip = {num_theta_skip}')
+
+        return(zoom_perc, num_theta_skip)
+
+    def _gen_tomo(self, nxentry, reduced_data):
+        """Generate tomography fields.
+        """
+        # Get full bright field
+        tbf = np.asarray(reduced_data.data.bright_field)
+        tbf_shape = tbf.shape
+
+        # Get image bounds
+        img_x_bounds = tuple(reduced_data.get('img_x_bounds', (0, tbf_shape[0])))
+        img_y_bounds = tuple(reduced_data.get('img_y_bounds', (0, tbf_shape[1])))
+
+        # Get resized dark field
+#        if 'dark_field' in data:
+#            tbf = np.asarray(reduced_data.data.dark_field[
+#                    img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]])
+#        else:
+#            logger.warning('Dark field unavailable')
+#            tdf = None
+        tdf = None
+
+        # Resize bright field
+        if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]):
+            tbf = tbf[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]]
+
+        # Get the tomography images
+        image_key = nxentry.instrument.detector.get('image_key', None)
+        if image_key and 'data' in nxentry.instrument.detector:
+            field_indices_all = [index for index, key in enumerate(image_key) if key == 0]
+            z_translation_all = nxentry.sample.z_translation[field_indices_all]
+            z_translation_levels = sorted(list(set(z_translation_all)))
+            num_tomo_stacks = len(z_translation_levels)
+            tomo_stacks = num_tomo_stacks*[np.array([])]
+            horizontal_shifts = []
+            vertical_shifts = []
+            thetas = None
+            tomo_stacks = []
+            for i, z_translation in enumerate(z_translation_levels):
+                field_indices = [field_indices_all[index]
+                        for index, z in enumerate(z_translation_all) if z == z_translation]
+                horizontal_shift = list(set(nxentry.sample.x_translation[field_indices]))
+                assert(len(horizontal_shift) == 1)
+                horizontal_shifts += horizontal_shift
+                vertical_shift = list(set(nxentry.sample.z_translation[field_indices]))
+                assert(len(vertical_shift) == 1)
+                vertical_shifts += vertical_shift
+                sequence_numbers = nxentry.instrument.detector.sequence_number[field_indices]
+                if thetas is None:
+                    thetas = np.asarray(nxentry.sample.rotation_angle[field_indices]) \
+                             [sequence_numbers]
+                else:
+                    assert(all(thetas[i] == nxentry.sample.rotation_angle[field_indices[index]]
+                            for i, index in enumerate(sequence_numbers)))
+                assert(list(set(sequence_numbers)) == [i for i in range(len(sequence_numbers))])
+                if list(sequence_numbers) == [i for i in range(len(sequence_numbers))]:
+                    tomo_stack = np.asarray(nxentry.instrument.detector.data[field_indices])
+                else:
+                    raise ValueError('Unable to load the tomography images')
+                tomo_stacks.append(tomo_stack)
+        else:
+            tomo_field_scans = nxentry.spec_scans.tomo_fields
+            tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans)
+            horizontal_shifts = tomo_fields.get_horizontal_shifts()
+            vertical_shifts = tomo_fields.get_vertical_shifts()
+            prefix = str(nxentry.instrument.detector.local_name)
+            t0 = time()
+            tomo_stacks = tomo_fields.get_detector_data(prefix)
+            logger.debug(f'Getting tomography images took {time()-t0:.2f} seconds')
+            logger.debug(f'Getting all images took {time()-t0:.2f} seconds')
+            thetas = np.linspace(tomo_fields.theta_range['start'], tomo_fields.theta_range['end'],
+                    tomo_fields.theta_range['num'])
+            if not isinstance(tomo_stacks, list):
+                horizontal_shifts = [horizontal_shifts]
+                vertical_shifts = [vertical_shifts]
+                tomo_stacks = [tomo_stacks]
+
+        reduced_tomo_stacks = []
+        if self.galaxy_flag:
+            path = 'tomo_reduce_plots'
+        else:
+            path = self.output_folder
+        for i, tomo_stack in enumerate(tomo_stacks):
+            # Resize the tomography images
+            # Right now the range is the same for each set in the image stack.
+            if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]):
+                t0 = time()
+                tomo_stack = tomo_stack[:,img_x_bounds[0]:img_x_bounds[1],
+                        img_y_bounds[0]:img_y_bounds[1]].astype('float64')
+                logger.debug(f'Resizing tomography images took {time()-t0:.2f} seconds')
+
+            # Subtract dark field
+            if tdf is not None:
+                t0 = time()
+                with set_numexpr_threads(self.num_core):
+                    ne.evaluate('tomo_stack-tdf', out=tomo_stack)
+                logger.debug(f'Subtracting dark field took {time()-t0:.2f} seconds')
+
+            # Normalize
+            t0 = time()
+            with set_numexpr_threads(self.num_core):
+                ne.evaluate('tomo_stack/tbf', out=tomo_stack, truediv=True)
+            logger.debug(f'Normalizing took {time()-t0:.2f} seconds')
+
+            # Remove non-positive values and linearize data
+            t0 = time()
+            cutoff = 1.e-6
+            with set_numexpr_threads(self.num_core):
+                ne.evaluate('where(tomo_stack<cutoff, cutoff, tomo_stack)', out=tomo_stack)
+            with set_numexpr_threads(self.num_core):
+                ne.evaluate('-log(tomo_stack)', out=tomo_stack)
+            logger.debug('Removing non-positive values and linearizing data took '+
+                    f'{time()-t0:.2f} seconds')
+
+            # Get rid of nans/infs that may be introduced by normalization
+            t0 = time()
+            np.where(np.isfinite(tomo_stack), tomo_stack, 0.)
+            logger.debug(f'Remove nans/infs took {time()-t0:.2f} seconds')
+
+            # Downsize tomography stack to smaller size
+            # TODO use theta_skip as well
+            tomo_stack = tomo_stack.astype('float32')
+            if not self.test_mode:
+                if len(tomo_stacks) == 1:
+                    title = f'red fullres theta {round(thetas[0], 2)+0}'
+                else:
+                    title = f'red stack {i+1} fullres theta {round(thetas[0], 2)+0}'
+                quick_imshow(tomo_stack[0,:,:], title=title, path=path, save_fig=self.save_figs,
+                        save_only=self.save_only, block=self.block)
+#                if not self.block:
+#                    clear_imshow(title)
+            if False and zoom_perc != 100:
+                t0 = time()
+                logger.debug(f'Zooming in ...')
+                tomo_zoom_list = []
+                for j in range(tomo_stack.shape[0]):
+                    tomo_zoom = spi.zoom(tomo_stack[j,:,:], 0.01*zoom_perc)
+                    tomo_zoom_list.append(tomo_zoom)
+                tomo_stack = np.stack([tomo_zoom for tomo_zoom in tomo_zoom_list])
+                logger.debug(f'... done in {time()-t0:.2f} seconds')
+                logger.info(f'Zooming in took {time()-t0:.2f} seconds')
+                del tomo_zoom_list
+                if not self.test_mode:
+                    title = f'red stack {zoom_perc}p theta {round(thetas[0], 2)+0}'
+                    quick_imshow(tomo_stack[0,:,:], title=title, path=path, save_fig=self.save_figs,
+                            save_only=self.save_only, block=self.block)
+#                    if not self.block:
+#                        clear_imshow(title)
+
+            # Save test data to file
+            if self.test_mode:
+#                row_index = int(tomo_stack.shape[0]/2)
+#                np.savetxt(f'{self.output_folder}/red_stack_{i+1}.txt', tomo_stack[row_index,:,:],
+#                        fmt='%.6e')
+                row_index = int(tomo_stack.shape[1]/2)
+                np.savetxt(f'{self.output_folder}/red_stack_{i+1}.txt', tomo_stack[:,row_index,:],
+                        fmt='%.6e')
+
+            # Combine resized stacks
+            reduced_tomo_stacks.append(tomo_stack)
+
+        # Add tomo field info to reduced data NXprocess
+        reduced_data['rotation_angle'] = thetas
+        reduced_data['x_translation'] = np.asarray(horizontal_shifts)
+        reduced_data['z_translation'] = np.asarray(vertical_shifts)
+        reduced_data.data['tomo_fields'] = np.asarray(reduced_tomo_stacks)
+
+        if tdf is not None:
+            del tdf
+        del tbf
+
+        return(reduced_data)
+
+    def _find_center_one_plane(self, sinogram, row, thetas, eff_pixel_size, cross_sectional_dim,
+            path=None, tol=0.1, num_core=1):
+        """Find center for a single tomography plane.
+        """
+        # Try automatic center finding routines for initial value
+        # sinogram index order: theta,column
+        # need column,theta for iradon, so take transpose
+        sinogram = np.asarray(sinogram)
+        sinogram_T = sinogram.T
+        center = sinogram.shape[1]/2
+
+        # Try using Nghia Vo’s method
+        t0 = time()
+        if num_core > num_core_tomopy_limit:
+            logger.debug(f'Running find_center_vo on {num_core_tomopy_limit} cores ...')
+            tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core_tomopy_limit)
+        else:
+            logger.debug(f'Running find_center_vo on {num_core} cores ...')
+            tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core)
+        logger.debug(f'... done in {time()-t0:.2f} seconds')
+        logger.info(f'Finding the center using Nghia Vo’s method took {time()-t0:.2f} seconds')
+        center_offset_vo = tomo_center-center
+        logger.info(f'Center at row {row} using Nghia Vo’s method = {center_offset_vo:.2f}')
+        t0 = time()
+        logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...')
+        recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas,
+                eff_pixel_size, cross_sectional_dim, False, num_core)
+        logger.debug(f'... done in {time()-t0:.2f} seconds')
+        logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds')
+
+        title = f'edges row{row} center offset{center_offset_vo:.2f} Vo'
+        self._plot_edges_one_plane(recon_plane, title, path=path)
+
+        # Try using phase correlation method
+#        if input_yesno('Try finding center using phase correlation (y/n)?', 'n'):
+#            t0 = time()
+#            logger.debug(f'Running find_center_pc ...')
+#            tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=0.1, rotc_guess=tomo_center)
+#            error = 1.
+#            while error > tol:
+#                prev = tomo_center
+#                tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=tol,
+#                        rotc_guess=tomo_center)
+#                error = np.abs(tomo_center-prev)
+#            logger.debug(f'... done in {time()-t0:.2f} seconds')
+#            logger.info('Finding the center using the phase correlation method took '+
+#                    f'{time()-t0:.2f} seconds')
+#            center_offset = tomo_center-center
+#            print(f'Center at row {row} using phase correlation = {center_offset:.2f}')
+#            t0 = time()
+#            logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...')
+#            recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas,
+#                    eff_pixel_size, cross_sectional_dim, False, num_core)
+#            logger.debug(f'... done in {time()-t0:.2f} seconds')
+#            logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds')
+#
+#            title = f'edges row{row} center_offset{center_offset:.2f} PC'
+#            self._plot_edges_one_plane(recon_plane, title, path=path)
+
+        # Select center location
+#        if input_yesno('Accept a center location (y) or continue search (n)?', 'y'):
+        if True:
+#            center_offset = input_num('    Enter chosen center offset', ge=-center, le=center,
+#                    default=center_offset_vo)
+            center_offset = center_offset_vo
+            del sinogram_T
+            del recon_plane
+            return float(center_offset)
+
+        # perform center finding search
+        while True:
+            center_offset_low = input_int('\nEnter lower bound for center offset', ge=-center,
+                    le=center)
+            center_offset_upp = input_int('Enter upper bound for center offset',
+                    ge=center_offset_low, le=center)
+            if center_offset_upp == center_offset_low:
+                center_offset_step = 1
+            else:
+                center_offset_step = input_int('Enter step size for center offset search', ge=1,
+                        le=center_offset_upp-center_offset_low)
+            num_center_offset = 1+int((center_offset_upp-center_offset_low)/center_offset_step)
+            center_offsets = np.linspace(center_offset_low, center_offset_upp, num_center_offset)
+            for center_offset in center_offsets:
+                if center_offset == center_offset_vo:
+                    continue
+                t0 = time()
+                logger.debug(f'Running _reconstruct_one_plane on {num_core} cores ...')
+                recon_plane = self._reconstruct_one_plane(sinogram_T, center_offset+center, thetas,
+                        eff_pixel_size, cross_sectional_dim, False, num_core)
+                logger.debug(f'... done in {time()-t0:.2f} seconds')
+                logger.info(f'Reconstructing center_offset {center_offset} took '+
+                        f'{time()-t0:.2f} seconds')
+                title = f'edges row{row} center_offset{center_offset:.2f}'
+                self._plot_edges_one_plane(recon_plane, title, path=path)
+            if input_int('\nContinue (0) or end the search (1)', ge=0, le=1):
+                break
+
+        del sinogram_T
+        del recon_plane
+        center_offset = input_num('    Enter chosen center offset', ge=-center, le=center)
+        return float(center_offset)
+
+    def _reconstruct_one_plane(self, tomo_plane_T, center, thetas, eff_pixel_size,
+            cross_sectional_dim, plot_sinogram=True, num_core=1):
+        """Invert the sinogram for a single tomography plane.
+        """
+        # tomo_plane_T index order: column,theta
+        assert(0 <= center < tomo_plane_T.shape[0])
+        center_offset = center-tomo_plane_T.shape[0]/2
+        two_offset = 2*int(np.round(center_offset))
+        two_offset_abs = np.abs(two_offset)
+        max_rad = int(0.55*(cross_sectional_dim/eff_pixel_size)) # 10% slack to avoid edge effects
+        if max_rad > 0.5*tomo_plane_T.shape[0]:
+            max_rad = 0.5*tomo_plane_T.shape[0]
+        dist_from_edge = max(1, int(np.floor((tomo_plane_T.shape[0]-two_offset_abs)/2.)-max_rad))
+        if two_offset >= 0:
+            logger.debug(f'sinogram range = [{two_offset+dist_from_edge}, {-dist_from_edge}]')
+            sinogram = tomo_plane_T[two_offset+dist_from_edge:-dist_from_edge,:]
+        else:
+            logger.debug(f'sinogram range = [{dist_from_edge}, {two_offset-dist_from_edge}]')
+            sinogram = tomo_plane_T[dist_from_edge:two_offset-dist_from_edge,:]
+        if not self.galaxy_flag and plot_sinogram:
+            quick_imshow(sinogram.T, f'sinogram center offset{center_offset:.2f}', aspect='auto',
+                    path=self.output_folder, save_fig=self.save_figs, save_only=self.save_only,
+                    block=self.block)
+
+        # Inverting sinogram
+        t0 = time()
+        recon_sinogram = iradon(sinogram, theta=thetas, circle=True)
+        logger.debug(f'Inverting sinogram took {time()-t0:.2f} seconds')
+        del sinogram
+
+        # Performing Gaussian filtering and removing ring artifacts
+        recon_parameters = None#self.config.get('recon_parameters')
+        if recon_parameters is None:
+            sigma = 1.0
+            ring_width = 15
+        else:
+            sigma = recon_parameters.get('gaussian_sigma', 1.0)
+            if not is_num(sigma, ge=0.0):
+                logger.warning(f'Invalid gaussian_sigma ({sigma}) in _reconstruct_one_plane, '+
+                        'set to a default value of 1.0')
+                sigma = 1.0
+            ring_width = recon_parameters.get('ring_width', 15)
+            if not is_int(ring_width, ge=0):
+                logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+
+                        'set to a default value of 15')
+                ring_width = 15
+        t0 = time()
+        recon_sinogram = spi.gaussian_filter(recon_sinogram, sigma, mode='nearest')
+        recon_clean = np.expand_dims(recon_sinogram, axis=0)
+        del recon_sinogram
+        recon_clean = tomopy.misc.corr.remove_ring(recon_clean, rwidth=ring_width, ncore=num_core)
+        logger.debug(f'Filtering and removing ring artifacts took {time()-t0:.2f} seconds')
+
+        return recon_clean
+
+    def _plot_edges_one_plane(self, recon_plane, title, path=None):
+        vis_parameters = None#self.config.get('vis_parameters')
+        if vis_parameters is None:
+            weight = 0.1
+        else:
+            weight = vis_parameters.get('denoise_weight', 0.1)
+            if not is_num(weight, ge=0.0):
+                logger.warning(f'Invalid weight ({weight}) in _plot_edges_one_plane, '+
+                        'set to a default value of 0.1')
+                weight = 0.1
+        edges = denoise_tv_chambolle(recon_plane, weight=weight)
+        vmax = np.max(edges[0,:,:])
+        vmin = -vmax
+        if path is None:
+            path = self.output_folder
+        quick_imshow(edges[0,:,:], f'{title} coolwarm', path=path, cmap='coolwarm',
+                save_fig=self.save_figs, save_only=self.save_only, block=self.block)
+        quick_imshow(edges[0,:,:], f'{title} gray', path=path, cmap='gray', vmin=vmin, vmax=vmax,
+                save_fig=self.save_figs, save_only=self.save_only, block=self.block)
+        del edges
+
+    def _reconstruct_one_tomo_stack(self, tomo_stack, thetas, center_offsets=[], num_core=1,
+            algorithm='gridrec'):
+        """Reconstruct a single tomography stack.
+        """
+        # tomo_stack order: row,theta,column
+        # input thetas must be in degrees 
+        # centers_offset: tomography axis shift in pixels relative to column center
+        # RV should we remove stripes?
+        # https://tomopy.readthedocs.io/en/latest/api/tomopy.prep.stripe.html
+        # RV should we remove rings?
+        # https://tomopy.readthedocs.io/en/latest/api/tomopy.misc.corr.html
+        # RV: Add an option to do (extra) secondary iterations later or to do some sort of convergence test?
+        if not len(center_offsets):
+            centers = np.zeros((tomo_stack.shape[0]))
+        elif len(center_offsets) == 2:
+            centers = np.linspace(center_offsets[0], center_offsets[1], tomo_stack.shape[0])
+        else:
+            if center_offsets.size != tomo_stack.shape[0]:
+                raise ValueError('center_offsets dimension mismatch in reconstruct_one_tomo_stack')
+            centers = center_offsets
+        centers += tomo_stack.shape[2]/2
+
+        # Get reconstruction parameters
+        recon_parameters = None#self.config.get('recon_parameters')
+        if recon_parameters is None:
+            sigma = 2.0
+            secondary_iters = 0
+            ring_width = 15
+        else:
+            sigma = recon_parameters.get('stripe_fw_sigma', 2.0)
+            if not is_num(sigma, ge=0):
+                logger.warning(f'Invalid stripe_fw_sigma ({sigma}) in '+
+                        '_reconstruct_one_tomo_stack, set to a default value of 2.0')
+                ring_width = 15
+            secondary_iters = recon_parameters.get('secondary_iters', 0)
+            if not is_int(secondary_iters, ge=0):
+                logger.warning(f'Invalid secondary_iters ({secondary_iters}) in '+
+                        '_reconstruct_one_tomo_stack, set to a default value of 0 (skip them)')
+                ring_width = 0
+            ring_width = recon_parameters.get('ring_width', 15)
+            if not is_int(ring_width, ge=0):
+                logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+
+                        'set to a default value of 15')
+                ring_width = 15
+
+        # Remove horizontal stripe
+        t0 = time()
+        if num_core > num_core_tomopy_limit:
+            logger.debug('Running remove_stripe_fw on {num_core_tomopy_limit} cores ...')
+            tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma,
+                    ncore=num_core_tomopy_limit)
+        else:
+            logger.debug(f'Running remove_stripe_fw on {num_core} cores ...')
+            tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma,
+                    ncore=num_core)
+        logger.debug(f'... tomopy.prep.stripe.remove_stripe_fw took {time()-t0:.2f} seconds')
+
+        # Perform initial image reconstruction
+        logger.debug('Performing initial image reconstruction')
+        t0 = time()
+        logger.debug(f'Running recon on {num_core} cores ...')
+        tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers,
+                sinogram_order=True, algorithm=algorithm, ncore=num_core)
+        logger.debug(f'... done in {time()-t0:.2f} seconds')
+        logger.info(f'Performing initial image reconstruction took {time()-t0:.2f} seconds')
+
+        # Run optional secondary iterations
+        if secondary_iters > 0:
+            logger.debug(f'Running {secondary_iters} secondary iterations')
+            #options = {'method':'SIRT_CUDA', 'proj_type':'cuda', 'num_iter':secondary_iters}
+            #RV: doesn't work for me:
+            #"Error: CUDA error 803: system has unsupported display driver/cuda driver combination."
+            #options = {'method':'SIRT', 'proj_type':'linear', 'MinConstraint': 0, 'num_iter':secondary_iters}
+            #SIRT did not finish while running overnight
+            #options = {'method':'SART', 'proj_type':'linear', 'num_iter':secondary_iters}
+            options = {'method':'SART', 'proj_type':'linear', 'MinConstraint': 0,
+                    'num_iter':secondary_iters}
+            t0 = time()
+            logger.debug(f'Running recon on {num_core} cores ...')
+            tomo_recon_stack  = tomopy.recon(tomo_stack, np.radians(thetas), centers,
+                    init_recon=tomo_recon_stack, options=options, sinogram_order=True,
+                    algorithm=tomopy.astra, ncore=num_core)
+            logger.debug(f'... done in {time()-t0:.2f} seconds')
+            logger.info(f'Performing secondary iterations took {time()-t0:.2f} seconds')
+
+        # Remove ring artifacts
+        t0 = time()
+        tomopy.misc.corr.remove_ring(tomo_recon_stack, rwidth=ring_width, out=tomo_recon_stack,
+                ncore=num_core)
+        logger.debug(f'Removing ring artifacts took {time()-t0:.2f} seconds')
+
+        return tomo_recon_stack
+
+    def _resize_reconstructed_data(self, data, z_only=False):
+        """Resize the reconstructed tomography data.
+        """
+        # Data order: row(z),x,y or stack,row(z),x,y
+        if isinstance(data, list):
+            for stack in data:
+                assert(stack.ndim == 3)
+            num_tomo_stacks = len(data)
+            tomo_recon_stacks = data
+        else:
+            assert(data.ndim == 3)
+            num_tomo_stacks = 1
+            tomo_recon_stacks = [data]
+
+        if z_only:
+            x_bounds = None
+            y_bounds = None
+        else:
+            # Selecting x bounds (in yz-plane)
+            tomosum = 0
+            [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,2))
+                    for i in range(num_tomo_stacks)]
+            select_x_bounds = input_yesno('\nDo you want to change the image x-bounds (y/n)?', 'y')
+            if not select_x_bounds:
+                x_bounds = None
+            else:
+                accept = False
+                index_ranges = None
+                while not accept:
+                    mask, x_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges,
+                            title='select x data range', legend='recon stack sum yz')
+                    while len(x_bounds) != 1:
+                        print('Please select exactly one continuous range')
+                        mask, x_bounds = draw_mask_1d(tomosum, title='select x data range',
+                                legend='recon stack sum yz')
+                    x_bounds = x_bounds[0]
+#                    quick_plot(tomosum, vlines=x_bounds, title='recon stack sum yz')
+#                    print(f'x_bounds = {x_bounds} (lower bound inclusive, upper bound '+
+#                            'exclusive)')
+#                    accept = input_yesno('Accept these bounds (y/n)?', 'y')
+                    accept = True
+            logger.debug(f'x_bounds = {x_bounds}')
+
+            # Selecting y bounds (in xz-plane)
+            tomosum = 0
+            [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,1))
+                    for i in range(num_tomo_stacks)]
+            select_y_bounds = input_yesno('\nDo you want to change the image y-bounds (y/n)?', 'y')
+            if not select_y_bounds:
+                y_bounds = None
+            else:
+                accept = False
+                index_ranges = None
+                while not accept:
+                    mask, y_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges,
+                            title='select x data range', legend='recon stack sum xz')
+                    while len(y_bounds) != 1:
+                        print('Please select exactly one continuous range')
+                        mask, y_bounds = draw_mask_1d(tomosum, title='select x data range',
+                                legend='recon stack sum xz')
+                    y_bounds = y_bounds[0]
+#                    quick_plot(tomosum, vlines=y_bounds, title='recon stack sum xz')
+#                    print(f'y_bounds = {y_bounds} (lower bound inclusive, upper bound '+
+#                            'exclusive)')
+#                    accept = input_yesno('Accept these bounds (y/n)?', 'y')
+                    accept = True
+            logger.debug(f'y_bounds = {y_bounds}')
+
+        # Selecting z bounds (in xy-plane) (only valid for a single image stack)
+        if num_tomo_stacks != 1:
+            z_bounds = None
+        else:
+            tomosum = 0
+            [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(1,2))
+                    for i in range(num_tomo_stacks)]
+            select_z_bounds = input_yesno('Do you want to change the image z-bounds (y/n)?', 'n')
+            if not select_z_bounds:
+                z_bounds = None
+            else:
+                accept = False
+                index_ranges = None
+                while not accept:
+                    mask, z_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges,
+                            title='select x data range', legend='recon stack sum xy')
+                    while len(z_bounds) != 1:
+                        print('Please select exactly one continuous range')
+                        mask, z_bounds = draw_mask_1d(tomosum, title='select x data range',
+                                legend='recon stack sum xy')
+                    z_bounds = z_bounds[0]
+#                    quick_plot(tomosum, vlines=z_bounds, title='recon stack sum xy')
+#                    print(f'z_bounds = {z_bounds} (lower bound inclusive, upper bound '+
+#                            'exclusive)')
+#                    accept = input_yesno('Accept these bounds (y/n)?', 'y')
+                    accept = True
+            logger.debug(f'z_bounds = {z_bounds}')
+
+        return(x_bounds, y_bounds, z_bounds)
+
+
+def run_tomo(input_file:str, output_file:str, modes:list[str], center_file=None, num_core=-1,
+        output_folder='.', save_figs='no', test_mode=False) -> None:
+
+    if test_mode:
+        logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s'
+        level = logging.getLevelName('INFO')
+        logging.basicConfig(filename=f'{output_folder}/tomo.log', filemode='w',
+                format=logging_format, level=level, force=True)
+    logger.info(f'input_file = {input_file}')
+    logger.info(f'center_file = {center_file}')
+    logger.info(f'output_file = {output_file}')
+    logger.debug(f'modes= {modes}')
+    logger.debug(f'num_core= {num_core}')
+    logger.info(f'output_folder = {output_folder}')
+    logger.info(f'save_figs = {save_figs}')
+    logger.info(f'test_mode = {test_mode}')
+
+    # Check for correction modes
+    legal_modes = ['reduce_data', 'find_center', 'reconstruct_data', 'combine_data', 'all']
+    if modes is None:
+        modes = ['all']
+    if not all(True if mode in legal_modes else False for mode in modes):
+        raise ValueError(f'Invalid parameter modes ({modes})')
+
+    # Instantiate Tomo object
+    tomo = Tomo(num_core=num_core, output_folder=output_folder, save_figs=save_figs,
+            test_mode=test_mode)
+
+    # Read input file
+    data = tomo.read(input_file)
+
+    # Generate reduced tomography images
+    if 'reduce_data' in modes or 'all' in modes:
+        data = tomo.gen_reduced_data(data)
+
+    # Find rotation axis centers for the tomography stacks.
+    center_data = None
+    if 'find_center' in modes or 'all' in modes:
+        center_data = tomo.find_centers(data)
+
+    # Reconstruct tomography stacks
+    if 'reconstruct_data' in modes or 'all' in modes:
+        if center_data is None:
+            # Read input file
+            center_data = tomo.read(center_file)
+        data = tomo.reconstruct_data(data, center_data)
+        center_data = None
+
+    # Combine reconstructed tomography stacks
+    if 'combine_data' in modes or 'all' in modes:
+        data = tomo.combine_data(data)
+
+    # Write output file
+    if data is not None and not test_mode:
+        if center_data is None:
+            data = tomo.write(data, output_file)
+        else:
+            data = tomo.write(center_data, output_file)
+
+    logger.info(f'Completed modes: {modes}')