Mercurial > repos > rv43 > chess_tomo
changeset 4:9aa288729b9a draft
planemo upload for repository https://github.com/rolfverberg/galaxytools commit f8c4bdb31c20c468045ad5e6eb255a293244bc6c-dirty
| author | rv43 |
|---|---|
| date | Mon, 20 Mar 2023 18:44:23 +0000 |
| parents | fc38431f257f |
| children | 543dba81eb15 |
| files | tomo_reduce.py tomo_reduce.xml workflow/__main__.py workflow/__version__.py workflow/link_to_galaxy.py workflow/models.py workflow/run_tomo.py |
| diffstat | 7 files changed, 3088 insertions(+), 1 deletions(-) [+] |
line wrap: on
line diff
--- a/tomo_reduce.py Mon Mar 20 18:30:26 2023 +0000 +++ b/tomo_reduce.py Mon Mar 20 18:44:23 2023 +0000 @@ -73,6 +73,7 @@ logging.debug(f'log = {args.log}') logging.debug(f'is log stdout? {args.log is sys.stdout}') logging.debug(f'log_level = {args.log_level}') + return # Instantiate Tomo object tomo = Tomo(galaxy_flag=args.galaxy_flag)
--- a/tomo_reduce.xml Mon Mar 20 18:30:26 2023 +0000 +++ b/tomo_reduce.xml Mon Mar 20 18:44:23 2023 +0000 @@ -1,4 +1,4 @@ -<tool id="tomo_reduce" name="Tomo Reduce" version="0.1.0" python_template_version="3.9"> +<tool id="tomo_reduce" name="Tomo Reduce" version="0.1.1" python_template_version="3.9"> <description>Reduce tomography images</description> <macros> <import>tomo_macros.xml</import> @@ -17,6 +17,10 @@ </command> <inputs> <param name="input_file" type="data" optional="false" label="Input file"/> + <section name="x_bounds" title="Reduction x bounds"> + <param name="x_bound_low" type="integer" value="-1" label="Lower bound"/> + <param name="x_bound_upp" type="integer" value="-1" label="Upper bound"/> + </section> </inputs> <outputs> <expand macro="common_outputs"/>
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/workflow/__main__.py Mon Mar 20 18:44:23 2023 +0000 @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 + +import logging +logging.getLogger(__name__) + +import argparse +import pathlib +import sys + +from .models import TomoWorkflow as Workflow +try: + from deepdiff import DeepDiff +except: + pass + +parser = argparse.ArgumentParser(description='''Operate on representations of + Tomo data workflows saved to files.''') +parser.add_argument('-l', '--log', +# type=argparse.FileType('w'), + default=sys.stdout, + help='Logging stream or filename') +parser.add_argument('--log_level', + choices=logging._nameToLevel.keys(), + default='INFO', + help='''Specify a preferred logging level.''') +subparsers = parser.add_subparsers(title='subcommands', required=True)#, dest='command') + + +# CONSTRUCT +def construct(args:list) -> None: + if args.template_file is not None: + wf = Workflow.construct_from_file(args.template_file) + wf.cli() + else: + wf = Workflow.construct_from_cli() + wf.write_to_file(args.output_file, force_overwrite=args.force_overwrite) + +construct_parser = subparsers.add_parser('construct', help='''Construct a valid Tomo + workflow representation on the command line and save it to a file. Optionally use + an existing file as a template and/or preform the reconstruction or transfer to Galaxy.''') +construct_parser.set_defaults(func=construct) +construct_parser.add_argument('-t', '--template_file', + type=pathlib.Path, + required=False, + help='''Full or relative template file path for the constructed workflow.''') +construct_parser.add_argument('-f', '--force_overwrite', + action='store_true', + help='''Use this flag to overwrite the output file if it already exists.''') +construct_parser.add_argument('-o', '--output_file', + type=pathlib.Path, + help='''Full or relative file path to which the constructed workflow will be written.''') + + +# VALIDATE +def validate(args:list) -> bool: + try: + wf = Workflow.construct_from_file(args.input_file) + logger.info(f'Success: {args.input_file} represents a valid Tomo workflow configuration.') + return(True) + except BaseException as e: + logger.error(f'{e.__class__.__name__}: {str(e)}') + logger.info(f'''Failure: {args.input_file} does not represent a valid Tomo workflow + configuration.''') + return(False) + +validate_parser = subparsers.add_parser('validate', + help='''Validate a file as a representation of a Tomo workflow (this is most useful + after a .yaml file has been manually edited).''') +validate_parser.set_defaults(func=validate) +validate_parser.add_argument('input_file', + type=pathlib.Path, + help='''Full or relative file path to validate as a Tomo workflow.''') + + +# CONVERT +def convert(args:list) -> None: + wf = Workflow.construct_from_file(args.input_file) + wf.write_to_file(args.output_file, force_overwrite=args.force_overwrite) + +convert_parser = subparsers.add_parser('convert', help='''Convert one Tomo workflow + representation to another. File format of both input and output files will be + automatically determined from the files' extensions.''') +convert_parser.set_defaults(func=convert) +convert_parser.add_argument('-f', '--force_overwrite', + action='store_true', + help='''Use this flag to overwrite the output file if it already exists.''') +convert_parser.add_argument('-i', '--input_file', + type=pathlib.Path, + required=True, + help='''Full or relative input file path to be converted.''') +convert_parser.add_argument('-o', '--output_file', + type=pathlib.Path, + required=True, + help='''Full or relative file path to which the converted input will be written.''') + + +# DIFF / COMPARE +def diff(args:list) -> bool: + raise ValueError('diff not tested') +# wf1 = Workflow.construct_from_file(args.file1).dict_for_yaml() +# wf2 = Workflow.construct_from_file(args.file2).dict_for_yaml() +# diff = DeepDiff(wf1,wf2, +# ignore_order_func=lambda level:'independent_dimensions' not in level.path(), +# report_repetition=True, +# ignore_string_type_changes=True, +# ignore_numeric_type_changes=True) + diff_report = diff.pretty() + if len(diff_report) > 0: + logger.info(f'The configurations in {args.file1} and {args.file2} are not identical.') + print(diff_report) + return(True) + else: + logger.info(f'The configurations in {args.file1} and {args.file2} are identical.') + return(False) + +diff_parser = subparsers.add_parser('diff', aliases=['compare'], help='''Print a comparison of + two Tomo workflow representations stored in files. The files may have different formats.''') +diff_parser.set_defaults(func=diff) +diff_parser.add_argument('file1', + type=pathlib.Path, + help='''Full or relative path to the first file for comparison.''') +diff_parser.add_argument('file2', + type=pathlib.Path, + help='''Full or relative path to the second file for comparison.''') + + +# LINK TO GALAXY +def link_to_galaxy(args:list) -> None: + from .link_to_galaxy import link_to_galaxy + link_to_galaxy(args.input_file, galaxy=args.galaxy, user=args.user, + password=args.password, api_key=args.api_key) + +link_parser = subparsers.add_parser('link_to_galaxy', help='''Construct a Galaxy history and link + to an existing Tomo workflow representations in a NeXus file.''') +link_parser.set_defaults(func=link_to_galaxy) +link_parser.add_argument('-i', '--input_file', + type=pathlib.Path, + required=True, + help='''Full or relative input file path to the existing Tomo workflow representations as + a NeXus file.''') +link_parser.add_argument('-g', '--galaxy', + required=True, + help='Target Galaxy instance URL/IP address') +link_parser.add_argument('-u', '--user', + default=None, + help='Galaxy user email address') +link_parser.add_argument('-p', '--password', + default=None, + help='Password for the Galaxy user') +link_parser.add_argument('-a', '--api_key', + default=None, + help='Galaxy admin user API key (required if not defined in the tools list file)') + + +# RUN THE RECONSTRUCTION +def run_tomo(args:list) -> None: + from .run_tomo import run_tomo + run_tomo(args.input_file, args.output_file, args.modes, center_file=args.center_file, + num_core=args.num_core, output_folder=args.output_folder, save_figs=args.save_figs) + +tomo_parser = subparsers.add_parser('run_tomo', help='''Construct and add reconstructed tomography + data to an existing Tomo workflow representations in a NeXus file.''') +tomo_parser.set_defaults(func=run_tomo) +tomo_parser.add_argument('-i', '--input_file', + required=True, + type=pathlib.Path, + help='''Full or relative input file path containing raw and/or reduced data.''') +tomo_parser.add_argument('-o', '--output_file', + required=True, + type=pathlib.Path, + help='''Full or relative input file path containing raw and/or reduced data.''') +tomo_parser.add_argument('-c', '--center_file', + type=pathlib.Path, + help='''Full or relative input file path containing the rotation axis centers info.''') +#tomo_parser.add_argument('-f', '--force_overwrite', +# action='store_true', +# help='''Use this flag to overwrite any existing reduced data.''') +tomo_parser.add_argument('-n', '--num_core', + type=int, + default=-1, + help='''Specify the number of processors to use.''') +tomo_parser.add_argument('--output_folder', + type=pathlib.Path, + default='.', + help='Full or relative path to an output folder') +tomo_parser.add_argument('-s', '--save_figs', + choices=['yes', 'no', 'only'], + default='no', + help='''Specify weather to display ('yes' or 'no'), save ('yes'), or only save ('only').''') +tomo_parser.add_argument('--reduce_data', + dest='modes', + const='reduce_data', + action='append_const', + help='''Use this flag to create and add reduced data to the input file.''') +tomo_parser.add_argument('--find_center', + dest='modes', + const='find_center', + action='append_const', + help='''Use this flag to find and add the calibrated center axis info to the input file.''') +tomo_parser.add_argument('--reconstruct_data', + dest='modes', + const='reconstruct_data', + action='append_const', + help='''Use this flag to create and add reconstructed data data to the input file.''') +tomo_parser.add_argument('--combine_data', + dest='modes', + const='combine_data', + action='append_const', + help='''Use this flag to combine reconstructed data data and add to the input file.''') + + +if __name__ == '__main__': + args = parser.parse_args(sys.argv[1:]) + + # Set log configuration + # When logging to file, the stdout log level defaults to WARNING + logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s' + level = logging.getLevelName(args.log_level) + if args.log is sys.stdout: + logging.basicConfig(format=logging_format, level=level, force=True, + handlers=[logging.StreamHandler()]) + else: + if isinstance(args.log, str): + logging.basicConfig(filename=f'{args.log}', filemode='w', + format=logging_format, level=level, force=True) + elif isinstance(args.log, io.TextIOWrapper): + logging.basicConfig(filemode='w', format=logging_format, level=level, + stream=args.log, force=True) + else: + raise ValueError(f'Invalid argument --log: {args.log}') + stream_handler = logging.StreamHandler() + logging.getLogger().addHandler(stream_handler) + stream_handler.setLevel(logging.WARNING) + stream_handler.setFormatter(logging.Formatter(logging_format)) + + args.func(args)
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/workflow/__version__.py Mon Mar 20 18:44:23 2023 +0000 @@ -0,0 +1,1 @@ +__version__='2022.3.0'
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/workflow/link_to_galaxy.py Mon Mar 20 18:44:23 2023 +0000 @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 + +import logging +logger = logging.getLogger(__name__) + +from bioblend.galaxy import GalaxyInstance +from nexusformat.nexus import * +from os import path +from yaml import safe_load + +from .models import import_scanparser, TomoWorkflow + +def get_folder_id(gi, path): + library_id = None + folder_id = None + folder_names = path[1:] if len(path) > 1 else [] + new_folders = folder_names + libs = gi.libraries.get_libraries(name=path[0]) + if libs: + for lib in libs: + library_id = lib['id'] + folders = gi.libraries.get_folders(library_id, folder_id=None, name=None) + for i, folder in enumerate(folders): + fid = folder['id'] + details = gi.libraries.show_folder(library_id, fid) + library_path = details['library_path'] + if library_path == folder_names: + return (library_id, fid, []) + elif len(library_path) < len(folder_names): + if library_path == folder_names[:len(library_path)]: + nf = folder_names[len(library_path):] + if len(nf) < len(new_folders): + folder_id = fid + new_folders = nf + return (library_id, folder_id, new_folders) + +def link_to_galaxy(filename:str, galaxy=None, user=None, password=None, api_key=None) -> None: + # Read input file + extension = path.splitext(filename)[1] +# RV yaml input not incorporated yet, since Galaxy can't use pyspec right now +# if extension == '.yml' or extension == '.yaml': +# with open(filename, 'r') as f: +# data = safe_load(f) +# elif extension == '.nxs': + if extension == '.nxs': + with NXFile(filename, mode='r') as nxfile: + data = nxfile.readfile() + else: + raise ValueError(f'Invalid filename extension ({extension})') + if isinstance(data, dict): + # Create Nexus format object from input dictionary + wf = TomoWorkflow(**data) + if len(wf.sample_maps) > 1: + raise ValueError(f'Multiple sample maps not yet implemented') + nxroot = NXroot() + for sample_map in wf.sample_maps: + import_scanparser(sample_map.station) +# RV raw data must be included, since Galaxy can't use pyspec right now +# sample_map.construct_nxentry(nxroot, include_raw_data=False) + sample_map.construct_nxentry(nxroot, include_raw_data=True) + nxentry = nxroot[nxroot.attrs['default']] + elif isinstance(data, NXroot): + nxentry = data[data.attrs['default']] + else: + raise ValueError(f'Invalid input file data ({data})') + + # Get a Galaxy instance + if user is not None and password is not None : + gi = GalaxyInstance(url=galaxy, email=user, password=password) + elif api_key is not None: + gi = GalaxyInstance(url=galaxy, key=api_key) + else: + exit('Please specify either a valid Galaxy username/password or an API key.') + + cycle = nxentry.instrument.source.attrs['cycle'] + btr = nxentry.instrument.source.attrs['btr'] + sample = nxentry.sample.name + + # Create a Galaxy work library/folder + # Combine the cycle, BTR and sample name as the base library name + lib_path = [p.strip() for p in f'{cycle}/{btr}/{sample}'.split('/')] + (library_id, folder_id, folder_names) = get_folder_id(gi, lib_path) + if not library_id: + library = gi.libraries.create_library(lib_path[0], description=None, synopsis=None) + library_id = library['id'] +# if user: +# gi.libraries.set_library_permissions(library_id, access_ids=user, +# manage_ids=user, modify_ids=user) + logger.info(f'Created Library:\n{library}') + if len(folder_names): + folder = gi.libraries.create_folder(library_id, folder_names[0], description=None, + base_folder_id=folder_id)[0] + folder_id = folder['id'] + logger.info(f'Created Folder:\n{folder}') + folder_names.pop(0) + while len(folder_names): + folder = gi.folders.create_folder(folder['id'], folder_names[0], + description=None) + folder_id = folder['id'] + logger.info(f'Created Folder:\n{folder}') + folder_names.pop(0) + + # Create a sym link for the Nexus file + dataset = gi.libraries.upload_from_galaxy_filesystem(library_id, path.abspath(filename), + folder_id=folder_id, file_type='auto', dbkey='?', link_data_only='link_to_files', + roles='', preserve_dirs=False, tag_using_filenames=False, tags=None)[0] + + # Make a history for the data + history_name = f'tomo {btr} {sample}' + history = gi.histories.create_history(name=history_name) + logger.info(f'Created history:\n{history}') + history_id = history['id'] + gi.histories.copy_dataset(history_id, dataset['id'], source='library') + +# TODO add option to either +# get a URL to share the history +# or to share with specific users +# This might require using: +# https://bioblend.readthedocs.io/en/latest/api_docs/galaxy/docs.html#using-bioblend-for-raw-api-calls +
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/workflow/models.py Mon Mar 20 18:44:23 2023 +0000 @@ -0,0 +1,1096 @@ +#!/usr/bin/env python3 + +import logging +logger = logging.getLogger(__name__) + +import logging + +import numpy as np +import os +import yaml + +from functools import cache +from pathlib import PosixPath +from pydantic import BaseModel as PydanticBaseModel +from pydantic import validator, ValidationError, conint, confloat, constr, conlist, FilePath, \ + PrivateAttr +from nexusformat.nexus import * +from time import time +from typing import Optional, Literal +from typing_extensions import TypedDict +try: + from pyspec.file.spec import FileSpec +except: + pass + +try: + from msnctools.general import is_int, is_num, input_int, input_int_list, input_num, \ + input_yesno, input_menu, index_nearest, string_to_list, file_exists_and_readable +except: + from general import is_int, is_num, input_int, input_int_list, input_num, \ + input_yesno, input_menu, index_nearest, string_to_list, file_exists_and_readable + + +def import_scanparser(station): + if station in ('id1a3', 'id3a'): + try: + from msnctools.scanparsers import SMBRotationScanParser + globals()['ScanParser'] = SMBRotationScanParser + except: + try: + from scanparsers import SMBRotationScanParser + globals()['ScanParser'] = SMBRotationScanParser + except: + pass + elif station in ('id3b'): + try: + from msnctools.scanparsers import FMBRotationScanParser + globals()['ScanParser'] = FMBRotationScanParser + except: + try: + from scanparsers import FMBRotationScanParser + globals()['ScanParser'] = FMBRotationScanParser + except: + pass + else: + raise RuntimeError(f'Invalid station: {station}') + +@cache +def get_available_scan_numbers(spec_file:str): + scans = FileSpec(spec_file).scans + scan_numbers = list(scans.keys()) + for scan_number in scan_numbers.copy(): + try: + parser = ScanParser(spec_file, scan_number) + try: + scan_type = parser.scan_type + except: + scan_type = None + except: + scan_numbers.remove(scan_number) + return(scan_numbers) + +@cache +def get_scanparser(spec_file:str, scan_number:int): + if scan_number not in get_available_scan_numbers(spec_file): + return(None) + else: + return(ScanParser(spec_file, scan_number)) + + +class BaseModel(PydanticBaseModel): + class Config: + validate_assignment = True + arbitrary_types_allowed = True + + @classmethod + def construct_from_cli(cls): + obj = cls.construct() + obj.cli() + return(obj) + + @classmethod + def construct_from_yaml(cls, filename): + try: + with open(filename, 'r') as infile: + indict = yaml.load(infile, Loader=yaml.CLoader) + except: + raise ValueError(f'Could not load a dictionary from {filename}') + else: + obj = cls(**indict) + return(obj) + + @classmethod + def construct_from_file(cls, filename): + file_exists_and_readable(filename) + filename = os.path.abspath(filename) + fileformat = os.path.splitext(filename)[1] + yaml_extensions = ('.yaml','.yml') + nexus_extensions = ('.nxs','.nx5','.h5','.hdf5') + t0 = time() + if fileformat.lower() in yaml_extensions: + obj = cls.construct_from_yaml(filename) + logger.info(f'Constructed a model from {filename} in {time()-t0:.2f} seconds.') + return(obj) + elif fileformat.lower() in nexus_extensions: + obj = cls.construct_from_nexus(filename) + logger.info(f'Constructed a model from {filename} in {time()-t0:.2f} seconds.') + return(obj) + else: + logger.error(f'Unsupported file extension for constructing a model: {fileformat}') + raise TypeError(f'Unrecognized file extension: {fileformat}') + + def dict_for_yaml(self, exclude_fields=[]): + yaml_dict = {} + for field_name in self.__fields__: + if field_name in exclude_fields: + continue + else: + field_value = getattr(self, field_name, None) + if field_value is not None: + if isinstance(field_value, BaseModel): + yaml_dict[field_name] = field_value.dict_for_yaml() + elif isinstance(field_value,list) and all(isinstance(item,BaseModel) + for item in field_value): + yaml_dict[field_name] = [item.dict_for_yaml() for item in field_value] + elif isinstance(field_value, PosixPath): + yaml_dict[field_name] = str(field_value) + else: + yaml_dict[field_name] = field_value + else: + continue + return(yaml_dict) + + def write_to_yaml(self, filename=None): + yaml_dict = self.dict_for_yaml() + if filename is None: + logger.info('Printing yaml representation here:\n'+ + f'{yaml.dump(yaml_dict, sort_keys=False)}') + else: + try: + with open(filename, 'w') as outfile: + yaml.dump(yaml_dict, outfile, sort_keys=False) + logger.info(f'Successfully wrote this model to {filename}') + except: + logger.error(f'Unknown error -- could not write to {filename} in yaml format.') + logger.info('Printing yaml representation here:\n'+ + f'{yaml.dump(yaml_dict, sort_keys=False)}') + + def write_to_file(self, filename, force_overwrite=False): + file_writeable, fileformat = self.output_file_valid(filename, + force_overwrite=force_overwrite) + if fileformat == 'yaml': + if file_writeable: + self.write_to_yaml(filename=filename) + else: + self.write_to_yaml() + elif fileformat == 'nexus': + if file_writeable: + self.write_to_nexus(filename=filename) + + def output_file_valid(self, filename, force_overwrite=False): + filename = os.path.abspath(filename) + fileformat = os.path.splitext(filename)[1] + yaml_extensions = ('.yaml','.yml') + nexus_extensions = ('.nxs','.nx5','.h5','.hdf5') + if fileformat.lower() not in (*yaml_extensions, *nexus_extensions): + return(False, None) # Only yaml and NeXus files allowed for output now. + elif fileformat.lower() in yaml_extensions: + fileformat = 'yaml' + elif fileformat.lower() in nexus_extensions: + fileformat = 'nexus' + if os.path.isfile(filename): + if os.access(filename, os.W_OK): + if not force_overwrite: + logger.error(f'{filename} will not be overwritten.') + return(False, fileformat) + else: + logger.error(f'Cannot access {filename} for writing.') + return(False, fileformat) + if os.path.isdir(os.path.dirname(filename)): + if os.access(os.path.dirname(filename), os.W_OK): + return(True, fileformat) + else: + logger.error(f'Cannot access {os.path.dirname(filename)} for writing.') + return(False, fileformat) + else: + try: + os.makedirs(os.path.dirname(filename)) + return(True, fileformat) + except: + logger.error(f'Cannot create {os.path.dirname(filename)} for output.') + return(False, fileformat) + + def set_single_attr_cli(self, attr_name, attr_desc='unknown attribute', list_flag=False, + **cli_kwargs): + if cli_kwargs.get('chain_attr_desc', False): + cli_kwargs['attr_desc'] = attr_desc + try: + attr = getattr(self, attr_name, None) + if attr is None: + attr = self.__fields__[attr_name].type_.construct() + if cli_kwargs.get('chain_attr_desc', False): + cli_kwargs['attr_desc'] = attr_desc + input_accepted = False + while not input_accepted: + try: + attr.cli(**cli_kwargs) + except ValidationError as e: + print(e) + print(f'Removing {attr_desc} configuration') + attr = self.__fields__[attr_name].type_.construct() + continue + except KeyboardInterrupt as e: + raise e + except BaseException as e: + print(f'{type(e).__name__}: {e}') + print(f'Removing {attr_desc} configuration') + attr = self.__fields__[attr_name].type_.construct() + continue + try: + setattr(self, attr_name, attr) + except ValidationError as e: + print(e) + except KeyboardInterrupt as e: + raise e + except BaseException as e: + print(f'{type(e).__name__}: {e}') + else: + input_accepted = True + except: + input_accepted = False + while not input_accepted: + attr = getattr(self, attr_name, None) + if attr is None: + input_value = input(f'Type and enter a value for {attr_desc}: ') + else: + input_value = input(f'Type and enter a new value for {attr_desc} or press '+ + f'enter to keep the current one ({attr}): ') + if list_flag: + input_value = string_to_list(input_value, remove_duplicates=False, sort=False) + if len(input_value) == 0: + input_value = getattr(self, attr_name, None) + try: + setattr(self, attr_name, input_value) + except ValidationError as e: + print(e) + except KeyboardInterrupt as e: + raise e + except BaseException as e: + print(f'Unexpected {type(e).__name__}: {e}') + else: + input_accepted = True + + def set_list_attr_cli(self, attr_name, attr_desc='unknown attribute', **cli_kwargs): + if cli_kwargs.get('chain_attr_desc', False): + cli_kwargs['attr_desc'] = attr_desc + attr = getattr(self, attr_name, None) + if attr is not None: + # Check existing items + for item in attr: + item_accepted = False + while not item_accepted: + item.cli(**cli_kwargs) + try: + setattr(self, attr_name, attr) + except ValidationError as e: + print(e) + except KeyboardInterrupt as e: + raise e + except BaseException as e: + print(f'{type(e).__name__}: {e}') + else: + item_accepted = True + else: + # Initialize list for new attribute & starting item + attr = [] + item = self.__fields__[attr_name].type_.construct() + # Append (optional) additional items + append = input_yesno(f'Add a {attr_desc} configuration? (y/n)', 'n') + while append: + attr.append(item.__class__.construct_from_cli()) + try: + setattr(self, attr_name, attr) + except ValidationError as e: + print(e) + print(f'Removing last {attr_desc} configuration from the list') + attr.pop() + except KeyboardInterrupt as e: + raise e + except BaseException as e: + print(f'{type(e).__name__}: {e}') + print(f'Removing last {attr_desc} configuration from the list') + attr.pop() + else: + append = input_yesno(f'Add another {attr_desc} configuration? (y/n)', 'n') + + +class Detector(BaseModel): + prefix: constr(strip_whitespace=True, min_length=1) + rows: conint(gt=0) + columns: conint(gt=0) + pixel_size: conlist(item_type=confloat(gt=0), min_items=1, max_items=2) + lens_magnification: confloat(gt=0) = 1.0 + + @property + def get_pixel_size(self): + return(list(np.asarray(self.pixel_size)/self.lens_magnification)) + + def construct_from_yaml(self, filename): + try: + with open(filename, 'r') as infile: + indict = yaml.load(infile, Loader=yaml.CLoader) + detector = indict['detector'] + self.prefix = detector['id'] + pixels = detector['pixels'] + self.rows = pixels['rows'] + self.columns = pixels['columns'] + self.pixel_size = pixels['size'] + self.lens_magnification = indict['lens_magnification'] + except: + logging.warning(f'Could not load a dictionary from {filename}') + return(False) + else: + return(True) + + def cli(self): + print('\n -- Configure the detector -- ') + self.set_single_attr_cli('prefix', 'detector ID') + self.set_single_attr_cli('rows', 'number of pixel rows') + self.set_single_attr_cli('columns', 'number of pixel columns') + self.set_single_attr_cli('pixel_size', 'pixel size in mm (enter either a single value for '+ + 'square pixels or a pair of values for the size in the respective row and column '+ + 'directions)', list_flag=True) + self.set_single_attr_cli('lens_magnification', 'lens magnification') + + def construct_nxdetector(self): + nxdetector = NXdetector() + nxdetector.local_name = self.prefix + pixel_size = self.get_pixel_size + if len(pixel_size) == 1: + nxdetector.x_pixel_size = pixel_size[0] + nxdetector.y_pixel_size = pixel_size[0] + else: + nxdetector.x_pixel_size = pixel_size[0] + nxdetector.y_pixel_size = pixel_size[1] + nxdetector.x_pixel_size.attrs['units'] = 'mm' + nxdetector.y_pixel_size.attrs['units'] = 'mm' + return(nxdetector) + + +class ScanInfo(TypedDict): + scan_number: int + starting_image_offset: conint(ge=0) + num_image: conint(gt=0) + ref_x: float + ref_z: float + +class SpecScans(BaseModel): + spec_file: FilePath + scan_numbers: conlist(item_type=conint(gt=0), min_items=1) + stack_info: conlist(item_type=ScanInfo, min_items=1) = [] + + @validator('spec_file') + def validate_spec_file(cls, spec_file): + try: + spec_file = os.path.abspath(spec_file) + sspec_file = FileSpec(spec_file) + except: + raise ValueError(f'Invalid SPEC file {spec_file}') + else: + return(spec_file) + + @validator('scan_numbers') + def validate_scan_numbers(cls, scan_numbers, values): + spec_file = values.get('spec_file') + if spec_file is not None: + spec_scans = FileSpec(spec_file) + for scan_number in scan_numbers: + scan = spec_scans.get_scan_by_number(scan_number) + if scan is None: + raise ValueError(f'There is no scan number {scan_number} in {spec_file}') + return(scan_numbers) + + @validator('stack_info') + def validate_stack_info(cls, stack_info, values): + scan_numbers = values.get('scan_numbers') + assert(len(scan_numbers) == len(stack_info)) + for scan_info in stack_info: + assert(scan_info['scan_number'] in scan_numbers) + is_int(scan_info['starting_image_offset'], ge=0, lt=scan_info['num_image'], + raise_error=True) + return(stack_info) + + @classmethod + def construct_from_nxcollection(cls, nxcollection:NXcollection): + config = {} + config['spec_file'] = nxcollection.attrs['spec_file'] + scan_numbers = [] + stack_info = [] + for nxsubentry_name, nxsubentry in nxcollection.items(): + scan_number = int(nxsubentry_name.split('_')[-1]) + scan_numbers.append(scan_number) + stack_info.append({'scan_number': scan_number, + 'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number), + 'num_image': len(nxsubentry.sample.rotation_angle), + 'ref_x': float(nxsubentry.sample.x_translation), + 'ref_z': float(nxsubentry.sample.z_translation)}) + config['scan_numbers'] = sorted(scan_numbers) + config['stack_info'] = stack_info + return(cls(**config)) + + @property + def available_scan_numbers(self): + return(get_available_scan_numbers(self.spec_file)) + + def set_from_nxcollection(self, nxcollection:NXcollection): + self.spec_file = nxcollection.attrs['spec_file'] + scan_numbers = [] + stack_info = [] + for nxsubentry_name, nxsubentry in nxcollection.items(): + scan_number = int(nxsubentry_name.split('_')[-1]) + scan_numbers.append(scan_number) + stack_info.append({'scan_number': scan_number, + 'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number), + 'num_image': len(nxsubentry.sample.rotation_angle), + 'ref_x': float(nxsubentry.sample.x_translation), + 'ref_z': float(nxsubentry.sample.z_translation)}) + self.scan_numbers = sorted(scan_numbers) + self.stack_info = stack_info + + def get_scan_index(self, scan_number): + scan_index = [scan_index for scan_index, scan_info in enumerate(self.stack_info) + if scan_info['scan_number'] == scan_number] + if len(scan_index) > 1: + raise ValueError('Duplicate scan_numbers in image stack') + elif len(scan_index) == 1: + return(scan_index[0]) + else: + return(None) + + def get_scanparser(self, scan_number): + return(get_scanparser(self.spec_file, scan_number)) + + def get_detector_data(self, detector_prefix, scan_number=None, scan_step_index=None): + image_stacks = [] + if scan_number is None: + scan_numbers = self.scan_numbers + else: + scan_numbers = [scan_number] + for scan_number in scan_numbers: + parser = self.get_scanparser(scan_number) + scan_info = self.stack_info[self.get_scan_index(scan_number)] + image_offset = scan_info['starting_image_offset'] + if scan_step_index is None: + num_image = scan_info['num_image'] + image_stacks.append(parser.get_detector_data(detector_prefix, + (image_offset, image_offset+num_image))) + else: + image_stacks.append(parser.get_detector_data(detector_prefix, + image_offset+scan_step_index)) + if scan_number is not None and scan_step_index is not None: + # Return a single image for a specific scan_number and scan_step_index request + return(image_stacks[0]) + else: + # Return a list otherwise + return(image_stacks) + return(image_stacks) + + def scan_numbers_cli(self, attr_desc, **kwargs): + available_scan_numbers = self.available_scan_numbers + station = kwargs.get('station') + if (station is not None and station in ('id1a3', 'id3a') and + 'scan_type' in kwargs): + scan_type = kwargs['scan_type'] + if scan_type == 'ts1': + available_scan_numbers = [] + for scan_number in self.available_scan_numbers: + parser = self.get_scanparser(scan_number) + try: + if parser.scan_type == scan_type: + available_scan_numbers.append(scan_number) + except: + pass + elif scan_type == 'df1': + tomo_scan_numbers = kwargs['tomo_scan_numbers'] + available_scan_numbers = [] + for scan_number in tomo_scan_numbers: + parser = self.get_scanparser(scan_number-2) + assert(parser.scan_type == scan_type) + available_scan_numbers.append(scan_number-2) + elif scan_type == 'bf1': + tomo_scan_numbers = kwargs['tomo_scan_numbers'] + available_scan_numbers = [] + for scan_number in tomo_scan_numbers: + parser = self.get_scanparser(scan_number-1) + assert(parser.scan_type == scan_type) + available_scan_numbers.append(scan_number-1) + if len(available_scan_numbers) == 1: + input_mode = 1 + else: + if hasattr(self, 'scan_numbers'): + print(f'Currently selected {attr_desc}scan numbers are: {self.scan_numbers}') + menu_options = [f'Select a subset of the available {attr_desc}scan numbers', + f'Use all available {attr_desc}scan numbers in {self.spec_file}', + f'Keep the currently selected {attr_desc}scan numbers'] + else: + menu_options = [f'Select a subset of the available {attr_desc}scan numbers', + f'Use all available {attr_desc}scan numbers in {self.spec_file}'] + print(f'Available scan numbers in {self.spec_file} are: '+ + f'{available_scan_numbers}') + input_mode = input_menu(menu_options, header='Choose one of the following options '+ + 'for selecting scan numbers') + if input_mode == 0: + accept_scan_numbers = False + while not accept_scan_numbers: + try: + self.scan_numbers = \ + input_int_list(f'Enter a series of {attr_desc}scan numbers') + except ValidationError as e: + print(e) + except KeyboardInterrupt as e: + raise e + except BaseException as e: + print(f'Unexpected {type(e).__name__}: {e}') + else: + accept_scan_numbers = True + elif input_mode == 1: + self.scan_numbers = available_scan_numbers + elif input_mode == 2: + pass + + def cli(self, **cli_kwargs): + if cli_kwargs.get('attr_desc') is not None: + attr_desc = f'{cli_kwargs["attr_desc"]} ' + else: + attr_desc = '' + print(f'\n -- Configure which scans to use from a single {attr_desc}SPEC file') + self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path') + self.scan_numbers_cli(attr_desc) + + def construct_nxcollection(self, image_key, thetas, detector): + nxcollection = NXcollection() + nxcollection.attrs['spec_file'] = str(self.spec_file) + parser = self.get_scanparser(self.scan_numbers[0]) + nxcollection.attrs['date'] = parser.spec_scan.file_date + for scan_number in self.scan_numbers: + # Get scan info + scan_info = self.stack_info[self.get_scan_index(scan_number)] + # Add an NXsubentry to the NXcollection for each scan + entry_name = f'scan_{scan_number}' + nxsubentry = NXsubentry() + nxcollection[entry_name] = nxsubentry + parser = self.get_scanparser(scan_number) + nxsubentry.start_time = parser.spec_scan.date + nxsubentry.spec_command = parser.spec_command + # Add an NXdata for independent dimensions to the scan's NXsubentry + num_image = scan_info['num_image'] + if thetas is None: + thetas = num_image*[0.0] + else: + assert(num_image == len(thetas)) +# nxsubentry.independent_dimensions = NXdata() +# nxsubentry.independent_dimensions.rotation_angle = thetas +# nxsubentry.independent_dimensions.rotation_angle.units = 'degrees' + # Add an NXinstrument to the scan's NXsubentry + nxsubentry.instrument = NXinstrument() + # Add an NXdetector to the NXinstrument to the scan's NXsubentry + nxsubentry.instrument.detector = detector.construct_nxdetector() + nxsubentry.instrument.detector.frame_start_number = scan_info['starting_image_offset'] + nxsubentry.instrument.detector.image_key = image_key + # Add an NXsample to the scan's NXsubentry + nxsubentry.sample = NXsample() + nxsubentry.sample.rotation_angle = thetas + nxsubentry.sample.rotation_angle.units = 'degrees' + nxsubentry.sample.x_translation = scan_info['ref_x'] + nxsubentry.sample.x_translation.units = 'mm' + nxsubentry.sample.z_translation = scan_info['ref_z'] + nxsubentry.sample.z_translation.units = 'mm' + return(nxcollection) + + +class FlatField(SpecScans): + + def image_range_cli(self, attr_desc, detector_prefix): + stack_info = self.stack_info + for scan_number in self.scan_numbers: + # Parse the available image range + parser = self.get_scanparser(scan_number) + image_offset = parser.starting_image_offset + num_image = parser.get_num_image(detector_prefix.upper()) + scan_index = self.get_scan_index(scan_number) + + # Select the image set + last_image_index = image_offset+num_image + print(f'Available good image set index range: [{image_offset}, {last_image_index})') + image_set_approved = False + if scan_index is not None: + scan_info = stack_info[scan_index] + print(f'Current starting image offset and number of images: '+ + f'{scan_info["starting_image_offset"]} and {scan_info["num_image"]}') + image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y') + if not image_set_approved: + print(f'Default starting image offset and number of images: '+ + f'{image_offset} and {num_image}') + image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y') + if image_set_approved: + offset = image_offset + num = last_image_index-offset + while not image_set_approved: + offset = input_int(f'Enter the starting image offset', ge=image_offset, + lt=last_image_index)#, default=image_offset) + num = input_int(f'Enter the number of images', ge=1, + le=last_image_index-offset)#, default=last_image_index-offset) + print(f'Current starting image offset and number of images: {offset} and {num}') + image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y') + if scan_index is not None: + scan_info['starting_image_offset'] = offset + scan_info['num_image'] = num + scan_info['ref_x'] = parser.horizontal_shift + scan_info['ref_z'] = parser.vertical_shift + else: + stack_info.append({'scan_number': scan_number, 'starting_image_offset': offset, + 'num_image': num, 'ref_x': parser.horizontal_shift, + 'ref_z': parser.vertical_shift}) + self.stack_info = stack_info + + def cli(self, **cli_kwargs): + if cli_kwargs.get('attr_desc') is not None: + attr_desc = f'{cli_kwargs["attr_desc"]} ' + else: + attr_desc = '' + station = cli_kwargs.get('station') + detector = cli_kwargs.get('detector') + print(f'\n -- Configure the location of the {attr_desc}scan data -- ') + if station in ('id1a3', 'id3a'): + self.spec_file = cli_kwargs['spec_file'] + tomo_scan_numbers = cli_kwargs['tomo_scan_numbers'] + scan_type = cli_kwargs['scan_type'] + self.scan_numbers_cli(attr_desc, station=station, tomo_scan_numbers=tomo_scan_numbers, + scan_type=scan_type) + else: + self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path') + self.scan_numbers_cli(attr_desc) + self.image_range_cli(attr_desc, detector.prefix) + + +class TomoField(SpecScans): + theta_range: dict = {} + + @validator('theta_range') + def validate_theta_range(cls, theta_range): + if len(theta_range) != 3 and len(theta_range) != 4: + raise ValueError(f'Invalid theta range {theta_range}') + is_num(theta_range['start'], raise_error=True) + is_num(theta_range['end'], raise_error=True) + is_int(theta_range['num'], gt=1, raise_error=True) + if theta_range['end'] <= theta_range['start']: + raise ValueError(f'Invalid theta range {theta_range}') + if 'start_index' in theta_range: + is_int(theta_range['start_index'], ge=0, raise_error=True) + return(theta_range) + + @classmethod + def construct_from_nxcollection(cls, nxcollection:NXcollection): + #RV Can I derive this from the same classfunction for SpecScans by adding theta_range + config = {} + config['spec_file'] = nxcollection.attrs['spec_file'] + scan_numbers = [] + stack_info = [] + for nxsubentry_name, nxsubentry in nxcollection.items(): + scan_number = int(nxsubentry_name.split('_')[-1]) + scan_numbers.append(scan_number) + stack_info.append({'scan_number': scan_number, + 'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number), + 'num_image': len(nxsubentry.sample.rotation_angle), + 'ref_x': float(nxsubentry.sample.x_translation), + 'ref_z': float(nxsubentry.sample.z_translation)}) + config['scan_numbers'] = sorted(scan_numbers) + config['stack_info'] = stack_info + for name in nxcollection.entries: + if 'scan_' in name: + thetas = np.asarray(nxcollection[name].sample.rotation_angle) + config['theta_range'] = {'start': thetas[0], 'end': thetas[-1], 'num': thetas.size} + break + return(cls(**config)) + + def get_horizontal_shifts(self, scan_number=None): + horizontal_shifts = [] + if scan_number is None: + scan_numbers = self.scan_numbers + else: + scan_numbers = [scan_number] + for scan_number in scan_numbers: + parser = self.get_scanparser(scan_number) + horizontal_shifts.append(parser.horizontal_shift) + if len(horizontal_shifts) == 1: + return(horizontal_shifts[0]) + else: + return(horizontal_shifts) + + def get_vertical_shifts(self, scan_number=None): + vertical_shifts = [] + if scan_number is None: + scan_numbers = self.scan_numbers + else: + scan_numbers = [scan_number] + for scan_number in scan_numbers: + parser = self.get_scanparser(scan_number) + vertical_shifts.append(parser.vertical_shift) + if len(vertical_shifts) == 1: + return(vertical_shifts[0]) + else: + return(vertical_shifts) + + def theta_range_cli(self, scan_number, attr_desc, station): + # Parse the available theta range + parser = self.get_scanparser(scan_number) + theta_vals = parser.theta_vals + spec_theta_start = theta_vals.get('start') + spec_theta_end = theta_vals.get('end') + spec_num_theta = theta_vals.get('num') + + # Check for consistency of theta ranges between scans + if scan_number != self.scan_numbers[0]: + parser = self.get_scanparser(self.scan_numbers[0]) + if (parser.theta_vals.get('start') != spec_theta_start or + parser.theta_vals.get('end') != spec_theta_end or + parser.theta_vals.get('num') != spec_num_theta): + raise ValueError(f'Incompatible theta ranges between {attr_desc}scans:'+ + f'\n\tScan {scan_number}: {theta_vals}'+ + f'\n\tScan {self.scan_numbers[0]}: {parser.theta_vals}') + return + + # Select the theta range for the tomo reconstruction from the first scan + theta_range_approved = False + thetas = np.linspace(spec_theta_start, spec_theta_end, spec_num_theta) + delta_theta = thetas[1]-thetas[0] + print(f'Theta range obtained from SPEC data: [{spec_theta_start}, {spec_theta_end}]') + print(f'Theta step size = {delta_theta}') + print(f'Number of theta values: {spec_num_theta}') + default_start = None + default_end = None + if station in ('id1a3', 'id3a'): + theta_range_approved = input_yesno(f'Accept this theta range (y/n)?', 'y') + if theta_range_approved: + self.theta_range = {'start': float(spec_theta_start), 'end': float(spec_theta_end), + 'num': int(spec_num_theta), 'start_index': 0} + return + elif station in ('id3b'): + if spec_theta_start <= 0.0 and spec_theta_end >= 180.0: + default_start = 0 + default_end = 180 + elif spec_theta_end-spec_theta_start == 180: + default_start = spec_theta_start + default_end = spec_theta_end + while not theta_range_approved: + theta_start = input_num(f'Enter the first theta (included)', ge=spec_theta_start, + lt=spec_theta_end, default=default_start) + theta_index_start = index_nearest(thetas, theta_start) + theta_start = thetas[theta_index_start] + theta_end = input_num(f'Enter the last theta (excluded)', + ge=theta_start+delta_theta, le=spec_theta_end, default=default_end) + theta_index_end = index_nearest(thetas, theta_end) + theta_end = thetas[theta_index_end] + num_theta = theta_index_end-theta_index_start + print(f'Selected theta range: [{theta_start}, {theta_start+delta_theta}, ..., '+ + f'{theta_end})') + print(f'Number of theta values: {num_theta}') + theta_range_approved = input_yesno(f'Accept this theta range (y/n)?', 'y') + self.theta_range = {'start': float(theta_start), 'end': float(theta_end), + 'num': int(num_theta), 'start_index': int(theta_index_start)} + + def image_range_cli(self, attr_desc, detector_prefix): + stack_info = self.stack_info + for scan_number in self.scan_numbers: + # Parse the available image range + parser = self.get_scanparser(scan_number) + image_offset = parser.starting_image_offset + num_image = parser.get_num_image(detector_prefix.upper()) + scan_index = self.get_scan_index(scan_number) + + # Select the image set matching the theta range + num_theta = self.theta_range['num'] + theta_index_start = self.theta_range['start_index'] + if num_theta > num_image-theta_index_start: + raise ValueError(f'Available {attr_desc}image indices incompatible with thetas:'+ + f'\n\tNumber of thetas and offset = {num_theta} and {theta_index_start}'+ + f'\n\tNumber of available images {num_image}') + if scan_index is not None: + scan_info = stack_info[scan_index] + scan_info['starting_image_offset'] = image_offset+theta_index_start + scan_info['num_image'] = num_theta + scan_info['ref_x'] = parser.horizontal_shift + scan_info['ref_z'] = parser.vertical_shift + else: + stack_info.append({'scan_number': scan_number, + 'starting_image_offset': image_offset+theta_index_start, + 'num_image': num_theta, 'ref_x': parser.horizontal_shift, + 'ref_z': parser.vertical_shift}) + self.stack_info = stack_info + + def cli(self, **cli_kwargs): + if cli_kwargs.get('attr_desc') is not None: + attr_desc = f'{cli_kwargs["attr_desc"]} ' + else: + attr_desc = '' + cycle = cli_kwargs.get('cycle') + btr = cli_kwargs.get('btr') + station = cli_kwargs.get('station') + detector = cli_kwargs.get('detector') + sample_name = cli_kwargs.get('sample_name') + print(f'\n -- Configure the location of the {attr_desc}scan data -- ') + if station in ('id1a3', 'id3a'): + basedir = f'/nfs/chess/{station}/{cycle}/{btr}' + runs = [d for d in os.listdir(basedir) if os.path.isdir(os.path.join(basedir, d))] +#RV index = 15-1 +#RV index = 7-1 + if sample_name is not None and sample_name in runs: + index = runs.index(sample_name) + else: + index = input_menu(runs, header='Choose a sample directory') + self.spec_file = f'{basedir}/{runs[index]}/spec.log' + self.scan_numbers_cli(attr_desc, station=station, scan_type='ts1') + else: + self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path') + self.scan_numbers_cli(attr_desc) + for scan_number in self.scan_numbers: + self.theta_range_cli(scan_number, attr_desc, station) + self.image_range_cli(attr_desc, detector.prefix) + + +class Sample(BaseModel): + name: constr(min_length=1) + description: Optional[str] + rotation_angles: Optional[list] + x_translations: Optional[list] + z_translations: Optional[list] + + @classmethod + def construct_from_nxsample(cls, nxsample:NXsample): + config = {} + config['name'] = nxsample.name.nxdata + if 'description' in nxsample: + config['description'] = nxsample.description.nxdata + if 'rotation_angle' in nxsample: + config['rotation_angle'] = nxsample.rotation_angle.nxdata + if 'x_translation' in nxsample: + config['x_translation'] = nxsample.x_translation.nxdata + if 'z_translation' in nxsample: + config['z_translation'] = nxsample.z_translation.nxdata + return(cls(**config)) + + def cli(self): + print('\n -- Configure the sample metadata -- ') +#RV self.name = 'sobhani-3249-A' +#RV self.name = 'tenstom_1304r-1' + self.set_single_attr_cli('name', 'the sample name') +#RV self.description = 'test sample' + self.set_single_attr_cli('description', 'a description of the sample (optional)') + + +class MapConfig(BaseModel): + cycle: constr(strip_whitespace=True, min_length=1) + btr: constr(strip_whitespace=True, min_length=1) + title: constr(strip_whitespace=True, min_length=1) + station: Literal['id1a3', 'id3a', 'id3b'] = None + sample: Sample + detector: Detector = Detector.construct() + tomo_fields: TomoField + dark_field: Optional[FlatField] + bright_field: FlatField + _thetas: list[float] = PrivateAttr() + _field_types = ({'name': 'dark_field', 'image_key': 2}, {'name': 'bright_field', + 'image_key': 1}, {'name': 'tomo_fields', 'image_key': 0}) + + @classmethod + def construct_from_nxentry(cls, nxentry:NXentry): + config = {} + config['cycle'] = nxentry.instrument.source.attrs['cycle'] + config['btr'] = nxentry.instrument.source.attrs['btr'] + config['title'] = nxentry.nxname + config['station'] = nxentry.instrument.source.attrs['station'] + config['sample'] = Sample.construct_from_nxsample(nxentry['sample']) + for nxobject_name, nxobject in nxentry.spec_scans.items(): + if isinstance(nxobject, NXcollection): + config[nxobject_name] = SpecScans.construct_from_nxcollection(nxobject) + return(cls(**config)) + +#FIX cache? + @property + def thetas(self): + try: + return(self._thetas) + except: + theta_range = self.tomo_fields.theta_range + self._thetas = list(np.linspace(theta_range['start'], theta_range['end'], + theta_range['num'])) + return(self._thetas) + + def cli(self): + print('\n -- Configure a map from a set of SPEC scans (dark, bright, and tomo), '+ + 'and / or detector data -- ') +#RV self.cycle = '2021-3' +#RV self.cycle = '2022-2' +#RV self.cycle = '2023-1' + self.set_single_attr_cli('cycle', 'beam cycle') +#RV self.btr = 'z-3234-A' +#RV self.btr = 'sobhani-3249-A' +#RV self.btr = 'przybyla-3606-a' + self.set_single_attr_cli('btr', 'BTR') +#RV self.title = 'z-3234-A' +#RV self.title = 'tomo7C' +#RV self.title = 'cmc-test-dwell-1' + self.set_single_attr_cli('title', 'title for the map entry') +#RV self.station = 'id3a' +#RV self.station = 'id3b' +#RV self.station = 'id1a3' + self.set_single_attr_cli('station', 'name of the station at which scans were collected '+ + '(currently choose from: id1a3, id3a, id3b)') + import_scanparser(self.station) + self.set_single_attr_cli('sample') + use_detector_config = False + if hasattr(self.detector, 'prefix') and len(self.detector.prefix): + use_detector_config = input_yesno(f'Current detector settings:\n{self.detector}\n'+ + f'Keep these settings? (y/n)') + if not use_detector_config: + menu_options = ['not listed', 'andor2', 'manta', 'retiga'] + input_mode = input_menu(menu_options, header='Choose one of the following detector '+ + 'configuration options') + if input_mode: + detector_config_file = f'{menu_options[input_mode]}.yaml' + have_detector_config = self.detector.construct_from_yaml(detector_config_file) + else: + have_detector_config = False + if not have_detector_config: + self.set_single_attr_cli('detector', 'detector') + self.set_single_attr_cli('tomo_fields', 'Tomo field', chain_attr_desc=True, + cycle=self.cycle, btr=self.btr, station=self.station, detector=self.detector, + sample_name=self.sample.name) + if self.station in ('id1a3', 'id3a'): + have_dark_field = True + tomo_spec_file = self.tomo_fields.spec_file + else: + have_dark_field = input_yesno(f'Are Dark field images available? (y/n)') + tomo_spec_file = None + if have_dark_field: + self.set_single_attr_cli('dark_field', 'Dark field', chain_attr_desc=True, + station=self.station, detector=self.detector, spec_file=tomo_spec_file, + tomo_scan_numbers=self.tomo_fields.scan_numbers, scan_type='df1') + self.set_single_attr_cli('bright_field', 'Bright field', chain_attr_desc=True, + station=self.station, detector=self.detector, spec_file=tomo_spec_file, + tomo_scan_numbers=self.tomo_fields.scan_numbers, scan_type='bf1') + + def construct_nxentry(self, nxroot, include_raw_data=True): + # Construct base NXentry + nxentry = NXentry() + + # Add an NXentry to the NXroot + nxroot[self.title] = nxentry + nxroot.attrs['default'] = self.title + nxentry.definition = 'NXtomo' +# nxentry.attrs['default'] = 'data' + + # Add an NXinstrument to the NXentry + nxinstrument = NXinstrument() + nxentry.instrument = nxinstrument + + # Add an NXsource to the NXinstrument + nxsource = NXsource() + nxinstrument.source = nxsource + nxsource.type = 'Synchrotron X-ray Source' + nxsource.name = 'CHESS' + nxsource.probe = 'x-ray' + + # Tag the NXsource with the runinfo (as an attribute) + nxsource.attrs['cycle'] = self.cycle + nxsource.attrs['btr'] = self.btr + nxsource.attrs['station'] = self.station + + # Add an NXdetector to the NXinstrument (don't fill in data fields yet) + nxinstrument.detector = self.detector.construct_nxdetector() + + # Add an NXsample to NXentry (don't fill in data fields yet) + nxsample = NXsample() + nxentry.sample = nxsample + nxsample.name = self.sample.name + nxsample.description = self.sample.description + + # Add an NXcollection to the base NXentry to hold metadata about the spec scans in the map + # Also obtain the data fields in NXsample and NXdetector + nxspec_scans = NXcollection() + nxentry.spec_scans = nxspec_scans + image_keys = [] + sequence_numbers = [] + image_stacks = [] + rotation_angles = [] + x_translations = [] + z_translations = [] + for field_type in self._field_types: + field_name = field_type['name'] + field = getattr(self, field_name) + if field is None: + continue + image_key = field_type['image_key'] + if field_type['name'] == 'tomo_fields': + thetas = self.thetas + else: + thetas = None + # Add the scans in a single spec file + nxspec_scans[field_name] = field.construct_nxcollection(image_key, thetas, + self.detector) + if include_raw_data: + image_stacks += field.get_detector_data(self.detector.prefix) + for scan_number in field.scan_numbers: + parser = field.get_scanparser(scan_number) + scan_info = field.stack_info[field.get_scan_index(scan_number)] + num_image = scan_info['num_image'] + image_keys += num_image*[image_key] + sequence_numbers += [i for i in range(num_image)] + if thetas is None: + rotation_angles += scan_info['num_image']*[0.0] + else: + assert(num_image == len(thetas)) + rotation_angles += thetas + x_translations += scan_info['num_image']*[scan_info['ref_x']] + z_translations += scan_info['num_image']*[scan_info['ref_z']] + + if include_raw_data: + # Add image data to NXdetector + nxinstrument.detector.image_key = image_keys + nxinstrument.detector.sequence_number = sequence_numbers + nxinstrument.detector.data = np.concatenate([image for image in image_stacks]) + + # Add image data to NXsample + nxsample.rotation_angle = rotation_angles + nxsample.rotation_angle.attrs['units'] = 'degrees' + nxsample.x_translation = x_translations + nxsample.x_translation.attrs['units'] = 'mm' + nxsample.z_translation = z_translations + nxsample.z_translation.attrs['units'] = 'mm' + + # Add an NXdata to NXentry + nxdata = NXdata() + nxentry.data = nxdata + nxdata.makelink(nxentry.instrument.detector.data, name='data') + nxdata.makelink(nxentry.instrument.detector.image_key) + nxdata.makelink(nxentry.sample.rotation_angle) + nxdata.makelink(nxentry.sample.x_translation) + nxdata.makelink(nxentry.sample.z_translation) +# nxdata.attrs['axes'] = ['field', 'row', 'column'] +# nxdata.attrs['field_indices'] = 0 +# nxdata.attrs['row_indices'] = 1 +# nxdata.attrs['column_indices'] = 2 + + +class TomoWorkflow(BaseModel): + sample_maps: conlist(item_type=MapConfig, min_items=1) = [MapConfig.construct()] + + @classmethod + def construct_from_nexus(cls, filename): + nxroot = nxload(filename) + sample_maps = [] + config = {'sample_maps': sample_maps} + for nxentry_name, nxentry in nxroot.items(): + sample_maps.append(MapConfig.construct_from_nxentry(nxentry)) + return(cls(**config)) + + def cli(self): + print('\n -- Configure a map -- ') + self.set_list_attr_cli('sample_maps', 'sample map') + + def construct_nxfile(self, filename, mode='w-'): + nxroot = NXroot() + t0 = time() + for sample_map in self.sample_maps: + logger.info(f'Start constructing the {sample_map.title} map.') + import_scanparser(sample_map.station) + sample_map.construct_nxentry(nxroot) + logger.info(f'Constructed all sample maps in {time()-t0:.2f} seconds.') + logger.info(f'Start saving all sample maps to {filename}.') + nxroot.save(filename, mode=mode) + + def write_to_nexus(self, filename): + t0 = time() + self.construct_nxfile(filename, mode='w') + logger.info(f'Saved all sample maps to {filename} in {time()-t0:.2f} seconds.')
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/workflow/run_tomo.py Mon Mar 20 18:44:23 2023 +0000 @@ -0,0 +1,1629 @@ +#!/usr/bin/env python3 + +import logging +logger = logging.getLogger(__name__) + +import numpy as np +try: + import numexpr as ne +except: + pass +try: + import scipy.ndimage as spi +except: + pass + +from multiprocessing import cpu_count +from nexusformat.nexus import * +from os import mkdir +from os import path as os_path +try: + from skimage.transform import iradon +except: + pass +try: + from skimage.restoration import denoise_tv_chambolle +except: + pass +from time import time +try: + import tomopy +except: + pass +from yaml import safe_load, safe_dump + +try: + from msnctools.fit import Fit +except: + from fit import Fit +try: + from msnctools.general import illegal_value, is_int, is_int_pair, is_num, is_index_range, \ + input_int, input_num, input_yesno, input_menu, draw_mask_1d, select_image_bounds, \ + select_one_image_bound, clear_imshow, quick_imshow, clear_plot, quick_plot +except: + from general import illegal_value, is_int, is_int_pair, is_num, is_index_range, \ + input_int, input_num, input_yesno, input_menu, draw_mask_1d, select_image_bounds, \ + select_one_image_bound, clear_imshow, quick_imshow, clear_plot, quick_plot + +try: + from workflow.models import import_scanparser, FlatField, TomoField, TomoWorkflow + from workflow.__version__ import __version__ +except: + pass + +num_core_tomopy_limit = 24 + +def nxcopy(nxobject:NXobject, exclude_nxpaths:list[str]=[], nxpath_prefix:str='') -> NXobject: + '''Function that returns a copy of a nexus object, optionally exluding certain child items. + + :param nxobject: the original nexus object to return a "copy" of + :type nxobject: nexusformat.nexus.NXobject + :param exlude_nxpaths: a list of paths to child nexus objects that + should be exluded from the returned "copy", defaults to `[]` + :type exclude_nxpaths: list[str], optional + :param nxpath_prefix: For use in recursive calls from inside this + function only! + :type nxpath_prefix: str + :return: a copy of `nxobject` with some children optionally exluded. + :rtype: NXobject + ''' + + nxobject_copy = nxobject.__class__() + if not len(nxpath_prefix): + if 'default' in nxobject.attrs: + nxobject_copy.attrs['default'] = nxobject.attrs['default'] + else: + for k, v in nxobject.attrs.items(): + nxobject_copy.attrs[k] = v + + for k, v in nxobject.items(): + nxpath = os_path.join(nxpath_prefix, k) + + if nxpath in exclude_nxpaths: + continue + + if isinstance(v, NXgroup): + nxobject_copy[k] = nxcopy(v, exclude_nxpaths=exclude_nxpaths, + nxpath_prefix=os_path.join(nxpath_prefix, k)) + else: + nxobject_copy[k] = v + + return(nxobject_copy) + +class set_numexpr_threads: + + def __init__(self, num_core): + if num_core is None or num_core < 1 or num_core > cpu_count(): + self.num_core = cpu_count() + else: + self.num_core = num_core + + def __enter__(self): + self.num_core_org = ne.set_num_threads(self.num_core) + + def __exit__(self, exc_type, exc_value, traceback): + ne.set_num_threads(self.num_core_org) + +class Tomo: + """Processing tomography data with misalignment. + """ + def __init__(self, galaxy_flag=False, num_core=-1, output_folder='.', save_figs=None, + test_mode=False): + """Initialize with optional config input file or dictionary + """ + if not isinstance(galaxy_flag, bool): + raise ValueError(f'Invalid parameter galaxy_flag ({galaxy_flag})') + self.galaxy_flag = galaxy_flag + self.num_core = num_core + if self.galaxy_flag: + if output_folder != '.': + logger.warning('Ignoring output_folder in galaxy mode') + self.output_folder = '.' + if test_mode != False: + logger.warning('Ignoring test_mode in galaxy mode') + self.test_mode = False + if save_figs is not None: + logger.warning('Ignoring save_figs in galaxy mode') + save_figs = 'only' + else: + self.output_folder = os_path.abspath(output_folder) + if not os_path.isdir(output_folder): + mkdir(os_path.abspath(output_folder)) + if not isinstance(test_mode, bool): + raise ValueError(f'Invalid parameter test_mode ({test_mode})') + self.test_mode = test_mode + if save_figs is None: + save_figs = 'no' + self.test_config = {} + if self.test_mode: + if save_figs != 'only': + logger.warning('Ignoring save_figs in test mode') + save_figs = 'only' + if save_figs == 'only': + self.save_only = True + self.save_figs = True + elif save_figs == 'yes': + self.save_only = False + self.save_figs = True + elif save_figs == 'no': + self.save_only = False + self.save_figs = False + else: + raise ValueError(f'Invalid parameter save_figs ({save_figs})') + if self.save_only: + self.block = False + else: + self.block = True + if self.num_core == -1: + self.num_core = cpu_count() + if not is_int(self.num_core, gt=0, log=False): + raise ValueError(f'Invalid parameter num_core ({num_core})') + if self.num_core > cpu_count(): + logger.warning(f'num_core = {self.num_core} is larger than the number of available ' + f'processors and reduced to {cpu_count()}') + self.num_core= cpu_count() + + def read(self, filename): + extension = os_path.splitext(filename)[1] + if extension == '.yml' or extension == '.yaml': + with open(filename, 'r') as f: + config = safe_load(f) +# if len(config) > 1: +# raise ValueError(f'Multiple root entries in {filename} not yet implemented') +# if len(list(config.values())[0]) > 1: +# raise ValueError(f'Multiple sample maps in {filename} not yet implemented') + return(config) + elif extension == '.nxs': + with NXFile(filename, mode='r') as nxfile: + nxroot = nxfile.readfile() + return(nxroot) + else: + raise ValueError(f'Invalid filename extension ({extension})') + + def write(self, data, filename): + extension = os_path.splitext(filename)[1] + if extension == '.yml' or extension == '.yaml': + with open(filename, 'w') as f: + safe_dump(data, f) + elif extension == '.nxs': + data.save(filename, mode='w') + elif extension == '.nc': + data.to_netcdf(os_path=filename) + else: + raise ValueError(f'Invalid filename extension ({extension})') + + def gen_reduced_data(self, data, img_x_bounds=None): + """Generate the reduced tomography images. + """ + logger.info('Generate the reduced tomography images') + + # Create plot galaxy path directory if needed + if self.galaxy_flag and not os_path.exists('tomo_reduce_plots'): + mkdir('tomo_reduce_plots') + + if isinstance(data, dict): + # Create Nexus format object from input dictionary + wf = TomoWorkflow(**data) + if len(wf.sample_maps) > 1: + raise ValueError(f'Multiple sample maps not yet implemented') +# print(f'\nwf:\n{wf}\n') + nxroot = NXroot() + t0 = time() + for sample_map in wf.sample_maps: + logger.info(f'Start constructing the {sample_map.title} map.') + import_scanparser(sample_map.station) + sample_map.construct_nxentry(nxroot, include_raw_data=False) + logger.info(f'Constructed all sample maps in {time()-t0:.2f} seconds.') + nxentry = nxroot[nxroot.attrs['default']] + # Get test mode configuration info + if self.test_mode: + self.test_config = data['sample_maps'][0]['test_mode'] + elif isinstance(data, NXroot): + nxentry = data[data.attrs['default']] + else: + raise ValueError(f'Invalid parameter data ({data})') + + # Create an NXprocess to store data reduction (meta)data + reduced_data = NXprocess() + + # Generate dark field + if 'dark_field' in nxentry['spec_scans']: + reduced_data = self._gen_dark(nxentry, reduced_data) + + # Generate bright field + reduced_data = self._gen_bright(nxentry, reduced_data) + + # Set vertical detector bounds for image stack + img_x_bounds = self._set_detector_bounds(nxentry, reduced_data, img_x_bounds=img_x_bounds) + logger.info(f'img_x_bounds = {img_x_bounds}') + reduced_data['img_x_bounds'] = img_x_bounds + + # Set zoom and/or theta skip to reduce memory the requirement + zoom_perc, num_theta_skip = self._set_zoom_or_skip() + if zoom_perc is not None: + reduced_data.attrs['zoom_perc'] = zoom_perc + if num_theta_skip is not None: + reduced_data.attrs['num_theta_skip'] = num_theta_skip + + # Generate reduced tomography fields + reduced_data = self._gen_tomo(nxentry, reduced_data) + + # Create a copy of the input Nexus object and remove raw and any existing reduced data + if isinstance(data, NXroot): + exclude_items = [f'{nxentry._name}/reduced_data/data', + f'{nxentry._name}/instrument/detector/data', + f'{nxentry._name}/instrument/detector/image_key', + f'{nxentry._name}/instrument/detector/sequence_number', + f'{nxentry._name}/sample/rotation_angle', + f'{nxentry._name}/sample/x_translation', + f'{nxentry._name}/sample/z_translation', + f'{nxentry._name}/data/data', + f'{nxentry._name}/data/image_key', + f'{nxentry._name}/data/rotation_angle', + f'{nxentry._name}/data/x_translation', + f'{nxentry._name}/data/z_translation'] + nxroot = nxcopy(data, exclude_nxpaths=exclude_items) + nxentry = nxroot[nxroot.attrs['default']] + + # Add the reduced data NXprocess + nxentry.reduced_data = reduced_data + + if 'data' not in nxentry: + nxentry.data = NXdata() + nxentry.attrs['default'] = 'data' + nxentry.data.makelink(nxentry.reduced_data.data.tomo_fields, name='reduced_data') + nxentry.data.makelink(nxentry.reduced_data.rotation_angle, name='rotation_angle') + nxentry.data.attrs['signal'] = 'reduced_data' + + return(nxroot) + + def find_centers(self, nxroot, center_rows=None, center_stack_index=None): + """Find the calibrated center axis info + """ + logger.info('Find the calibrated center axis info') + + if not isinstance(nxroot, NXroot): + raise ValueError(f'Invalid parameter nxroot ({nxroot})') + nxentry = nxroot[nxroot.attrs['default']] + if not isinstance(nxentry, NXentry): + raise ValueError(f'Invalid nxentry ({nxentry})') + if self.galaxy_flag: + if center_rows is not None: + center_rows = tuple(center_rows) + if not is_int_pair(center_rows): + raise ValueError(f'Invalid parameter center_rows ({center_rows})') + elif center_rows is not None: + logger.warning(f'Ignoring parameter center_rows ({center_rows})') + center_rows = None + if self.galaxy_flag: + if center_stack_index is not None and not is_int(center_stack_index, ge=0): + raise ValueError(f'Invalid parameter center_stack_index ({center_stack_index})') + elif center_stack_index is not None: + logger.warning(f'Ignoring parameter center_stack_index ({center_stack_index})') + center_stack_index = None + + # Create plot galaxy path directory and path if needed + if self.galaxy_flag: + if not os_path.exists('tomo_find_centers_plots'): + mkdir('tomo_find_centers_plots') + path = 'tomo_find_centers_plots' + else: + path = self.output_folder + + # Check if reduced data is available + if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data): + raise KeyError(f'Unable to find valid reduced data in {nxentry}.') + + # Select the image stack to calibrate the center axis + # reduced data axes order: stack,theta,row,column + # Note: Nexus cannot follow a link if the data it points to is too big, + # so get the data from the actual place, not from nxentry.data + tomo_fields_shape = nxentry.reduced_data.data.tomo_fields.shape + if len(tomo_fields_shape) != 4 or any(True for dim in tomo_fields_shape if not dim): + raise KeyError('Unable to load the required reduced tomography stack') + num_tomo_stacks = tomo_fields_shape[0] + if num_tomo_stacks == 1: + center_stack_index = 0 + default = 'n' + else: + if self.test_mode: + center_stack_index = self.test_config['center_stack_index']-1 # make offset 0 + elif self.galaxy_flag: + if center_stack_index is None: + center_stack_index = int(num_tomo_stacks/2) + if center_stack_index >= num_tomo_stacks: + raise ValueError(f'Invalid parameter center_stack_index ({center_stack_index})') + else: + center_stack_index = input_int('\nEnter tomography stack index to calibrate the ' + 'center axis', ge=1, le=num_tomo_stacks, default=int(1+num_tomo_stacks/2)) + center_stack_index -= 1 + default = 'y' + + # Get thetas (in degrees) + thetas = np.asarray(nxentry.reduced_data.rotation_angle) + + # Get effective pixel_size + if 'zoom_perc' in nxentry.reduced_data: + eff_pixel_size = 100.*(nxentry.instrument.detector.x_pixel_size/ + nxentry.reduced_data.attrs['zoom_perc']) + else: + eff_pixel_size = nxentry.instrument.detector.x_pixel_size + + # Get cross sectional diameter + cross_sectional_dim = tomo_fields_shape[3]*eff_pixel_size + logger.debug(f'cross_sectional_dim = {cross_sectional_dim}') + + # Determine center offset at sample row boundaries + logger.info('Determine center offset at sample row boundaries') + + # Lower row center + if self.test_mode: + lower_row = self.test_config['lower_row'] + elif self.galaxy_flag: + if center_rows is None: + lower_row = 0 + else: + lower_row = min(center_rows) + if not 0 <= lower_row < tomo_fields_shape[2]-1: + raise ValueError(f'Invalid parameter center_rows ({center_rows})') + else: + lower_row = select_one_image_bound( + nxentry.reduced_data.data.tomo_fields[center_stack_index,0,:,:], 0, bound=0, + title=f'theta={round(thetas[0], 2)+0}', + bound_name='row index to find lower center', default=default, raise_error=True) + logger.debug('Finding center...') + t0 = time() + lower_center_offset = self._find_center_one_plane( + #np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index,:,lower_row,:]), + nxentry.reduced_data.data.tomo_fields[center_stack_index,:,lower_row,:], + lower_row, thetas, eff_pixel_size, cross_sectional_dim, path=path, + num_core=self.num_core) + logger.debug(f'... done in {time()-t0:.2f} seconds') + logger.debug(f'lower_row = {lower_row:.2f}') + logger.debug(f'lower_center_offset = {lower_center_offset:.2f}') + + # Upper row center + if self.test_mode: + upper_row = self.test_config['upper_row'] + elif self.galaxy_flag: + if center_rows is None: + upper_row = tomo_fields_shape[2]-1 + else: + upper_row = max(center_rows) + if not lower_row < upper_row < tomo_fields_shape[2]: + raise ValueError(f'Invalid parameter center_rows ({center_rows})') + else: + upper_row = select_one_image_bound( + nxentry.reduced_data.data.tomo_fields[center_stack_index,0,:,:], 0, + bound=tomo_fields_shape[2]-1, title=f'theta={round(thetas[0], 2)+0}', + bound_name='row index to find upper center', default=default, raise_error=True) + logger.debug('Finding center...') + t0 = time() + upper_center_offset = self._find_center_one_plane( + #np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index,:,upper_row,:]), + nxentry.reduced_data.data.tomo_fields[center_stack_index,:,upper_row,:], + upper_row, thetas, eff_pixel_size, cross_sectional_dim, path=path, + num_core=self.num_core) + logger.debug(f'... done in {time()-t0:.2f} seconds') + logger.debug(f'upper_row = {upper_row:.2f}') + logger.debug(f'upper_center_offset = {upper_center_offset:.2f}') + + center_config = {'lower_row': lower_row, 'lower_center_offset': lower_center_offset, + 'upper_row': upper_row, 'upper_center_offset': upper_center_offset} + if num_tomo_stacks > 1: + center_config['center_stack_index'] = center_stack_index+1 # save as offset 1 + + # Save test data to file + if self.test_mode: + with open(f'{self.output_folder}/center_config.yaml', 'w') as f: + safe_dump(center_config, f) + + return(center_config) + + def reconstruct_data(self, nxroot, center_info, x_bounds=None, y_bounds=None): + """Reconstruct the tomography data. + """ + logger.info('Reconstruct the tomography data') + + if not isinstance(nxroot, NXroot): + raise ValueError(f'Invalid parameter nxroot ({nxroot})') + nxentry = nxroot[nxroot.attrs['default']] + if not isinstance(nxentry, NXentry): + raise ValueError(f'Invalid nxentry ({nxentry})') + if not isinstance(center_info, dict): + raise ValueError(f'Invalid parameter center_info ({center_info})') + + # Create plot galaxy path directory and path if needed + if self.galaxy_flag: + if not os_path.exists('tomo_reconstruct_plots'): + mkdir('tomo_reconstruct_plots') + path = 'tomo_reconstruct_plots' + else: + path = self.output_folder + + # Check if reduced data is available + if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data): + raise KeyError(f'Unable to find valid reduced data in {nxentry}.') + + # Create an NXprocess to store image reconstruction (meta)data + nxprocess = NXprocess() + + # Get rotation axis rows and centers + lower_row = center_info.get('lower_row') + lower_center_offset = center_info.get('lower_center_offset') + upper_row = center_info.get('upper_row') + upper_center_offset = center_info.get('upper_center_offset') + if (lower_row is None or lower_center_offset is None or upper_row is None or + upper_center_offset is None): + raise KeyError(f'Unable to find valid calibrated center axis info in {center_info}.') + center_slope = (upper_center_offset-lower_center_offset)/(upper_row-lower_row) + + # Get thetas (in degrees) + thetas = np.asarray(nxentry.reduced_data.rotation_angle) + + # Reconstruct tomography data + # reduced data axes order: stack,theta,row,column + # reconstructed data order in each stack: row/z,x,y + # Note: Nexus cannot follow a link if the data it points to is too big, + # so get the data from the actual place, not from nxentry.data + if 'zoom_perc' in nxentry.reduced_data: + res_title = f'{nxentry.reduced_data.attrs["zoom_perc"]}p' + else: + res_title = 'fullres' + load_error = False + num_tomo_stacks = nxentry.reduced_data.data.tomo_fields.shape[0] + tomo_recon_stacks = num_tomo_stacks*[np.array([])] + for i in range(num_tomo_stacks): + # Convert reduced data stack from theta,row,column to row,theta,column + logger.debug(f'Reading reduced data stack {i+1}...') + t0 = time() + tomo_stack = np.asarray(nxentry.reduced_data.data.tomo_fields[i]) + logger.debug(f'... done in {time()-t0:.2f} seconds') + if len(tomo_stack.shape) != 3 or any(True for dim in tomo_stack.shape if not dim): + raise ValueError(f'Unable to load tomography stack {i+1} for reconstruction') + tomo_stack = np.swapaxes(tomo_stack, 0, 1) + assert(len(thetas) == tomo_stack.shape[1]) + assert(0 <= lower_row < upper_row < tomo_stack.shape[0]) + center_offsets = [lower_center_offset-lower_row*center_slope, + upper_center_offset+(tomo_stack.shape[0]-1-upper_row)*center_slope] + t0 = time() + logger.debug(f'Running _reconstruct_one_tomo_stack on {self.num_core} cores ...') + tomo_recon_stack = self._reconstruct_one_tomo_stack(tomo_stack, thetas, + center_offsets=center_offsets, num_core=self.num_core, algorithm='gridrec') + logger.debug(f'... done in {time()-t0:.2f} seconds') + logger.info(f'Reconstruction of stack {i+1} took {time()-t0:.2f} seconds') + + # Combine stacks + tomo_recon_stacks[i] = tomo_recon_stack + + # Resize the reconstructed tomography data + # reconstructed data order in each stack: row/z,x,y + if self.test_mode: + x_bounds = self.test_config.get('x_bounds') + y_bounds = self.test_config.get('y_bounds') + z_bounds = None + elif self.galaxy_flag: + if x_bounds is not None and not is_int_pair(x_bounds, ge=0, + lt=tomo_recon_stacks[0].shape[1]): + raise ValueError(f'Invalid parameter x_bounds ({x_bounds})') + if y_bounds is not None and not is_int_pair(y_bounds, ge=0, + lt=tomo_recon_stacks[0].shape[1]): + raise ValueError(f'Invalid parameter y_bounds ({y_bounds})') + z_bounds = None + else: + x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_stacks) + if x_bounds is None: + x_range = (0, tomo_recon_stacks[0].shape[1]) + x_slice = int(x_range[1]/2) + else: + x_range = (min(x_bounds), max(x_bounds)) + x_slice = int((x_bounds[0]+x_bounds[1])/2) + if y_bounds is None: + y_range = (0, tomo_recon_stacks[0].shape[2]) + y_slice = int(y_range[1]/2) + else: + y_range = (min(y_bounds), max(y_bounds)) + y_slice = int((y_bounds[0]+y_bounds[1])/2) + if z_bounds is None: + z_range = (0, tomo_recon_stacks[0].shape[0]) + z_slice = int(z_range[1]/2) + else: + z_range = (min(z_bounds), max(z_bounds)) + z_slice = int((z_bounds[0]+z_bounds[1])/2) + + # Plot a few reconstructed image slices + if num_tomo_stacks == 1: + basetitle = 'recon' + else: + basetitle = f'recon stack {i+1}' + for i, stack in enumerate(tomo_recon_stacks): + title = f'{basetitle} {res_title} xslice{x_slice}' + quick_imshow(stack[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]], + title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, + block=self.block) + title = f'{basetitle} {res_title} yslice{y_slice}' + quick_imshow(stack[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice], + title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, + block=self.block) + title = f'{basetitle} {res_title} zslice{z_slice}' + quick_imshow(stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], + title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, + block=self.block) + + # Save test data to file + # reconstructed data order in each stack: row/z,x,y + if self.test_mode: + for i, stack in enumerate(tomo_recon_stacks): + np.savetxt(f'{self.output_folder}/recon_stack_{i+1}.txt', + stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e') + + # Add image reconstruction to reconstructed data NXprocess + # reconstructed data order in each stack: row/z,x,y + nxprocess.data = NXdata() + nxprocess.attrs['default'] = 'data' + for k, v in center_info.items(): + nxprocess[k] = v + if x_bounds is not None: + nxprocess.x_bounds = x_bounds + if y_bounds is not None: + nxprocess.y_bounds = y_bounds + if z_bounds is not None: + nxprocess.z_bounds = z_bounds + nxprocess.data['reconstructed_data'] = np.asarray([stack[z_range[0]:z_range[1], + x_range[0]:x_range[1],y_range[0]:y_range[1]] for stack in tomo_recon_stacks]) + nxprocess.data.attrs['signal'] = 'reconstructed_data' + + # Create a copy of the input Nexus object and remove reduced data + exclude_items = [f'{nxentry._name}/reduced_data/data', f'{nxentry._name}/data/reduced_data'] + nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items) + + # Add the reconstructed data NXprocess to the new Nexus object + nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']] + nxentry_copy.reconstructed_data = nxprocess + if 'data' not in nxentry_copy: + nxentry_copy.data = NXdata() + nxentry_copy.attrs['default'] = 'data' + nxentry_copy.data.makelink(nxprocess.data.reconstructed_data, name='reconstructed_data') + nxentry_copy.data.attrs['signal'] = 'reconstructed_data' + + return(nxroot_copy) + + def combine_data(self, nxroot, x_bounds=None, y_bounds=None): + """Combine the reconstructed tomography stacks. + """ + logger.info('Combine the reconstructed tomography stacks') + + if not isinstance(nxroot, NXroot): + raise ValueError(f'Invalid parameter nxroot ({nxroot})') + nxentry = nxroot[nxroot.attrs['default']] + if not isinstance(nxentry, NXentry): + raise ValueError(f'Invalid nxentry ({nxentry})') + + # Create plot galaxy path directory and path if needed + if self.galaxy_flag: + if not os_path.exists('tomo_combine_plots'): + mkdir('tomo_combine_plots') + path = 'tomo_combine_plots' + else: + path = self.output_folder + + # Check if reconstructed image data is available + if ('reconstructed_data' not in nxentry or 'reconstructed_data' not in nxentry.data): + raise KeyError(f'Unable to find valid reconstructed image data in {nxentry}.') + + # Create an NXprocess to store combined image reconstruction (meta)data + nxprocess = NXprocess() + + # Get the reconstructed data + # reconstructed data order: stack,row(z),x,y + # Note: Nexus cannot follow a link if the data it points to is too big, + # so get the data from the actual place, not from nxentry.data + num_tomo_stacks = nxentry.reconstructed_data.data.reconstructed_data.shape[0] + if num_tomo_stacks == 1: + logger.info('Only one stack available: leaving combine_data') + return(None) + + # Combine the reconstructed stacks + # (load one stack at a time to reduce risk of hitting Nexus data access limit) + t0 = time() + logger.debug(f'Combining the reconstructed stacks ...') + tomo_recon_combined = np.asarray(nxentry.reconstructed_data.data.reconstructed_data[0]) + if num_tomo_stacks > 2: + tomo_recon_combined = np.concatenate([tomo_recon_combined]+ + [nxentry.reconstructed_data.data.reconstructed_data[i] + for i in range(1, num_tomo_stacks-1)]) + if num_tomo_stacks > 1: + tomo_recon_combined = np.concatenate([tomo_recon_combined]+ + [nxentry.reconstructed_data.data.reconstructed_data[num_tomo_stacks-1]]) + logger.debug(f'... done in {time()-t0:.2f} seconds') + logger.info(f'Combining the reconstructed stacks took {time()-t0:.2f} seconds') + + # Resize the combined tomography data stacks + # combined data order: row/z,x,y + if self.test_mode: + x_bounds = None + y_bounds = None + z_bounds = self.test_config.get('z_bounds') + elif self.galaxy_flag: + if x_bounds is not None and not is_int_pair(x_bounds, ge=0, + lt=tomo_recon_stacks[0].shape[1]): + raise ValueError(f'Invalid parameter x_bounds ({x_bounds})') + if y_bounds is not None and not is_int_pair(y_bounds, ge=0, + lt=tomo_recon_stacks[0].shape[1]): + raise ValueError(f'Invalid parameter y_bounds ({y_bounds})') + z_bounds = None + else: + x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_combined, + z_only=True) + if x_bounds is None: + x_range = (0, tomo_recon_combined.shape[1]) + x_slice = int(x_range[1]/2) + else: + x_range = x_bounds + x_slice = int((x_bounds[0]+x_bounds[1])/2) + if y_bounds is None: + y_range = (0, tomo_recon_combined.shape[2]) + y_slice = int(y_range[1]/2) + else: + y_range = y_bounds + y_slice = int((y_bounds[0]+y_bounds[1])/2) + if z_bounds is None: + z_range = (0, tomo_recon_combined.shape[0]) + z_slice = int(z_range[1]/2) + else: + z_range = z_bounds + z_slice = int((z_bounds[0]+z_bounds[1])/2) + + # Plot a few combined image slices + quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]], + title=f'recon combined xslice{x_slice}', path=path, + save_fig=self.save_figs, save_only=self.save_only, block=self.block) + quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice], + title=f'recon combined yslice{y_slice}', path=path, + save_fig=self.save_figs, save_only=self.save_only, block=self.block) + quick_imshow(tomo_recon_combined[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], + title=f'recon combined zslice{z_slice}', path=path, + save_fig=self.save_figs, save_only=self.save_only, block=self.block) + + # Save test data to file + # combined data order: row/z,x,y + if self.test_mode: + np.savetxt(f'{self.output_folder}/recon_combined.txt', tomo_recon_combined[ + z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e') + + # Add image reconstruction to reconstructed data NXprocess + # combined data order: row/z,x,y + nxprocess.data = NXdata() + nxprocess.attrs['default'] = 'data' + if x_bounds is not None: + nxprocess.x_bounds = x_bounds + if y_bounds is not None: + nxprocess.y_bounds = y_bounds + if z_bounds is not None: + nxprocess.z_bounds = z_bounds + nxprocess.data['combined_data'] = tomo_recon_combined + nxprocess.data.attrs['signal'] = 'combined_data' + + # Create a copy of the input Nexus object and remove reconstructed data + exclude_items = [f'{nxentry._name}/reconstructed_data/data', + f'{nxentry._name}/data/reconstructed_data'] + nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items) + + # Add the combined data NXprocess to the new Nexus object + nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']] + nxentry_copy.combined_data = nxprocess + if 'data' not in nxentry_copy: + nxentry_copy.data = NXdata() + nxentry_copy.attrs['default'] = 'data' + nxentry_copy.data.makelink(nxprocess.data.combined_data, name='combined_data') + nxentry_copy.data.attrs['signal'] = 'combined_data' + + return(nxroot_copy) + + def _gen_dark(self, nxentry, reduced_data): + """Generate dark field. + """ + # Get the dark field images + image_key = nxentry.instrument.detector.get('image_key', None) + if image_key and 'data' in nxentry.instrument.detector: + field_indices = [index for index, key in enumerate(image_key) if key == 2] + tdf_stack = nxentry.instrument.detector.data[field_indices,:,:] + # RV the default NXtomo form does not accomodate bright or dark field stacks + else: + dark_field_scans = nxentry.spec_scans.dark_field + dark_field = FlatField.construct_from_nxcollection(dark_field_scans) + prefix = str(nxentry.instrument.detector.local_name) + tdf_stack = dark_field.get_detector_data(prefix) + if isinstance(tdf_stack, list): + assert(len(tdf_stack) == 1) # TODO + tdf_stack = tdf_stack[0] + + # Take median + if tdf_stack.ndim == 2: + tdf = tdf_stack + elif tdf_stack.ndim == 3: + tdf = np.median(tdf_stack, axis=0) + del tdf_stack + else: + raise ValueError(f'Invalid tdf_stack shape ({tdf_stack.shape})') + + # Remove dark field intensities above the cutoff +#RV tdf_cutoff = None + tdf_cutoff = tdf.min()+2*(np.median(tdf)-tdf.min()) + logger.debug(f'tdf_cutoff = {tdf_cutoff}') + if tdf_cutoff is not None: + if not is_num(tdf_cutoff, ge=0): + logger.warning(f'Ignoring illegal value of tdf_cutoff {tdf_cutoff}') + else: + tdf[tdf > tdf_cutoff] = np.nan + logger.debug(f'tdf_cutoff = {tdf_cutoff}') + + # Remove nans + tdf_mean = np.nanmean(tdf) + logger.debug(f'tdf_mean = {tdf_mean}') + np.nan_to_num(tdf, copy=False, nan=tdf_mean, posinf=tdf_mean, neginf=0.) + + # Plot dark field + if self.galaxy_flag: + quick_imshow(tdf, title='dark field', path='tomo_reduce_plots', save_fig=self.save_figs, + save_only=self.save_only) + elif not self.test_mode: + quick_imshow(tdf, title='dark field', path=self.output_folder, save_fig=self.save_figs, + save_only=self.save_only) + clear_imshow('dark field') +# quick_imshow(tdf, title='dark field', block=True) + + # Add dark field to reduced data NXprocess + reduced_data.data = NXdata() + reduced_data.data['dark_field'] = tdf + + return(reduced_data) + + def _gen_bright(self, nxentry, reduced_data): + """Generate bright field. + """ + # Get the bright field images + image_key = nxentry.instrument.detector.get('image_key', None) + if image_key and 'data' in nxentry.instrument.detector: + field_indices = [index for index, key in enumerate(image_key) if key == 1] + tbf_stack = nxentry.instrument.detector.data[field_indices,:,:] + # RV the default NXtomo form does not accomodate bright or dark field stacks + else: + bright_field_scans = nxentry.spec_scans.bright_field + bright_field = FlatField.construct_from_nxcollection(bright_field_scans) + prefix = str(nxentry.instrument.detector.local_name) + tbf_stack = bright_field.get_detector_data(prefix) + if isinstance(tbf_stack, list): + assert(len(tbf_stack) == 1) # TODO + tbf_stack = tbf_stack[0] + + # Take median if more than one image + """Median or mean: It may be best to try the median because of some image + artifacts that arise due to crinkles in the upstream kapton tape windows + causing some phase contrast images to appear on the detector. + One thing that also may be useful in a future implementation is to do a + brightfield adjustment on EACH frame of the tomo based on a ROI in the + corner of the frame where there is no sample but there is the direct X-ray + beam because there is frame to frame fluctuations from the incoming beam. + We don’t typically account for them but potentially could. + """ + if tbf_stack.ndim == 2: + tbf = tbf_stack + elif tbf_stack.ndim == 3: + tbf = np.median(tbf_stack, axis=0) + del tbf_stack + else: + raise ValueError(f'Invalid tbf_stack shape ({tbf_stacks.shape})') + + # Subtract dark field + if 'data' in reduced_data and 'dark_field' in reduced_data.data: + tbf -= reduced_data.data.dark_field + else: + logger.warning('Dark field unavailable') + + # Set any non-positive values to one + # (avoid negative bright field values for spikes in dark field) + tbf[tbf < 1] = 1 + + # Plot bright field + if self.galaxy_flag: + quick_imshow(tbf, title='bright field', path='tomo_reduce_plots', + save_fig=self.save_figs, save_only=self.save_only) + elif not self.test_mode: + quick_imshow(tbf, title='bright field', path=self.output_folder, + save_fig=self.save_figs, save_only=self.save_only) + clear_imshow('bright field') +# quick_imshow(tbf, title='bright field', block=True) + + # Add bright field to reduced data NXprocess + if 'data' not in reduced_data: + reduced_data.data = NXdata() + reduced_data.data['bright_field'] = tbf + + return(reduced_data) + + def _set_detector_bounds(self, nxentry, reduced_data, img_x_bounds=None): + """Set vertical detector bounds for each image stack. + Right now the range is the same for each set in the image stack. + """ + if self.test_mode: + return(tuple(self.test_config['img_x_bounds'])) + + # Get the first tomography image and the reference heights + image_key = nxentry.instrument.detector.get('image_key', None) + if image_key and 'data' in nxentry.instrument.detector: + field_indices = [index for index, key in enumerate(image_key) if key == 0] + first_image = np.asarray(nxentry.instrument.detector.data[field_indices[0],:,:]) + theta = float(nxentry.sample.rotation_angle[field_indices[0]]) + z_translation_all = nxentry.sample.z_translation[field_indices] + vertical_shifts = sorted(list(set(z_translation_all))) + num_tomo_stacks = len(vertical_shifts) + else: + tomo_field_scans = nxentry.spec_scans.tomo_fields + tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans) + vertical_shifts = tomo_fields.get_vertical_shifts() + if not isinstance(vertical_shifts, list): + vertical_shifts = [vertical_shifts] + prefix = str(nxentry.instrument.detector.local_name) + t0 = time() + first_image = tomo_fields.get_detector_data(prefix, tomo_fields.scan_numbers[0], 0) + logger.debug(f'Getting first image took {time()-t0:.2f} seconds') + num_tomo_stacks = len(tomo_fields.scan_numbers) + theta = tomo_fields.theta_range['start'] + + # Select image bounds + title = f'tomography image at theta={round(theta, 2)+0}' + if (img_x_bounds is not None and not is_index_range(img_x_bounds, ge=0, + le=first_image.shape[0])): + raise ValueError(f'Invalid parameter img_x_bounds ({img_x_bounds})') + if nxentry.instrument.source.attrs['station'] in ('id1a3', 'id3a'): + pixel_size = nxentry.instrument.detector.x_pixel_size + # Try to get a fit from the bright field + tbf = np.asarray(reduced_data.data.bright_field) + tbf_shape = tbf.shape + x_sum = np.sum(tbf, 1) + x_sum_min = x_sum.min() + x_sum_max = x_sum.max() + fit = Fit.fit_data(x_sum, 'rectangle', x=np.array(range(len(x_sum))), form='atan', + guess=True) + parameters = fit.best_values + x_low_fit = parameters.get('center1', None) + x_upp_fit = parameters.get('center2', None) + sig_low = parameters.get('sigma1', None) + sig_upp = parameters.get('sigma2', None) + have_fit = fit.success and x_low_fit is not None and x_upp_fit is not None and \ + sig_low is not None and sig_upp is not None and \ + 0 <= x_low_fit < x_upp_fit <= x_sum.size and \ + (sig_low+sig_upp)/(x_upp_fit-x_low_fit) < 0.1 + if have_fit: + # Set a 5% margin on each side + margin = 0.05*(x_upp_fit-x_low_fit) + x_low_fit = max(0, x_low_fit-margin) + x_upp_fit = min(tbf_shape[0], x_upp_fit+margin) + if num_tomo_stacks == 1: + if have_fit: + # Set the default range to enclose the full fitted window + x_low = int(x_low_fit) + x_upp = int(x_upp_fit) + else: + # Center a default range of 1 mm (RV: can we get this from the slits?) + num_x_min = int((1.0-0.5*pixel_size)/pixel_size) + x_low = int(0.5*(tbf_shape[0]-num_x_min)) + x_upp = x_low+num_x_min + else: + # Get the default range from the reference heights + delta_z = vertical_shifts[1]-vertical_shifts[0] + for i in range(2, num_tomo_stacks): + delta_z = min(delta_z, vertical_shifts[i]-vertical_shifts[i-1]) + logger.debug(f'delta_z = {delta_z}') + num_x_min = int((delta_z-0.5*pixel_size)/pixel_size) + logger.debug(f'num_x_min = {num_x_min}') + if num_x_min > tbf_shape[0]: + logger.warning('Image bounds and pixel size prevent seamless stacking') + if have_fit: + # Center the default range relative to the fitted window + x_low = int(0.5*(x_low_fit+x_upp_fit-num_x_min)) + x_upp = x_low+num_x_min + else: + # Center the default range + x_low = int(0.5*(tbf_shape[0]-num_x_min)) + x_upp = x_low+num_x_min + if self.galaxy_flag: + img_x_bounds = (x_low, x_upp) + else: + tmp = np.copy(tbf) + tmp_max = tmp.max() + tmp[x_low,:] = tmp_max + tmp[x_upp-1,:] = tmp_max + quick_imshow(tmp, title='bright field') + tmp = np.copy(first_image) + tmp_max = tmp.max() + tmp[x_low,:] = tmp_max + tmp[x_upp-1,:] = tmp_max + quick_imshow(tmp, title=title) + del tmp + quick_plot((range(x_sum.size), x_sum), + ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'), + ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'), + title='sum over theta and y') + print(f'lower bound = {x_low} (inclusive)') + print(f'upper bound = {x_upp} (exclusive)]') + accept = input_yesno('Accept these bounds (y/n)?', 'y') + clear_imshow('bright field') + clear_imshow(title) + clear_plot('sum over theta and y') + if accept: + img_x_bounds = (x_low, x_upp) + else: + while True: + mask, img_x_bounds = draw_mask_1d(x_sum, title='select x data range', + legend='sum over theta and y') + if len(img_x_bounds) == 1: + break + else: + print(f'Choose a single connected data range') + img_x_bounds = tuple(img_x_bounds[0]) + if (num_tomo_stacks > 1 and img_x_bounds[1]-img_x_bounds[0]+1 < + int((delta_z-0.5*pixel_size)/pixel_size)): + logger.warning('Image bounds and pixel size prevent seamless stacking') + else: + if num_tomo_stacks > 1: + raise NotImplementedError('Selecting image bounds for multiple stacks on FMB') + # For FMB: use the first tomography image to select range + # RV: revisit if they do tomography with multiple stacks + x_sum = np.sum(first_image, 1) + x_sum_min = x_sum.min() + x_sum_max = x_sum.max() + if self.galaxy_flag: + if img_x_bounds is None: + img_x_bounds = (0, first_image.shape[0]) + else: + quick_imshow(first_image, title=title) + print('Select vertical data reduction range from first tomography image') + img_x_bounds = select_image_bounds(first_image, 0, title=title) + clear_imshow(title) + if img_x_bounds is None: + raise ValueError('Unable to select image bounds') + + # Plot results + if self.galaxy_flag: + path = 'tomo_reduce_plots' + else: + path = self.output_folder + x_low = img_x_bounds[0] + x_upp = img_x_bounds[1] + tmp = np.copy(first_image) + tmp_max = tmp.max() + tmp[x_low,:] = tmp_max + tmp[x_upp-1,:] = tmp_max + quick_imshow(tmp, title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, + block=self.block) + del tmp + quick_plot((range(x_sum.size), x_sum), + ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'), + ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'), + title='sum over theta and y', path=path, save_fig=self.save_figs, + save_only=self.save_only, block=self.block) + + return(img_x_bounds) + + def _set_zoom_or_skip(self): + """Set zoom and/or theta skip to reduce memory the requirement for the analysis. + """ +# if input_yesno('\nDo you want to zoom in to reduce memory requirement (y/n)?', 'n'): +# zoom_perc = input_int(' Enter zoom percentage', ge=1, le=100) +# else: +# zoom_perc = None + zoom_perc = None +# if input_yesno('Do you want to skip thetas to reduce memory requirement (y/n)?', 'n'): +# num_theta_skip = input_int(' Enter the number skip theta interval', ge=0, +# lt=num_theta) +# else: +# num_theta_skip = None + num_theta_skip = None + logger.debug(f'zoom_perc = {zoom_perc}') + logger.debug(f'num_theta_skip = {num_theta_skip}') + + return(zoom_perc, num_theta_skip) + + def _gen_tomo(self, nxentry, reduced_data): + """Generate tomography fields. + """ + # Get full bright field + tbf = np.asarray(reduced_data.data.bright_field) + tbf_shape = tbf.shape + + # Get image bounds + img_x_bounds = tuple(reduced_data.get('img_x_bounds', (0, tbf_shape[0]))) + img_y_bounds = tuple(reduced_data.get('img_y_bounds', (0, tbf_shape[1]))) + + # Get resized dark field +# if 'dark_field' in data: +# tbf = np.asarray(reduced_data.data.dark_field[ +# img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]]) +# else: +# logger.warning('Dark field unavailable') +# tdf = None + tdf = None + + # Resize bright field + if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]): + tbf = tbf[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]] + + # Get the tomography images + image_key = nxentry.instrument.detector.get('image_key', None) + if image_key and 'data' in nxentry.instrument.detector: + field_indices_all = [index for index, key in enumerate(image_key) if key == 0] + z_translation_all = nxentry.sample.z_translation[field_indices_all] + z_translation_levels = sorted(list(set(z_translation_all))) + num_tomo_stacks = len(z_translation_levels) + tomo_stacks = num_tomo_stacks*[np.array([])] + horizontal_shifts = [] + vertical_shifts = [] + thetas = None + tomo_stacks = [] + for i, z_translation in enumerate(z_translation_levels): + field_indices = [field_indices_all[index] + for index, z in enumerate(z_translation_all) if z == z_translation] + horizontal_shift = list(set(nxentry.sample.x_translation[field_indices])) + assert(len(horizontal_shift) == 1) + horizontal_shifts += horizontal_shift + vertical_shift = list(set(nxentry.sample.z_translation[field_indices])) + assert(len(vertical_shift) == 1) + vertical_shifts += vertical_shift + sequence_numbers = nxentry.instrument.detector.sequence_number[field_indices] + if thetas is None: + thetas = np.asarray(nxentry.sample.rotation_angle[field_indices]) \ + [sequence_numbers] + else: + assert(all(thetas[i] == nxentry.sample.rotation_angle[field_indices[index]] + for i, index in enumerate(sequence_numbers))) + assert(list(set(sequence_numbers)) == [i for i in range(len(sequence_numbers))]) + if list(sequence_numbers) == [i for i in range(len(sequence_numbers))]: + tomo_stack = np.asarray(nxentry.instrument.detector.data[field_indices]) + else: + raise ValueError('Unable to load the tomography images') + tomo_stacks.append(tomo_stack) + else: + tomo_field_scans = nxentry.spec_scans.tomo_fields + tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans) + horizontal_shifts = tomo_fields.get_horizontal_shifts() + vertical_shifts = tomo_fields.get_vertical_shifts() + prefix = str(nxentry.instrument.detector.local_name) + t0 = time() + tomo_stacks = tomo_fields.get_detector_data(prefix) + logger.debug(f'Getting tomography images took {time()-t0:.2f} seconds') + logger.debug(f'Getting all images took {time()-t0:.2f} seconds') + thetas = np.linspace(tomo_fields.theta_range['start'], tomo_fields.theta_range['end'], + tomo_fields.theta_range['num']) + if not isinstance(tomo_stacks, list): + horizontal_shifts = [horizontal_shifts] + vertical_shifts = [vertical_shifts] + tomo_stacks = [tomo_stacks] + + reduced_tomo_stacks = [] + if self.galaxy_flag: + path = 'tomo_reduce_plots' + else: + path = self.output_folder + for i, tomo_stack in enumerate(tomo_stacks): + # Resize the tomography images + # Right now the range is the same for each set in the image stack. + if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]): + t0 = time() + tomo_stack = tomo_stack[:,img_x_bounds[0]:img_x_bounds[1], + img_y_bounds[0]:img_y_bounds[1]].astype('float64') + logger.debug(f'Resizing tomography images took {time()-t0:.2f} seconds') + + # Subtract dark field + if tdf is not None: + t0 = time() + with set_numexpr_threads(self.num_core): + ne.evaluate('tomo_stack-tdf', out=tomo_stack) + logger.debug(f'Subtracting dark field took {time()-t0:.2f} seconds') + + # Normalize + t0 = time() + with set_numexpr_threads(self.num_core): + ne.evaluate('tomo_stack/tbf', out=tomo_stack, truediv=True) + logger.debug(f'Normalizing took {time()-t0:.2f} seconds') + + # Remove non-positive values and linearize data + t0 = time() + cutoff = 1.e-6 + with set_numexpr_threads(self.num_core): + ne.evaluate('where(tomo_stack<cutoff, cutoff, tomo_stack)', out=tomo_stack) + with set_numexpr_threads(self.num_core): + ne.evaluate('-log(tomo_stack)', out=tomo_stack) + logger.debug('Removing non-positive values and linearizing data took '+ + f'{time()-t0:.2f} seconds') + + # Get rid of nans/infs that may be introduced by normalization + t0 = time() + np.where(np.isfinite(tomo_stack), tomo_stack, 0.) + logger.debug(f'Remove nans/infs took {time()-t0:.2f} seconds') + + # Downsize tomography stack to smaller size + # TODO use theta_skip as well + tomo_stack = tomo_stack.astype('float32') + if not self.test_mode: + if len(tomo_stacks) == 1: + title = f'red fullres theta {round(thetas[0], 2)+0}' + else: + title = f'red stack {i+1} fullres theta {round(thetas[0], 2)+0}' + quick_imshow(tomo_stack[0,:,:], title=title, path=path, save_fig=self.save_figs, + save_only=self.save_only, block=self.block) +# if not self.block: +# clear_imshow(title) + if False and zoom_perc != 100: + t0 = time() + logger.debug(f'Zooming in ...') + tomo_zoom_list = [] + for j in range(tomo_stack.shape[0]): + tomo_zoom = spi.zoom(tomo_stack[j,:,:], 0.01*zoom_perc) + tomo_zoom_list.append(tomo_zoom) + tomo_stack = np.stack([tomo_zoom for tomo_zoom in tomo_zoom_list]) + logger.debug(f'... done in {time()-t0:.2f} seconds') + logger.info(f'Zooming in took {time()-t0:.2f} seconds') + del tomo_zoom_list + if not self.test_mode: + title = f'red stack {zoom_perc}p theta {round(thetas[0], 2)+0}' + quick_imshow(tomo_stack[0,:,:], title=title, path=path, save_fig=self.save_figs, + save_only=self.save_only, block=self.block) +# if not self.block: +# clear_imshow(title) + + # Save test data to file + if self.test_mode: +# row_index = int(tomo_stack.shape[0]/2) +# np.savetxt(f'{self.output_folder}/red_stack_{i+1}.txt', tomo_stack[row_index,:,:], +# fmt='%.6e') + row_index = int(tomo_stack.shape[1]/2) + np.savetxt(f'{self.output_folder}/red_stack_{i+1}.txt', tomo_stack[:,row_index,:], + fmt='%.6e') + + # Combine resized stacks + reduced_tomo_stacks.append(tomo_stack) + + # Add tomo field info to reduced data NXprocess + reduced_data['rotation_angle'] = thetas + reduced_data['x_translation'] = np.asarray(horizontal_shifts) + reduced_data['z_translation'] = np.asarray(vertical_shifts) + reduced_data.data['tomo_fields'] = np.asarray(reduced_tomo_stacks) + + if tdf is not None: + del tdf + del tbf + + return(reduced_data) + + def _find_center_one_plane(self, sinogram, row, thetas, eff_pixel_size, cross_sectional_dim, + path=None, tol=0.1, num_core=1): + """Find center for a single tomography plane. + """ + # Try automatic center finding routines for initial value + # sinogram index order: theta,column + # need column,theta for iradon, so take transpose + sinogram = np.asarray(sinogram) + sinogram_T = sinogram.T + center = sinogram.shape[1]/2 + + # Try using Nghia Vo’s method + t0 = time() + if num_core > num_core_tomopy_limit: + logger.debug(f'Running find_center_vo on {num_core_tomopy_limit} cores ...') + tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core_tomopy_limit) + else: + logger.debug(f'Running find_center_vo on {num_core} cores ...') + tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core) + logger.debug(f'... done in {time()-t0:.2f} seconds') + logger.info(f'Finding the center using Nghia Vo’s method took {time()-t0:.2f} seconds') + center_offset_vo = tomo_center-center + logger.info(f'Center at row {row} using Nghia Vo’s method = {center_offset_vo:.2f}') + t0 = time() + logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...') + recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas, + eff_pixel_size, cross_sectional_dim, False, num_core) + logger.debug(f'... done in {time()-t0:.2f} seconds') + logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds') + + title = f'edges row{row} center offset{center_offset_vo:.2f} Vo' + self._plot_edges_one_plane(recon_plane, title, path=path) + + # Try using phase correlation method +# if input_yesno('Try finding center using phase correlation (y/n)?', 'n'): +# t0 = time() +# logger.debug(f'Running find_center_pc ...') +# tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=0.1, rotc_guess=tomo_center) +# error = 1. +# while error > tol: +# prev = tomo_center +# tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=tol, +# rotc_guess=tomo_center) +# error = np.abs(tomo_center-prev) +# logger.debug(f'... done in {time()-t0:.2f} seconds') +# logger.info('Finding the center using the phase correlation method took '+ +# f'{time()-t0:.2f} seconds') +# center_offset = tomo_center-center +# print(f'Center at row {row} using phase correlation = {center_offset:.2f}') +# t0 = time() +# logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...') +# recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas, +# eff_pixel_size, cross_sectional_dim, False, num_core) +# logger.debug(f'... done in {time()-t0:.2f} seconds') +# logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds') +# +# title = f'edges row{row} center_offset{center_offset:.2f} PC' +# self._plot_edges_one_plane(recon_plane, title, path=path) + + # Select center location +# if input_yesno('Accept a center location (y) or continue search (n)?', 'y'): + if True: +# center_offset = input_num(' Enter chosen center offset', ge=-center, le=center, +# default=center_offset_vo) + center_offset = center_offset_vo + del sinogram_T + del recon_plane + return float(center_offset) + + # perform center finding search + while True: + center_offset_low = input_int('\nEnter lower bound for center offset', ge=-center, + le=center) + center_offset_upp = input_int('Enter upper bound for center offset', + ge=center_offset_low, le=center) + if center_offset_upp == center_offset_low: + center_offset_step = 1 + else: + center_offset_step = input_int('Enter step size for center offset search', ge=1, + le=center_offset_upp-center_offset_low) + num_center_offset = 1+int((center_offset_upp-center_offset_low)/center_offset_step) + center_offsets = np.linspace(center_offset_low, center_offset_upp, num_center_offset) + for center_offset in center_offsets: + if center_offset == center_offset_vo: + continue + t0 = time() + logger.debug(f'Running _reconstruct_one_plane on {num_core} cores ...') + recon_plane = self._reconstruct_one_plane(sinogram_T, center_offset+center, thetas, + eff_pixel_size, cross_sectional_dim, False, num_core) + logger.debug(f'... done in {time()-t0:.2f} seconds') + logger.info(f'Reconstructing center_offset {center_offset} took '+ + f'{time()-t0:.2f} seconds') + title = f'edges row{row} center_offset{center_offset:.2f}' + self._plot_edges_one_plane(recon_plane, title, path=path) + if input_int('\nContinue (0) or end the search (1)', ge=0, le=1): + break + + del sinogram_T + del recon_plane + center_offset = input_num(' Enter chosen center offset', ge=-center, le=center) + return float(center_offset) + + def _reconstruct_one_plane(self, tomo_plane_T, center, thetas, eff_pixel_size, + cross_sectional_dim, plot_sinogram=True, num_core=1): + """Invert the sinogram for a single tomography plane. + """ + # tomo_plane_T index order: column,theta + assert(0 <= center < tomo_plane_T.shape[0]) + center_offset = center-tomo_plane_T.shape[0]/2 + two_offset = 2*int(np.round(center_offset)) + two_offset_abs = np.abs(two_offset) + max_rad = int(0.55*(cross_sectional_dim/eff_pixel_size)) # 10% slack to avoid edge effects + if max_rad > 0.5*tomo_plane_T.shape[0]: + max_rad = 0.5*tomo_plane_T.shape[0] + dist_from_edge = max(1, int(np.floor((tomo_plane_T.shape[0]-two_offset_abs)/2.)-max_rad)) + if two_offset >= 0: + logger.debug(f'sinogram range = [{two_offset+dist_from_edge}, {-dist_from_edge}]') + sinogram = tomo_plane_T[two_offset+dist_from_edge:-dist_from_edge,:] + else: + logger.debug(f'sinogram range = [{dist_from_edge}, {two_offset-dist_from_edge}]') + sinogram = tomo_plane_T[dist_from_edge:two_offset-dist_from_edge,:] + if not self.galaxy_flag and plot_sinogram: + quick_imshow(sinogram.T, f'sinogram center offset{center_offset:.2f}', aspect='auto', + path=self.output_folder, save_fig=self.save_figs, save_only=self.save_only, + block=self.block) + + # Inverting sinogram + t0 = time() + recon_sinogram = iradon(sinogram, theta=thetas, circle=True) + logger.debug(f'Inverting sinogram took {time()-t0:.2f} seconds') + del sinogram + + # Performing Gaussian filtering and removing ring artifacts + recon_parameters = None#self.config.get('recon_parameters') + if recon_parameters is None: + sigma = 1.0 + ring_width = 15 + else: + sigma = recon_parameters.get('gaussian_sigma', 1.0) + if not is_num(sigma, ge=0.0): + logger.warning(f'Invalid gaussian_sigma ({sigma}) in _reconstruct_one_plane, '+ + 'set to a default value of 1.0') + sigma = 1.0 + ring_width = recon_parameters.get('ring_width', 15) + if not is_int(ring_width, ge=0): + logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+ + 'set to a default value of 15') + ring_width = 15 + t0 = time() + recon_sinogram = spi.gaussian_filter(recon_sinogram, sigma, mode='nearest') + recon_clean = np.expand_dims(recon_sinogram, axis=0) + del recon_sinogram + recon_clean = tomopy.misc.corr.remove_ring(recon_clean, rwidth=ring_width, ncore=num_core) + logger.debug(f'Filtering and removing ring artifacts took {time()-t0:.2f} seconds') + + return recon_clean + + def _plot_edges_one_plane(self, recon_plane, title, path=None): + vis_parameters = None#self.config.get('vis_parameters') + if vis_parameters is None: + weight = 0.1 + else: + weight = vis_parameters.get('denoise_weight', 0.1) + if not is_num(weight, ge=0.0): + logger.warning(f'Invalid weight ({weight}) in _plot_edges_one_plane, '+ + 'set to a default value of 0.1') + weight = 0.1 + edges = denoise_tv_chambolle(recon_plane, weight=weight) + vmax = np.max(edges[0,:,:]) + vmin = -vmax + if path is None: + path = self.output_folder + quick_imshow(edges[0,:,:], f'{title} coolwarm', path=path, cmap='coolwarm', + save_fig=self.save_figs, save_only=self.save_only, block=self.block) + quick_imshow(edges[0,:,:], f'{title} gray', path=path, cmap='gray', vmin=vmin, vmax=vmax, + save_fig=self.save_figs, save_only=self.save_only, block=self.block) + del edges + + def _reconstruct_one_tomo_stack(self, tomo_stack, thetas, center_offsets=[], num_core=1, + algorithm='gridrec'): + """Reconstruct a single tomography stack. + """ + # tomo_stack order: row,theta,column + # input thetas must be in degrees + # centers_offset: tomography axis shift in pixels relative to column center + # RV should we remove stripes? + # https://tomopy.readthedocs.io/en/latest/api/tomopy.prep.stripe.html + # RV should we remove rings? + # https://tomopy.readthedocs.io/en/latest/api/tomopy.misc.corr.html + # RV: Add an option to do (extra) secondary iterations later or to do some sort of convergence test? + if not len(center_offsets): + centers = np.zeros((tomo_stack.shape[0])) + elif len(center_offsets) == 2: + centers = np.linspace(center_offsets[0], center_offsets[1], tomo_stack.shape[0]) + else: + if center_offsets.size != tomo_stack.shape[0]: + raise ValueError('center_offsets dimension mismatch in reconstruct_one_tomo_stack') + centers = center_offsets + centers += tomo_stack.shape[2]/2 + + # Get reconstruction parameters + recon_parameters = None#self.config.get('recon_parameters') + if recon_parameters is None: + sigma = 2.0 + secondary_iters = 0 + ring_width = 15 + else: + sigma = recon_parameters.get('stripe_fw_sigma', 2.0) + if not is_num(sigma, ge=0): + logger.warning(f'Invalid stripe_fw_sigma ({sigma}) in '+ + '_reconstruct_one_tomo_stack, set to a default value of 2.0') + ring_width = 15 + secondary_iters = recon_parameters.get('secondary_iters', 0) + if not is_int(secondary_iters, ge=0): + logger.warning(f'Invalid secondary_iters ({secondary_iters}) in '+ + '_reconstruct_one_tomo_stack, set to a default value of 0 (skip them)') + ring_width = 0 + ring_width = recon_parameters.get('ring_width', 15) + if not is_int(ring_width, ge=0): + logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+ + 'set to a default value of 15') + ring_width = 15 + + # Remove horizontal stripe + t0 = time() + if num_core > num_core_tomopy_limit: + logger.debug('Running remove_stripe_fw on {num_core_tomopy_limit} cores ...') + tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma, + ncore=num_core_tomopy_limit) + else: + logger.debug(f'Running remove_stripe_fw on {num_core} cores ...') + tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma, + ncore=num_core) + logger.debug(f'... tomopy.prep.stripe.remove_stripe_fw took {time()-t0:.2f} seconds') + + # Perform initial image reconstruction + logger.debug('Performing initial image reconstruction') + t0 = time() + logger.debug(f'Running recon on {num_core} cores ...') + tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers, + sinogram_order=True, algorithm=algorithm, ncore=num_core) + logger.debug(f'... done in {time()-t0:.2f} seconds') + logger.info(f'Performing initial image reconstruction took {time()-t0:.2f} seconds') + + # Run optional secondary iterations + if secondary_iters > 0: + logger.debug(f'Running {secondary_iters} secondary iterations') + #options = {'method':'SIRT_CUDA', 'proj_type':'cuda', 'num_iter':secondary_iters} + #RV: doesn't work for me: + #"Error: CUDA error 803: system has unsupported display driver/cuda driver combination." + #options = {'method':'SIRT', 'proj_type':'linear', 'MinConstraint': 0, 'num_iter':secondary_iters} + #SIRT did not finish while running overnight + #options = {'method':'SART', 'proj_type':'linear', 'num_iter':secondary_iters} + options = {'method':'SART', 'proj_type':'linear', 'MinConstraint': 0, + 'num_iter':secondary_iters} + t0 = time() + logger.debug(f'Running recon on {num_core} cores ...') + tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers, + init_recon=tomo_recon_stack, options=options, sinogram_order=True, + algorithm=tomopy.astra, ncore=num_core) + logger.debug(f'... done in {time()-t0:.2f} seconds') + logger.info(f'Performing secondary iterations took {time()-t0:.2f} seconds') + + # Remove ring artifacts + t0 = time() + tomopy.misc.corr.remove_ring(tomo_recon_stack, rwidth=ring_width, out=tomo_recon_stack, + ncore=num_core) + logger.debug(f'Removing ring artifacts took {time()-t0:.2f} seconds') + + return tomo_recon_stack + + def _resize_reconstructed_data(self, data, z_only=False): + """Resize the reconstructed tomography data. + """ + # Data order: row(z),x,y or stack,row(z),x,y + if isinstance(data, list): + for stack in data: + assert(stack.ndim == 3) + num_tomo_stacks = len(data) + tomo_recon_stacks = data + else: + assert(data.ndim == 3) + num_tomo_stacks = 1 + tomo_recon_stacks = [data] + + if z_only: + x_bounds = None + y_bounds = None + else: + # Selecting x bounds (in yz-plane) + tomosum = 0 + [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,2)) + for i in range(num_tomo_stacks)] + select_x_bounds = input_yesno('\nDo you want to change the image x-bounds (y/n)?', 'y') + if not select_x_bounds: + x_bounds = None + else: + accept = False + index_ranges = None + while not accept: + mask, x_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, + title='select x data range', legend='recon stack sum yz') + while len(x_bounds) != 1: + print('Please select exactly one continuous range') + mask, x_bounds = draw_mask_1d(tomosum, title='select x data range', + legend='recon stack sum yz') + x_bounds = x_bounds[0] +# quick_plot(tomosum, vlines=x_bounds, title='recon stack sum yz') +# print(f'x_bounds = {x_bounds} (lower bound inclusive, upper bound '+ +# 'exclusive)') +# accept = input_yesno('Accept these bounds (y/n)?', 'y') + accept = True + logger.debug(f'x_bounds = {x_bounds}') + + # Selecting y bounds (in xz-plane) + tomosum = 0 + [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,1)) + for i in range(num_tomo_stacks)] + select_y_bounds = input_yesno('\nDo you want to change the image y-bounds (y/n)?', 'y') + if not select_y_bounds: + y_bounds = None + else: + accept = False + index_ranges = None + while not accept: + mask, y_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, + title='select x data range', legend='recon stack sum xz') + while len(y_bounds) != 1: + print('Please select exactly one continuous range') + mask, y_bounds = draw_mask_1d(tomosum, title='select x data range', + legend='recon stack sum xz') + y_bounds = y_bounds[0] +# quick_plot(tomosum, vlines=y_bounds, title='recon stack sum xz') +# print(f'y_bounds = {y_bounds} (lower bound inclusive, upper bound '+ +# 'exclusive)') +# accept = input_yesno('Accept these bounds (y/n)?', 'y') + accept = True + logger.debug(f'y_bounds = {y_bounds}') + + # Selecting z bounds (in xy-plane) (only valid for a single image stack) + if num_tomo_stacks != 1: + z_bounds = None + else: + tomosum = 0 + [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(1,2)) + for i in range(num_tomo_stacks)] + select_z_bounds = input_yesno('Do you want to change the image z-bounds (y/n)?', 'n') + if not select_z_bounds: + z_bounds = None + else: + accept = False + index_ranges = None + while not accept: + mask, z_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, + title='select x data range', legend='recon stack sum xy') + while len(z_bounds) != 1: + print('Please select exactly one continuous range') + mask, z_bounds = draw_mask_1d(tomosum, title='select x data range', + legend='recon stack sum xy') + z_bounds = z_bounds[0] +# quick_plot(tomosum, vlines=z_bounds, title='recon stack sum xy') +# print(f'z_bounds = {z_bounds} (lower bound inclusive, upper bound '+ +# 'exclusive)') +# accept = input_yesno('Accept these bounds (y/n)?', 'y') + accept = True + logger.debug(f'z_bounds = {z_bounds}') + + return(x_bounds, y_bounds, z_bounds) + + +def run_tomo(input_file:str, output_file:str, modes:list[str], center_file=None, num_core=-1, + output_folder='.', save_figs='no', test_mode=False) -> None: + + if test_mode: + logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s' + level = logging.getLevelName('INFO') + logging.basicConfig(filename=f'{output_folder}/tomo.log', filemode='w', + format=logging_format, level=level, force=True) + logger.info(f'input_file = {input_file}') + logger.info(f'center_file = {center_file}') + logger.info(f'output_file = {output_file}') + logger.debug(f'modes= {modes}') + logger.debug(f'num_core= {num_core}') + logger.info(f'output_folder = {output_folder}') + logger.info(f'save_figs = {save_figs}') + logger.info(f'test_mode = {test_mode}') + + # Check for correction modes + legal_modes = ['reduce_data', 'find_center', 'reconstruct_data', 'combine_data', 'all'] + if modes is None: + modes = ['all'] + if not all(True if mode in legal_modes else False for mode in modes): + raise ValueError(f'Invalid parameter modes ({modes})') + + # Instantiate Tomo object + tomo = Tomo(num_core=num_core, output_folder=output_folder, save_figs=save_figs, + test_mode=test_mode) + + # Read input file + data = tomo.read(input_file) + + # Generate reduced tomography images + if 'reduce_data' in modes or 'all' in modes: + data = tomo.gen_reduced_data(data) + + # Find rotation axis centers for the tomography stacks. + center_data = None + if 'find_center' in modes or 'all' in modes: + center_data = tomo.find_centers(data) + + # Reconstruct tomography stacks + if 'reconstruct_data' in modes or 'all' in modes: + if center_data is None: + # Read input file + center_data = tomo.read(center_file) + data = tomo.reconstruct_data(data, center_data) + center_data = None + + # Combine reconstructed tomography stacks + if 'combine_data' in modes or 'all' in modes: + data = tomo.combine_data(data) + + # Write output file + if data is not None and not test_mode: + if center_data is None: + data = tomo.write(data, output_file) + else: + data = tomo.write(center_data, output_file) + + logger.info(f'Completed modes: {modes}')
