changeset 4:9aa288729b9a draft

planemo upload for repository https://github.com/rolfverberg/galaxytools commit f8c4bdb31c20c468045ad5e6eb255a293244bc6c-dirty
author rv43
date Mon, 20 Mar 2023 18:44:23 +0000
parents fc38431f257f
children 543dba81eb15
files tomo_reduce.py tomo_reduce.xml workflow/__main__.py workflow/__version__.py workflow/link_to_galaxy.py workflow/models.py workflow/run_tomo.py
diffstat 7 files changed, 3088 insertions(+), 1 deletions(-) [+]
line wrap: on
line diff
--- a/tomo_reduce.py	Mon Mar 20 18:30:26 2023 +0000
+++ b/tomo_reduce.py	Mon Mar 20 18:44:23 2023 +0000
@@ -73,6 +73,7 @@
     logging.debug(f'log = {args.log}')
     logging.debug(f'is log stdout? {args.log is sys.stdout}')
     logging.debug(f'log_level = {args.log_level}')
+    return
 
     # Instantiate Tomo object
     tomo = Tomo(galaxy_flag=args.galaxy_flag)
--- a/tomo_reduce.xml	Mon Mar 20 18:30:26 2023 +0000
+++ b/tomo_reduce.xml	Mon Mar 20 18:44:23 2023 +0000
@@ -1,4 +1,4 @@
-<tool id="tomo_reduce" name="Tomo Reduce" version="0.1.0" python_template_version="3.9">
+<tool id="tomo_reduce" name="Tomo Reduce" version="0.1.1" python_template_version="3.9">
     <description>Reduce tomography images</description>
     <macros>
         <import>tomo_macros.xml</import>
@@ -17,6 +17,10 @@
     </command>
     <inputs>
         <param name="input_file" type="data" optional="false" label="Input file"/>
+        <section name="x_bounds" title="Reduction x bounds">
+            <param name="x_bound_low" type="integer" value="-1" label="Lower bound"/>
+            <param name="x_bound_upp" type="integer" value="-1" label="Upper bound"/>
+        </section>
     </inputs>
     <outputs>
         <expand macro="common_outputs"/>
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/workflow/__main__.py	Mon Mar 20 18:44:23 2023 +0000
@@ -0,0 +1,236 @@
+#!/usr/bin/env python3
+
+import logging
+logging.getLogger(__name__)
+
+import argparse
+import pathlib
+import sys
+
+from .models import TomoWorkflow as Workflow
+try:
+    from deepdiff import DeepDiff
+except:
+    pass
+
+parser = argparse.ArgumentParser(description='''Operate on representations of
+        Tomo data workflows saved to files.''')
+parser.add_argument('-l', '--log',
+#        type=argparse.FileType('w'),
+        default=sys.stdout,
+        help='Logging stream or filename')
+parser.add_argument('--log_level',
+        choices=logging._nameToLevel.keys(),
+        default='INFO',
+        help='''Specify a preferred logging level.''')
+subparsers = parser.add_subparsers(title='subcommands', required=True)#, dest='command')
+
+
+# CONSTRUCT
+def construct(args:list) -> None:
+    if args.template_file is not None:
+        wf = Workflow.construct_from_file(args.template_file)
+        wf.cli()
+    else:
+        wf = Workflow.construct_from_cli()
+    wf.write_to_file(args.output_file, force_overwrite=args.force_overwrite)
+
+construct_parser = subparsers.add_parser('construct', help='''Construct a valid Tomo
+        workflow representation on the command line and save it to a file. Optionally use
+        an existing file as a template and/or preform the reconstruction or transfer to Galaxy.''')
+construct_parser.set_defaults(func=construct)
+construct_parser.add_argument('-t', '--template_file',
+        type=pathlib.Path,
+        required=False,
+        help='''Full or relative template file path for the constructed workflow.''')
+construct_parser.add_argument('-f', '--force_overwrite',
+        action='store_true',
+        help='''Use this flag to overwrite the output file if it already exists.''')
+construct_parser.add_argument('-o', '--output_file',
+        type=pathlib.Path,
+        help='''Full or relative file path to which the constructed workflow will be written.''')
+
+
+# VALIDATE
+def validate(args:list) -> bool:
+    try:
+        wf = Workflow.construct_from_file(args.input_file)
+        logger.info(f'Success: {args.input_file} represents a valid Tomo workflow configuration.')
+        return(True)
+    except BaseException as e:
+        logger.error(f'{e.__class__.__name__}: {str(e)}')
+        logger.info(f'''Failure: {args.input_file} does not represent a valid Tomo workflow
+                configuration.''')
+        return(False)
+
+validate_parser = subparsers.add_parser('validate',
+        help='''Validate a file as a representation of a Tomo workflow (this is most useful
+                after a .yaml file has been manually edited).''')
+validate_parser.set_defaults(func=validate)
+validate_parser.add_argument('input_file',
+        type=pathlib.Path,
+        help='''Full or relative file path to validate as a Tomo workflow.''')
+
+
+# CONVERT
+def convert(args:list) -> None:
+    wf = Workflow.construct_from_file(args.input_file)
+    wf.write_to_file(args.output_file, force_overwrite=args.force_overwrite)
+
+convert_parser = subparsers.add_parser('convert', help='''Convert one Tomo workflow
+        representation to another. File format of both input and output files will be
+        automatically determined from the files' extensions.''')
+convert_parser.set_defaults(func=convert)
+convert_parser.add_argument('-f', '--force_overwrite',
+        action='store_true',
+        help='''Use this flag to overwrite the output file if it already exists.''')
+convert_parser.add_argument('-i', '--input_file',
+        type=pathlib.Path,
+        required=True,
+        help='''Full or relative input file path to be converted.''')
+convert_parser.add_argument('-o', '--output_file',
+        type=pathlib.Path,
+        required=True,
+        help='''Full or relative file path to which the converted input will be written.''')
+
+
+# DIFF / COMPARE
+def diff(args:list) -> bool:
+    raise ValueError('diff not tested')
+#    wf1 = Workflow.construct_from_file(args.file1).dict_for_yaml()
+#    wf2 = Workflow.construct_from_file(args.file2).dict_for_yaml()
+#    diff = DeepDiff(wf1,wf2,
+#                    ignore_order_func=lambda level:'independent_dimensions' not in level.path(),
+#                    report_repetition=True,
+#                    ignore_string_type_changes=True,
+#                    ignore_numeric_type_changes=True)
+    diff_report = diff.pretty()
+    if len(diff_report) > 0:
+        logger.info(f'The configurations in {args.file1} and {args.file2} are not identical.')
+        print(diff_report)
+        return(True)
+    else:
+        logger.info(f'The configurations in {args.file1} and {args.file2} are identical.')
+        return(False)
+
+diff_parser = subparsers.add_parser('diff', aliases=['compare'], help='''Print a comparison of 
+        two Tomo workflow representations stored in files. The files may have different formats.''')
+diff_parser.set_defaults(func=diff)
+diff_parser.add_argument('file1',
+        type=pathlib.Path,
+        help='''Full or relative path to the first file for comparison.''')
+diff_parser.add_argument('file2',
+        type=pathlib.Path,
+        help='''Full or relative path to the second file for comparison.''')
+
+
+# LINK TO GALAXY
+def link_to_galaxy(args:list) -> None:
+    from .link_to_galaxy import link_to_galaxy
+    link_to_galaxy(args.input_file, galaxy=args.galaxy, user=args.user,
+            password=args.password, api_key=args.api_key)
+
+link_parser = subparsers.add_parser('link_to_galaxy', help='''Construct a Galaxy history and link
+        to an existing Tomo workflow representations in a NeXus file.''')
+link_parser.set_defaults(func=link_to_galaxy)
+link_parser.add_argument('-i', '--input_file',
+        type=pathlib.Path,
+        required=True,
+        help='''Full or relative input file path to the existing Tomo workflow representations as 
+                a NeXus file.''')
+link_parser.add_argument('-g', '--galaxy',
+        required=True,
+        help='Target Galaxy instance URL/IP address')
+link_parser.add_argument('-u', '--user',
+        default=None,
+        help='Galaxy user email address')
+link_parser.add_argument('-p', '--password',
+        default=None,
+        help='Password for the Galaxy user')
+link_parser.add_argument('-a', '--api_key',
+        default=None,
+        help='Galaxy admin user API key (required if not defined in the tools list file)')
+
+
+# RUN THE RECONSTRUCTION
+def run_tomo(args:list) -> None:
+    from .run_tomo import run_tomo
+    run_tomo(args.input_file, args.output_file, args.modes, center_file=args.center_file,
+            num_core=args.num_core, output_folder=args.output_folder, save_figs=args.save_figs)
+
+tomo_parser = subparsers.add_parser('run_tomo', help='''Construct and add reconstructed tomography
+        data to an existing Tomo workflow representations in a NeXus file.''')
+tomo_parser.set_defaults(func=run_tomo)
+tomo_parser.add_argument('-i', '--input_file',
+        required=True,
+        type=pathlib.Path,
+        help='''Full or relative input file path containing raw and/or reduced data.''')
+tomo_parser.add_argument('-o', '--output_file',
+        required=True,
+        type=pathlib.Path,
+        help='''Full or relative input file path containing raw and/or reduced data.''')
+tomo_parser.add_argument('-c', '--center_file',
+        type=pathlib.Path,
+        help='''Full or relative input file path containing the rotation axis centers info.''')
+#tomo_parser.add_argument('-f', '--force_overwrite',
+#        action='store_true',
+#        help='''Use this flag to overwrite any existing reduced data.''')
+tomo_parser.add_argument('-n', '--num_core',
+        type=int,
+        default=-1,
+        help='''Specify the number of processors to use.''')
+tomo_parser.add_argument('--output_folder',
+        type=pathlib.Path,
+        default='.',
+        help='Full or relative path to an output folder')
+tomo_parser.add_argument('-s', '--save_figs',
+        choices=['yes', 'no', 'only'],
+        default='no',
+        help='''Specify weather to display ('yes' or 'no'), save ('yes'), or only save ('only').''')
+tomo_parser.add_argument('--reduce_data',
+        dest='modes',
+        const='reduce_data',
+        action='append_const',
+        help='''Use this flag to create and add reduced data to the input file.''')
+tomo_parser.add_argument('--find_center',
+        dest='modes',
+        const='find_center',
+        action='append_const',
+        help='''Use this flag to find and add the calibrated center axis info to the input file.''')
+tomo_parser.add_argument('--reconstruct_data',
+        dest='modes',
+        const='reconstruct_data',
+        action='append_const',
+        help='''Use this flag to create and add reconstructed data data to the input file.''')
+tomo_parser.add_argument('--combine_data',
+        dest='modes',
+        const='combine_data',
+        action='append_const',
+        help='''Use this flag to combine reconstructed data data and add to the input file.''')
+
+
+if __name__ == '__main__':
+    args = parser.parse_args(sys.argv[1:])
+
+    # Set log configuration
+    # When logging to file, the stdout log level defaults to WARNING
+    logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s'
+    level = logging.getLevelName(args.log_level)
+    if args.log is sys.stdout:
+        logging.basicConfig(format=logging_format, level=level, force=True,
+                handlers=[logging.StreamHandler()])
+    else:
+        if isinstance(args.log, str):
+            logging.basicConfig(filename=f'{args.log}', filemode='w',
+                    format=logging_format, level=level, force=True)
+        elif isinstance(args.log, io.TextIOWrapper):
+            logging.basicConfig(filemode='w', format=logging_format, level=level,
+                    stream=args.log, force=True)
+        else:
+            raise ValueError(f'Invalid argument --log: {args.log}')
+        stream_handler = logging.StreamHandler()
+        logging.getLogger().addHandler(stream_handler)
+        stream_handler.setLevel(logging.WARNING)
+        stream_handler.setFormatter(logging.Formatter(logging_format))
+
+    args.func(args)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/workflow/__version__.py	Mon Mar 20 18:44:23 2023 +0000
@@ -0,0 +1,1 @@
+__version__='2022.3.0'
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/workflow/link_to_galaxy.py	Mon Mar 20 18:44:23 2023 +0000
@@ -0,0 +1,120 @@
+#!/usr/bin/env python3
+
+import logging
+logger = logging.getLogger(__name__)
+
+from bioblend.galaxy import GalaxyInstance
+from nexusformat.nexus import *
+from os import path
+from yaml import safe_load
+
+from .models import import_scanparser, TomoWorkflow
+
+def get_folder_id(gi, path):
+    library_id = None
+    folder_id = None
+    folder_names = path[1:] if len(path) > 1 else []
+    new_folders = folder_names
+    libs = gi.libraries.get_libraries(name=path[0])
+    if libs:
+        for lib in libs:
+            library_id = lib['id']
+            folders = gi.libraries.get_folders(library_id, folder_id=None, name=None)
+            for i, folder in enumerate(folders):
+                fid = folder['id']
+                details = gi.libraries.show_folder(library_id, fid)
+                library_path = details['library_path']
+                if library_path == folder_names:
+                    return (library_id, fid, [])
+                elif len(library_path) < len(folder_names):
+                    if library_path == folder_names[:len(library_path)]:
+                        nf = folder_names[len(library_path):]
+                        if len(nf) < len(new_folders):
+                            folder_id = fid
+                            new_folders = nf
+    return (library_id, folder_id, new_folders)
+
+def link_to_galaxy(filename:str, galaxy=None, user=None, password=None, api_key=None) -> None:
+    # Read input file
+    extension = path.splitext(filename)[1]
+# RV yaml input not incorporated yet, since Galaxy can't use pyspec right now
+#    if extension == '.yml' or extension == '.yaml':
+#        with open(filename, 'r') as f:
+#            data = safe_load(f)
+#    elif extension == '.nxs':
+    if extension == '.nxs':
+        with NXFile(filename, mode='r') as nxfile:
+            data = nxfile.readfile()
+    else:
+        raise ValueError(f'Invalid filename extension ({extension})')
+    if isinstance(data, dict):
+        # Create Nexus format object from input dictionary
+        wf = TomoWorkflow(**data)
+        if len(wf.sample_maps) > 1:
+            raise ValueError(f'Multiple sample maps not yet implemented')
+        nxroot = NXroot()
+        for sample_map in wf.sample_maps:
+            import_scanparser(sample_map.station)
+# RV raw data must be included, since Galaxy can't use pyspec right now
+#            sample_map.construct_nxentry(nxroot, include_raw_data=False)
+            sample_map.construct_nxentry(nxroot, include_raw_data=True)
+        nxentry = nxroot[nxroot.attrs['default']]
+    elif isinstance(data, NXroot):
+        nxentry = data[data.attrs['default']]
+    else:
+        raise ValueError(f'Invalid input file data ({data})')
+
+    # Get a Galaxy instance
+    if user is not None and password is not None :
+        gi = GalaxyInstance(url=galaxy, email=user, password=password)
+    elif api_key is not None:
+        gi = GalaxyInstance(url=galaxy, key=api_key)
+    else:
+        exit('Please specify either a valid Galaxy username/password or an API key.')
+
+    cycle = nxentry.instrument.source.attrs['cycle']
+    btr = nxentry.instrument.source.attrs['btr']
+    sample = nxentry.sample.name
+
+    # Create a Galaxy work library/folder
+    # Combine the cycle, BTR and sample name as the base library name
+    lib_path = [p.strip() for p in f'{cycle}/{btr}/{sample}'.split('/')]
+    (library_id, folder_id, folder_names) = get_folder_id(gi, lib_path)
+    if not library_id:
+        library = gi.libraries.create_library(lib_path[0], description=None, synopsis=None)
+        library_id = library['id']
+#        if user:
+#            gi.libraries.set_library_permissions(library_id, access_ids=user,
+#                    manage_ids=user, modify_ids=user)
+        logger.info(f'Created Library:\n{library}')
+    if len(folder_names):
+        folder = gi.libraries.create_folder(library_id, folder_names[0], description=None,
+                base_folder_id=folder_id)[0]
+        folder_id = folder['id']
+        logger.info(f'Created Folder:\n{folder}')
+        folder_names.pop(0)
+        while len(folder_names):
+            folder = gi.folders.create_folder(folder['id'], folder_names[0],
+                    description=None)
+            folder_id = folder['id']
+            logger.info(f'Created Folder:\n{folder}')
+            folder_names.pop(0)
+
+    # Create a sym link for the Nexus file
+    dataset = gi.libraries.upload_from_galaxy_filesystem(library_id, path.abspath(filename),
+            folder_id=folder_id, file_type='auto', dbkey='?', link_data_only='link_to_files',
+            roles='', preserve_dirs=False, tag_using_filenames=False, tags=None)[0]
+
+    # Make a history for the data
+    history_name = f'tomo {btr} {sample}'
+    history = gi.histories.create_history(name=history_name)
+    logger.info(f'Created history:\n{history}')
+    history_id = history['id']
+    gi.histories.copy_dataset(history_id, dataset['id'], source='library')
+
+# TODO add option to either 
+#     get a URL to share the history
+#     or to share with specific users
+# This might require using: 
+#     https://bioblend.readthedocs.io/en/latest/api_docs/galaxy/docs.html#using-bioblend-for-raw-api-calls
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/workflow/models.py	Mon Mar 20 18:44:23 2023 +0000
@@ -0,0 +1,1096 @@
+#!/usr/bin/env python3
+
+import logging
+logger = logging.getLogger(__name__)
+
+import logging
+
+import numpy as np
+import os
+import yaml
+
+from functools import cache
+from pathlib import PosixPath
+from pydantic import BaseModel as PydanticBaseModel
+from pydantic import validator, ValidationError, conint, confloat, constr, conlist, FilePath, \
+        PrivateAttr
+from nexusformat.nexus import *
+from time import time
+from typing import Optional, Literal
+from typing_extensions import TypedDict
+try:
+    from pyspec.file.spec import FileSpec
+except:
+    pass
+
+try:
+    from msnctools.general import is_int, is_num, input_int, input_int_list, input_num, \
+            input_yesno, input_menu, index_nearest, string_to_list, file_exists_and_readable
+except:
+    from general import is_int, is_num, input_int, input_int_list, input_num, \
+            input_yesno, input_menu, index_nearest, string_to_list, file_exists_and_readable
+
+
+def import_scanparser(station):
+    if station in ('id1a3', 'id3a'):
+        try:
+            from msnctools.scanparsers import SMBRotationScanParser
+            globals()['ScanParser'] = SMBRotationScanParser
+        except:
+            try:
+                from scanparsers import SMBRotationScanParser
+                globals()['ScanParser'] = SMBRotationScanParser
+            except:
+                pass
+    elif station in ('id3b'):
+        try:
+            from msnctools.scanparsers import FMBRotationScanParser
+            globals()['ScanParser'] = FMBRotationScanParser
+        except:
+            try:
+                from scanparsers import FMBRotationScanParser
+                globals()['ScanParser'] = FMBRotationScanParser
+            except:
+                pass
+    else:
+        raise RuntimeError(f'Invalid station: {station}')
+
+@cache
+def get_available_scan_numbers(spec_file:str):
+    scans = FileSpec(spec_file).scans
+    scan_numbers = list(scans.keys())
+    for scan_number in scan_numbers.copy():
+        try:
+            parser = ScanParser(spec_file, scan_number)
+            try:
+                scan_type = parser.scan_type
+            except:
+                scan_type = None
+        except:
+            scan_numbers.remove(scan_number)
+    return(scan_numbers)
+
+@cache
+def get_scanparser(spec_file:str, scan_number:int):
+    if scan_number not in get_available_scan_numbers(spec_file):
+        return(None)
+    else:
+        return(ScanParser(spec_file, scan_number))
+
+
+class BaseModel(PydanticBaseModel):
+    class Config:
+        validate_assignment = True
+        arbitrary_types_allowed = True
+
+    @classmethod
+    def construct_from_cli(cls):
+        obj = cls.construct()
+        obj.cli()
+        return(obj)
+
+    @classmethod
+    def construct_from_yaml(cls, filename):
+        try:
+            with open(filename, 'r') as infile:
+                indict = yaml.load(infile, Loader=yaml.CLoader)
+        except:
+            raise ValueError(f'Could not load a dictionary from {filename}')
+        else:
+            obj = cls(**indict)
+            return(obj)
+
+    @classmethod
+    def construct_from_file(cls, filename):
+        file_exists_and_readable(filename)
+        filename = os.path.abspath(filename)
+        fileformat = os.path.splitext(filename)[1]
+        yaml_extensions = ('.yaml','.yml')
+        nexus_extensions = ('.nxs','.nx5','.h5','.hdf5')
+        t0 = time()
+        if fileformat.lower() in yaml_extensions:
+            obj = cls.construct_from_yaml(filename)
+            logger.info(f'Constructed a model from {filename} in {time()-t0:.2f} seconds.')
+            return(obj)
+        elif fileformat.lower() in nexus_extensions:
+            obj = cls.construct_from_nexus(filename)
+            logger.info(f'Constructed a model from {filename} in {time()-t0:.2f} seconds.')
+            return(obj)
+        else:
+            logger.error(f'Unsupported file extension for constructing a model: {fileformat}')
+            raise TypeError(f'Unrecognized file extension: {fileformat}')
+
+    def dict_for_yaml(self, exclude_fields=[]):
+        yaml_dict = {}
+        for field_name in self.__fields__:
+            if field_name in exclude_fields:
+                continue
+            else:
+                field_value = getattr(self, field_name, None)
+                if field_value is not None:
+                    if isinstance(field_value, BaseModel):
+                        yaml_dict[field_name] = field_value.dict_for_yaml()
+                    elif isinstance(field_value,list) and all(isinstance(item,BaseModel)
+                            for item in field_value):
+                        yaml_dict[field_name] = [item.dict_for_yaml() for item in field_value]
+                    elif isinstance(field_value, PosixPath):
+                        yaml_dict[field_name] = str(field_value)
+                    else:
+                        yaml_dict[field_name] = field_value
+                else:
+                    continue
+        return(yaml_dict)
+
+    def write_to_yaml(self, filename=None):
+        yaml_dict = self.dict_for_yaml()
+        if filename is None:
+            logger.info('Printing yaml representation here:\n'+
+                    f'{yaml.dump(yaml_dict, sort_keys=False)}')
+        else:
+            try:
+                with open(filename, 'w') as outfile:
+                    yaml.dump(yaml_dict, outfile, sort_keys=False)
+                logger.info(f'Successfully wrote this model to {filename}')
+            except:
+                logger.error(f'Unknown error -- could not write to {filename} in yaml format.')
+                logger.info('Printing yaml representation here:\n'+
+                        f'{yaml.dump(yaml_dict, sort_keys=False)}')
+
+    def write_to_file(self, filename, force_overwrite=False):
+        file_writeable, fileformat = self.output_file_valid(filename,
+                force_overwrite=force_overwrite)
+        if fileformat == 'yaml':
+            if file_writeable:
+                self.write_to_yaml(filename=filename)
+            else:
+                self.write_to_yaml()
+        elif fileformat == 'nexus':
+            if file_writeable:
+                self.write_to_nexus(filename=filename)
+
+    def output_file_valid(self, filename, force_overwrite=False):
+        filename = os.path.abspath(filename)
+        fileformat = os.path.splitext(filename)[1]
+        yaml_extensions = ('.yaml','.yml')
+        nexus_extensions = ('.nxs','.nx5','.h5','.hdf5')
+        if fileformat.lower() not in (*yaml_extensions, *nexus_extensions):
+            return(False, None) # Only yaml and NeXus files allowed for output now.
+        elif fileformat.lower() in yaml_extensions:
+            fileformat = 'yaml'
+        elif fileformat.lower() in nexus_extensions:
+            fileformat = 'nexus'
+        if os.path.isfile(filename):
+            if os.access(filename, os.W_OK):
+                if not force_overwrite:
+                    logger.error(f'{filename} will not be overwritten.')
+                    return(False, fileformat)
+            else:
+                logger.error(f'Cannot access {filename} for writing.')
+                return(False, fileformat)
+        if os.path.isdir(os.path.dirname(filename)):
+            if os.access(os.path.dirname(filename), os.W_OK):
+                return(True, fileformat)
+            else:
+                logger.error(f'Cannot access {os.path.dirname(filename)} for writing.')
+                return(False, fileformat)
+        else:
+            try:
+                os.makedirs(os.path.dirname(filename))
+                return(True, fileformat)
+            except:
+                logger.error(f'Cannot create {os.path.dirname(filename)} for output.')
+                return(False, fileformat)
+
+    def set_single_attr_cli(self, attr_name, attr_desc='unknown attribute', list_flag=False,
+            **cli_kwargs):
+        if cli_kwargs.get('chain_attr_desc', False):
+            cli_kwargs['attr_desc'] = attr_desc
+        try:
+            attr = getattr(self, attr_name, None)
+            if attr is None:
+                attr = self.__fields__[attr_name].type_.construct()
+            if cli_kwargs.get('chain_attr_desc', False):
+                cli_kwargs['attr_desc'] = attr_desc
+            input_accepted = False
+            while not input_accepted:
+                try:
+                    attr.cli(**cli_kwargs)
+                except ValidationError as e:
+                    print(e)
+                    print(f'Removing {attr_desc} configuration')
+                    attr = self.__fields__[attr_name].type_.construct()
+                    continue
+                except KeyboardInterrupt as e:
+                    raise e
+                except BaseException as e:
+                    print(f'{type(e).__name__}: {e}')
+                    print(f'Removing {attr_desc} configuration')
+                    attr = self.__fields__[attr_name].type_.construct()
+                    continue
+                try:
+                    setattr(self, attr_name, attr)
+                except ValidationError as e:
+                    print(e)
+                except KeyboardInterrupt as e:
+                    raise e
+                except BaseException as e:
+                    print(f'{type(e).__name__}: {e}')
+                else:
+                    input_accepted = True
+        except:
+            input_accepted = False
+            while not input_accepted:
+                attr = getattr(self, attr_name, None)
+                if attr is None:
+                    input_value = input(f'Type and enter a value for {attr_desc}: ')
+                else:
+                    input_value = input(f'Type and enter a new value for {attr_desc} or press '+
+                            f'enter to keep the current one ({attr}): ')
+                if list_flag:
+                    input_value = string_to_list(input_value, remove_duplicates=False, sort=False)
+                if len(input_value) == 0:
+                    input_value = getattr(self, attr_name, None)
+                try:
+                    setattr(self, attr_name, input_value)
+                except ValidationError as e:
+                    print(e)
+                except KeyboardInterrupt as e:
+                    raise e
+                except BaseException as e:
+                    print(f'Unexpected {type(e).__name__}: {e}')
+                else:
+                    input_accepted = True
+
+    def set_list_attr_cli(self, attr_name, attr_desc='unknown attribute', **cli_kwargs):
+        if cli_kwargs.get('chain_attr_desc', False):
+            cli_kwargs['attr_desc'] = attr_desc
+        attr = getattr(self, attr_name, None)
+        if attr is not None:
+            # Check existing items
+            for item in attr:
+                item_accepted = False
+                while not item_accepted:
+                    item.cli(**cli_kwargs)
+                    try:
+                        setattr(self, attr_name, attr)
+                    except ValidationError as e:
+                        print(e)
+                    except KeyboardInterrupt as e:
+                        raise e
+                    except BaseException as e:
+                        print(f'{type(e).__name__}: {e}')
+                    else:
+                        item_accepted = True
+        else:
+            # Initialize list for new attribute & starting item
+            attr = []
+            item = self.__fields__[attr_name].type_.construct()
+        # Append (optional) additional items
+        append = input_yesno(f'Add a {attr_desc} configuration? (y/n)', 'n')
+        while append:
+            attr.append(item.__class__.construct_from_cli())
+            try:
+                setattr(self, attr_name, attr)
+            except ValidationError as e:
+                print(e)
+                print(f'Removing last {attr_desc} configuration from the list')
+                attr.pop()
+            except KeyboardInterrupt as e:
+                raise e
+            except BaseException as e:
+                print(f'{type(e).__name__}: {e}')
+                print(f'Removing last {attr_desc} configuration from the list')
+                attr.pop()
+            else:
+                append = input_yesno(f'Add another {attr_desc} configuration? (y/n)', 'n')
+
+
+class Detector(BaseModel):
+    prefix: constr(strip_whitespace=True, min_length=1)
+    rows: conint(gt=0)
+    columns: conint(gt=0)
+    pixel_size: conlist(item_type=confloat(gt=0), min_items=1, max_items=2)
+    lens_magnification: confloat(gt=0) = 1.0
+
+    @property
+    def get_pixel_size(self):
+        return(list(np.asarray(self.pixel_size)/self.lens_magnification))
+
+    def construct_from_yaml(self, filename):
+        try:
+            with open(filename, 'r') as infile:
+                indict = yaml.load(infile, Loader=yaml.CLoader)
+            detector = indict['detector']
+            self.prefix = detector['id']
+            pixels = detector['pixels']
+            self.rows = pixels['rows']
+            self.columns = pixels['columns']
+            self.pixel_size = pixels['size']
+            self.lens_magnification = indict['lens_magnification']
+        except:
+            logging.warning(f'Could not load a dictionary from {filename}')
+            return(False)
+        else:
+            return(True)
+
+    def cli(self):
+        print('\n -- Configure the detector -- ')
+        self.set_single_attr_cli('prefix', 'detector ID')
+        self.set_single_attr_cli('rows', 'number of pixel rows')
+        self.set_single_attr_cli('columns', 'number of pixel columns')
+        self.set_single_attr_cli('pixel_size', 'pixel size in mm (enter either a single value for '+
+                'square pixels or a pair of values for the size in the respective row and column '+
+                'directions)', list_flag=True)
+        self.set_single_attr_cli('lens_magnification', 'lens magnification')
+
+    def construct_nxdetector(self):
+        nxdetector = NXdetector()
+        nxdetector.local_name = self.prefix
+        pixel_size = self.get_pixel_size
+        if len(pixel_size) == 1:
+            nxdetector.x_pixel_size = pixel_size[0]
+            nxdetector.y_pixel_size = pixel_size[0]
+        else:
+            nxdetector.x_pixel_size = pixel_size[0]
+            nxdetector.y_pixel_size = pixel_size[1]
+        nxdetector.x_pixel_size.attrs['units'] = 'mm'
+        nxdetector.y_pixel_size.attrs['units'] = 'mm'
+        return(nxdetector)
+
+
+class ScanInfo(TypedDict):
+    scan_number: int
+    starting_image_offset: conint(ge=0)
+    num_image: conint(gt=0)
+    ref_x: float
+    ref_z: float
+
+class SpecScans(BaseModel):
+    spec_file: FilePath
+    scan_numbers: conlist(item_type=conint(gt=0), min_items=1)
+    stack_info: conlist(item_type=ScanInfo, min_items=1) = []
+
+    @validator('spec_file')
+    def validate_spec_file(cls, spec_file):
+        try:
+            spec_file = os.path.abspath(spec_file)
+            sspec_file = FileSpec(spec_file)
+        except:
+            raise ValueError(f'Invalid SPEC file {spec_file}')
+        else:
+            return(spec_file)
+
+    @validator('scan_numbers')
+    def validate_scan_numbers(cls, scan_numbers, values):
+        spec_file = values.get('spec_file')
+        if spec_file is not None:
+            spec_scans = FileSpec(spec_file)
+            for scan_number in scan_numbers:
+                scan = spec_scans.get_scan_by_number(scan_number)
+                if scan is None:
+                    raise ValueError(f'There is no scan number {scan_number} in {spec_file}')
+        return(scan_numbers)
+
+    @validator('stack_info')
+    def validate_stack_info(cls, stack_info, values):
+        scan_numbers = values.get('scan_numbers')
+        assert(len(scan_numbers) == len(stack_info))
+        for scan_info in stack_info:
+            assert(scan_info['scan_number'] in scan_numbers)
+            is_int(scan_info['starting_image_offset'], ge=0, lt=scan_info['num_image'],
+                    raise_error=True)
+        return(stack_info)
+
+    @classmethod
+    def construct_from_nxcollection(cls, nxcollection:NXcollection):
+        config = {}
+        config['spec_file'] = nxcollection.attrs['spec_file']
+        scan_numbers = []
+        stack_info = []
+        for nxsubentry_name, nxsubentry in nxcollection.items():
+            scan_number = int(nxsubentry_name.split('_')[-1])
+            scan_numbers.append(scan_number)
+            stack_info.append({'scan_number': scan_number,
+                    'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number),
+                    'num_image': len(nxsubentry.sample.rotation_angle),
+                    'ref_x': float(nxsubentry.sample.x_translation),
+                    'ref_z': float(nxsubentry.sample.z_translation)})
+        config['scan_numbers'] = sorted(scan_numbers)
+        config['stack_info'] = stack_info
+        return(cls(**config))
+
+    @property
+    def available_scan_numbers(self):
+        return(get_available_scan_numbers(self.spec_file))
+
+    def set_from_nxcollection(self, nxcollection:NXcollection):
+        self.spec_file = nxcollection.attrs['spec_file']
+        scan_numbers = []
+        stack_info = []
+        for nxsubentry_name, nxsubentry in nxcollection.items():
+            scan_number = int(nxsubentry_name.split('_')[-1])
+            scan_numbers.append(scan_number)
+            stack_info.append({'scan_number': scan_number,
+                    'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number),
+                    'num_image': len(nxsubentry.sample.rotation_angle),
+                    'ref_x': float(nxsubentry.sample.x_translation),
+                    'ref_z': float(nxsubentry.sample.z_translation)})
+        self.scan_numbers = sorted(scan_numbers)
+        self.stack_info = stack_info
+
+    def get_scan_index(self, scan_number):
+        scan_index = [scan_index for scan_index, scan_info in enumerate(self.stack_info)
+                if scan_info['scan_number'] == scan_number]
+        if len(scan_index) > 1:
+            raise ValueError('Duplicate scan_numbers in image stack')
+        elif len(scan_index) == 1:
+            return(scan_index[0])
+        else:
+            return(None)
+
+    def get_scanparser(self, scan_number):
+        return(get_scanparser(self.spec_file, scan_number))
+
+    def get_detector_data(self, detector_prefix, scan_number=None, scan_step_index=None):
+        image_stacks = []
+        if scan_number is None:
+            scan_numbers = self.scan_numbers
+        else:
+            scan_numbers = [scan_number]
+        for scan_number in scan_numbers:
+            parser = self.get_scanparser(scan_number)
+            scan_info = self.stack_info[self.get_scan_index(scan_number)]
+            image_offset = scan_info['starting_image_offset']
+            if scan_step_index is None:
+                num_image = scan_info['num_image']
+                image_stacks.append(parser.get_detector_data(detector_prefix,
+                        (image_offset, image_offset+num_image)))
+            else:
+                image_stacks.append(parser.get_detector_data(detector_prefix,
+                        image_offset+scan_step_index))
+        if scan_number is not None and scan_step_index is not None:
+            # Return a single image for a specific scan_number and scan_step_index request
+            return(image_stacks[0])
+        else:
+            # Return a list otherwise
+            return(image_stacks)
+        return(image_stacks)
+
+    def scan_numbers_cli(self, attr_desc, **kwargs):
+        available_scan_numbers = self.available_scan_numbers
+        station = kwargs.get('station')
+        if (station is not None and station in ('id1a3', 'id3a') and
+                'scan_type' in kwargs):
+            scan_type = kwargs['scan_type']
+            if scan_type == 'ts1':
+                available_scan_numbers = []
+                for scan_number in self.available_scan_numbers:
+                    parser = self.get_scanparser(scan_number)
+                    try:
+                        if parser.scan_type == scan_type:
+                            available_scan_numbers.append(scan_number)
+                    except:
+                        pass
+            elif scan_type == 'df1':
+                tomo_scan_numbers = kwargs['tomo_scan_numbers']
+                available_scan_numbers = []
+                for scan_number in tomo_scan_numbers:
+                    parser = self.get_scanparser(scan_number-2)
+                    assert(parser.scan_type == scan_type)
+                    available_scan_numbers.append(scan_number-2)
+            elif scan_type == 'bf1':
+                tomo_scan_numbers = kwargs['tomo_scan_numbers']
+                available_scan_numbers = []
+                for scan_number in tomo_scan_numbers:
+                    parser = self.get_scanparser(scan_number-1)
+                    assert(parser.scan_type == scan_type)
+                    available_scan_numbers.append(scan_number-1)
+        if len(available_scan_numbers) == 1:
+            input_mode = 1
+        else:
+            if hasattr(self, 'scan_numbers'):
+                print(f'Currently selected {attr_desc}scan numbers are: {self.scan_numbers}')
+                menu_options = [f'Select a subset of the available {attr_desc}scan numbers',
+                        f'Use all available {attr_desc}scan numbers in {self.spec_file}',
+                        f'Keep the currently selected {attr_desc}scan numbers']
+            else:
+                menu_options = [f'Select a subset of the available {attr_desc}scan numbers',
+                        f'Use all available {attr_desc}scan numbers in {self.spec_file}']
+            print(f'Available scan numbers in {self.spec_file} are: '+
+                    f'{available_scan_numbers}')
+            input_mode = input_menu(menu_options, header='Choose one of the following options '+
+                    'for selecting scan numbers')
+        if input_mode == 0:
+            accept_scan_numbers = False
+            while not accept_scan_numbers:
+                try:
+                    self.scan_numbers = \
+                            input_int_list(f'Enter a series of {attr_desc}scan numbers')
+                except ValidationError as e:
+                    print(e)
+                except KeyboardInterrupt as e:
+                    raise e
+                except BaseException as e:
+                    print(f'Unexpected {type(e).__name__}: {e}')
+                else:
+                    accept_scan_numbers = True
+        elif input_mode == 1:
+            self.scan_numbers = available_scan_numbers
+        elif input_mode == 2:
+            pass
+
+    def cli(self, **cli_kwargs):
+        if cli_kwargs.get('attr_desc') is not None:
+            attr_desc = f'{cli_kwargs["attr_desc"]} '
+        else:
+            attr_desc = ''
+        print(f'\n -- Configure which scans to use from a single {attr_desc}SPEC file')
+        self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path')
+        self.scan_numbers_cli(attr_desc)
+
+    def construct_nxcollection(self, image_key, thetas, detector):
+        nxcollection = NXcollection()
+        nxcollection.attrs['spec_file'] = str(self.spec_file)
+        parser = self.get_scanparser(self.scan_numbers[0])
+        nxcollection.attrs['date'] = parser.spec_scan.file_date
+        for scan_number in self.scan_numbers:
+            # Get scan info
+            scan_info = self.stack_info[self.get_scan_index(scan_number)]
+            # Add an NXsubentry to the NXcollection for each scan
+            entry_name = f'scan_{scan_number}'
+            nxsubentry = NXsubentry()
+            nxcollection[entry_name] = nxsubentry
+            parser = self.get_scanparser(scan_number)
+            nxsubentry.start_time = parser.spec_scan.date
+            nxsubentry.spec_command = parser.spec_command
+            # Add an NXdata for independent dimensions to the scan's NXsubentry
+            num_image = scan_info['num_image']
+            if thetas is None:
+                thetas = num_image*[0.0]
+            else:
+                assert(num_image == len(thetas))
+#            nxsubentry.independent_dimensions = NXdata()
+#            nxsubentry.independent_dimensions.rotation_angle = thetas
+#            nxsubentry.independent_dimensions.rotation_angle.units = 'degrees'
+            # Add an NXinstrument to the scan's NXsubentry
+            nxsubentry.instrument = NXinstrument()
+            # Add an NXdetector to the NXinstrument to the scan's NXsubentry
+            nxsubentry.instrument.detector = detector.construct_nxdetector()
+            nxsubentry.instrument.detector.frame_start_number = scan_info['starting_image_offset']
+            nxsubentry.instrument.detector.image_key = image_key
+            # Add an NXsample to the scan's NXsubentry
+            nxsubentry.sample = NXsample()
+            nxsubentry.sample.rotation_angle = thetas
+            nxsubentry.sample.rotation_angle.units = 'degrees'
+            nxsubentry.sample.x_translation = scan_info['ref_x']
+            nxsubentry.sample.x_translation.units = 'mm'
+            nxsubentry.sample.z_translation = scan_info['ref_z']
+            nxsubentry.sample.z_translation.units = 'mm'
+        return(nxcollection)
+
+
+class FlatField(SpecScans):
+
+    def image_range_cli(self, attr_desc, detector_prefix):
+        stack_info = self.stack_info
+        for scan_number in self.scan_numbers:
+            # Parse the available image range
+            parser = self.get_scanparser(scan_number)
+            image_offset = parser.starting_image_offset
+            num_image = parser.get_num_image(detector_prefix.upper())
+            scan_index = self.get_scan_index(scan_number)
+
+            # Select the image set
+            last_image_index = image_offset+num_image
+            print(f'Available good image set index range: [{image_offset}, {last_image_index})')
+            image_set_approved = False
+            if scan_index is not None:
+                scan_info = stack_info[scan_index]
+                print(f'Current starting image offset and number of images: '+
+                        f'{scan_info["starting_image_offset"]} and {scan_info["num_image"]}')
+                image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y')
+            if not image_set_approved:
+                print(f'Default starting image offset and number of images: '+
+                        f'{image_offset} and {num_image}')
+                image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y')
+                if image_set_approved:
+                    offset = image_offset
+                    num = last_image_index-offset
+                while not image_set_approved:
+                    offset = input_int(f'Enter the starting image offset', ge=image_offset,
+                            lt=last_image_index)#, default=image_offset)
+                    num = input_int(f'Enter the number of images', ge=1,
+                            le=last_image_index-offset)#, default=last_image_index-offset)
+                    print(f'Current starting image offset and number of images: {offset} and {num}')
+                    image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y')
+                if scan_index is not None:
+                    scan_info['starting_image_offset'] = offset
+                    scan_info['num_image'] = num
+                    scan_info['ref_x'] = parser.horizontal_shift
+                    scan_info['ref_z'] = parser.vertical_shift
+                else:
+                    stack_info.append({'scan_number': scan_number, 'starting_image_offset': offset,
+                            'num_image': num, 'ref_x': parser.horizontal_shift,
+                            'ref_z': parser.vertical_shift})
+        self.stack_info = stack_info
+
+    def cli(self, **cli_kwargs):
+        if cli_kwargs.get('attr_desc') is not None:
+            attr_desc = f'{cli_kwargs["attr_desc"]} '
+        else:
+            attr_desc = ''
+        station = cli_kwargs.get('station')
+        detector = cli_kwargs.get('detector')
+        print(f'\n -- Configure the location of the {attr_desc}scan data -- ')
+        if station in ('id1a3', 'id3a'):
+            self.spec_file = cli_kwargs['spec_file']
+            tomo_scan_numbers = cli_kwargs['tomo_scan_numbers']
+            scan_type = cli_kwargs['scan_type']
+            self.scan_numbers_cli(attr_desc, station=station, tomo_scan_numbers=tomo_scan_numbers,
+                    scan_type=scan_type)
+        else:
+            self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path')
+            self.scan_numbers_cli(attr_desc)
+        self.image_range_cli(attr_desc, detector.prefix)
+
+
+class TomoField(SpecScans):
+    theta_range: dict = {}
+
+    @validator('theta_range')
+    def validate_theta_range(cls, theta_range):
+        if len(theta_range) != 3 and len(theta_range) != 4:
+            raise ValueError(f'Invalid theta range {theta_range}')
+        is_num(theta_range['start'], raise_error=True)
+        is_num(theta_range['end'], raise_error=True)
+        is_int(theta_range['num'], gt=1, raise_error=True)
+        if theta_range['end'] <= theta_range['start']:
+            raise ValueError(f'Invalid theta range {theta_range}')
+        if 'start_index' in theta_range:
+            is_int(theta_range['start_index'], ge=0, raise_error=True)
+        return(theta_range)
+
+    @classmethod
+    def construct_from_nxcollection(cls, nxcollection:NXcollection):
+        #RV Can I derive this from the same classfunction for SpecScans by adding theta_range
+        config = {}
+        config['spec_file'] = nxcollection.attrs['spec_file']
+        scan_numbers = []
+        stack_info = []
+        for nxsubentry_name, nxsubentry in nxcollection.items():
+            scan_number = int(nxsubentry_name.split('_')[-1])
+            scan_numbers.append(scan_number)
+            stack_info.append({'scan_number': scan_number,
+                    'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number),
+                    'num_image': len(nxsubentry.sample.rotation_angle),
+                    'ref_x': float(nxsubentry.sample.x_translation),
+                    'ref_z': float(nxsubentry.sample.z_translation)})
+        config['scan_numbers'] = sorted(scan_numbers)
+        config['stack_info'] = stack_info
+        for name in nxcollection.entries:
+            if 'scan_' in name:
+                thetas = np.asarray(nxcollection[name].sample.rotation_angle)
+                config['theta_range'] = {'start': thetas[0], 'end': thetas[-1], 'num': thetas.size}
+                break
+        return(cls(**config))
+
+    def get_horizontal_shifts(self, scan_number=None):
+        horizontal_shifts = []
+        if scan_number is None:
+            scan_numbers = self.scan_numbers
+        else:
+            scan_numbers = [scan_number]
+        for scan_number in scan_numbers:
+            parser = self.get_scanparser(scan_number)
+            horizontal_shifts.append(parser.horizontal_shift)
+        if len(horizontal_shifts) == 1:
+            return(horizontal_shifts[0])
+        else:
+            return(horizontal_shifts)
+
+    def get_vertical_shifts(self, scan_number=None):
+        vertical_shifts = []
+        if scan_number is None:
+            scan_numbers = self.scan_numbers
+        else:
+            scan_numbers = [scan_number]
+        for scan_number in scan_numbers:
+            parser = self.get_scanparser(scan_number)
+            vertical_shifts.append(parser.vertical_shift)
+        if len(vertical_shifts) == 1:
+            return(vertical_shifts[0])
+        else:
+            return(vertical_shifts)
+
+    def theta_range_cli(self, scan_number, attr_desc, station):
+        # Parse the available theta range
+        parser = self.get_scanparser(scan_number)
+        theta_vals = parser.theta_vals
+        spec_theta_start = theta_vals.get('start')
+        spec_theta_end = theta_vals.get('end')
+        spec_num_theta = theta_vals.get('num')
+
+        # Check for consistency of theta ranges between scans
+        if scan_number != self.scan_numbers[0]:
+            parser = self.get_scanparser(self.scan_numbers[0])
+            if (parser.theta_vals.get('start') != spec_theta_start or
+                    parser.theta_vals.get('end') != spec_theta_end or
+                    parser.theta_vals.get('num') != spec_num_theta):
+                raise ValueError(f'Incompatible theta ranges between {attr_desc}scans:'+
+                        f'\n\tScan {scan_number}: {theta_vals}'+
+                        f'\n\tScan {self.scan_numbers[0]}: {parser.theta_vals}')
+            return
+
+        # Select the theta range for the tomo reconstruction from the first scan
+        theta_range_approved = False
+        thetas = np.linspace(spec_theta_start, spec_theta_end, spec_num_theta)
+        delta_theta = thetas[1]-thetas[0]
+        print(f'Theta range obtained from SPEC data: [{spec_theta_start}, {spec_theta_end}]')
+        print(f'Theta step size = {delta_theta}')
+        print(f'Number of theta values: {spec_num_theta}')
+        default_start = None
+        default_end = None
+        if station in ('id1a3', 'id3a'):
+            theta_range_approved = input_yesno(f'Accept this theta range (y/n)?', 'y')
+            if theta_range_approved:
+                self.theta_range = {'start': float(spec_theta_start), 'end': float(spec_theta_end),
+                        'num': int(spec_num_theta), 'start_index': 0}
+                return
+        elif station in ('id3b'):
+            if spec_theta_start <= 0.0 and spec_theta_end >= 180.0:
+                default_start = 0
+                default_end = 180
+            elif spec_theta_end-spec_theta_start == 180:
+                default_start = spec_theta_start
+                default_end = spec_theta_end
+        while not theta_range_approved:
+            theta_start = input_num(f'Enter the first theta (included)', ge=spec_theta_start,
+                    lt=spec_theta_end, default=default_start)
+            theta_index_start = index_nearest(thetas, theta_start)
+            theta_start = thetas[theta_index_start]
+            theta_end = input_num(f'Enter the last theta (excluded)',
+                    ge=theta_start+delta_theta, le=spec_theta_end, default=default_end)
+            theta_index_end = index_nearest(thetas, theta_end)
+            theta_end = thetas[theta_index_end]
+            num_theta = theta_index_end-theta_index_start
+            print(f'Selected theta range: [{theta_start}, {theta_start+delta_theta}, ..., '+
+                    f'{theta_end})')
+            print(f'Number of theta values: {num_theta}')
+            theta_range_approved = input_yesno(f'Accept this theta range (y/n)?', 'y')
+        self.theta_range = {'start': float(theta_start), 'end': float(theta_end),
+                'num': int(num_theta), 'start_index': int(theta_index_start)}
+
+    def image_range_cli(self, attr_desc, detector_prefix):
+        stack_info = self.stack_info
+        for scan_number in self.scan_numbers:
+            # Parse the available image range
+            parser = self.get_scanparser(scan_number)
+            image_offset = parser.starting_image_offset
+            num_image = parser.get_num_image(detector_prefix.upper())
+            scan_index = self.get_scan_index(scan_number)
+
+            # Select the image set matching the theta range
+            num_theta = self.theta_range['num']
+            theta_index_start = self.theta_range['start_index']
+            if num_theta > num_image-theta_index_start:
+                raise ValueError(f'Available {attr_desc}image indices incompatible with thetas:'+
+                        f'\n\tNumber of thetas and offset = {num_theta} and {theta_index_start}'+
+                        f'\n\tNumber of available images {num_image}')
+            if scan_index is not None:
+                scan_info = stack_info[scan_index]
+                scan_info['starting_image_offset'] = image_offset+theta_index_start
+                scan_info['num_image'] = num_theta
+                scan_info['ref_x'] = parser.horizontal_shift
+                scan_info['ref_z'] = parser.vertical_shift
+            else:
+                stack_info.append({'scan_number': scan_number,
+                        'starting_image_offset': image_offset+theta_index_start,
+                        'num_image': num_theta, 'ref_x': parser.horizontal_shift,
+                        'ref_z': parser.vertical_shift})
+        self.stack_info = stack_info
+
+    def cli(self, **cli_kwargs):
+        if cli_kwargs.get('attr_desc') is not None:
+            attr_desc = f'{cli_kwargs["attr_desc"]} '
+        else:
+            attr_desc = ''
+        cycle = cli_kwargs.get('cycle')
+        btr = cli_kwargs.get('btr')
+        station = cli_kwargs.get('station')
+        detector = cli_kwargs.get('detector')
+        sample_name = cli_kwargs.get('sample_name')
+        print(f'\n -- Configure the location of the {attr_desc}scan data -- ')
+        if station in ('id1a3', 'id3a'):
+            basedir = f'/nfs/chess/{station}/{cycle}/{btr}'
+            runs = [d for d in os.listdir(basedir) if os.path.isdir(os.path.join(basedir, d))]
+#RV            index = 15-1
+#RV            index = 7-1
+            if sample_name is not None and sample_name in runs:
+                index = runs.index(sample_name)
+            else:
+                index = input_menu(runs, header='Choose a sample directory')
+            self.spec_file = f'{basedir}/{runs[index]}/spec.log'
+            self.scan_numbers_cli(attr_desc, station=station, scan_type='ts1')
+        else:
+            self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path')
+            self.scan_numbers_cli(attr_desc)
+        for scan_number in self.scan_numbers:
+            self.theta_range_cli(scan_number, attr_desc, station)
+        self.image_range_cli(attr_desc, detector.prefix)
+
+
+class Sample(BaseModel):
+    name: constr(min_length=1)
+    description: Optional[str]
+    rotation_angles: Optional[list]
+    x_translations: Optional[list]
+    z_translations: Optional[list]
+
+    @classmethod
+    def construct_from_nxsample(cls, nxsample:NXsample):
+        config = {}
+        config['name'] = nxsample.name.nxdata
+        if 'description' in nxsample:
+            config['description'] = nxsample.description.nxdata
+        if 'rotation_angle' in nxsample:
+            config['rotation_angle'] = nxsample.rotation_angle.nxdata
+        if 'x_translation' in nxsample:
+            config['x_translation'] = nxsample.x_translation.nxdata
+        if 'z_translation' in nxsample:
+            config['z_translation'] = nxsample.z_translation.nxdata
+        return(cls(**config))
+
+    def cli(self):
+        print('\n -- Configure the sample metadata -- ')
+#RV        self.name = 'sobhani-3249-A'
+#RV        self.name = 'tenstom_1304r-1'
+        self.set_single_attr_cli('name', 'the sample name')
+#RV        self.description = 'test sample'
+        self.set_single_attr_cli('description', 'a description of the sample (optional)')
+
+
+class MapConfig(BaseModel):
+    cycle: constr(strip_whitespace=True, min_length=1)
+    btr: constr(strip_whitespace=True, min_length=1)
+    title: constr(strip_whitespace=True, min_length=1)
+    station: Literal['id1a3', 'id3a', 'id3b'] = None
+    sample: Sample
+    detector: Detector = Detector.construct()
+    tomo_fields: TomoField
+    dark_field: Optional[FlatField]
+    bright_field: FlatField
+    _thetas: list[float] = PrivateAttr()
+    _field_types = ({'name': 'dark_field', 'image_key': 2}, {'name': 'bright_field',
+            'image_key': 1}, {'name': 'tomo_fields', 'image_key': 0})
+
+    @classmethod
+    def construct_from_nxentry(cls, nxentry:NXentry):
+        config = {}
+        config['cycle'] = nxentry.instrument.source.attrs['cycle']
+        config['btr'] = nxentry.instrument.source.attrs['btr']
+        config['title'] = nxentry.nxname
+        config['station'] = nxentry.instrument.source.attrs['station']
+        config['sample'] = Sample.construct_from_nxsample(nxentry['sample'])
+        for nxobject_name, nxobject in nxentry.spec_scans.items():
+            if isinstance(nxobject, NXcollection):
+                config[nxobject_name] = SpecScans.construct_from_nxcollection(nxobject)
+        return(cls(**config))
+
+#FIX cache?
+    @property
+    def thetas(self):
+        try:
+            return(self._thetas)
+        except:
+            theta_range = self.tomo_fields.theta_range
+            self._thetas = list(np.linspace(theta_range['start'], theta_range['end'],
+                    theta_range['num']))
+            return(self._thetas)
+
+    def cli(self):
+        print('\n -- Configure a map from a set of SPEC scans (dark, bright, and tomo), '+
+                'and / or detector data -- ')
+#RV        self.cycle = '2021-3'
+#RV        self.cycle = '2022-2'
+#RV        self.cycle = '2023-1'
+        self.set_single_attr_cli('cycle', 'beam cycle')
+#RV        self.btr = 'z-3234-A'
+#RV        self.btr = 'sobhani-3249-A'
+#RV        self.btr = 'przybyla-3606-a'
+        self.set_single_attr_cli('btr', 'BTR')
+#RV        self.title = 'z-3234-A'
+#RV        self.title = 'tomo7C'
+#RV        self.title = 'cmc-test-dwell-1'
+        self.set_single_attr_cli('title', 'title for the map entry')
+#RV        self.station = 'id3a'
+#RV        self.station = 'id3b'
+#RV        self.station = 'id1a3'
+        self.set_single_attr_cli('station', 'name of the station at which scans were collected '+
+                '(currently choose from: id1a3, id3a, id3b)')
+        import_scanparser(self.station)
+        self.set_single_attr_cli('sample')
+        use_detector_config = False
+        if hasattr(self.detector, 'prefix') and len(self.detector.prefix):
+            use_detector_config = input_yesno(f'Current detector settings:\n{self.detector}\n'+
+                    f'Keep these settings? (y/n)')
+        if not use_detector_config:
+            menu_options = ['not listed', 'andor2', 'manta', 'retiga']
+            input_mode = input_menu(menu_options, header='Choose one of the following detector '+
+                    'configuration options')
+            if input_mode:
+                detector_config_file = f'{menu_options[input_mode]}.yaml'
+                have_detector_config = self.detector.construct_from_yaml(detector_config_file)
+            else:
+                have_detector_config = False
+            if not have_detector_config:
+                self.set_single_attr_cli('detector', 'detector')
+        self.set_single_attr_cli('tomo_fields', 'Tomo field', chain_attr_desc=True,
+                cycle=self.cycle, btr=self.btr, station=self.station, detector=self.detector,
+                sample_name=self.sample.name)
+        if self.station in ('id1a3', 'id3a'):
+            have_dark_field = True
+            tomo_spec_file = self.tomo_fields.spec_file
+        else:
+            have_dark_field = input_yesno(f'Are Dark field images available? (y/n)')
+            tomo_spec_file = None
+        if have_dark_field:
+            self.set_single_attr_cli('dark_field', 'Dark field', chain_attr_desc=True,
+                    station=self.station, detector=self.detector, spec_file=tomo_spec_file,
+                    tomo_scan_numbers=self.tomo_fields.scan_numbers, scan_type='df1')
+        self.set_single_attr_cli('bright_field', 'Bright field', chain_attr_desc=True,
+                station=self.station, detector=self.detector, spec_file=tomo_spec_file,
+                    tomo_scan_numbers=self.tomo_fields.scan_numbers, scan_type='bf1')
+
+    def construct_nxentry(self, nxroot, include_raw_data=True):
+        # Construct base NXentry
+        nxentry = NXentry()
+
+        # Add an NXentry to the NXroot
+        nxroot[self.title] = nxentry
+        nxroot.attrs['default'] = self.title
+        nxentry.definition = 'NXtomo'
+#        nxentry.attrs['default'] = 'data'
+
+        # Add an NXinstrument to the NXentry
+        nxinstrument = NXinstrument()
+        nxentry.instrument = nxinstrument
+
+        # Add an NXsource to the NXinstrument
+        nxsource = NXsource()
+        nxinstrument.source = nxsource
+        nxsource.type = 'Synchrotron X-ray Source'
+        nxsource.name = 'CHESS'
+        nxsource.probe = 'x-ray'
+ 
+        # Tag the NXsource with the runinfo (as an attribute)
+        nxsource.attrs['cycle'] = self.cycle
+        nxsource.attrs['btr'] = self.btr
+        nxsource.attrs['station'] = self.station
+
+        # Add an NXdetector to the NXinstrument (don't fill in data fields yet)
+        nxinstrument.detector = self.detector.construct_nxdetector()
+
+        # Add an NXsample to NXentry (don't fill in data fields yet)
+        nxsample = NXsample()
+        nxentry.sample = nxsample
+        nxsample.name = self.sample.name
+        nxsample.description = self.sample.description
+
+        # Add an NXcollection to the base NXentry to hold metadata about the spec scans in the map
+        # Also obtain the data fields in NXsample and NXdetector
+        nxspec_scans = NXcollection()
+        nxentry.spec_scans = nxspec_scans
+        image_keys = []
+        sequence_numbers = []
+        image_stacks = []
+        rotation_angles = []
+        x_translations = []
+        z_translations = []
+        for field_type in self._field_types:
+            field_name = field_type['name']
+            field = getattr(self, field_name)
+            if field is None:
+                continue
+            image_key = field_type['image_key']
+            if field_type['name'] == 'tomo_fields':
+                thetas = self.thetas
+            else:
+                thetas = None
+            # Add the scans in a single spec file
+            nxspec_scans[field_name] = field.construct_nxcollection(image_key, thetas,
+                    self.detector)
+            if include_raw_data:
+                image_stacks += field.get_detector_data(self.detector.prefix)
+                for scan_number in field.scan_numbers:
+                    parser = field.get_scanparser(scan_number)
+                    scan_info = field.stack_info[field.get_scan_index(scan_number)]
+                    num_image = scan_info['num_image']
+                    image_keys += num_image*[image_key]
+                    sequence_numbers += [i for i in range(num_image)]
+                    if thetas is None:
+                        rotation_angles += scan_info['num_image']*[0.0]
+                    else:
+                        assert(num_image == len(thetas))
+                        rotation_angles += thetas
+                    x_translations += scan_info['num_image']*[scan_info['ref_x']]
+                    z_translations += scan_info['num_image']*[scan_info['ref_z']]
+
+        if include_raw_data:
+            # Add image data to NXdetector
+            nxinstrument.detector.image_key = image_keys
+            nxinstrument.detector.sequence_number = sequence_numbers
+            nxinstrument.detector.data = np.concatenate([image for image in image_stacks])
+
+            # Add image data to NXsample
+            nxsample.rotation_angle = rotation_angles
+            nxsample.rotation_angle.attrs['units'] = 'degrees'
+            nxsample.x_translation = x_translations
+            nxsample.x_translation.attrs['units'] = 'mm'
+            nxsample.z_translation = z_translations
+            nxsample.z_translation.attrs['units'] = 'mm'
+
+            # Add an NXdata to NXentry
+            nxdata = NXdata()
+            nxentry.data = nxdata
+            nxdata.makelink(nxentry.instrument.detector.data, name='data')
+            nxdata.makelink(nxentry.instrument.detector.image_key)
+            nxdata.makelink(nxentry.sample.rotation_angle)
+            nxdata.makelink(nxentry.sample.x_translation)
+            nxdata.makelink(nxentry.sample.z_translation)
+#            nxdata.attrs['axes'] = ['field', 'row', 'column']
+#            nxdata.attrs['field_indices'] = 0
+#            nxdata.attrs['row_indices'] = 1
+#            nxdata.attrs['column_indices'] = 2
+
+
+class TomoWorkflow(BaseModel):
+    sample_maps: conlist(item_type=MapConfig, min_items=1) = [MapConfig.construct()]
+
+    @classmethod
+    def construct_from_nexus(cls, filename):
+        nxroot = nxload(filename)
+        sample_maps = []
+        config = {'sample_maps': sample_maps}
+        for nxentry_name, nxentry in nxroot.items():
+            sample_maps.append(MapConfig.construct_from_nxentry(nxentry))
+        return(cls(**config))
+
+    def cli(self):
+        print('\n -- Configure a map -- ')
+        self.set_list_attr_cli('sample_maps', 'sample map')
+
+    def construct_nxfile(self, filename, mode='w-'):
+        nxroot = NXroot()
+        t0 = time()
+        for sample_map in self.sample_maps:
+            logger.info(f'Start constructing the {sample_map.title} map.')
+            import_scanparser(sample_map.station)
+            sample_map.construct_nxentry(nxroot)
+        logger.info(f'Constructed all sample maps in {time()-t0:.2f} seconds.')
+        logger.info(f'Start saving all sample maps to {filename}.')
+        nxroot.save(filename, mode=mode)
+
+    def write_to_nexus(self, filename):
+        t0 = time()
+        self.construct_nxfile(filename, mode='w')
+        logger.info(f'Saved all sample maps to {filename} in {time()-t0:.2f} seconds.')
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/workflow/run_tomo.py	Mon Mar 20 18:44:23 2023 +0000
@@ -0,0 +1,1629 @@
+#!/usr/bin/env python3
+
+import logging
+logger = logging.getLogger(__name__)
+
+import numpy as np
+try:
+    import numexpr as ne
+except:
+    pass
+try:
+    import scipy.ndimage as spi
+except:
+    pass
+
+from multiprocessing import cpu_count
+from nexusformat.nexus import *
+from os import mkdir
+from os import path as os_path
+try:
+    from skimage.transform import iradon
+except:
+    pass
+try:
+    from skimage.restoration import denoise_tv_chambolle
+except:
+    pass
+from time import time
+try:
+    import tomopy
+except:
+    pass
+from yaml import safe_load, safe_dump
+
+try:
+    from msnctools.fit import Fit
+except:
+    from fit import Fit
+try:
+    from msnctools.general import illegal_value, is_int, is_int_pair, is_num, is_index_range, \
+            input_int, input_num, input_yesno, input_menu, draw_mask_1d, select_image_bounds, \
+            select_one_image_bound, clear_imshow, quick_imshow, clear_plot, quick_plot
+except:
+    from general import illegal_value, is_int, is_int_pair, is_num, is_index_range, \
+            input_int, input_num, input_yesno, input_menu, draw_mask_1d, select_image_bounds, \
+            select_one_image_bound, clear_imshow, quick_imshow, clear_plot, quick_plot
+
+try:
+    from workflow.models import import_scanparser, FlatField, TomoField, TomoWorkflow
+    from workflow.__version__ import __version__
+except:
+    pass
+
+num_core_tomopy_limit = 24
+
+def nxcopy(nxobject:NXobject, exclude_nxpaths:list[str]=[], nxpath_prefix:str='') -> NXobject:
+    '''Function that returns a copy of a nexus object, optionally exluding certain child items.
+
+    :param nxobject: the original nexus object to return a "copy" of
+    :type nxobject: nexusformat.nexus.NXobject
+    :param exlude_nxpaths: a list of paths to child nexus objects that
+        should be exluded from the returned "copy", defaults to `[]`
+    :type exclude_nxpaths: list[str], optional
+    :param nxpath_prefix: For use in recursive calls from inside this
+        function only!
+    :type nxpath_prefix: str
+    :return: a copy of `nxobject` with some children optionally exluded.
+    :rtype: NXobject
+    '''
+
+    nxobject_copy = nxobject.__class__()
+    if not len(nxpath_prefix):
+        if 'default' in nxobject.attrs:
+            nxobject_copy.attrs['default'] = nxobject.attrs['default']
+    else:
+        for k, v in nxobject.attrs.items():
+            nxobject_copy.attrs[k] = v
+
+    for k, v in nxobject.items():
+        nxpath = os_path.join(nxpath_prefix, k)
+
+        if nxpath in exclude_nxpaths:
+            continue
+
+        if isinstance(v, NXgroup):
+            nxobject_copy[k] = nxcopy(v, exclude_nxpaths=exclude_nxpaths,
+                    nxpath_prefix=os_path.join(nxpath_prefix, k))
+        else:
+            nxobject_copy[k] = v
+
+    return(nxobject_copy)
+
+class set_numexpr_threads:
+
+    def __init__(self, num_core):
+        if num_core is None or num_core < 1 or num_core > cpu_count():
+            self.num_core = cpu_count()
+        else:
+            self.num_core = num_core
+
+    def __enter__(self):
+        self.num_core_org = ne.set_num_threads(self.num_core)
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        ne.set_num_threads(self.num_core_org)
+
+class Tomo:
+    """Processing tomography data with misalignment.
+    """
+    def __init__(self, galaxy_flag=False, num_core=-1, output_folder='.', save_figs=None,
+            test_mode=False):
+        """Initialize with optional config input file or dictionary
+        """
+        if not isinstance(galaxy_flag, bool):
+            raise ValueError(f'Invalid parameter galaxy_flag ({galaxy_flag})')
+        self.galaxy_flag = galaxy_flag
+        self.num_core = num_core
+        if self.galaxy_flag:
+            if output_folder != '.':
+                logger.warning('Ignoring output_folder in galaxy mode')
+            self.output_folder = '.'
+            if test_mode != False:
+                logger.warning('Ignoring test_mode in galaxy mode')
+            self.test_mode = False
+            if save_figs is not None:
+                logger.warning('Ignoring save_figs in galaxy mode')
+            save_figs = 'only'
+        else:
+            self.output_folder = os_path.abspath(output_folder)
+            if not os_path.isdir(output_folder):
+                mkdir(os_path.abspath(output_folder))
+            if not isinstance(test_mode, bool):
+                raise ValueError(f'Invalid parameter test_mode ({test_mode})')
+            self.test_mode = test_mode
+            if save_figs is None:
+                save_figs = 'no'
+        self.test_config = {}
+        if self.test_mode:
+            if save_figs != 'only':
+                logger.warning('Ignoring save_figs in test mode')
+            save_figs = 'only'
+        if save_figs == 'only':
+            self.save_only = True
+            self.save_figs = True
+        elif save_figs == 'yes':
+            self.save_only = False
+            self.save_figs = True
+        elif save_figs == 'no':
+            self.save_only = False
+            self.save_figs = False
+        else:
+            raise ValueError(f'Invalid parameter save_figs ({save_figs})')
+        if self.save_only:
+            self.block = False
+        else:
+            self.block = True
+        if self.num_core == -1:
+            self.num_core = cpu_count()
+        if not is_int(self.num_core, gt=0, log=False):
+            raise ValueError(f'Invalid parameter num_core ({num_core})')
+        if self.num_core > cpu_count():
+            logger.warning(f'num_core = {self.num_core} is larger than the number of available '
+                    f'processors and reduced to {cpu_count()}')
+            self.num_core= cpu_count()
+
+    def read(self, filename):
+        extension = os_path.splitext(filename)[1]
+        if extension == '.yml' or extension == '.yaml':
+            with open(filename, 'r') as f:
+                config = safe_load(f)
+#            if len(config) > 1:
+#                raise ValueError(f'Multiple root entries in {filename} not yet implemented')
+#            if len(list(config.values())[0]) > 1:
+#                raise ValueError(f'Multiple sample maps in {filename} not yet implemented')
+            return(config)
+        elif extension == '.nxs':
+            with NXFile(filename, mode='r') as nxfile:
+                nxroot = nxfile.readfile()
+            return(nxroot)
+        else:
+            raise ValueError(f'Invalid filename extension ({extension})')
+
+    def write(self, data, filename):
+        extension = os_path.splitext(filename)[1]
+        if extension == '.yml' or extension == '.yaml':
+            with open(filename, 'w') as f:
+                safe_dump(data, f)
+        elif extension == '.nxs':
+            data.save(filename, mode='w')
+        elif extension == '.nc':
+            data.to_netcdf(os_path=filename)
+        else:
+            raise ValueError(f'Invalid filename extension ({extension})')
+
+    def gen_reduced_data(self, data, img_x_bounds=None):
+        """Generate the reduced tomography images.
+        """
+        logger.info('Generate the reduced tomography images')
+
+        # Create plot galaxy path directory if needed
+        if self.galaxy_flag and not os_path.exists('tomo_reduce_plots'):
+            mkdir('tomo_reduce_plots')
+
+        if isinstance(data, dict):
+            # Create Nexus format object from input dictionary
+            wf = TomoWorkflow(**data)
+            if len(wf.sample_maps) > 1:
+                raise ValueError(f'Multiple sample maps not yet implemented')
+#            print(f'\nwf:\n{wf}\n')
+            nxroot = NXroot()
+            t0 = time()
+            for sample_map in wf.sample_maps:
+                logger.info(f'Start constructing the {sample_map.title} map.')
+                import_scanparser(sample_map.station)
+                sample_map.construct_nxentry(nxroot, include_raw_data=False)
+            logger.info(f'Constructed all sample maps in {time()-t0:.2f} seconds.')
+            nxentry = nxroot[nxroot.attrs['default']]
+            # Get test mode configuration info
+            if self.test_mode:
+                self.test_config = data['sample_maps'][0]['test_mode']
+        elif isinstance(data, NXroot):
+            nxentry = data[data.attrs['default']]
+        else:
+            raise ValueError(f'Invalid parameter data ({data})')
+
+        # Create an NXprocess to store data reduction (meta)data
+        reduced_data = NXprocess()
+
+        # Generate dark field
+        if 'dark_field' in nxentry['spec_scans']:
+            reduced_data = self._gen_dark(nxentry, reduced_data)
+
+        # Generate bright field
+        reduced_data = self._gen_bright(nxentry, reduced_data)
+
+        # Set vertical detector bounds for image stack
+        img_x_bounds = self._set_detector_bounds(nxentry, reduced_data, img_x_bounds=img_x_bounds)
+        logger.info(f'img_x_bounds = {img_x_bounds}')
+        reduced_data['img_x_bounds'] = img_x_bounds
+
+        # Set zoom and/or theta skip to reduce memory the requirement
+        zoom_perc, num_theta_skip = self._set_zoom_or_skip()
+        if zoom_perc is not None:
+            reduced_data.attrs['zoom_perc'] = zoom_perc
+        if num_theta_skip is not None:
+            reduced_data.attrs['num_theta_skip'] = num_theta_skip
+
+        # Generate reduced tomography fields
+        reduced_data = self._gen_tomo(nxentry, reduced_data)
+
+        # Create a copy of the input Nexus object and remove raw and any existing reduced data
+        if isinstance(data, NXroot):
+            exclude_items = [f'{nxentry._name}/reduced_data/data',
+                    f'{nxentry._name}/instrument/detector/data',
+                    f'{nxentry._name}/instrument/detector/image_key',
+                    f'{nxentry._name}/instrument/detector/sequence_number',
+                    f'{nxentry._name}/sample/rotation_angle',
+                    f'{nxentry._name}/sample/x_translation',
+                    f'{nxentry._name}/sample/z_translation',
+                    f'{nxentry._name}/data/data',
+                    f'{nxentry._name}/data/image_key',
+                    f'{nxentry._name}/data/rotation_angle',
+                    f'{nxentry._name}/data/x_translation',
+                    f'{nxentry._name}/data/z_translation']
+            nxroot = nxcopy(data, exclude_nxpaths=exclude_items)
+            nxentry = nxroot[nxroot.attrs['default']]
+
+        # Add the reduced data NXprocess
+        nxentry.reduced_data = reduced_data
+
+        if 'data' not in nxentry:
+            nxentry.data = NXdata()
+        nxentry.attrs['default'] = 'data'
+        nxentry.data.makelink(nxentry.reduced_data.data.tomo_fields, name='reduced_data')
+        nxentry.data.makelink(nxentry.reduced_data.rotation_angle, name='rotation_angle')
+        nxentry.data.attrs['signal'] = 'reduced_data'
+ 
+        return(nxroot)
+
+    def find_centers(self, nxroot, center_rows=None, center_stack_index=None):
+        """Find the calibrated center axis info
+        """
+        logger.info('Find the calibrated center axis info')
+
+        if not isinstance(nxroot, NXroot):
+            raise ValueError(f'Invalid parameter nxroot ({nxroot})')
+        nxentry = nxroot[nxroot.attrs['default']]
+        if not isinstance(nxentry, NXentry):
+            raise ValueError(f'Invalid nxentry ({nxentry})')
+        if self.galaxy_flag:
+            if center_rows is not None:
+                center_rows = tuple(center_rows)
+                if not is_int_pair(center_rows):
+                    raise ValueError(f'Invalid parameter center_rows ({center_rows})')
+        elif center_rows is not None:
+            logger.warning(f'Ignoring parameter center_rows ({center_rows})')
+            center_rows = None
+        if self.galaxy_flag:
+            if center_stack_index is not None and not is_int(center_stack_index, ge=0):
+                raise ValueError(f'Invalid parameter center_stack_index ({center_stack_index})')
+        elif center_stack_index is not None:
+            logger.warning(f'Ignoring parameter center_stack_index ({center_stack_index})')
+            center_stack_index = None
+
+        # Create plot galaxy path directory and path if needed
+        if self.galaxy_flag:
+            if not os_path.exists('tomo_find_centers_plots'):
+                mkdir('tomo_find_centers_plots')
+            path = 'tomo_find_centers_plots'
+        else:
+            path = self.output_folder
+
+        # Check if reduced data is available
+        if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data):
+            raise KeyError(f'Unable to find valid reduced data in {nxentry}.')
+
+        # Select the image stack to calibrate the center axis
+        #   reduced data axes order: stack,theta,row,column
+        #   Note: Nexus cannot follow a link if the data it points to is too big,
+        #         so get the data from the actual place, not from nxentry.data
+        tomo_fields_shape = nxentry.reduced_data.data.tomo_fields.shape
+        if len(tomo_fields_shape) != 4 or any(True for dim in tomo_fields_shape if not dim):
+            raise KeyError('Unable to load the required reduced tomography stack')
+        num_tomo_stacks = tomo_fields_shape[0]
+        if num_tomo_stacks == 1:
+            center_stack_index = 0
+            default = 'n'
+        else:
+            if self.test_mode:
+                center_stack_index = self.test_config['center_stack_index']-1 # make offset 0
+            elif self.galaxy_flag:
+                if center_stack_index is None:
+                    center_stack_index = int(num_tomo_stacks/2)
+                if center_stack_index >= num_tomo_stacks:
+                    raise ValueError(f'Invalid parameter center_stack_index ({center_stack_index})')
+            else:
+                center_stack_index = input_int('\nEnter tomography stack index to calibrate the '
+                        'center axis', ge=1, le=num_tomo_stacks, default=int(1+num_tomo_stacks/2))
+                center_stack_index -= 1
+            default = 'y'
+
+        # Get thetas (in degrees)
+        thetas = np.asarray(nxentry.reduced_data.rotation_angle)
+
+        # Get effective pixel_size
+        if 'zoom_perc' in nxentry.reduced_data:
+            eff_pixel_size = 100.*(nxentry.instrument.detector.x_pixel_size/
+                nxentry.reduced_data.attrs['zoom_perc'])
+        else:
+            eff_pixel_size = nxentry.instrument.detector.x_pixel_size
+
+        # Get cross sectional diameter
+        cross_sectional_dim = tomo_fields_shape[3]*eff_pixel_size
+        logger.debug(f'cross_sectional_dim = {cross_sectional_dim}')
+
+        # Determine center offset at sample row boundaries
+        logger.info('Determine center offset at sample row boundaries')
+
+        # Lower row center
+        if self.test_mode:
+            lower_row = self.test_config['lower_row']
+        elif self.galaxy_flag:
+            if center_rows is None:
+                lower_row = 0
+            else:
+                lower_row = min(center_rows)
+                if not 0 <= lower_row < tomo_fields_shape[2]-1:
+                    raise ValueError(f'Invalid parameter center_rows ({center_rows})')
+        else:
+            lower_row = select_one_image_bound(
+                    nxentry.reduced_data.data.tomo_fields[center_stack_index,0,:,:], 0, bound=0,
+                    title=f'theta={round(thetas[0], 2)+0}',
+                    bound_name='row index to find lower center', default=default, raise_error=True)
+        logger.debug('Finding center...')
+        t0 = time()
+        lower_center_offset = self._find_center_one_plane(
+                #np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index,:,lower_row,:]),
+                nxentry.reduced_data.data.tomo_fields[center_stack_index,:,lower_row,:],
+                lower_row, thetas, eff_pixel_size, cross_sectional_dim, path=path,
+                num_core=self.num_core)
+        logger.debug(f'... done in {time()-t0:.2f} seconds')
+        logger.debug(f'lower_row = {lower_row:.2f}')
+        logger.debug(f'lower_center_offset = {lower_center_offset:.2f}')
+
+        # Upper row center
+        if self.test_mode:
+            upper_row = self.test_config['upper_row']
+        elif self.galaxy_flag:
+            if center_rows is None:
+                upper_row = tomo_fields_shape[2]-1
+            else:
+                upper_row = max(center_rows)
+                if not lower_row < upper_row < tomo_fields_shape[2]:
+                    raise ValueError(f'Invalid parameter center_rows ({center_rows})')
+        else:
+            upper_row = select_one_image_bound(
+                    nxentry.reduced_data.data.tomo_fields[center_stack_index,0,:,:], 0,
+                    bound=tomo_fields_shape[2]-1, title=f'theta={round(thetas[0], 2)+0}',
+                    bound_name='row index to find upper center', default=default, raise_error=True)
+        logger.debug('Finding center...')
+        t0 = time()
+        upper_center_offset = self._find_center_one_plane(
+                #np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index,:,upper_row,:]),
+                nxentry.reduced_data.data.tomo_fields[center_stack_index,:,upper_row,:],
+                upper_row, thetas, eff_pixel_size, cross_sectional_dim, path=path,
+                num_core=self.num_core)
+        logger.debug(f'... done in {time()-t0:.2f} seconds')
+        logger.debug(f'upper_row = {upper_row:.2f}')
+        logger.debug(f'upper_center_offset = {upper_center_offset:.2f}')
+
+        center_config = {'lower_row': lower_row, 'lower_center_offset': lower_center_offset,
+                'upper_row': upper_row, 'upper_center_offset': upper_center_offset}
+        if num_tomo_stacks > 1:
+            center_config['center_stack_index'] = center_stack_index+1 # save as offset 1
+
+        # Save test data to file
+        if self.test_mode:
+            with open(f'{self.output_folder}/center_config.yaml', 'w') as f:
+                safe_dump(center_config, f)
+
+        return(center_config)
+
+    def reconstruct_data(self, nxroot, center_info, x_bounds=None, y_bounds=None):
+        """Reconstruct the tomography data.
+        """
+        logger.info('Reconstruct the tomography data')
+
+        if not isinstance(nxroot, NXroot):
+            raise ValueError(f'Invalid parameter nxroot ({nxroot})')
+        nxentry = nxroot[nxroot.attrs['default']]
+        if not isinstance(nxentry, NXentry):
+            raise ValueError(f'Invalid nxentry ({nxentry})')
+        if not isinstance(center_info, dict):
+            raise ValueError(f'Invalid parameter center_info ({center_info})')
+
+        # Create plot galaxy path directory and path if needed
+        if self.galaxy_flag:
+            if not os_path.exists('tomo_reconstruct_plots'):
+                mkdir('tomo_reconstruct_plots')
+            path = 'tomo_reconstruct_plots'
+        else:
+            path = self.output_folder
+
+        # Check if reduced data is available
+        if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data):
+            raise KeyError(f'Unable to find valid reduced data in {nxentry}.')
+
+        # Create an NXprocess to store image reconstruction (meta)data
+        nxprocess = NXprocess()
+
+        # Get rotation axis rows and centers
+        lower_row = center_info.get('lower_row')
+        lower_center_offset = center_info.get('lower_center_offset')
+        upper_row = center_info.get('upper_row')
+        upper_center_offset = center_info.get('upper_center_offset')
+        if (lower_row is None or lower_center_offset is None or upper_row is None or
+                upper_center_offset is None):
+            raise KeyError(f'Unable to find valid calibrated center axis info in {center_info}.')
+        center_slope = (upper_center_offset-lower_center_offset)/(upper_row-lower_row)
+
+        # Get thetas (in degrees)
+        thetas = np.asarray(nxentry.reduced_data.rotation_angle)
+
+        # Reconstruct tomography data
+        #   reduced data axes order: stack,theta,row,column
+        #   reconstructed data order in each stack: row/z,x,y
+        #   Note: Nexus cannot follow a link if the data it points to is too big,
+        #         so get the data from the actual place, not from nxentry.data
+        if 'zoom_perc' in nxentry.reduced_data:
+            res_title = f'{nxentry.reduced_data.attrs["zoom_perc"]}p'
+        else:
+            res_title = 'fullres'
+        load_error = False
+        num_tomo_stacks = nxentry.reduced_data.data.tomo_fields.shape[0]
+        tomo_recon_stacks = num_tomo_stacks*[np.array([])]
+        for i in range(num_tomo_stacks):
+            # Convert reduced data stack from theta,row,column to row,theta,column
+            logger.debug(f'Reading reduced data stack {i+1}...')
+            t0 = time()
+            tomo_stack = np.asarray(nxentry.reduced_data.data.tomo_fields[i])
+            logger.debug(f'... done in {time()-t0:.2f} seconds')
+            if len(tomo_stack.shape) != 3 or any(True for dim in tomo_stack.shape if not dim):
+                raise ValueError(f'Unable to load tomography stack {i+1} for reconstruction')
+            tomo_stack = np.swapaxes(tomo_stack, 0, 1)
+            assert(len(thetas) == tomo_stack.shape[1])
+            assert(0 <= lower_row < upper_row < tomo_stack.shape[0])
+            center_offsets = [lower_center_offset-lower_row*center_slope,
+                    upper_center_offset+(tomo_stack.shape[0]-1-upper_row)*center_slope]
+            t0 = time()
+            logger.debug(f'Running _reconstruct_one_tomo_stack on {self.num_core} cores ...')
+            tomo_recon_stack = self._reconstruct_one_tomo_stack(tomo_stack, thetas,
+                    center_offsets=center_offsets, num_core=self.num_core, algorithm='gridrec')
+            logger.debug(f'... done in {time()-t0:.2f} seconds')
+            logger.info(f'Reconstruction of stack {i+1} took {time()-t0:.2f} seconds')
+
+            # Combine stacks
+            tomo_recon_stacks[i] = tomo_recon_stack
+
+        # Resize the reconstructed tomography data
+        #   reconstructed data order in each stack: row/z,x,y
+        if self.test_mode:
+            x_bounds = self.test_config.get('x_bounds')
+            y_bounds = self.test_config.get('y_bounds')
+            z_bounds = None
+        elif self.galaxy_flag:
+            if x_bounds is not None and not is_int_pair(x_bounds, ge=0,
+                    lt=tomo_recon_stacks[0].shape[1]):
+                raise ValueError(f'Invalid parameter x_bounds ({x_bounds})')
+            if y_bounds is not None and not is_int_pair(y_bounds, ge=0,
+                    lt=tomo_recon_stacks[0].shape[1]):
+                raise ValueError(f'Invalid parameter y_bounds ({y_bounds})')
+            z_bounds = None
+        else:
+            x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_stacks)
+        if x_bounds is None:
+            x_range = (0, tomo_recon_stacks[0].shape[1])
+            x_slice = int(x_range[1]/2)
+        else:
+            x_range = (min(x_bounds), max(x_bounds))
+            x_slice = int((x_bounds[0]+x_bounds[1])/2)
+        if y_bounds is None:
+            y_range = (0, tomo_recon_stacks[0].shape[2])
+            y_slice = int(y_range[1]/2)
+        else:
+            y_range = (min(y_bounds), max(y_bounds))
+            y_slice = int((y_bounds[0]+y_bounds[1])/2)
+        if z_bounds is None:
+            z_range = (0, tomo_recon_stacks[0].shape[0])
+            z_slice = int(z_range[1]/2)
+        else:
+            z_range = (min(z_bounds), max(z_bounds))
+            z_slice = int((z_bounds[0]+z_bounds[1])/2)
+
+        # Plot a few reconstructed image slices
+        if num_tomo_stacks == 1:
+            basetitle = 'recon'
+        else:
+            basetitle = f'recon stack {i+1}'
+        for i, stack in enumerate(tomo_recon_stacks):
+            title = f'{basetitle} {res_title} xslice{x_slice}'
+            quick_imshow(stack[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]],
+                    title=title, path=path, save_fig=self.save_figs, save_only=self.save_only,
+                    block=self.block)
+            title = f'{basetitle} {res_title} yslice{y_slice}'
+            quick_imshow(stack[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice],
+                    title=title, path=path, save_fig=self.save_figs, save_only=self.save_only,
+                    block=self.block)
+            title = f'{basetitle} {res_title} zslice{z_slice}'
+            quick_imshow(stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]],
+                    title=title, path=path, save_fig=self.save_figs, save_only=self.save_only,
+                    block=self.block)
+
+        # Save test data to file
+        #   reconstructed data order in each stack: row/z,x,y
+        if self.test_mode:
+            for i, stack in enumerate(tomo_recon_stacks):
+                np.savetxt(f'{self.output_folder}/recon_stack_{i+1}.txt',
+                        stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e')
+
+        # Add image reconstruction to reconstructed data NXprocess
+        #   reconstructed data order in each stack: row/z,x,y
+        nxprocess.data = NXdata()
+        nxprocess.attrs['default'] = 'data'
+        for k, v in center_info.items():
+            nxprocess[k] = v
+        if x_bounds is not None:
+            nxprocess.x_bounds = x_bounds
+        if y_bounds is not None:
+            nxprocess.y_bounds = y_bounds
+        if z_bounds is not None:
+            nxprocess.z_bounds = z_bounds
+        nxprocess.data['reconstructed_data'] = np.asarray([stack[z_range[0]:z_range[1],
+                x_range[0]:x_range[1],y_range[0]:y_range[1]] for stack in tomo_recon_stacks])
+        nxprocess.data.attrs['signal'] = 'reconstructed_data'
+
+        # Create a copy of the input Nexus object and remove reduced data
+        exclude_items = [f'{nxentry._name}/reduced_data/data', f'{nxentry._name}/data/reduced_data']
+        nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items)
+
+        # Add the reconstructed data NXprocess to the new Nexus object
+        nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']]
+        nxentry_copy.reconstructed_data = nxprocess
+        if 'data' not in nxentry_copy:
+            nxentry_copy.data = NXdata()
+        nxentry_copy.attrs['default'] = 'data'
+        nxentry_copy.data.makelink(nxprocess.data.reconstructed_data, name='reconstructed_data')
+        nxentry_copy.data.attrs['signal'] = 'reconstructed_data'
+
+        return(nxroot_copy)
+
+    def combine_data(self, nxroot, x_bounds=None, y_bounds=None):
+        """Combine the reconstructed tomography stacks.
+        """
+        logger.info('Combine the reconstructed tomography stacks')
+
+        if not isinstance(nxroot, NXroot):
+            raise ValueError(f'Invalid parameter nxroot ({nxroot})')
+        nxentry = nxroot[nxroot.attrs['default']]
+        if not isinstance(nxentry, NXentry):
+            raise ValueError(f'Invalid nxentry ({nxentry})')
+
+        # Create plot galaxy path directory and path if needed
+        if self.galaxy_flag:
+            if not os_path.exists('tomo_combine_plots'):
+                mkdir('tomo_combine_plots')
+            path = 'tomo_combine_plots'
+        else:
+            path = self.output_folder
+
+        # Check if reconstructed image data is available
+        if ('reconstructed_data' not in nxentry or 'reconstructed_data' not in nxentry.data):
+            raise KeyError(f'Unable to find valid reconstructed image data in {nxentry}.')
+
+        # Create an NXprocess to store combined image reconstruction (meta)data
+        nxprocess = NXprocess()
+
+        # Get the reconstructed data
+        #   reconstructed data order: stack,row(z),x,y
+        #   Note: Nexus cannot follow a link if the data it points to is too big,
+        #         so get the data from the actual place, not from nxentry.data
+        num_tomo_stacks = nxentry.reconstructed_data.data.reconstructed_data.shape[0]
+        if num_tomo_stacks == 1:
+            logger.info('Only one stack available: leaving combine_data')
+            return(None)
+
+        # Combine the reconstructed stacks
+        # (load one stack at a time to reduce risk of hitting Nexus data access limit)
+        t0 = time()
+        logger.debug(f'Combining the reconstructed stacks ...')
+        tomo_recon_combined = np.asarray(nxentry.reconstructed_data.data.reconstructed_data[0])
+        if num_tomo_stacks > 2:
+            tomo_recon_combined = np.concatenate([tomo_recon_combined]+
+                    [nxentry.reconstructed_data.data.reconstructed_data[i]
+                    for i in range(1, num_tomo_stacks-1)])
+        if num_tomo_stacks > 1:
+            tomo_recon_combined = np.concatenate([tomo_recon_combined]+
+                    [nxentry.reconstructed_data.data.reconstructed_data[num_tomo_stacks-1]])
+        logger.debug(f'... done in {time()-t0:.2f} seconds')
+        logger.info(f'Combining the reconstructed stacks took {time()-t0:.2f} seconds')
+
+        # Resize the combined tomography data stacks
+        #   combined data order: row/z,x,y
+        if self.test_mode:
+            x_bounds = None
+            y_bounds = None
+            z_bounds = self.test_config.get('z_bounds')
+        elif self.galaxy_flag:
+            if x_bounds is not None and not is_int_pair(x_bounds, ge=0,
+                    lt=tomo_recon_stacks[0].shape[1]):
+                raise ValueError(f'Invalid parameter x_bounds ({x_bounds})')
+            if y_bounds is not None and not is_int_pair(y_bounds, ge=0,
+                    lt=tomo_recon_stacks[0].shape[1]):
+                raise ValueError(f'Invalid parameter y_bounds ({y_bounds})')
+            z_bounds = None
+        else:
+            x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_combined,
+                    z_only=True)
+        if x_bounds is None:
+            x_range = (0, tomo_recon_combined.shape[1])
+            x_slice = int(x_range[1]/2)
+        else:
+            x_range = x_bounds
+            x_slice = int((x_bounds[0]+x_bounds[1])/2)
+        if y_bounds is None:
+            y_range = (0, tomo_recon_combined.shape[2])
+            y_slice = int(y_range[1]/2)
+        else:
+            y_range = y_bounds
+            y_slice = int((y_bounds[0]+y_bounds[1])/2)
+        if z_bounds is None:
+            z_range = (0, tomo_recon_combined.shape[0])
+            z_slice = int(z_range[1]/2)
+        else:
+            z_range = z_bounds
+            z_slice = int((z_bounds[0]+z_bounds[1])/2)
+
+        # Plot a few combined image slices
+        quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]],
+                title=f'recon combined xslice{x_slice}', path=path,
+                save_fig=self.save_figs, save_only=self.save_only, block=self.block)
+        quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice],
+                title=f'recon combined yslice{y_slice}', path=path,
+                save_fig=self.save_figs, save_only=self.save_only, block=self.block)
+        quick_imshow(tomo_recon_combined[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]],
+                title=f'recon combined zslice{z_slice}', path=path,
+                save_fig=self.save_figs, save_only=self.save_only, block=self.block)
+
+        # Save test data to file
+        #   combined data order: row/z,x,y
+        if self.test_mode:
+            np.savetxt(f'{self.output_folder}/recon_combined.txt', tomo_recon_combined[
+                    z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e')
+
+        # Add image reconstruction to reconstructed data NXprocess
+        #   combined data order: row/z,x,y
+        nxprocess.data = NXdata()
+        nxprocess.attrs['default'] = 'data'
+        if x_bounds is not None:
+            nxprocess.x_bounds = x_bounds
+        if y_bounds is not None:
+            nxprocess.y_bounds = y_bounds
+        if z_bounds is not None:
+            nxprocess.z_bounds = z_bounds
+        nxprocess.data['combined_data'] = tomo_recon_combined
+        nxprocess.data.attrs['signal'] = 'combined_data'
+
+        # Create a copy of the input Nexus object and remove reconstructed data
+        exclude_items = [f'{nxentry._name}/reconstructed_data/data',
+                f'{nxentry._name}/data/reconstructed_data']
+        nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items)
+
+        # Add the combined data NXprocess to the new Nexus object
+        nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']]
+        nxentry_copy.combined_data = nxprocess
+        if 'data' not in nxentry_copy:
+            nxentry_copy.data = NXdata()
+        nxentry_copy.attrs['default'] = 'data'
+        nxentry_copy.data.makelink(nxprocess.data.combined_data, name='combined_data')
+        nxentry_copy.data.attrs['signal'] = 'combined_data'
+
+        return(nxroot_copy)
+
+    def _gen_dark(self, nxentry, reduced_data):
+        """Generate dark field.
+        """
+        # Get the dark field images
+        image_key = nxentry.instrument.detector.get('image_key', None)
+        if image_key and 'data' in nxentry.instrument.detector:
+            field_indices = [index for index, key in enumerate(image_key) if key == 2]
+            tdf_stack = nxentry.instrument.detector.data[field_indices,:,:]
+            # RV the default NXtomo form does not accomodate bright or dark field stacks
+        else:
+            dark_field_scans = nxentry.spec_scans.dark_field
+            dark_field = FlatField.construct_from_nxcollection(dark_field_scans)
+            prefix = str(nxentry.instrument.detector.local_name)
+            tdf_stack = dark_field.get_detector_data(prefix)
+            if isinstance(tdf_stack, list):
+                assert(len(tdf_stack) == 1) # TODO
+                tdf_stack = tdf_stack[0]
+
+        # Take median
+        if tdf_stack.ndim == 2:
+            tdf = tdf_stack
+        elif tdf_stack.ndim == 3:
+            tdf = np.median(tdf_stack, axis=0)
+            del tdf_stack
+        else:
+           raise ValueError(f'Invalid tdf_stack shape ({tdf_stack.shape})')
+
+        # Remove dark field intensities above the cutoff
+#RV        tdf_cutoff = None
+        tdf_cutoff = tdf.min()+2*(np.median(tdf)-tdf.min())
+        logger.debug(f'tdf_cutoff = {tdf_cutoff}')
+        if tdf_cutoff is not None:
+            if not is_num(tdf_cutoff, ge=0):
+                logger.warning(f'Ignoring illegal value of tdf_cutoff {tdf_cutoff}')
+            else:
+                tdf[tdf > tdf_cutoff] = np.nan
+                logger.debug(f'tdf_cutoff = {tdf_cutoff}')
+
+        # Remove nans
+        tdf_mean = np.nanmean(tdf)
+        logger.debug(f'tdf_mean = {tdf_mean}')
+        np.nan_to_num(tdf, copy=False, nan=tdf_mean, posinf=tdf_mean, neginf=0.)
+
+        # Plot dark field
+        if self.galaxy_flag:
+            quick_imshow(tdf, title='dark field', path='tomo_reduce_plots', save_fig=self.save_figs,
+                    save_only=self.save_only)
+        elif not self.test_mode:
+            quick_imshow(tdf, title='dark field', path=self.output_folder, save_fig=self.save_figs,
+                    save_only=self.save_only)
+            clear_imshow('dark field')
+#        quick_imshow(tdf, title='dark field', block=True)
+
+        # Add dark field to reduced data NXprocess
+        reduced_data.data = NXdata()
+        reduced_data.data['dark_field'] = tdf
+
+        return(reduced_data)
+
+    def _gen_bright(self, nxentry, reduced_data):
+        """Generate bright field.
+        """
+        # Get the bright field images
+        image_key = nxentry.instrument.detector.get('image_key', None)
+        if image_key and 'data' in nxentry.instrument.detector:
+            field_indices = [index for index, key in enumerate(image_key) if key == 1]
+            tbf_stack = nxentry.instrument.detector.data[field_indices,:,:]
+            # RV the default NXtomo form does not accomodate bright or dark field stacks
+        else:
+            bright_field_scans = nxentry.spec_scans.bright_field
+            bright_field = FlatField.construct_from_nxcollection(bright_field_scans)
+            prefix = str(nxentry.instrument.detector.local_name)
+            tbf_stack = bright_field.get_detector_data(prefix)
+            if isinstance(tbf_stack, list):
+                assert(len(tbf_stack) == 1) # TODO
+                tbf_stack = tbf_stack[0]
+
+        # Take median if more than one image
+        """Median or mean: It may be best to try the median because of some image 
+           artifacts that arise due to crinkles in the upstream kapton tape windows 
+           causing some phase contrast images to appear on the detector.
+           One thing that also may be useful in a future implementation is to do a 
+           brightfield adjustment on EACH frame of the tomo based on a ROI in the 
+           corner of the frame where there is no sample but there is the direct X-ray 
+           beam because there is frame to frame fluctuations from the incoming beam. 
+           We don’t typically account for them but potentially could.
+        """
+        if tbf_stack.ndim == 2:
+            tbf = tbf_stack
+        elif tbf_stack.ndim == 3:
+            tbf = np.median(tbf_stack, axis=0)
+            del tbf_stack
+        else:
+           raise ValueError(f'Invalid tbf_stack shape ({tbf_stacks.shape})')
+
+        # Subtract dark field
+        if 'data' in reduced_data and 'dark_field' in reduced_data.data:
+            tbf -= reduced_data.data.dark_field
+        else:
+            logger.warning('Dark field unavailable')
+
+        # Set any non-positive values to one
+        # (avoid negative bright field values for spikes in dark field)
+        tbf[tbf < 1] = 1
+
+        # Plot bright field
+        if self.galaxy_flag:
+            quick_imshow(tbf, title='bright field', path='tomo_reduce_plots',
+                    save_fig=self.save_figs, save_only=self.save_only)
+        elif not self.test_mode:
+            quick_imshow(tbf, title='bright field', path=self.output_folder,
+                    save_fig=self.save_figs, save_only=self.save_only)
+            clear_imshow('bright field')
+#        quick_imshow(tbf, title='bright field', block=True)
+
+        # Add bright field to reduced data NXprocess
+        if 'data' not in reduced_data: 
+            reduced_data.data = NXdata()
+        reduced_data.data['bright_field'] = tbf
+
+        return(reduced_data)
+
+    def _set_detector_bounds(self, nxentry, reduced_data, img_x_bounds=None):
+        """Set vertical detector bounds for each image stack.
+        Right now the range is the same for each set in the image stack.
+        """
+        if self.test_mode:
+            return(tuple(self.test_config['img_x_bounds']))
+
+        # Get the first tomography image and the reference heights
+        image_key = nxentry.instrument.detector.get('image_key', None)
+        if image_key and 'data' in nxentry.instrument.detector:
+            field_indices = [index for index, key in enumerate(image_key) if key == 0]
+            first_image = np.asarray(nxentry.instrument.detector.data[field_indices[0],:,:])
+            theta = float(nxentry.sample.rotation_angle[field_indices[0]])
+            z_translation_all = nxentry.sample.z_translation[field_indices]
+            vertical_shifts = sorted(list(set(z_translation_all)))
+            num_tomo_stacks = len(vertical_shifts)
+        else:
+            tomo_field_scans = nxentry.spec_scans.tomo_fields
+            tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans)
+            vertical_shifts = tomo_fields.get_vertical_shifts()
+            if not isinstance(vertical_shifts, list):
+               vertical_shifts = [vertical_shifts]
+            prefix = str(nxentry.instrument.detector.local_name)
+            t0 = time()
+            first_image = tomo_fields.get_detector_data(prefix, tomo_fields.scan_numbers[0], 0)
+            logger.debug(f'Getting first image took {time()-t0:.2f} seconds')
+            num_tomo_stacks = len(tomo_fields.scan_numbers)
+            theta = tomo_fields.theta_range['start']
+
+        # Select image bounds
+        title = f'tomography image at theta={round(theta, 2)+0}'
+        if (img_x_bounds is not None and not is_index_range(img_x_bounds, ge=0,
+                le=first_image.shape[0])):
+            raise ValueError(f'Invalid parameter img_x_bounds ({img_x_bounds})')
+        if nxentry.instrument.source.attrs['station'] in ('id1a3', 'id3a'):
+            pixel_size = nxentry.instrument.detector.x_pixel_size
+            # Try to get a fit from the bright field
+            tbf = np.asarray(reduced_data.data.bright_field)
+            tbf_shape = tbf.shape
+            x_sum = np.sum(tbf, 1)
+            x_sum_min = x_sum.min()
+            x_sum_max = x_sum.max()
+            fit = Fit.fit_data(x_sum, 'rectangle', x=np.array(range(len(x_sum))), form='atan',
+                    guess=True)
+            parameters = fit.best_values
+            x_low_fit = parameters.get('center1', None)
+            x_upp_fit = parameters.get('center2', None)
+            sig_low = parameters.get('sigma1', None)
+            sig_upp = parameters.get('sigma2', None)
+            have_fit = fit.success and x_low_fit is not None and x_upp_fit is not None and \
+                    sig_low is not None and sig_upp is not None and \
+                    0 <= x_low_fit < x_upp_fit <= x_sum.size and \
+                    (sig_low+sig_upp)/(x_upp_fit-x_low_fit) < 0.1
+            if have_fit:
+                # Set a 5% margin on each side
+                margin = 0.05*(x_upp_fit-x_low_fit)
+                x_low_fit = max(0, x_low_fit-margin)
+                x_upp_fit = min(tbf_shape[0], x_upp_fit+margin)
+            if num_tomo_stacks == 1:
+                if have_fit:
+                    # Set the default range to enclose the full fitted window
+                    x_low = int(x_low_fit)
+                    x_upp = int(x_upp_fit)
+                else:
+                    # Center a default range of 1 mm (RV: can we get this from the slits?)
+                    num_x_min = int((1.0-0.5*pixel_size)/pixel_size)
+                    x_low = int(0.5*(tbf_shape[0]-num_x_min))
+                    x_upp = x_low+num_x_min
+            else:
+                # Get the default range from the reference heights
+                delta_z = vertical_shifts[1]-vertical_shifts[0]
+                for i in range(2, num_tomo_stacks):
+                    delta_z = min(delta_z, vertical_shifts[i]-vertical_shifts[i-1])
+                logger.debug(f'delta_z = {delta_z}')
+                num_x_min = int((delta_z-0.5*pixel_size)/pixel_size)
+                logger.debug(f'num_x_min = {num_x_min}')
+                if num_x_min > tbf_shape[0]:
+                    logger.warning('Image bounds and pixel size prevent seamless stacking')
+                if have_fit:
+                    # Center the default range relative to the fitted window
+                    x_low = int(0.5*(x_low_fit+x_upp_fit-num_x_min))
+                    x_upp = x_low+num_x_min
+                else:
+                    # Center the default range
+                    x_low = int(0.5*(tbf_shape[0]-num_x_min))
+                    x_upp = x_low+num_x_min
+            if self.galaxy_flag:
+                img_x_bounds = (x_low, x_upp)
+            else:
+                tmp = np.copy(tbf)
+                tmp_max = tmp.max()
+                tmp[x_low,:] = tmp_max
+                tmp[x_upp-1,:] = tmp_max
+                quick_imshow(tmp, title='bright field')
+                tmp = np.copy(first_image)
+                tmp_max = tmp.max()
+                tmp[x_low,:] = tmp_max
+                tmp[x_upp-1,:] = tmp_max
+                quick_imshow(tmp, title=title)
+                del tmp
+                quick_plot((range(x_sum.size), x_sum),
+                        ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'),
+                        ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'),
+                        title='sum over theta and y')
+                print(f'lower bound = {x_low} (inclusive)')
+                print(f'upper bound = {x_upp} (exclusive)]')
+                accept =  input_yesno('Accept these bounds (y/n)?', 'y')
+                clear_imshow('bright field')
+                clear_imshow(title)
+                clear_plot('sum over theta and y')
+                if accept:
+                    img_x_bounds = (x_low, x_upp)
+                else:
+                    while True:
+                        mask, img_x_bounds = draw_mask_1d(x_sum, title='select x data range',
+                                legend='sum over theta and y')
+                        if len(img_x_bounds) == 1:
+                            break
+                        else:
+                            print(f'Choose a single connected data range')
+                    img_x_bounds = tuple(img_x_bounds[0])
+            if (num_tomo_stacks > 1 and img_x_bounds[1]-img_x_bounds[0]+1 < 
+                    int((delta_z-0.5*pixel_size)/pixel_size)):
+                logger.warning('Image bounds and pixel size prevent seamless stacking')
+        else:
+            if num_tomo_stacks > 1:
+                raise NotImplementedError('Selecting image bounds for multiple stacks on FMB')
+            # For FMB: use the first tomography image to select range
+            # RV: revisit if they do tomography with multiple stacks
+            x_sum = np.sum(first_image, 1)
+            x_sum_min = x_sum.min()
+            x_sum_max = x_sum.max()
+            if self.galaxy_flag:
+                if img_x_bounds is None:
+                    img_x_bounds = (0, first_image.shape[0])
+            else:
+                quick_imshow(first_image, title=title)
+                print('Select vertical data reduction range from first tomography image')
+                img_x_bounds = select_image_bounds(first_image, 0, title=title)
+                clear_imshow(title)
+                if img_x_bounds is None:
+                    raise ValueError('Unable to select image bounds')
+
+        # Plot results
+        if self.galaxy_flag:
+            path = 'tomo_reduce_plots'
+        else:
+            path = self.output_folder
+        x_low = img_x_bounds[0]
+        x_upp = img_x_bounds[1]
+        tmp = np.copy(first_image)
+        tmp_max = tmp.max()
+        tmp[x_low,:] = tmp_max
+        tmp[x_upp-1,:] = tmp_max
+        quick_imshow(tmp, title=title, path=path, save_fig=self.save_figs, save_only=self.save_only,
+                block=self.block)
+        del tmp
+        quick_plot((range(x_sum.size), x_sum),
+                ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'),
+                ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'),
+                title='sum over theta and y', path=path, save_fig=self.save_figs,
+                save_only=self.save_only, block=self.block)
+
+        return(img_x_bounds)
+
+    def _set_zoom_or_skip(self):
+        """Set zoom and/or theta skip to reduce memory the requirement for the analysis.
+        """
+#        if input_yesno('\nDo you want to zoom in to reduce memory requirement (y/n)?', 'n'):
+#            zoom_perc = input_int('    Enter zoom percentage', ge=1, le=100)
+#        else:
+#            zoom_perc = None
+        zoom_perc = None
+#        if input_yesno('Do you want to skip thetas to reduce memory requirement (y/n)?', 'n'):
+#            num_theta_skip = input_int('    Enter the number skip theta interval', ge=0,
+#                    lt=num_theta)
+#        else:
+#            num_theta_skip = None
+        num_theta_skip = None
+        logger.debug(f'zoom_perc = {zoom_perc}')
+        logger.debug(f'num_theta_skip = {num_theta_skip}')
+
+        return(zoom_perc, num_theta_skip)
+
+    def _gen_tomo(self, nxentry, reduced_data):
+        """Generate tomography fields.
+        """
+        # Get full bright field
+        tbf = np.asarray(reduced_data.data.bright_field)
+        tbf_shape = tbf.shape
+
+        # Get image bounds
+        img_x_bounds = tuple(reduced_data.get('img_x_bounds', (0, tbf_shape[0])))
+        img_y_bounds = tuple(reduced_data.get('img_y_bounds', (0, tbf_shape[1])))
+
+        # Get resized dark field
+#        if 'dark_field' in data:
+#            tbf = np.asarray(reduced_data.data.dark_field[
+#                    img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]])
+#        else:
+#            logger.warning('Dark field unavailable')
+#            tdf = None
+        tdf = None
+
+        # Resize bright field
+        if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]):
+            tbf = tbf[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]]
+
+        # Get the tomography images
+        image_key = nxentry.instrument.detector.get('image_key', None)
+        if image_key and 'data' in nxentry.instrument.detector:
+            field_indices_all = [index for index, key in enumerate(image_key) if key == 0]
+            z_translation_all = nxentry.sample.z_translation[field_indices_all]
+            z_translation_levels = sorted(list(set(z_translation_all)))
+            num_tomo_stacks = len(z_translation_levels)
+            tomo_stacks = num_tomo_stacks*[np.array([])]
+            horizontal_shifts = []
+            vertical_shifts = []
+            thetas = None
+            tomo_stacks = []
+            for i, z_translation in enumerate(z_translation_levels):
+                field_indices = [field_indices_all[index]
+                        for index, z in enumerate(z_translation_all) if z == z_translation]
+                horizontal_shift = list(set(nxentry.sample.x_translation[field_indices]))
+                assert(len(horizontal_shift) == 1)
+                horizontal_shifts += horizontal_shift
+                vertical_shift = list(set(nxentry.sample.z_translation[field_indices]))
+                assert(len(vertical_shift) == 1)
+                vertical_shifts += vertical_shift
+                sequence_numbers = nxentry.instrument.detector.sequence_number[field_indices]
+                if thetas is None:
+                    thetas = np.asarray(nxentry.sample.rotation_angle[field_indices]) \
+                             [sequence_numbers]
+                else:
+                    assert(all(thetas[i] == nxentry.sample.rotation_angle[field_indices[index]]
+                            for i, index in enumerate(sequence_numbers)))
+                assert(list(set(sequence_numbers)) == [i for i in range(len(sequence_numbers))])
+                if list(sequence_numbers) == [i for i in range(len(sequence_numbers))]:
+                    tomo_stack = np.asarray(nxentry.instrument.detector.data[field_indices])
+                else:
+                    raise ValueError('Unable to load the tomography images')
+                tomo_stacks.append(tomo_stack)
+        else:
+            tomo_field_scans = nxentry.spec_scans.tomo_fields
+            tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans)
+            horizontal_shifts = tomo_fields.get_horizontal_shifts()
+            vertical_shifts = tomo_fields.get_vertical_shifts()
+            prefix = str(nxentry.instrument.detector.local_name)
+            t0 = time()
+            tomo_stacks = tomo_fields.get_detector_data(prefix)
+            logger.debug(f'Getting tomography images took {time()-t0:.2f} seconds')
+            logger.debug(f'Getting all images took {time()-t0:.2f} seconds')
+            thetas = np.linspace(tomo_fields.theta_range['start'], tomo_fields.theta_range['end'],
+                    tomo_fields.theta_range['num'])
+            if not isinstance(tomo_stacks, list):
+                horizontal_shifts = [horizontal_shifts]
+                vertical_shifts = [vertical_shifts]
+                tomo_stacks = [tomo_stacks]
+
+        reduced_tomo_stacks = []
+        if self.galaxy_flag:
+            path = 'tomo_reduce_plots'
+        else:
+            path = self.output_folder
+        for i, tomo_stack in enumerate(tomo_stacks):
+            # Resize the tomography images
+            # Right now the range is the same for each set in the image stack.
+            if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]):
+                t0 = time()
+                tomo_stack = tomo_stack[:,img_x_bounds[0]:img_x_bounds[1],
+                        img_y_bounds[0]:img_y_bounds[1]].astype('float64')
+                logger.debug(f'Resizing tomography images took {time()-t0:.2f} seconds')
+
+            # Subtract dark field
+            if tdf is not None:
+                t0 = time()
+                with set_numexpr_threads(self.num_core):
+                    ne.evaluate('tomo_stack-tdf', out=tomo_stack)
+                logger.debug(f'Subtracting dark field took {time()-t0:.2f} seconds')
+
+            # Normalize
+            t0 = time()
+            with set_numexpr_threads(self.num_core):
+                ne.evaluate('tomo_stack/tbf', out=tomo_stack, truediv=True)
+            logger.debug(f'Normalizing took {time()-t0:.2f} seconds')
+
+            # Remove non-positive values and linearize data
+            t0 = time()
+            cutoff = 1.e-6
+            with set_numexpr_threads(self.num_core):
+                ne.evaluate('where(tomo_stack<cutoff, cutoff, tomo_stack)', out=tomo_stack)
+            with set_numexpr_threads(self.num_core):
+                ne.evaluate('-log(tomo_stack)', out=tomo_stack)
+            logger.debug('Removing non-positive values and linearizing data took '+
+                    f'{time()-t0:.2f} seconds')
+
+            # Get rid of nans/infs that may be introduced by normalization
+            t0 = time()
+            np.where(np.isfinite(tomo_stack), tomo_stack, 0.)
+            logger.debug(f'Remove nans/infs took {time()-t0:.2f} seconds')
+
+            # Downsize tomography stack to smaller size
+            # TODO use theta_skip as well
+            tomo_stack = tomo_stack.astype('float32')
+            if not self.test_mode:
+                if len(tomo_stacks) == 1:
+                    title = f'red fullres theta {round(thetas[0], 2)+0}'
+                else:
+                    title = f'red stack {i+1} fullres theta {round(thetas[0], 2)+0}'
+                quick_imshow(tomo_stack[0,:,:], title=title, path=path, save_fig=self.save_figs,
+                        save_only=self.save_only, block=self.block)
+#                if not self.block:
+#                    clear_imshow(title)
+            if False and zoom_perc != 100:
+                t0 = time()
+                logger.debug(f'Zooming in ...')
+                tomo_zoom_list = []
+                for j in range(tomo_stack.shape[0]):
+                    tomo_zoom = spi.zoom(tomo_stack[j,:,:], 0.01*zoom_perc)
+                    tomo_zoom_list.append(tomo_zoom)
+                tomo_stack = np.stack([tomo_zoom for tomo_zoom in tomo_zoom_list])
+                logger.debug(f'... done in {time()-t0:.2f} seconds')
+                logger.info(f'Zooming in took {time()-t0:.2f} seconds')
+                del tomo_zoom_list
+                if not self.test_mode:
+                    title = f'red stack {zoom_perc}p theta {round(thetas[0], 2)+0}'
+                    quick_imshow(tomo_stack[0,:,:], title=title, path=path, save_fig=self.save_figs,
+                            save_only=self.save_only, block=self.block)
+#                    if not self.block:
+#                        clear_imshow(title)
+
+            # Save test data to file
+            if self.test_mode:
+#                row_index = int(tomo_stack.shape[0]/2)
+#                np.savetxt(f'{self.output_folder}/red_stack_{i+1}.txt', tomo_stack[row_index,:,:],
+#                        fmt='%.6e')
+                row_index = int(tomo_stack.shape[1]/2)
+                np.savetxt(f'{self.output_folder}/red_stack_{i+1}.txt', tomo_stack[:,row_index,:],
+                        fmt='%.6e')
+
+            # Combine resized stacks
+            reduced_tomo_stacks.append(tomo_stack)
+
+        # Add tomo field info to reduced data NXprocess
+        reduced_data['rotation_angle'] = thetas
+        reduced_data['x_translation'] = np.asarray(horizontal_shifts)
+        reduced_data['z_translation'] = np.asarray(vertical_shifts)
+        reduced_data.data['tomo_fields'] = np.asarray(reduced_tomo_stacks)
+
+        if tdf is not None:
+            del tdf
+        del tbf
+
+        return(reduced_data)
+
+    def _find_center_one_plane(self, sinogram, row, thetas, eff_pixel_size, cross_sectional_dim,
+            path=None, tol=0.1, num_core=1):
+        """Find center for a single tomography plane.
+        """
+        # Try automatic center finding routines for initial value
+        # sinogram index order: theta,column
+        # need column,theta for iradon, so take transpose
+        sinogram = np.asarray(sinogram)
+        sinogram_T = sinogram.T
+        center = sinogram.shape[1]/2
+
+        # Try using Nghia Vo’s method
+        t0 = time()
+        if num_core > num_core_tomopy_limit:
+            logger.debug(f'Running find_center_vo on {num_core_tomopy_limit} cores ...')
+            tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core_tomopy_limit)
+        else:
+            logger.debug(f'Running find_center_vo on {num_core} cores ...')
+            tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core)
+        logger.debug(f'... done in {time()-t0:.2f} seconds')
+        logger.info(f'Finding the center using Nghia Vo’s method took {time()-t0:.2f} seconds')
+        center_offset_vo = tomo_center-center
+        logger.info(f'Center at row {row} using Nghia Vo’s method = {center_offset_vo:.2f}')
+        t0 = time()
+        logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...')
+        recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas,
+                eff_pixel_size, cross_sectional_dim, False, num_core)
+        logger.debug(f'... done in {time()-t0:.2f} seconds')
+        logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds')
+
+        title = f'edges row{row} center offset{center_offset_vo:.2f} Vo'
+        self._plot_edges_one_plane(recon_plane, title, path=path)
+
+        # Try using phase correlation method
+#        if input_yesno('Try finding center using phase correlation (y/n)?', 'n'):
+#            t0 = time()
+#            logger.debug(f'Running find_center_pc ...')
+#            tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=0.1, rotc_guess=tomo_center)
+#            error = 1.
+#            while error > tol:
+#                prev = tomo_center
+#                tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=tol,
+#                        rotc_guess=tomo_center)
+#                error = np.abs(tomo_center-prev)
+#            logger.debug(f'... done in {time()-t0:.2f} seconds')
+#            logger.info('Finding the center using the phase correlation method took '+
+#                    f'{time()-t0:.2f} seconds')
+#            center_offset = tomo_center-center
+#            print(f'Center at row {row} using phase correlation = {center_offset:.2f}')
+#            t0 = time()
+#            logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...')
+#            recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas,
+#                    eff_pixel_size, cross_sectional_dim, False, num_core)
+#            logger.debug(f'... done in {time()-t0:.2f} seconds')
+#            logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds')
+#
+#            title = f'edges row{row} center_offset{center_offset:.2f} PC'
+#            self._plot_edges_one_plane(recon_plane, title, path=path)
+
+        # Select center location
+#        if input_yesno('Accept a center location (y) or continue search (n)?', 'y'):
+        if True:
+#            center_offset = input_num('    Enter chosen center offset', ge=-center, le=center,
+#                    default=center_offset_vo)
+            center_offset = center_offset_vo
+            del sinogram_T
+            del recon_plane
+            return float(center_offset)
+
+        # perform center finding search
+        while True:
+            center_offset_low = input_int('\nEnter lower bound for center offset', ge=-center,
+                    le=center)
+            center_offset_upp = input_int('Enter upper bound for center offset',
+                    ge=center_offset_low, le=center)
+            if center_offset_upp == center_offset_low:
+                center_offset_step = 1
+            else:
+                center_offset_step = input_int('Enter step size for center offset search', ge=1,
+                        le=center_offset_upp-center_offset_low)
+            num_center_offset = 1+int((center_offset_upp-center_offset_low)/center_offset_step)
+            center_offsets = np.linspace(center_offset_low, center_offset_upp, num_center_offset)
+            for center_offset in center_offsets:
+                if center_offset == center_offset_vo:
+                    continue
+                t0 = time()
+                logger.debug(f'Running _reconstruct_one_plane on {num_core} cores ...')
+                recon_plane = self._reconstruct_one_plane(sinogram_T, center_offset+center, thetas,
+                        eff_pixel_size, cross_sectional_dim, False, num_core)
+                logger.debug(f'... done in {time()-t0:.2f} seconds')
+                logger.info(f'Reconstructing center_offset {center_offset} took '+
+                        f'{time()-t0:.2f} seconds')
+                title = f'edges row{row} center_offset{center_offset:.2f}'
+                self._plot_edges_one_plane(recon_plane, title, path=path)
+            if input_int('\nContinue (0) or end the search (1)', ge=0, le=1):
+                break
+
+        del sinogram_T
+        del recon_plane
+        center_offset = input_num('    Enter chosen center offset', ge=-center, le=center)
+        return float(center_offset)
+
+    def _reconstruct_one_plane(self, tomo_plane_T, center, thetas, eff_pixel_size,
+            cross_sectional_dim, plot_sinogram=True, num_core=1):
+        """Invert the sinogram for a single tomography plane.
+        """
+        # tomo_plane_T index order: column,theta
+        assert(0 <= center < tomo_plane_T.shape[0])
+        center_offset = center-tomo_plane_T.shape[0]/2
+        two_offset = 2*int(np.round(center_offset))
+        two_offset_abs = np.abs(two_offset)
+        max_rad = int(0.55*(cross_sectional_dim/eff_pixel_size)) # 10% slack to avoid edge effects
+        if max_rad > 0.5*tomo_plane_T.shape[0]:
+            max_rad = 0.5*tomo_plane_T.shape[0]
+        dist_from_edge = max(1, int(np.floor((tomo_plane_T.shape[0]-two_offset_abs)/2.)-max_rad))
+        if two_offset >= 0:
+            logger.debug(f'sinogram range = [{two_offset+dist_from_edge}, {-dist_from_edge}]')
+            sinogram = tomo_plane_T[two_offset+dist_from_edge:-dist_from_edge,:]
+        else:
+            logger.debug(f'sinogram range = [{dist_from_edge}, {two_offset-dist_from_edge}]')
+            sinogram = tomo_plane_T[dist_from_edge:two_offset-dist_from_edge,:]
+        if not self.galaxy_flag and plot_sinogram:
+            quick_imshow(sinogram.T, f'sinogram center offset{center_offset:.2f}', aspect='auto',
+                    path=self.output_folder, save_fig=self.save_figs, save_only=self.save_only,
+                    block=self.block)
+
+        # Inverting sinogram
+        t0 = time()
+        recon_sinogram = iradon(sinogram, theta=thetas, circle=True)
+        logger.debug(f'Inverting sinogram took {time()-t0:.2f} seconds')
+        del sinogram
+
+        # Performing Gaussian filtering and removing ring artifacts
+        recon_parameters = None#self.config.get('recon_parameters')
+        if recon_parameters is None:
+            sigma = 1.0
+            ring_width = 15
+        else:
+            sigma = recon_parameters.get('gaussian_sigma', 1.0)
+            if not is_num(sigma, ge=0.0):
+                logger.warning(f'Invalid gaussian_sigma ({sigma}) in _reconstruct_one_plane, '+
+                        'set to a default value of 1.0')
+                sigma = 1.0
+            ring_width = recon_parameters.get('ring_width', 15)
+            if not is_int(ring_width, ge=0):
+                logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+
+                        'set to a default value of 15')
+                ring_width = 15
+        t0 = time()
+        recon_sinogram = spi.gaussian_filter(recon_sinogram, sigma, mode='nearest')
+        recon_clean = np.expand_dims(recon_sinogram, axis=0)
+        del recon_sinogram
+        recon_clean = tomopy.misc.corr.remove_ring(recon_clean, rwidth=ring_width, ncore=num_core)
+        logger.debug(f'Filtering and removing ring artifacts took {time()-t0:.2f} seconds')
+
+        return recon_clean
+
+    def _plot_edges_one_plane(self, recon_plane, title, path=None):
+        vis_parameters = None#self.config.get('vis_parameters')
+        if vis_parameters is None:
+            weight = 0.1
+        else:
+            weight = vis_parameters.get('denoise_weight', 0.1)
+            if not is_num(weight, ge=0.0):
+                logger.warning(f'Invalid weight ({weight}) in _plot_edges_one_plane, '+
+                        'set to a default value of 0.1')
+                weight = 0.1
+        edges = denoise_tv_chambolle(recon_plane, weight=weight)
+        vmax = np.max(edges[0,:,:])
+        vmin = -vmax
+        if path is None:
+            path = self.output_folder
+        quick_imshow(edges[0,:,:], f'{title} coolwarm', path=path, cmap='coolwarm',
+                save_fig=self.save_figs, save_only=self.save_only, block=self.block)
+        quick_imshow(edges[0,:,:], f'{title} gray', path=path, cmap='gray', vmin=vmin, vmax=vmax,
+                save_fig=self.save_figs, save_only=self.save_only, block=self.block)
+        del edges
+
+    def _reconstruct_one_tomo_stack(self, tomo_stack, thetas, center_offsets=[], num_core=1,
+            algorithm='gridrec'):
+        """Reconstruct a single tomography stack.
+        """
+        # tomo_stack order: row,theta,column
+        # input thetas must be in degrees 
+        # centers_offset: tomography axis shift in pixels relative to column center
+        # RV should we remove stripes?
+        # https://tomopy.readthedocs.io/en/latest/api/tomopy.prep.stripe.html
+        # RV should we remove rings?
+        # https://tomopy.readthedocs.io/en/latest/api/tomopy.misc.corr.html
+        # RV: Add an option to do (extra) secondary iterations later or to do some sort of convergence test?
+        if not len(center_offsets):
+            centers = np.zeros((tomo_stack.shape[0]))
+        elif len(center_offsets) == 2:
+            centers = np.linspace(center_offsets[0], center_offsets[1], tomo_stack.shape[0])
+        else:
+            if center_offsets.size != tomo_stack.shape[0]:
+                raise ValueError('center_offsets dimension mismatch in reconstruct_one_tomo_stack')
+            centers = center_offsets
+        centers += tomo_stack.shape[2]/2
+
+        # Get reconstruction parameters
+        recon_parameters = None#self.config.get('recon_parameters')
+        if recon_parameters is None:
+            sigma = 2.0
+            secondary_iters = 0
+            ring_width = 15
+        else:
+            sigma = recon_parameters.get('stripe_fw_sigma', 2.0)
+            if not is_num(sigma, ge=0):
+                logger.warning(f'Invalid stripe_fw_sigma ({sigma}) in '+
+                        '_reconstruct_one_tomo_stack, set to a default value of 2.0')
+                ring_width = 15
+            secondary_iters = recon_parameters.get('secondary_iters', 0)
+            if not is_int(secondary_iters, ge=0):
+                logger.warning(f'Invalid secondary_iters ({secondary_iters}) in '+
+                        '_reconstruct_one_tomo_stack, set to a default value of 0 (skip them)')
+                ring_width = 0
+            ring_width = recon_parameters.get('ring_width', 15)
+            if not is_int(ring_width, ge=0):
+                logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+
+                        'set to a default value of 15')
+                ring_width = 15
+
+        # Remove horizontal stripe
+        t0 = time()
+        if num_core > num_core_tomopy_limit:
+            logger.debug('Running remove_stripe_fw on {num_core_tomopy_limit} cores ...')
+            tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma,
+                    ncore=num_core_tomopy_limit)
+        else:
+            logger.debug(f'Running remove_stripe_fw on {num_core} cores ...')
+            tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma,
+                    ncore=num_core)
+        logger.debug(f'... tomopy.prep.stripe.remove_stripe_fw took {time()-t0:.2f} seconds')
+
+        # Perform initial image reconstruction
+        logger.debug('Performing initial image reconstruction')
+        t0 = time()
+        logger.debug(f'Running recon on {num_core} cores ...')
+        tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers,
+                sinogram_order=True, algorithm=algorithm, ncore=num_core)
+        logger.debug(f'... done in {time()-t0:.2f} seconds')
+        logger.info(f'Performing initial image reconstruction took {time()-t0:.2f} seconds')
+
+        # Run optional secondary iterations
+        if secondary_iters > 0:
+            logger.debug(f'Running {secondary_iters} secondary iterations')
+            #options = {'method':'SIRT_CUDA', 'proj_type':'cuda', 'num_iter':secondary_iters}
+            #RV: doesn't work for me:
+            #"Error: CUDA error 803: system has unsupported display driver/cuda driver combination."
+            #options = {'method':'SIRT', 'proj_type':'linear', 'MinConstraint': 0, 'num_iter':secondary_iters}
+            #SIRT did not finish while running overnight
+            #options = {'method':'SART', 'proj_type':'linear', 'num_iter':secondary_iters}
+            options = {'method':'SART', 'proj_type':'linear', 'MinConstraint': 0,
+                    'num_iter':secondary_iters}
+            t0 = time()
+            logger.debug(f'Running recon on {num_core} cores ...')
+            tomo_recon_stack  = tomopy.recon(tomo_stack, np.radians(thetas), centers,
+                    init_recon=tomo_recon_stack, options=options, sinogram_order=True,
+                    algorithm=tomopy.astra, ncore=num_core)
+            logger.debug(f'... done in {time()-t0:.2f} seconds')
+            logger.info(f'Performing secondary iterations took {time()-t0:.2f} seconds')
+
+        # Remove ring artifacts
+        t0 = time()
+        tomopy.misc.corr.remove_ring(tomo_recon_stack, rwidth=ring_width, out=tomo_recon_stack,
+                ncore=num_core)
+        logger.debug(f'Removing ring artifacts took {time()-t0:.2f} seconds')
+
+        return tomo_recon_stack
+
+    def _resize_reconstructed_data(self, data, z_only=False):
+        """Resize the reconstructed tomography data.
+        """
+        # Data order: row(z),x,y or stack,row(z),x,y
+        if isinstance(data, list):
+            for stack in data:
+                assert(stack.ndim == 3)
+            num_tomo_stacks = len(data)
+            tomo_recon_stacks = data
+        else:
+            assert(data.ndim == 3)
+            num_tomo_stacks = 1
+            tomo_recon_stacks = [data]
+
+        if z_only:
+            x_bounds = None
+            y_bounds = None
+        else:
+            # Selecting x bounds (in yz-plane)
+            tomosum = 0
+            [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,2))
+                    for i in range(num_tomo_stacks)]
+            select_x_bounds = input_yesno('\nDo you want to change the image x-bounds (y/n)?', 'y')
+            if not select_x_bounds:
+                x_bounds = None
+            else:
+                accept = False
+                index_ranges = None
+                while not accept:
+                    mask, x_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges,
+                            title='select x data range', legend='recon stack sum yz')
+                    while len(x_bounds) != 1:
+                        print('Please select exactly one continuous range')
+                        mask, x_bounds = draw_mask_1d(tomosum, title='select x data range',
+                                legend='recon stack sum yz')
+                    x_bounds = x_bounds[0]
+#                    quick_plot(tomosum, vlines=x_bounds, title='recon stack sum yz')
+#                    print(f'x_bounds = {x_bounds} (lower bound inclusive, upper bound '+
+#                            'exclusive)')
+#                    accept = input_yesno('Accept these bounds (y/n)?', 'y')
+                    accept = True
+            logger.debug(f'x_bounds = {x_bounds}')
+
+            # Selecting y bounds (in xz-plane)
+            tomosum = 0
+            [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,1))
+                    for i in range(num_tomo_stacks)]
+            select_y_bounds = input_yesno('\nDo you want to change the image y-bounds (y/n)?', 'y')
+            if not select_y_bounds:
+                y_bounds = None
+            else:
+                accept = False
+                index_ranges = None
+                while not accept:
+                    mask, y_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges,
+                            title='select x data range', legend='recon stack sum xz')
+                    while len(y_bounds) != 1:
+                        print('Please select exactly one continuous range')
+                        mask, y_bounds = draw_mask_1d(tomosum, title='select x data range',
+                                legend='recon stack sum xz')
+                    y_bounds = y_bounds[0]
+#                    quick_plot(tomosum, vlines=y_bounds, title='recon stack sum xz')
+#                    print(f'y_bounds = {y_bounds} (lower bound inclusive, upper bound '+
+#                            'exclusive)')
+#                    accept = input_yesno('Accept these bounds (y/n)?', 'y')
+                    accept = True
+            logger.debug(f'y_bounds = {y_bounds}')
+
+        # Selecting z bounds (in xy-plane) (only valid for a single image stack)
+        if num_tomo_stacks != 1:
+            z_bounds = None
+        else:
+            tomosum = 0
+            [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(1,2))
+                    for i in range(num_tomo_stacks)]
+            select_z_bounds = input_yesno('Do you want to change the image z-bounds (y/n)?', 'n')
+            if not select_z_bounds:
+                z_bounds = None
+            else:
+                accept = False
+                index_ranges = None
+                while not accept:
+                    mask, z_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges,
+                            title='select x data range', legend='recon stack sum xy')
+                    while len(z_bounds) != 1:
+                        print('Please select exactly one continuous range')
+                        mask, z_bounds = draw_mask_1d(tomosum, title='select x data range',
+                                legend='recon stack sum xy')
+                    z_bounds = z_bounds[0]
+#                    quick_plot(tomosum, vlines=z_bounds, title='recon stack sum xy')
+#                    print(f'z_bounds = {z_bounds} (lower bound inclusive, upper bound '+
+#                            'exclusive)')
+#                    accept = input_yesno('Accept these bounds (y/n)?', 'y')
+                    accept = True
+            logger.debug(f'z_bounds = {z_bounds}')
+
+        return(x_bounds, y_bounds, z_bounds)
+
+
+def run_tomo(input_file:str, output_file:str, modes:list[str], center_file=None, num_core=-1,
+        output_folder='.', save_figs='no', test_mode=False) -> None:
+
+    if test_mode:
+        logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s'
+        level = logging.getLevelName('INFO')
+        logging.basicConfig(filename=f'{output_folder}/tomo.log', filemode='w',
+                format=logging_format, level=level, force=True)
+    logger.info(f'input_file = {input_file}')
+    logger.info(f'center_file = {center_file}')
+    logger.info(f'output_file = {output_file}')
+    logger.debug(f'modes= {modes}')
+    logger.debug(f'num_core= {num_core}')
+    logger.info(f'output_folder = {output_folder}')
+    logger.info(f'save_figs = {save_figs}')
+    logger.info(f'test_mode = {test_mode}')
+
+    # Check for correction modes
+    legal_modes = ['reduce_data', 'find_center', 'reconstruct_data', 'combine_data', 'all']
+    if modes is None:
+        modes = ['all']
+    if not all(True if mode in legal_modes else False for mode in modes):
+        raise ValueError(f'Invalid parameter modes ({modes})')
+
+    # Instantiate Tomo object
+    tomo = Tomo(num_core=num_core, output_folder=output_folder, save_figs=save_figs,
+            test_mode=test_mode)
+
+    # Read input file
+    data = tomo.read(input_file)
+
+    # Generate reduced tomography images
+    if 'reduce_data' in modes or 'all' in modes:
+        data = tomo.gen_reduced_data(data)
+
+    # Find rotation axis centers for the tomography stacks.
+    center_data = None
+    if 'find_center' in modes or 'all' in modes:
+        center_data = tomo.find_centers(data)
+
+    # Reconstruct tomography stacks
+    if 'reconstruct_data' in modes or 'all' in modes:
+        if center_data is None:
+            # Read input file
+            center_data = tomo.read(center_file)
+        data = tomo.reconstruct_data(data, center_data)
+        center_data = None
+
+    # Combine reconstructed tomography stacks
+    if 'combine_data' in modes or 'all' in modes:
+        data = tomo.combine_data(data)
+
+    # Write output file
+    if data is not None and not test_mode:
+        if center_data is None:
+            data = tomo.write(data, output_file)
+        else:
+            data = tomo.write(center_data, output_file)
+
+    logger.info(f'Completed modes: {modes}')