view workflow/models.py @ 71:1cf15b61cd83 draft

planemo upload for repository https://github.com/rolfverberg/galaxytools commit 366e516aef0735af2998c6ff3af037181c8d5213
author rv43
date Mon, 20 Mar 2023 13:56:57 +0000
parents fba792d5f83b
children
line wrap: on
line source

#!/usr/bin/env python3

import logging
logger = logging.getLogger(__name__)

import logging

import numpy as np
import os
import yaml

from functools import cache
from pathlib import PosixPath
from pydantic import BaseModel as PydanticBaseModel
from pydantic import validator, ValidationError, conint, confloat, constr, conlist, FilePath, \
        PrivateAttr
from nexusformat.nexus import *
from time import time
from typing import Optional, Literal
from typing_extensions import TypedDict
try:
    from pyspec.file.spec import FileSpec
except:
    pass

try:
    from msnctools.general import is_int, is_num, input_int, input_int_list, input_num, \
            input_yesno, input_menu, index_nearest, string_to_list, file_exists_and_readable
except:
    from general import is_int, is_num, input_int, input_int_list, input_num, \
            input_yesno, input_menu, index_nearest, string_to_list, file_exists_and_readable


def import_scanparser(station):
    if station in ('id1a3', 'id3a'):
        try:
            from msnctools.scanparsers import SMBRotationScanParser
            globals()['ScanParser'] = SMBRotationScanParser
        except:
            try:
                from scanparsers import SMBRotationScanParser
                globals()['ScanParser'] = SMBRotationScanParser
            except:
                pass
    elif station in ('id3b'):
        try:
            from msnctools.scanparsers import FMBRotationScanParser
            globals()['ScanParser'] = FMBRotationScanParser
        except:
            try:
                from scanparsers import FMBRotationScanParser
                globals()['ScanParser'] = FMBRotationScanParser
            except:
                pass
    else:
        raise RuntimeError(f'Invalid station: {station}')

@cache
def get_available_scan_numbers(spec_file:str):
    scans = FileSpec(spec_file).scans
    scan_numbers = list(scans.keys())
    for scan_number in scan_numbers.copy():
        try:
            parser = ScanParser(spec_file, scan_number)
            try:
                scan_type = parser.scan_type
            except:
                scan_type = None
        except:
            scan_numbers.remove(scan_number)
    return(scan_numbers)

@cache
def get_scanparser(spec_file:str, scan_number:int):
    if scan_number not in get_available_scan_numbers(spec_file):
        return(None)
    else:
        return(ScanParser(spec_file, scan_number))


class BaseModel(PydanticBaseModel):
    class Config:
        validate_assignment = True
        arbitrary_types_allowed = True

    @classmethod
    def construct_from_cli(cls):
        obj = cls.construct()
        obj.cli()
        return(obj)

    @classmethod
    def construct_from_yaml(cls, filename):
        try:
            with open(filename, 'r') as infile:
                indict = yaml.load(infile, Loader=yaml.CLoader)
        except:
            raise ValueError(f'Could not load a dictionary from {filename}')
        else:
            obj = cls(**indict)
            return(obj)

    @classmethod
    def construct_from_file(cls, filename):
        file_exists_and_readable(filename)
        filename = os.path.abspath(filename)
        fileformat = os.path.splitext(filename)[1]
        yaml_extensions = ('.yaml','.yml')
        nexus_extensions = ('.nxs','.nx5','.h5','.hdf5')
        t0 = time()
        if fileformat.lower() in yaml_extensions:
            obj = cls.construct_from_yaml(filename)
            logger.info(f'Constructed a model from {filename} in {time()-t0:.2f} seconds.')
            return(obj)
        elif fileformat.lower() in nexus_extensions:
            obj = cls.construct_from_nexus(filename)
            logger.info(f'Constructed a model from {filename} in {time()-t0:.2f} seconds.')
            return(obj)
        else:
            logger.error(f'Unsupported file extension for constructing a model: {fileformat}')
            raise TypeError(f'Unrecognized file extension: {fileformat}')

    def dict_for_yaml(self, exclude_fields=[]):
        yaml_dict = {}
        for field_name in self.__fields__:
            if field_name in exclude_fields:
                continue
            else:
                field_value = getattr(self, field_name, None)
                if field_value is not None:
                    if isinstance(field_value, BaseModel):
                        yaml_dict[field_name] = field_value.dict_for_yaml()
                    elif isinstance(field_value,list) and all(isinstance(item,BaseModel)
                            for item in field_value):
                        yaml_dict[field_name] = [item.dict_for_yaml() for item in field_value]
                    elif isinstance(field_value, PosixPath):
                        yaml_dict[field_name] = str(field_value)
                    else:
                        yaml_dict[field_name] = field_value
                else:
                    continue
        return(yaml_dict)

    def write_to_yaml(self, filename=None):
        yaml_dict = self.dict_for_yaml()
        if filename is None:
            logger.info('Printing yaml representation here:\n'+
                    f'{yaml.dump(yaml_dict, sort_keys=False)}')
        else:
            try:
                with open(filename, 'w') as outfile:
                    yaml.dump(yaml_dict, outfile, sort_keys=False)
                logger.info(f'Successfully wrote this model to {filename}')
            except:
                logger.error(f'Unknown error -- could not write to {filename} in yaml format.')
                logger.info('Printing yaml representation here:\n'+
                        f'{yaml.dump(yaml_dict, sort_keys=False)}')

    def write_to_file(self, filename, force_overwrite=False):
        file_writeable, fileformat = self.output_file_valid(filename,
                force_overwrite=force_overwrite)
        if fileformat == 'yaml':
            if file_writeable:
                self.write_to_yaml(filename=filename)
            else:
                self.write_to_yaml()
        elif fileformat == 'nexus':
            if file_writeable:
                self.write_to_nexus(filename=filename)

    def output_file_valid(self, filename, force_overwrite=False):
        filename = os.path.abspath(filename)
        fileformat = os.path.splitext(filename)[1]
        yaml_extensions = ('.yaml','.yml')
        nexus_extensions = ('.nxs','.nx5','.h5','.hdf5')
        if fileformat.lower() not in (*yaml_extensions, *nexus_extensions):
            return(False, None) # Only yaml and NeXus files allowed for output now.
        elif fileformat.lower() in yaml_extensions:
            fileformat = 'yaml'
        elif fileformat.lower() in nexus_extensions:
            fileformat = 'nexus'
        if os.path.isfile(filename):
            if os.access(filename, os.W_OK):
                if not force_overwrite:
                    logger.error(f'{filename} will not be overwritten.')
                    return(False, fileformat)
            else:
                logger.error(f'Cannot access {filename} for writing.')
                return(False, fileformat)
        if os.path.isdir(os.path.dirname(filename)):
            if os.access(os.path.dirname(filename), os.W_OK):
                return(True, fileformat)
            else:
                logger.error(f'Cannot access {os.path.dirname(filename)} for writing.')
                return(False, fileformat)
        else:
            try:
                os.makedirs(os.path.dirname(filename))
                return(True, fileformat)
            except:
                logger.error(f'Cannot create {os.path.dirname(filename)} for output.')
                return(False, fileformat)

    def set_single_attr_cli(self, attr_name, attr_desc='unknown attribute', list_flag=False,
            **cli_kwargs):
        if cli_kwargs.get('chain_attr_desc', False):
            cli_kwargs['attr_desc'] = attr_desc
        try:
            attr = getattr(self, attr_name, None)
            if attr is None:
                attr = self.__fields__[attr_name].type_.construct()
            if cli_kwargs.get('chain_attr_desc', False):
                cli_kwargs['attr_desc'] = attr_desc
            input_accepted = False
            while not input_accepted:
                try:
                    attr.cli(**cli_kwargs)
                except ValidationError as e:
                    print(e)
                    print(f'Removing {attr_desc} configuration')
                    attr = self.__fields__[attr_name].type_.construct()
                    continue
                except KeyboardInterrupt as e:
                    raise e
                except BaseException as e:
                    print(f'{type(e).__name__}: {e}')
                    print(f'Removing {attr_desc} configuration')
                    attr = self.__fields__[attr_name].type_.construct()
                    continue
                try:
                    setattr(self, attr_name, attr)
                except ValidationError as e:
                    print(e)
                except KeyboardInterrupt as e:
                    raise e
                except BaseException as e:
                    print(f'{type(e).__name__}: {e}')
                else:
                    input_accepted = True
        except:
            input_accepted = False
            while not input_accepted:
                attr = getattr(self, attr_name, None)
                if attr is None:
                    input_value = input(f'Type and enter a value for {attr_desc}: ')
                else:
                    input_value = input(f'Type and enter a new value for {attr_desc} or press '+
                            f'enter to keep the current one ({attr}): ')
                if list_flag:
                    input_value = string_to_list(input_value, remove_duplicates=False, sort=False)
                if len(input_value) == 0:
                    input_value = getattr(self, attr_name, None)
                try:
                    setattr(self, attr_name, input_value)
                except ValidationError as e:
                    print(e)
                except KeyboardInterrupt as e:
                    raise e
                except BaseException as e:
                    print(f'Unexpected {type(e).__name__}: {e}')
                else:
                    input_accepted = True

    def set_list_attr_cli(self, attr_name, attr_desc='unknown attribute', **cli_kwargs):
        if cli_kwargs.get('chain_attr_desc', False):
            cli_kwargs['attr_desc'] = attr_desc
        attr = getattr(self, attr_name, None)
        if attr is not None:
            # Check existing items
            for item in attr:
                item_accepted = False
                while not item_accepted:
                    item.cli(**cli_kwargs)
                    try:
                        setattr(self, attr_name, attr)
                    except ValidationError as e:
                        print(e)
                    except KeyboardInterrupt as e:
                        raise e
                    except BaseException as e:
                        print(f'{type(e).__name__}: {e}')
                    else:
                        item_accepted = True
        else:
            # Initialize list for new attribute & starting item
            attr = []
            item = self.__fields__[attr_name].type_.construct()
        # Append (optional) additional items
        append = input_yesno(f'Add a {attr_desc} configuration? (y/n)', 'n')
        while append:
            attr.append(item.__class__.construct_from_cli())
            try:
                setattr(self, attr_name, attr)
            except ValidationError as e:
                print(e)
                print(f'Removing last {attr_desc} configuration from the list')
                attr.pop()
            except KeyboardInterrupt as e:
                raise e
            except BaseException as e:
                print(f'{type(e).__name__}: {e}')
                print(f'Removing last {attr_desc} configuration from the list')
                attr.pop()
            else:
                append = input_yesno(f'Add another {attr_desc} configuration? (y/n)', 'n')


class Detector(BaseModel):
    prefix: constr(strip_whitespace=True, min_length=1)
    rows: conint(gt=0)
    columns: conint(gt=0)
    pixel_size: conlist(item_type=confloat(gt=0), min_items=1, max_items=2)
    lens_magnification: confloat(gt=0) = 1.0

    @property
    def get_pixel_size(self):
        return(list(np.asarray(self.pixel_size)/self.lens_magnification))

    def construct_from_yaml(self, filename):
        try:
            with open(filename, 'r') as infile:
                indict = yaml.load(infile, Loader=yaml.CLoader)
            detector = indict['detector']
            self.prefix = detector['id']
            pixels = detector['pixels']
            self.rows = pixels['rows']
            self.columns = pixels['columns']
            self.pixel_size = pixels['size']
            self.lens_magnification = indict['lens_magnification']
        except:
            logging.warning(f'Could not load a dictionary from {filename}')
            return(False)
        else:
            return(True)

    def cli(self):
        print('\n -- Configure the detector -- ')
        self.set_single_attr_cli('prefix', 'detector ID')
        self.set_single_attr_cli('rows', 'number of pixel rows')
        self.set_single_attr_cli('columns', 'number of pixel columns')
        self.set_single_attr_cli('pixel_size', 'pixel size in mm (enter either a single value for '+
                'square pixels or a pair of values for the size in the respective row and column '+
                'directions)', list_flag=True)
        self.set_single_attr_cli('lens_magnification', 'lens magnification')

    def construct_nxdetector(self):
        nxdetector = NXdetector()
        nxdetector.local_name = self.prefix
        pixel_size = self.get_pixel_size
        if len(pixel_size) == 1:
            nxdetector.x_pixel_size = pixel_size[0]
            nxdetector.y_pixel_size = pixel_size[0]
        else:
            nxdetector.x_pixel_size = pixel_size[0]
            nxdetector.y_pixel_size = pixel_size[1]
        nxdetector.x_pixel_size.attrs['units'] = 'mm'
        nxdetector.y_pixel_size.attrs['units'] = 'mm'
        return(nxdetector)


class ScanInfo(TypedDict):
    scan_number: int
    starting_image_offset: conint(ge=0)
    num_image: conint(gt=0)
    ref_x: float
    ref_z: float

class SpecScans(BaseModel):
    spec_file: FilePath
    scan_numbers: conlist(item_type=conint(gt=0), min_items=1)
    stack_info: conlist(item_type=ScanInfo, min_items=1) = []

    @validator('spec_file')
    def validate_spec_file(cls, spec_file):
        try:
            spec_file = os.path.abspath(spec_file)
            sspec_file = FileSpec(spec_file)
        except:
            raise ValueError(f'Invalid SPEC file {spec_file}')
        else:
            return(spec_file)

    @validator('scan_numbers')
    def validate_scan_numbers(cls, scan_numbers, values):
        spec_file = values.get('spec_file')
        if spec_file is not None:
            spec_scans = FileSpec(spec_file)
            for scan_number in scan_numbers:
                scan = spec_scans.get_scan_by_number(scan_number)
                if scan is None:
                    raise ValueError(f'There is no scan number {scan_number} in {spec_file}')
        return(scan_numbers)

    @validator('stack_info')
    def validate_stack_info(cls, stack_info, values):
        scan_numbers = values.get('scan_numbers')
        assert(len(scan_numbers) == len(stack_info))
        for scan_info in stack_info:
            assert(scan_info['scan_number'] in scan_numbers)
            is_int(scan_info['starting_image_offset'], ge=0, lt=scan_info['num_image'],
                    raise_error=True)
        return(stack_info)

    @classmethod
    def construct_from_nxcollection(cls, nxcollection:NXcollection):
        config = {}
        config['spec_file'] = nxcollection.attrs['spec_file']
        scan_numbers = []
        stack_info = []
        for nxsubentry_name, nxsubentry in nxcollection.items():
            scan_number = int(nxsubentry_name.split('_')[-1])
            scan_numbers.append(scan_number)
            stack_info.append({'scan_number': scan_number,
                    'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number),
                    'num_image': len(nxsubentry.sample.rotation_angle),
                    'ref_x': float(nxsubentry.sample.x_translation),
                    'ref_z': float(nxsubentry.sample.z_translation)})
        config['scan_numbers'] = sorted(scan_numbers)
        config['stack_info'] = stack_info
        return(cls(**config))

    @property
    def available_scan_numbers(self):
        return(get_available_scan_numbers(self.spec_file))

    def set_from_nxcollection(self, nxcollection:NXcollection):
        self.spec_file = nxcollection.attrs['spec_file']
        scan_numbers = []
        stack_info = []
        for nxsubentry_name, nxsubentry in nxcollection.items():
            scan_number = int(nxsubentry_name.split('_')[-1])
            scan_numbers.append(scan_number)
            stack_info.append({'scan_number': scan_number,
                    'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number),
                    'num_image': len(nxsubentry.sample.rotation_angle),
                    'ref_x': float(nxsubentry.sample.x_translation),
                    'ref_z': float(nxsubentry.sample.z_translation)})
        self.scan_numbers = sorted(scan_numbers)
        self.stack_info = stack_info

    def get_scan_index(self, scan_number):
        scan_index = [scan_index for scan_index, scan_info in enumerate(self.stack_info)
                if scan_info['scan_number'] == scan_number]
        if len(scan_index) > 1:
            raise ValueError('Duplicate scan_numbers in image stack')
        elif len(scan_index) == 1:
            return(scan_index[0])
        else:
            return(None)

    def get_scanparser(self, scan_number):
        return(get_scanparser(self.spec_file, scan_number))

    def get_detector_data(self, detector_prefix, scan_number=None, scan_step_index=None):
        image_stacks = []
        if scan_number is None:
            scan_numbers = self.scan_numbers
        else:
            scan_numbers = [scan_number]
        for scan_number in scan_numbers:
            parser = self.get_scanparser(scan_number)
            scan_info = self.stack_info[self.get_scan_index(scan_number)]
            image_offset = scan_info['starting_image_offset']
            if scan_step_index is None:
                num_image = scan_info['num_image']
                image_stacks.append(parser.get_detector_data(detector_prefix,
                        (image_offset, image_offset+num_image)))
            else:
                image_stacks.append(parser.get_detector_data(detector_prefix,
                        image_offset+scan_step_index))
        if scan_number is not None and scan_step_index is not None:
            # Return a single image for a specific scan_number and scan_step_index request
            return(image_stacks[0])
        else:
            # Return a list otherwise
            return(image_stacks)
        return(image_stacks)

    def scan_numbers_cli(self, attr_desc, **kwargs):
        available_scan_numbers = self.available_scan_numbers
        station = kwargs.get('station')
        if (station is not None and station in ('id1a3', 'id3a') and
                'scan_type' in kwargs):
            scan_type = kwargs['scan_type']
            if scan_type == 'ts1':
                available_scan_numbers = []
                for scan_number in self.available_scan_numbers:
                    parser = self.get_scanparser(scan_number)
                    try:
                        if parser.scan_type == scan_type:
                            available_scan_numbers.append(scan_number)
                    except:
                        pass
            elif scan_type == 'df1':
                tomo_scan_numbers = kwargs['tomo_scan_numbers']
                available_scan_numbers = []
                for scan_number in tomo_scan_numbers:
                    parser = self.get_scanparser(scan_number-2)
                    assert(parser.scan_type == scan_type)
                    available_scan_numbers.append(scan_number-2)
            elif scan_type == 'bf1':
                tomo_scan_numbers = kwargs['tomo_scan_numbers']
                available_scan_numbers = []
                for scan_number in tomo_scan_numbers:
                    parser = self.get_scanparser(scan_number-1)
                    assert(parser.scan_type == scan_type)
                    available_scan_numbers.append(scan_number-1)
        if len(available_scan_numbers) == 1:
            input_mode = 1
        else:
            if hasattr(self, 'scan_numbers'):
                print(f'Currently selected {attr_desc}scan numbers are: {self.scan_numbers}')
                menu_options = [f'Select a subset of the available {attr_desc}scan numbers',
                        f'Use all available {attr_desc}scan numbers in {self.spec_file}',
                        f'Keep the currently selected {attr_desc}scan numbers']
            else:
                menu_options = [f'Select a subset of the available {attr_desc}scan numbers',
                        f'Use all available {attr_desc}scan numbers in {self.spec_file}']
            print(f'Available scan numbers in {self.spec_file} are: '+
                    f'{available_scan_numbers}')
            input_mode = input_menu(menu_options, header='Choose one of the following options '+
                    'for selecting scan numbers')
        if input_mode == 0:
            accept_scan_numbers = False
            while not accept_scan_numbers:
                try:
                    self.scan_numbers = \
                            input_int_list(f'Enter a series of {attr_desc}scan numbers')
                except ValidationError as e:
                    print(e)
                except KeyboardInterrupt as e:
                    raise e
                except BaseException as e:
                    print(f'Unexpected {type(e).__name__}: {e}')
                else:
                    accept_scan_numbers = True
        elif input_mode == 1:
            self.scan_numbers = available_scan_numbers
        elif input_mode == 2:
            pass

    def cli(self, **cli_kwargs):
        if cli_kwargs.get('attr_desc') is not None:
            attr_desc = f'{cli_kwargs["attr_desc"]} '
        else:
            attr_desc = ''
        print(f'\n -- Configure which scans to use from a single {attr_desc}SPEC file')
        self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path')
        self.scan_numbers_cli(attr_desc)

    def construct_nxcollection(self, image_key, thetas, detector):
        nxcollection = NXcollection()
        nxcollection.attrs['spec_file'] = str(self.spec_file)
        parser = self.get_scanparser(self.scan_numbers[0])
        nxcollection.attrs['date'] = parser.spec_scan.file_date
        for scan_number in self.scan_numbers:
            # Get scan info
            scan_info = self.stack_info[self.get_scan_index(scan_number)]
            # Add an NXsubentry to the NXcollection for each scan
            entry_name = f'scan_{scan_number}'
            nxsubentry = NXsubentry()
            nxcollection[entry_name] = nxsubentry
            parser = self.get_scanparser(scan_number)
            nxsubentry.start_time = parser.spec_scan.date
            nxsubentry.spec_command = parser.spec_command
            # Add an NXdata for independent dimensions to the scan's NXsubentry
            num_image = scan_info['num_image']
            if thetas is None:
                thetas = num_image*[0.0]
            else:
                assert(num_image == len(thetas))
#            nxsubentry.independent_dimensions = NXdata()
#            nxsubentry.independent_dimensions.rotation_angle = thetas
#            nxsubentry.independent_dimensions.rotation_angle.units = 'degrees'
            # Add an NXinstrument to the scan's NXsubentry
            nxsubentry.instrument = NXinstrument()
            # Add an NXdetector to the NXinstrument to the scan's NXsubentry
            nxsubentry.instrument.detector = detector.construct_nxdetector()
            nxsubentry.instrument.detector.frame_start_number = scan_info['starting_image_offset']
            nxsubentry.instrument.detector.image_key = image_key
            # Add an NXsample to the scan's NXsubentry
            nxsubentry.sample = NXsample()
            nxsubentry.sample.rotation_angle = thetas
            nxsubentry.sample.rotation_angle.units = 'degrees'
            nxsubentry.sample.x_translation = scan_info['ref_x']
            nxsubentry.sample.x_translation.units = 'mm'
            nxsubentry.sample.z_translation = scan_info['ref_z']
            nxsubentry.sample.z_translation.units = 'mm'
        return(nxcollection)


class FlatField(SpecScans):

    def image_range_cli(self, attr_desc, detector_prefix):
        stack_info = self.stack_info
        for scan_number in self.scan_numbers:
            # Parse the available image range
            parser = self.get_scanparser(scan_number)
            image_offset = parser.starting_image_offset
            num_image = parser.get_num_image(detector_prefix.upper())
            scan_index = self.get_scan_index(scan_number)

            # Select the image set
            last_image_index = image_offset+num_image
            print(f'Available good image set index range: [{image_offset}, {last_image_index})')
            image_set_approved = False
            if scan_index is not None:
                scan_info = stack_info[scan_index]
                print(f'Current starting image offset and number of images: '+
                        f'{scan_info["starting_image_offset"]} and {scan_info["num_image"]}')
                image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y')
            if not image_set_approved:
                print(f'Default starting image offset and number of images: '+
                        f'{image_offset} and {num_image}')
                image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y')
                if image_set_approved:
                    offset = image_offset
                    num = last_image_index-offset
                while not image_set_approved:
                    offset = input_int(f'Enter the starting image offset', ge=image_offset,
                            lt=last_image_index)#, default=image_offset)
                    num = input_int(f'Enter the number of images', ge=1,
                            le=last_image_index-offset)#, default=last_image_index-offset)
                    print(f'Current starting image offset and number of images: {offset} and {num}')
                    image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y')
                if scan_index is not None:
                    scan_info['starting_image_offset'] = offset
                    scan_info['num_image'] = num
                    scan_info['ref_x'] = parser.horizontal_shift
                    scan_info['ref_z'] = parser.vertical_shift
                else:
                    stack_info.append({'scan_number': scan_number, 'starting_image_offset': offset,
                            'num_image': num, 'ref_x': parser.horizontal_shift,
                            'ref_z': parser.vertical_shift})
        self.stack_info = stack_info

    def cli(self, **cli_kwargs):
        if cli_kwargs.get('attr_desc') is not None:
            attr_desc = f'{cli_kwargs["attr_desc"]} '
        else:
            attr_desc = ''
        station = cli_kwargs.get('station')
        detector = cli_kwargs.get('detector')
        print(f'\n -- Configure the location of the {attr_desc}scan data -- ')
        if station in ('id1a3', 'id3a'):
            self.spec_file = cli_kwargs['spec_file']
            tomo_scan_numbers = cli_kwargs['tomo_scan_numbers']
            scan_type = cli_kwargs['scan_type']
            self.scan_numbers_cli(attr_desc, station=station, tomo_scan_numbers=tomo_scan_numbers,
                    scan_type=scan_type)
        else:
            self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path')
            self.scan_numbers_cli(attr_desc)
        self.image_range_cli(attr_desc, detector.prefix)


class TomoField(SpecScans):
    theta_range: dict = {}

    @validator('theta_range')
    def validate_theta_range(cls, theta_range):
        if len(theta_range) != 3 and len(theta_range) != 4:
            raise ValueError(f'Invalid theta range {theta_range}')
        is_num(theta_range['start'], raise_error=True)
        is_num(theta_range['end'], raise_error=True)
        is_int(theta_range['num'], gt=1, raise_error=True)
        if theta_range['end'] <= theta_range['start']:
            raise ValueError(f'Invalid theta range {theta_range}')
        if 'start_index' in theta_range:
            is_int(theta_range['start_index'], ge=0, raise_error=True)
        return(theta_range)

    @classmethod
    def construct_from_nxcollection(cls, nxcollection:NXcollection):
        #RV Can I derive this from the same classfunction for SpecScans by adding theta_range
        config = {}
        config['spec_file'] = nxcollection.attrs['spec_file']
        scan_numbers = []
        stack_info = []
        for nxsubentry_name, nxsubentry in nxcollection.items():
            scan_number = int(nxsubentry_name.split('_')[-1])
            scan_numbers.append(scan_number)
            stack_info.append({'scan_number': scan_number,
                    'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number),
                    'num_image': len(nxsubentry.sample.rotation_angle),
                    'ref_x': float(nxsubentry.sample.x_translation),
                    'ref_z': float(nxsubentry.sample.z_translation)})
        config['scan_numbers'] = sorted(scan_numbers)
        config['stack_info'] = stack_info
        for name in nxcollection.entries:
            if 'scan_' in name:
                thetas = np.asarray(nxcollection[name].sample.rotation_angle)
                config['theta_range'] = {'start': thetas[0], 'end': thetas[-1], 'num': thetas.size}
                break
        return(cls(**config))

    def get_horizontal_shifts(self, scan_number=None):
        horizontal_shifts = []
        if scan_number is None:
            scan_numbers = self.scan_numbers
        else:
            scan_numbers = [scan_number]
        for scan_number in scan_numbers:
            parser = self.get_scanparser(scan_number)
            horizontal_shifts.append(parser.horizontal_shift)
        if len(horizontal_shifts) == 1:
            return(horizontal_shifts[0])
        else:
            return(horizontal_shifts)

    def get_vertical_shifts(self, scan_number=None):
        vertical_shifts = []
        if scan_number is None:
            scan_numbers = self.scan_numbers
        else:
            scan_numbers = [scan_number]
        for scan_number in scan_numbers:
            parser = self.get_scanparser(scan_number)
            vertical_shifts.append(parser.vertical_shift)
        if len(vertical_shifts) == 1:
            return(vertical_shifts[0])
        else:
            return(vertical_shifts)

    def theta_range_cli(self, scan_number, attr_desc, station):
        # Parse the available theta range
        parser = self.get_scanparser(scan_number)
        theta_vals = parser.theta_vals
        spec_theta_start = theta_vals.get('start')
        spec_theta_end = theta_vals.get('end')
        spec_num_theta = theta_vals.get('num')

        # Check for consistency of theta ranges between scans
        if scan_number != self.scan_numbers[0]:
            parser = self.get_scanparser(self.scan_numbers[0])
            if (parser.theta_vals.get('start') != spec_theta_start or
                    parser.theta_vals.get('end') != spec_theta_end or
                    parser.theta_vals.get('num') != spec_num_theta):
                raise ValueError(f'Incompatible theta ranges between {attr_desc}scans:'+
                        f'\n\tScan {scan_number}: {theta_vals}'+
                        f'\n\tScan {self.scan_numbers[0]}: {parser.theta_vals}')
            return

        # Select the theta range for the tomo reconstruction from the first scan
        theta_range_approved = False
        thetas = np.linspace(spec_theta_start, spec_theta_end, spec_num_theta)
        delta_theta = thetas[1]-thetas[0]
        print(f'Theta range obtained from SPEC data: [{spec_theta_start}, {spec_theta_end}]')
        print(f'Theta step size = {delta_theta}')
        print(f'Number of theta values: {spec_num_theta}')
        default_start = None
        default_end = None
        if station in ('id1a3', 'id3a'):
            theta_range_approved = input_yesno(f'Accept this theta range (y/n)?', 'y')
            if theta_range_approved:
                self.theta_range = {'start': float(spec_theta_start), 'end': float(spec_theta_end),
                        'num': int(spec_num_theta), 'start_index': 0}
                return
        elif station in ('id3b'):
            if spec_theta_start <= 0.0 and spec_theta_end >= 180.0:
                default_start = 0
                default_end = 180
            elif spec_theta_end-spec_theta_start == 180:
                default_start = spec_theta_start
                default_end = spec_theta_end
        while not theta_range_approved:
            theta_start = input_num(f'Enter the first theta (included)', ge=spec_theta_start,
                    lt=spec_theta_end, default=default_start)
            theta_index_start = index_nearest(thetas, theta_start)
            theta_start = thetas[theta_index_start]
            theta_end = input_num(f'Enter the last theta (excluded)',
                    ge=theta_start+delta_theta, le=spec_theta_end, default=default_end)
            theta_index_end = index_nearest(thetas, theta_end)
            theta_end = thetas[theta_index_end]
            num_theta = theta_index_end-theta_index_start
            print(f'Selected theta range: [{theta_start}, {theta_start+delta_theta}, ..., '+
                    f'{theta_end})')
            print(f'Number of theta values: {num_theta}')
            theta_range_approved = input_yesno(f'Accept this theta range (y/n)?', 'y')
        self.theta_range = {'start': float(theta_start), 'end': float(theta_end),
                'num': int(num_theta), 'start_index': int(theta_index_start)}

    def image_range_cli(self, attr_desc, detector_prefix):
        stack_info = self.stack_info
        for scan_number in self.scan_numbers:
            # Parse the available image range
            parser = self.get_scanparser(scan_number)
            image_offset = parser.starting_image_offset
            num_image = parser.get_num_image(detector_prefix.upper())
            scan_index = self.get_scan_index(scan_number)

            # Select the image set matching the theta range
            num_theta = self.theta_range['num']
            theta_index_start = self.theta_range['start_index']
            if num_theta > num_image-theta_index_start:
                raise ValueError(f'Available {attr_desc}image indices incompatible with thetas:'+
                        f'\n\tNumber of thetas and offset = {num_theta} and {theta_index_start}'+
                        f'\n\tNumber of available images {num_image}')
            if scan_index is not None:
                scan_info = stack_info[scan_index]
                scan_info['starting_image_offset'] = image_offset+theta_index_start
                scan_info['num_image'] = num_theta
                scan_info['ref_x'] = parser.horizontal_shift
                scan_info['ref_z'] = parser.vertical_shift
            else:
                stack_info.append({'scan_number': scan_number,
                        'starting_image_offset': image_offset+theta_index_start,
                        'num_image': num_theta, 'ref_x': parser.horizontal_shift,
                        'ref_z': parser.vertical_shift})
        self.stack_info = stack_info

    def cli(self, **cli_kwargs):
        if cli_kwargs.get('attr_desc') is not None:
            attr_desc = f'{cli_kwargs["attr_desc"]} '
        else:
            attr_desc = ''
        cycle = cli_kwargs.get('cycle')
        btr = cli_kwargs.get('btr')
        station = cli_kwargs.get('station')
        detector = cli_kwargs.get('detector')
        sample_name = cli_kwargs.get('sample_name')
        print(f'\n -- Configure the location of the {attr_desc}scan data -- ')
        if station in ('id1a3', 'id3a'):
            basedir = f'/nfs/chess/{station}/{cycle}/{btr}'
            runs = [d for d in os.listdir(basedir) if os.path.isdir(os.path.join(basedir, d))]
#RV            index = 15-1
#RV            index = 7-1
            if sample_name is not None and sample_name in runs:
                index = runs.index(sample_name)
            else:
                index = input_menu(runs, header='Choose a sample directory')
            self.spec_file = f'{basedir}/{runs[index]}/spec.log'
            self.scan_numbers_cli(attr_desc, station=station, scan_type='ts1')
        else:
            self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path')
            self.scan_numbers_cli(attr_desc)
        for scan_number in self.scan_numbers:
            self.theta_range_cli(scan_number, attr_desc, station)
        self.image_range_cli(attr_desc, detector.prefix)


class Sample(BaseModel):
    name: constr(min_length=1)
    description: Optional[str]
    rotation_angles: Optional[list]
    x_translations: Optional[list]
    z_translations: Optional[list]

    @classmethod
    def construct_from_nxsample(cls, nxsample:NXsample):
        config = {}
        config['name'] = nxsample.name.nxdata
        if 'description' in nxsample:
            config['description'] = nxsample.description.nxdata
        if 'rotation_angle' in nxsample:
            config['rotation_angle'] = nxsample.rotation_angle.nxdata
        if 'x_translation' in nxsample:
            config['x_translation'] = nxsample.x_translation.nxdata
        if 'z_translation' in nxsample:
            config['z_translation'] = nxsample.z_translation.nxdata
        return(cls(**config))

    def cli(self):
        print('\n -- Configure the sample metadata -- ')
#RV        self.name = 'sobhani-3249-A'
#RV        self.name = 'tenstom_1304r-1'
        self.set_single_attr_cli('name', 'the sample name')
#RV        self.description = 'test sample'
        self.set_single_attr_cli('description', 'a description of the sample (optional)')


class MapConfig(BaseModel):
    cycle: constr(strip_whitespace=True, min_length=1)
    btr: constr(strip_whitespace=True, min_length=1)
    title: constr(strip_whitespace=True, min_length=1)
    station: Literal['id1a3', 'id3a', 'id3b'] = None
    sample: Sample
    detector: Detector = Detector.construct()
    tomo_fields: TomoField
    dark_field: Optional[FlatField]
    bright_field: FlatField
    _thetas: list[float] = PrivateAttr()
    _field_types = ({'name': 'dark_field', 'image_key': 2}, {'name': 'bright_field',
            'image_key': 1}, {'name': 'tomo_fields', 'image_key': 0})

    @classmethod
    def construct_from_nxentry(cls, nxentry:NXentry):
        config = {}
        config['cycle'] = nxentry.instrument.source.attrs['cycle']
        config['btr'] = nxentry.instrument.source.attrs['btr']
        config['title'] = nxentry.nxname
        config['station'] = nxentry.instrument.source.attrs['station']
        config['sample'] = Sample.construct_from_nxsample(nxentry['sample'])
        for nxobject_name, nxobject in nxentry.spec_scans.items():
            if isinstance(nxobject, NXcollection):
                config[nxobject_name] = SpecScans.construct_from_nxcollection(nxobject)
        return(cls(**config))

#FIX cache?
    @property
    def thetas(self):
        try:
            return(self._thetas)
        except:
            theta_range = self.tomo_fields.theta_range
            self._thetas = list(np.linspace(theta_range['start'], theta_range['end'],
                    theta_range['num']))
            return(self._thetas)

    def cli(self):
        print('\n -- Configure a map from a set of SPEC scans (dark, bright, and tomo), '+
                'and / or detector data -- ')
#RV        self.cycle = '2021-3'
#RV        self.cycle = '2022-2'
#RV        self.cycle = '2023-1'
        self.set_single_attr_cli('cycle', 'beam cycle')
#RV        self.btr = 'z-3234-A'
#RV        self.btr = 'sobhani-3249-A'
#RV        self.btr = 'przybyla-3606-a'
        self.set_single_attr_cli('btr', 'BTR')
#RV        self.title = 'z-3234-A'
#RV        self.title = 'tomo7C'
#RV        self.title = 'cmc-test-dwell-1'
        self.set_single_attr_cli('title', 'title for the map entry')
#RV        self.station = 'id3a'
#RV        self.station = 'id3b'
#RV        self.station = 'id1a3'
        self.set_single_attr_cli('station', 'name of the station at which scans were collected '+
                '(currently choose from: id1a3, id3a, id3b)')
        import_scanparser(self.station)
        self.set_single_attr_cli('sample')
        use_detector_config = False
        if hasattr(self.detector, 'prefix') and len(self.detector.prefix):
            use_detector_config = input_yesno(f'Current detector settings:\n{self.detector}\n'+
                    f'Keep these settings? (y/n)')
        if not use_detector_config:
            menu_options = ['not listed', 'andor2', 'manta', 'retiga']
            input_mode = input_menu(menu_options, header='Choose one of the following detector '+
                    'configuration options')
            if input_mode:
                detector_config_file = f'{menu_options[input_mode]}.yaml'
                have_detector_config = self.detector.construct_from_yaml(detector_config_file)
            else:
                have_detector_config = False
            if not have_detector_config:
                self.set_single_attr_cli('detector', 'detector')
        self.set_single_attr_cli('tomo_fields', 'Tomo field', chain_attr_desc=True,
                cycle=self.cycle, btr=self.btr, station=self.station, detector=self.detector,
                sample_name=self.sample.name)
        if self.station in ('id1a3', 'id3a'):
            have_dark_field = True
            tomo_spec_file = self.tomo_fields.spec_file
        else:
            have_dark_field = input_yesno(f'Are Dark field images available? (y/n)')
            tomo_spec_file = None
        if have_dark_field:
            self.set_single_attr_cli('dark_field', 'Dark field', chain_attr_desc=True,
                    station=self.station, detector=self.detector, spec_file=tomo_spec_file,
                    tomo_scan_numbers=self.tomo_fields.scan_numbers, scan_type='df1')
        self.set_single_attr_cli('bright_field', 'Bright field', chain_attr_desc=True,
                station=self.station, detector=self.detector, spec_file=tomo_spec_file,
                    tomo_scan_numbers=self.tomo_fields.scan_numbers, scan_type='bf1')

    def construct_nxentry(self, nxroot, include_raw_data=True):
        # Construct base NXentry
        nxentry = NXentry()

        # Add an NXentry to the NXroot
        nxroot[self.title] = nxentry
        nxroot.attrs['default'] = self.title
        nxentry.definition = 'NXtomo'
#        nxentry.attrs['default'] = 'data'

        # Add an NXinstrument to the NXentry
        nxinstrument = NXinstrument()
        nxentry.instrument = nxinstrument

        # Add an NXsource to the NXinstrument
        nxsource = NXsource()
        nxinstrument.source = nxsource
        nxsource.type = 'Synchrotron X-ray Source'
        nxsource.name = 'CHESS'
        nxsource.probe = 'x-ray'
 
        # Tag the NXsource with the runinfo (as an attribute)
        nxsource.attrs['cycle'] = self.cycle
        nxsource.attrs['btr'] = self.btr
        nxsource.attrs['station'] = self.station

        # Add an NXdetector to the NXinstrument (don't fill in data fields yet)
        nxinstrument.detector = self.detector.construct_nxdetector()

        # Add an NXsample to NXentry (don't fill in data fields yet)
        nxsample = NXsample()
        nxentry.sample = nxsample
        nxsample.name = self.sample.name
        nxsample.description = self.sample.description

        # Add an NXcollection to the base NXentry to hold metadata about the spec scans in the map
        # Also obtain the data fields in NXsample and NXdetector
        nxspec_scans = NXcollection()
        nxentry.spec_scans = nxspec_scans
        image_keys = []
        sequence_numbers = []
        image_stacks = []
        rotation_angles = []
        x_translations = []
        z_translations = []
        for field_type in self._field_types:
            field_name = field_type['name']
            field = getattr(self, field_name)
            if field is None:
                continue
            image_key = field_type['image_key']
            if field_type['name'] == 'tomo_fields':
                thetas = self.thetas
            else:
                thetas = None
            # Add the scans in a single spec file
            nxspec_scans[field_name] = field.construct_nxcollection(image_key, thetas,
                    self.detector)
            if include_raw_data:
                image_stacks += field.get_detector_data(self.detector.prefix)
                for scan_number in field.scan_numbers:
                    parser = field.get_scanparser(scan_number)
                    scan_info = field.stack_info[field.get_scan_index(scan_number)]
                    num_image = scan_info['num_image']
                    image_keys += num_image*[image_key]
                    sequence_numbers += [i for i in range(num_image)]
                    if thetas is None:
                        rotation_angles += scan_info['num_image']*[0.0]
                    else:
                        assert(num_image == len(thetas))
                        rotation_angles += thetas
                    x_translations += scan_info['num_image']*[scan_info['ref_x']]
                    z_translations += scan_info['num_image']*[scan_info['ref_z']]

        if include_raw_data:
            # Add image data to NXdetector
            nxinstrument.detector.image_key = image_keys
            nxinstrument.detector.sequence_number = sequence_numbers
            nxinstrument.detector.data = np.concatenate([image for image in image_stacks])

            # Add image data to NXsample
            nxsample.rotation_angle = rotation_angles
            nxsample.rotation_angle.attrs['units'] = 'degrees'
            nxsample.x_translation = x_translations
            nxsample.x_translation.attrs['units'] = 'mm'
            nxsample.z_translation = z_translations
            nxsample.z_translation.attrs['units'] = 'mm'

            # Add an NXdata to NXentry
            nxdata = NXdata()
            nxentry.data = nxdata
            nxdata.makelink(nxentry.instrument.detector.data, name='data')
            nxdata.makelink(nxentry.instrument.detector.image_key)
            nxdata.makelink(nxentry.sample.rotation_angle)
            nxdata.makelink(nxentry.sample.x_translation)
            nxdata.makelink(nxentry.sample.z_translation)
#            nxdata.attrs['axes'] = ['field', 'row', 'column']
#            nxdata.attrs['field_indices'] = 0
#            nxdata.attrs['row_indices'] = 1
#            nxdata.attrs['column_indices'] = 2


class TomoWorkflow(BaseModel):
    sample_maps: conlist(item_type=MapConfig, min_items=1) = [MapConfig.construct()]

    @classmethod
    def construct_from_nexus(cls, filename):
        nxroot = nxload(filename)
        sample_maps = []
        config = {'sample_maps': sample_maps}
        for nxentry_name, nxentry in nxroot.items():
            sample_maps.append(MapConfig.construct_from_nxentry(nxentry))
        return(cls(**config))

    def cli(self):
        print('\n -- Configure a map -- ')
        self.set_list_attr_cli('sample_maps', 'sample map')

    def construct_nxfile(self, filename, mode='w-'):
        nxroot = NXroot()
        t0 = time()
        for sample_map in self.sample_maps:
            logger.info(f'Start constructing the {sample_map.title} map.')
            import_scanparser(sample_map.station)
            sample_map.construct_nxentry(nxroot)
        logger.info(f'Constructed all sample maps in {time()-t0:.2f} seconds.')
        logger.info(f'Start saving all sample maps to {filename}.')
        nxroot.save(filename, mode=mode)

    def write_to_nexus(self, filename):
        t0 = time()
        self.construct_nxfile(filename, mode='w')
        logger.info(f'Saved all sample maps to {filename} in {time()-t0:.2f} seconds.')