Mercurial > repos > rv43 > tomo
diff workflow/models.py @ 69:fba792d5f83b draft
planemo upload for repository https://github.com/rolfverberg/galaxytools commit ab9f412c362a4ab986d00e21d5185cfcf82485c1
author | rv43 |
---|---|
date | Fri, 10 Mar 2023 16:02:04 +0000 |
parents | |
children | 1cf15b61cd83 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/workflow/models.py Fri Mar 10 16:02:04 2023 +0000 @@ -0,0 +1,1077 @@ +#!/usr/bin/env python3 + +import logging +logger = logging.getLogger(__name__) + +import logging + +import numpy as np +import os +import yaml + +from functools import cache +from pathlib import PosixPath +from pydantic import validator, ValidationError, conint, confloat, constr, \ + conlist, FilePath, PrivateAttr +from pydantic import BaseModel as PydanticBaseModel +from nexusformat.nexus import * +from time import time +from typing import Optional, Literal +from typing_extensions import TypedDict +try: + from pyspec.file.spec import FileSpec +except: + pass + +from msnctools.general import is_int, is_num, input_int, input_int_list, input_num, input_yesno, \ + input_menu, index_nearest, string_to_list, file_exists_and_readable + + +def import_scanparser(station): + if station in ('id1a3', 'id3a'): + from msnctools.scanparsers import SMBRotationScanParser + globals()['ScanParser'] = SMBRotationScanParser + elif station in ('id3b'): + from msnctools.scanparsers import FMBRotationScanParser + globals()['ScanParser'] = FMBRotationScanParser + else: + raise RuntimeError(f'Invalid station: {station}') + +@cache +def get_available_scan_numbers(spec_file:str): + scans = FileSpec(spec_file).scans + scan_numbers = list(scans.keys()) + for scan_number in scan_numbers.copy(): + try: + parser = ScanParser(spec_file, scan_number) + try: + scan_type = parser.get_scan_type() + except: + scan_type = None + pass + except: + scan_numbers.remove(scan_number) + return(scan_numbers) + +@cache +def get_scanparser(spec_file:str, scan_number:int): + if scan_number not in get_available_scan_numbers(spec_file): + return(None) + else: + return(ScanParser(spec_file, scan_number)) + + +class BaseModel(PydanticBaseModel): + class Config: + validate_assignment = True + arbitrary_types_allowed = True + + @classmethod + def construct_from_cli(cls): + obj = cls.construct() + obj.cli() + return(obj) + + @classmethod + def construct_from_yaml(cls, filename): + try: + with open(filename, 'r') as infile: + indict = yaml.load(infile, Loader=yaml.CLoader) + except: + raise ValueError(f'Could not load a dictionary from {filename}') + else: + obj = cls(**indict) + return(obj) + + @classmethod + def construct_from_file(cls, filename): + file_exists_and_readable(filename) + filename = os.path.abspath(filename) + fileformat = os.path.splitext(filename)[1] + yaml_extensions = ('.yaml','.yml') + nexus_extensions = ('.nxs','.nx5','.h5','.hdf5') + t0 = time() + if fileformat.lower() in yaml_extensions: + obj = cls.construct_from_yaml(filename) + logger.info(f'Constructed a model from {filename} in {time()-t0:.2f} seconds.') + return(obj) + elif fileformat.lower() in nexus_extensions: + obj = cls.construct_from_nexus(filename) + logger.info(f'Constructed a model from {filename} in {time()-t0:.2f} seconds.') + return(obj) + else: + logger.error(f'Unsupported file extension for constructing a model: {fileformat}') + raise TypeError(f'Unrecognized file extension: {fileformat}') + + def dict_for_yaml(self, exclude_fields=[]): + yaml_dict = {} + for field_name in self.__fields__: + if field_name in exclude_fields: + continue + else: + field_value = getattr(self, field_name, None) + if field_value is not None: + if isinstance(field_value, BaseModel): + yaml_dict[field_name] = field_value.dict_for_yaml() + elif isinstance(field_value,list) and all(isinstance(item,BaseModel) + for item in field_value): + yaml_dict[field_name] = [item.dict_for_yaml() for item in field_value] + elif isinstance(field_value, PosixPath): + yaml_dict[field_name] = str(field_value) + else: + yaml_dict[field_name] = field_value + else: + continue + return(yaml_dict) + + def write_to_yaml(self, filename=None): + yaml_dict = self.dict_for_yaml() + if filename is None: + logger.info('Printing yaml representation here:\n'+ + f'{yaml.dump(yaml_dict, sort_keys=False)}') + else: + try: + with open(filename, 'w') as outfile: + yaml.dump(yaml_dict, outfile, sort_keys=False) + logger.info(f'Successfully wrote this model to {filename}') + except: + logger.error(f'Unknown error -- could not write to {filename} in yaml format.') + logger.info('Printing yaml representation here:\n'+ + f'{yaml.dump(yaml_dict, sort_keys=False)}') + + def write_to_file(self, filename, force_overwrite=False): + file_writeable, fileformat = self.output_file_valid(filename, + force_overwrite=force_overwrite) + if fileformat == 'yaml': + if file_writeable: + self.write_to_yaml(filename=filename) + else: + self.write_to_yaml() + elif fileformat == 'nexus': + if file_writeable: + self.write_to_nexus(filename=filename) + + def output_file_valid(self, filename, force_overwrite=False): + filename = os.path.abspath(filename) + fileformat = os.path.splitext(filename)[1] + yaml_extensions = ('.yaml','.yml') + nexus_extensions = ('.nxs','.nx5','.h5','.hdf5') + if fileformat.lower() not in (*yaml_extensions, *nexus_extensions): + return(False, None) # Only yaml and NeXus files allowed for output now. + elif fileformat.lower() in yaml_extensions: + fileformat = 'yaml' + elif fileformat.lower() in nexus_extensions: + fileformat = 'nexus' + if os.path.isfile(filename): + if os.access(filename, os.W_OK): + if not force_overwrite: + logger.error(f'{filename} will not be overwritten.') + return(False, fileformat) + else: + logger.error(f'Cannot access {filename} for writing.') + return(False, fileformat) + if os.path.isdir(os.path.dirname(filename)): + if os.access(os.path.dirname(filename), os.W_OK): + return(True, fileformat) + else: + logger.error(f'Cannot access {os.path.dirname(filename)} for writing.') + return(False, fileformat) + else: + try: + os.makedirs(os.path.dirname(filename)) + return(True, fileformat) + except: + logger.error(f'Cannot create {os.path.dirname(filename)} for output.') + return(False, fileformat) + + def set_single_attr_cli(self, attr_name, attr_desc='unknown attribute', list_flag=False, + **cli_kwargs): + if cli_kwargs.get('chain_attr_desc', False): + cli_kwargs['attr_desc'] = attr_desc + try: + attr = getattr(self, attr_name, None) + if attr is None: + attr = self.__fields__[attr_name].type_.construct() + if cli_kwargs.get('chain_attr_desc', False): + cli_kwargs['attr_desc'] = attr_desc + input_accepted = False + while not input_accepted: + try: + attr.cli(**cli_kwargs) + except ValidationError as e: + print(e) + print(f'Removing {attr_desc} configuration') + attr = self.__fields__[attr_name].type_.construct() + continue + except KeyboardInterrupt as e: + raise e + except BaseException as e: + print(f'{type(e).__name__}: {e}') + print(f'Removing {attr_desc} configuration') + attr = self.__fields__[attr_name].type_.construct() + continue + try: + setattr(self, attr_name, attr) + except ValidationError as e: + print(e) + except KeyboardInterrupt as e: + raise e + except BaseException as e: + print(f'{type(e).__name__}: {e}') + else: + input_accepted = True + except: + input_accepted = False + while not input_accepted: + attr = getattr(self, attr_name, None) + if attr is None: + input_value = input(f'Type and enter a value for {attr_desc}: ') + else: + input_value = input(f'Type and enter a new value for {attr_desc} or press '+ + f'enter to keep the current one ({attr}): ') + if list_flag: + input_value = string_to_list(input_value, remove_duplicates=False, sort=False) + if len(input_value) == 0: + input_value = getattr(self, attr_name, None) + try: + setattr(self, attr_name, input_value) + except ValidationError as e: + print(e) + except KeyboardInterrupt as e: + raise e + except BaseException as e: + print(f'Unexpected {type(e).__name__}: {e}') + else: + input_accepted = True + + def set_list_attr_cli(self, attr_name, attr_desc='unknown attribute', **cli_kwargs): + if cli_kwargs.get('chain_attr_desc', False): + cli_kwargs['attr_desc'] = attr_desc + attr = getattr(self, attr_name, None) + if attr is not None: + # Check existing items + for item in attr: + item_accepted = False + while not item_accepted: + item.cli(**cli_kwargs) + try: + setattr(self, attr_name, attr) + except ValidationError as e: + print(e) + except KeyboardInterrupt as e: + raise e + except BaseException as e: + print(f'{type(e).__name__}: {e}') + else: + item_accepted = True + else: + # Initialize list for new attribute & starting item + attr = [] + item = self.__fields__[attr_name].type_.construct() + # Append (optional) additional items + append = input_yesno(f'Add a {attr_desc} configuration? (y/n)', 'n') + while append: + attr.append(item.__class__.construct_from_cli()) + try: + setattr(self, attr_name, attr) + except ValidationError as e: + print(e) + print(f'Removing last {attr_desc} configuration from the list') + attr.pop() + except KeyboardInterrupt as e: + raise e + except BaseException as e: + print(f'{type(e).__name__}: {e}') + print(f'Removing last {attr_desc} configuration from the list') + attr.pop() + else: + append = input_yesno(f'Add another {attr_desc} configuration? (y/n)', 'n') + + +class Detector(BaseModel): + prefix: constr(strip_whitespace=True, min_length=1) + rows: conint(gt=0) + columns: conint(gt=0) + pixel_size: conlist(item_type=confloat(gt=0), min_items=1, max_items=2) + lens_magnification: confloat(gt=0) = 1.0 + + @property + def get_pixel_size(self): + return(list(np.asarray(self.pixel_size)/self.lens_magnification)) + + def construct_from_yaml(self, filename): + try: + with open(filename, 'r') as infile: + indict = yaml.load(infile, Loader=yaml.CLoader) + detector = indict['detector'] + self.prefix = detector['id'] + pixels = detector['pixels'] + self.rows = pixels['rows'] + self.columns = pixels['columns'] + self.pixel_size = pixels['size'] + self.lens_magnification = indict['lens_magnification'] + except: + logging.warning(f'Could not load a dictionary from {filename}') + return(False) + else: + return(True) + + def cli(self): + print('\n -- Configure the detector -- ') + self.set_single_attr_cli('prefix', 'detector ID') + self.set_single_attr_cli('rows', 'number of pixel rows') + self.set_single_attr_cli('columns', 'number of pixel columns') + self.set_single_attr_cli('pixel_size', 'pixel size in mm (enter either a single value for '+ + 'square pixels or a pair of values for the size in the respective row and column '+ + 'directions)', list_flag=True) + self.set_single_attr_cli('lens_magnification', 'lens magnification') + + def construct_nxdetector(self): + nxdetector = NXdetector() + nxdetector.local_name = self.prefix + pixel_size = self.get_pixel_size + if len(pixel_size) == 1: + nxdetector.x_pixel_size = pixel_size[0] + nxdetector.y_pixel_size = pixel_size[0] + else: + nxdetector.x_pixel_size = pixel_size[0] + nxdetector.y_pixel_size = pixel_size[1] + nxdetector.x_pixel_size.attrs['units'] = 'mm' + nxdetector.y_pixel_size.attrs['units'] = 'mm' + return(nxdetector) + + +class ScanInfo(TypedDict): + scan_number: int + starting_image_offset: conint(ge=0) + num_image: conint(gt=0) + ref_x: float + ref_z: float + +class SpecScans(BaseModel): + spec_file: FilePath + scan_numbers: conlist(item_type=conint(gt=0), min_items=1) + stack_info: conlist(item_type=ScanInfo, min_items=1) = [] + + @validator('spec_file') + def validate_spec_file(cls, spec_file): + try: + spec_file = os.path.abspath(spec_file) + sspec_file = FileSpec(spec_file) + except: + raise ValueError(f'Invalid SPEC file {spec_file}') + else: + return(spec_file) + + @validator('scan_numbers') + def validate_scan_numbers(cls, scan_numbers, values): + spec_file = values.get('spec_file') + if spec_file is not None: + spec_scans = FileSpec(spec_file) + for scan_number in scan_numbers: + scan = spec_scans.get_scan_by_number(scan_number) + if scan is None: + raise ValueError(f'There is no scan number {scan_number} in {spec_file}') + return(scan_numbers) + + @validator('stack_info') + def validate_stack_info(cls, stack_info, values): + scan_numbers = values.get('scan_numbers') + assert(len(scan_numbers) == len(stack_info)) + for scan_info in stack_info: + assert(scan_info['scan_number'] in scan_numbers) + is_int(scan_info['starting_image_offset'], ge=0, lt=scan_info['num_image'], + raise_error=True) + return(stack_info) + + @classmethod + def construct_from_nxcollection(cls, nxcollection:NXcollection): + config = {} + config['spec_file'] = nxcollection.attrs['spec_file'] + scan_numbers = [] + stack_info = [] + for nxsubentry_name, nxsubentry in nxcollection.items(): + scan_number = int(nxsubentry_name.split('_')[-1]) + scan_numbers.append(scan_number) + stack_info.append({'scan_number': scan_number, + 'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number), + 'num_image': len(nxsubentry.sample.rotation_angle), + 'ref_x': float(nxsubentry.sample.x_translation), + 'ref_z': float(nxsubentry.sample.z_translation)}) + config['scan_numbers'] = sorted(scan_numbers) + config['stack_info'] = stack_info + return(cls(**config)) + + @property + def available_scan_numbers(self): + return(get_available_scan_numbers(self.spec_file)) + + def set_from_nxcollection(self, nxcollection:NXcollection): + self.spec_file = nxcollection.attrs['spec_file'] + scan_numbers = [] + stack_info = [] + for nxsubentry_name, nxsubentry in nxcollection.items(): + scan_number = int(nxsubentry_name.split('_')[-1]) + scan_numbers.append(scan_number) + stack_info.append({'scan_number': scan_number, + 'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number), + 'num_image': len(nxsubentry.sample.rotation_angle), + 'ref_x': float(nxsubentry.sample.x_translation), + 'ref_z': float(nxsubentry.sample.z_translation)}) + self.scan_numbers = sorted(scan_numbers) + self.stack_info = stack_info + + def get_scan_index(self, scan_number): + scan_index = [scan_index for scan_index, scan_info in enumerate(self.stack_info) + if scan_info['scan_number'] == scan_number] + if len(scan_index) > 1: + raise ValueError('Duplicate scan_numbers in image stack') + elif len(scan_index) == 1: + return(scan_index[0]) + else: + return(None) + + def get_scanparser(self, scan_number): + return(get_scanparser(self.spec_file, scan_number)) + +# def get_detector_data(self, detector_prefix, scan_number=None, scan_step_index=None): +# if scan_number is None: +# scan_number = self.scan_numbers[0] +# if scan_step_index is None: +# scan_info = self.stack_info[self.get_scan_index(scan_number)] +# scan_step_index = scan_info['starting_image_offset'] +# parser = self.get_scanparser(scan_number) +# return(parser.get_detector_data(detector_prefix, scan_step_index)) + + def get_detector_data(self, detector_prefix, scan_number=None, scan_step_index=None): + image_stacks = [] + if scan_number is None: + scan_numbers = self.scan_numbers + else: + scan_numbers = [scan_number] + for scan_number in scan_numbers: + parser = self.get_scanparser(scan_number) + scan_info = self.stack_info[self.get_scan_index(scan_number)] + image_offset = scan_info['starting_image_offset'] + if scan_step_index is None: + num_image = scan_info['num_image'] + image_stacks.append(parser.get_detector_data(detector_prefix, + (image_offset, image_offset+num_image))) + else: + image_stacks.append(parser.get_detector_data(detector_prefix, + image_offset+scan_step_index)) + if len(image_stacks) == 1: + return(image_stacks[0]) + else: + return(image_stacks) + + def scan_numbers_cli(self, attr_desc, **kwargs): + available_scan_numbers = self.available_scan_numbers + station = kwargs.get('station') + if (station is not None and station in ('id1a3', 'id3a') and + 'scan_type' in kwargs): + scan_type = kwargs['scan_type'] + if scan_type == 'ts1': + available_scan_numbers = [] + for scan_number in self.available_scan_numbers: + parser = self.get_scanparser(scan_number) + if parser.scan_type == scan_type: + available_scan_numbers.append(scan_number) + elif scan_type == 'df1': + tomo_scan_numbers = kwargs['tomo_scan_numbers'] + available_scan_numbers = [] + for scan_number in tomo_scan_numbers: + parser = self.get_scanparser(scan_number-2) + assert(parser.scan_type == scan_type) + available_scan_numbers.append(scan_number-2) + elif scan_type == 'bf1': + tomo_scan_numbers = kwargs['tomo_scan_numbers'] + available_scan_numbers = [] + for scan_number in tomo_scan_numbers: + parser = self.get_scanparser(scan_number-1) + assert(parser.scan_type == scan_type) + available_scan_numbers.append(scan_number-1) + if len(available_scan_numbers) == 1: + input_mode = 1 + else: + if hasattr(self, 'scan_numbers'): + print(f'Currently selected {attr_desc}scan numbers are: {self.scan_numbers}') + menu_options = [f'Select a subset of the available {attr_desc}scan numbers', + f'Use all available {attr_desc}scan numbers in {self.spec_file}', + f'Keep the currently selected {attr_desc}scan numbers'] + else: + menu_options = [f'Select a subset of the available {attr_desc}scan numbers', + f'Use all available {attr_desc}scan numbers in {self.spec_file}'] + print(f'Available scan numbers in {self.spec_file} are: '+ + f'{available_scan_numbers}') + input_mode = input_menu(menu_options, header='Choose one of the following options '+ + 'for selecting scan numbers') + if input_mode == 0: + accept_scan_numbers = False + while not accept_scan_numbers: + try: + self.scan_numbers = \ + input_int_list(f'Enter a series of {attr_desc}scan numbers') + except ValidationError as e: + print(e) + except KeyboardInterrupt as e: + raise e + except BaseException as e: + print(f'Unexpected {type(e).__name__}: {e}') + else: + accept_scan_numbers = True + elif input_mode == 1: + self.scan_numbers = available_scan_numbers + elif input_mode == 2: + pass + + def cli(self, **cli_kwargs): + if cli_kwargs.get('attr_desc') is not None: + attr_desc = f'{cli_kwargs["attr_desc"]} ' + else: + attr_desc = '' + print(f'\n -- Configure which scans to use from a single {attr_desc}SPEC file') + self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path') + self.scan_numbers_cli(attr_desc) + + def construct_nxcollection(self, image_key, thetas, detector): + nxcollection = NXcollection() + nxcollection.attrs['spec_file'] = str(self.spec_file) + parser = self.get_scanparser(self.scan_numbers[0]) + nxcollection.attrs['date'] = parser.spec_scan.file_date + for scan_number in self.scan_numbers: + # Get scan info + scan_info = self.stack_info[self.get_scan_index(scan_number)] + # Add an NXsubentry to the NXcollection for each scan + entry_name = f'scan_{scan_number}' + nxsubentry = NXsubentry() + nxcollection[entry_name] = nxsubentry + parser = self.get_scanparser(scan_number) + nxsubentry.start_time = parser.spec_scan.date + nxsubentry.spec_command = parser.spec_command + # Add an NXdata for independent dimensions to the scan's NXsubentry + num_image = scan_info['num_image'] + if thetas is None: + thetas = num_image*[0.0] + else: + assert(num_image == len(thetas)) +# nxsubentry.independent_dimensions = NXdata() +# nxsubentry.independent_dimensions.rotation_angle = thetas +# nxsubentry.independent_dimensions.rotation_angle.units = 'degrees' + # Add an NXinstrument to the scan's NXsubentry + nxsubentry.instrument = NXinstrument() + # Add an NXdetector to the NXinstrument to the scan's NXsubentry + nxsubentry.instrument.detector = detector.construct_nxdetector() + nxsubentry.instrument.detector.frame_start_number = scan_info['starting_image_offset'] + nxsubentry.instrument.detector.image_key = image_key + # Add an NXsample to the scan's NXsubentry + nxsubentry.sample = NXsample() + nxsubentry.sample.rotation_angle = thetas + nxsubentry.sample.rotation_angle.units = 'degrees' + nxsubentry.sample.x_translation = scan_info['ref_x'] + nxsubentry.sample.x_translation.units = 'mm' + nxsubentry.sample.z_translation = scan_info['ref_z'] + nxsubentry.sample.z_translation.units = 'mm' + return(nxcollection) + + +class FlatField(SpecScans): + + def image_range_cli(self, attr_desc, detector_prefix): + stack_info = self.stack_info + for scan_number in self.scan_numbers: + # Parse the available image range + parser = self.get_scanparser(scan_number) + image_offset = parser.starting_image_offset + num_image = parser.get_num_image(detector_prefix.upper()) + scan_index = self.get_scan_index(scan_number) + + # Select the image set + last_image_index = image_offset+num_image-1 + print(f'Available good image set index range: [{image_offset}, {last_image_index}]') + image_set_approved = False + if scan_index is not None: + scan_info = stack_info[scan_index] + print(f'Current starting image offset and number of images: '+ + f'{scan_info["starting_image_offset"]} and {scan_info["num_image"]}') + image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y') + if not image_set_approved: + print(f'Default starting image offset and number of images: '+ + f'{image_offset} and {last_image_index-image_offset}') + image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y') + if image_set_approved: + offset = image_offset + num = last_image_index-offset + while not image_set_approved: + offset = input_int(f'Enter the starting image offset', ge=image_offset, + le=last_image_index-1)#, default=image_offset) + num = input_int(f'Enter the number of images', ge=1, + le=last_image_index-offset+1)#, default=last_image_index-offset+1) + print(f'Current starting image offset and number of images: {offset} and {num}') + image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y') + if scan_index is not None: + scan_info['starting_image_offset'] = offset + scan_info['num_image'] = num + scan_info['ref_x'] = parser.horizontal_shift + scan_info['ref_z'] = parser.vertical_shift + else: + stack_info.append({'scan_number': scan_number, 'starting_image_offset': offset, + 'num_image': num, 'ref_x': parser.horizontal_shift, + 'ref_z': parser.vertical_shift}) + self.stack_info = stack_info + + def cli(self, **cli_kwargs): + if cli_kwargs.get('attr_desc') is not None: + attr_desc = f'{cli_kwargs["attr_desc"]} ' + else: + attr_desc = '' + station = cli_kwargs.get('station') + detector = cli_kwargs.get('detector') + print(f'\n -- Configure the location of the {attr_desc}scan data -- ') + if station in ('id1a3', 'id3a'): + self.spec_file = cli_kwargs['spec_file'] + tomo_scan_numbers = cli_kwargs['tomo_scan_numbers'] + scan_type = cli_kwargs['scan_type'] + self.scan_numbers_cli(attr_desc, station=station, tomo_scan_numbers=tomo_scan_numbers, + scan_type=scan_type) + else: + self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path') + self.scan_numbers_cli(attr_desc) + self.image_range_cli(attr_desc, detector.prefix) + + +class TomoField(SpecScans): + theta_range: dict = {} + + @validator('theta_range') + def validate_theta_range(cls, theta_range): + if len(theta_range) != 3 and len(theta_range) != 4: + raise ValueError(f'Invalid theta range {theta_range}') + is_num(theta_range['start'], raise_error=True) + is_num(theta_range['end'], raise_error=True) + is_int(theta_range['num'], gt=1, raise_error=True) + if theta_range['end'] <= theta_range['start']: + raise ValueError(f'Invalid theta range {theta_range}') + if 'start_index' in theta_range: + is_int(theta_range['start_index'], ge=0, raise_error=True) + return(theta_range) + + @classmethod + def construct_from_nxcollection(cls, nxcollection:NXcollection): + #RV Can I derive this from the same classfunction for SpecScans by adding theta_range + config = {} + config['spec_file'] = nxcollection.attrs['spec_file'] + scan_numbers = [] + stack_info = [] + for nxsubentry_name, nxsubentry in nxcollection.items(): + scan_number = int(nxsubentry_name.split('_')[-1]) + scan_numbers.append(scan_number) + stack_info.append({'scan_number': scan_number, + 'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number), + 'num_image': len(nxsubentry.sample.rotation_angle), + 'ref_x': float(nxsubentry.sample.x_translation), + 'ref_z': float(nxsubentry.sample.z_translation)}) + config['scan_numbers'] = sorted(scan_numbers) + config['stack_info'] = stack_info + for name in nxcollection.entries: + if 'scan_' in name: + thetas = np.asarray(nxcollection[name].sample.rotation_angle) + config['theta_range'] = {'start': thetas[0], 'end': thetas[-1], 'num': thetas.size} + break + return(cls(**config)) + + def get_horizontal_shifts(self, scan_number=None): + horizontal_shifts = [] + if scan_number is None: + scan_numbers = self.scan_numbers + else: + scan_numbers = [scan_number] + for scan_number in scan_numbers: + parser = self.get_scanparser(scan_number) + horizontal_shifts.append(parser.get_horizontal_shift()) + if len(horizontal_shifts) == 1: + return(horizontal_shifts[0]) + else: + return(horizontal_shifts) + + def get_vertical_shifts(self, scan_number=None): + vertical_shifts = [] + if scan_number is None: + scan_numbers = self.scan_numbers + else: + scan_numbers = [scan_number] + for scan_number in scan_numbers: + parser = self.get_scanparser(scan_number) + vertical_shifts.append(parser.get_vertical_shift()) + if len(vertical_shifts) == 1: + return(vertical_shifts[0]) + else: + return(vertical_shifts) + + def theta_range_cli(self, scan_number, attr_desc, station): + # Parse the available theta range + parser = self.get_scanparser(scan_number) + theta_vals = parser.theta_vals + spec_theta_start = theta_vals.get('start') + spec_theta_end = theta_vals.get('end') + spec_num_theta = theta_vals.get('num') + + # Check for consistency of theta ranges between scans + if scan_number != self.scan_numbers[0]: + parser = self.get_scanparser(self.scan_numbers[0]) + if (parser.theta_vals.get('start') != spec_theta_start or + parser.theta_vals.get('end') != spec_theta_end or + parser.theta_vals.get('num') != spec_num_theta): + raise ValueError(f'Incompatible theta ranges between {attr_desc}scans:'+ + f'\n\tScan {scan_number}: {theta_vals}'+ + f'\n\tScan {self.scan_numbers[0]}: {parser.theta_vals}') + return + + # Select the theta range for the tomo reconstruction from the first scan + thetas = np.linspace(spec_theta_start, spec_theta_end, spec_num_theta) + delta_theta = thetas[1]-thetas[0] + theta_range_approved = False + print(f'Theta range obtained from SPEC data: [{spec_theta_start}, {spec_theta_end})') + print(f'Theta step size = {delta_theta}') + print(f'Number of theta values: {spec_num_theta}') + exit('Done') + default_start = None + default_end = None + if station in ('id1a3', 'id3a'): + theta_range_approved = input_yesno(f'Accept this theta range (y/n)?', 'y') + if theta_range_approved: + theta_start = spec_theta_start + theta_end = spec_theta_end + num_theta = spec_num_theta + theta_index_start = 0 + elif station in ('id3b'): + if spec_theta_start <= 0.0 and spec_theta_end >= 180.0: + default_start = 0 + default_end = 180 + elif spec_theta_end-spec_theta_start == 180: + default_start = spec_theta_start + default_end = spec_theta_end + while not theta_range_approved: + theta_start = input_num(f'Enter the first theta (included)', ge=spec_theta_start, + lt=spec_theta_end, default=default_start) + theta_index_start = index_nearest(thetas, theta_start) + theta_start = thetas[theta_index_start] + theta_end = input_num(f'Enter the last theta (excluded)', + ge=theta_start+delta_theta, le=spec_theta_end, default=default_end) + theta_index_end = index_nearest(thetas, theta_end) + theta_end = thetas[theta_index_end] + num_theta = theta_index_end-theta_index_start + print(f'Selected theta range: [{theta_start}, {theta_start+delta_theta}, ..., '+ + f'{theta_end})') + print(f'Number of theta values: {num_theta}') + theta_range_approved = input_yesno(f'Accept this theta range (y/n)?', 'y') + self.thetas = np.linspace(theta_start, theta_end, num_theta) + + def image_range_cli(self, attr_desc, detector_prefix): + stack_info = self.stack_info + for scan_number in self.scan_numbers: + # Parse the available image range + parser = self.get_scanparser(scan_number) + image_offset = parser.starting_image_offset + num_image = parser.get_num_image(detector_prefix.upper()) + scan_index = self.get_scan_index(scan_number) + + # Select the image set matching the theta range + num_theta = self.theta_range['num'] + theta_index_start = self.theta_range['start_index'] + if num_theta > num_image-theta_index_start: + raise ValueError(f'Available {attr_desc}image indices incompatible with thetas:'+ + f'\n\tNumber of thetas and offset = {num_theta} and {theta_index_start}'+ + f'\n\tNumber of available images {num_image}') + if scan_index is not None: + scan_info = stack_info[scan_index] + scan_info['starting_image_offset'] = image_offset+theta_index_start + scan_info['num_image'] = num_theta + scan_info['ref_x'] = parser.horizontal_shift + scan_info['ref_z'] = parser.vertical_shift + else: + stack_info.append({'scan_number': scan_number, + 'starting_image_offset': image_offset+theta_index_start, + 'num_image': num_theta, 'ref_x': parser.horizontal_shift, + 'ref_z': parser.vertical_shift}) + self.stack_info = stack_info + + def cli(self, **cli_kwargs): + if cli_kwargs.get('attr_desc') is not None: + attr_desc = f'{cli_kwargs["attr_desc"]} ' + else: + attr_desc = '' + cycle = cli_kwargs.get('cycle') + btr = cli_kwargs.get('btr') + station = cli_kwargs.get('station') + detector = cli_kwargs.get('detector') + print(f'\n -- Configure the location of the {attr_desc}scan data -- ') + if station in ('id1a3', 'id3a'): + basedir = f'/nfs/chess/{station}/{cycle}/{btr}' + runs = [d for d in os.listdir(basedir) if os.path.isdir(os.path.join(basedir, d))] +#RV index = 15-1 +#RV index = 7-1 + index = input_menu(runs, header='Choose a sample directory') + self.spec_file = f'{basedir}/{runs[index]}/spec.log' + self.scan_numbers_cli(attr_desc, station=station, scan_type='ts1') + else: + self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path') + self.scan_numbers_cli(attr_desc) + for scan_number in self.scan_numbers: + self.theta_range_cli(scan_number, attr_desc, station) + self.image_range_cli(attr_desc, detector.prefix) + + +class Sample(BaseModel): + name: constr(min_length=1) + description: Optional[str] + rotation_angles: Optional[list] + x_translations: Optional[list] + z_translations: Optional[list] + + @classmethod + def construct_from_nxsample(cls, nxsample:NXsample): + config = {} + config['name'] = nxsample.name.nxdata + if 'description' in nxsample: + config['description'] = nxsample.description.nxdata + if 'rotation_angle' in nxsample: + config['rotation_angle'] = nxsample.rotation_angle.nxdata + if 'x_translation' in nxsample: + config['x_translation'] = nxsample.x_translation.nxdata + if 'z_translation' in nxsample: + config['z_translation'] = nxsample.z_translation.nxdata + return(cls(**config)) + + def cli(self): + print('\n -- Configure the sample metadata -- ') +#RV self.name = 'test' +#RV self.name = 'sobhani-3249-A' + self.set_single_attr_cli('name', 'the sample name') +#RV self.description = 'test sample' + self.set_single_attr_cli('description', 'a description of the sample (optional)') + + +class MapConfig(BaseModel): + cycle: constr(strip_whitespace=True, min_length=1) + btr: constr(strip_whitespace=True, min_length=1) + title: constr(strip_whitespace=True, min_length=1) + station: Literal['id1a3', 'id3a', 'id3b'] = None + sample: Sample + detector: Detector = Detector.construct() + tomo_fields: TomoField + dark_field: Optional[FlatField] + bright_field: FlatField + _thetas: list[float] = PrivateAttr() + _field_types = ({'name': 'dark_field', 'image_key': 2}, {'name': 'bright_field', + 'image_key': 1}, {'name': 'tomo_fields', 'image_key': 0}) + + @classmethod + def construct_from_nxentry(cls, nxentry:NXentry): + config = {} + config['cycle'] = nxentry.instrument.source.attrs['cycle'] + config['btr'] = nxentry.instrument.source.attrs['btr'] + config['title'] = nxentry.nxname + config['station'] = nxentry.instrument.source.attrs['station'] + config['sample'] = Sample.construct_from_nxsample(nxentry['sample']) + for nxobject_name, nxobject in nxentry.spec_scans.items(): + if isinstance(nxobject, NXcollection): + config[nxobject_name] = SpecScans.construct_from_nxcollection(nxobject) + return(cls(**config)) + +#FIX cache? + @property + def thetas(self): + try: + return(self._thetas) + except: + theta_range = self.tomo_fields.theta_range + self._thetas = list(np.linspace(theta_range['start'], theta_range['end'], + theta_range['num'])) + return(self._thetas) + + def cli(self): + print('\n -- Configure a map from a set of SPEC scans (dark, bright, and tomo), '+ + 'and / or detector data -- ') +#RV self.cycle = '2021-3' +#RV self.cycle = '2022-2' +#RV self.cycle = '2023-1' + self.set_single_attr_cli('cycle', 'beam cycle') +#RV self.btr = 'z-3234-A' +#RV self.btr = 'sobhani-3249-A' +#RV self.btr = 'przybyla-3606-a' + self.set_single_attr_cli('btr', 'BTR') +#RV self.title = 'z-3234-A' +#RV self.title = 'tomo7C' +#RV self.title = 'cmc-test-dwell-1' + self.set_single_attr_cli('title', 'title for the map entry') +#RV self.station = 'id3a' +#RV self.station = 'id3b' +#RV self.station = 'id1a3' + self.set_single_attr_cli('station', 'name of the station at which scans were collected '+ + '(currently choose from: id1a3, id3a, id3b)') + import_scanparser(self.station) + self.set_single_attr_cli('sample') + use_detector_config = False + if hasattr(self.detector, 'prefix') and len(self.detector.prefix): + use_detector_config = input_yesno(f'Current detector settings:\n{self.detector}\n'+ + f'Keep these settings? (y/n)') + if not use_detector_config: +#RV have_detector_config = True + have_detector_config = input_yesno(f'Is a detector configuration file available? (y/n)') + if have_detector_config: +#RV detector_config_file = 'retiga.yaml' +#RV detector_config_file = 'andor2.yaml' + detector_config_file = input(f'Enter detector configuration file name: ') + have_detector_config = self.detector.construct_from_yaml(detector_config_file) + if not have_detector_config: + self.set_single_attr_cli('detector', 'detector') + self.set_single_attr_cli('tomo_fields', 'Tomo field', chain_attr_desc=True, + cycle=self.cycle, btr=self.btr, station=self.station, detector=self.detector) + if self.station in ('id1a3', 'id3a'): + have_dark_field = True + tomo_spec_file = self.tomo_fields.spec_file + else: + have_dark_field = input_yesno(f'Are Dark field images available? (y/n)') + tomo_spec_file = None + if have_dark_field: + self.set_single_attr_cli('dark_field', 'Dark field', chain_attr_desc=True, + station=self.station, detector=self.detector, spec_file=tomo_spec_file, + tomo_scan_numbers=self.tomo_fields.scan_numbers, scan_type='df1') + self.set_single_attr_cli('bright_field', 'Bright field', chain_attr_desc=True, + station=self.station, detector=self.detector, spec_file=tomo_spec_file, + tomo_scan_numbers=self.tomo_fields.scan_numbers, scan_type='bf1') + + def construct_nxentry(self, nxroot, include_raw_data=True): + # Construct base NXentry + nxentry = NXentry() + + # Add an NXentry to the NXroot + nxroot[self.title] = nxentry + nxroot.attrs['default'] = self.title + nxentry.definition = 'NXtomo' +# nxentry.attrs['default'] = 'data' + + # Add an NXinstrument to the NXentry + nxinstrument = NXinstrument() + nxentry.instrument = nxinstrument + + # Add an NXsource to the NXinstrument + nxsource = NXsource() + nxinstrument.source = nxsource + nxsource.type = 'Synchrotron X-ray Source' + nxsource.name = 'CHESS' + nxsource.probe = 'x-ray' + + # Tag the NXsource with the runinfo (as an attribute) + nxsource.attrs['cycle'] = self.cycle + nxsource.attrs['btr'] = self.btr + nxsource.attrs['station'] = self.station + + # Add an NXdetector to the NXinstrument (don't fill in data fields yet) + nxinstrument.detector = self.detector.construct_nxdetector() + + # Add an NXsample to NXentry (don't fill in data fields yet) + nxsample = NXsample() + nxentry.sample = nxsample + nxsample.name = self.sample.name + nxsample.description = self.sample.description + + # Add an NXcollection to the base NXentry to hold metadata about the spec scans in the map + # Also obtain the data fields in NXsample and NXdetector + nxspec_scans = NXcollection() + nxentry.spec_scans = nxspec_scans + image_keys = [] + sequence_numbers = [] + image_stacks = [] + rotation_angles = [] + x_translations = [] + z_translations = [] + for field_type in self._field_types: + field_name = field_type['name'] + field = getattr(self, field_name) + if field is None: + continue + image_key = field_type['image_key'] + if field_type['name'] == 'tomo_fields': + thetas = self.thetas + else: + thetas = None + # Add the scans in a single spec file + nxspec_scans[field_name] = field.construct_nxcollection(image_key, thetas, + self.detector) + if include_raw_data: + image_stacks.append(field.get_detector_data(self.detector.prefix)) + for scan_number in field.scan_numbers: + parser = field.get_scanparser(scan_number) + scan_info = field.stack_info[field.get_scan_index(scan_number)] + num_image = scan_info['num_image'] + image_keys += num_image*[image_key] + sequence_numbers += [i for i in range(num_image)] + if thetas is None: + rotation_angles += scan_info['num_image']*[0.0] + else: + assert(num_image == len(thetas)) + rotation_angles += thetas + x_translations += scan_info['num_image']*[scan_info['ref_x']] + z_translations += scan_info['num_image']*[scan_info['ref_z']] + + if include_raw_data: + # Add image data to NXdetector + nxinstrument.detector.image_key = image_keys + nxinstrument.detector.sequence_number = sequence_numbers + nxinstrument.detector.data = np.concatenate([image for image in image_stacks]) + + # Add image data to NXsample + nxsample.rotation_angle = rotation_angles + nxsample.rotation_angle.attrs['units'] = 'degrees' + nxsample.x_translation = x_translations + nxsample.x_translation.attrs['units'] = 'mm' + nxsample.z_translation = z_translations + nxsample.z_translation.attrs['units'] = 'mm' + + # Add an NXdata to NXentry + nxdata = NXdata() + nxentry.data = nxdata + nxdata.makelink(nxentry.instrument.detector.data, name='data') + nxdata.makelink(nxentry.instrument.detector.image_key) + nxdata.makelink(nxentry.sample.rotation_angle) + nxdata.makelink(nxentry.sample.x_translation) + nxdata.makelink(nxentry.sample.z_translation) +# nxdata.attrs['axes'] = ['field', 'row', 'column'] +# nxdata.attrs['field_indices'] = 0 +# nxdata.attrs['row_indices'] = 1 +# nxdata.attrs['column_indices'] = 2 + + +class TomoWorkflow(BaseModel): + sample_maps: conlist(item_type=MapConfig, min_items=1) = [MapConfig.construct()] + + @classmethod + def construct_from_nexus(cls, filename): + nxroot = nxload(filename) + sample_maps = [] + config = {'sample_maps': sample_maps} + for nxentry_name, nxentry in nxroot.items(): + sample_maps.append(MapConfig.construct_from_nxentry(nxentry)) + return(cls(**config)) + + def cli(self): + print('\n -- Configure a map -- ') + self.set_list_attr_cli('sample_maps', 'sample map') + + def construct_nxfile(self, filename, mode='w-'): + nxroot = NXroot() + t0 = time() + for sample_map in self.sample_maps: + logger.info(f'Start constructing the {sample_map.title} map.') + import_scanparser(sample_map.station) + sample_map.construct_nxentry(nxroot) + logger.info(f'Constructed all sample maps in {time()-t0:.2f} seconds.') + logger.info(f'Start saving all sample maps to {filename}.') + nxroot.save(filename, mode=mode) + + def write_to_nexus(self, filename): + t0 = time() + self.construct_nxfile(filename, mode='w') + logger.info(f'Saved all sample maps to {filename} in {time()-t0:.2f} seconds.')