Mercurial > repos > rv43 > tomo
changeset 69:fba792d5f83b draft
planemo upload for repository https://github.com/rolfverberg/galaxytools commit ab9f412c362a4ab986d00e21d5185cfcf82485c1
author | rv43 |
---|---|
date | Fri, 10 Mar 2023 16:02:04 +0000 |
parents | ba5866d0251d |
children | 97c4e2cbbad9 |
files | detector.py fit.py general.py run_link_to_galaxy run_tomo_all run_tomo_combine run_tomo_find_center run_tomo_reconstruct run_tomo_reduce sobhani-3249-A.yaml tomo.py tomo_combine.py tomo_find_center.py tomo_macros.xml tomo_reconstruct.py tomo_reduce.py tomo_reduce.xml tomo_setup.py tomo_setup.xml workflow/__main__.py workflow/__version__.py workflow/link_to_galaxy.py workflow/models.py workflow/run_tomo.py |
diffstat | 24 files changed, 6774 insertions(+), 4578 deletions(-) [+] |
line wrap: on
line diff
--- a/detector.py Fri Aug 19 20:16:56 2022 +0000 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,348 +0,0 @@ -import logging -import os -import yaml -from functools import cache -from copy import deepcopy - -from general import illegal_value, is_int, is_num, input_yesno - -#from hexrd.instrument import HEDMInstrument, PlanarDetector - -class DetectorConfig: - def __init__(self, config_source): - self._config_source = config_source - - if isinstance(self._config_source, ((str, bytes, os.PathLike, int))): - self._config_file = self._config_source - self._config = self._load_config_file() - elif isinstance(self._config_source, dict): - self._config_file = None - self._config = self._config_source - else: - self._config_file = None - self._config = False - - self._valid = self._validate() - - if not self.valid: - logging.error(f'Cannot create a valid instance of {self.__class__.__name__} '+ - f'from {self._config_source}') - - def __repr__(self): - return(f'{self.__class__.__name__}({self._config_source.__repr__()})') - def __str__(self): - return(f'{self.__class__.__name__} generated from {self._config_source}') - - @property - def config_file(self): - return(self._config_file) - - @property - def config(self): - return(deepcopy(self._config)) - - @property - def valid(self): - return(self._valid) - - def load_config_file(self): - raise(NotImplementedError) - - def validate(self): - raise(NotImplementedError) - - def _load_config_file(self): - if not os.path.isfile(self.config_file): - logging.error(f'{self.config_file} is not a file.') - return(False) - else: - return(self.load_config_file()) - - def _validate(self): - if not self.config: - logging.error('A configuration must be loaded prior to calling Detector._validate') - return(False) - else: - return(self.validate()) - - def _write_to_file(self, out_file): - out_file = os.path.abspath(out_file) - - current_config_valid = self.validate() - if not current_config_valid: - write_invalid_config = input_yesno(s=f'This {self.__class__.__name__} is currently '+ - f'invalid. Write the configuration to {out_file} anyways?', default='no') - if not write_invalid_config: - logging.info('In accordance with user input, the invalid configuration will '+ - f'not be written to {out_file}') - return - - if os.access(out_file, os.W_OK): - if os.path.exists(out_file): - overwrite = input_yesno(s=f'{out_file} already exists. Overwrite?', default='no') - if overwrite: - self.write_to_file(out_file) - else: - logging.info(f'In accordance with user input, {out_file} will not be '+ - 'overwritten') - else: - self.write_to_file(out_file) - else: - logging.error(f'Insufficient permissions to write to {out_file}') - - def write_to_file(self, out_file): - raise(NotImplementedError) - -class YamlDetectorConfig(DetectorConfig): - def __init__(self, config_source, validate_yaml_pars=[]): - self._validate_yaml_pars = validate_yaml_pars - super().__init__(config_source) - - def load_config_file(self): - if not os.path.splitext(self._config_file)[1]: - if os.path.isfile(f'{self._config_file}.yml'): - self._config_file = f'{self._config_file}.yml' - if os.path.isfile(f'{self._config_file}.yaml'): - self._config_file = f'{self._config_file}.yaml' - if not os.path.isfile(self._config_file): - logging.error(f'Unable to load {self._config_file}') - return(False) - with open(self._config_file, 'r') as infile: - config = yaml.safe_load(infile) - if isinstance(config, dict): - return(config) - else: - logging.error(f'Unable to load {self._config_file} as a dictionary') - return(False) - - def validate(self): - if not self._validate_yaml_pars: - logging.warning('There are no required parameters provided for this detector '+ - 'configuration') - return(True) - - def validate_nested_pars(config, validate_yaml_par): - yaml_par_levels = validate_yaml_par.split(':') - first_level_par = yaml_par_levels[0] - try: - first_level_par = int(first_level_par) - except: - pass - try: - next_level_config = config[first_level_par] - if len(yaml_par_levels) > 1: - next_level_pars = ':'.join(yaml_par_levels[1:]) - return(validate_nested_pars(next_level_config, next_level_pars)) - else: - return(True) - except: - return(False) - - pars_missing = [p for p in self._validate_yaml_pars - if not validate_nested_pars(self.config, p)] - if len(pars_missing) > 0: - logging.error(f'Missing item(s) in configuration: {", ".join(pars_missing)}') - return(False) - else: - return(True) - - def write_to_file(self, out_file): - with open(out_file, 'w') as outf: - yaml.dump(self.config, outf) - - -class TomoDetectorConfig(YamlDetectorConfig): - def __init__(self, config_source): - validate_yaml_pars = ['detector', - 'lens_magnification', - 'detector:pixels:rows', - 'detector:pixels:columns', - *[f'detector:pixels:size:{i}' for i in range(2)]] - super().__init__(config_source, validate_yaml_pars=validate_yaml_pars) - - @property - @cache - def lens_magnification(self): - lens_magnification = self.config.get('lens_magnification') - if not isinstance(lens_magnification, (int, float)) or lens_magnification <= 0.: - illegal_value(lens_magnification, 'lens_magnification', 'detector file') - logging.warning('Using default lens_magnification value of 1.0') - return(1.0) - else: - return(lens_magnification) - - @property - @cache - def pixel_size(self): - pixel_size = self.config['detector'].get('pixels').get('size') - if isinstance(pixel_size, (int, float)): - if pixel_size <= 0.: - illegal_value(pixel_size, 'pixel_size', 'detector file') - return(None) - pixel_size /= self.lens_magnification - elif isinstance(pixel_size, list): - if ((len(pixel_size) > 2) or - (len(pixel_size) == 2 and pixel_size[0] != pixel_size[1])): - illegal_value(pixel_size, 'pixel size', 'detector file') - return(None) - elif not is_num(pixel_size[0], 0.): - illegal_value(pixel_size, 'pixel size', 'detector file') - return(None) - else: - pixel_size = pixel_size[0]/self.lens_magnification - else: - illegal_value(pixel_size, 'pixel size', 'detector file') - return(None) - - return(pixel_size) - - @property - @cache - def dimensions(self): - pixels = self.config['detector'].get('pixels') - num_rows = pixels.get('rows') - if not is_int(num_rows, 1): - illegal_value(num_rows, 'rows', 'detector file') - return(None) - num_columns = pixels.get('columns') - if not is_int(num_columns, 1): - illegal_value(num_columns, 'columns', 'detector file') - return(None) - return(num_rows, num_columns) - - -class EDDDetectorConfig(YamlDetectorConfig): - def __init__(self, config_source): - validate_yaml_pars = ['num_bins', - 'max_E', - # 'angle', # KLS leave this out for now -- I think it has to do with the relative geometry of sample, beam, and detector (not a property of the detector on its own), so may not belong here in the DetectorConfig object? - 'tth_angle', - 'slope', - 'intercept'] - super().__init__(config_source, validate_yaml_pars=validate_yaml_pars) - - @property - @cache - def num_bins(self): - try: - num_bins = int(self.config['num_bins']) - if num_bins <= 0: - raise(ValueError) - else: - return(num_bins) - except: - illegal_value(self.config['num_bins'], 'num_bins') - @property - @cache - def max_E(self): - try: - max_E = float(self.config['max_E']) - if max_E <= 0: - raise(ValueError) - else: - return(max_E) - except: - illegal_value(self.config['max_E'], 'max_E') - return(None) - - @property - def bin_energies(self): - return(self.slope * np.linspace(0, self.max_E, self.num_bins, endpoint=False) + - self.intercept) - - @property - def tth_angle(self): - try: - return(float(self.config['tth_angle'])) - except: - illegal_value(tth_angle, 'tth_angle') - return(None) - @tth_angle.setter - def tth_angle(self, value): - try: - self._config['tth_angle'] = float(value) - except: - illegal_value(value, 'tth_angle') - - @property - def slope(self): - try: - return(float(self.config['slope'])) - except: - illegal_value(slope, 'slope') - return(None) - @slope.setter - def slope(self, value): - try: - self._config['slope'] = float(value) - except: - illegal_value(value, 'slope') - - @property - def intercept(self): - try: - return(float(self.config['intercept'])) - except: - illegal_value(intercept, 'intercept') - return(None) - @intercept.setter - def intercept(self, value): - try: - self._config['intercept'] = float(value) - except: - illegal_value(value, 'intercept') - - -# class HexrdDetectorConfig(YamlDetectorConfig): -# def __init__(self, config_source, detector_names=[]): -# self.detector_names = detector_names -# validate_yaml_pars_each_detector = [*[f'buffer:{i}' for i in range(2)], -# 'distortion:function_name', -# *[f'distortion:parameters:{i}' for i in range(6)], -# 'pixels:columns', -# 'pixels:rows', -# *['pixels:size:%i' % i for i in range(2)], -# 'saturation_level', -# *[f'transform:tilt:{i}' for i in range(3)], -# *[f'transform:translation:{i}' for i in range(3)]] -# validate_yaml_pars = [] -# for detector_name in self.detector_names: -# validate_yaml_pars += [f'detectors:{detector_name}:{par}' for par in validate_yaml_pars_each_detector] - -# super().__init__(config_source, validate_yaml_pars=validate_yaml_pars) - -# def validate(self): -# yaml_valid = YamlDetectorConfig.validate(self) -# if not yaml_valid: -# return(False) -# else: -# hedm_instrument = HEDMInstrument(instrument_config=self.config) -# for detector_name in self.detector_names: -# if detector_name in hedm_instrument.detectors: -# if isinstance(hedm_instrument.detectors[detector_name], PlanarDetector): -# continue -# else: -# return(False) -# else: -# return(False) -# return(True) - -# class SAXSWAXSDetectorConfig(DetectorConfig): -# def __init__(self, config_source): -# super().__init__(config_source) - -# @property -# def ai(self): -# return(self.config) -# @ai.setter -# def ai(self, value): -# if isinstance(value, pyFAI.azimuthalIntegrator.AzimuthalIntegrator): -# self.config = ai -# else: -# illegal_value(value, 'azimuthal integrator') - -# # pyFAI will perform its own error-checking for the mask attribute. -# mask = property(self.ai.get_mask, self.ai,set_mask) - - -
--- a/fit.py Fri Aug 19 20:16:56 2022 +0000 +++ b/fit.py Fri Mar 10 16:02:04 2023 +0000 @@ -7,45 +7,118 @@ @author: rv43 """ -import sys -import re import logging -import numpy as np -from asteval import Interpreter +from asteval import Interpreter, get_ast_names from copy import deepcopy -#from lmfit import Minimizer from lmfit import Model, Parameters +from lmfit.model import ModelResult from lmfit.models import ConstantModel, LinearModel, QuadraticModel, PolynomialModel,\ - StepModel, RectangleModel, GaussianModel, LorentzianModel + ExponentialModel, StepModel, RectangleModel, ExpressionModel, GaussianModel,\ + LorentzianModel +import numpy as np +from os import cpu_count, getpid, listdir, mkdir, path +from re import compile, sub +from shutil import rmtree +from sympy import diff, simplify +try: + from joblib import Parallel, delayed + have_joblib = True +except: + have_joblib = False +try: + import xarray as xr + have_xarray = True +except: + have_xarray = False -from general import is_index, index_nearest, quickPlot +from .general import illegal_value, is_int, is_dict_series, is_index, index_nearest, \ + almost_equal, quick_plot #, eval_expr +#from sys import path as syspath +#syspath.append(f'/nfs/chess/user/rv43/msnctools/msnctools') +#from general import illegal_value, is_int, is_dict_series, is_index, index_nearest, \ +# almost_equal, quick_plot #, eval_expr + +from sys import float_info +float_min = float_info.min +float_max = float_info.max # sigma = fwhm_factor*fwhm fwhm_factor = { - 'gaussian' : f'fwhm/(2*sqrt(2*log(2)))', - 'lorentzian' : f'0.5*fwhm', - 'splitlorentzian' : f'0.5*fwhm', # sigma = sigma_r - 'voight' : f'0.2776*fwhm', # sigma = gamma - 'pseudovoight' : f'0.5*fwhm'} # fraction = 0.5 + 'gaussian': f'fwhm/(2*sqrt(2*log(2)))', + 'lorentzian': f'0.5*fwhm', + 'splitlorentzian': f'0.5*fwhm', # sigma = sigma_r + 'voight': f'0.2776*fwhm', # sigma = gamma + 'pseudovoight': f'0.5*fwhm'} # fraction = 0.5 # amplitude = height_factor*height*fwhm height_factor = { - 'gaussian' : f'height*fwhm*0.5*sqrt(pi/log(2))', - 'lorentzian' : f'height*fwhm*0.5*pi', - 'splitlorentzian' : f'height*fwhm*0.5*pi', # sigma = sigma_r - 'voight' : f'3.334*height*fwhm', # sigma = gamma - 'pseudovoight' : f'1.268*height*fwhm'} # fraction = 0.5 + 'gaussian': f'height*fwhm*0.5*sqrt(pi/log(2))', + 'lorentzian': f'height*fwhm*0.5*pi', + 'splitlorentzian': f'height*fwhm*0.5*pi', # sigma = sigma_r + 'voight': f'3.334*height*fwhm', # sigma = gamma + 'pseudovoight': f'1.268*height*fwhm'} # fraction = 0.5 class Fit: """Wrapper class for lmfit """ - def __init__(self, x, y, models=None, **kwargs): - self._x = x - self._y = y + def __init__(self, y, x=None, models=None, normalize=True, **kwargs): + if not isinstance(normalize, bool): + raise ValueError(f'Invalid parameter normalize ({normalize})') + self._mask = None self._model = None + self._norm = None + self._normalized = False self._parameters = Parameters() + self._parameter_bounds = None + self._parameter_norms = {} + self._linear_parameters = [] + self._nonlinear_parameters = [] self._result = None + self._try_linear_fit = True + self._y = None + self._y_norm = None + self._y_range = None + if 'try_linear_fit' in kwargs: + try_linear_fit = kwargs.pop('try_linear_fit') + if not isinstance(try_linear_fit, bool): + illegal_value(try_linear_fit, 'try_linear_fit', 'Fit.fit', raise_error=True) + self._try_linear_fit = try_linear_fit + if y is not None: + if isinstance(y, (tuple, list, np.ndarray)): + self._x = np.asarray(x) + elif have_xarray and isinstance(y, xr.DataArray): + if x is not None: + logging.warning('Ignoring superfluous input x ({x}) in Fit.__init__') + if y.ndim != 1: + illegal_value(y.ndim, 'DataArray dimensions', 'Fit:__init__', raise_error=True) + self._x = np.asarray(y[y.dims[0]]) + else: + illegal_value(y, 'y', 'Fit:__init__', raise_error=True) + self._y = y + if self._x.ndim != 1: + raise ValueError(f'Invalid dimension for input x ({self._x.ndim})') + if self._x.size != self._y.size: + raise ValueError(f'Inconsistent x and y dimensions ({self._x.size} vs '+ + f'{self._y.size})') + if 'mask' in kwargs: + self._mask = kwargs.pop('mask') + if self._mask is None: + y_min = float(self._y.min()) + self._y_range = float(self._y.max())-y_min + if normalize and self._y_range > 0.0: + self._norm = (y_min, self._y_range) + else: + self._mask = np.asarray(self._mask).astype(bool) + if self._x.size != self._mask.size: + raise ValueError(f'Inconsistent x and mask dimensions ({self._x.size} vs '+ + f'{self._mask.size})') + y_masked = np.asarray(self._y)[~self._mask] + y_min = float(y_masked.min()) + self._y_range = float(y_masked.max())-y_min + if normalize and self._y_range > 0.0: + if normalize and self._y_range > 0.0: + self._norm = (y_min, self._y_range) if models is not None: if callable(models) or isinstance(models, str): kwargs = self.add_model(models, **kwargs) @@ -55,73 +128,173 @@ self.fit(**kwargs) @classmethod - def fit_data(cls, x, y, models, **kwargs): - return cls(x, y, models, **kwargs) + def fit_data(cls, y, models, x=None, normalize=True, **kwargs): + return(cls(y, x=x, models=models, normalize=normalize, **kwargs)) @property def best_errors(self): if self._result is None: - return None - errors = {} - names = sorted(self._result.params) - for name in names: - par = self._result.params[name] - errors[name] = par.stderr - return errors + return(None) + return({name:self._result.params[name].stderr for name in sorted(self._result.params) + if name != 'tmp_normalization_offset_c'}) @property def best_fit(self): if self._result is None: - return None - return self._result.best_fit + return(None) + return(self._result.best_fit) @property def best_parameters(self): if self._result is None: - return None - parameters = [] - names = sorted(self._result.params) - for name in names: - par = self._result.params[name] - parameters.append({'name' : par.name, 'value' : par.value, 'error' : par.stderr, - 'init_value' : par.init_value, 'min' : par.min, 'max' : par.max, - 'vary' : par.vary, 'expr' : par.expr}) - return parameters + return(None) + parameters = {} + for name in sorted(self._result.params): + if name != 'tmp_normalization_offset_c': + par = self._result.params[name] + parameters[name] = {'value': par.value, 'error': par.stderr, + 'init_value': par.init_value, 'min': par.min, 'max': par.max, + 'vary': par.vary, 'expr': par.expr} + return(parameters) + + @property + def best_results(self): + """Convert the input data array to a data set and add the fit results. + """ + if self._result is None: + return(None) + if isinstance(self._y, xr.DataArray): + best_results = self._y.to_dataset() + dims = self._y.dims + fit_name = f'{self._y.name}_fit' + else: + coords = {'x': (['x'], self._x)} + dims = ('x') + best_results = xr.Dataset(coords=coords) + best_results['y'] = (dims, self._y) + fit_name = 'y_fit' + best_results[fit_name] = (dims, self.best_fit) + if self._mask is not None: + best_results['mask'] = self._mask + best_results.coords['par_names'] = ('peak', [name for name in self.best_values.keys()]) + best_results['best_values'] = (['par_names'], [v for v in self.best_values.values()]) + best_results['best_errors'] = (['par_names'], [v for v in self.best_errors.values()]) + best_results.attrs['components'] = self.components + return(best_results) @property def best_values(self): if self._result is None: - return None - return self._result.params.valuesdict() + return(None) + return({name:self._result.params[name].value for name in sorted(self._result.params) + if name != 'tmp_normalization_offset_c'}) @property def chisqr(self): - return self._result.chisqr + if self._result is None: + return(None) + return(self._result.chisqr) + + @property + def components(self): + components = {} + if self._result is None: + logging.warning('Unable to collect components in Fit.components') + return(components) + for component in self._result.components: + if 'tmp_normalization_offset_c' in component.param_names: + continue + parameters = {} + for name in component.param_names: + par = self._parameters[name] + parameters[name] = {'free': par.vary, 'value': self._result.params[name].value} + if par.expr is not None: + parameters[name]['expr'] = par.expr + expr = None + if isinstance(component, ExpressionModel): + name = component._name + if name[-1] == '_': + name = name[:-1] + expr = component.expr + else: + prefix = component.prefix + if len(prefix): + if prefix[-1] == '_': + prefix = prefix[:-1] + name = f'{prefix} ({component._name})' + else: + name = f'{component._name}' + if expr is None: + components[name] = {'parameters': parameters} + else: + components[name] = {'expr': expr, 'parameters': parameters} + return(components) @property def covar(self): - return self._result.covar + if self._result is None: + return(None) + return(self._result.covar) + + @property + def init_parameters(self): + if self._result is None or self._result.init_params is None: + return(None) + parameters = {} + for name in sorted(self._result.init_params): + if name != 'tmp_normalization_offset_c': + par = self._result.init_params[name] + parameters[name] = {'value': par.value, 'min': par.min, 'max': par.max, + 'vary': par.vary, 'expr': par.expr} + return(parameters) @property def init_values(self): + if self._result is None or self._result.init_params is None: + return(None) + return({name:self._result.init_params[name].value for name in + sorted(self._result.init_params) if name != 'tmp_normalization_offset_c'}) + + @property + def normalization_offset(self): if self._result is None: - return None - return self._result.init_values + return(None) + if self._norm is None: + return(0.0) + else: + if self._result.init_params is not None: + normalization_offset = self._result.init_params['tmp_normalization_offset_c'] + else: + normalization_offset = self._result.params['tmp_normalization_offset_c'] + return(normalization_offset) @property def num_func_eval(self): - return self._result.nfev + if self._result is None: + return(None) + return(self._result.nfev) + + @property + def parameters(self): + return({name:{'min': par.min, 'max': par.max, 'vary': par.vary, 'expr': par.expr} + for name, par in self._parameters.items() if name != 'tmp_normalization_offset_c'}) @property def redchi(self): - return self._result.redchi + if self._result is None: + return(None) + return(self._result.redchi) @property def residual(self): - return self._result.residual + if self._result is None: + return(None) + return(self._result.residual) @property def success(self): + if self._result is None: + return(None) if not self._result.success: # print(f'ier = {self._result.ier}') # print(f'lmdif_message = {self._result.lmdif_message}') @@ -133,183 +306,692 @@ logging.warning(f'ier = {self._result.ier}: {self._result.message}') else: logging.warning(f'ier = {self._result.ier}: {self._result.message}') - return True + return(True) # self.print_fit_report() # self.plot() - return self._result.success + return(self._result.success) @property def var_names(self): """Intended to be used with covar """ if self._result is None: - return None - return self._result.var_names + return(None) + return(getattr(self._result, 'var_names', None)) + + @property + def x(self): + return(self._x) - def print_fit_report(self, show_correl=False): - if self._result is not None: - print(self._result.fit_report(show_correl=show_correl)) + @property + def y(self): + return(self._y) + + def print_fit_report(self, result=None, show_correl=False): + if result is None: + result = self._result + if result is not None: + print(result.fit_report(show_correl=show_correl)) def add_parameter(self, **parameter): if not isinstance(parameter, dict): - illegal_value(parameter, 'parameter', 'add_parameter') - return + raise ValueError(f'Invalid parameter ({parameter})') + if parameter.get('expr') is not None: + raise KeyError(f'Illegal "expr" key in parameter {parameter}') + name = parameter['name'] + if not isinstance(name, str): + raise ValueError(f'Illegal "name" value ({name}) in parameter {parameter}') + if parameter.get('norm') is None: + self._parameter_norms[name] = False + else: + norm = parameter.pop('norm') + if self._norm is None: + logging.warning(f'Ignoring norm in parameter {name} in '+ + f'Fit.add_parameter (normalization is turned off)') + self._parameter_norms[name] = False + else: + if not isinstance(norm, bool): + raise ValueError(f'Illegal "norm" value ({norm}) in parameter {parameter}') + self._parameter_norms[name] = norm + vary = parameter.get('vary') + if vary is not None: + if not isinstance(vary, bool): + raise ValueError(f'Illegal "vary" value ({vary}) in parameter {parameter}') + if not vary: + if 'min' in parameter: + logging.warning(f'Ignoring min in parameter {name} in '+ + f'Fit.add_parameter (vary = {vary})') + parameter.pop('min') + if 'max' in parameter: + logging.warning(f'Ignoring max in parameter {name} in '+ + f'Fit.add_parameter (vary = {vary})') + parameter.pop('max') + if self._norm is not None and name not in self._parameter_norms: + raise ValueError(f'Missing parameter normalization type for paremeter {name}') self._parameters.add(**parameter) - def add_model(self, model, prefix=None, parameters=None, **kwargs): + def add_model(self, model, prefix=None, parameters=None, parameter_norms=None, **kwargs): # Create the new model -# print('\nAt start adding model:') -# self._parameters.pretty_print() +# print(f'at start add_model:\nself._parameters:\n{self._parameters}') +# print(f'at start add_model: kwargs = {kwargs}') +# print(f'parameters = {parameters}') +# print(f'parameter_norms = {parameter_norms}') +# if len(self._parameters.keys()): +# print('\nAt start adding model:') +# self._parameters.pretty_print() +# print(f'parameter_norms:\n{self._parameter_norms}') if prefix is not None and not isinstance(prefix, str): logging.warning('Ignoring illegal prefix: {model} {type(model)}') prefix = None + if prefix is None: + pprefix = '' + else: + pprefix = prefix + if parameters is not None: + if isinstance(parameters, dict): + parameters = (parameters, ) + elif not is_dict_series(parameters): + illegal_value(parameters, 'parameters', 'Fit.add_model', raise_error=True) + parameters = deepcopy(parameters) + if parameter_norms is not None: + if isinstance(parameter_norms, dict): + parameter_norms = (parameter_norms, ) + if not is_dict_series(parameter_norms): + illegal_value(parameter_norms, 'parameter_norms', 'Fit.add_model', raise_error=True) + new_parameter_norms = {} if callable(model): + # Linear fit not yet implemented for callable models + self._try_linear_fit = False + if parameter_norms is None: + if parameters is None: + raise ValueError('Either "parameters" or "parameter_norms" is required in '+ + f'{model}') + for par in parameters: + name = par['name'] + if not isinstance(name, str): + raise ValueError(f'Illegal "name" value ({name}) in input parameters') + if par.get('norm') is not None: + norm = par.pop('norm') + if not isinstance(norm, bool): + raise ValueError(f'Illegal "norm" value ({norm}) in input parameters') + new_parameter_norms[f'{pprefix}{name}'] = norm + else: + for par in parameter_norms: + name = par['name'] + if not isinstance(name, str): + raise ValueError(f'Illegal "name" value ({name}) in input parameters') + norm = par.get('norm') + if norm is None or not isinstance(norm, bool): + raise ValueError(f'Illegal "norm" value ({norm}) in input parameters') + new_parameter_norms[f'{pprefix}{name}'] = norm + if parameters is not None: + for par in parameters: + if par.get('expr') is not None: + raise KeyError(f'Illegal "expr" key ({par.get("expr")}) in parameter '+ + f'{name} for a callable model {model}') + name = par['name'] + if not isinstance(name, str): + raise ValueError(f'Illegal "name" value ({name}) in input parameters') +# RV FIX callable model will need partial deriv functions for any linear pars to get the linearized matrix, so for now skip linear solution option newmodel = Model(model, prefix=prefix) elif isinstance(model, str): - if model == 'constant': + if model == 'constant': # Par: c newmodel = ConstantModel(prefix=prefix) - elif model == 'linear': + new_parameter_norms[f'{pprefix}c'] = True + self._linear_parameters.append(f'{pprefix}c') + elif model == 'linear': # Par: slope, intercept newmodel = LinearModel(prefix=prefix) - elif model == 'quadratic': + new_parameter_norms[f'{pprefix}slope'] = True + new_parameter_norms[f'{pprefix}intercept'] = True + self._linear_parameters.append(f'{pprefix}slope') + self._linear_parameters.append(f'{pprefix}intercept') + elif model == 'quadratic': # Par: a, b, c newmodel = QuadraticModel(prefix=prefix) - elif model == 'gaussian': + new_parameter_norms[f'{pprefix}a'] = True + new_parameter_norms[f'{pprefix}b'] = True + new_parameter_norms[f'{pprefix}c'] = True + self._linear_parameters.append(f'{pprefix}a') + self._linear_parameters.append(f'{pprefix}b') + self._linear_parameters.append(f'{pprefix}c') + elif model == 'gaussian': # Par: amplitude, center, sigma (fwhm, height) newmodel = GaussianModel(prefix=prefix) - elif model == 'step': + new_parameter_norms[f'{pprefix}amplitude'] = True + new_parameter_norms[f'{pprefix}center'] = False + new_parameter_norms[f'{pprefix}sigma'] = False + self._linear_parameters.append(f'{pprefix}amplitude') + self._nonlinear_parameters.append(f'{pprefix}center') + self._nonlinear_parameters.append(f'{pprefix}sigma') + # parameter norms for height and fwhm are needed to get correct errors + new_parameter_norms[f'{pprefix}height'] = True + new_parameter_norms[f'{pprefix}fwhm'] = False + elif model == 'lorentzian': # Par: amplitude, center, sigma (fwhm, height) + newmodel = LorentzianModel(prefix=prefix) + new_parameter_norms[f'{pprefix}amplitude'] = True + new_parameter_norms[f'{pprefix}center'] = False + new_parameter_norms[f'{pprefix}sigma'] = False + self._linear_parameters.append(f'{pprefix}amplitude') + self._nonlinear_parameters.append(f'{pprefix}center') + self._nonlinear_parameters.append(f'{pprefix}sigma') + # parameter norms for height and fwhm are needed to get correct errors + new_parameter_norms[f'{pprefix}height'] = True + new_parameter_norms[f'{pprefix}fwhm'] = False + elif model == 'exponential': # Par: amplitude, decay + newmodel = ExponentialModel(prefix=prefix) + new_parameter_norms[f'{pprefix}amplitude'] = True + new_parameter_norms[f'{pprefix}decay'] = False + self._linear_parameters.append(f'{pprefix}amplitude') + self._nonlinear_parameters.append(f'{pprefix}decay') + elif model == 'step': # Par: amplitude, center, sigma form = kwargs.get('form') if form is not None: - del kwargs['form'] + kwargs.pop('form') if form is None or form not in ('linear', 'atan', 'arctan', 'erf', 'logistic'): - logging.error(f'Illegal form parameter for build-in step model ({form})') - return kwargs + raise ValueError(f'Invalid parameter form for build-in step model ({form})') newmodel = StepModel(prefix=prefix, form=form) - elif model == 'rectangle': + new_parameter_norms[f'{pprefix}amplitude'] = True + new_parameter_norms[f'{pprefix}center'] = False + new_parameter_norms[f'{pprefix}sigma'] = False + self._linear_parameters.append(f'{pprefix}amplitude') + self._nonlinear_parameters.append(f'{pprefix}center') + self._nonlinear_parameters.append(f'{pprefix}sigma') + elif model == 'rectangle': # Par: amplitude, center1, center2, sigma1, sigma2 form = kwargs.get('form') if form is not None: - del kwargs['form'] + kwargs.pop('form') if form is None or form not in ('linear', 'atan', 'arctan', 'erf', 'logistic'): - logging.error(f'Illegal form parameter for build-in rectangle model ({form})') - return kwargs + raise ValueError('Invalid parameter form for build-in rectangle model '+ + f'({form})') newmodel = RectangleModel(prefix=prefix, form=form) + new_parameter_norms[f'{pprefix}amplitude'] = True + new_parameter_norms[f'{pprefix}center1'] = False + new_parameter_norms[f'{pprefix}center2'] = False + new_parameter_norms[f'{pprefix}sigma1'] = False + new_parameter_norms[f'{pprefix}sigma2'] = False + self._linear_parameters.append(f'{pprefix}amplitude') + self._nonlinear_parameters.append(f'{pprefix}center1') + self._nonlinear_parameters.append(f'{pprefix}center2') + self._nonlinear_parameters.append(f'{pprefix}sigma1') + self._nonlinear_parameters.append(f'{pprefix}sigma2') + elif model == 'expression': # Par: by expression + expr = kwargs['expr'] + if not isinstance(expr, str): + raise ValueError(f'Illegal "expr" value ({expr}) in {model}') + kwargs.pop('expr') + if parameter_norms is not None: + logging.warning('Ignoring parameter_norms (normalization determined from '+ + 'linearity)}') + if parameters is not None: + for par in parameters: + if par.get('expr') is not None: + raise KeyError(f'Illegal "expr" key ({par.get("expr")}) in parameter '+ + f'({par}) for an expression model') + if par.get('norm') is not None: + logging.warning(f'Ignoring "norm" key in parameter ({par}) '+ + '(normalization determined from linearity)}') + par.pop('norm') + name = par['name'] + if not isinstance(name, str): + raise ValueError(f'Illegal "name" value ({name}) in input parameters') + ast = Interpreter() + expr_parameters = [name for name in get_ast_names(ast.parse(expr)) + if name != 'x' and name not in self._parameters + and name not in ast.symtable] +# print(f'\nexpr_parameters: {expr_parameters}') +# print(f'expr = {expr}') + if prefix is None: + newmodel = ExpressionModel(expr=expr) + else: + for name in expr_parameters: + expr = sub(rf'\b{name}\b', f'{prefix}{name}', expr) + expr_parameters = [f'{prefix}{name}' for name in expr_parameters] +# print(f'\nexpr_parameters: {expr_parameters}') +# print(f'expr = {expr}') + newmodel = ExpressionModel(expr=expr, name=name) +# print(f'\nnewmodel = {newmodel.__dict__}') +# print(f'params_names = {newmodel._param_names}') +# print(f'params_names = {newmodel.param_names}') + # Remove already existing names + for name in newmodel.param_names.copy(): + if name not in expr_parameters: + newmodel._func_allargs.remove(name) + newmodel._param_names.remove(name) +# print(f'params_names = {newmodel._param_names}') +# print(f'params_names = {newmodel.param_names}') else: - logging.error('Unknown build-in fit model') - return kwargs + raise ValueError(f'Unknown build-in fit model ({model})') else: - illegal_value(model, 'model', 'add_model') - return kwargs + illegal_value(model, 'model', 'Fit.add_model', raise_error=True) # Add the new model to the current one +# print('\nBefore adding model:') +# print(f'\nnewmodel = {newmodel.__dict__}') +# if len(self._parameters): +# self._parameters.pretty_print() if self._model is None: self._model = newmodel else: self._model += newmodel - if self._parameters is None: - self._parameters = newmodel.make_params() - else: - self._parameters += newmodel.make_params() + new_parameters = newmodel.make_params() + self._parameters += new_parameters # print('\nAfter adding model:') +# print(f'\nnewmodel = {newmodel.__dict__}') +# print(f'\nnew_parameters = {new_parameters}') # self._parameters.pretty_print() - # Initialize the model parameters + # Check linearity of expression model paremeters + if isinstance(newmodel, ExpressionModel): + for name in newmodel.param_names: + if not diff(newmodel.expr, name, name): + if name not in self._linear_parameters: + self._linear_parameters.append(name) + new_parameter_norms[name] = True +# print(f'\nADDING {name} TO LINEAR') + else: + if name not in self._nonlinear_parameters: + self._nonlinear_parameters.append(name) + new_parameter_norms[name] = False +# print(f'\nADDING {name} TO NONLINEAR') +# print(f'new_parameter_norms:\n{new_parameter_norms}') + + # Scale the default initial model parameters + if self._norm is not None: + for name, norm in new_parameter_norms.copy().items(): + par = self._parameters.get(name) + if par is None: + new_parameter_norms.pop(name) + continue + if par.expr is None and norm: + value = par.value*self._norm[1] + _min = par.min + _max = par.max + if not np.isinf(_min) and abs(_min) != float_min: + _min *= self._norm[1] + if not np.isinf(_max) and abs(_max) != float_min: + _max *= self._norm[1] + par.set(value=value, min=_min, max=_max) +# print('\nAfter norm defaults:') +# self._parameters.pretty_print() +# print(f'parameters:\n{parameters}') +# print(f'all_parameters:\n{list(self.parameters)}') +# print(f'new_parameter_norms:\n{new_parameter_norms}') +# print(f'parameter_norms:\n{self._parameter_norms}') + + # Initialize the model parameters from parameters if prefix is None: prefix = "" if parameters is not None: - if not isinstance(parameters, (tuple, list)): - illegal_value(parameters, 'parameters', 'add_model') - return kwargs for parameter in parameters: - if not isinstance(parameter, dict): - illegal_value(parameter, 'parameter in parameters', 'add_model') - return kwargs - parameter['name'] = prefix+parameter['name'] - self._parameters.add(**parameter) - for name, value in kwargs.items(): - if isinstance(value, (int, float)): - self._parameters.add(prefix+name, value=value) -# print('\nAt end add_model:') + name = parameter['name'] + if not isinstance(name, str): + raise ValueError(f'Illegal "name" value ({name}) in input parameters') + if name not in new_parameters: + name = prefix+name + parameter['name'] = name + if name not in new_parameters: + logging.warning(f'Ignoring superfluous parameter info for {name}') + continue + if name in self._parameters: + parameter.pop('name') + if 'norm' in parameter: + if not isinstance(parameter['norm'], bool): + illegal_value(parameter['norm'], 'norm', 'Fit.add_model', + raise_error=True) + new_parameter_norms[name] = parameter['norm'] + parameter.pop('norm') + if parameter.get('expr') is not None: + if 'value' in parameter: + logging.warning(f'Ignoring value in parameter {name} '+ + f'(set by expression: {parameter["expr"]})') + parameter.pop('value') + if 'vary' in parameter: + logging.warning(f'Ignoring vary in parameter {name} '+ + f'(set by expression: {parameter["expr"]})') + parameter.pop('vary') + if 'min' in parameter: + logging.warning(f'Ignoring min in parameter {name} '+ + f'(set by expression: {parameter["expr"]})') + parameter.pop('min') + if 'max' in parameter: + logging.warning(f'Ignoring max in parameter {name} '+ + f'(set by expression: {parameter["expr"]})') + parameter.pop('max') + if 'vary' in parameter: + if not isinstance(parameter['vary'], bool): + illegal_value(parameter['vary'], 'vary', 'Fit.add_model', + raise_error=True) + if not parameter['vary']: + if 'min' in parameter: + logging.warning(f'Ignoring min in parameter {name} in '+ + f'Fit.add_model (vary = {parameter["vary"]})') + parameter.pop('min') + if 'max' in parameter: + logging.warning(f'Ignoring max in parameter {name} in '+ + f'Fit.add_model (vary = {parameter["vary"]})') + parameter.pop('max') + self._parameters[name].set(**parameter) + parameter['name'] = name + else: + illegal_value(parameter, 'parameter name', 'Fit.model', raise_error=True) + self._parameter_norms = {**self._parameter_norms, **new_parameter_norms} +# print('\nAfter parameter init:') # self._parameters.pretty_print() +# print(f'parameters:\n{parameters}') +# print(f'new_parameter_norms:\n{new_parameter_norms}') +# print(f'parameter_norms:\n{self._parameter_norms}') +# print(f'kwargs:\n{kwargs}') - return kwargs + # Initialize the model parameters from kwargs + for name, value in {**kwargs}.items(): + full_name = f'{pprefix}{name}' + if full_name in new_parameter_norms and isinstance(value, (int, float)): + kwargs.pop(name) + if self._parameters[full_name].expr is None: + self._parameters[full_name].set(value=value) + else: + logging.warning(f'Ignoring parameter {name} in Fit.fit (set by expression: '+ + f'{self._parameters[full_name].expr})') +# print('\nAfter kwargs init:') +# self._parameters.pretty_print() +# print(f'parameter_norms:\n{self._parameter_norms}') +# print(f'kwargs:\n{kwargs}') + + # Check parameter norms (also need it for expressions to renormalize the errors) + if self._norm is not None and (callable(model) or model == 'expression'): + missing_norm = False + for name in new_parameters.valuesdict(): + if name not in self._parameter_norms: + print(f'new_parameters:\n{new_parameters.valuesdict()}') + print(f'self._parameter_norms:\n{self._parameter_norms}') + logging.error(f'Missing parameter normalization type for {name} in {model}') + missing_norm = True + if missing_norm: + raise ValueError + +# print(f'at end add_model:\nself._parameters:\n{list(self.parameters)}') +# print(f'at end add_model: kwargs = {kwargs}') +# print(f'\nat end add_model: newmodel:\n{newmodel.__dict__}\n') + return(kwargs) def fit(self, interactive=False, guess=False, **kwargs): + # Check inputs if self._model is None: logging.error('Undefined fit model') return - # Current parameter values - pars = self._parameters.valuesdict() - # Apply parameter updates through keyword arguments - for par in set(pars) & set(kwargs): - pars[par] = kwargs.pop(par) - self._parameters[par].set(value=pars[par]) - # Check for uninitialized parameters - for par, value in pars.items(): - if value is None or np.isinf(value) or np.isnan(value): - if interactive: - self._parameters[par].set(value= - input_num(f'Enter an initial value for {par}: ')) - else: - self._parameters[par].set(value=1.0) -# print('\nAt start actual fit:') -# print(f'kwargs = {kwargs}') -# self._parameters.pretty_print() -# print(f'parameters:\n{self._parameters}') -# print(f'x = {self._x}') -# print(f'len(x) = {len(self._x)}') -# print(f'y = {self._y}') -# print(f'len(y) = {len(self._y)}') + if not isinstance(interactive, bool): + illegal_value(interactive, 'interactive', 'Fit.fit', raise_error=True) + if not isinstance(guess, bool): + illegal_value(guess, 'guess', 'Fit.fit', raise_error=True) + if 'try_linear_fit' in kwargs: + try_linear_fit = kwargs.pop('try_linear_fit') + if not isinstance(try_linear_fit, bool): + illegal_value(try_linear_fit, 'try_linear_fit', 'Fit.fit', raise_error=True) + if not self._try_linear_fit: + logging.warning('Ignore superfluous keyword argument "try_linear_fit" (not '+ + 'yet supported for callable models)') + else: + self._try_linear_fit = try_linear_fit +# if self._result is None: +# if 'parameters' in kwargs: +# raise ValueError('Invalid parameter parameters ({kwargs["parameters"]})') +# else: + if self._result is not None: + if guess: + logging.warning('Ignoring input parameter guess in Fit.fit during refitting') + guess = False + + # Check for circular expressions + # FIX TODO +# for name1, par1 in self._parameters.items(): +# if par1.expr is not None: + + # Apply mask if supplied: + if 'mask' in kwargs: + self._mask = kwargs.pop('mask') + if self._mask is not None: + self._mask = np.asarray(self._mask).astype(bool) + if self._x.size != self._mask.size: + raise ValueError(f'Inconsistent x and mask dimensions ({self._x.size} vs '+ + f'{self._mask.size})') + + # Estimate initial parameters with build-in lmfit guess method (only for a single model) +# print(f'\nat start fit: kwargs = {kwargs}') +#RV print('\nAt start of fit:') +#RV self._parameters.pretty_print() +# print(f'parameter_norms:\n{self._parameter_norms}') if guess: - self._parameters = self._model.guess(self._y, x=self._x) - self._result = self._model.fit(self._y, self._parameters, x=self._x, **kwargs) -# print('\nAt end actual fit:') -# print(f'var_names:\n{self._result.var_names}') -# print(f'stderr:\n{np.sqrt(np.diagonal(self._result.covar))}') -# self._parameters.pretty_print() -# print(f'parameters:\n{self._parameters}') + if self._mask is None: + self._parameters = self._model.guess(self._y, x=self._x) + else: + self._parameters = self._model.guess(np.asarray(self._y)[~self._mask], + x=self._x[~self._mask]) +# print('\nAfter guess:') +# self._parameters.pretty_print() + + # Add constant offset for a normalized model + if self._result is None and self._norm is not None and self._norm[0]: + self.add_model('constant', prefix='tmp_normalization_offset_', parameters={'name': 'c', + 'value': -self._norm[0], 'vary': False, 'norm': True}) + #'value': -self._norm[0]/self._norm[1], 'vary': False, 'norm': False}) + + # Adjust existing parameters for refit: + if 'parameters' in kwargs: + parameters = kwargs.pop('parameters') + if isinstance(parameters, dict): + parameters = (parameters, ) + elif not is_dict_series(parameters): + illegal_value(parameters, 'parameters', 'Fit.fit', raise_error=True) + for par in parameters: + name = par['name'] + if name not in self._parameters: + raise ValueError(f'Unable to match {name} parameter {par} to an existing one') + if self._parameters[name].expr is not None: + raise ValueError(f'Unable to modify {name} parameter {par} (currently an '+ + 'expression)') + if par.get('expr') is not None: + raise KeyError(f'Illegal "expr" key in {name} parameter {par}') + self._parameters[name].set(vary=par.get('vary')) + self._parameters[name].set(min=par.get('min')) + self._parameters[name].set(max=par.get('max')) + self._parameters[name].set(value=par.get('value')) +#RV print('\nAfter adjust:') +#RV self._parameters.pretty_print() + + # Apply parameter updates through keyword arguments +# print(f'kwargs = {kwargs}') +# print(f'parameter_norms = {self._parameter_norms}') + for name in set(self._parameters) & set(kwargs): + value = kwargs.pop(name) + if self._parameters[name].expr is None: + self._parameters[name].set(value=value) + else: + logging.warning(f'Ignoring parameter {name} in Fit.fit (set by expression: '+ + f'{self._parameters[name].expr})') + + # Check for uninitialized parameters + for name, par in self._parameters.items(): + if par.expr is None: + value = par.value + if value is None or np.isinf(value) or np.isnan(value): + if interactive: + value = input_num(f'Enter an initial value for {name}', default=1.0) + else: + value = 1.0 + if self._norm is None or name not in self._parameter_norms: + self._parameters[name].set(value=value) + elif self._parameter_norms[name]: + self._parameters[name].set(value=value*self._norm[1]) - def plot(self): - if self._result is None: + # Check if model is linear + try: + linear_model = self._check_linearity_model() + except: + linear_model = False +# print(f'\n\n--------> linear_model = {linear_model}\n') + if kwargs.get('check_only_linearity') is not None: + return(linear_model) + + # Normalize the data and initial parameters +#RV print('\nBefore normalization:') +#RV self._parameters.pretty_print() +# print(f'parameter_norms:\n{self._parameter_norms}') + self._normalize() +# print(f'norm = {self._norm}') +#RV print('\nAfter normalization:') +#RV self._parameters.pretty_print() +# self.print_fit_report() +# print(f'parameter_norms:\n{self._parameter_norms}') + + if linear_model: + # Perform a linear fit by direct matrix solution with numpy + try: + if self._mask is None: + self._fit_linear_model(self._x, self._y_norm) + else: + self._fit_linear_model(self._x[~self._mask], + np.asarray(self._y_norm)[~self._mask]) + except: + linear_model = False + if not linear_model: + # Perform a non-linear fit with lmfit + # Prevent initial values from sitting at boundaries + self._parameter_bounds = {name:{'min': par.min, 'max': par.max} for name, par in + self._parameters.items() if par.vary} + for par in self._parameters.values(): + if par.vary: + par.set(value=self._reset_par_at_boundary(par, par.value)) +# print('\nAfter checking boundaries:') +# self._parameters.pretty_print() + + # Perform the fit +# fit_kws = None +# if 'Dfun' in kwargs: +# fit_kws = {'Dfun': kwargs.pop('Dfun')} +# self._result = self._model.fit(self._y_norm, self._parameters, x=self._x, +# fit_kws=fit_kws, **kwargs) + if self._mask is None: + self._result = self._model.fit(self._y_norm, self._parameters, x=self._x, **kwargs) + else: + self._result = self._model.fit(np.asarray(self._y_norm)[~self._mask], + self._parameters, x=self._x[~self._mask], **kwargs) +#RV print('\nAfter fit:') +# print(f'\nself._result ({self._result}):\n\t{self._result.__dict__}') +#RV self._parameters.pretty_print() +# self.print_fit_report() + + # Set internal parameter values to fit results upon success + if self.success: + for name, par in self._parameters.items(): + if par.expr is None and par.vary: + par.set(value=self._result.params[name].value) +# print('\nAfter update parameter values:') +# self._parameters.pretty_print() + + # Renormalize the data and results + self._renormalize() +#RV print('\nAfter renormalization:') +#RV self._parameters.pretty_print() +# self.print_fit_report() + + def plot(self, y=None, y_title=None, result=None, skip_init=False, plot_comp_legends=False, + plot_residual=False, plot_masked_data=True, **kwargs): + if result is None: + result = self._result + if result is None: return - components = self._result.eval_components() - plots = ((self._x, self._y, '.'), (self._x, self._result.best_fit, 'k-'), - (self._x, self._result.init_fit, 'g-')) - legend = ['data', 'best fit', 'init'] - if len(components) > 1: + plots = [] + legend = [] + if self._mask is None: + mask = np.zeros(self._x.size).astype(bool) + plot_masked_data = False + else: + mask = self._mask + if y is not None: + if not isinstance(y, (tuple, list, np.ndarray)): + illegal_value(y, 'y', 'Fit.plot') + if len(y) != len(self._x): + logging.warning('Ignoring parameter y in Fit.plot (wrong dimension)') + y = None + if y is not None: + if y_title is None or not isinstance(y_title, str): + y_title = 'data' + plots += [(self._x, y, '.')] + legend += [y_title] + if self._y is not None: + plots += [(self._x, np.asarray(self._y), 'b.')] + legend += ['data'] + if plot_masked_data: + plots += [(self._x[mask], np.asarray(self._y)[mask], 'bx')] + legend += ['masked data'] + if isinstance(plot_residual, bool) and plot_residual: + plots += [(self._x[~mask], result.residual, 'k-')] + legend += ['residual'] + plots += [(self._x[~mask], result.best_fit, 'k-')] + legend += ['best fit'] + if not skip_init and hasattr(result, 'init_fit'): + plots += [(self._x[~mask], result.init_fit, 'g-')] + legend += ['init'] + components = result.eval_components(x=self._x[~mask]) + num_components = len(components) + if 'tmp_normalization_offset_' in components: + num_components -= 1 + if num_components > 1: + eval_index = 0 for modelname, y in components.items(): + if modelname == 'tmp_normalization_offset_': + continue + if modelname == '_eval': + modelname = f'eval{eval_index}' + if len(modelname) > 20: + modelname = f'{modelname[0:16]} ...' if isinstance(y, (int, float)): - y *= np.ones(len(self._x)) - plots += ((self._x, y, '--'),) -# if modelname[-1] == '_': -# legend.append(modelname[:-1]) -# else: -# legend.append(modelname) - quickPlot(plots, legend=legend, block=True) + y *= np.ones(self._x[~mask].size) + plots += [(self._x[~mask], y, '--')] + if plot_comp_legends: + if modelname[-1] == '_': + legend.append(modelname[:-1]) + else: + legend.append(modelname) + title = kwargs.get('title') + if title is not None: + kwargs.pop('title') + quick_plot(tuple(plots), legend=legend, title=title, block=True, **kwargs) @staticmethod def guess_init_peak(x, y, *args, center_guess=None, use_max_for_center=True): """ Return a guess for the initial height, center and fwhm for a peak """ +# print(f'\n\nargs = {args}') +# print(f'center_guess = {center_guess}') +# quick_plot(x, y, vlines=center_guess, block=True) center_guesses = None + x = np.asarray(x) + y = np.asarray(y) if len(x) != len(y): - logging.error(f'Illegal x and y lengths ({len(x)}, {len(y)}), skip initial guess') - return None, None, None + logging.error(f'Invalid x and y lengths ({len(x)}, {len(y)}), skip initial guess') + return(None, None, None) if isinstance(center_guess, (int, float)): if len(args): logging.warning('Ignoring additional arguments for single center_guess value') + center_guesses = [center_guess] elif isinstance(center_guess, (tuple, list, np.ndarray)): if len(center_guess) == 1: logging.warning('Ignoring additional arguments for single center_guess value') if not isinstance(center_guess[0], (int, float)): - raise ValueError(f'Illegal center_guess type ({type(center_guess[0])})') + raise ValueError(f'Invalid parameter center_guess ({type(center_guess[0])})') center_guess = center_guess[0] else: if len(args) != 1: - raise ValueError(f'Illegal number of arguments ({len(args)})') + raise ValueError(f'Invalid number of arguments ({len(args)})') n = args[0] if not is_index(n, 0, len(center_guess)): - raise ValueError('Illegal argument') + raise ValueError('Invalid argument') center_guesses = center_guess center_guess = center_guesses[n] elif center_guess is not None: - raise ValueError(f'Illegal center_guess type ({type(center_guess)})') + raise ValueError(f'Invalid center_guess type ({type(center_guess)})') +# print(f'x = {x}') +# print(f'y = {y}') +# print(f'center_guess = {center_guess}') # Sort the inputs index = np.argsort(x) @@ -319,12 +1001,21 @@ # print(f'miny = {miny}') # print(f'x_range = {x[0]} {x[-1]} {len(x)}') # print(f'y_range = {y[0]} {y[-1]} {len(y)}') +# quick_plot(x, y, vlines=center_guess, block=True) # xx = x # yy = y # Set range for current peak +# print(f'n = {n}') # print(f'center_guesses = {center_guesses}') if center_guesses is not None: + if len(center_guesses) > 1: + index = np.argsort(center_guesses) + n = list(index).index(n) +# print(f'n = {n}') +# print(f'index = {index}') + center_guesses = np.asarray(center_guesses)[index] +# print(f'center_guesses = {center_guesses}') if n == 0: low = 0 upp = index_nearest(x, (center_guesses[0]+center_guesses[1])/2) @@ -338,7 +1029,7 @@ # print(f'upp = {upp}') x = x[low:upp] y = y[low:upp] -# quickPlot(x, y, vlines=(x[0], center_guess, x[-1]), block=True) +# quick_plot(x, y, vlines=(x[0], center_guess, x[-1]), block=True) # Estimate FHHM maxy = y.max() @@ -377,7 +1068,7 @@ fwhm_index2 = i break # print(f'fwhm_index2 = {fwhm_index2} {x[fwhm_index2]}') -# quickPlot((x,y,'o'), vlines=(x[fwhm_index1], center, x[fwhm_index2]), block=True) +# quick_plot((x,y,'o'), vlines=(x[fwhm_index1], center, x[fwhm_index2]), block=True) if fwhm_index1 == 0 and fwhm_index2 < len(x)-1: fwhm = 2*(x[fwhm_index2]-center) elif fwhm_index1 > 0 and fwhm_index2 == len(x)-1: @@ -389,53 +1080,442 @@ # print(f'fwhm = {fwhm}') # Return height, center and FWHM -# quickPlot((x,y,'o'), (xx,yy), vlines=(x[fwhm_index1], center, x[fwhm_index2]), block=True) - return height, center, fwhm +# quick_plot((x,y,'o'), (xx,yy), vlines=(x[fwhm_index1], center, x[fwhm_index2]), block=True) + return(height, center, fwhm) + + def _check_linearity_model(self): + """Identify the linearity of all model parameters and check if the model is linear or not + """ + if not self._try_linear_fit: + logging.info('Skip linearity check (not yet supported for callable models)') + return(False) + free_parameters = [name for name, par in self._parameters.items() if par.vary] + for component in self._model.components: + if 'tmp_normalization_offset_c' in component.param_names: + continue + if isinstance(component, ExpressionModel): + for name in free_parameters: + if diff(component.expr, name, name): +# print(f'\t\t{component.expr} is non-linear in {name}') + self._nonlinear_parameters.append(name) + if name in self._linear_parameters: + self._linear_parameters.remove(name) + else: + model_parameters = component.param_names.copy() + for basename, hint in component.param_hints.items(): + name = f'{component.prefix}{basename}' + if hint.get('expr') is not None: + model_parameters.remove(name) + for name in model_parameters: + expr = self._parameters[name].expr + if expr is not None: + for nname in free_parameters: + if name in self._nonlinear_parameters: + if diff(expr, nname): +# print(f'\t\t{component} is non-linear in {nname} (through {name} = "{expr}")') + self._nonlinear_parameters.append(nname) + if nname in self._linear_parameters: + self._linear_parameters.remove(nname) + else: + assert(name in self._linear_parameters) +# print(f'\n\nexpr ({type(expr)}) = {expr}\nnname ({type(nname)}) = {nname}\n\n') + if diff(expr, nname, nname): +# print(f'\t\t{component} is non-linear in {nname} (through {name} = "{expr}")') + self._nonlinear_parameters.append(nname) + if nname in self._linear_parameters: + self._linear_parameters.remove(nname) +# print(f'\nfree parameters:\n\t{free_parameters}') +# print(f'linear parameters:\n\t{self._linear_parameters}') +# print(f'nonlinear parameters:\n\t{self._nonlinear_parameters}\n') + if any(True for name in self._nonlinear_parameters if self._parameters[name].vary): + return(False) + return(True) + + def _fit_linear_model(self, x, y): + """Perform a linear fit by direct matrix solution with numpy + """ + # Construct the matrix and the free parameter vector +# print(f'\nparameters:') +# self._parameters.pretty_print() +# print(f'\nparameter_norms:\n\t{self._parameter_norms}') +# print(f'\nlinear_parameters:\n\t{self._linear_parameters}') +# print(f'nonlinear_parameters:\n\t{self._nonlinear_parameters}') + free_parameters = [name for name, par in self._parameters.items() if par.vary] +# print(f'free parameters:\n\t{free_parameters}\n') + expr_parameters = {name:par.expr for name, par in self._parameters.items() + if par.expr is not None} + model_parameters = [] + for component in self._model.components: + if 'tmp_normalization_offset_c' in component.param_names: + continue + model_parameters += component.param_names + for basename, hint in component.param_hints.items(): + name = f'{component.prefix}{basename}' + if hint.get('expr') is not None: + expr_parameters.pop(name) + model_parameters.remove(name) +# print(f'expr parameters:\n{expr_parameters}') +# print(f'model parameters:\n\t{model_parameters}\n') + norm = 1.0 + if self._normalized: + norm = self._norm[1] +# print(f'\n\nself._normalized = {self._normalized}\nnorm = {norm}\nself._norm = {self._norm}\n') + # Add expression parameters to asteval + ast = Interpreter() +# print(f'Adding to asteval sym table:') + for name, expr in expr_parameters.items(): +# print(f'\tadding {name} {expr}') + ast.symtable[name] = expr + # Add constant parameters to asteval + # (renormalize to use correctly in evaluation of expression models) + for name, par in self._parameters.items(): + if par.expr is None and not par.vary: + if self._parameter_norms[name]: +# print(f'\tadding {name} {par.value*norm}') + ast.symtable[name] = par.value*norm + else: +# print(f'\tadding {name} {par.value}') + ast.symtable[name] = par.value + A = np.zeros((len(x), len(free_parameters)), dtype='float64') + y_const = np.zeros(len(x), dtype='float64') + have_expression_model = False + for component in self._model.components: + if isinstance(component, ConstantModel): + name = component.param_names[0] +# print(f'\nConstant model: {name} {self._parameters[name]}\n') + if name in free_parameters: +# print(f'\t\t{name} is a free constant set matrix column {free_parameters.index(name)} to 1.0') + A[:,free_parameters.index(name)] = 1.0 + else: + if self._parameter_norms[name]: + delta_y_const = self._parameters[name]*np.ones(len(x)) + else: + delta_y_const = (self._parameters[name]*norm)*np.ones(len(x)) + y_const += delta_y_const +# print(f'\ndelta_y_const ({type(delta_y_const)}):\n{delta_y_const}\n') + elif isinstance(component, ExpressionModel): + have_expression_model = True + const_expr = component.expr +# print(f'\nExpression model:\nconst_expr: {const_expr}\n') + for name in free_parameters: + dexpr_dname = diff(component.expr, name) + if dexpr_dname: + const_expr = f'{const_expr}-({str(dexpr_dname)})*{name}' +# print(f'\tconst_expr: {const_expr}') + if not self._parameter_norms[name]: + dexpr_dname = f'({dexpr_dname})/{norm}' +# print(f'\t{component.expr} is linear in {name}\n\t\tadd "{str(dexpr_dname)}" to matrix as column {free_parameters.index(name)}') + fx = [(lambda _: ast.eval(str(dexpr_dname)))(ast(f'x={v}')) for v in x] +# print(f'\tfx:\n{fx}') + if len(ast.error): + raise ValueError(f'Unable to evaluate {dexpr_dname}') + A[:,free_parameters.index(name)] += fx +# if self._parameter_norms[name]: +# print(f'\t\t{component.expr} is linear in {name} add "{str(dexpr_dname)}" to matrix as column {free_parameters.index(name)}') +# A[:,free_parameters.index(name)] += fx +# else: +# print(f'\t\t{component.expr} is linear in {name} add "({str(dexpr_dname)})/{norm}" to matrix as column {free_parameters.index(name)}') +# A[:,free_parameters.index(name)] += np.asarray(fx)/norm + # FIX: find another solution if expr not supported by simplify + const_expr = str(simplify(f'({const_expr})/{norm}')) +# print(f'\nconst_expr: {const_expr}') + delta_y_const = [(lambda _: ast.eval(const_expr))(ast(f'x = {v}')) for v in x] + y_const += delta_y_const +# print(f'\ndelta_y_const ({type(delta_y_const)}):\n{delta_y_const}\n') + if len(ast.error): + raise ValueError(f'Unable to evaluate {const_expr}') + else: + free_model_parameters = [name for name in component.param_names + if name in free_parameters or name in expr_parameters] +# print(f'\nBuild-in model ({component}):\nfree_model_parameters: {free_model_parameters}\n') + if not len(free_model_parameters): + y_const += component.eval(params=self._parameters, x=x) + elif isinstance(component, LinearModel): + if f'{component.prefix}slope' in free_model_parameters: + A[:,free_parameters.index(f'{component.prefix}slope')] = x + else: + y_const += self._parameters[f'{component.prefix}slope'].value*x + if f'{component.prefix}intercept' in free_model_parameters: + A[:,free_parameters.index(f'{component.prefix}intercept')] = 1.0 + else: + y_const += self._parameters[f'{component.prefix}intercept'].value* \ + np.ones(len(x)) + elif isinstance(component, QuadraticModel): + if f'{component.prefix}a' in free_model_parameters: + A[:,free_parameters.index(f'{component.prefix}a')] = x**2 + else: + y_const += self._parameters[f'{component.prefix}a'].value*x**2 + if f'{component.prefix}b' in free_model_parameters: + A[:,free_parameters.index(f'{component.prefix}b')] = x + else: + y_const += self._parameters[f'{component.prefix}b'].value*x + if f'{component.prefix}c' in free_model_parameters: + A[:,free_parameters.index(f'{component.prefix}c')] = 1.0 + else: + y_const += self._parameters[f'{component.prefix}c'].value*np.ones(len(x)) + else: + # At this point each build-in model must be strictly proportional to each linear + # model parameter. Without this assumption, the model equation is needed + # For the current build-in lmfit models, this can only ever be the amplitude + assert(len(free_model_parameters) == 1) + name = f'{component.prefix}amplitude' + assert(free_model_parameters[0] == name) + assert(self._parameter_norms[name]) + expr = self._parameters[name].expr + if expr is None: +# print(f'\t{component} is linear in {name} add to matrix as column {free_parameters.index(name)}') + parameters = deepcopy(self._parameters) + parameters[name].set(value=1.0) + index = free_parameters.index(name) + A[:,free_parameters.index(name)] += component.eval(params=parameters, x=x) + else: + const_expr = expr +# print(f'\tconst_expr: {const_expr}') + parameters = deepcopy(self._parameters) + parameters[name].set(value=1.0) + dcomp_dname = component.eval(params=parameters, x=x) +# print(f'\tdcomp_dname ({type(dcomp_dname)}):\n{dcomp_dname}') + for nname in free_parameters: + dexpr_dnname = diff(expr, nname) + if dexpr_dnname: + assert(self._parameter_norms[name]) +# print(f'\t\td({expr})/d{nname} = {dexpr_dnname}') +# print(f'\t\t{component} is linear in {nname} (through {name} = "{expr}", add to matrix as column {free_parameters.index(nname)})') + fx = np.asarray(dexpr_dnname*dcomp_dname, dtype='float64') +# print(f'\t\tfx ({type(fx)}): {fx}') +# print(f'free_parameters.index({nname}): {free_parameters.index(nname)}') + if self._parameter_norms[nname]: + A[:,free_parameters.index(nname)] += fx + else: + A[:,free_parameters.index(nname)] += fx/norm + const_expr = f'{const_expr}-({dexpr_dnname})*{nname}' +# print(f'\t\tconst_expr: {const_expr}') + const_expr = str(simplify(f'({const_expr})/{norm}')) +# print(f'\tconst_expr: {const_expr}') + fx = [(lambda _: ast.eval(const_expr))(ast(f'x = {v}')) for v in x] +# print(f'\tfx: {fx}') + delta_y_const = np.multiply(fx, dcomp_dname) + y_const += delta_y_const +# print(f'\ndelta_y_const ({type(delta_y_const)}):\n{delta_y_const}\n') +# print(A) +# print(y_const) + solution, residual, rank, s = np.linalg.lstsq(A, y-y_const, rcond=None) +# print(f'\nsolution ({type(solution)} {solution.shape}):\n\t{solution}') +# print(f'\nresidual ({type(residual)} {residual.shape}):\n\t{residual}') +# print(f'\nrank ({type(rank)} {rank.shape}):\n\t{rank}') +# print(f'\ns ({type(s)} {s.shape}):\n\t{s}\n') + + # Assemble result (compensate for normalization in expression models) + for name, value in zip(free_parameters, solution): + self._parameters[name].set(value=value) + if self._normalized and (have_expression_model or len(expr_parameters)): + for name, norm in self._parameter_norms.items(): + par = self._parameters[name] + if par.expr is None and norm: + self._parameters[name].set(value=par.value*self._norm[1]) +# self._parameters.pretty_print() +# print(f'\nself._parameter_norms:\n\t{self._parameter_norms}') + self._result = ModelResult(self._model, deepcopy(self._parameters)) + self._result.best_fit = self._model.eval(params=self._parameters, x=x) + if self._normalized and (have_expression_model or len(expr_parameters)): + if 'tmp_normalization_offset_c' in self._parameters: + offset = self._parameters['tmp_normalization_offset_c'] + else: + offset = 0.0 + self._result.best_fit = (self._result.best_fit-offset-self._norm[0])/self._norm[1] + if self._normalized: + for name, norm in self._parameter_norms.items(): + par = self._parameters[name] + if par.expr is None and norm: + value = par.value/self._norm[1] + self._parameters[name].set(value=value) + self._result.params[name].set(value=value) +# self._parameters.pretty_print() + self._result.residual = self._result.best_fit-y + self._result.components = self._model.components + self._result.init_params = None +# quick_plot((x, y, '.'), (x, y_const, 'g'), (x, self._result.best_fit, 'k'), (x, self._result.residual, 'r'), block=True) + + def _normalize(self): + """Normalize the data and initial parameters + """ + if self._normalized: + return + if self._norm is None: + if self._y is not None and self._y_norm is None: + self._y_norm = np.asarray(self._y) + else: + if self._y is not None and self._y_norm is None: + self._y_norm = (np.asarray(self._y)-self._norm[0])/self._norm[1] + self._y_range = 1.0 + for name, norm in self._parameter_norms.items(): + par = self._parameters[name] + if par.expr is None and norm: + value = par.value/self._norm[1] + _min = par.min + _max = par.max + if not np.isinf(_min) and abs(_min) != float_min: + _min /= self._norm[1] + if not np.isinf(_max) and abs(_max) != float_min: + _max /= self._norm[1] + par.set(value=value, min=_min, max=_max) + self._normalized = True + + def _renormalize(self): + """Renormalize the data and results + """ + if self._norm is None or not self._normalized: + return + self._normalized = False + for name, norm in self._parameter_norms.items(): + par = self._parameters[name] + if par.expr is None and norm: + value = par.value*self._norm[1] + _min = par.min + _max = par.max + if not np.isinf(_min) and abs(_min) != float_min: + _min *= self._norm[1] + if not np.isinf(_max) and abs(_max) != float_min: + _max *= self._norm[1] + par.set(value=value, min=_min, max=_max) + if self._result is None: + return + self._result.best_fit = self._result.best_fit*self._norm[1]+self._norm[0] + for name, par in self._result.params.items(): + if self._parameter_norms.get(name, False): + if par.stderr is not None: + par.stderr *= self._norm[1] + if par.expr is None: + _min = par.min + _max = par.max + value = par.value*self._norm[1] + if par.init_value is not None: + par.init_value *= self._norm[1] + if not np.isinf(_min) and abs(_min) != float_min: + _min *= self._norm[1] + if not np.isinf(_max) and abs(_max) != float_min: + _max *= self._norm[1] + par.set(value=value, min=_min, max=_max) + if hasattr(self._result, 'init_fit'): + self._result.init_fit = self._result.init_fit*self._norm[1]+self._norm[0] + if hasattr(self._result, 'init_values'): + init_values = {} + for name, value in self._result.init_values.items(): + if name not in self._parameter_norms or self._parameters[name].expr is not None: + init_values[name] = value + elif self._parameter_norms[name]: + init_values[name] = value*self._norm[1] + self._result.init_values = init_values + for name, par in self._result.init_params.items(): + if par.expr is None and self._parameter_norms.get(name, False): + value = par.value + _min = par.min + _max = par.max + value *= self._norm[1] + if not np.isinf(_min) and abs(_min) != float_min: + _min *= self._norm[1] + if not np.isinf(_max) and abs(_max) != float_min: + _max *= self._norm[1] + par.set(value=value, min=_min, max=_max) + par.init_value = par.value + # Don't renormalize chisqr, it has no useful meaning in physical units + #self._result.chisqr *= self._norm[1]*self._norm[1] + if self._result.covar is not None: + for i, name in enumerate(self._result.var_names): + if self._parameter_norms.get(name, False): + for j in range(len(self._result.var_names)): + if self._result.covar[i,j] is not None: + self._result.covar[i,j] *= self._norm[1] + if self._result.covar[j,i] is not None: + self._result.covar[j,i] *= self._norm[1] + # Don't renormalize redchi, it has no useful meaning in physical units + #self._result.redchi *= self._norm[1]*self._norm[1] + if self._result.residual is not None: + self._result.residual *= self._norm[1] + + def _reset_par_at_boundary(self, par, value): + assert(par.vary) + name = par.name + _min = self._parameter_bounds[name]['min'] + _max = self._parameter_bounds[name]['max'] + if np.isinf(_min): + if not np.isinf(_max): + if self._parameter_norms.get(name, False): + upp = _max-0.1*self._y_range + elif _max == 0.0: + upp = _max-0.1 + else: + upp = _max-0.1*abs(_max) + if value >= upp: + return(upp) + else: + if np.isinf(_max): + if self._parameter_norms.get(name, False): + low = _min+0.1*self._y_range + elif _min == 0.0: + low = _min+0.1 + else: + low = _min+0.1*abs(_min) + if value <= low: + return(low) + else: + low = 0.9*_min+0.1*_max + upp = 0.1*_min+0.9*_max + if value <= low: + return(low) + elif value >= upp: + return(upp) + return(value) class FitMultipeak(Fit): """Fit data with multiple peaks """ - def __init__(self, x, y, normalize=True): - super().__init__(x, deepcopy(y)) - self._norm = None + def __init__(self, y, x=None, normalize=True): + super().__init__(y, x=x, normalize=normalize) self._fwhm_max = None self._sigma_max = None - if normalize: - self._normalize() - #quickPlot((self._x,self._y), block=True) @classmethod - def fit_multipeak(cls, x, y, centers, peak_models='gaussian', center_exprs=None, fit_type=None, - background_order=None, fwhm_max=None, plot_components=None): + def fit_multipeak(cls, y, centers, x=None, normalize=True, peak_models='gaussian', + center_exprs=None, fit_type=None, background_order=None, background_exp=False, + fwhm_max=None, plot_components=False): """Make sure that centers and fwhm_max are in the correct units and consistent with expr for a uniform fit (fit_type == 'uniform') """ - fit = cls(x, y) + fit = cls(y, x=x, normalize=normalize) success = fit.fit(centers, fit_type=fit_type, peak_models=peak_models, fwhm_max=fwhm_max, center_exprs=center_exprs, background_order=background_order, - plot_components=plot_components) + background_exp=background_exp, plot_components=plot_components) if success: - return fit.best_fit, fit.residual, fit.best_values, fit.best_errors, fit.redchi, \ - fit.success + return(fit.best_fit, fit.residual, fit.best_values, fit.best_errors, fit.redchi, \ + fit.success) else: - return np.array([]), np.array([]), {}, {}, sys.float_info.max, False + return(np.array([]), np.array([]), {}, {}, float_max, False) def fit(self, centers, fit_type=None, peak_models=None, center_exprs=None, fwhm_max=None, - background_order=None, plot_components=None, param_constraint=False): + background_order=None, background_exp=False, plot_components=False, + param_constraint=False): self._fwhm_max = fwhm_max # Create the multipeak model self._create_model(centers, fit_type, peak_models, center_exprs, background_order, - param_constraint) + background_exp, param_constraint) + + # RV: Obsolete Normalize the data and results +# print('\nBefore fit before normalization in FitMultipeak:') +# self._parameters.pretty_print() +# self._normalize() +# print('\nBefore fit after normalization in FitMultipeak:') +# self._parameters.pretty_print() # Perform the fit try: if param_constraint: - super().fit(fit_kws={'xtol' : 1.e-5, 'ftol' : 1.e-5, 'gtol' : 1.e-5}) + super().fit(fit_kws={'xtol': 1.e-5, 'ftol': 1.e-5, 'gtol': 1.e-5}) else: super().fit() except: - return False + return(False) # Check for valid fit parameter results fit_failure = self._check_validity() @@ -447,21 +1527,27 @@ else: logging.info(' -> Retry fitting with constraints') self.fit(centers, fit_type, peak_models, center_exprs, fwhm_max=fwhm_max, - background_order=background_order, plot_components=plot_components, - param_constraint=True) + background_order=background_order, background_exp=background_exp, + plot_components=plot_components, param_constraint=True) else: - # Renormalize the data and results - self._renormalize() + # RV: Obsolete Renormalize the data and results +# print('\nAfter fit before renormalization in FitMultipeak:') +# self._parameters.pretty_print() +# self.print_fit_report() +# self._renormalize() +# print('\nAfter fit after renormalization in FitMultipeak:') +# self._parameters.pretty_print() +# self.print_fit_report() # Print report and plot components if requested - if plot_components is not None: + if plot_components: self.print_fit_report() self.plot() - return success + return(success) def _create_model(self, centers, fit_type=None, peak_models=None, center_exprs=None, - background_order=None, param_constraint=False): + background_order=None, background_exp=False, param_constraint=False): """Create the multipeak model """ if isinstance(centers, (int, float)): @@ -476,10 +1562,10 @@ f'{num_peaks})') if num_peaks == 1: if fit_type is not None: - logging.warning('Ignoring fit_type input for fitting one peak') + logging.debug('Ignoring fit_type input for fitting one peak') fit_type = None if center_exprs is not None: - logging.warning('Ignoring center_exprs input for fitting one peak') + logging.debug('Ignoring center_exprs input for fitting one peak') center_exprs = None else: if fit_type == 'uniform': @@ -493,10 +1579,10 @@ logging.warning('Ignoring center_exprs input for unconstrained fit') center_exprs = None else: - raise ValueError(f'Illegal fit_type in fit_multigaussian {fit_type}') + raise ValueError(f'Invalid fit_type in fit_multigaussian {fit_type}') self._sigma_max = None if param_constraint: - min_value = sys.float_info.min + min_value = float_min if self._fwhm_max is not None: self._sigma_max = np.zeros(num_peaks) else: @@ -510,13 +1596,18 @@ # Add background model if background_order is not None: if background_order == 0: - self.add_model('constant', prefix='background', c=0.0) + self.add_model('constant', prefix='background', parameters= + {'name': 'c', 'value': float_min, 'min': min_value}) elif background_order == 1: self.add_model('linear', prefix='background', slope=0.0, intercept=0.0) elif background_order == 2: self.add_model('quadratic', prefix='background', a=0.0, b=0.0, c=0.0) else: - raise ValueError(f'background_order = {background_order}') + raise ValueError(f'Invalid parameter background_order ({background_order})') + if background_exp: + self.add_model('exponential', prefix='background', parameters=( + {'name': 'amplitude', 'value': float_min, 'min': min_value}, + {'name': 'decay', 'value': float_min, 'min': min_value})) # Add peaks and guess initial fit parameters ast = Interpreter() @@ -534,9 +1625,9 @@ sig_max = ast(fwhm_factor[peak_models[0]]) self._sigma_max[0] = sig_max self.add_model(peak_models[0], parameters=( - {'name' : 'amplitude', 'value' : amp_init, 'min' : min_value}, - {'name' : 'center', 'value' : cen_init, 'min' : min_value}, - {'name' : 'sigma', 'value' : sig_init, 'min' : min_value, 'max' : sig_max})) + {'name': 'amplitude', 'value': amp_init, 'min': min_value}, + {'name': 'center', 'value': cen_init, 'min': min_value}, + {'name': 'sigma', 'value': sig_init, 'min': min_value, 'max': sig_max})) else: if fit_type == 'uniform': self.add_parameter(name='scale_factor', value=1.0) @@ -556,107 +1647,921 @@ self._sigma_max[i] = sig_max if fit_type == 'uniform': self.add_model(peak_models[i], prefix=f'peak{i+1}_', parameters=( - {'name' : 'amplitude', 'value' : amp_init, 'min' : min_value}, - {'name' : 'center', 'expr' : center_exprs[i], 'min' : min_value}, - {'name' : 'sigma', 'value' : sig_init, 'min' : min_value, - 'max' : sig_max})) + {'name': 'amplitude', 'value': amp_init, 'min': min_value}, + {'name': 'center', 'expr': center_exprs[i]}, + {'name': 'sigma', 'value': sig_init, 'min': min_value, + 'max': sig_max})) else: self.add_model('gaussian', prefix=f'peak{i+1}_', parameters=( - {'name' : 'amplitude', 'value' : amp_init, 'min' : min_value}, - {'name' : 'center', 'value' : cen_init, 'min' : min_value}, - {'name' : 'sigma', 'value' : sig_init, 'min' : min_value, - 'max' : sig_max})) + {'name': 'amplitude', 'value': amp_init, 'min': min_value}, + {'name': 'center', 'value': cen_init, 'min': min_value}, + {'name': 'sigma', 'value': sig_init, 'min': min_value, + 'max': sig_max})) def _check_validity(self): """Check for valid fit parameter results """ fit_failure = False - index = re.compile(r'\d+') - for parameter in self.best_parameters: - name = parameter['name'] - if ((('amplitude' in name or 'height' in name) and parameter['value'] <= 0.0) or - (('sigma' in name or 'fwhm' in name) and parameter['value'] <= 0.0) or - ('center' in name and parameter['value'] <= 0.0) or - (name == 'scale_factor' and parameter['value'] <= 0.0)): - logging.info(f'Invalid fit result for {name} ({parameter["value"]})') + index = compile(r'\d+') + for name, par in self.best_parameters.items(): + if 'background' in name: +# if ((name == 'backgroundc' and par['value'] <= 0.0) or +# (name.endswith('amplitude') and par['value'] <= 0.0) or + if ((name.endswith('amplitude') and par['value'] <= 0.0) or + (name.endswith('decay') and par['value'] <= 0.0)): + logging.info(f'Invalid fit result for {name} ({par["value"]})') + fit_failure = True + elif (((name.endswith('amplitude') or name.endswith('height')) and + par['value'] <= 0.0) or + ((name.endswith('sigma') or name.endswith('fwhm')) and par['value'] <= 0.0) or + (name.endswith('center') and par['value'] <= 0.0) or + (name == 'scale_factor' and par['value'] <= 0.0)): + logging.info(f'Invalid fit result for {name} ({par["value"]})') fit_failure = True - if 'sigma' in name and self._sigma_max is not None: + if name.endswith('sigma') and self._sigma_max is not None: if name == 'sigma': sigma_max = self._sigma_max[0] else: sigma_max = self._sigma_max[int(index.search(name).group())-1] - i = int(index.search(name).group())-1 - if parameter['value'] > sigma_max: - logging.info(f'Invalid fit result for {name} ({parameter["value"]})') + if par['value'] > sigma_max: + logging.info(f'Invalid fit result for {name} ({par["value"]})') + fit_failure = True + elif par['value'] == sigma_max: + logging.warning(f'Edge result on for {name} ({par["value"]})') + if name.endswith('fwhm') and self._fwhm_max is not None: + if par['value'] > self._fwhm_max: + logging.info(f'Invalid fit result for {name} ({par["value"]})') fit_failure = True - elif parameter['value'] == sigma_max: - logging.warning(f'Edge result on for {name} ({parameter["value"]})') - if 'fwhm' in name and self._fwhm_max is not None: - if parameter['value'] > self._fwhm_max: - logging.info(f'Invalid fit result for {name} ({parameter["value"]})') - fit_failure = True - elif parameter['value'] == self._fwhm_max: - logging.warning(f'Edge result on for {name} ({parameter["value"]})') - return fit_failure + elif par['value'] == self._fwhm_max: + logging.warning(f'Edge result on for {name} ({par["value"]})') + return(fit_failure) + + +class FitMap(Fit): + """Fit a map of data + """ + def __init__(self, ymap, x=None, models=None, normalize=True, transpose=None, **kwargs): + super().__init__(None) + self._best_errors = None + self._best_fit = None + self._best_parameters = None + self._best_values = None + self._inv_transpose = None + self._max_nfev = None + self._memfolder = None + self._new_parameters = None + self._out_of_bounds = None + self._plot = False + self._print_report = False + self._redchi = None + self._redchi_cutoff = 0.1 + self._skip_init = True + self._success = None + self._transpose = None + self._try_no_bounds = True + + # At this point the fastest index should always be the signal dimension so that the slowest + # ndim-1 dimensions are the map dimensions + if isinstance(ymap, (tuple, list, np.ndarray)): + self._x = np.asarray(x) + elif have_xarray and isinstance(ymap, xr.DataArray): + if x is not None: + logging.warning('Ignoring superfluous input x ({x}) in Fit.__init__') + self._x = np.asarray(ymap[ymap.dims[-1]]) + else: + illegal_value(ymap, 'ymap', 'FitMap:__init__', raise_error=True) + self._ymap = ymap + + # Verify the input parameters + if self._x.ndim != 1: + raise ValueError(f'Invalid dimension for input x {self._x.ndim}') + if self._ymap.ndim < 2: + raise ValueError('Invalid number of dimension of the input dataset '+ + f'{self._ymap.ndim}') + if self._x.size != self._ymap.shape[-1]: + raise ValueError(f'Inconsistent x and y dimensions ({self._x.size} vs '+ + f'{self._ymap.shape[-1]})') + if not isinstance(normalize, bool): + logging.warning(f'Invalid value for normalize ({normalize}) in Fit.__init__: '+ + 'setting normalize to True') + normalize = True + if isinstance(transpose, bool) and not transpose: + transpose = None + if transpose is not None and self._ymap.ndim < 3: + logging.warning(f'Transpose meaningless for {self._ymap.ndim-1}D data maps: ignoring '+ + 'transpose') + if transpose is not None: + if self._ymap.ndim == 3 and isinstance(transpose, bool) and transpose: + self._transpose = (1, 0) + elif not isinstance(transpose, (tuple, list)): + logging.warning(f'Invalid data type for transpose ({transpose}, '+ + f'{type(transpose)}) in Fit.__init__: setting transpose to False') + elif len(transpose) != self._ymap.ndim-1: + logging.warning(f'Invalid dimension for transpose ({transpose}, must be equal to '+ + f'{self._ymap.ndim-1}) in Fit.__init__: setting transpose to False') + elif any(i not in transpose for i in range(len(transpose))): + logging.warning(f'Invalid index in transpose ({transpose}) '+ + f'in Fit.__init__: setting transpose to False') + elif not all(i==transpose[i] for i in range(self._ymap.ndim-1)): + self._transpose = transpose + if self._transpose is not None: + self._inv_transpose = tuple(self._transpose.index(i) + for i in range(len(self._transpose))) + + # Flatten the map (transpose if requested) + # Store the flattened map in self._ymap_norm, whether normalized or not + if self._transpose is not None: + self._ymap_norm = np.transpose(np.asarray(self._ymap), list(self._transpose)+ + [len(self._transpose)]) + else: + self._ymap_norm = np.asarray(self._ymap) + self._map_dim = int(self._ymap_norm.size/self._x.size) + self._map_shape = self._ymap_norm.shape[:-1] + self._ymap_norm = np.reshape(self._ymap_norm, (self._map_dim, self._x.size)) + + # Check if a mask is provided + if 'mask' in kwargs: + self._mask = kwargs.pop('mask') + if self._mask is None: + ymap_min = float(self._ymap_norm.min()) + ymap_max = float(self._ymap_norm.max()) + else: + self._mask = np.asarray(self._mask).astype(bool) + if self._x.size != self._mask.size: + raise ValueError(f'Inconsistent mask dimension ({self._x.size} vs '+ + f'{self._mask.size})') + ymap_masked = np.asarray(self._ymap_norm)[:,~self._mask] + ymap_min = float(ymap_masked.min()) + ymap_max = float(ymap_masked.max()) + + # Normalize the data + self._y_range = ymap_max-ymap_min + if normalize and self._y_range > 0.0: + self._norm = (ymap_min, self._y_range) + self._ymap_norm = (self._ymap_norm-self._norm[0])/self._norm[1] + else: + self._redchi_cutoff *= self._y_range**2 + if models is not None: + if callable(models) or isinstance(models, str): + kwargs = self.add_model(models, **kwargs) + elif isinstance(models, (tuple, list)): + for model in models: + kwargs = self.add_model(model, **kwargs) + self.fit(**kwargs) + + @classmethod + def fit_map(cls, x, ymap, models, normalize=True, **kwargs): + return(cls(x, ymap, models, normalize=normalize, **kwargs)) + + @property + def best_errors(self): + return(self._best_errors) + + @property + def best_fit(self): + return(self._best_fit) - def _normalize(self): - """Normalize the data + @property + def best_results(self): + """Convert the input data array to a data set and add the fit results. """ - y_min = self._y.min() - self._norm = (y_min, self._y.max()-y_min) - if self._norm[1] == 0.0: - self._norm = None + if self.best_values is None or self.best_errors is None or self.best_fit is None: + return(None) + if not have_xarray: + logging.warning('Unable to load xarray module') + return(None) + best_values = self.best_values + best_errors = self.best_errors + if isinstance(self._ymap, xr.DataArray): + best_results = self._ymap.to_dataset() + dims = self._ymap.dims + fit_name = f'{self._ymap.name}_fit' else: - self._y = (self._y-self._norm[0])/self._norm[1] + coords = {f'dim{n}_index':([f'dim{n}_index'], range(self._ymap.shape[n])) + for n in range(self._ymap.ndim-1)} + coords['x'] = (['x'], self._x) + dims = list(coords.keys()) + best_results = xr.Dataset(coords=coords) + best_results['y'] = (dims, self._ymap) + fit_name = 'y_fit' + best_results[fit_name] = (dims, self.best_fit) + if self._mask is not None: + best_results['mask'] = self._mask + for n in range(best_values.shape[0]): + best_results[f'{self._best_parameters[n]}_values'] = (dims[:-1], best_values[n]) + best_results[f'{self._best_parameters[n]}_errors'] = (dims[:-1], best_errors[n]) + best_results.attrs['components'] = self.components + return(best_results) + + @property + def best_values(self): + return(self._best_values) + + @property + def chisqr(self): + logging.warning('property chisqr not defined for fit.FitMap') + return(None) + + @property + def components(self): + components = {} + if self._result is None: + logging.warning('Unable to collect components in FitMap.components') + return(components) + for component in self._result.components: + if 'tmp_normalization_offset_c' in component.param_names: + continue + parameters = {} + for name in component.param_names: + if self._parameters[name].vary: + parameters[name] = {'free': True} + elif self._parameters[name].expr is not None: + parameters[name] = {'free': False, 'expr': self._parameters[name].expr} + else: + parameters[name] = {'free': False, 'value': self.init_parameters[name]['value']} + expr = None + if isinstance(component, ExpressionModel): + name = component._name + if name[-1] == '_': + name = name[:-1] + expr = component.expr + else: + prefix = component.prefix + if len(prefix): + if prefix[-1] == '_': + prefix = prefix[:-1] + name = f'{prefix} ({component._name})' + else: + name = f'{component._name}' + if expr is None: + components[name] = {'parameters': parameters} + else: + components[name] = {'expr': expr, 'parameters': parameters} + return(components) + + @property + def covar(self): + logging.warning('property covar not defined for fit.FitMap') + return(None) + + @property + def max_nfev(self): + return(self._max_nfev) + + @property + def num_func_eval(self): + logging.warning('property num_func_eval not defined for fit.FitMap') + return(None) - def _renormalize(self): - """Renormalize the data and results - """ - if self._norm is None: + @property + def out_of_bounds(self): + return(self._out_of_bounds) + + @property + def redchi(self): + return(self._redchi) + + @property + def residual(self): + if self.best_fit is None: + return(None) + if self._mask is None: + return(np.asarray(self._ymap)-self.best_fit) + else: + ymap_flat = np.reshape(np.asarray(self._ymap), (self._map_dim, self._x.size)) + ymap_flat_masked = ymap_flat[:,~self._mask] + ymap_masked = np.reshape(ymap_flat_masked, + list(self._map_shape)+[ymap_flat_masked.shape[-1]]) + return(ymap_masked-self.best_fit) + + @property + def success(self): + return(self._success) + + @property + def var_names(self): + logging.warning('property var_names not defined for fit.FitMap') + return(None) + + @property + def y(self): + logging.warning('property y not defined for fit.FitMap') + return(None) + + @property + def ymap(self): + return(self._ymap) + + def best_parameters(self, dims=None): + if dims is None: + return(self._best_parameters) + if not isinstance(dims, (list, tuple)) or len(dims) != len(self._map_shape): + illegal_value(dims, 'dims', 'FitMap.best_parameters', raise_error=True) + if self.best_values is None or self.best_errors is None: + logging.warning(f'Unable to obtain best parameter values for dims = {dims} in '+ + 'FitMap.best_parameters') + return({}) + # Create current parameters + parameters = deepcopy(self._parameters) + for n, name in enumerate(self._best_parameters): + if self._parameters[name].vary: + parameters[name].set(value=self.best_values[n][dims]) + parameters[name].stderr = self.best_errors[n][dims] + parameters_dict = {} + for name in sorted(parameters): + if name != 'tmp_normalization_offset_c': + par = parameters[name] + parameters_dict[name] = {'value': par.value, 'error': par.stderr, + 'init_value': self.init_parameters[name]['value'], 'min': par.min, + 'max': par.max, 'vary': par.vary, 'expr': par.expr} + return(parameters_dict) + + def freemem(self): + if self._memfolder is None: + return + try: + rmtree(self._memfolder) + self._memfolder = None + except: + logging.warning('Could not clean-up automatically.') + + def plot(self, dims, y_title=None, plot_residual=False, plot_comp_legends=False, + plot_masked_data=True): + if not isinstance(dims, (list, tuple)) or len(dims) != len(self._map_shape): + illegal_value(dims, 'dims', 'FitMap.plot', raise_error=True) + if self._result is None or self.best_fit is None or self.best_values is None: + logging.warning(f'Unable to plot fit for dims = {dims} in FitMap.plot') return - self._y = self._norm[0]+self._norm[1]*self._y - self._result.best_fit = self._norm[0]+self._norm[1]*self._result.best_fit - for name in self._result.params: - par = self._result.params[name] - if 'amplitude' in name or 'height' in name or 'background' in name: - par.value *= self._norm[1] - if par.stderr is not None: - par.stderr *= self._norm[1] - if par.init_value is not None: - par.init_value *= self._norm[1] - if par.min is not None and not np.isinf(par.min): - par.min *= self._norm[1] - if par.max is not None and not np.isinf(par.max): - par.max *= self._norm[1] - if 'intercept' in name or 'backgroundc' in name: - par.value += self._norm[0] - if par.init_value is not None: - par.init_value += self._norm[0] - if par.min is not None and not np.isinf(par.min): - par.min += self._norm[0] - if par.max is not None and not np.isinf(par.max): - par.max += self._norm[0] - self._result.init_fit = self._norm[0]+self._norm[1]*self._result.init_fit - init_values = {} - for name in self._result.init_values: - init_values[name] = self._result.init_values[name] - if init_values[name] is None: + if y_title is None or not isinstance(y_title, str): + y_title = 'data' + if self._mask is None: + mask = np.zeros(self._x.size).astype(bool) + plot_masked_data = False + else: + mask = self._mask + plots = [(self._x, np.asarray(self._ymap[dims]), 'b.')] + legend = [y_title] + if plot_masked_data: + plots += [(self._x[mask], np.asarray(self._ymap)[(*dims,mask)], 'bx')] + legend += ['masked data'] + plots += [(self._x[~mask], self.best_fit[dims], 'k-')] + legend += ['best fit'] + if plot_residual: + plots += [(self._x[~mask], self.residual[dims], 'k--')] + legend += ['residual'] + # Create current parameters + parameters = deepcopy(self._parameters) + for name in self._best_parameters: + if self._parameters[name].vary: + parameters[name].set(value= + self.best_values[self._best_parameters.index(name)][dims]) + for component in self._result.components: + if 'tmp_normalization_offset_c' in component.param_names: continue - if 'amplitude' in name or 'height' in name or 'background' in name: - init_values[name] *= self._norm[1] - if 'intercept' in name or 'backgroundc' in name: - init_values[name] += self._norm[0] - self._result.init_values = init_values - # Don't renormalized chisqr, it has no useful meaning in physical units - #self._result.chisqr *= self._norm[1]*self._norm[1] - if self._result.covar is not None: - for i, name in enumerate(self._result.var_names): - if 'amplitude' in name or 'height' in name or 'background' in name: - for j in range(len(self._result.var_names)): - if self._result.covar[i,j] is not None: - self._result.covar[i,j] *= self._norm[1] - if self._result.covar[j,i] is not None: - self._result.covar[j,i] *= self._norm[1] - # Don't renormalized redchi, it has no useful meaning in physical units - #self._result.redchi *= self._norm[1]*self._norm[1] - self._result.residual *= self._norm[1] + if isinstance(component, ExpressionModel): + prefix = component._name + if prefix[-1] == '_': + prefix = prefix[:-1] + modelname = f'{prefix}: {component.expr}' + else: + prefix = component.prefix + if len(prefix): + if prefix[-1] == '_': + prefix = prefix[:-1] + modelname = f'{prefix} ({component._name})' + else: + modelname = f'{component._name}' + if len(modelname) > 20: + modelname = f'{modelname[0:16]} ...' + y = component.eval(params=parameters, x=self._x[~mask]) + if isinstance(y, (int, float)): + y *= np.ones(self._x[~mask].size) + plots += [(self._x[~mask], y, '--')] + if plot_comp_legends: + legend.append(modelname) + quick_plot(tuple(plots), legend=legend, title=str(dims), block=True) + + def fit(self, **kwargs): +# t0 = time() + # Check input parameters + if self._model is None: + logging.error('Undefined fit model') + if 'num_proc' in kwargs: + num_proc = kwargs.pop('num_proc') + if not is_int(num_proc, ge=1): + illegal_value(num_proc, 'num_proc', 'FitMap.fit', raise_error=True) + else: + num_proc = cpu_count() + if num_proc > 1 and not have_joblib: + logging.warning(f'Missing joblib in the conda environment, running FitMap serially') + num_proc = 1 + if num_proc > cpu_count(): + logging.warning(f'The requested number of processors ({num_proc}) exceeds the maximum '+ + f'number of processors, num_proc reduced to ({cpu_count()})') + num_proc = cpu_count() + if 'try_no_bounds' in kwargs: + self._try_no_bounds = kwargs.pop('try_no_bounds') + if not isinstance(self._try_no_bounds, bool): + illegal_value(self._try_no_bounds, 'try_no_bounds', 'FitMap.fit', raise_error=True) + if 'redchi_cutoff' in kwargs: + self._redchi_cutoff = kwargs.pop('redchi_cutoff') + if not is_num(self._redchi_cutoff, gt=0): + illegal_value(self._redchi_cutoff, 'redchi_cutoff', 'FitMap.fit', raise_error=True) + if 'print_report' in kwargs: + self._print_report = kwargs.pop('print_report') + if not isinstance(self._print_report, bool): + illegal_value(self._print_report, 'print_report', 'FitMap.fit', raise_error=True) + if 'plot' in kwargs: + self._plot = kwargs.pop('plot') + if not isinstance(self._plot, bool): + illegal_value(self._plot, 'plot', 'FitMap.fit', raise_error=True) + if 'skip_init' in kwargs: + self._skip_init = kwargs.pop('skip_init') + if not isinstance(self._skip_init, bool): + illegal_value(self._skip_init, 'skip_init', 'FitMap.fit', raise_error=True) + + # Apply mask if supplied: + if 'mask' in kwargs: + self._mask = kwargs.pop('mask') + if self._mask is not None: + self._mask = np.asarray(self._mask).astype(bool) + if self._x.size != self._mask.size: + raise ValueError(f'Inconsistent x and mask dimensions ({self._x.size} vs '+ + f'{self._mask.size})') + + # Add constant offset for a normalized single component model + if self._result is None and self._norm is not None and self._norm[0]: + self.add_model('constant', prefix='tmp_normalization_offset_', parameters={'name': 'c', + 'value': -self._norm[0], 'vary': False, 'norm': True}) + #'value': -self._norm[0]/self._norm[1], 'vary': False, 'norm': False}) + + # Adjust existing parameters for refit: + if 'parameters' in kwargs: +# print('\nIn FitMap before adjusting existing parameters for refit:') +# self._parameters.pretty_print() +# if self._result is None: +# raise ValueError('Invalid parameter parameters ({parameters})') +# if self._best_values is None: +# raise ValueError('Valid self._best_values required for refitting in FitMap.fit') + parameters = kwargs.pop('parameters') +# print(f'\nparameters:\n{parameters}') + if isinstance(parameters, dict): + parameters = (parameters, ) + elif not is_dict_series(parameters): + illegal_value(parameters, 'parameters', 'Fit.fit', raise_error=True) + for par in parameters: + name = par['name'] + if name not in self._parameters: + raise ValueError(f'Unable to match {name} parameter {par} to an existing one') + if self._parameters[name].expr is not None: + raise ValueError(f'Unable to modify {name} parameter {par} (currently an '+ + 'expression)') + value = par.get('value') + vary = par.get('vary') + if par.get('expr') is not None: + raise KeyError(f'Illegal "expr" key in {name} parameter {par}') + self._parameters[name].set(value=value, vary=vary, min=par.get('min'), + max=par.get('max')) + # Overwrite existing best values for fixed parameters when a value is specified +# print(f'best values befored resetting:\n{self._best_values}') + if isinstance(value, (int, float)) and vary is False: + for i, nname in enumerate(self._best_parameters): + if nname == name: + self._best_values[i] = value +# print(f'best values after resetting (value={value}, vary={vary}):\n{self._best_values}') +#RV print('\nIn FitMap after adjusting existing parameters for refit:') +#RV self._parameters.pretty_print() + + # Check for uninitialized parameters + for name, par in self._parameters.items(): + if par.expr is None: + value = par.value + if value is None or np.isinf(value) or np.isnan(value): + value = 1.0 + if self._norm is None or name not in self._parameter_norms: + self._parameters[name].set(value=value) + elif self._parameter_norms[name]: + self._parameters[name].set(value=value*self._norm[1]) + + # Create the best parameter list, consisting of all varying parameters plus the expression + # parameters in order to collect their errors + if self._result is None: + # Initial fit + assert(self._best_parameters is None) + self._best_parameters = [name for name, par in self._parameters.items() + if par.vary or par.expr is not None] + num_new_parameters = 0 + else: + # Refit + assert(len(self._best_parameters)) + self._new_parameters = [name for name, par in self._parameters.items() + if name != 'tmp_normalization_offset_c' and name not in self._best_parameters and + (par.vary or par.expr is not None)] + num_new_parameters = len(self._new_parameters) + num_best_parameters = len(self._best_parameters) + + # Flatten and normalize the best values of the previous fit, remove the remaining results + # of the previous fit + if self._result is not None: +# print('\nBefore flatten and normalize:') +# print(f'self._best_values:\n{self._best_values}') + self._out_of_bounds = None + self._max_nfev = None + self._redchi = None + self._success = None + self._best_fit = None + self._best_errors = None + assert(self._best_values is not None) + assert(self._best_values.shape[0] == num_best_parameters) + assert(self._best_values.shape[1:] == self._map_shape) + if self._transpose is not None: + self._best_values = np.transpose(self._best_values, + [0]+[i+1 for i in self._transpose]) + self._best_values = [np.reshape(self._best_values[i], self._map_dim) + for i in range(num_best_parameters)] + if self._norm is not None: + for i, name in enumerate(self._best_parameters): + if self._parameter_norms.get(name, False): + self._best_values[i] /= self._norm[1] +#RV print('\nAfter flatten and normalize:') +#RV print(f'self._best_values:\n{self._best_values}') + + # Normalize the initial parameters (and best values for a refit) +# print('\nIn FitMap before normalize:') +# self._parameters.pretty_print() +# print(f'\nparameter_norms:\n{self._parameter_norms}\n') + self._normalize() +# print('\nIn FitMap after normalize:') +# self._parameters.pretty_print() +# print(f'\nparameter_norms:\n{self._parameter_norms}\n') + + # Prevent initial values from sitting at boundaries + self._parameter_bounds = {name:{'min': par.min, 'max': par.max} + for name, par in self._parameters.items() if par.vary} + for name, par in self._parameters.items(): + if par.vary: + par.set(value=self._reset_par_at_boundary(par, par.value)) +# print('\nAfter checking boundaries:') +# self._parameters.pretty_print() + + # Set parameter bounds to unbound (only use bounds when fit fails) + if self._try_no_bounds: + for name in self._parameter_bounds.keys(): + self._parameters[name].set(min=-np.inf, max=np.inf) + + # Allocate memory to store fit results + if self._mask is None: + x_size = self._x.size + else: + x_size = self._x[~self._mask].size + if num_proc == 1: + self._out_of_bounds_flat = np.zeros(self._map_dim, dtype=bool) + self._max_nfev_flat = np.zeros(self._map_dim, dtype=bool) + self._redchi_flat = np.zeros(self._map_dim, dtype=np.float64) + self._success_flat = np.zeros(self._map_dim, dtype=bool) + self._best_fit_flat = np.zeros((self._map_dim, x_size), + dtype=self._ymap_norm.dtype) + self._best_errors_flat = [np.zeros(self._map_dim, dtype=np.float64) + for _ in range(num_best_parameters+num_new_parameters)] + if self._result is None: + self._best_values_flat = [np.zeros(self._map_dim, dtype=np.float64) + for _ in range(num_best_parameters)] + else: + self._best_values_flat = self._best_values + self._best_values_flat += [np.zeros(self._map_dim, dtype=np.float64) + for _ in range(num_new_parameters)] + else: + self._memfolder = './joblib_memmap' + try: + mkdir(self._memfolder) + except FileExistsError: + pass + filename_memmap = path.join(self._memfolder, 'out_of_bounds_memmap') + self._out_of_bounds_flat = np.memmap(filename_memmap, dtype=bool, + shape=(self._map_dim), mode='w+') + filename_memmap = path.join(self._memfolder, 'max_nfev_memmap') + self._max_nfev_flat = np.memmap(filename_memmap, dtype=bool, + shape=(self._map_dim), mode='w+') + filename_memmap = path.join(self._memfolder, 'redchi_memmap') + self._redchi_flat = np.memmap(filename_memmap, dtype=np.float64, + shape=(self._map_dim), mode='w+') + filename_memmap = path.join(self._memfolder, 'success_memmap') + self._success_flat = np.memmap(filename_memmap, dtype=bool, + shape=(self._map_dim), mode='w+') + filename_memmap = path.join(self._memfolder, 'best_fit_memmap') + self._best_fit_flat = np.memmap(filename_memmap, dtype=self._ymap_norm.dtype, + shape=(self._map_dim, x_size), mode='w+') + self._best_errors_flat = [] + for i in range(num_best_parameters+num_new_parameters): + filename_memmap = path.join(self._memfolder, f'best_errors_memmap_{i}') + self._best_errors_flat.append(np.memmap(filename_memmap, dtype=np.float64, + shape=self._map_dim, mode='w+')) + self._best_values_flat = [] + for i in range(num_best_parameters): + filename_memmap = path.join(self._memfolder, f'best_values_memmap_{i}') + self._best_values_flat.append(np.memmap(filename_memmap, dtype=np.float64, + shape=self._map_dim, mode='w+')) + if self._result is not None: + self._best_values_flat[i][:] = self._best_values[i][:] + for i in range(num_new_parameters): + filename_memmap = path.join(self._memfolder, + f'best_values_memmap_{i+num_best_parameters}') + self._best_values_flat.append(np.memmap(filename_memmap, dtype=np.float64, + shape=self._map_dim, mode='w+')) + + # Update the best parameter list + if num_new_parameters: + self._best_parameters += self._new_parameters + + # Perform the first fit to get model component info and initial parameters + current_best_values = {} +# print(f'0 before:\n{current_best_values}') +# t1 = time() + self._result = self._fit(0, current_best_values, return_result=True, **kwargs) +# t2 = time() +# print(f'0 after:\n{current_best_values}') +# print('\nAfter the first fit:') +# self._parameters.pretty_print() +# print(self._result.fit_report(show_correl=False)) + + # Remove all irrelevant content from self._result + for attr in ('_abort', 'aborted', 'aic', 'best_fit', 'best_values', 'bic', 'calc_covar', + 'call_kws', 'chisqr', 'ci_out', 'col_deriv', 'covar', 'data', 'errorbars', + 'flatchain', 'ier', 'init_vals', 'init_fit', 'iter_cb', 'jacfcn', 'kws', + 'last_internal_values', 'lmdif_message', 'message', 'method', 'nan_policy', + 'ndata', 'nfev', 'nfree', 'params', 'redchi', 'reduce_fcn', 'residual', 'result', + 'scale_covar', 'show_candidates', 'calc_covar', 'success', 'userargs', 'userfcn', + 'userkws', 'values', 'var_names', 'weights', 'user_options'): + try: + delattr(self._result, attr) + except AttributeError: +# logging.warning(f'Unknown attribute {attr} in fit.FtMap._cleanup_result') + pass + +# t3 = time() + if num_proc == 1: + # Perform the remaining fits serially + for n in range(1, self._map_dim): +# print(f'{n} before:\n{current_best_values}') + self._fit(n, current_best_values, **kwargs) +# print(f'{n} after:\n{current_best_values}') + else: + # Perform the remaining fits in parallel + num_fit = self._map_dim-1 +# print(f'num_fit = {num_fit}') + if num_proc > num_fit: + logging.warning(f'The requested number of processors ({num_proc}) exceeds the '+ + f'number of fits, num_proc reduced to ({num_fit})') + num_proc = num_fit + num_fit_per_proc = 1 + else: + num_fit_per_proc = round((num_fit)/num_proc) + if num_proc*num_fit_per_proc < num_fit: + num_fit_per_proc +=1 +# print(f'num_fit_per_proc = {num_fit_per_proc}') + num_fit_batch = min(num_fit_per_proc, 40) +# print(f'num_fit_batch = {num_fit_batch}') + with Parallel(n_jobs=num_proc) as parallel: + parallel(delayed(self._fit_parallel)(current_best_values, num_fit_batch, + n_start, **kwargs) for n_start in range(1, self._map_dim, num_fit_batch)) +# t4 = time() + + # Renormalize the initial parameters for external use + if self._norm is not None and self._normalized: + init_values = {} + for name, value in self._result.init_values.items(): + if name not in self._parameter_norms or self._parameters[name].expr is not None: + init_values[name] = value + elif self._parameter_norms[name]: + init_values[name] = value*self._norm[1] + self._result.init_values = init_values + for name, par in self._result.init_params.items(): + if par.expr is None and self._parameter_norms.get(name, False): + _min = par.min + _max = par.max + value = par.value*self._norm[1] + if not np.isinf(_min) and abs(_min) != float_min: + _min *= self._norm[1] + if not np.isinf(_max) and abs(_max) != float_min: + _max *= self._norm[1] + par.set(value=value, min=_min, max=_max) + par.init_value = par.value + + # Remap the best results +# t5 = time() + self._out_of_bounds = np.copy(np.reshape(self._out_of_bounds_flat, self._map_shape)) + self._max_nfev = np.copy(np.reshape(self._max_nfev_flat, self._map_shape)) + self._redchi = np.copy(np.reshape(self._redchi_flat, self._map_shape)) + self._success = np.copy(np.reshape(self._success_flat, self._map_shape)) + self._best_fit = np.copy(np.reshape(self._best_fit_flat, + list(self._map_shape)+[x_size])) + self._best_values = np.asarray([np.reshape(par, list(self._map_shape)) + for par in self._best_values_flat]) + self._best_errors = np.asarray([np.reshape(par, list(self._map_shape)) + for par in self._best_errors_flat]) + if self._inv_transpose is not None: + self._out_of_bounds = np.transpose(self._out_of_bounds, self._inv_transpose) + self._max_nfev = np.transpose(self._max_nfev, self._inv_transpose) + self._redchi = np.transpose(self._redchi, self._inv_transpose) + self._success = np.transpose(self._success, self._inv_transpose) + self._best_fit = np.transpose(self._best_fit, + list(self._inv_transpose)+[len(self._inv_transpose)]) + self._best_values = np.transpose(self._best_values, + [0]+[i+1 for i in self._inv_transpose]) + self._best_errors = np.transpose(self._best_errors, + [0]+[i+1 for i in self._inv_transpose]) + del self._out_of_bounds_flat + del self._max_nfev_flat + del self._redchi_flat + del self._success_flat + del self._best_fit_flat + del self._best_values_flat + del self._best_errors_flat +# t6 = time() + + # Restore parameter bounds and renormalize the parameters + for name, par in self._parameter_bounds.items(): + self._parameters[name].set(min=par['min'], max=par['max']) + self._normalized = False + if self._norm is not None: + for name, norm in self._parameter_norms.items(): + par = self._parameters[name] + if par.expr is None and norm: + value = par.value*self._norm[1] + _min = par.min + _max = par.max + if not np.isinf(_min) and abs(_min) != float_min: + _min *= self._norm[1] + if not np.isinf(_max) and abs(_max) != float_min: + _max *= self._norm[1] + par.set(value=value, min=_min, max=_max) +# t7 = time() +# print(f'total run time in fit: {t7-t0:.2f} seconds') +# print(f'run time first fit: {t2-t1:.2f} seconds') +# print(f'run time remaining fits: {t4-t3:.2f} seconds') +# print(f'run time remapping results: {t6-t5:.2f} seconds') + +# print('\n\nAt end fit:') +# self._parameters.pretty_print() +# print(f'self._best_values:\n{self._best_values}\n\n') + + # Free the shared memory + self.freemem() + + def _fit_parallel(self, current_best_values, num, n_start, **kwargs): + num = min(num, self._map_dim-n_start) + for n in range(num): +# print(f'{n_start+n} before:\n{current_best_values}') + self._fit(n_start+n, current_best_values, **kwargs) +# print(f'{n_start+n} after:\n{current_best_values}') + + def _fit(self, n, current_best_values, return_result=False, **kwargs): +#RV print(f'\n\nstart FitMap._fit {n}\n') +#RV print(f'current_best_values = {current_best_values}') +#RV print(f'self._best_parameters = {self._best_parameters}') +#RV print(f'self._new_parameters = {self._new_parameters}\n\n') +# self._parameters.pretty_print() + # Set parameters to current best values, but prevent them from sitting at boundaries + if self._new_parameters is None: + # Initial fit + for name, value in current_best_values.items(): + par = self._parameters[name] + par.set(value=self._reset_par_at_boundary(par, value)) + else: + # Refit + for i, name in enumerate(self._best_parameters): + par = self._parameters[name] + if name in self._new_parameters: + if name in current_best_values: + par.set(value=self._reset_par_at_boundary(par, current_best_values[name])) + elif par.expr is None: + par.set(value=self._best_values[i][n]) +#RV print(f'\nbefore fit {n}') +#RV self._parameters.pretty_print() + if self._mask is None: + result = self._model.fit(self._ymap_norm[n], self._parameters, x=self._x, **kwargs) + else: + result = self._model.fit(self._ymap_norm[n][~self._mask], self._parameters, + x=self._x[~self._mask], **kwargs) +# print(f'\nafter fit {n}') +# self._parameters.pretty_print() +# print(result.fit_report(show_correl=False)) + out_of_bounds = False + for name, par in self._parameter_bounds.items(): + value = result.params[name].value + if not np.isinf(par['min']) and value < par['min']: + out_of_bounds = True + break + if not np.isinf(par['max']) and value > par['max']: + out_of_bounds = True + break + self._out_of_bounds_flat[n] = out_of_bounds + if self._try_no_bounds and out_of_bounds: + # Rerun fit with parameter bounds in place + for name, par in self._parameter_bounds.items(): + self._parameters[name].set(min=par['min'], max=par['max']) + # Set parameters to current best values, but prevent them from sitting at boundaries + if self._new_parameters is None: + # Initial fit + for name, value in current_best_values.items(): + par = self._parameters[name] + par.set(value=self._reset_par_at_boundary(par, value)) + else: + # Refit + for i, name in enumerate(self._best_parameters): + par = self._parameters[name] + if name in self._new_parameters: + if name in current_best_values: + par.set(value=self._reset_par_at_boundary(par, + current_best_values[name])) + elif par.expr is None: + par.set(value=self._best_values[i][n]) +# print('\nbefore fit') +# self._parameters.pretty_print() +# print(result.fit_report(show_correl=False)) + if self._mask is None: + result = self._model.fit(self._ymap_norm[n], self._parameters, x=self._x, **kwargs) + else: + result = self._model.fit(self._ymap_norm[n][~self._mask], self._parameters, + x=self._x[~self._mask], **kwargs) +# print(f'\nafter fit {n}') +# self._parameters.pretty_print() +# print(result.fit_report(show_correl=False)) + out_of_bounds = False + for name, par in self._parameter_bounds.items(): + value = result.params[name].value + if not np.isinf(par['min']) and value < par['min']: + out_of_bounds = True + break + if not np.isinf(par['max']) and value > par['max']: + out_of_bounds = True + break +# print(f'{n} redchi < redchi_cutoff = {result.redchi < self._redchi_cutoff} success = {result.success} out_of_bounds = {out_of_bounds}') + # Reset parameters back to unbound + for name in self._parameter_bounds.keys(): + self._parameters[name].set(min=-np.inf, max=np.inf) + assert(not out_of_bounds) + if result.redchi >= self._redchi_cutoff: + result.success = False + if result.nfev == result.max_nfev: +# print(f'Maximum number of function evaluations reached for n = {n}') +# logging.warning(f'Maximum number of function evaluations reached for n = {n}') + if result.redchi < self._redchi_cutoff: + result.success = True + self._max_nfev_flat[n] = True + if result.success: + assert(all(True for par in current_best_values if par in result.params.values())) + for par in result.params.values(): + if par.vary: + current_best_values[par.name] = par.value + else: + logging.warning(f'Fit for n = {n} failed: {result.lmdif_message}') + # Renormalize the data and results + self._renormalize(n, result) + if self._print_report: + print(result.fit_report(show_correl=False)) + if self._plot: + dims = np.unravel_index(n, self._map_shape) + if self._inv_transpose is not None: + dims= tuple(dims[self._inv_transpose[i]] for i in range(len(dims))) + super().plot(result=result, y=np.asarray(self._ymap[dims]), plot_comp_legends=True, + skip_init=self._skip_init, title=str(dims)) +#RV print(f'\n\nend FitMap._fit {n}\n') +#RV print(f'current_best_values = {current_best_values}') +# self._parameters.pretty_print() +# print(result.fit_report(show_correl=False)) +#RV print(f'\nself._best_values_flat:\n{self._best_values_flat}\n\n') + if return_result: + return(result) + + def _renormalize(self, n, result): + self._redchi_flat[n] = np.float64(result.redchi) + self._success_flat[n] = result.success + if self._norm is None or not self._normalized: + self._best_fit_flat[n] = result.best_fit + for i, name in enumerate(self._best_parameters): + self._best_values_flat[i][n] = np.float64(result.params[name].value) + self._best_errors_flat[i][n] = np.float64(result.params[name].stderr) + else: + pars = set(self._parameter_norms) & set(self._best_parameters) + for name, par in result.params.items(): + if name in pars and self._parameter_norms[name]: + if par.stderr is not None: + par.stderr *= self._norm[1] + if par.expr is None: + par.value *= self._norm[1] + if self._print_report: + if par.init_value is not None: + par.init_value *= self._norm[1] + if not np.isinf(par.min) and abs(par.min) != float_min: + par.min *= self._norm[1] + if not np.isinf(par.max) and abs(par.max) != float_min: + par.max *= self._norm[1] + self._best_fit_flat[n] = result.best_fit*self._norm[1]+self._norm[0] + for i, name in enumerate(self._best_parameters): + self._best_values_flat[i][n] = np.float64(result.params[name].value) + self._best_errors_flat[i][n] = np.float64(result.params[name].stderr) + if self._plot: + if not self._skip_init: + result.init_fit = result.init_fit*self._norm[1]+self._norm[0] + result.best_fit = np.copy(self._best_fit_flat[n])
--- a/general.py Fri Aug 19 20:16:56 2022 +0000 +++ b/general.py Fri Mar 10 16:02:04 2023 +0000 @@ -1,5 +1,8 @@ #!/usr/bin/env python3 +#FIX write a function that returns a list of peak indices for a given plot +#FIX use raise_error concept on more functions to optionally raise an error + # -*- coding: utf-8 -*- """ Created on Mon Dec 6 15:36:22 2021 @@ -8,11 +11,15 @@ """ import logging +logger=logging.getLogger(__name__) import os import sys import re -import yaml +try: + from yaml import safe_load, safe_dump +except: + pass try: import h5py except: @@ -20,314 +27,480 @@ import numpy as np try: import matplotlib.pyplot as plt + import matplotlib.lines as mlines + from matplotlib import transforms from matplotlib.widgets import Button except: pass from ast import literal_eval +try: + from asteval import Interpreter, get_ast_names +except: + pass from copy import deepcopy +try: + from sympy import diff, simplify +except: + pass from time import time -def depth_list(L): return isinstance(L, list) and max(map(depth_list, L))+1 -def depth_tuple(T): return isinstance(T, tuple) and max(map(depth_tuple, T))+1 +def depth_list(L): return(isinstance(L, list) and max(map(depth_list, L))+1) +def depth_tuple(T): return(isinstance(T, tuple) and max(map(depth_tuple, T))+1) def unwrap_tuple(T): if depth_tuple(T) > 1 and len(T) == 1: T = unwrap_tuple(*T) - return T - -def illegal_value(value, name, location=None, exit_flag=False): + return(T) + +def illegal_value(value, name, location=None, raise_error=False, log=True): if not isinstance(location, str): location = '' else: location = f'in {location} ' if isinstance(name, str): - logging.error(f'Illegal value for {name} {location}({value}, {type(value)})') + error_msg = f'Illegal value for {name} {location}({value}, {type(value)})' else: - logging.error(f'Illegal value {location}({value}, {type(value)})') - if exit_flag: - raise ValueError + error_msg = f'Illegal value {location}({value}, {type(value)})' + if log: + logger.error(error_msg) + if raise_error: + raise ValueError(error_msg) -def is_int(v, v_min=None, v_max=None): - """Value is an integer in range v_min <= v <= v_max. - """ - if not isinstance(v, int): - return False - if v_min is not None and not isinstance(v_min, int): - illegal_value(v_min, 'v_min', 'is_int') - return False - if v_max is not None and not isinstance(v_max, int): - illegal_value(v_max, 'v_max', 'is_int') - return False - if v_min is not None and v_max is not None and v_min > v_max: - logging.error(f'Illegal v_min, v_max combination ({v_min}, {v_max})') - return False - if (v_min is not None and v < v_min) or (v_max is not None and v > v_max): - return False - return True +def illegal_combination(value1, name1, value2, name2, location=None, raise_error=False, + log=True): + if not isinstance(location, str): + location = '' + else: + location = f'in {location} ' + if isinstance(name1, str): + error_msg = f'Illegal combination for {name1} and {name2} {location}'+ \ + f'({value1}, {type(value1)} and {value2}, {type(value2)})' + else: + error_msg = f'Illegal combination {location}'+ \ + f'({value1}, {type(value1)} and {value2}, {type(value2)})' + if log: + logger.error(error_msg) + if raise_error: + raise ValueError(error_msg) -def is_int_pair(v, v_min=None, v_max=None): - """Value is an integer pair, each in range v_min <= v[i] <= v_max or - v_min[i] <= v[i] <= v_max[i]. +def test_ge_gt_le_lt(ge, gt, le, lt, func, location=None, raise_error=False, log=True): + """Check individual and mutual validity of ge, gt, le, lt qualifiers + func: is_int or is_num to test for int or numbers + Return: True upon success or False when mutually exlusive """ - if not (isinstance(v, (tuple, list)) and len(v) == 2 and isinstance(v[0], int) and - isinstance(v[1], int)): - return False - if v_min is not None or v_max is not None: - if (v_min is None or isinstance(v_min, int)) and (v_max is None or isinstance(v_max, int)): - if True in [True if not is_int(vi, v_min=v_min, v_max=v_max) else False for vi in v]: - return False - elif is_int_pair(v_min) and is_int_pair(v_max): - if True in [True if v_min[i] > v_max[i] else False for i in range(2)]: - logging.error(f'Illegal v_min, v_max combination ({v_min}, {v_max})') - return False - if True in [True if not is_int(v[i], v_min[i], v_max[i]) else False for i in range(2)]: - return False - elif is_int_pair(v_min) and (v_max is None or isinstance(v_max, int)): - if True in [True if not is_int(v[i], v_min=v_min[i], v_max=v_max) else False - for i in range(2)]: - return False - elif (v_min is None or isinstance(v_min, int)) and is_int_pair(v_max): - if True in [True if not is_int(v[i], v_min=v_min, v_max=v_max[i]) else False - for i in range(2)]: - return False + if ge is None and gt is None and le is None and lt is None: + return(True) + if ge is not None: + if not func(ge): + illegal_value(ge, 'ge', location, raise_error, log) + return(False) + if gt is not None: + illegal_combination(ge, 'ge', gt, 'gt', location, raise_error, log) + return(False) + elif gt is not None and not func(gt): + illegal_value(gt, 'gt', location, raise_error, log) + return(False) + if le is not None: + if not func(le): + illegal_value(le, 'le', location, raise_error, log) + return(False) + if lt is not None: + illegal_combination(le, 'le', lt, 'lt', location, raise_error, log) + return(False) + elif lt is not None and not func(lt): + illegal_value(lt, 'lt', location, raise_error, log) + return(False) + if ge is not None: + if le is not None and ge > le: + illegal_combination(ge, 'ge', le, 'le', location, raise_error, log) + return(False) + elif lt is not None and ge >= lt: + illegal_combination(ge, 'ge', lt, 'lt', location, raise_error, log) + return(False) + elif gt is not None: + if le is not None and gt >= le: + illegal_combination(gt, 'gt', le, 'le', location, raise_error, log) + return(False) + elif lt is not None and gt >= lt: + illegal_combination(gt, 'gt', lt, 'lt', location, raise_error, log) + return(False) + return(True) + +def range_string_ge_gt_le_lt(ge=None, gt=None, le=None, lt=None): + """Return a range string representation matching the ge, gt, le, lt qualifiers + Does not validate the inputs, do that as needed before calling + """ + range_string = '' + if ge is not None: + if le is None and lt is None: + range_string += f'>= {ge}' else: - logging.error(f'Illegal v_min or v_max input ({v_min} {type(v_min)} and '+ - f'{v_max} {type(v_max)})') - return False - return True + range_string += f'[{ge}, ' + elif gt is not None: + if le is None and lt is None: + range_string += f'> {gt}' + else: + range_string += f'({gt}, ' + if le is not None: + if ge is None and gt is None: + range_string += f'<= {le}' + else: + range_string += f'{le}]' + elif lt is not None: + if ge is None and gt is None: + range_string += f'< {lt}' + else: + range_string += f'{lt})' + return(range_string) -def is_int_series(l, v_min=None, v_max=None): - """Value is a tuple or list of integers, each in range v_min <= l[i] <= v_max. +def is_int(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): + """Value is an integer in range ge <= v <= le or gt < v < lt or some combination. + Return: True if yes or False is no """ - if v_min is not None and not isinstance(v_min, int): - illegal_value(v_min, 'v_min', 'is_int_series') - return False - if v_max is not None and not isinstance(v_max, int): - illegal_value(v_max, 'v_max', 'is_int_series') - return False - if not isinstance(l, (tuple, list)): - return False - if True in [True if not is_int(v, v_min=v_min, v_max=v_max) else False for v in l]: - return False - return True + return(_is_int_or_num(v, 'int', ge, gt, le, lt, raise_error, log)) -def is_num(v, v_min=None, v_max=None): - """Value is a number in range v_min <= v <= v_max. +def is_num(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): + """Value is a number in range ge <= v <= le or gt < v < lt or some combination. + Return: True if yes or False is no """ - if not isinstance(v, (int, float)): - return False - if v_min is not None and not isinstance(v_min, (int, float)): - illegal_value(v_min, 'v_min', 'is_num') - return False - if v_max is not None and not isinstance(v_max, (int, float)): - illegal_value(v_max, 'v_max', 'is_num') - return False - if v_min is not None and v_max is not None and v_min > v_max: - logging.error(f'Illegal v_min, v_max combination ({v_min}, {v_max})') - return False - if (v_min is not None and v < v_min) or (v_max is not None and v > v_max): - return False - return True + return(_is_int_or_num(v, 'num', ge, gt, le, lt, raise_error, log)) -def is_num_pair(v, v_min=None, v_max=None): - """Value is a number pair, each in range v_min <= v[i] <= v_max or - v_min[i] <= v[i] <= v_max[i]. +def _is_int_or_num(v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False, + log=True): + if type_str == 'int': + if not isinstance(v, int): + illegal_value(v, 'v', '_is_int_or_num', raise_error, log) + return(False) + if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, '_is_int_or_num', raise_error, log): + return(False) + elif type_str == 'num': + if not isinstance(v, (int, float)): + illegal_value(v, 'v', '_is_int_or_num', raise_error, log) + return(False) + if not test_ge_gt_le_lt(ge, gt, le, lt, is_num, '_is_int_or_num', raise_error, log): + return(False) + else: + illegal_value(type_str, 'type_str', '_is_int_or_num', raise_error, log) + return(False) + if ge is None and gt is None and le is None and lt is None: + return(True) + error = False + if ge is not None and v < ge: + error = True + error_msg = f'Value {v} out of range: {v} !>= {ge}' + if not error and gt is not None and v <= gt: + error = True + error_msg = f'Value {v} out of range: {v} !> {gt}' + if not error and le is not None and v > le: + error = True + error_msg = f'Value {v} out of range: {v} !<= {le}' + if not error and lt is not None and v >= lt: + error = True + error_msg = f'Value {v} out of range: {v} !< {lt}' + if error: + if log: + logger.error(error_msg) + if raise_error: + raise ValueError(error_msg) + return(False) + return(True) + +def is_int_pair(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): + """Value is an integer pair, each in range ge <= v[i] <= le or gt < v[i] < lt or + ge[i] <= v[i] <= le[i] or gt[i] < v[i] < lt[i] or some combination. + Return: True if yes or False is no + """ + return(_is_int_or_num_pair(v, 'int', ge, gt, le, lt, raise_error, log)) + +def is_num_pair(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): + """Value is a number pair, each in range ge <= v[i] <= le or gt < v[i] < lt or + ge[i] <= v[i] <= le[i] or gt[i] < v[i] < lt[i] or some combination. + Return: True if yes or False is no """ - if not (isinstance(v, (tuple, list)) and len(v) == 2 and isinstance(v[0], (int, float)) and - isinstance(v[1], (int, float))): - return False - if v_min is not None or v_max is not None: - if ((v_min is None or isinstance(v_min, (int, float))) and - (v_max is None or isinstance(v_max, (int, float)))): - if True in [True if not is_num(vi, v_min=v_min, v_max=v_max) else False for vi in v]: - return False - elif is_num_pair(v_min) and is_num_pair(v_max): - if True in [True if v_min[i] > v_max[i] else False for i in range(2)]: - logging.error(f'Illegal v_min, v_max combination ({v_min}, {v_max})') - return False - if True in [True if not is_num(v[i], v_min[i], v_max[i]) else False for i in range(2)]: - return False - elif is_num_pair(v_min) and (v_max is None or isinstance(v_max, (int, float))): - if True in [True if not is_num(v[i], v_min=v_min[i], v_max=v_max) else False - for i in range(2)]: - return False - elif (v_min is None or isinstance(v_min, (int, float))) and is_num_pair(v_max): - if True in [True if not is_num(v[i], v_min=v_min, v_max=v_max[i]) else False - for i in range(2)]: - return False + return(_is_int_or_num_pair(v, 'num', ge, gt, le, lt, raise_error, log)) + +def _is_int_or_num_pair(v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False, + log=True): + if type_str == 'int': + if not (isinstance(v, (tuple, list)) and len(v) == 2 and isinstance(v[0], int) and + isinstance(v[1], int)): + illegal_value(v, 'v', '_is_int_or_num_pair', raise_error, log) + return(False) + func = is_int + elif type_str == 'num': + if not (isinstance(v, (tuple, list)) and len(v) == 2 and isinstance(v[0], (int, float)) and + isinstance(v[1], (int, float))): + illegal_value(v, 'v', '_is_int_or_num_pair', raise_error, log) + return(False) + func = is_num + else: + illegal_value(type_str, 'type_str', '_is_int_or_num_pair', raise_error, log) + return(False) + if ge is None and gt is None and le is None and lt is None: + return(True) + if ge is None or func(ge, log=True): + ge = 2*[ge] + elif not _is_int_or_num_pair(ge, type_str, raise_error=raise_error, log=log): + return(False) + if gt is None or func(gt, log=True): + gt = 2*[gt] + elif not _is_int_or_num_pair(gt, type_str, raise_error=raise_error, log=log): + return(False) + if le is None or func(le, log=True): + le = 2*[le] + elif not _is_int_or_num_pair(le, type_str, raise_error=raise_error, log=log): + return(False) + if lt is None or func(lt, log=True): + lt = 2*[lt] + elif not _is_int_or_num_pair(lt, type_str, raise_error=raise_error, log=log): + return(False) + if (not func(v[0], ge[0], gt[0], le[0], lt[0], raise_error, log) or + not func(v[1], ge[1], gt[1], le[1], lt[1], raise_error, log)): + return(False) + return(True) + +def is_int_series(l, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): + """Value is a tuple or list of integers, each in range ge <= l[i] <= le or + gt < l[i] < lt or some combination. + """ + if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, 'is_int_series', raise_error, log): + return(False) + if not isinstance(l, (tuple, list)): + illegal_value(l, 'l', 'is_int_series', raise_error, log) + return(False) + if any(True if not is_int(v, ge, gt, le, lt, raise_error, log) else False for v in l): + return(False) + return(True) + +def is_num_series(l, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): + """Value is a tuple or list of numbers, each in range ge <= l[i] <= le or + gt < l[i] < lt or some combination. + """ + if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, 'is_int_series', raise_error, log): + return(False) + if not isinstance(l, (tuple, list)): + illegal_value(l, 'l', 'is_num_series', raise_error, log) + return(False) + if any(True if not is_num(v, ge, gt, le, lt, raise_error, log) else False for v in l): + return(False) + return(True) + +def is_str_series(l, raise_error=False, log=True): + """Value is a tuple or list of strings. + """ + if (not isinstance(l, (tuple, list)) or + any(True if not isinstance(s, str) else False for s in l)): + illegal_value(l, 'l', 'is_str_series', raise_error, log) + return(False) + return(True) + +def is_dict_series(l, raise_error=False, log=True): + """Value is a tuple or list of dictionaries. + """ + if (not isinstance(l, (tuple, list)) or + any(True if not isinstance(d, dict) else False for d in l)): + illegal_value(l, 'l', 'is_dict_series', raise_error, log) + return(False) + return(True) + +def is_dict_nums(l, raise_error=False, log=True): + """Value is a dictionary with single number values + """ + if (not isinstance(l, dict) or + any(True if not is_num(v, log=False) else False for v in l.values())): + illegal_value(l, 'l', 'is_dict_nums', raise_error, log) + return(False) + return(True) + +def is_dict_strings(l, raise_error=False, log=True): + """Value is a dictionary with single string values + """ + if (not isinstance(l, dict) or + any(True if not isinstance(v, str) else False for v in l.values())): + illegal_value(l, 'l', 'is_dict_strings', raise_error, log) + return(False) + return(True) + +def is_index(v, ge=0, lt=None, raise_error=False, log=True): + """Value is an array index in range ge <= v < lt. + NOTE lt IS NOT included! + """ + if isinstance(lt, int): + if lt <= ge: + illegal_combination(ge, 'ge', lt, 'lt', 'is_index', raise_error, log) + return(False) + return(is_int(v, ge=ge, lt=lt, raise_error=raise_error, log=log)) + +def is_index_range(v, ge=0, le=None, lt=None, raise_error=False, log=True): + """Value is an array index range in range ge <= v[0] <= v[1] <= le or ge <= v[0] <= v[1] < lt. + NOTE le IS included! + """ + if not is_int_pair(v, raise_error=raise_error, log=log): + return(False) + if not test_ge_gt_le_lt(ge, None, le, lt, is_int, 'is_index_range', raise_error, log): + return(False) + if not ge <= v[0] <= v[1] or (le is not None and v[1] > le) or (lt is not None and v[1] >= lt): + if le is not None: + error_msg = f'Value {v} out of range: !({ge} <= {v[0]} <= {v[1]} <= {le})' else: - logging.error(f'Illegal v_min or v_max input ({v_min} {type(v_min)} and '+ - f'{v_max} {type(v_max)})') - return False - return True - -def is_num_series(l, v_min=None, v_max=None): - """Value is a tuple or list of numbers, each in range v_min <= l[i] <= v_max. - """ - if v_min is not None and not isinstance(v_min, (int, float)): - illegal_value(v_min, 'v_min', 'is_num_series') - return False - if v_max is not None and not isinstance(v_max, (int, float)): - illegal_value(v_max, 'v_max', 'is_num_series') - return False - if not isinstance(l, (tuple, list)): - return False - if True in [True if not is_num(v, v_min=v_min, v_max=v_max) else False for v in l]: - return False - return True - -def is_index(v, v_min=0, v_max=None): - """Value is an array index in range v_min <= v < v_max. - NOTE v_max IS NOT included! - """ - if isinstance(v_max, int): - if v_max <= v_min: - logging.error(f'Illegal v_min, v_max combination ({v_min}, {v_max})') - return False - v_max -= 1 - return is_int(v, v_min, v_max) - -def is_index_range(v, v_min=0, v_max=None): - """Value is an array index range in range v_min <= v[0] <= v[1] <= v_max. - NOTE v_max IS included! - """ - if not is_int_pair(v): - return False - if not isinstance(v_min, int): - illegal_value(v_min, 'v_min', 'is_index_range') - return False - if v_max is not None: - if not isinstance(v_max, int): - illegal_value(v_max, 'v_max', 'is_index_range') - return False - if v_max < v_min: - logging.error(f'Illegal v_min, v_max combination ({v_min}, {v_max})') - return False - if not v_min <= v[0] <= v[1] or (v_max is not None and v[1] > v_max): - return False - return True + error_msg = f'Value {v} out of range: !({ge} <= {v[0]} <= {v[1]} < {lt})' + if log: + logger.error(error_msg) + if raise_error: + raise ValueError(error_msg) + return(False) + return(True) def index_nearest(a, value): a = np.asarray(a) if a.ndim > 1: - logging.warning(f'Illegal input array ({a}, {type(a)})') + raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})') # Round up for .5 value *= 1.0+sys.float_info.epsilon - return (int)(np.argmin(np.abs(a-value))) + return((int)(np.argmin(np.abs(a-value)))) def index_nearest_low(a, value): a = np.asarray(a) if a.ndim > 1: - logging.warning(f'Illegal input array ({a}, {type(a)})') + raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})') index = int(np.argmin(np.abs(a-value))) if value < a[index] and index > 0: index -= 1 - return index + return(index) def index_nearest_upp(a, value): a = np.asarray(a) if a.ndim > 1: - logging.warning(f'Illegal input array ({a}, {type(a)})') + raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})') index = int(np.argmin(np.abs(a-value))) if value > a[index] and index < a.size-1: index += 1 - return index + return(index) def round_to_n(x, n=1): if x == 0.0: - return 0 + return(0) else: - return round(x, n-1-int(np.floor(np.log10(abs(x))))) + return(type(x)(round(x, n-1-int(np.floor(np.log10(abs(x))))))) def round_up_to_n(x, n=1): xr = round_to_n(x, n) if abs(x/xr) > 1.0: xr += np.sign(x)*10**(np.floor(np.log10(abs(x)))+1-n) - return xr + return(type(x)(xr)) def trunc_to_n(x, n=1): xr = round_to_n(x, n) if abs(xr/x) > 1.0: xr -= np.sign(x)*10**(np.floor(np.log10(abs(x)))+1-n) - return xr + return(type(x)(xr)) -def string_to_list(s): +def almost_equal(a, b, sig_figs): + if is_num(a) and is_num(b): + return(abs(round_to_n(a-b, sig_figs)) < pow(10, -sig_figs+1)) + else: + raise ValueError(f'Invalid value for a or b in almost_equal (a: {a}, {type(a)}, '+ + f'b: {b}, {type(b)})') + return(False) + +def string_to_list(s, split_on_dash=True, remove_duplicates=True, sort=True): """Return a list of numbers by splitting/expanding a string on any combination of - dashes, commas, and/or whitespaces - e.g: '1, 3, 5-8,12 ' -> [1, 3, 5, 6, 7, 8, 12] + commas, whitespaces, or dashes (when split_on_dash=True) + e.g: '1, 3, 5-8, 12 ' -> [1, 3, 5, 6, 7, 8, 12] """ if not isinstance(s, str): illegal_value(s, location='string_to_list') - return None + return(None) if not len(s): - return [] - try: - list1 = [x for x in re.split('\s+,\s+|\s+,|,\s+|\s+|,', s.strip())] - except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): - return None + return([]) try: - l = [] - for l1 in list1: - l2 = [literal_eval(x) for x in re.split('\s+-\s+|\s+-|-\s+|\s+|-', l1)] - if len(l2) == 1: - l += l2 - elif len(l2) == 2 and l2[1] > l2[0]: - l += [i for i in range(l2[0], l2[1]+1)] - else: - raise ValueError + ll = [x for x in re.split('\s+,\s+|\s+,|,\s+|\s+|,', s.strip())] except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): - return None - return sorted(set(l)) + return(None) + if split_on_dash: + try: + l = [] + for l1 in ll: + l2 = [literal_eval(x) for x in re.split('\s+-\s+|\s+-|-\s+|\s+|-', l1)] + if len(l2) == 1: + l += l2 + elif len(l2) == 2 and l2[1] > l2[0]: + l += [i for i in range(l2[0], l2[1]+1)] + else: + raise ValueError + except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): + return(None) + else: + l = [literal_eval(x) for x in ll] + if remove_duplicates: + l = list(dict.fromkeys(l)) + if sort: + l = sorted(l) + return(l) def get_trailing_int(string): indexRegex = re.compile(r'\d+$') mo = indexRegex.search(string) if mo is None: - return None + return(None) else: - return int(mo.group()) + return(int(mo.group())) + +def input_int(s=None, ge=None, gt=None, le=None, lt=None, default=None, inset=None, + raise_error=False, log=True): + return(_input_int_or_num('int', s, ge, gt, le, lt, default, inset, raise_error, log)) + +def input_num(s=None, ge=None, gt=None, le=None, lt=None, default=None, raise_error=False, + log=True): + return(_input_int_or_num('num', s, ge, gt, le, lt, default, None, raise_error,log)) -def input_int(s=None, v_min=None, v_max=None, default=None, inset=None): +def _input_int_or_num(type_str, s=None, ge=None, gt=None, le=None, lt=None, default=None, + inset=None, raise_error=False, log=True): + if type_str == 'int': + if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, '_input_int_or_num', raise_error, log): + return(None) + elif type_str == 'num': + if not test_ge_gt_le_lt(ge, gt, le, lt, is_num, '_input_int_or_num', raise_error, log): + return(None) + else: + illegal_value(type_str, 'type_str', '_input_int_or_num', raise_error, log) + return(None) if default is not None: - if not isinstance(default, int): - illegal_value(default, 'default', 'input_int') - return None + if not _is_int_or_num(default, type_str, raise_error=raise_error, log=log): + return(None) + if ge is not None and default < ge: + illegal_combination(ge, 'ge', default, 'default', '_input_int_or_num', raise_error, + log) + return(None) + if gt is not None and default <= gt: + illegal_combination(gt, 'gt', default, 'default', '_input_int_or_num', raise_error, + log) + return(None) + if le is not None and default > le: + illegal_combination(le, 'le', default, 'default', '_input_int_or_num', raise_error, + log) + return(None) + if lt is not None and default >= lt: + illegal_combination(lt, 'lt', default, 'default', '_input_int_or_num', raise_error, + log) + return(None) default_string = f' [{default}]' else: default_string = '' - if v_min is not None: - if not isinstance(v_min, int): - illegal_value(v_min, 'v_min', 'input_int') - return None - if default is not None and default < v_min: - logging.error('Illegal v_min, default combination ({v_min}, {default})') - return None - if v_max is not None: - if not isinstance(v_max, int): - illegal_value(v_max, 'v_max', 'input_int') - return None - if v_min is not None and v_min > v_max: - logging.error(f'Illegal v_min, v_max combination ({v_min}, {v_max})') - return None - if default is not None and default > v_max: - logging.error('Illegal default, v_max combination ({default}, {v_max})') - return None if inset is not None: - if (not isinstance(inset, (tuple, list)) or False in [True if isinstance(i, int) else - False for i in inset]): - illegal_value(inset, 'inset', 'input_int') - return None - if v_min is not None and v_max is not None: - v_range = f' ({v_min}, {v_max})' - elif v_min is not None: - v_range = f' (>= {v_min})' - elif v_max is not None: - v_range = f' (<= {v_max})' - else: - v_range = '' + if (not isinstance(inset, (tuple, list)) or any(True if not isinstance(i, int) else + False for i in inset)): + illegal_value(inset, 'inset', '_input_int_or_num', raise_error, log) + return(None) + v_range = f'{range_string_ge_gt_le_lt(ge, gt, le, lt)}' + if len(v_range): + v_range = f' {v_range}' if s is None: - print(f'Enter an integer{v_range}{default_string}: ') + if type_str == 'int': + print(f'Enter an integer{v_range}{default_string}: ') + else: + print(f'Enter a number{v_range}{default_string}: ') else: print(f'{s}{v_range}{default_string}: ') try: @@ -342,116 +515,90 @@ except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): v = None except: - print('Unexpected error') - raise - if not is_int(v, v_min, v_max): - print('Illegal input, enter a valid integer') - v = input_int(s, v_min, v_max, default) - return v + if log: + logger.error('Unexpected error') + if raise_error: + raise ValueError('Unexpected error') + if not _is_int_or_num(v, type_str, ge, gt, le, lt): + v = _input_int_or_num(type_str, s, ge, gt, le, lt, default, inset, raise_error, log) + return(v) + +def input_int_list(s=None, ge=None, le=None, split_on_dash=True, remove_duplicates=True, + sort=True, raise_error=False, log=True): + """Prompt the user to input a list of interger and split the entered string on any combination + of commas, whitespaces, or dashes (when split_on_dash is True) + e.g: '1 3,5-8 , 12 ' -> [1, 3, 5, 6, 7, 8, 12] + remove_duplicates: removes duplicates if True (may also change the order) + sort: sort in ascending order if True + return None upon an illegal input + """ + return(_input_int_or_num_list('int', s, ge, le, split_on_dash, remove_duplicates, sort, + raise_error, log)) -def input_num(s=None, v_min=None, v_max=None, default=None): - if default is not None: - if not isinstance(default, (int, float)): - illegal_value(default, 'default', 'input_num') - return None - default_string = f' [{default}]' - else: - default_string = '' - if v_min is not None: - if not isinstance(v_min, (int, float)): - illegal_value(vmin, 'vmin', 'input_num') - return None - if default is not None and default < v_min: - logging.error('Illegal v_min, default combination ({v_min}, {default})') - return None - if v_max is not None: - if not isinstance(v_max, (int, float)): - illegal_value(vmax, 'vmax', 'input_num') - return None - if v_min is not None and v_max < v_min: - logging.error(f'Illegal v_min, v_max combination ({v_min}, {v_max})') - return None - if default is not None and default > v_max: - logging.error('Illegal default, v_max combination ({default}, {v_max})') - return None - if v_min is not None and v_max is not None: - v_range = f' ({v_min}, {v_max})' - elif v_min is not None: - v_range = f' (>= {v_min})' - elif v_max is not None: - v_range = f' (<= {v_max})' +def input_num_list(s=None, ge=None, le=None, remove_duplicates=True, sort=True, raise_error=False, + log=True): + """Prompt the user to input a list of numbers and split the entered string on any combination + of commas or whitespaces + e.g: '1.0, 3, 5.8, 12 ' -> [1.0, 3.0, 5.8, 12.0] + remove_duplicates: removes duplicates if True (may also change the order) + sort: sort in ascending order if True + return None upon an illegal input + """ + return(_input_int_or_num_list('num', s, ge, le, False, remove_duplicates, sort, raise_error, + log)) + +def _input_int_or_num_list(type_str, s=None, ge=None, le=None, split_on_dash=True, + remove_duplicates=True, sort=True, raise_error=False, log=True): + #FIX do we want a limit on max dimension? + if type_str == 'int': + if not test_ge_gt_le_lt(ge, None, le, None, is_int, 'input_int_or_num_list', raise_error, + log): + return(None) + elif type_str == 'num': + if not test_ge_gt_le_lt(ge, None, le, None, is_num, 'input_int_or_num_list', raise_error, + log): + return(None) else: - v_range = '' - if s is None: - print(f'Enter a number{v_range}{default_string}: ') - else: - print(f'{s}{v_range}{default_string}: ') - try: - i = input() - if isinstance(i, str) and not len(i): - v = default - print(f'{v}') - else: - v = literal_eval(i) - except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): - v = None - except: - print('Unexpected error') - raise - if not is_num(v, v_min, v_max): - print('Illegal input, enter a valid number') - v = input_num(s, v_min, v_max, default) - return v - -def input_int_list(s=None, v_min=None, v_max=None): - if v_min is not None and not isinstance(v_min, int): - illegal_value(vmin, 'vmin', 'input_int_list') - return None - if v_max is not None: - if not isinstance(v_max, int): - illegal_value(vmax, 'vmax', 'input_int_list') - return None - if v_max < v_min: - logging.error(f'Illegal v_min, v_max combination ({v_min}, {v_max})') - return None - if v_min is not None and v_max is not None: - v_range = f' (each value in ({v_min}, {v_max}))' - elif v_min is not None: - v_range = f' (each value >= {v_min})' - elif v_max is not None: - v_range = f' (each value <= {v_max})' - else: - v_range = '' + illegal_value(type_str, 'type_str', '_input_int_or_num_list') + return(None) + v_range = f'{range_string_ge_gt_le_lt(ge=ge, le=le)}' + if len(v_range): + v_range = f' (each value in {v_range})' if s is None: print(f'Enter a series of integers{v_range}: ') else: print(f'{s}{v_range}: ') try: - l = string_to_list(input()) + l = string_to_list(input(), split_on_dash, remove_duplicates, sort) except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): l = None except: print('Unexpected error') raise if (not isinstance(l, list) or - True in [True if not is_int(v, v_min, v_max) else False for v in l]): - print('Illegal input: enter a valid set of dash/comma/whitespace separated integers '+ - 'e.g. 2,3,5-8,10') - l = input_int_list(s, v_min, v_max) - return l + any(True if not _is_int_or_num(v, type_str, ge=ge, le=le) else False for v in l)): + if split_on_dash: + print('Invalid input: enter a valid set of dash/comma/whitespace separated integers '+ + 'e.g. 1 3,5-8 , 12') + else: + print('Invalid input: enter a valid set of comma/whitespace separated integers '+ + 'e.g. 1 3,5 8 , 12') + l = _input_int_or_num_list(type_str, s, ge, le, split_on_dash, remove_duplicates, sort, + raise_error, log) + return(l) def input_yesno(s=None, default=None): if default is not None: if not isinstance(default, str): illegal_value(default, 'default', 'input_yesno') - return None + return(None) if default.lower() in 'yes': default = 'y' elif default.lower() in 'no': default = 'n' else: illegal_value(default, 'default', 'input_yesno') - return None + return(None) default_string = f' [{default}]' else: default_string = '' @@ -468,19 +615,19 @@ elif i is not None and i.lower() in 'no': v = False else: - print('Illegal input, enter yes or no') + print('Invalid input, enter yes or no') v = input_yesno(s, default) - return v + return(v) def input_menu(items, default=None, header=None): - if not isinstance(items, (tuple, list)) or False in [True if isinstance(i, str) else False - for i in items]: + if not isinstance(items, (tuple, list)) or any(True if not isinstance(i, str) else False + for i in items): illegal_value(items, 'items', 'input_menu') - return None + return(None) if default is not None: if not (isinstance(default, str) and default in items): - logging.error(f'Illegal value for default ({default}), must be in {items}') - return None + logger.error(f'Invalid value for default ({default}), must be in {items}') + return(None) default_string = f' [{items.index(default)+1}]' else: default_string = '' @@ -507,38 +654,283 @@ print('Unexpected error') raise if choice is None: - print(f'Illegal choice, enter a number between 1 and {len(items)}') + print(f'Invalid choice, enter a number between 1 and {len(items)}') choice = input_menu(items, default) - return choice + return(choice) + +def assert_no_duplicates_in_list_of_dicts(l: list, raise_error=False) -> list: + if not isinstance(l, list): + illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error) + return(None) + if any(True if not isinstance(d, dict) else False for d in l): + illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error) + return(None) + if len(l) != len([dict(t) for t in {tuple(sorted(d.items())) for d in l}]): + if raise_error: + raise ValueError(f'Duplicate items found in {l}') + else: + logger.error(f'Duplicate items found in {l}') + return(None) + else: + return(l) -def create_mask(x, bounds=None, reverse_mask=False, current_mask=None): +def assert_no_duplicate_key_in_list_of_dicts(l: list, key: str, raise_error=False) -> list: + if not isinstance(key, str): + illegal_value(key, 'key', 'assert_no_duplicate_key_in_list_of_dicts', raise_error) + return(None) + if not isinstance(l, list): + illegal_value(l, 'l', 'assert_no_duplicate_key_in_list_of_dicts', raise_error) + return(None) + if any(True if not isinstance(d, dict) else False for d in l): + illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error) + return(None) + keys = [d.get(key, None) for d in l] + if None in keys or len(set(keys)) != len(l): + if raise_error: + raise ValueError(f'Duplicate or missing key ({key}) found in {l}') + else: + logger.error(f'Duplicate or missing key ({key}) found in {l}') + return(None) + else: + return(l) + +def assert_no_duplicate_attr_in_list_of_objs(l: list, attr: str, raise_error=False) -> list: + if not isinstance(attr, str): + illegal_value(attr, 'attr', 'assert_no_duplicate_attr_in_list_of_objs', raise_error) + return(None) + if not isinstance(l, list): + illegal_value(l, 'l', 'assert_no_duplicate_key_in_list_of_objs', raise_error) + return(None) + attrs = [getattr(obj, attr, None) for obj in l] + if None in attrs or len(set(attrs)) != len(l): + if raise_error: + raise ValueError(f'Duplicate or missing attr ({attr}) found in {l}') + else: + logger.error(f'Duplicate or missing attr ({attr}) found in {l}') + return(None) + else: + return(l) + +def file_exists_and_readable(path): + if not os.path.isfile(path): + raise ValueError(f'{path} is not a valid file') + elif not os.access(path, os.R_OK): + raise ValueError(f'{path} is not accessible for reading') + else: + return(path) + +def create_mask(x, bounds=None, exclude_bounds=False, current_mask=None): # bounds is a pair of number in the same units a x if not isinstance(x, (tuple, list, np.ndarray)) or not len(x): - logging.warning(f'Illegal input array ({x}, {type(x)})') - return None + logger.warning(f'Invalid input array ({x}, {type(x)})') + return(None) if bounds is not None and not is_num_pair(bounds): - logging.warning(f'Illegal bounds parameter ({bounds} {type(bounds)}, input ignored') + logger.warning(f'Invalid bounds parameter ({bounds} {type(bounds)}, input ignored') bounds = None if bounds is not None: - if not reverse_mask: + if exclude_bounds: + mask = np.logical_or(x < min(bounds), x > max(bounds)) + else: mask = np.logical_and(x > min(bounds), x < max(bounds)) - else: - mask = np.logical_or(x < min(bounds), x > max(bounds)) else: mask = np.ones(len(x), dtype=bool) if current_mask is not None: if not isinstance(current_mask, (tuple, list, np.ndarray)) or len(current_mask) != len(x): - logging.warning(f'Illegal current_mask ({current_mask}, {type(current_mask)}), '+ + logger.warning(f'Invalid current_mask ({current_mask}, {type(current_mask)}), '+ 'input ignored') else: - mask = np.logical_and(mask, current_mask) + mask = np.logical_or(mask, current_mask) if not True in mask: - logging.warning('Entire data array is masked') - return mask + logger.warning('Entire data array is masked') + return(mask) + +def eval_expr(name, expr, expr_variables, user_variables=None, max_depth=10, raise_error=False, + log=True, **kwargs): + """Evaluate an expression of expressions + """ + if not isinstance(name, str): + illegal_value(name, 'name', 'eval_expr', raise_error, log) + return(None) + if not isinstance(expr, str): + illegal_value(expr, 'expr', 'eval_expr', raise_error, log) + return(None) + if not is_dict_strings(expr_variables, log=False): + illegal_value(expr_variables, 'expr_variables', 'eval_expr', raise_error, log) + return(None) + if user_variables is not None and not is_dict_nums(user_variables, log=False): + illegal_value(user_variables, 'user_variables', 'eval_expr', raise_error, log) + return(None) + if not is_int(max_depth, gt=1, log=False): + illegal_value(max_depth, 'max_depth', 'eval_expr', raise_error, log) + return(None) + if not isinstance(raise_error, bool): + illegal_value(raise_error, 'raise_error', 'eval_expr', raise_error, log) + return(None) + if not isinstance(log, bool): + illegal_value(log, 'log', 'eval_expr', raise_error, log) + return(None) +# print(f'\nEvaluate the full expression for {expr}') + if 'chain' in kwargs: + chain = kwargs.pop('chain') + if not is_str_series(chain): + illegal_value(chain, 'chain', 'eval_expr', raise_error, log) + return(None) + else: + chain = [] + if len(chain) > max_depth: + error_msg = 'Exceeded maximum depth ({max_depth}) in eval_expr' + if log: + logger.error(error_msg) + if raise_error: + raise ValueError(error_msg) + return(None) + if name not in chain: + chain.append(name) +# print(f'start: chain = {chain}') + if 'ast' in kwargs: + ast = kwargs.pop('ast') + else: + ast = Interpreter() + if user_variables is not None: + ast.symtable.update(user_variables) + chain_vars = [var for var in get_ast_names(ast.parse(expr)) + if var in expr_variables and var not in ast.symtable] +# print(f'chain_vars: {chain_vars}') + save_chain = chain.copy() + for var in chain_vars: +# print(f'\n\tname = {name}, var = {var}:\n\t\t{expr_variables[var]}') +# print(f'\tchain = {chain}') + if var in chain: + error_msg = f'Circular variable {var} in eval_expr' + if log: + logger.error(error_msg) + if raise_error: + raise ValueError(error_msg) + return(None) +# print(f'\tknown symbols:\n\t\t{ast.user_defined_symbols()}\n') + if var in ast.user_defined_symbols(): + val = ast.symtable[var] + else: + #val = eval_expr(var, expr_variables[var], expr_variables, user_variables=user_variables, + val = eval_expr(var, expr_variables[var], expr_variables, max_depth=max_depth, + raise_error=raise_error, log=log, chain=chain, ast=ast) + if val is None: + return(None) + ast.symtable[var] = val +# print(f'\tval = {val}') +# print(f'\t{var} = {ast.symtable[var]}') + chain = save_chain.copy() +# print(f'\treset loop for {var}: chain = {chain}') + val = ast.eval(expr) +# print(f'return val for {expr} = {val}\n') + return(val) + +def full_gradient(expr, x, expr_name=None, expr_variables=None, valid_variables=None, max_depth=10, + raise_error=False, log=True, **kwargs): + """Compute the full gradient dexpr/dx + """ + if not isinstance(x, str): + illegal_value(x, 'x', 'full_gradient', raise_error, log) + return(None) + if expr_name is not None and not isinstance(expr_name, str): + illegal_value(expr_name, 'expr_name', 'eval_expr', raise_error, log) + return(None) + if expr_variables is not None and not is_dict_strings(expr_variables, log=False): + illegal_value(expr_variables, 'expr_variables', 'full_gradient', raise_error, log) + return(None) + if valid_variables is not None and not is_str_series(valid_variables, log=False): + illegal_value(valid_variables, 'valid_variables', 'full_gradient', raise_error, log) + if not is_int(max_depth, gt=1, log=False): + illegal_value(max_depth, 'max_depth', 'eval_expr', raise_error, log) + return(None) + if not isinstance(raise_error, bool): + illegal_value(raise_error, 'raise_error', 'eval_expr', raise_error, log) + return(None) + if not isinstance(log, bool): + illegal_value(log, 'log', 'eval_expr', raise_error, log) + return(None) +# print(f'\nGet full gradient of {expr_name} = {expr} with respect to {x}') + if expr_name is not None and expr_name == x: + return(1.0) + if 'chain' in kwargs: + chain = kwargs.pop('chain') + if not is_str_series(chain): + illegal_value(chain, 'chain', 'eval_expr', raise_error, log) + return(None) + else: + chain = [] + if len(chain) > max_depth: + error_msg = 'Exceeded maximum depth ({max_depth}) in eval_expr' + if log: + logger.error(error_msg) + if raise_error: + raise ValueError(error_msg) + return(None) + if expr_name is not None and expr_name not in chain: + chain.append(expr_name) +# print(f'start ({x}): chain = {chain}') + ast = Interpreter() + if expr_variables is None: + chain_vars = [] + else: + chain_vars = [var for var in get_ast_names(ast.parse(f'{expr}')) + if var in expr_variables and var != x and var not in ast.symtable] +# print(f'chain_vars: {chain_vars}') + if valid_variables is not None: + unknown_vars = [var for var in chain_vars if var not in valid_variables] + if len(unknown_vars): + error_msg = f'Unknown variable {unknown_vars} in {expr}' + if log: + logger.error(error_msg) + if raise_error: + raise ValueError(error_msg) + return(None) + dexpr_dx = diff(expr, x) +# print(f'direct gradient: d({expr})/d({x}) = {dexpr_dx} ({type(dexpr_dx)})') + save_chain = chain.copy() + for var in chain_vars: +# print(f'\n\texpr_name = {expr_name}, var = {var}:\n\t\t{expr}') +# print(f'\tchain = {chain}') + if var in chain: + error_msg = f'Circular variable {var} in full_gradient' + if log: + logger.error(error_msg) + if raise_error: + raise ValueError(error_msg) + return(None) + dexpr_dvar = diff(expr, var) +# print(f'\td({expr})/d({var}) = {dexpr_dvar}') + if dexpr_dvar: + dvar_dx = full_gradient(expr_variables[var], x, expr_name=var, + expr_variables=expr_variables, valid_variables=valid_variables, + max_depth=max_depth, raise_error=raise_error, log=log, chain=chain) +# print(f'\t\td({var})/d({x}) = {dvar_dx}') + if dvar_dx: + dexpr_dx = f'{dexpr_dx}+({dexpr_dvar})*({dvar_dx})' +# print(f'\t\t2: chain = {chain}') + chain = save_chain.copy() +# print(f'\treset loop for {var}: chain = {chain}') +# print(f'full gradient: d({expr})/d({x}) = {dexpr_dx} ({type(dexpr_dx)})') +# print(f'reset end: chain = {chain}\n\n') + return(simplify(dexpr_dx)) + +def bounds_from_mask(mask, return_include_bounds:bool=True): + bounds = [] + for i, m in enumerate(mask): + if m == return_include_bounds: + if len(bounds) == 0 or type(bounds[-1]) == tuple: + bounds.append(i) + else: + if len(bounds) > 0 and isinstance(bounds[-1], int): + bounds[-1] = (bounds[-1], i-1) + if len(bounds) > 0 and isinstance(bounds[-1], int): + bounds[-1] = (bounds[-1], mask.size-1) + return(bounds) def draw_mask_1d(ydata, xdata=None, current_index_ranges=None, current_mask=None, select_mask=True, num_index_ranges_max=None, title=None, legend=None, test_mode=False): - def draw_selections(ax): + #FIX make color blind friendly + def draw_selections(ax, current_include, current_exclude, selected_index_ranges): ax.clear() ax.set_title(title) ax.legend([legend]) @@ -570,26 +962,32 @@ selected_index_ranges[-1] = (selected_index_ranges[-1], event.xdata) else: selected_index_ranges[-1] = (event.xdata, selected_index_ranges[-1]) - draw_selections(event.inaxes) + draw_selections(event.inaxes, current_include, current_exclude, selected_index_ranges) else: selected_index_ranges.pop(-1) def confirm_selection(event): plt.close() - + def clear_last_selection(event): if len(selected_index_ranges): selected_index_ranges.pop(-1) - draw_selections(ax) + else: + while len(current_include): + current_include.pop() + while len(current_exclude): + current_exclude.pop() + selected_mask.fill(False) + draw_selections(ax, current_include, current_exclude, selected_index_ranges) - def update_mask(mask): + def update_mask(mask, selected_index_ranges, unselected_index_ranges): for (low, upp) in selected_index_ranges: selected_mask = np.logical_and(xdata >= xdata[low], xdata <= xdata[upp]) mask = np.logical_or(mask, selected_mask) for (low, upp) in unselected_index_ranges: unselected_mask = np.logical_and(xdata >= xdata[low], xdata <= xdata[upp]) mask[unselected_mask] = False - return mask + return(mask) def update_index_ranges(mask): # Update the currently included index ranges (where mask is True) @@ -603,34 +1001,34 @@ current_include[-1] = (current_include[-1], i-1) if len(current_include) > 0 and isinstance(current_include[-1], int): current_include[-1] = (current_include[-1], num_data-1) - return current_include + return(current_include) - # Check for valid inputs + # Check inputs ydata = np.asarray(ydata) if ydata.ndim > 1: - logging.warning(f'Illegal ydata dimension ({ydata.ndim})') - return None, None + logger.warning(f'Invalid ydata dimension ({ydata.ndim})') + return(None, None) num_data = ydata.size if xdata is None: xdata = np.arange(num_data) else: xdata = np.asarray(xdata, dtype=np.float64) if xdata.ndim > 1 or xdata.size != num_data: - logging.warning(f'Illegal xdata shape ({xdata.shape})') - return None, None + logger.warning(f'Invalid xdata shape ({xdata.shape})') + return(None, None) if not np.all(xdata[:-1] < xdata[1:]): - logging.warning('Illegal xdata: must be monotonically increasing') - return None, None + logger.warning('Invalid xdata: must be monotonically increasing') + return(None, None) if current_index_ranges is not None: if not isinstance(current_index_ranges, (tuple, list)): - logging.warning('Illegal current_index_ranges parameter ({current_index_ranges}, '+ + logger.warning('Invalid current_index_ranges parameter ({current_index_ranges}, '+ f'{type(current_index_ranges)})') - return None, None + return(None, None) if not isinstance(select_mask, bool): - logging.warning('Illegal select_mask parameter ({select_mask}, {type(select_mask)})') - return None, None + logger.warning('Invalid select_mask parameter ({select_mask}, {type(select_mask)})') + return(None, None) if num_index_ranges_max is not None: - logging.warning('num_index_ranges_max input not yet implemented in draw_mask_1d') + logger.warning('num_index_ranges_max input not yet implemented in draw_mask_1d') if title is None: title = 'select ranges of data' elif not isinstance(title, str): @@ -668,7 +1066,7 @@ if upp >= num_data: upp = num_data-1 selected_index_ranges.append((low, upp)) - selected_mask = update_mask(selected_mask) + selected_mask = update_mask(selected_mask, selected_index_ranges, unselected_index_ranges) if current_index_ranges is not None and current_mask is not None: selected_mask = np.logical_and(current_mask, selected_mask) if current_mask is not None: @@ -697,7 +1095,7 @@ plt.close('all') fig, ax = plt.subplots() plt.subplots_adjust(bottom=0.2) - draw_selections(ax) + draw_selections(ax, current_include, current_exclude, selected_index_ranges) # Set up event handling for click-and-drag range selection cid_click = fig.canvas.mpl_connect('button_press_event', onclick) @@ -724,251 +1122,364 @@ selected_index_ranges # Update the mask with the currently selected/unselected x-ranges - selected_mask = update_mask(selected_mask) + selected_mask = update_mask(selected_mask, selected_index_ranges, unselected_index_ranges) # Update the currently included index ranges (where mask is True) current_include = update_index_ranges(selected_mask) + + return(selected_mask, current_include) - return selected_mask, current_include +def select_peaks(ydata:np.ndarray, x_values:np.ndarray=None, x_mask:np.ndarray=None, + peak_x_values:np.ndarray=np.array([]), peak_x_indices:np.ndarray=np.array([]), + return_peak_x_values:bool=False, return_peak_x_indices:bool=False, + return_peak_input_indices:bool=False, return_sorted:bool=False, + title:str=None, xlabel:str=None, ylabel:str=None) -> list : + + # Check arguments + if (len(peak_x_values) > 0 or return_peak_x_values) and not len(x_values) > 0: + raise RuntimeError('Cannot use peak_x_values or return_peak_x_values without x_values') + if not ((len(peak_x_values) > 0) ^ (len(peak_x_indices) > 0)): + raise RuntimeError('Use exactly one of peak_x_values or peak_x_indices') + return_format_iter = iter((return_peak_x_values, return_peak_x_indices, return_peak_input_indices)) + if not (any(return_format_iter) and not any(return_format_iter)): + raise RuntimeError('Exactly one of return_peak_x_values, return_peak_x_indices, or '+ + 'return_peak_input_indices must be True') + + EXCLUDE_PEAK_PROPERTIES = {'color': 'black', 'linestyle': '--','linewidth': 1, + 'marker': 10, 'markersize': 5, 'fillstyle': 'none'} + INCLUDE_PEAK_PROPERTIES = {'color': 'green', 'linestyle': '-', 'linewidth': 2, + 'marker': 10, 'markersize': 10, 'fillstyle': 'full'} + MASKED_PEAK_PROPERTIES = {'color': 'gray', 'linestyle': ':', 'linewidth': 1} + + # Setup reference data & plot + x_indices = np.arange(len(ydata)) + if x_values is None: + x_values = x_indices + if x_mask is None: + x_mask = np.full(x_values.shape, True, dtype=bool) + fig, ax = plt.subplots() + handles = ax.plot(x_values, ydata, label='Reference data') + handles.append(mlines.Line2D([], [], label='Excluded / unselected HKL', **EXCLUDE_PEAK_PROPERTIES)) + handles.append(mlines.Line2D([], [], label='Included / selected HKL', **INCLUDE_PEAK_PROPERTIES)) + handles.append(mlines.Line2D([], [], label='HKL in masked region (unselectable)', **MASKED_PEAK_PROPERTIES)) + ax.legend(handles=handles, loc='upper right') + ax.set(title=title, xlabel=xlabel, ylabel=ylabel) + + + # Plot vertical line at each peak + value_to_index = lambda x_value: int(np.argmin(abs(x_values - x_value))) + if len(peak_x_indices) > 0: + peak_x_values = x_values[peak_x_indices] + else: + peak_x_indices = np.array(list(map(value_to_index, peak_x_values))) + peak_vlines = [] + for loc in peak_x_values: + nearest_index = value_to_index(loc) + if nearest_index in x_indices[x_mask]: + peak_vline = ax.axvline(loc, **EXCLUDE_PEAK_PROPERTIES) + peak_vline.set_picker(5) + else: + peak_vline = ax.axvline(loc, **MASKED_PEAK_PROPERTIES) + peak_vlines.append(peak_vline) -def findImageFiles(path, filetype, name=None): + # Indicate masked regions by gray-ing out the axes facecolor + mask_exclude_bounds = bounds_from_mask(x_mask, return_include_bounds=False) + for (low, upp) in mask_exclude_bounds: + xlow = x_values[low] + xupp = x_values[upp] + ax.axvspan(xlow, xupp, facecolor='gray', alpha=0.5) + + # Setup peak picking + selected_peak_input_indices = [] + def onpick(event): + try: + peak_index = peak_vlines.index(event.artist) + except: + pass + else: + peak_vline = event.artist + if peak_index in selected_peak_input_indices: + peak_vline.set(**EXCLUDE_PEAK_PROPERTIES) + selected_peak_input_indices.remove(peak_index) + else: + peak_vline.set(**INCLUDE_PEAK_PROPERTIES) + selected_peak_input_indices.append(peak_index) + plt.draw() + cid_pick_peak = fig.canvas.mpl_connect('pick_event', onpick) + + # Setup "Confirm" button + def confirm_selection(event): + plt.close() + plt.subplots_adjust(bottom=0.2) + confirm_b = Button(plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm') + cid_confirm = confirm_b.on_clicked(confirm_selection) + + # Show figure for user interaction + plt.show() + + # Disconnect callbacks when figure is closed + fig.canvas.mpl_disconnect(cid_pick_peak) + confirm_b.disconnect(cid_confirm) + + if return_peak_input_indices: + selected_peaks = np.array(selected_peak_input_indices) + if return_peak_x_values: + selected_peaks = peak_x_values[selected_peak_input_indices] + if return_peak_x_indices: + selected_peaks = peak_x_indices[selected_peak_input_indices] + + if return_sorted: + selected_peaks.sort() + + return(selected_peaks) + +def find_image_files(path, filetype, name=None): if isinstance(name, str): - name = f' {name} ' + name = f'{name.strip()} ' else: - name = ' ' + name = '' # Find available index range if filetype == 'tif': if not isinstance(path, str) or not os.path.isdir(path): - illegal_value(path, 'path', 'findImageFiles') - return -1, 0, [] + illegal_value(path, 'path', 'find_image_files') + return(-1, 0, []) indexRegex = re.compile(r'\d+') # At this point only tiffs files = sorted([f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) and f.endswith('.tif') and indexRegex.search(f)]) - num_imgs = len(files) - if num_imgs < 1: - logging.warning('No available'+name+'files') - return -1, 0, [] + num_img = len(files) + if num_img < 1: + logger.warning(f'No available {name}files') + return(-1, 0, []) first_index = indexRegex.search(files[0]).group() last_index = indexRegex.search(files[-1]).group() if first_index is None or last_index is None: - logging.error('Unable to find correctly indexed'+name+'images') - return -1, 0, [] + logger.error(f'Unable to find correctly indexed {name}images') + return(-1, 0, []) first_index = int(first_index) last_index = int(last_index) - if num_imgs != last_index-first_index+1: - logging.error('Non-consecutive set of indices for'+name+'images') - return -1, 0, [] + if num_img != last_index-first_index+1: + logger.error(f'Non-consecutive set of indices for {name}images') + return(-1, 0, []) paths = [os.path.join(path, f) for f in files] elif filetype == 'h5': if not isinstance(path, str) or not os.path.isfile(path): - illegal_value(path, 'path', 'findImageFiles') - return -1, 0, [] + illegal_value(path, 'path', 'find_image_files') + return(-1, 0, []) # At this point only h5 in alamo2 detector style first_index = 0 with h5py.File(path, 'r') as f: - num_imgs = f['entry/instrument/detector/data'].shape[0] - last_index = num_imgs-1 + num_img = f['entry/instrument/detector/data'].shape[0] + last_index = num_img-1 paths = [path] else: - illegal_value(filetype, 'filetype', 'findImageFiles') - return -1, 0, [] - logging.debug('\nNumber of available'+name+f'images: {num_imgs}') - logging.debug('Index range of available'+name+f'images: [{first_index}, '+ + illegal_value(filetype, 'filetype', 'find_image_files') + return(-1, 0, []) + logger.info(f'Number of available {name}images: {num_img}') + logger.info(f'Index range of available {name}images: [{first_index}, '+ f'{last_index}]') - return first_index, num_imgs, paths + return(first_index, num_img, paths) -def selectImageRange(first_index, offset, num_imgs, name=None, num_required=None): +def select_image_range(first_index, offset, num_available, num_img=None, name=None, + num_required=None): if isinstance(name, str): - name = f' {name} ' + name = f'{name.strip()} ' else: - name = ' ' + name = '' # Check existing values - use_input = False - if (is_int(first_index, 0) and is_int(offset, 0) and is_int(num_imgs, 1)): - if offset < 0: - use_input = input_yesno(f'\nCurrent{name}first index = {first_index}, '+ - 'use this value (y/n)?', 'y') + if not is_int(num_available, gt=0): + logger.warning(f'No available {name}images') + return(0, 0, 0) + if num_img is not None and not is_int(num_img, ge=0): + illegal_value(num_img, 'num_img', 'select_image_range') + return(0, 0, 0) + if is_int(first_index, ge=0) and is_int(offset, ge=0): + if num_required is None: + if input_yesno(f'\nCurrent {name}first image index/offset = {first_index}/{offset},'+ + 'use these values (y/n)?', 'y'): + if num_img is not None: + if input_yesno(f'Current number of {name}images = {num_img}, '+ + 'use this value (y/n)? ', 'y'): + return(first_index, offset, num_img) + else: + if input_yesno(f'Number of available {name}images = {num_available}, '+ + 'use all (y/n)? ', 'y'): + return(first_index, offset, num_available) else: - use_input = input_yesno(f'\nCurrent{name}first index/offset = '+ - f'{first_index}/{offset}, use these values (y/n)?', 'y') - if num_required is None: - if use_input: - use_input = input_yesno(f'Current number of{name}images = '+ - f'{num_imgs}, use this value (y/n)? ', 'y') - if use_input: - return first_index, offset, num_imgs + if input_yesno(f'\nCurrent {name}first image offset = {offset}, '+ + f'use this values (y/n)?', 'y'): + return(first_index, offset, num_required) # Check range against requirements - if num_imgs < 1: - logging.warning('No available'+name+'images') - return -1, -1, 0 if num_required is None: - if num_imgs == 1: - return first_index, 0, 1 + if num_available == 1: + return(first_index, 0, 1) else: - if not is_int(num_required, 1): - illegal_value(num_required, 'num_required', 'selectImageRange') - return -1, -1, 0 - if num_imgs < num_required: - logging.error('Unable to find the required'+name+ - f'images ({num_imgs} out of {num_required})') - return -1, -1, 0 + if not is_int(num_required, ge=1): + illegal_value(num_required, 'num_required', 'select_image_range') + return(0, 0, 0) + if num_available < num_required: + logger.error(f'Unable to find the required {name}images ({num_available} out of '+ + f'{num_required})') + return(0, 0, 0) # Select index range - print('\nThe number of available'+name+f'images is {num_imgs}') + print(f'\nThe number of available {name}images is {num_available}') if num_required is None: - last_index = first_index+num_imgs + last_index = first_index+num_available use_all = f'Use all ([{first_index}, {last_index}])' - pick_offset = 'Pick a first index offset and a number of images' - pick_bounds = 'Pick the first and last index' + pick_offset = 'Pick the first image index offset and the number of images' + pick_bounds = 'Pick the first and last image index' choice = input_menu([use_all, pick_offset, pick_bounds], default=pick_offset) if not choice: offset = 0 + num_img = num_available elif choice == 1: - offset = input_int('Enter the first index offset', 0, last_index-first_index) - first_index += offset - if first_index == last_index: - num_imgs = 1 + offset = input_int('Enter the first index offset', ge=0, le=last_index-first_index) + if first_index+offset == last_index: + num_img = 1 else: - num_imgs = input_int('Enter the number of images', 1, num_imgs-offset) + num_img = input_int('Enter the number of images', ge=1, le=num_available-offset) else: - offset = input_int('Enter the first index', first_index, last_index) - first_index += offset - num_imgs = input_int('Enter the last index', first_index, last_index)-first_index+1 + offset = input_int('Enter the first index', ge=first_index, le=last_index) + num_img = 1-offset+input_int('Enter the last index', ge=offset, le=last_index) + offset -= first_index else: use_all = f'Use ([{first_index}, {first_index+num_required-1}])' pick_offset = 'Pick the first index offset' choice = input_menu([use_all, pick_offset], pick_offset) offset = 0 if choice == 1: - offset = input_int('Enter the first index offset', 0, num_imgs-num_required) - first_index += offset - num_imgs = num_required + offset = input_int('Enter the first index offset', ge=0, le=num_available-num_required) + num_img = num_required - return first_index, offset, num_imgs + return(first_index, offset, num_img) -def loadImage(f, img_x_bounds=None, img_y_bounds=None): +def load_image(f, img_x_bounds=None, img_y_bounds=None): """Load a single image from file. """ if not os.path.isfile(f): - logging.error(f'Unable to load {f}') - return None + logger.error(f'Unable to load {f}') + return(None) img_read = plt.imread(f) if not img_x_bounds: img_x_bounds = (0, img_read.shape[0]) else: if (not isinstance(img_x_bounds, (tuple, list)) or len(img_x_bounds) != 2 or not (0 <= img_x_bounds[0] < img_x_bounds[1] <= img_read.shape[0])): - logging.error(f'inconsistent row dimension in {f}') - return None + logger.error(f'inconsistent row dimension in {f}') + return(None) if not img_y_bounds: img_y_bounds = (0, img_read.shape[1]) else: if (not isinstance(img_y_bounds, list) or len(img_y_bounds) != 2 or not (0 <= img_y_bounds[0] < img_y_bounds[1] <= img_read.shape[1])): - logging.error(f'inconsistent column dimension in {f}') - return None - return img_read[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]] + logger.error(f'inconsistent column dimension in {f}') + return(None) + return(img_read[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]]) -def loadImageStack(files, filetype, img_offset, num_imgs, num_img_skip=0, +def load_image_stack(files, filetype, img_offset, num_img, num_img_skip=0, img_x_bounds=None, img_y_bounds=None): """Load a set of images and return them as a stack. """ - logging.debug(f'img_offset = {img_offset}') - logging.debug(f'num_imgs = {num_imgs}') - logging.debug(f'num_img_skip = {num_img_skip}') - logging.debug(f'\nfiles:\n{files}\n') + logger.debug(f'img_offset = {img_offset}') + logger.debug(f'num_img = {num_img}') + logger.debug(f'num_img_skip = {num_img_skip}') + logger.debug(f'\nfiles:\n{files}\n') img_stack = np.array([]) if filetype == 'tif': img_read_stack = [] i = 1 t0 = time() - for f in files[img_offset:img_offset+num_imgs:num_img_skip+1]: + for f in files[img_offset:img_offset+num_img:num_img_skip+1]: if not i%20: - logging.info(f' loading {i}/{num_imgs}: {f}') + logger.info(f' loading {i}/{num_img}: {f}') else: - logging.debug(f' loading {i}/{num_imgs}: {f}') - img_read = loadImage(f, img_x_bounds, img_y_bounds) + logger.debug(f' loading {i}/{num_img}: {f}') + img_read = load_image(f, img_x_bounds, img_y_bounds) img_read_stack.append(img_read) i += num_img_skip+1 img_stack = np.stack([img_read for img_read in img_read_stack]) - logging.info(f'... done in {time()-t0:.2f} seconds!') - logging.debug(f'img_stack shape = {np.shape(img_stack)}') + logger.info(f'... done in {time()-t0:.2f} seconds!') + logger.debug(f'img_stack shape = {np.shape(img_stack)}') del img_read_stack, img_read elif filetype == 'h5': if not isinstance(files[0], str) and not os.path.isfile(files[0]): - illegal_value(files[0], 'files[0]', 'loadImageStack') - return img_stack + illegal_value(files[0], 'files[0]', 'load_image_stack') + return(img_stack) t0 = time() - logging.info(f'Loading {files[0]}') + logger.info(f'Loading {files[0]}') with h5py.File(files[0], 'r') as f: shape = f['entry/instrument/detector/data'].shape if len(shape) != 3: - logging.error(f'inconsistent dimensions in {files[0]}') + logger.error(f'inconsistent dimensions in {files[0]}') if not img_x_bounds: img_x_bounds = (0, shape[1]) else: if (not isinstance(img_x_bounds, (tuple, list)) or len(img_x_bounds) != 2 or not (0 <= img_x_bounds[0] < img_x_bounds[1] <= shape[1])): - logging.error(f'inconsistent row dimension in {files[0]} {img_x_bounds} '+ + logger.error(f'inconsistent row dimension in {files[0]} {img_x_bounds} '+ f'{shape[1]}') if not img_y_bounds: img_y_bounds = (0, shape[2]) else: if (not isinstance(img_y_bounds, list) or len(img_y_bounds) != 2 or not (0 <= img_y_bounds[0] < img_y_bounds[1] <= shape[2])): - logging.error(f'inconsistent column dimension in {files[0]}') + logger.error(f'inconsistent column dimension in {files[0]}') img_stack = f.get('entry/instrument/detector/data')[ - img_offset:img_offset+num_imgs:num_img_skip+1, + img_offset:img_offset+num_img:num_img_skip+1, img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]] - logging.info(f'... done in {time()-t0:.2f} seconds!') + logger.info(f'... done in {time()-t0:.2f} seconds!') else: - illegal_value(filetype, 'filetype', 'loadImageStack') - return img_stack + illegal_value(filetype, 'filetype', 'load_image_stack') + return(img_stack) -def combine_tiffs_in_h5(files, num_imgs, h5_filename): - img_stack = loadImageStack(files, 'tif', 0, num_imgs) +def combine_tiffs_in_h5(files, num_img, h5_filename): + img_stack = load_image_stack(files, 'tif', 0, num_img) with h5py.File(h5_filename, 'w') as f: f.create_dataset('entry/instrument/detector/data', data=img_stack) del img_stack - return [h5_filename] + return([h5_filename]) -def clearImshow(title=None): +def clear_imshow(title=None): plt.ioff() if title is None: title = 'quick imshow' elif not isinstance(title, str): - illegal_value(title, 'title', 'clearImshow') + illegal_value(title, 'title', 'clear_imshow') return plt.close(fig=title) -def clearPlot(title=None): +def clear_plot(title=None): plt.ioff() if title is None: title = 'quick plot' elif not isinstance(title, str): - illegal_value(title, 'title', 'clearPlot') + illegal_value(title, 'title', 'clear_plot') return plt.close(fig=title) -def quickImshow(a, title=None, path=None, name=None, save_fig=False, save_only=False, - clear=True, extent=None, show_grid=False, grid_color='w', grid_linewidth=1, **kwargs): +def quick_imshow(a, title=None, path=None, name=None, save_fig=False, save_only=False, + clear=True, extent=None, show_grid=False, grid_color='w', grid_linewidth=1, + block=False, **kwargs): if title is not None and not isinstance(title, str): - illegal_value(title, 'title', 'quickImshow') + illegal_value(title, 'title', 'quick_imshow') return if path is not None and not isinstance(path, str): - illegal_value(path, 'path', 'quickImshow') + illegal_value(path, 'path', 'quick_imshow') return if not isinstance(save_fig, bool): - illegal_value(save_fig, 'save_fig', 'quickImshow') + illegal_value(save_fig, 'save_fig', 'quick_imshow') return if not isinstance(save_only, bool): - illegal_value(save_only, 'save_only', 'quickImshow') + illegal_value(save_only, 'save_only', 'quick_imshow') return if not isinstance(clear, bool): - illegal_value(clear, 'clear', 'quickImshow') + illegal_value(clear, 'clear', 'quick_imshow') + return + if not isinstance(block, bool): + illegal_value(block, 'block', 'quick_imshow') return if not title: title='quick imshow' @@ -985,12 +1496,30 @@ path = name else: path = f'{path}/{name}' + if 'cmap' in kwargs and a.ndim == 3 and (a.shape[2] == 3 or a.shape[2] == 4): + use_cmap = True + if a.shape[2] == 4 and a[:,:,-1].min() != a[:,:,-1].max(): + use_cmap = False + if any(True if a[i,j,0] != a[i,j,1] and a[i,j,0] != a[i,j,2] else False + for i in range(a.shape[0]) for j in range(a.shape[1])): + use_cmap = False + if use_cmap: + a = a[:,:,0] + else: + logger.warning('Image incompatible with cmap option, ignore cmap') + kwargs.pop('cmap') if extent is None: extent = (0, a.shape[1], a.shape[0], 0) if clear: - plt.close(fig=title) + try: + plt.close(fig=title) + except: + pass if not save_only: - plt.ion() + if block: + plt.ioff() + else: + plt.ion() plt.figure(title) plt.imshow(a, extent=extent, **kwargs) if show_grid: @@ -1004,45 +1533,47 @@ else: if save_fig: plt.savefig(path) + if block: + plt.show(block=block) -def quickPlot(*args, xerr=None, yerr=None, vlines=None, title=None, xlim=None, ylim=None, - xlabel=None, ylabel=None, legend=None, path=None, name=None, show_grid=False, +def quick_plot(*args, xerr=None, yerr=None, vlines=None, title=None, xlim=None, ylim=None, + xlabel=None, ylabel=None, legend=None, path=None, name=None, show_grid=False, save_fig=False, save_only=False, clear=True, block=False, **kwargs): if title is not None and not isinstance(title, str): - illegal_value(title, 'title', 'quickPlot') + illegal_value(title, 'title', 'quick_plot') title = None if xlim is not None and not isinstance(xlim, (tuple, list)) and len(xlim) != 2: - illegal_value(xlim, 'xlim', 'quickPlot') + illegal_value(xlim, 'xlim', 'quick_plot') xlim = None if ylim is not None and not isinstance(ylim, (tuple, list)) and len(ylim) != 2: - illegal_value(ylim, 'ylim', 'quickPlot') + illegal_value(ylim, 'ylim', 'quick_plot') ylim = None if xlabel is not None and not isinstance(xlabel, str): - illegal_value(xlabel, 'xlabel', 'quickPlot') + illegal_value(xlabel, 'xlabel', 'quick_plot') xlabel = None if ylabel is not None and not isinstance(ylabel, str): - illegal_value(ylabel, 'ylabel', 'quickPlot') + illegal_value(ylabel, 'ylabel', 'quick_plot') ylabel = None if legend is not None and not isinstance(legend, (tuple, list)): - illegal_value(legend, 'legend', 'quickPlot') + illegal_value(legend, 'legend', 'quick_plot') legend = None if path is not None and not isinstance(path, str): - illegal_value(path, 'path', 'quickPlot') + illegal_value(path, 'path', 'quick_plot') return if not isinstance(show_grid, bool): - illegal_value(show_grid, 'show_grid', 'quickPlot') + illegal_value(show_grid, 'show_grid', 'quick_plot') return if not isinstance(save_fig, bool): - illegal_value(save_fig, 'save_fig', 'quickPlot') + illegal_value(save_fig, 'save_fig', 'quick_plot') return if not isinstance(save_only, bool): - illegal_value(save_only, 'save_only', 'quickPlot') + illegal_value(save_only, 'save_only', 'quick_plot') return if not isinstance(clear, bool): - illegal_value(clear, 'clear', 'quickPlot') + illegal_value(clear, 'clear', 'quick_plot') return if not isinstance(block, bool): - illegal_value(block, 'block', 'quickPlot') + illegal_value(block, 'block', 'quick_plot') return if title is None: title = 'quick plot' @@ -1060,10 +1591,13 @@ else: path = f'{path}/{name}' if clear: - plt.close(fig=title) + try: + plt.close(fig=title) + except: + pass args = unwrap_tuple(args) if depth_tuple(args) > 1 and (xerr is not None or yerr is not None): - logging.warning('Error bars ignored form multiple curves') + logger.warning('Error bars ignored form multiple curves') if not save_only: if block: plt.ioff() @@ -1079,6 +1613,8 @@ else: plt.errorbar(*args, xerr=xerr, yerr=yerr, **kwargs) if vlines is not None: + if isinstance(vlines, (int, float)): + vlines = [vlines] for v in vlines: plt.axvline(v, color='r', linestyle='--', **kwargs) # if vlines is not None: @@ -1106,108 +1642,97 @@ if block: plt.show(block=block) -def selectArrayBounds(a, x_low=None, x_upp=None, num_x_min=None, ask_bounds=False, +def select_array_bounds(a, x_low=None, x_upp=None, num_x_min=None, ask_bounds=False, title='select array bounds'): """Interactively select the lower and upper data bounds for a numpy array. """ if isinstance(a, (tuple, list)): a = np.array(a) if not isinstance(a, np.ndarray) or a.ndim != 1: - illegal_value(a.ndim, 'array type or dimension', 'selectArrayBounds') - return None + illegal_value(a.ndim, 'array type or dimension', 'select_array_bounds') + return(None) len_a = len(a) if num_x_min is None: num_x_min = 1 else: if num_x_min < 2 or num_x_min > len_a: - logging.warning('Illegal value for num_x_min in selectArrayBounds, input ignored') + logger.warning('Invalid value for num_x_min in select_array_bounds, input ignored') num_x_min = 1 # Ask to use current bounds if ask_bounds and (x_low is not None or x_upp is not None): if x_low is None: x_low = 0 - if not is_int(x_low, 0, len_a-num_x_min): - illegal_value(x_low, 'x_low', 'selectArrayBounds') - return None + if not is_int(x_low, ge=0, le=len_a-num_x_min): + illegal_value(x_low, 'x_low', 'select_array_bounds') + return(None) if x_upp is None: x_upp = len_a - if not is_int(x_upp, x_low+num_x_min, len_a): - illegal_value(x_upp, 'x_upp', 'selectArrayBounds') - return None - quickPlot((range(len_a), a), vlines=(x_low,x_upp), title=title) + if not is_int(x_upp, ge=x_low+num_x_min, le=len_a): + illegal_value(x_upp, 'x_upp', 'select_array_bounds') + return(None) + quick_plot((range(len_a), a), vlines=(x_low,x_upp), title=title) if not input_yesno(f'\nCurrent array bounds: [{x_low}, {x_upp}] '+ 'use these values (y/n)?', 'y'): x_low = None x_upp = None else: - clearPlot(title) - return x_low, x_upp + clear_plot(title) + return(x_low, x_upp) if x_low is None: x_min = 0 x_max = len_a x_low_max = len_a-num_x_min while True: - quickPlot(range(x_min, x_max), a[x_min:x_max], title=title) + quick_plot(range(x_min, x_max), a[x_min:x_max], title=title) zoom_flag = input_yesno('Set lower data bound (y) or zoom in (n)?', 'y') if zoom_flag: - x_low = input_int(' Set lower data bound', 0, x_low_max) + x_low = input_int(' Set lower data bound', ge=0, le=x_low_max) break else: - x_min = input_int(' Set lower zoom index', 0, x_low_max) - x_max = input_int(' Set upper zoom index', x_min+1, x_low_max+1) + x_min = input_int(' Set lower zoom index', ge=0, le=x_low_max) + x_max = input_int(' Set upper zoom index', ge=x_min+1, le=x_low_max+1) else: - if not is_int(x_low, 0, len_a-num_x_min): - illegal_value(x_low, 'x_low', 'selectArrayBounds') - return None + if not is_int(x_low, ge=0, le=len_a-num_x_min): + illegal_value(x_low, 'x_low', 'select_array_bounds') + return(None) if x_upp is None: x_min = x_low+num_x_min x_max = len_a x_upp_min = x_min while True: - quickPlot(range(x_min, x_max), a[x_min:x_max], title=title) + quick_plot(range(x_min, x_max), a[x_min:x_max], title=title) zoom_flag = input_yesno('Set upper data bound (y) or zoom in (n)?', 'y') if zoom_flag: - x_upp = input_int(' Set upper data bound', x_upp_min, len_a) + x_upp = input_int(' Set upper data bound', ge=x_upp_min, le=len_a) break else: - x_min = input_int(' Set upper zoom index', x_upp_min, len_a-1) - x_max = input_int(' Set upper zoom index', x_min+1, len_a) + x_min = input_int(' Set upper zoom index', ge=x_upp_min, le=len_a-1) + x_max = input_int(' Set upper zoom index', ge=x_min+1, le=len_a) else: - if not is_int(x_upp, x_low+num_x_min, len_a): - illegal_value(x_upp, 'x_upp', 'selectArrayBounds') - return None + if not is_int(x_upp, ge=x_low+num_x_min, le=len_a): + illegal_value(x_upp, 'x_upp', 'select_array_bounds') + return(None) print(f'lower bound = {x_low} (inclusive)\nupper bound = {x_upp} (exclusive)]') - quickPlot((range(len_a), a), vlines=(x_low,x_upp), title=title) + quick_plot((range(len_a), a), vlines=(x_low,x_upp), title=title) if not input_yesno('Accept these bounds (y/n)?', 'y'): - x_low, x_upp = selectArrayBounds(a, None, None, num_x_min, title=title) - clearPlot(title) - return x_low, x_upp + x_low, x_upp = select_array_bounds(a, None, None, num_x_min, title=title) + clear_plot(title) + return(x_low, x_upp) -def selectImageBounds(a, axis, low=None, upp=None, num_min=None, - title='select array bounds'): +def select_image_bounds(a, axis, low=None, upp=None, num_min=None, title='select array bounds', + raise_error=False): """Interactively select the lower and upper data bounds for a 2D numpy array. """ - if isinstance(a, np.ndarray): - if a.ndim != 2: - illegal_value(a.ndim, 'array dimension', 'selectImageBounds') - return None - elif isinstance(a, (tuple, list)): - if len(a) != 2: - illegal_value(len(a), 'array dimension', 'selectImageBounds') - return None - if len(a[0]) != len(a[1]) or not (isinstance(a[0], (tuple, list, np.ndarray)) and - isinstance(a[1], (tuple, list, np.ndarray))): - logging.error(f'Illegal array type in selectImageBounds ({type(a[0])} {type(a[1])})') - return None - a = np.array(a) - else: - illegal_value(a, 'array type', 'selectImageBounds') - return None + a = np.asarray(a) + if a.ndim != 2: + illegal_value(a.ndim, 'array dimension', location='select_image_bounds', + raise_error=raise_error) + return(None) if axis < 0 or axis >= a.ndim: - illegal_value(axis, 'axis', 'selectImageBounds') - return None + illegal_value(axis, 'axis', location='select_image_bounds', raise_error=raise_error) + return(None) low_save = low upp_save = upp num_min_save = num_min @@ -1215,7 +1740,7 @@ num_min = 1 else: if num_min < 2 or num_min > a.shape[axis]: - logging.warning('Illegal input for num_min in selectImageBounds, input ignored') + logger.warning('Invalid input for num_min in select_image_bounds, input ignored') num_min = 1 if low is None: min_ = 0 @@ -1223,44 +1748,44 @@ low_max = a.shape[axis]-num_min while True: if axis: - quickImshow(a[:,min_:max_], title=title, aspect='auto', + quick_imshow(a[:,min_:max_], title=title, aspect='auto', extent=[min_,max_,a.shape[0],0]) else: - quickImshow(a[min_:max_,:], title=title, aspect='auto', + quick_imshow(a[min_:max_,:], title=title, aspect='auto', extent=[0,a.shape[1], max_,min_]) zoom_flag = input_yesno('Set lower data bound (y) or zoom in (n)?', 'y') if zoom_flag: - low = input_int(' Set lower data bound', 0, low_max) + low = input_int(' Set lower data bound', ge=0, le=low_max) break else: - min_ = input_int(' Set lower zoom index', 0, low_max) - max_ = input_int(' Set upper zoom index', min_+1, low_max+1) + min_ = input_int(' Set lower zoom index', ge=0, le=low_max) + max_ = input_int(' Set upper zoom index', ge=min_+1, le=low_max+1) else: - if not is_int(low, 0, a.shape[axis]-num_min): - illegal_value(low, 'low', 'selectImageBounds') - return None + if not is_int(low, ge=0, le=a.shape[axis]-num_min): + illegal_value(low, 'low', location='select_image_bounds', raise_error=raise_error) + return(None) if upp is None: min_ = low+num_min max_ = a.shape[axis] upp_min = min_ while True: if axis: - quickImshow(a[:,min_:max_], title=title, aspect='auto', + quick_imshow(a[:,min_:max_], title=title, aspect='auto', extent=[min_,max_,a.shape[0],0]) else: - quickImshow(a[min_:max_,:], title=title, aspect='auto', + quick_imshow(a[min_:max_,:], title=title, aspect='auto', extent=[0,a.shape[1], max_,min_]) zoom_flag = input_yesno('Set upper data bound (y) or zoom in (n)?', 'y') if zoom_flag: - upp = input_int(' Set upper data bound', upp_min, a.shape[axis]) + upp = input_int(' Set upper data bound', ge=upp_min, le=a.shape[axis]) break else: - min_ = input_int(' Set upper zoom index', upp_min, a.shape[axis]-1) - max_ = input_int(' Set upper zoom index', min_+1, a.shape[axis]) + min_ = input_int(' Set upper zoom index', ge=upp_min, le=a.shape[axis]-1) + max_ = input_int(' Set upper zoom index', ge=min_+1, le=a.shape[axis]) else: - if not is_int(upp, low+num_min, a.shape[axis]): - illegal_value(upp, 'upp', 'selectImageBounds') - return None + if not is_int(upp, ge=low+num_min, le=a.shape[axis]): + illegal_value(upp, 'upp', location='select_image_bounds', raise_error=raise_error) + return(None) bounds = (low, upp) a_tmp = np.copy(a) a_tmp_max = a.max() @@ -1271,12 +1796,64 @@ a_tmp[bounds[0],:] = a_tmp_max a_tmp[bounds[1]-1,:] = a_tmp_max print(f'lower bound = {low} (inclusive)\nupper bound = {upp} (exclusive)') - quickImshow(a_tmp, title=title) + quick_imshow(a_tmp, title=title, aspect='auto') del a_tmp if not input_yesno('Accept these bounds (y/n)?', 'y'): - bounds = selectImageBounds(a, axis, low=low_save, upp=upp_save, num_min=num_min_save, + bounds = select_image_bounds(a, axis, low=low_save, upp=upp_save, num_min=num_min_save, title=title) - return bounds + return(bounds) + +def select_one_image_bound(a, axis, bound=None, bound_name=None, title='select array bounds', + default='y', raise_error=False): + """Interactively select a data boundary for a 2D numpy array. + """ + a = np.asarray(a) + if a.ndim != 2: + illegal_value(a.ndim, 'array dimension', location='select_one_image_bound', + raise_error=raise_error) + return(None) + if axis < 0 or axis >= a.ndim: + illegal_value(axis, 'axis', location='select_one_image_bound', raise_error=raise_error) + return(None) + if bound_name is None: + bound_name = 'data bound' + if bound is None: + min_ = 0 + max_ = a.shape[axis] + bound_max = a.shape[axis]-1 + while True: + if axis: + quick_imshow(a[:,min_:max_], title=title, aspect='auto', + extent=[min_,max_,a.shape[0],0]) + else: + quick_imshow(a[min_:max_,:], title=title, aspect='auto', + extent=[0,a.shape[1], max_,min_]) + zoom_flag = input_yesno(f'Set {bound_name} (y) or zoom in (n)?', 'y') + if zoom_flag: + bound = input_int(f' Set {bound_name}', ge=0, le=bound_max) + clear_imshow(title) + break + else: + min_ = input_int(' Set lower zoom index', ge=0, le=bound_max) + max_ = input_int(' Set upper zoom index', ge=min_+1, le=bound_max+1) + + elif not is_int(bound, ge=0, le=a.shape[axis]-1): + illegal_value(bound, 'bound', location='select_one_image_bound', raise_error=raise_error) + return(None) + else: + print(f'Current {bound_name} = {bound}') + a_tmp = np.copy(a) + a_tmp_max = a.max() + if axis: + a_tmp[:,bound] = a_tmp_max + else: + a_tmp[bound,:] = a_tmp_max + quick_imshow(a_tmp, title=title, aspect='auto') + del a_tmp + if not input_yesno(f'Accept this {bound_name} (y/n)?', default): + bound = select_one_image_bound(a, axis, bound_name=bound_name, title=title) + clear_imshow(title) + return(bound) class Config: @@ -1289,79 +1866,79 @@ # Load config file if config_file is not None and config_dict is not None: - logging.warning('Ignoring config_dict (both config_file and config_dict are specified)') + logger.warning('Ignoring config_dict (both config_file and config_dict are specified)') if config_file is not None: - self.loadFile(config_file) + self.load_file(config_file) elif config_dict is not None: - self.loadDict(config_dict) + self.load_dict(config_dict) - def loadFile(self, config_file): + def load_file(self, config_file): """Load a config file. """ if self.load_flag: - logging.warning('Overwriting any previously loaded config file') + logger.warning('Overwriting any previously loaded config file') self.config = {} # Ensure config file exists if not os.path.isfile(config_file): - logging.error(f'Unable to load {config_file}') + logger.error(f'Unable to load {config_file}') return # Load config file (for now for Galaxy, allow .dat extension) self.suffix = os.path.splitext(config_file)[1] if self.suffix == '.yml' or self.suffix == '.yaml' or self.suffix == '.dat': with open(config_file, 'r') as f: - self.config = yaml.safe_load(f) + self.config = safe_load(f) elif self.suffix == '.txt': with open(config_file, 'r') as f: lines = f.read().splitlines() self.config = {item[0].strip():literal_eval(item[1].strip()) for item in [line.split('#')[0].split('=') for line in lines if '=' in line.split('#')[0]]} else: - illegal_value(self.suffix, 'config file extension', 'Config.loadFile') + illegal_value(self.suffix, 'config file extension', 'Config.load_file') # Make sure config file was correctly loaded if isinstance(self.config, dict): self.load_flag = True else: - logging.error(f'Unable to load dictionary from config file: {config_file}') + logger.error(f'Unable to load dictionary from config file: {config_file}') self.config = {} - def loadDict(self, config_dict): + def load_dict(self, config_dict): """Takes a dictionary and places it into self.config. """ if self.load_flag: - logging.warning('Overwriting the previously loaded config file') + logger.warning('Overwriting the previously loaded config file') if isinstance(config_dict, dict): self.config = config_dict self.load_flag = True else: - illegal_value(config_dict, 'dictionary config object', 'Config.loadDict') + illegal_value(config_dict, 'dictionary config object', 'Config.load_dict') self.config = {} - def saveFile(self, config_file): + def save_file(self, config_file): """Save the config file (as a yaml file only right now). """ suffix = os.path.splitext(config_file)[1] if suffix != '.yml' and suffix != '.yaml': - illegal_value(suffix, 'config file extension', 'Config.saveFile') + illegal_value(suffix, 'config file extension', 'Config.save_file') # Check if config file exists if os.path.isfile(config_file): - logging.info(f'Updating {config_file}') + logger.info(f'Updating {config_file}') else: - logging.info(f'Saving {config_file}') + logger.info(f'Saving {config_file}') # Save config file with open(config_file, 'w') as f: - yaml.safe_dump(self.config, f) + safe_dump(self.config, f) def validate(self, pars_required, pars_missing=None): """Returns False if any required keys are missing. """ if not self.load_flag: - logging.error('Load a config file prior to calling Config.validate') + logger.error('Load a config file prior to calling Config.validate') def validate_nested_pars(config, par): par_levels = par.split(':') @@ -1374,15 +1951,15 @@ next_level_config = config[first_level_par] if len(par_levels) > 1: next_level_par = ':'.join(par_levels[1:]) - return validate_nested_pars(next_level_config, next_level_par) + return(validate_nested_pars(next_level_config, next_level_par)) else: - return True + return(True) except: - return False + return(False) pars_missing = [p for p in pars_required if not validate_nested_pars(self.config, p)] if len(pars_missing) > 0: - logging.error(f'Missing item(s) in configuration: {", ".join(pars_missing)}') - return False + logger.error(f'Missing item(s) in configuration: {", ".join(pars_missing)}') + return(False) else: - return True + return(True)
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/run_link_to_galaxy Fri Mar 10 16:02:04 2023 +0000 @@ -0,0 +1,6 @@ +#!/bin/bash + +#python -m workflow link_to_galaxy -i tenstom_1304r-1.nxs -g 'https://galaxy.classe.cornell.edu' -a 'fbea44f58986b87b40bb9315c496e687' +#python -m workflow link_to_galaxy -i tenstom_1304r-1.nxs -g 'https://galaxy-dev.classe.cornell.edu' -a 'bd404baf78eef76657277f33021d408f' + +python -m workflow link_to_galaxy -i sobhani-3249-A.yaml -g 'https://galaxy-dev.classe.cornell.edu' -a 'bd404baf78eef76657277f33021d408f'
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/run_tomo_all Fri Mar 10 16:02:04 2023 +0000 @@ -0,0 +1,6 @@ +#!/bin/bash + +# From workflow +#python -m workflow run_tomo -i sobhani-3249-A.yaml -o sobhani-3249-A.nxs -n 48 -s 'no' +#python -m workflow run_tomo -i tenstom_1304r-1_one_stack.yaml -o tenstom_1304r-1_one_stack.nxs -n 48 -s 'no' +python -m workflow run_tomo -i tenstom_1304r-1.yaml -o tenstom_1304r-1.nxs -n 48 -s 'no'
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/run_tomo_combine Fri Mar 10 16:02:04 2023 +0000 @@ -0,0 +1,4 @@ +#!/bin/bash + +# As Galaxy tool: +python tomo_combine.py -i tenstom_1304r-1_recon.nxs -o tenstom_1304r-1_recon_combined.nxs --log_level INFO
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/run_tomo_find_center Fri Mar 10 16:02:04 2023 +0000 @@ -0,0 +1,9 @@ +#!/bin/bash + +# From workflow +python -m workflow run_tomo -i sobhani-3249-A_reduce.nxs -o sobhani-3249-A_centers.yaml -n 48 -s 'only' --find_center + +# As Galaxy tool: +#python tomo_find_center.py -i sobhani-3249-A_reduce.nxs -o sobhani-3249-A_centers.yaml --log_level INFO -l tomo.log --galaxy_flag --center_rows 50 270 +#python tomo_find_center.py -i tenstom_1304r-1_one_stack_reduce.nxs -o tenstom_1304r-1_one_stack_centers.yaml --log_level INFO +#python tomo_find_center.py -i tenstom_1304r-1_reduce.nxs -o tenstom_1304r-1_centers.yaml --log_level INFO
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/run_tomo_reconstruct Fri Mar 10 16:02:04 2023 +0000 @@ -0,0 +1,9 @@ +#!/bin/bash + +# From workflow +python -m workflow run_tomo -i sobhani-3249-A_reduce.nxs -c sobhani-3249-A_centers.yaml -o sobhani-3249-A_recon.nxs -n 48 -s 'only' --reconstruct_data + +# As Galaxy tool: +#python tomo_reconstruct.py -i sobhani-3249-A_reduce.nxs -c sobhani-3249-A_centers.yaml -o sobhani-3249-A_recon.nxs --log_level INFO -l tomo.log --galaxy_flag --x_bounds 650 1050 --y_bounds 270 1430 +#python tomo_reconstruct.py -i tenstom_1304r-1_one_stack_reduce.nxs -c tenstom_1304r-1_one_stack_centers.yaml -o tenstom_1304r-1_one_stack_recon.nxs --log_level INFO +#python tomo_reconstruct.py -i tenstom_1304r-1_reduce.nxs -c tenstom_1304r-1_centers.yaml -o tenstom_1304r-1_recon.nxs --log_level INFO
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/run_tomo_reduce Fri Mar 10 16:02:04 2023 +0000 @@ -0,0 +1,9 @@ +#!/bin/bash + +# From workflow +python -m workflow run_tomo -i sobhani-3249-A.yaml -o sobhani-3249-A_reduce.nxs -n 48 -s 'only' --reduce_data + +# As Galaxy tool: +#python tomo_reduce.py -i sobhani-3249-A.yaml -o sobhani-3249-A_reduce.nxs --log_level INFO -l tomo.log --galaxy_flag --img_x_bounds 620 950 +#python tomo_reduce.py -i tenstom_1304r-1_one_stack.yaml -o tenstom_1304r-1_one_stack_reduce.nxs --log_level INFO -l tomo.log --galaxy_flag +#python tomo_reduce.py -i tenstom_1304r-1.yaml -o tenstom_1304r-1_reduce.nxs --log_level INFO -l tomo.log --galaxy_flag #--img_x_bounds 713 1388
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/sobhani-3249-A.yaml Fri Mar 10 16:02:04 2023 +0000 @@ -0,0 +1,40 @@ +sample_maps: +- cycle: 2022-1 + btr: sobhani-3249-A + title: sobhani-3249-A + station: id3b + sample: + name: tomo7C + detector: + prefix: andor2 + rows: 1436 + columns: 1700 + pixel_size: + - 0.0065 + - 0.0065 + lens_magnification: 5.0 + tomo_fields: + spec_file: /nfs/chess/scratch/user/rv43/2022-1/id3b/sobhani-3249-A/tomo7C/tomo7C + scan_numbers: + - 1 + stack_info: + - scan_number: 1 + starting_image_offset: 5 + num_image: 360 + ref_x: 0.0 + ref_z: 6.1845 + theta_range: + start: 0.0 + end: 180.0 + num: 360 + start_index: 4 + bright_field: + spec_file: /nfs/chess/scratch/user/rv43/2022-1/id3b/sobhani-3249-A/tomo7C/tomo7C_flat + scan_numbers: + - 1 + stack_info: + - scan_number: 1 + starting_image_offset: 1 + num_image: 20 + ref_x: 0.0 + ref_z: 6.1845
--- a/tomo.py Fri Aug 19 20:16:56 2022 +0000 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,2680 +0,0 @@ -#!/usr/bin/env python3 - -# -*- coding: utf-8 -*- -""" -Created on Fri Dec 10 09:54:37 2021 - -@author: rv43 -""" - -import logging - -import os -import sys -import getopt -import re -import io -import argparse -import numpy as np -try: - import numexpr as ne -except: - pass -import multiprocessing as mp -try: - import scipy.ndimage as spi -except: - pass -try: - import tomopy -except: - pass -from time import time -try: - from skimage.transform import iradon -except: - pass -try: - from skimage.restoration import denoise_tv_chambolle -except: - pass - -from detector import TomoDetectorConfig -from fit import Fit -from general import illegal_value, is_int, is_num, is_index_range, get_trailing_int, \ - input_int, input_num, input_yesno, input_menu, findImageFiles, loadImageStack, clearPlot, \ - draw_mask_1d, quickPlot, clearImshow, quickImshow, combine_tiffs_in_h5, Config -from general import selectImageRange, selectImageBounds - -# the following tomopy routines don't run with more than 24 cores on Galaxy-Dev -# - tomopy.find_center_vo -# - tomopy.prep.stripe.remove_stripe_fw -num_core_tomopy_limit = 24 - -class set_numexpr_threads: - - def __init__(self, num_core): - cpu_count = mp.cpu_count() - logging.debug(f'start: num_core={num_core} cpu_count={cpu_count}') - if num_core is None or num_core < 1 or num_core > cpu_count: - self.num_core = cpu_count - else: - self.num_core = num_core - logging.debug(f'self.num_core={self.num_core}') - - def __enter__(self): - self.num_core_org = ne.set_num_threads(self.num_core) - logging.debug(f'self.num_core={self.num_core}') - - def __exit__(self, exc_type, exc_value, traceback): - ne.set_num_threads(self.num_core_org) - -class ConfigTomo(Config): - """Class for processing a config file. - """ - - def __init__(self, config_file=None, config_dict=None): - super().__init__(config_file, config_dict) - - def _validate_txt(self): - """Returns False if any required config parameter is illegal or missing. - """ - is_valid = True - - # Check for required first-level keys - pars_required = ['tdf_data_path', 'tbf_data_path', 'detector_id'] - pars_missing = [] - is_valid = super().validate(pars_required, pars_missing) - if len(pars_missing) > 0: - logging.error(f'Missing item(s) in config file: {", ".join(pars_missing)}') - self.detector_id = self.config.get('detector_id') - - # Find tomography dark field images file/folder - self.tdf_data_path = self.config.get('tdf_data_path') - - # Find tomography bright field images file/folder - self.tbf_data_path = self.config.get('tbf_data_path') - - # Check number of tomography image stacks - self.num_tomo_stacks = self.config.get('num_tomo_stacks', 1) - if not is_int(self.num_tomo_stacks, 1): - self.num_tomo_stacks = None - illegal_value(self.num_tomo_stacks, 'num_tomo_stacks', 'config file') - return False - logging.info(f'num_tomo_stacks = {self.num_tomo_stacks}') - - # Find tomography images file/folders and stack parameters - tomo_data_paths_indices = sorted({key:value for key,value in self.config.items() - if 'tomo_data_path' in key}.items()) - if len(tomo_data_paths_indices) != self.num_tomo_stacks: - logging.error(f'Incorrect number of tomography data path names in config file') - is_valid = False - self.tomo_data_paths = [tomo_data_paths_indices[i][1] for i in range(self.num_tomo_stacks)] - self.tomo_data_indices = [get_trailing_int(tomo_data_paths_indices[i][0]) - if get_trailing_int(tomo_data_paths_indices[i][0]) else None - for i in range(self.num_tomo_stacks)] - tomo_ref_height_indices = sorted({key:value for key,value in self.config.items() - if 'z_pos' in key}.items()) - if self.num_tomo_stacks > 1 and len(tomo_ref_height_indices) != self.num_tomo_stacks: - logging.error(f'Incorrect number of tomography reference heights in config file') - is_valid = False - if len(tomo_ref_height_indices): - self.tomo_ref_heights = [ - tomo_ref_height_indices[i][1] for i in range(self.num_tomo_stacks)] - else: - self.tomo_ref_heights = [0.0]*self.num_tomo_stacks - - # Check tomo angle (theta) range - self.start_theta = self.config.get('start_theta', 0.) - if not is_num(self.start_theta, 0.): - illegal_value(self.start_theta, 'start_theta', 'config file') - is_valid = False - logging.debug(f'start_theta = {self.start_theta}') - self.end_theta = self.config.get('end_theta', 180.) - if not is_num(self.end_theta, self.start_theta): - illegal_value(self.end_theta, 'end_theta', 'config file') - is_valid = False - logging.debug(f'end_theta = {self.end_theta}') - self.num_thetas = self.config.get('num_thetas') - if not (self.num_thetas is None or is_int(self.num_thetas, 1)): - illegal_value(self.num_thetas, 'num_thetas', 'config file') - self.num_thetas = None - is_valid = False - logging.debug(f'num_thetas = {self.num_thetas}') - - return is_valid - - def _validate_yaml(self): - """Returns False if any required config parameter is illegal or missing. - """ - is_valid = True - - # Check for required first-level keys - pars_required = ['dark_field', 'bright_field', 'stack_info', 'detector'] - pars_missing = [] - is_valid = super().validate(pars_required, pars_missing) - if len(pars_missing) > 0: - logging.error(f'Missing item(s) in config file: {", ".join(pars_missing)}') - self.detector_id = self.config['detector'].get('id') - - # Find tomography dark field images file/folder - self.tdf_data_path = self.config['dark_field'].get('data_path') - - # Find tomography bright field images file/folder - self.tbf_data_path = self.config['bright_field'].get('data_path') - - # Check number of tomography image stacks - stack_info = self.config['stack_info'] - self.num_tomo_stacks = stack_info.get('num', 1) - if not is_int(self.num_tomo_stacks, 1): - self.num_tomo_stacks = None - illegal_value(self.num_tomo_stacks, 'num_tomo_stacks', 'config file') - return False - logging.info(f'num_tomo_stacks = {self.num_tomo_stacks}') - - # Find tomography images file/folders and stack parameters - stacks = stack_info.get('stacks') - if stacks is None or len(stacks) is not self.num_tomo_stacks: - illegal_value(stacks, 'stacks', 'config file') - return False - self.tomo_data_paths = [] - self.tomo_data_indices = [] - self.tomo_ref_heights = [] - for stack in stacks: - self.tomo_data_paths.append(stack.get('data_path')) - self.tomo_data_indices.append(stack.get('index')) - self.tomo_ref_heights.append(stack.get('ref_height')) - - # Check tomo angle (theta) range - theta_range = self.config.get('theta_range') - if theta_range is None: - self.start_theta = 0. - self.end_theta = 180. - self.num_thetas = None - else: - self.start_theta = theta_range.get('start', 0.) - if not is_num(self.start_theta, 0.): - illegal_value(self.start_theta, 'theta_range:start', 'config file') - is_valid = False - logging.debug(f'start_theta = {self.start_theta}') - self.end_theta = theta_range.get('end', 180.) - if not is_num(self.end_theta, self.start_theta): - illegal_value(self.end_theta, 'theta_range:end', 'config file') - is_valid = False - logging.debug(f'end_theta = {self.end_theta}') - self.num_thetas = theta_range.get('num') - if self.num_thetas and not is_int(self.num_thetas, 1): - illegal_value(self.num_thetas, 'theta_range:num', 'config file') - self.num_thetas = None - is_valid = False - logging.debug(f'num_thetas = {self.num_thetas}') - - return is_valid - - def validate(self): - """Returns False if any required config parameter is illegal or missing. - """ - is_valid = True - - # Check work_folder (shared by both file formats) - work_folder = os.path.abspath(self.config.get('work_folder', '')) - if not os.path.isdir(work_folder): - illegal_value(work_folder, 'work_folder', 'config file') - is_valid = False - logging.info(f'work_folder: {work_folder}') - - # Check data filetype (shared by both file formats) - self.data_filetype = self.config.get('data_filetype', 'tif') - if not isinstance(self.data_filetype, str) or (self.data_filetype != 'tif' and - self.data_filetype != 'h5'): - illegal_value(self.data_filetype, 'data_filetype', 'config file') - - if self.suffix == '.yml' or self.suffix == '.yaml': - is_valid = self._validate_yaml() - elif self.suffix == '.txt': - is_valid = self._validate_txt() - else: - logging.error(f'Undefined or illegal config file extension: {self.suffix}') - is_valid = False - - # Find tomography bright field images file/folder - if self.tdf_data_path: - if self.data_filetype == 'h5': - if isinstance(self.tdf_data_path, str): - if not os.path.isabs(self.tdf_data_path): - self.tdf_data_path = os.path.abspath( - f'{work_folder}/{self.tdf_data_path}') - else: - illegal_value(tdf_data_fil, 'tdf_data_path', 'config file') - is_valid = False - else: - if isinstance(self.tdf_data_path, int): - self.tdf_data_path = os.path.abspath( - f'{work_folder}/{self.tdf_data_path}/nf') - elif isinstance(self.tdf_data_path, str): - if not os.path.isabs(self.tdf_data_path): - self.tdf_data_path = os.path.abspath( - f'{work_folder}/{self.tdf_data_path}') - else: - illegal_value(self.tdf_data_path, 'tdf_data_path', 'config file') - is_valid = False - logging.info(f'dark field images path = {self.tdf_data_path}') - - # Find tomography bright field images file/folder - if self.tbf_data_path: - if self.data_filetype == 'h5': - if isinstance(self.tbf_data_path, str): - if not os.path.isabs(self.tbf_data_path): - self.tbf_data_path = os.path.abspath( - f'{work_folder}/{self.tbf_data_path}') - else: - illegal_value(tbf_data_fil, 'tbf_data_path', 'config file') - is_valid = False - else: - if isinstance(self.tbf_data_path, int): - self.tbf_data_path = os.path.abspath( - f'{work_folder}/{self.tbf_data_path}/nf') - elif isinstance(self.tbf_data_path, str): - if not os.path.isabs(self.tbf_data_path): - self.tbf_data_path = os.path.abspath( - f'{work_folder}/{self.tbf_data_path}') - else: - illegal_value(self.tbf_data_path, 'tbf_data_path', 'config file') - is_valid = False - logging.info(f'bright field images path = {self.tbf_data_path}') - - # Find tomography images file/folders and stack parameters - tomo_data_paths = [] - tomo_data_indices = [] - tomo_ref_heights = [] - for data_path, index, ref_height in zip(self.tomo_data_paths, self.tomo_data_indices, - self.tomo_ref_heights): - if self.data_filetype == 'h5': - if isinstance(data_path, str): - if not os.path.isabs(data_path): - data_path = os.path.abspath(f'{work_folder}/{data_path}') - else: - illegal_value(data_path, 'stack_info:stacks:data_path', 'config file') - is_valid = False - data_path = None - else: - if isinstance(data_path, int): - data_path = os.path.abspath(f'{work_folder}/{data_path}/nf') - elif isinstance(data_path, str): - if not os.path.isabs(data_path): - data_path = os.path.abspath(f'{work_folder}/{data_path}') - else: - illegal_value(data_path, 'stack_info:stacks:data_path', 'config file') - is_valid = False - data_path = None - tomo_data_paths.append(data_path) - if index is None: - if self.num_tomo_stacks > 1: - logging.error('Missing stack_info:stacks:index in config file') - is_valid = False - index = None - else: - index = 1 - elif not isinstance(index, int): - illegal_value(index, 'stack_info:stacks:index', 'config file') - is_valid = False - index = None - tomo_data_indices.append(index) - if ref_height is None: - if self.num_tomo_stacks > 1: - logging.error('Missing stack_info:stacks:ref_height in config file') - is_valid = False - ref_height = None - else: - ref_height = 0. - elif not is_num(ref_height): - illegal_value(ref_height, 'stack_info:stacks:ref_height', 'config file') - is_valid = False - ref_height = None - # Set reference heights relative to first stack - if (len(tomo_ref_heights) and is_num(ref_height) and - is_num(tomo_ref_heights[0])): - ref_height = (round(ref_height-tomo_ref_heights[0], 3)) - tomo_ref_heights.append(ref_height) - tomo_ref_heights[0] = 0.0 - logging.info('tomography data paths:') - for i in range(self.num_tomo_stacks): - logging.info(f' {tomo_data_paths[i]}') - logging.info(f'tomography data path indices: {tomo_data_indices}') - logging.info(f'tomography reference heights: {tomo_ref_heights}') - - # Update config in memory - if self.suffix == '.txt': - self.config = {} - dark_field = self.config.get('dark_field') - if dark_field is None: - self.config['dark_field'] = {'data_path' : self.tdf_data_path} - else: - self.config['dark_field']['data_path'] = self.tdf_data_path - bright_field = self.config.get('bright_field') - if bright_field is None: - self.config['bright_field'] = {'data_path' : self.tbf_data_path} - else: - self.config['bright_field']['data_path'] = self.tbf_data_path - detector = self.config.get('detector') - if detector is None: - self.config['detector'] = {'id' : self.detector_id} - else: - detector['id'] = self.detector_id - self.config['work_folder'] = work_folder - self.config['data_filetype'] = self.data_filetype - stack_info = self.config.get('stack_info') - if stack_info is None: - stacks = [] - for i in range(self.num_tomo_stacks): - stacks.append({'data_path' : tomo_data_paths[i], 'index' : tomo_data_indices[i], - 'ref_height' : tomo_ref_heights[i]}) - self.config['stack_info'] = {'num' : self.num_tomo_stacks, 'stacks' : stacks} - else: - stack_info['num'] = self.num_tomo_stacks - stacks = stack_info.get('stacks') - for i,stack in enumerate(stacks): - stack['data_path'] = tomo_data_paths[i] - stack['index'] = tomo_data_indices[i] - stack['ref_height'] = tomo_ref_heights[i] - if self.num_thetas: - theta_range = {'start' : self.start_theta, 'end' : self.end_theta, - 'num' : self.num_thetas} - else: - theta_range = {'start' : self.start_theta, 'end' : self.end_theta} - self.config['theta_range'] = theta_range - - # Cleanup temporary validation variables - del self.tdf_data_path - del self.tbf_data_path - del self.detector_id - del self.data_filetype - del self.num_tomo_stacks - del self.tomo_data_paths - del self.tomo_data_indices - del self.tomo_ref_heights - del self.start_theta - del self.end_theta - del self.num_thetas - - return is_valid - -class Tomo: - """Processing tomography data with misalignment. - """ - - def __init__(self, config_file=None, config_dict=None, config_out=None, output_folder='.', - log_level='INFO', log_stream='tomo.log', galaxy_flag=False, test_mode=False, - num_core=-1): - """Initialize with optional config input file or dictionary - """ - self.num_core = None - self.config_out = config_out - self.output_folder = output_folder - self.galaxy_flag = galaxy_flag - self.test_mode = test_mode - self.save_plots = True # Make input argument? - self.save_plots_only = True # Make input argument? - self.cf = None - self.config = None - self.is_valid = True - self.tdf = np.array([]) - self.tbf = np.array([]) - self.tomo_stacks = [] - self.tomo_recon_stacks = [] - - # Validate input parameters - if config_file is not None and not os.path.isfile(config_file): - raise OSError(f'Invalid config_file input {config_file} {type(config_file)}') - if config_dict is not None and not isinstance(config_dict, dict): - raise ValueError(f'Invalid config_dict input {config_dict} {type(config_dict)}') - if self.config_out is not None and not isinstance(self.config_out, str): - raise OSError(f'Invalid config_out input {self.config_out} {type(self.config_out)}') - if not os.path.isdir(output_folder): - os.mkdir(os.path.abspath(output_folder)) - if isinstance(log_stream, str): - path = os.path.split(log_stream)[0] - if path and not os.path.isdir(path): - raise OSError(f'Invalid log_stream path') - if not os.path.isabs(path): - log_stream = f'{output_folder}/{log_stream}' - if not isinstance(galaxy_flag, bool): - raise ValueError(f'Invalid galaxy_flag input {galaxy_flag} {type(galaxy_flag)}') - if not isinstance(self.test_mode, bool): - raise ValueError(f'Invalid test_mode input {self.test_mode} {type(self.test_mode)}') - if not isinstance(num_core, int) or num_core < -1 or num_core == 0: - raise ValueError(f'Invalid num_core input {num_core} {type(num_core)}') - if num_core == -1: - self.num_core = mp.cpu_count() - else: - self.num_core = num_core - - # Set log configuration - logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s' - if self.test_mode: - self.save_plots_only = True - if isinstance(log_stream, str): - logging.basicConfig(filename=f'{log_stream}', filemode='w', - format=logging_format, level=logging.INFO, force=True) - #format=logging_format, level=logging.WARNING, force=True) - elif isinstance(log_stream, io.TextIOWrapper): - #logging.basicConfig(filemode='w', format=logging_format, level=logging.WARNING, - logging.basicConfig(filemode='w', format=logging_format, level=logging.INFO, - stream=log_stream, force=True) - else: - raise ValueError(f'Invalid log_stream: {log_stream}') - logging.warning('Ignoring log_level argument in test mode') - else: - level = getattr(logging, log_level.upper(), None) - if not isinstance(level, int): - raise ValueError(f'Invalid log_level: {log_level}') - if log_stream is sys.stdout: - logging.basicConfig(format=logging_format, level=level, force=True, - handlers=[logging.StreamHandler()]) - else: - if isinstance(log_stream, str): - logging.basicConfig(filename=f'{log_stream}', filemode='w', - format=logging_format, level=level, force=True) - elif isinstance(log_stream, io.TextIOWrapper): - logging.basicConfig(filemode='w', format=logging_format, level=level, - stream=log_stream, force=True) - else: - raise ValueError(f'Invalid log_stream: {log_stream}') - stream_handler = logging.StreamHandler() - logging.getLogger().addHandler(stream_handler) - stream_handler.setLevel(logging.WARNING) - stream_handler.setFormatter(logging.Formatter(logging_format)) - - # Check/set output config file name - if self.config_out is None: - self.config_out = f'{self.output_folder}/config.yaml' - elif (self.config_out is os.path.basename(self.config_out) and - not os.path.isabs(self.config_out)): - self.config_out = f'{self.output_folder}/{self.config_out}' - - # Create config object and load config file - self.cf = ConfigTomo(config_file, config_dict) - if not self.cf.load_flag: - self.is_valid = False - return - - if self.galaxy_flag: - assert(self.output_folder == '.') - assert(self.test_mode is False) - self.save_plots = True - self.save_plots_only = True - else: - # Input validation is already performed during link_data_to_galaxy - - # Check config file parameters - self.is_valid = self.cf.validate() - - # Load detector info file - df = TomoDetectorConfig(self.cf.config['detector']['id']) - - # Check detector info file parameters - if df.valid: - pixel_size = df.pixel_size - num_rows, num_columns = df.dimensions - if not pixel_size or not num_rows or not num_columns: - self.is_valid = False - else: - pixel_size = None - num_rows = None - num_columns = None - self.is_valid = False - - # Update config - self.cf.config['detector']['pixel_size'] = pixel_size - self.cf.config['detector']['rows'] = num_rows - self.cf.config['detector']['columns'] = num_columns - logging.debug(f'pixel_size = self.cf.config["detector"]["pixel_size"]') - logging.debug(f'num_rows: {self.cf.config["detector"]["rows"]}') - logging.debug(f'num_columns: {self.cf.config["detector"]["columns"]}') - - # Safe config to file - if self.is_valid: - self.cf.saveFile(self.config_out) - - # Initialize shortcut to config - self.config = self.cf.config - - # Initialize tomography stack - num_tomo_stacks = self.config['stack_info']['num'] - if num_tomo_stacks: - self.tomo_stacks = [np.array([]) for _ in range(num_tomo_stacks)] - self.tomo_recon_stacks = [np.array([]) for _ in range(num_tomo_stacks)] - - logging.debug(f'num_core = {self.num_core}') - logging.debug(f'config_file = {config_file}') - logging.debug(f'config_dict = {config_dict}') - logging.debug(f'config_out = {self.config_out}') - logging.debug(f'output_folder = {self.output_folder}') - logging.debug(f'log_stream = {log_stream}') - logging.debug(f'log_level = {log_level}') - logging.debug(f'galaxy_flag = {self.galaxy_flag}') - logging.debug(f'test_mode = {self.test_mode}') - logging.debug(f'save_plots = {self.save_plots}') - logging.debug(f'save_plots_only = {self.save_plots_only}') - - def _selectImageRanges(self, available_stacks=None): - """Select image files to be included in analysis. - """ - self.is_valid = True - stack_info = self.config['stack_info'] - if available_stacks is None: - available_stacks = [False]*stack_info['num'] - elif len(available_stacks) != stack_info['num']: - logging.warning('Illegal dimension of available_stacks in getImageFiles '+ - f'({len(available_stacks)}'); - available_stacks = [False]*stack_info['num'] - - # Check number of tomography angles/thetas - num_thetas = self.config['theta_range'].get('num') - if num_thetas is None: - num_thetas = input_int('\nEnter the number of thetas', 1) - elif not is_int(num_thetas, 0): - illegal_value(num_thetas, 'num_thetas', 'config file') - self.is_valid = False - return - self.config['theta_range']['num'] = num_thetas - logging.debug(f'num_thetas = {self.config["theta_range"]["num"]}') - - # Find tomography dark field images - dark_field = self.config['dark_field'] - img_start = dark_field.get('img_start', -1) - img_offset = dark_field.get('img_offset', -1) - num_imgs = dark_field.get('num', 0) - if not self.test_mode: - img_start, img_offset, num_imgs = selectImageRange(img_start, img_offset, - num_imgs, 'dark field') - if img_start < 0 or num_imgs < 1: - logging.error('Unable to find suitable dark field images') - if dark_field['data_path']: - self.is_valid = False - dark_field['img_start'] = img_start - dark_field['img_offset'] = img_offset - dark_field['num'] = num_imgs - logging.debug(f'Dark field image start index: {dark_field["img_start"]}') - logging.debug(f'Dark field image offset: {dark_field["img_offset"]}') - logging.debug(f'Number of dark field images: {dark_field["num"]}') - - # Find tomography bright field images - bright_field = self.config['bright_field'] - img_start = bright_field.get('img_start', -1) - img_offset = bright_field.get('img_offset', -1) - num_imgs = bright_field.get('num', 0) - if not self.test_mode: - img_start, img_offset, num_imgs = selectImageRange(img_start, img_offset, - num_imgs, 'bright field') - if img_start < 0 or num_imgs < 1: - logging.error('Unable to find suitable bright field images') - self.is_valid = False - bright_field['img_start'] = img_start - bright_field['img_offset'] = img_offset - bright_field['num'] = num_imgs - logging.debug(f'Bright field image start index: {bright_field["img_start"]}') - logging.debug(f'Bright field image offset: {bright_field["img_offset"]}') - logging.debug(f'Number of bright field images: {bright_field["num"]}') - - # Find tomography images - for i,stack in enumerate(stack_info['stacks']): - # Check if stack is already loaded or available - if self.tomo_stacks[i].size or available_stacks[i]: - continue - index = stack['index'] - img_start = stack.get('img_start', -1) - img_offset = stack.get('img_offset', -1) - num_imgs = stack.get('num', 0) - if not self.test_mode: - img_start, img_offset, num_imgs = selectImageRange(img_start, img_offset, - num_imgs, f'tomography stack {index}', num_thetas) - if img_start < 0 or num_imgs != num_thetas: - logging.error('Unable to find suitable tomography images') - self.is_valid = False - stack['img_start'] = img_start - stack['img_offset'] = img_offset - stack['num'] = num_imgs - logging.debug(f'Tomography stack {index} image start index: {stack["img_start"]}') - logging.debug(f'Tomography stack {index} image offset: {stack["img_offset"]}') - logging.debug(f'Number of tomography images for stack {index}: {stack["num"]}') - - # Safe updated config to file - if self.is_valid: - self.cf.saveFile(self.config_out) - - return - - def _genDark(self, tdf_files): - """Generate dark field. - """ - # Load the dark field images - logging.debug('Loading dark field...') - dark_field = self.config['dark_field'] - tdf_stack = loadImageStack(tdf_files, self.config['data_filetype'], - dark_field['img_offset'], dark_field['num']) - - # Take median - self.tdf = np.median(tdf_stack, axis=0) - del tdf_stack - - # Remove dark field intensities above the cutoff - tdf_cutoff = dark_field.get('cutoff') - if tdf_cutoff is not None: - if not is_num(tdf_cutoff, 0): - logging.warning(f'Ignoring illegal value of tdf_cutoff {tdf_cutoff}') - else: - self.tdf[self.tdf > tdf_cutoff] = np.nan - logging.debug(f'tdf_cutoff = {tdf_cutoff}') - - tdf_mean = np.nanmean(self.tdf) - logging.debug(f'tdf_mean = {tdf_mean}') - np.nan_to_num(self.tdf, copy=False, nan=tdf_mean, posinf=tdf_mean, neginf=0.) - if self.galaxy_flag: - quickImshow(self.tdf, title='dark field', path='setup_pngs', - save_fig=True, save_only=True) - elif not self.test_mode: - quickImshow(self.tdf, title='dark field', path=self.output_folder, - save_fig=self.save_plots, save_only=self.save_plots_only) - - def _genBright(self, tbf_files): - """Generate bright field. - """ - # Load the bright field images - logging.debug('Loading bright field...') - bright_field = self.config['bright_field'] - tbf_stack = loadImageStack(tbf_files, self.config['data_filetype'], - bright_field['img_offset'], bright_field['num']) - - # Take median - """Median or mean: It may be best to try the median because of some image - artifacts that arise due to crinkles in the upstream kapton tape windows - causing some phase contrast images to appear on the detector. - One thing that also may be useful in a future implementation is to do a - brightfield adjustment on EACH frame of the tomo based on a ROI in the - corner of the frame where there is no sample but there is the direct X-ray - beam because there is frame to frame fluctuations from the incoming beam. - We don’t typically account for them but potentially could. - """ - self.tbf = np.median(tbf_stack, axis=0) - del tbf_stack - - # Subtract dark field - if self.tdf.size: - self.tbf -= self.tdf - else: - logging.warning('Dark field unavailable') - if self.galaxy_flag: - quickImshow(self.tbf, title='bright field', path='setup_pngs', - save_fig=True, save_only=True) - elif not self.test_mode: - quickImshow(self.tbf, title='bright field', path=self.output_folder, - save_fig=self.save_plots, save_only=self.save_plots_only) - - def _setDetectorBounds(self, tomo_stack_files): - """Set vertical detector bounds for image stack. - """ - preprocess = self.config.get('preprocess') - if preprocess is None: - img_x_bounds = [None, None] - else: - img_x_bounds = preprocess.get('img_x_bounds', [0, self.tbf.shape[0]]) - if img_x_bounds[0] is not None and img_x_bounds[1] is not None: - if img_x_bounds[0] < 0: - illegal_value(img_x_bounds[0], 'preprocess:img_x_bounds[0]', 'config file') - img_x_bounds[0] = 0 - if not is_index_range(img_x_bounds, 0, self.tbf.shape[0]): - illegal_value(img_x_bounds[1], 'preprocess:img_x_bounds[1]', 'config file') - img_x_bounds[1] = self.tbf.shape[0] - if self.test_mode: - # Update config and save to file - if preprocess is None: - self.cf.config['preprocess'] = {'img_x_bounds' : [0, self.tbf.shape[0]]} - else: - preprocess['img_x_bounds'] = img_x_bounds - self.cf.saveFile(self.config_out) - return - - # Check reference heights - pixel_size = self.config['detector']['pixel_size'] - if pixel_size is None: - raise ValueError('Detector pixel size unavailable') - if not self.tbf.size: - raise ValueError('Bright field unavailable') - num_x_min = None - num_tomo_stacks = self.config['stack_info']['num'] - stacks = self.config['stack_info']['stacks'] - if num_tomo_stacks > 1: - delta_z = stacks[1]['ref_height']-stacks[0]['ref_height'] - for i in range(2, num_tomo_stacks): - delta_z = min(delta_z, stacks[i]['ref_height']-stacks[i-1]['ref_height']) - logging.debug(f'delta_z = {delta_z}') - num_x_min = int((delta_z-0.5*pixel_size)/pixel_size) - logging.debug(f'num_x_min = {num_x_min}') - if num_x_min > self.tbf.shape[0]: - logging.warning('Image bounds and pixel size prevent seamless stacking') - num_x_min = None - - # Select image bounds - if self.galaxy_flag: - x_sum = np.sum(self.tbf, 1) - x_sum_min = x_sum.min() - x_sum_max = x_sum.max() - x_low = 0 - x_upp = x_sum.size - if num_x_min is not None: - fit = Fit.fit_data(np.array(range(len(x_sum))), x_sum, models='rectangle', - form='atan', guess=True) - parameters = fit.best_values - x_low = parameters.get('center1', None) - x_upp = parameters.get('center2', None) - sig_low = parameters.get('sigma1', None) - sig_upp = parameters.get('sigma2', None) - if (x_low is not None and x_upp is not None and sig_low is not None and - sig_upp is not None and 0 <= x_low < x_upp <= x_sum.size and - (sig_low+sig_upp)/(x_upp-x_low) < 0.1): - if num_tomo_stacks == 1 or num_x_min is None: - x_low = int(x_low-(x_upp-x_low)/10) - x_upp = int(x_upp+(x_upp-x_low)/10) - else: - x_low = int((x_low+x_upp)/2-num_x_min/2) - x_upp = x_low+num_x_min - if x_low < 0: - x_low = 0 - if x_upp > x_sum.size: - x_upp = x_sum.size - else: - x_low = 0 - x_upp = x_sum.size - quickPlot((range(x_sum.size), x_sum), - ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'), - ([x_upp-1, x_upp-1], [x_sum_min, x_sum_max], 'r-'), - title=f'sum bright field over theta/y (row bounds: [{x_low}, {x_upp}])', - path='setup_pngs', name='detectorbounds.png', save_fig=True, save_only=True, - show_grid=True) - for i,stack in enumerate(stacks): - tomo_stack = loadImageStack(tomo_stack_files[i], self.config['data_filetype'], - stack['img_offset'], 1) - tomo_stack = tomo_stack[0,:,:] - if num_x_min is not None: - tomo_stack_max = tomo_stack.max() - tomo_stack[x_low,:] = tomo_stack_max - tomo_stack[x_upp-1,:] = tomo_stack_max - title = f'tomography image at theta={self.config["theta_range"]["start"]}' - quickImshow(tomo_stack, title=title, path='setup_pngs', - name=f'tomo_{stack["index"]}.png', save_fig=True, save_only=True, - show_grid=True) - del tomo_stack - - # Update config and save to file - img_x_bounds = [x_low, x_upp] - logging.debug(f'img_x_bounds: {img_x_bounds}') - if preprocess is None: - self.cf.config['preprocess'] = {'img_x_bounds' : img_x_bounds} - else: - preprocess['img_x_bounds'] = img_x_bounds - self.cf.saveFile(self.config_out) - del x_sum - return - - # For one tomography stack only: load the first image - title = None - quickImshow(self.tbf, title='bright field') - if num_tomo_stacks == 1: - tomo_stack = loadImageStack(tomo_stack_files[0], self.config['data_filetype'], - stacks[0]['img_offset'], 1) - title = f'tomography image at theta={self.config["theta_range"]["start"]}' - quickImshow(tomo_stack[0,:,:], title=title) - tomo_or_bright = input_menu(['bright field', 'first tomography image'], - header='\nSelect image bounds from') - else: - print('\nSelect image bounds from bright field') - tomo_or_bright = 0 - if tomo_or_bright: - x_sum = np.sum(tomo_stack[0,:,:], 1) - use_bounds = False - if img_x_bounds[0] is not None and img_x_bounds[1] is not None: - tmp = np.copy(tomo_stack[0,:,:]) - tmp_max = tmp.max() - tmp[img_x_bounds[0],:] = tmp_max - tmp[img_x_bounds[1]-1,:] = tmp_max - title = f'tomography image at theta={self.config["theta_range"]["start"]}' - quickImshow(tmp, title=title) - del tmp - x_sum_min = x_sum.min() - x_sum_max = x_sum.max() - quickPlot((range(x_sum.size), x_sum), - ([img_x_bounds[0], img_x_bounds[0]], [x_sum_min, x_sum_max], 'r-'), - ([img_x_bounds[1]-1, img_x_bounds[1]-1], [x_sum_min, x_sum_max], 'r-'), - title='sum over theta and y') - print(f'lower bound = {img_x_bounds[0]} (inclusive)\n'+ - f'upper bound = {img_x_bounds[1]} (exclusive)]') - use_bounds = input_yesno('Accept these bounds (y/n)?', 'y') - if not use_bounds: - img_x_bounds = list(selectImageBounds(tomo_stack[0,:,:], 0, - img_x_bounds[0], img_x_bounds[1], num_x_min, - f'tomography image at theta={self.config["theta_range"]["start"]}')) - if num_x_min is not None and img_x_bounds[1]-img_x_bounds[0]+1 < num_x_min: - logging.warning('Image bounds and pixel size prevent seamless stacking') - title = f'tomography image at theta={self.config["theta_range"]["start"]}' - quickImshow(tomo_stack[0,:,:], title=title, path=self.output_folder, - save_fig=self.save_plots, save_only=True) - quickPlot(range(img_x_bounds[0], img_x_bounds[1]), - x_sum[img_x_bounds[0]:img_x_bounds[1]], - title='sum over theta and y', path=self.output_folder, - save_fig=self.save_plots, save_only=True) - else: - x_sum = np.sum(self.tbf, 1) - x_sum_min = x_sum.min() - x_sum_max = x_sum.max() - use_bounds = False - if img_x_bounds[0] is not None and img_x_bounds[1] is not None: - tmp = np.copy(self.tbf) - tmp_max = tmp.max() - tmp[img_x_bounds[0],:] = tmp_max - tmp[img_x_bounds[1]-1,:] = tmp_max - title = 'bright field' - quickImshow(tmp, title=title) - del tmp - quickPlot((range(x_sum.size), x_sum), - ([img_x_bounds[0], img_x_bounds[0]], [x_sum_min, x_sum_max], 'r-'), - ([img_x_bounds[1]-1, img_x_bounds[1]-1], [x_sum_min, x_sum_max], 'r-'), - title='sum over theta and y') - print(f'lower bound = {img_x_bounds[0]} (inclusive)\n'+ - f'upper bound = {img_x_bounds[1]} (exclusive)]') - use_bounds = input_yesno('Accept these bounds (y/n)?', 'y') - if not use_bounds: - use_fit = False - fit = Fit.fit_data(np.array(range(len(x_sum))), x_sum, models='rectangle', - form='atan', guess=True) - parameters = fit.best_values - x_low = parameters.get('center1', None) - x_upp = parameters.get('center2', None) - sig_low = parameters.get('sigma1', None) - sig_upp = parameters.get('sigma2', None) - if (x_low is not None and x_upp is not None and sig_low is not None and - sig_upp is not None and 0 <= x_low < x_upp <= x_sum.size and - (sig_low+sig_upp)/(x_upp-x_low) < 0.1): - if num_tomo_stacks == 1 or num_x_min is None: - x_low = int(x_low-(x_upp-x_low)/10) - x_upp = int(x_upp+(x_upp-x_low)/10) - else: - x_low = int((x_low+x_upp)/2-num_x_min/2) - x_upp = x_low+num_x_min - if x_low < 0: - x_low = 0 - if x_upp > x_sum.size: - x_upp = x_sum.size - tmp = np.copy(self.tbf) - tmp_max = tmp.max() - tmp[x_low,:] = tmp_max - tmp[x_upp-1,:] = tmp_max - title = 'bright field' - quickImshow(tmp, title=title) - del tmp - quickPlot((range(x_sum.size), x_sum), - ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'), - ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'), - title='sum over theta and y') - print(f'lower bound = {x_low} (inclusive)') - print(f'upper bound = {x_upp} (exclusive)]') - use_fit = input_yesno('Accept these bounds (y/n)?', 'y') - if use_fit: - img_x_bounds = [x_low, x_upp] - else: - accept = False - while not accept: - mask, img_x_bounds = draw_mask_1d(x_sum, title='select x data range', - legend='sum over theta and y') - print(f'img_x_bounds = {img_x_bounds}') - while (len(img_x_bounds) != 1 or (len(x_sum) >= num_x_min and - img_x_bounds[0][1]-img_x_bounds[0][0]+1 < num_x_min)): - exit('Should not be here') - print('Please select exactly one continuous range') - mask, img_x_bounds = draw_mask_1d(x_sum, title='select x data range', - legend='sum over theta and y') - img_x_bounds = list(img_x_bounds[0]) - quickPlot(x_sum, vlines=img_x_bounds, title='sum over theta and y') - print(f'img_x_bounds = {img_x_bounds} (lower bound inclusive, upper bound '+ - 'exclusive)') - accept = input_yesno('Accept these bounds (y/n)?', 'y') - if not accept: - img_x_bounds = None - if num_x_min is not None and img_x_bounds[1]-img_x_bounds[0]+1 < num_x_min: - logging.warning('Image bounds and pixel size prevent seamless stacking') - #quickPlot(range(img_x_bounds[0], img_x_bounds[1]), - # x_sum[img_x_bounds[0]:img_x_bounds[1]], - # title='sum over theta and y', path=self.output_folder, - # save_fig=self.save_plots, save_only=True) - quickPlot((range(x_sum.size), x_sum), - ([img_x_bounds[0], img_x_bounds[0]], [x_sum_min, x_sum_max], 'r-'), - ([img_x_bounds[1], img_x_bounds[1]], [x_sum_min, x_sum_max], 'r-'), - title='sum over theta and y', path=self.output_folder, - save_fig=self.save_plots, save_only=True) - del x_sum - for i,stack in enumerate(stacks): - tomo_stack = loadImageStack(tomo_stack_files[i], self.config['data_filetype'], - stack['img_offset'], 1) - tomo_stack = tomo_stack[0,:,:] - if num_x_min is not None: - tomo_stack_max = tomo_stack.max() - tomo_stack[img_x_bounds[0],:] = tomo_stack_max - tomo_stack[img_x_bounds[1]-1,:] = tomo_stack_max - title = f'tomography image at theta={self.config["theta_range"]["start"]}' - if self.galaxy_flag: - quickImshow(tomo_stack, title=title, path='setup_pngs', - name=f'tomo_{stack["index"]}.png', save_fig=True, save_only=True, - show_grid=True) - else: - quickImshow(tomo_stack, title=title, path=self.output_folder, - name=f'tomo_{stack["index"]}.png', save_fig=self.save_plots, - save_only=True, show_grid=True) - del tomo_stack - logging.debug(f'img_x_bounds: {img_x_bounds}') - - if self.save_plots_only: - clearImshow('bright field') - clearPlot('sum over theta and y') - if title: - clearPlot(title) - - # Update config and save to file - if preprocess is None: - self.cf.config['preprocess'] = {'img_x_bounds' : img_x_bounds} - else: - preprocess['img_x_bounds'] = img_x_bounds - self.cf.saveFile(self.config_out) - - def _setZoomOrSkip(self): - """Set zoom and/or theta skip to reduce memory the requirement for the analysis. - """ - preprocess = self.config.get('preprocess') - zoom_perc = 100 - if not self.galaxy_flag: - if preprocess is None or 'zoom_perc' not in preprocess: - if input_yesno('\nDo you want to zoom in to reduce memory requirement (y/n)?', 'n'): - zoom_perc = input_int(' Enter zoom percentage', 1, 100) - else: - zoom_perc = preprocess['zoom_perc'] - if is_num(zoom_perc, 1., 100.): - zoom_perc = int(zoom_perc) - else: - illegal_value(zoom_perc, 'preprocess:zoom_perc', 'config file') - zoom_perc = 100 - num_theta_skip = 0 - if not self.galaxy_flag: - if preprocess is None or 'num_theta_skip' not in preprocess: - if input_yesno('Do you want to skip thetas to reduce memory requirement (y/n)?', - 'n'): - num_theta_skip = input_int(' Enter the number skip theta interval', 0, - self.num_thetas-1) - else: - num_theta_skip = preprocess['num_theta_skip'] - if not is_int(num_theta_skip, 0): - illegal_value(num_theta_skip, 'preprocess:num_theta_skip', 'config file') - num_theta_skip = 0 - logging.debug(f'zoom_perc = {zoom_perc}') - logging.debug(f'num_theta_skip = {num_theta_skip}') - - # Update config and save to file - if preprocess is None: - self.cf.config['preprocess'] = {'zoom_perc' : zoom_perc, - 'num_theta_skip' : num_theta_skip} - else: - preprocess['zoom_perc'] = zoom_perc - preprocess['num_theta_skip'] = num_theta_skip - self.cf.saveFile(self.config_out) - - def _loadTomo(self, base_name, index, required=False): - """Load a tomography stack. - """ - # stack order: row,theta,column - zoom_perc = None - preprocess = self.config.get('preprocess') - if preprocess: - zoom_perc = preprocess.get('zoom_perc') - if zoom_perc is None or zoom_perc == 100: - title = f'{base_name} fullres' - else: - title = f'{base_name} {zoom_perc}p' - title += f'_{index}' - tomo_file = re.sub(r"\s+", '_', f'{self.output_folder}/{title}.npy') - load_flag = False - available = False - if os.path.isfile(tomo_file): - available = True - if required: - load_flag = True - else: - load_flag = input_yesno(f'\nDo you want to load {tomo_file} (y/n)?') - stack = np.array([]) - if load_flag: - t0 = time() - logging.info(f'Loading {tomo_file} ...') - try: - stack = np.load(tomo_file) - except IOError or ValueError: - stack = np.array([]) - logging.error(f'Error loading {tomo_file}') - logging.info(f'... done in {time()-t0:.2f} seconds!') - if stack.size: - quickImshow(stack[:,0,:], title=title, path=self.output_folder, - save_fig=self.save_plots, save_only=self.save_plots_only) - return stack, available - - def _saveTomo(self, base_name, stack, index=None): - """Save a tomography stack. - """ - zoom_perc = None - preprocess = self.config.get('preprocess') - if preprocess: - zoom_perc = preprocess.get('zoom_perc') - if zoom_perc is None or zoom_perc == 100: - title = f'{base_name} fullres' - else: - title = f'{base_name} {zoom_perc}p' - if index: - title += f'_{index}' - tomo_file = re.sub(r"\s+", '_', f'{self.output_folder}/{title}.npy') - t0 = time() - logging.info(f'Saving {tomo_file} ...') - np.save(tomo_file, stack) - logging.info(f'... done in {time()-t0:.2f} seconds!') - - def _genTomo(self, tomo_stack_files, available_stacks, num_core=None): - """Generate tomography fields. - """ - if num_core is None: - num_core = self.num_core - stacks = self.config['stack_info']['stacks'] - assert(len(self.tomo_stacks) == self.config['stack_info']['num']) - assert(len(self.tomo_stacks) == len(stacks)) - if len(available_stacks) != len(stacks): - logging.warning('Illegal dimension of available_stacks in _genTomo'+ - f'({len(available_stacks)}'); - available_stacks = [False]*self.num_tomo_stacks - - preprocess = self.config.get('preprocess') - if preprocess is None: - img_x_bounds = [0, self.tbf.shape[0]] - img_y_bounds = [0, self.tbf.shape[1]] - zoom_perc = 100 - num_theta_skip = 0 - else: - img_x_bounds = preprocess.get('img_x_bounds', [0, self.tbf.shape[0]]) - img_y_bounds = preprocess.get('img_y_bounds', [0, self.tbf.shape[1]]) - zoom_perc = preprocess.get('zoom_perc', 100) - num_theta_skip = preprocess.get('num_theta_skip', 0) - - if self.tdf.size: - tdf = self.tdf[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]] - else: - logging.warning('Dark field unavailable') - if not self.tbf.size: - raise ValueError('Bright field unavailable') - tbf = self.tbf[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]] - - for i,stack in enumerate(stacks): - # Check if stack is already loaded or available - if self.tomo_stacks[i].size or available_stacks[i]: - continue - - # Load a stack of tomography images - index = stack['index'] - t0 = time() - tomo_stack = loadImageStack(tomo_stack_files[i], self.config['data_filetype'], - stack['img_offset'], self.config['theta_range']['num'], num_theta_skip, - img_x_bounds, img_y_bounds) - tomo_stack = tomo_stack.astype('float64') - logging.debug(f'loading stack {index} took {time()-t0:.2f} seconds!') - - # Subtract dark field - if self.tdf.size: - t0 = time() - with set_numexpr_threads(self.num_core): - ne.evaluate('tomo_stack-tdf', out=tomo_stack) - logging.debug(f'subtracting dark field took {time()-t0:.2f} seconds!') - - # Normalize - t0 = time() - with set_numexpr_threads(self.num_core): - ne.evaluate('tomo_stack/tbf', out=tomo_stack, truediv=True) - logging.debug(f'normalizing took {time()-t0:.2f} seconds!') - - # Remove non-positive values and linearize data - t0 = time() - cutoff = 1.e-6 - with set_numexpr_threads(self.num_core): - ne.evaluate('where(tomo_stack<cutoff, cutoff, tomo_stack)', out=tomo_stack) - with set_numexpr_threads(self.num_core): - ne.evaluate('-log(tomo_stack)', out=tomo_stack) - logging.debug('removing non-positive values and linearizing data took '+ - f'{time()-t0:.2f} seconds!') - - # Get rid of nans/infs that may be introduced by normalization - t0 = time() - np.where(np.isfinite(tomo_stack), tomo_stack, 0.) - logging.debug(f'remove nans/infs took {time()-t0:.2f} seconds!') - - # Downsize tomography stack to smaller size - tomo_stack = tomo_stack.astype('float32') - if not self.galaxy_flag: - title = f'red stack fullres {index}' - if not self.test_mode: - quickImshow(tomo_stack[0,:,:], title=title, path=self.output_folder, - save_fig=self.save_plots, save_only=self.save_plots_only) - if zoom_perc != 100: - t0 = time() - logging.info(f'Zooming in ...') - tomo_zoom_list = [] - for j in range(tomo_stack.shape[0]): - tomo_zoom = spi.zoom(tomo_stack[j,:,:], 0.01*zoom_perc) - tomo_zoom_list.append(tomo_zoom) - tomo_stack = np.stack([tomo_zoom for tomo_zoom in tomo_zoom_list]) - logging.info(f'... done in {time()-t0:.2f} seconds!') - del tomo_zoom_list - if not self.galaxy_flag: - title = f'red stack {zoom_perc}p {index}' - if not self.test_mode: - quickImshow(tomo_stack[0,:,:], title=title, path=self.output_folder, - save_fig=self.save_plots, save_only=self.save_plots_only) - - # Convert tomography stack from theta,row,column to row,theta,column - tomo_stack = np.swapaxes(tomo_stack, 0, 1) - - # Save tomography stack to file - if not self.galaxy_flag: - if not self.test_mode: - self._saveTomo('red stack', tomo_stack, index) - else: - np.savetxt(f'{self.output_folder}/red_stack_{index}.txt', - tomo_stack[0,:,:], fmt='%.6e') - - # Combine stacks - t0 = time() - self.tomo_stacks[i] = tomo_stack - logging.debug(f'combining nstack took {time()-t0:.2f} seconds!') - - # Update config and save to file - stack['preprocessed'] = True - stack.pop('reconstructed', 'reconstructed not found') - find_center = self.config.get('find_center') - if find_center: - find_center.pop('completed', 'completed not found') - if self.test_mode: - find_center.pop('lower_center_offset', 'lower_center_offset not found') - find_center.pop('upper_center_offset', 'upper_center_offset not found') - self.cf.saveFile(self.config_out) - - if self.tdf.size: - del tdf - del tbf - - def _reconstructOnePlane(self, tomo_plane_T, center, thetas_deg, eff_pixel_size, - cross_sectional_dim, plot_sinogram=True, num_core=1): - """Invert the sinogram for a single tomography plane. - """ - # tomo_plane_T index order: column,theta - assert(0 <= center < tomo_plane_T.shape[0]) - center_offset = center-tomo_plane_T.shape[0]/2 - two_offset = 2*int(np.round(center_offset)) - two_offset_abs = np.abs(two_offset) - max_rad = int(0.5*(cross_sectional_dim/eff_pixel_size)*1.1) # 10% slack to avoid edge effects - if max_rad > 0.5*tomo_plane_T.shape[0]: - max_rad = 0.5*tomo_plane_T.shape[0] - dist_from_edge = max(1, int(np.floor((tomo_plane_T.shape[0]-two_offset_abs)/2.)-max_rad)) - if two_offset >= 0: - logging.debug(f'sinogram range = [{two_offset+dist_from_edge}, {-dist_from_edge}]') - sinogram = tomo_plane_T[two_offset+dist_from_edge:-dist_from_edge,:] - else: - logging.debug(f'sinogram range = [{dist_from_edge}, {two_offset-dist_from_edge}]') - sinogram = tomo_plane_T[dist_from_edge:two_offset-dist_from_edge,:] - if plot_sinogram and not self.test_mode: - quickImshow(sinogram.T, f'sinogram center offset{center_offset:.2f}', - path=self.output_folder, save_fig=self.save_plots, - save_only=self.save_plots_only, aspect='auto') - - # Inverting sinogram - t0 = time() - recon_sinogram = iradon(sinogram, theta=thetas_deg, circle=True) - logging.debug(f'inverting sinogram took {time()-t0:.2f} seconds!') - del sinogram - - # Performing Gaussian filtering and removing ring artifacts - recon_parameters = self.config.get('recon_parameters') - if recon_parameters is None: - sigma = 1.0 - ring_width = 15 - else: - sigma = recon_parameters.get('gaussian_sigma', 1.0) - if not is_num(sigma, 0.0): - logging.warning(f'Illegal gaussian_sigma ({sigma}) in _reconstructOnePlane, '+ - 'set to a default value of 1.0') - sigma = 1.0 - ring_width = recon_parameters.get('ring_width', 15) - if not is_int(ring_width, 0): - logging.warning(f'Illegal ring_width ({ring_width}) in _reconstructOnePlane, '+ - 'set to a default value of 15') - ring_width = 15 - t0 = time() - recon_sinogram = spi.gaussian_filter(recon_sinogram, sigma, mode='nearest') - recon_clean = np.expand_dims(recon_sinogram, axis=0) - del recon_sinogram - t1 = time() - logging.debug(f'running remove_ring on {num_core} cores ...') - recon_clean = tomopy.misc.corr.remove_ring(recon_clean, rwidth=ring_width, ncore=num_core) - logging.debug(f'... remove_ring took {time()-t1:.2f} seconds!') - logging.debug(f'filtering and removing ring artifact took {time()-t0:.2f} seconds!') - return recon_clean - - def _plotEdgesOnePlane(self, recon_plane, title, path=None): - vis_parameters = self.config.get('vis_parameters') - if vis_parameters is None: - weight = 0.1 - else: - weight = vis_parameters.get('denoise_weight', 0.1) - if not is_num(weight, 0.0): - logging.warning(f'Illegal weight ({weight}) in _plotEdgesOnePlane, '+ - 'set to a default value of 0.1') - weight = 0.1 - edges = denoise_tv_chambolle(recon_plane, weight=weight) - vmax = np.max(edges[0,:,:]) - vmin = -vmax - if self.galaxy_flag: - quickImshow(edges[0,:,:], title, path=path, save_fig=True, save_only=True, - cmap='gray', vmin=vmin, vmax=vmax) - else: - if path is None: - path=self.output_folder - quickImshow(edges[0,:,:], f'{title} coolwarm', path=path, - save_fig=self.save_plots, save_only=self.save_plots_only, cmap='coolwarm') - quickImshow(edges[0,:,:], f'{title} gray', path=path, - save_fig=self.save_plots, save_only=self.save_plots_only, cmap='gray', - vmin=vmin, vmax=vmax) - del edges - - def _findCenterOnePlane(self, sinogram, row, thetas_deg, eff_pixel_size, cross_sectional_dim, - tol=0.1, num_core=1, galaxy_param=None): - """Find center for a single tomography plane. - """ - if self.galaxy_flag: - assert(isinstance(galaxy_param, dict)) - if not os.path.exists('find_center_pngs'): - os.mkdir('find_center_pngs') - # sinogram index order: theta,column - # need index order column,theta for iradon, so take transpose - sinogram_T = sinogram.T - center = sinogram.shape[1]/2 - - # try automatic center finding routines for initial value - t0 = time() - if num_core > num_core_tomopy_limit: - logging.debug(f'running find_center_vo on {num_core_tomopy_limit} cores ...') - tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core_tomopy_limit) - else: - logging.debug(f'running find_center_vo on {num_core} cores ...') - tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core) - logging.debug(f'... find_center_vo took {time()-t0:.2f} seconds!') - center_offset_vo = tomo_center-center - if self.test_mode: - logging.info(f'Center at row {row} using Nghia Vo’s method = {center_offset_vo:.2f}') - del sinogram_T - return float(center_offset_vo) - elif self.galaxy_flag: - logging.info(f'Center at row {row} using Nghia Vo’s method = {center_offset_vo:.2f}') - t0 = time() - logging.debug(f'running _reconstructOnePlane on {num_core} cores ...') - recon_plane = self._reconstructOnePlane(sinogram_T, tomo_center, thetas_deg, - eff_pixel_size, cross_sectional_dim, False, num_core) - logging.debug(f'... _reconstructOnePlane took {time()-t0:.2f} seconds!') - title = f'edges row{row} center offset{center_offset_vo:.2f} Vo' - self._plotEdgesOnePlane(recon_plane, title, path='find_center_pngs') - del recon_plane - if not galaxy_param['center_type_selector']: - del sinogram_T - return float(center_offset_vo) - else: - print(f'Center at row {row} using Nghia Vo’s method = {center_offset_vo:.2f}') - recon_plane = self._reconstructOnePlane(sinogram_T, tomo_center, thetas_deg, - eff_pixel_size, cross_sectional_dim, False, num_core) - title = f'edges row{row} center offset{center_offset_vo:.2f} Vo' - self._plotEdgesOnePlane(recon_plane, title) - if not self.galaxy_flag: - if input_yesno('Try finding center using phase correlation (y/n)?', 'n'): - tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=0.1, - rotc_guess=tomo_center) - error = 1. - while error > tol: - prev = tomo_center - tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=tol, - rotc_guess=tomo_center) - error = np.abs(tomo_center-prev) - center_offset = tomo_center-center - print(f'Center at row {row} using phase correlation = {center_offset:.2f}') - recon_plane = self._reconstructOnePlane(sinogram_T, tomo_center, thetas_deg, - eff_pixel_size, cross_sectional_dim, False, num_core) - title = f'edges row{row} center_offset{center_offset:.2f} PC' - self._plotEdgesOnePlane(recon_plane, title) - if input_yesno('Accept a center location (y) or continue search (n)?', 'y'): - center_offset = input_num(' Enter chosen center offset', -center, center, - center_offset_vo) - del sinogram_T - del recon_plane - return float(center_offset) - - # perform center finding search - while True: - if self.galaxy_flag and galaxy_param and galaxy_param['center_type_selector']: - set_center = center_offset_vo - if galaxy_param['center_type_selector'] == 'user': - set_center = galaxy_param['set_center'] - set_range = galaxy_param['set_range'] - center_offset_step = galaxy_param['set_step'] - if (not is_num(set_range, 0) or not is_num(center_offset_step) or - center_offset_step <= 0): - logging.warning('Illegal center finding search parameter, skip search') - del sinogram_T - return float(center_offset_vo) - set_range = center_offset_step*max(1, int(set_range/center_offset_step)) - center_offset_low = set_center-set_range - center_offset_upp = set_center+set_range - else: - center_offset_low = input_int('\nEnter lower bound for center offset', - -center, center) - center_offset_upp = input_int('Enter upper bound for center offset', - center_offset_low, center) - if center_offset_upp == center_offset_low: - center_offset_step = 1 - else: - center_offset_step = input_int('Enter step size for center offset search', - 1, center_offset_upp-center_offset_low) - num_center_offset = 1+int((center_offset_upp-center_offset_low)/center_offset_step) - center_offsets = np.linspace(center_offset_low, center_offset_upp, num_center_offset) - for center_offset in center_offsets: - if center_offset == center_offset_vo: - continue - t0 = time() - logging.debug(f'running _reconstructOnePlane on {num_core} cores ...') - recon_plane = self._reconstructOnePlane(sinogram_T, center_offset+center, - thetas_deg, eff_pixel_size, cross_sectional_dim, False, num_core) - logging.debug(f'... _reconstructOnePlane took {time()-t0:.2f} seconds!') - title = f'edges row{row} center_offset{center_offset:.2f}' - if self.galaxy_flag: - self._plotEdgesOnePlane(recon_plane, title, path='find_center_pngs') - else: - self._plotEdgesOnePlane(recon_plane, title) - if self.galaxy_flag or input_int('\nContinue (0) or end the search (1)', 0, 1): - break - - del sinogram_T - del recon_plane - if self.galaxy_flag: - center_offset = center_offset_vo - else: - center_offset = input_num(' Enter chosen center offset', -center, center) - return float(center_offset) - - def _reconstructOneTomoStack(self, tomo_stack, thetas, row_bounds=None, center_offsets=[], - num_core=1, algorithm='gridrec'): - """reconstruct a single tomography stack. - """ - # stack order: row,theta,column - # thetas must be in radians - # centers_offset: tomography axis shift in pixels relative to column center - # RV should we remove stripes? - # https://tomopy.readthedocs.io/en/latest/api/tomopy.prep.stripe.html - # RV should we remove rings? - # https://tomopy.readthedocs.io/en/latest/api/tomopy.misc.corr.html - # RV: Add an option to do (extra) secondary iterations later or to do some sort of convergence test? - if row_bounds is None: - row_bounds = [0, tomo_stack.shape[0]] - else: - if not is_index_range(row_bounds, 0, tomo_stack.shape[0]): - raise ValueError('Illegal row bounds in reconstructOneTomoStack') - if thetas.size != tomo_stack.shape[1]: - raise ValueError('theta dimension mismatch in reconstructOneTomoStack') - if not len(center_offsets): - centers = np.zeros((row_bounds[1]-row_bounds[0])) - elif len(center_offsets) == 2: - centers = np.linspace(center_offsets[0], center_offsets[1], - row_bounds[1]-row_bounds[0]) - else: - if center_offsets.size != row_bounds[1]-row_bounds[0]: - raise ValueError('center_offsets dimension mismatch in reconstructOneTomoStack') - centers = center_offsets - centers += tomo_stack.shape[2]/2 - - # Removing horizontal stripe and ring artifacts and perform secondary iterations - recon_parameters = self.config.get('recon_parameters') - if recon_parameters is None: - sigma = 2.0 - secondary_iters = 0 - ring_width = 15 - else: - sigma = recon_parameters.get('stripe_fw_sigma', 2.0) - if not is_num(sigma, 0): - logging.warning(f'Illegal stripe_fw_sigma ({sigma}) in '+ - '_reconstructOneTomoStack, set to a default value of 2.0') - ring_width = 15 - secondary_iters = recon_parameters.get('secondary_iters', 0) - if not is_int(secondary_iters, 0): - logging.warning(f'Illegal secondary_iters ({secondary_iters}) in '+ - '_reconstructOneTomoStack, set to a default value of 0 (skip them)') - ring_width = 0 - ring_width = recon_parameters.get('ring_width', 15) - if not is_int(ring_width, 0): - logging.warning(f'Illegal ring_width ({ring_width}) in _reconstructOnePlane, '+ - 'set to a default value of 15') - ring_width = 15 - if True: - t0 = time() - if num_core > num_core_tomopy_limit: - logging.debug('running remove_stripe_fw on {num_core_tomopy_limit} cores ...') - tomo_stack = tomopy.prep.stripe.remove_stripe_fw( - tomo_stack[row_bounds[0]:row_bounds[1]], sigma=sigma, - ncore=num_core_tomopy_limit) - else: - logging.debug(f'running remove_stripe_fw on {num_core} cores ...') - tomo_stack = tomopy.prep.stripe.remove_stripe_fw( - tomo_stack[row_bounds[0]:row_bounds[1]], sigma=sigma, ncore=num_core) - logging.debug(f'... tomopy.prep.stripe.remove_stripe_fw took {time()-t0:.2f} seconds!') - else: - tomo_stack = tomo_stack[row_bounds[0]:row_bounds[1]] - logging.debug('performing initial reconstruction') - t0 = time() - logging.debug(f'running recon on {num_core} cores ...') - tomo_recon_stack = tomopy.recon(tomo_stack, thetas, centers, sinogram_order=True, - algorithm=algorithm, ncore=num_core) - logging.debug(f'... recon took {time()-t0:.2f} seconds!') - if secondary_iters > 0: - logging.debug(f'running {secondary_iters} secondary iterations') - #options = {'method':'SIRT_CUDA', 'proj_type':'cuda', 'num_iter':secondary_iters} - #RV: doesn't work for me: "Error: CUDA error 803: system has unsupported display driver / - # cuda driver combination." - #options = {'method':'SIRT', 'proj_type':'linear', 'MinConstraint': 0, 'num_iter':secondary_iters} - #SIRT did not finish while running overnight - #options = {'method':'SART', 'proj_type':'linear', 'num_iter':secondary_iters} - options = {'method':'SART', 'proj_type':'linear', 'MinConstraint': 0, - 'num_iter':secondary_iters} - t0 = time() - logging.debug(f'running recon on {num_core} cores ...') - tomo_recon_stack = tomopy.recon(tomo_stack, thetas, centers, - init_recon=tomo_recon_stack, options=options, sinogram_order=True, - algorithm=tomopy.astra, ncore=num_core) - logging.debug(f'... recon took {time()-t0:.2f} seconds!') - if True: - t0 = time() - logging.debug(f'running remove_ring on {num_core} cores ...') - tomopy.misc.corr.remove_ring(tomo_recon_stack, rwidth=ring_width, out=tomo_recon_stack, - ncore=num_core) - logging.debug(f'... remove_ring took {time()-t0:.2f} seconds!') - return tomo_recon_stack - - def findImageFiles(self, tiff_to_h5_flag = False): - """Find all available image files. - """ - self.is_valid = True - - # Find dark field images - dark_field = self.config['dark_field'] - img_start, num_imgs, dark_files = findImageFiles( - dark_field['data_path'], self.config['data_filetype'], 'dark field') - if img_start < 0 or num_imgs < 1: - logging.error('Unable to find suitable dark field images') - if dark_field['data_path']: - self.is_valid = False - img_start_old = dark_field.get('img_start') - num_imgs_old = dark_field.get('num') - if num_imgs_old is None: - dark_field['num'] = num_imgs - else: - if num_imgs_old > num_imgs: - logging.error('Inconsistent number of availaible dark field images') - if dark_field['data_path']: - self.is_valid = False - if img_start_old is None: - dark_field['img_start'] = img_start - else: - if img_start_old < img_start: - logging.error('Inconsistent image start index for dark field images') - if dark_field['data_path']: - self.is_valid = False - logging.info(f'Number of dark field images = {dark_field["num"]}') - logging.info(f'Dark field image start index = {dark_field["img_start"]}') - if num_imgs and tiff_to_h5_flag and self.config['data_filetype'] == 'tif': - dark_files = combine_tiffs_in_h5(dark_files, num_imgs, - f'{self.config["work_folder"]}/dark_field.h5') - dark_field['data_path'] = dark_files[0] - - # Find bright field images - bright_field = self.config['bright_field'] - img_start, num_imgs, bright_files = findImageFiles( - bright_field['data_path'], self.config['data_filetype'], 'bright field') - if img_start < 0 or num_imgs < 1: - logging.error('Unable to find suitable bright field images') - self.is_valid = False - img_start_old = bright_field.get('img_start') - num_imgs_old = bright_field.get('num') - if num_imgs_old is None: - bright_field['num'] = num_imgs - else: - if num_imgs_old > num_imgs: - logging.error('Inconsistent number of availaible bright field images') - self.is_valid = False - if img_start_old is None: - bright_field['img_start'] = img_start - else: - if img_start_old < img_start: - logging.warning('Inconsistent image start index for bright field images') - self.is_valid = False - logging.info(f'Number of bright field images = {bright_field["num"]}') - logging.info(f'Bright field image start index = {bright_field["img_start"]}') - if num_imgs and tiff_to_h5_flag and self.config['data_filetype'] == 'tif': - bright_files = combine_tiffs_in_h5(bright_files, num_imgs, - f'{self.config["work_folder"]}/bright_field.h5') - bright_field['data_path'] = bright_files[0] - - # Find tomography images - tomo_stack_files = [] - for stack in self.config['stack_info']['stacks']: - index = stack['index'] - img_start, num_imgs, tomo_files = findImageFiles( - stack['data_path'], self.config['data_filetype'], f'tomography set {index}') - if img_start < 0 or num_imgs < 1: - logging.error('Unable to find suitable tomography images') - self.is_valid = False - img_start_old = stack.get('img_start') - num_imgs_old = stack.get('num') - if num_imgs_old is None: - stack['num'] = num_imgs - else: - if num_imgs_old > num_imgs: - logging.error('Inconsistent number of availaible tomography images') - self.is_valid = False - if img_start_old is None: - stack['img_start'] = img_start - else: - if img_start_old < img_start: - logging.warning('Inconsistent image start index for tomography images') - self.is_valid = False - logging.info(f'Number of tomography images for set {index} = {stack["num"]}') - logging.info(f'Tomography set {index} image start index = {stack["img_start"]}') - if num_imgs and tiff_to_h5_flag and self.config['data_filetype'] == 'tif': - tomo_files = combine_tiffs_in_h5(tomo_files, num_imgs, - f'{self.config["work_folder"]}/tomo_field_{index}.h5') - stack['data_path'] = tomo_files[0] - tomo_stack_files.append(tomo_files) - del tomo_files - - # Safe updated config - if tiff_to_h5_flag: - self.config['data_filetype'] == 'h5' - if self.is_valid: - self.cf.saveFile(self.config_out) - - return dark_files, bright_files, tomo_stack_files - - def loadTomoStacks(self, input_name, recon_flag=False): - """Load tomography stacks (only for Galaxy). - """ - assert(self.galaxy_flag) - t0 = time() - logging.info(f'Loading preprocessed tomography stack from {input_name} ...') - stack_info = self.config['stack_info'] - stacks = stack_info.get('stacks') - assert(len(self.tomo_stacks) == stack_info['num']) - with np.load(input_name) as f: - if recon_flag: - for i,stack in enumerate(stacks): - self.tomo_recon_stacks[i] = f[f'set_{stack["index"]}'] - logging.info(f'loaded stack {i}: index = {stack["index"]}, shape = '+ - f'{self.tomo_recon_stacks[i].shape}') - else: - for i,stack in enumerate(stacks): - self.tomo_stacks[i] = f[f'set_{stack["index"]}'] - logging.info(f'loaded stack {i}: index = {stack["index"]}, shape = '+ - f'{self.tomo_stacks[i].shape}') - logging.info(f'... done in {time()-t0:.2f} seconds!') - - def genTomoStacks(self, galaxy_param=None, num_core=None): - """Preprocess tomography images. - """ - if num_core is None: - num_core = self.num_core - logging.info(f'num_core = {num_core}') - # Try loading any already preprocessed stacks (skip in Galaxy) - # preprocessed stack order for each one in stack: row,theta,column - stack_info = self.config['stack_info'] - stacks = stack_info['stacks'] - num_tomo_stacks = stack_info['num'] - assert(num_tomo_stacks == len(self.tomo_stacks)) - available_stacks = [False]*num_tomo_stacks - if self.galaxy_flag: - assert(isinstance(galaxy_param, dict)) - tdf_files = galaxy_param['tdf_files'] - tbf_files = galaxy_param['tbf_files'] - tomo_stack_files = galaxy_param['tomo_stack_files'] - assert(num_tomo_stacks == len(tomo_stack_files)) - if not os.path.exists('setup_pngs'): - os.mkdir('setup_pngs') - else: - if galaxy_param: - logging.warning('Ignoring galaxy_param in genTomoStacks (only for Galaxy)') - galaxy_param = None - tdf_files, tbf_files, tomo_stack_files = self.findImageFiles() - if not self.is_valid: - return - if self.test_mode: - required = True - else: - required = False - for i,stack in enumerate(stacks): - if not self.tomo_stacks[i].size and stack.get('preprocessed', False): - self.tomo_stacks[i], available_stacks[i] = \ - self._loadTomo('red stack', stack['index'], required=required) - - # Preprocess any unloaded stacks - if False in available_stacks: - logging.debug('Preprocessing tomography images') - - # Check required image files (skip in Galaxy) - if not self.galaxy_flag: - self._selectImageRanges(available_stacks) - if not self.is_valid: - return - - # Generate dark field - if tdf_files: - self._genDark(tdf_files) - - # Generate bright field - self._genBright(tbf_files) - - # Set vertical detector bounds for image stack - self._setDetectorBounds(tomo_stack_files) - - # Set zoom and/or theta skip to reduce memory the requirement - self._setZoomOrSkip() - - # Generate tomography fields - self._genTomo(tomo_stack_files, available_stacks, num_core) - - # Save tomography stack to file - if self.galaxy_flag: - t0 = time() - output_name = galaxy_param['output_name'] - logging.info(f'Saving preprocessed tomography stack to {output_name} ...') - save_stacks = {f'set_{stack["index"]}':tomo_stack - for stack,tomo_stack in zip(stacks,self.tomo_stacks)} - np.savez(output_name, **save_stacks) - logging.info(f'... done in {time()-t0:.2f} seconds!') - - del available_stacks - - # Adjust sample reference height, update config and save to file - preprocess = self.config.get('preprocess') - if preprocess is None: - img_x_bounds = [0, self.tbf.shape[0]] - else: - img_x_bounds = preprocess.get('img_x_bounds', [0, self.tbf.shape[0]]) - pixel_size = self.config['detector']['pixel_size'] - if pixel_size is None: - raise ValueError('Detector pixel size unavailable') - pixel_size *= img_x_bounds[0] - for stack in stacks: - stack['ref_height'] = stack['ref_height']+pixel_size - self.cf.saveFile(self.config_out) - - def findCenters(self, galaxy_param=None, num_core=None): - """Find rotation axis centers for the tomography stacks. - """ - if num_core is None: - num_core = self.num_core - logging.info(f'num_core = {num_core}') - logging.debug('Find centers for tomography stacks') - stacks = self.config['stack_info']['stacks'] - available_stacks = [stack['index'] for stack in stacks if stack.get('preprocessed', False)] - logging.debug('Available stacks: {available_stacks}') - if self.galaxy_flag: - row_bounds = galaxy_param['row_bounds'] - center_rows = galaxy_param['center_rows'] - center_type_selector = galaxy_param['center_type_selector'] - if center_type_selector: - if center_type_selector == 'vo': - set_center = None - elif center_type_selector == 'user': - set_center = galaxy_param['set_center'] - else: - logging.error('Illegal center_type_selector entry in galaxy_param '+ - f'({center_type_selector})') - galaxy_param['center_type_selector'] = None - logging.debug(f'row_bounds = {row_bounds}') - logging.debug(f'center_rows = {center_rows}') - logging.debug(f'center_type_selector = {center_type_selector}') - else: - if galaxy_param: - logging.warning('Ignoring galaxy_param in findCenters (only for Galaxy)') - galaxy_param = None - - # Check for valid available center stack index - find_center = self.config.get('find_center') - center_stack_index = None - if find_center and 'center_stack_index' in find_center: - center_stack_index = find_center['center_stack_index'] - if (not isinstance(center_stack_index, int) or - center_stack_index not in available_stacks): - illegal_value(center_stack_index, 'find_center:center_stack_index', 'config file') - else: - if self.test_mode: - assert(find_center.get('completed', False) == False) - else: - print('\nFound calibration center offset info for stack '+ - f'{center_stack_index}') - if (input_yesno('Do you want to use this again (y/n)?', 'y') - and find_center.get('completed', False) == True): - return - - # Load the required preprocessed stack - # preprocessed stack order: row,theta,column - num_tomo_stacks = self.config['stack_info']['num'] - assert(len(stacks) == num_tomo_stacks) - assert(len(self.tomo_stacks) == num_tomo_stacks) - if num_tomo_stacks == 1: - if not center_stack_index: - center_stack_index = stacks[0]['index'] - elif center_stack_index != stacks[0]['index']: - raise ValueError(f'Inconsistent center_stack_index {center_stack_index}') - if not self.tomo_stacks[0].size: - self.tomo_stacks[0], available = self._loadTomo('red stack', center_stack_index, - required=True) - center_stack = self.tomo_stacks[0] - if not center_stack.size: - stacks[0]['preprocessed'] = False - raise OSError('Unable to load the required preprocessed tomography stack') - assert(stacks[0].get('preprocessed', False) == True) - elif self.galaxy_flag: - center_stack_index = stacks[int(num_tomo_stacks/2)]['index'] - tomo_stack_index = available_stacks.index(center_stack_index) - center_stack = self.tomo_stacks[tomo_stack_index] - if not center_stack.size: - stacks[tomo_stack_index]['preprocessed'] = False - raise OSError('Unable to load the required preprocessed tomography stack') - assert(stacks[tomo_stack_index].get('preprocessed', False) == True) - else: - while True: - if not center_stack_index: - center_stack_index = input_int('\nEnter tomography stack index to get ' - f'rotation axis centers ({available_stacks})', inset=available_stacks) - while center_stack_index not in available_stacks: - center_stack_index = input_int('\nEnter tomography stack index to get ' - f'rotation axis centers ({available_stacks})', inset=available_stacks) - tomo_stack_index = available_stacks.index(center_stack_index) - if not self.tomo_stacks[tomo_stack_index].size: - self.tomo_stacks[tomo_stack_index], available = self._loadTomo( - 'red stack', center_stack_index, required=True) - center_stack = self.tomo_stacks[tomo_stack_index] - if not center_stack.size: - stacks[tomo_stack_index]['preprocessed'] = False - logging.error(f'Unable to load the {center_stack_index}th '+ - 'preprocessed tomography stack, pick another one') - else: - break - assert(stacks[tomo_stack_index].get('preprocessed', False) == True) - if find_center is None: - self.config['find_center'] = {'center_stack_index' : center_stack_index} - find_center = self.config['find_center'] - else: - find_center['center_stack_index'] = center_stack_index - if not self.galaxy_flag: - row_bounds = find_center.get('row_bounds', None) - center_rows = [find_center.get('lower_row', None), - find_center.get('upper_row', None)] - if row_bounds is None: - row_bounds = [0, center_stack.shape[0]] - if row_bounds[0] == -1: - row_bounds[0] = 0 - if row_bounds[1] == -1: - row_bounds[1] = center_stack.shape[0] - if center_rows[0] == -1: - center_rows[0] = 0 - if center_rows[1] == -1: - center_rows[1] = center_stack.shape[0]-1 - if not is_index_range(row_bounds, 0, center_stack.shape[0]): - illegal_value(row_bounds, 'row_bounds', 'Tomo:findCenters') - return - - # Set thetas (in degrees) - theta_range = self.config['theta_range'] - theta_start = theta_range['start'] - theta_end = theta_range['end'] - num_theta = theta_range['num'] - num_theta_skip = self.config['preprocess'].get('num_theta_skip', 0) - thetas_deg = np.linspace(theta_start, theta_end, int(num_theta/(num_theta_skip+1)), - endpoint=False) - - # Get non-overlapping sample row boundaries - zoom_perc = self.config['preprocess'].get('zoom_perc', 100) - pixel_size = self.config['detector']['pixel_size'] - if pixel_size is None: - raise ValueError('Detector pixel size unavailable') - eff_pixel_size = 100.*pixel_size/zoom_perc - logging.debug(f'eff_pixel_size = {eff_pixel_size}') - if num_tomo_stacks == 1: - accept = True - if not self.test_mode and not self.galaxy_flag: - accept = False - print('\nSelect bounds for image reconstruction') - if is_index_range(row_bounds, 0, center_stack.shape[0]): - a_tmp = np.copy(center_stack[:,0,:]) - a_tmp_max = a_tmp.max() - a_tmp[row_bounds[0],:] = a_tmp_max - a_tmp[row_bounds[1]-1,:] = a_tmp_max - print(f'lower bound = {row_bounds[0]} (inclusive)') - print(f'upper bound = {row_bounds[1]} (exclusive)') - quickImshow(a_tmp, title=f'center stack theta={theta_start}', - aspect='auto') - del a_tmp - accept = input_yesno('Accept these bounds (y/n)?', 'y') - if accept: - n1 = row_bounds[0] - n2 = row_bounds[1] - else: - n1, n2 = selectImageBounds(center_stack[:,0,:], 0, - title=f'center stack theta={theta_start}') - else: - tomo_ref_heights = [stack['ref_height'] for stack in stacks] - n1 = int((1.+(tomo_ref_heights[0]+center_stack.shape[0]*eff_pixel_size- - tomo_ref_heights[1])/eff_pixel_size)/2) - n2 = center_stack.shape[0]-n1 - logging.debug(f'n1 = {n1}, n2 = {n2} (n2-n1) = {(n2-n1)*eff_pixel_size:.3f} mm') - if not self.test_mode and not self.galaxy_flag: - tmp = center_stack[:,0,:] - tmp_max = tmp.max() - tmp[n1,:] = tmp_max - tmp[n2-1,:] = tmp_max - if is_index_range(center_rows, 0, tmp.shape[0]): - tmp[center_rows[0],:] = tmp_max - tmp[center_rows[1]-1,:] = tmp_max - extent = [0, tmp.shape[1], tmp.shape[0], 0] - quickImshow(tmp, title=f'center stack theta={theta_start}', - path=self.output_folder, extent=extent, save_fig=self.save_plots, - save_only=self.save_plots_only, aspect='auto') - del tmp - #extent = [0, center_stack.shape[2], n2, n1] - #quickImshow(center_stack[n1:n2,0,:], title=f'center stack theta={theta_start}', - # path=self.output_folder, extent=extent, save_fig=self.save_plots, - # save_only=self.save_plots_only, show_grid=True, aspect='auto') - - # Get cross sectional diameter in mm - cross_sectional_dim = center_stack.shape[2]*eff_pixel_size - logging.debug(f'cross_sectional_dim = {cross_sectional_dim}') - - # Determine center offset at sample row boundaries - logging.info('Determine center offset at sample row boundaries') - - # Lower row center - use_row = False - use_center = False - row = center_rows[0] - if self.test_mode or self.galaxy_flag: - assert(is_int(row, n1, n2-2)) - if is_int(row, n1, n2-2): - if self.test_mode or self.galaxy_flag: - use_row = True - else: - quickImshow(center_stack[:,0,:], title=f'theta={theta_start}', aspect='auto') - use_row = input_yesno('\nCurrent row index for lower center = ' - f'{row}, use this value (y/n)?', 'y') - if self.save_plots_only: - clearImshow(f'theta={theta_start}') - if use_row: - center_offset = find_center.get('lower_center_offset') - if is_num(center_offset): - use_center = input_yesno('Current lower center offset = '+ - f'{center_offset}, use this value (y/n)?', 'y') - if not use_center: - if not use_row: - if not self.test_mode: - quickImshow(center_stack[:,0,:], title=f'theta={theta_start}', - aspect='auto') - row = input_int('\nEnter row index to find lower center', n1, n2-2, n1) - if row == '': - row = n1 - if self.save_plots_only: - clearImshow(f'theta={theta_start}') - # center_stack order: row,theta,column - center_offset = self._findCenterOnePlane(center_stack[row,:,:], row, thetas_deg, - eff_pixel_size, cross_sectional_dim, num_core=num_core, - galaxy_param=galaxy_param) - logging.info(f'lower_center_offset = {center_offset:.2f} {type(center_offset)}') - - # Update config and save to file - find_center['row_bounds'] = [n1, n2] - find_center['lower_row'] = row - find_center['lower_center_offset'] = center_offset - self.cf.saveFile(self.config_out) - lower_row = row - - # Upper row center - use_row = False - use_center = False - row = center_rows[1] - if self.test_mode or self.galaxy_flag: - assert(is_int(row, lower_row+1, n2-1)) - if is_int(row, lower_row+1, n2-1): - if self.test_mode or self.galaxy_flag: - use_row = True - else: - quickImshow(center_stack[:,0,:], title=f'theta={theta_start}', aspect='auto') - use_row = input_yesno('\nCurrent row index for upper center = ' - f'{row}, use this value (y/n)?', 'y') - if self.save_plots_only: - clearImshow(f'theta={theta_start}') - if use_row: - center_offset = find_center.get('upper_center_offset') - if is_num(center_offset): - use_center = input_yesno('Current upper center offset = '+ - f'{center_offset}, use this value (y/n)?', 'y') - if not use_center: - if not use_row: - if not self.test_mode: - quickImshow(center_stack[:,0,:], title=f'theta={theta_start}', - aspect='auto') - row = input_int('\nEnter row index to find upper center', lower_row+1, n2-1, n2-1) - if row == '': - row = n2-1 - if self.save_plots_only: - clearImshow(f'theta={theta_start}') - # center_stack order: row,theta,column - center_offset = self._findCenterOnePlane(center_stack[row,:,:], row, thetas_deg, - eff_pixel_size, cross_sectional_dim, num_core=num_core, - galaxy_param=galaxy_param) - logging.info(f'upper_center_offset = {center_offset:.2f}') - del center_stack - - # Update config and save to file - find_center['upper_row'] = row - find_center['upper_center_offset'] = center_offset - find_center['completed'] = True - self.cf.saveFile(self.config_out) - - def checkCenters(self): - """Check centers for the tomography stacks. - """ - #RV TODO load all stacks and check at all stack boundaries - return - logging.debug('Check centers for tomography stacks') - center_stack_index = self.config.get('center_stack_index') - if center_stack_index is None: - raise ValueError('Unable to read center_stack_index from config') - center_stack_index = self.tomo_stacks[self.tomo_data_indices.index(center_stack_index)] - lower_row = self.config.get('lower_row') - if lower_row is None: - raise ValueError('Unable to read lower_row from config') - lower_center_offset = self.config.get('lower_center_offset') - if lower_center_offset is None: - raise ValueError('Unable to read lower_center_offset from config') - upper_row = self.config.get('upper_row') - if upper_row is None: - raise ValueError('Unable to read upper_row from config') - upper_center_offset = self.config.get('upper_center_offset') - if upper_center_offset is None: - raise ValueError('Unable to read upper_center_offset from config') - center_slope = (upper_center_offset-lower_center_offset)/(upper_row-lower_row) - shift = upper_center_offset-lower_center_offset - if lower_row == 0: - logging.warning(f'lower_row == 0: one row offset between both planes') - else: - lower_row -= 1 - lower_center_offset -= center_slope - - # stack order: stack,row,theta,column - if center_stack_index: - stack1 = self.tomo_stacks[center_stack_index-1] - stack2 = self.tomo_stacks[center_stack_index] - if not stack1.size: - logging.error(f'Unable to load required tomography stack {stack1}') - elif not stack2.size: - logging.error(f'Unable to load required tomography stack {stack1}') - else: - assert(0 <= lower_row < stack2.shape[0]) - assert(0 <= upper_row < stack1.shape[0]) - plane1 = stack1[upper_row,:] - plane2 = stack2[lower_row,:] - for i in range(-2, 3): - shift_i = shift+2*i - plane1_shifted = spi.shift(plane2, [0, shift_i]) - quickPlot((plane1[0,:],), (plane1_shifted[0,:],), - title=f'stacks {stack1} {stack2} shifted {2*i} theta={self.start_theta}', - path=self.output_folder, save_fig=self.save_plots, - save_only=self.save_plots_only) - if center_stack_index < self.num_tomo_stacks-1: - stack1 = self.tomo_stacks[center_stack_index] - stack2 = self.tomo_stacks[center_stack_index+1] - if not stack1.size: - logging.error(f'Unable to load required tomography stack {stack1}') - elif not stack2.size: - logging.error(f'Unable to load required tomography stack {stack1}') - else: - assert(0 <= lower_row < stack2.shape[0]) - assert(0 <= upper_row < stack1.shape[0]) - plane1 = stack1[upper_row,:] - plane2 = stack2[lower_row,:] - for i in range(-2, 3): - shift_i = -shift+2*i - plane1_shifted = spi.shift(plane2, [0, shift_i]) - quickPlot((plane1[0,:],), (plane1_shifted[0,:],), - title=f'stacks {stack1} {stack2} shifted {2*i} theta={start_theta}', - path=self.output_folder, save_fig=self.save_plots, - save_only=self.save_plots_only) - del plane1, plane2, plane1_shifted - - # Update config file - self.config = update('config.txt', 'check_centers', True, 'find_centers') - - def reconstructTomoStacks(self, galaxy_param=None, num_core=None): - """Reconstruct tomography stacks. - """ - if num_core is None: - num_core = self.num_core - logging.info(f'num_core = {num_core}') - if self.galaxy_flag: - assert(galaxy_param) - if not os.path.exists('center_slice_pngs'): - os.mkdir('center_slice_pngs') - logging.debug('Reconstruct tomography stacks') - stacks = self.config['stack_info']['stacks'] - assert(len(self.tomo_stacks) == self.config['stack_info']['num']) - assert(len(self.tomo_stacks) == len(stacks)) - assert(len(self.tomo_recon_stacks) == len(stacks)) - if self.galaxy_flag: - assert(isinstance(galaxy_param, dict)) - # Get rotation axis centers - center_offsets = galaxy_param['center_offsets'] - assert(isinstance(center_offsets, list) and len(center_offsets) == 2) - lower_center_offset = center_offsets[0] - assert(is_num(lower_center_offset)) - upper_center_offset = center_offsets[1] - assert(is_num(upper_center_offset)) - else: - if galaxy_param: - logging.warning('Ignoring galaxy_param in reconstructTomoStacks (only for Galaxy)') - galaxy_param = None - lower_center_offset = None - upper_center_offset = None - - # Get rotation axis rows and centers - find_center = self.config['find_center'] - lower_row = find_center.get('lower_row') - if lower_row is None: - logging.error('Unable to read lower_row from config') - return - upper_row = find_center.get('upper_row') - if upper_row is None: - logging.error('Unable to read upper_row from config') - return - logging.debug(f'lower_row = {lower_row} upper_row = {upper_row}') - assert(lower_row < upper_row) - if lower_center_offset is None: - lower_center_offset = find_center.get('lower_center_offset') - if lower_center_offset is None: - logging.error('Unable to read lower_center_offset from config') - return - if upper_center_offset is None: - upper_center_offset = find_center.get('upper_center_offset') - if upper_center_offset is None: - logging.error('Unable to read upper_center_offset from config') - return - center_slope = (upper_center_offset-lower_center_offset)/(upper_row-lower_row) - - # Set thetas (in radians) - theta_range = self.config['theta_range'] - theta_start = theta_range['start'] - theta_end = theta_range['end'] - num_theta = theta_range['num'] - num_theta_skip = self.config['preprocess'].get('num_theta_skip', 0) - thetas = np.radians(np.linspace(theta_start, theta_end, - int(num_theta/(num_theta_skip+1)), endpoint=False)) - - # Reconstruct tomo stacks - zoom_perc = self.config['preprocess'].get('zoom_perc', 100) - if zoom_perc == 100: - basetitle = 'recon stack fullres' - else: - basetitle = f'recon stack {zoom_perc}p' - load_error = False - for i,stack in enumerate(stacks): - # Check if stack can be loaded - # reconstructed stack order for each one in stack : row/z,x,y - # preprocessed stack order for each one in stack: row,theta,column - index = stack['index'] - if not self.galaxy_flag: - available = False - if stack.get('reconstructed', False): - self.tomo_recon_stacks[i], available = self._loadTomo('recon stack', index) - if self.tomo_recon_stacks[i].size or available: - if self.tomo_stacks[i].size: - self.tomo_stacks[i] = np.array([]) - assert(stack.get('preprocessed', False) == True) - assert(stack.get('reconstructed', False) == True) - continue - stack['reconstructed'] = False - if not self.tomo_stacks[i].size: - self.tomo_stacks[i], available = self._loadTomo('red stack', index, - required=True) - if not self.tomo_stacks[i].size: - logging.error(f'Unable to load tomography stack {index} for reconstruction') - stack[i]['preprocessed'] = False - load_error = True - continue - assert(0 <= lower_row < upper_row < self.tomo_stacks[i].shape[0]) - center_offsets = [lower_center_offset-lower_row*center_slope, - upper_center_offset+(self.tomo_stacks[i].shape[0]-1-upper_row)*center_slope] - t0 = time() - logging.debug(f'running _reconstructOneTomoStack on {num_core} cores ...') - self.tomo_recon_stacks[i]= self._reconstructOneTomoStack(self.tomo_stacks[i], - thetas, center_offsets=center_offsets, num_core=num_core, - algorithm='gridrec') - logging.debug(f'... _reconstructOneTomoStack took {time()-t0:.2f} seconds!') - logging.info(f'Reconstruction of stack {index} took {time()-t0:.2f} seconds!') - if self.galaxy_flag: - x_slice = int(self.tomo_recon_stacks[i].shape[0]/2) - title = f'{basetitle} {index} xslice{x_slice}' - quickImshow(self.tomo_recon_stacks[i][x_slice,:,:], title=title, - path='center_slice_pngs', save_fig=True, save_only=True) - y_slice = int(self.tomo_recon_stacks[i].shape[1]/2) - title = f'{basetitle} {index} yslice{y_slice}' - quickImshow(self.tomo_recon_stacks[i][:,y_slice,:], title=title, - path='center_slice_pngs', save_fig=True, save_only=True) - z_slice = int(self.tomo_recon_stacks[i].shape[2]/2) - title = f'{basetitle} {index} zslice{z_slice}' - quickImshow(self.tomo_recon_stacks[i][:,:,z_slice], title=title, - path='center_slice_pngs', save_fig=True, save_only=True) - else: - x_slice = int(self.tomo_recon_stacks[i].shape[0]/2) - title = f'{basetitle} {index} xslice{x_slice}' - quickImshow(self.tomo_recon_stacks[i][x_slice,:,:], title=title, - path=self.output_folder, save_fig=self.save_plots, - save_only=self.save_plots_only) - y_slice = int(self.tomo_recon_stacks[i].shape[1]/2) - title = f'{basetitle} {index} yslice{y_slice}' - quickImshow(self.tomo_recon_stacks[i][:,y_slice,:], title=title, - path=self.output_folder, save_fig=self.save_plots, - save_only=self.save_plots_only) - z_slice = int(self.tomo_recon_stacks[i].shape[2]/2) - title = f'{basetitle} {index} zslice{z_slice}' - quickImshow(self.tomo_recon_stacks[i][:,:,z_slice], title=title, - path=self.output_folder, save_fig=self.save_plots, - save_only=self.save_plots_only) -# quickPlot(self.tomo_recon_stacks[i] -# [x_slice,int(self.tomo_recon_stacks[i].shape[1]/2),:], -# title=f'{title} cut{int(self.tomo_recon_stacks[i].shape[1]/2)}', -# path=self.output_folder, save_fig=self.save_plots, -# save_only=self.save_plots_only) - if not self.test_mode: - self._saveTomo('recon stack', self.tomo_recon_stacks[i], index) - self.tomo_stacks[i] = np.array([]) - - # Update config and save to file - stack['reconstructed'] = True - combine_stacks = self.config.get('combine_stacks') - if combine_stacks and index in combine_stacks.get('stacks', []): - combine_stacks['stacks'].remove(index) - self.cf.saveFile(self.config_out) - - if self.galaxy_flag: - # Save reconstructed tomography stack to file - t0 = time() - output_name = galaxy_param['output_name'] - logging.info(f'Saving reconstructed tomography stack to {output_name} ...') - save_stacks = {f'set_{stack["index"]}':tomo_stack - for stack,tomo_stack in zip(stacks,self.tomo_recon_stacks)} - np.savez(output_name, **save_stacks) - logging.info(f'... done in {time()-t0:.2f} seconds!') - - # Create cross section profile in yz-plane - tomosum = 0 - [tomosum := tomosum+np.sum(tomo_recon_stack, axis=(0,2)) for tomo_recon_stack in - self.tomo_recon_stacks] - quickPlot(tomosum, title='recon stack sum yz', path='center_slice_pngs', - save_fig=True, save_only=True) - - # Create cross section profile in xz-plane - tomosum = 0 - [tomosum := tomosum+np.sum(tomo_recon_stack, axis=(0,1)) for tomo_recon_stack in - self.tomo_recon_stacks] - quickPlot(tomosum, title='recon stack sum xz', path='center_slice_pngs', - save_fig=True, save_only=True) - - # Create cross section profile in xy-plane - num_tomo_stacks = len(stacks) - row_bounds = self.config['find_center']['row_bounds'] - if not is_index_range(row_bounds, 0, self.tomo_recon_stacks[0].shape[0]): - illegal_value(row_bounds, 'find_center:row_bounds', 'config file') - return - if num_tomo_stacks == 1: - low_bound = row_bounds[0] - else: - low_bound = 0 - tomosum = np.sum(self.tomo_recon_stacks[0][low_bound:row_bounds[1],:,:], axis=(1,2)) - if num_tomo_stacks > 2: - tomosum = np.concatenate([tomosum]+ - [np.sum(self.tomo_recon_stacks[i][row_bounds[0]:row_bounds[1],:,:], - axis=(1,2)) for i in range(1, num_tomo_stacks-1)]) - if num_tomo_stacks > 1: - tomosum = np.concatenate([tomosum, - np.sum(self.tomo_recon_stacks[num_tomo_stacks-1][row_bounds[0]:,:,:], - axis=(1,2))]) - quickPlot(tomosum, title='recon stack sum xy', path='center_slice_pngs', - save_fig=True, save_only=True) - - def combineTomoStacks(self, galaxy_param=None): - """Combine the reconstructed tomography stacks. - """ - # stack order: stack,row(z),x,y - if self.galaxy_flag: - assert(galaxy_param) - if not os.path.exists('combine_pngs'): - os.mkdir('combine_pngs') - logging.debug('Combine reconstructed tomography stacks') - stack_info = self.config['stack_info'] - stacks = stack_info['stacks'] - assert(len(self.tomo_recon_stacks) == stack_info['num']) - assert(len(self.tomo_recon_stacks) == len(stacks)) - if self.galaxy_flag: - assert(isinstance(galaxy_param, dict)) - # Get image bounds - x_bounds = galaxy_param['x_bounds'] - assert(isinstance(x_bounds, list) and len(x_bounds) == 2) - y_bounds = galaxy_param['y_bounds'] - assert(isinstance(y_bounds, list) and len(y_bounds) == 2) - z_bounds = galaxy_param['z_bounds'] - assert(isinstance(z_bounds, list) and len(z_bounds) == 2) - else: - if galaxy_param: - logging.warning('Ignoring galaxy_param in reconstructTomoStacks (only for Galaxy)') - galaxy_param = None - - # Load any unloaded reconstructed stacks - for i,stack in enumerate(stacks): - available = False - if not self.tomo_recon_stacks[i].size and stack.get('reconstructed', False): - self.tomo_recon_stacks[i], available = self._loadTomo('recon stack', - stack['index'], required=True) - if not self.tomo_recon_stacks[i].size: - logging.error(f'Unable to load reconstructed stack {stack["index"]}') - stack['reconstructed'] = False - return - if i: - if (self.tomo_recon_stacks[i].shape[1] != self.tomo_recon_stacks[0].shape[1] or - self.tomo_recon_stacks[i].shape[1] != self.tomo_recon_stacks[0].shape[1]): - logging.error('Incompatible reconstructed tomography stack dimensions') - return - - # Get center stack boundaries - row_bounds = self.config['find_center']['row_bounds'] - if not is_index_range(row_bounds, 0, self.tomo_recon_stacks[0].shape[0]): - illegal_value(row_bounds, 'find_center:row_bounds', 'config file') - return - - # Selecting x bounds (in yz-plane) - tomosum = 0 - #RV FIX := - [tomosum := tomosum+np.sum(tomo_recon_stack, axis=(0,2)) for tomo_recon_stack in - self.tomo_recon_stacks] - combine_stacks = self.config.get('combine_stacks') - if self.galaxy_flag: - if x_bounds[0] == -1: - x_bounds[0] = 0 - if x_bounds[1] == -1: - x_bounds[1] = tomosum.size - if not is_index_range(x_bounds, 0, tomosum.size): - illegal_value(x_bounds, 'x_bounds', 'galaxy input') - tomosum_min = tomosum.min() - tomosum_max = tomosum.max() - quickPlot((range(tomosum.size), tomosum), - ([x_bounds[0], x_bounds[0]], [tomosum_min, tomosum_max], 'r-'), - ([x_bounds[1]-1, x_bounds[1]-1], [tomosum_min, tomosum_max], 'r-'), - title=f'recon stack sum yz', path='combine_pngs', save_fig=True, save_only=True) - else: - x_bounds = None - change_x_bounds = 'y' - if combine_stacks and 'x_bounds' in combine_stacks: - x_bounds = combine_stacks['x_bounds'] - if is_index_range(x_bounds, 0, tomosum.size): - if not self.test_mode: - quickPlot(tomosum, vlines=x_bounds, title='recon stack sum yz') - print(f'x_bounds = {x_bounds} (lower bound inclusive, upper bound '+ - 'exclusive)') - change_x_bounds = 'n' - else: - illegal_value(x_bounds, 'combine_stacks:x_bounds', 'config file') - x_bounds = None - if self.test_mode: - if x_bounds is None: - x_bounds = [0, tomosum.size] - else: - if not input_yesno('\nDo you want to change the image x-bounds (y/n)?', - change_x_bounds): - if x_bounds is None: - x_bounds = [0, tomosum.size] - else: - accept = False - if x_bounds is None: - index_ranges = None - else: - index_ranges = [x_bounds] - while not accept: - mask, x_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, - title='select x data range', legend='recon stack sum yz') - while len(x_bounds) != 1: - print('Please select exactly one continuous range') - mask, x_bounds = draw_mask_1d(tomosum, title='select x data range', - legend='recon stack sum yz') - x_bounds = list(x_bounds[0]) - quickPlot(tomosum, vlines=x_bounds, title='recon stack sum yz') - print(f'x_bounds = {x_bounds} (lower bound inclusive, upper bound '+ - 'exclusive)') - accept = input_yesno('Accept these bounds (y/n)?', 'y') - if self.save_plots_only: - clearPlot('recon stack sum yz') - logging.info(f'x_bounds = {x_bounds}') - - # Selecting y bounds (in xz-plane) - tomosum = 0 - #RV FIX := - [tomosum := tomosum+np.sum(tomo_recon_stack, axis=(0,1)) for tomo_recon_stack in - self.tomo_recon_stacks] - if self.galaxy_flag: - if y_bounds[0] == -1: - y_bounds[0] = 0 - if y_bounds[1] == -1: - y_bounds[1] = tomosum.size - if not is_index_range(y_bounds, 0, tomosum.size): - illegal_value(y_bounds, 'y_bounds', 'galaxy input') - tomosum_min = tomosum.min() - tomosum_max = tomosum.max() - quickPlot((range(tomosum.size), tomosum), - ([y_bounds[0], y_bounds[0]], [tomosum_min, tomosum_max], 'r-'), - ([y_bounds[1]-1, y_bounds[1]-1], [tomosum_min, tomosum_max], 'r-'), - title=f'recon stack sum xz', path='combine_pngs', save_fig=True, save_only=True) - else: - y_bounds = None - change_y_bounds = 'y' - if combine_stacks and 'y_bounds' in combine_stacks: - y_bounds = combine_stacks['y_bounds'] - if is_index_range(y_bounds, 0, tomosum.size): - if not self.test_mode: - quickPlot(tomosum, vlines=y_bounds, title='recon stack sum xz') - print(f'y_bounds = {y_bounds} (lower bound inclusive, upper bound '+ - 'exclusive)') - change_y_bounds = 'n' - else: - illegal_value(y_bounds, 'combine_stacks:y_bounds', 'config file') - y_bounds = None - if self.test_mode: - if y_bounds is None: - y_bounds = [0, tomosum.size] - else: - if not input_yesno('\nDo you want to change the image y-bounds (y/n)?', - change_y_bounds): - if y_bounds is None: - y_bounds = [0, tomosum.size] - else: - accept = False - if y_bounds is None: - index_ranges = None - else: - index_ranges = [y_bounds] - while not accept: - mask, y_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, - title='select x data range', legend='recon stack sum xz') - while len(y_bounds) != 1: - print('Please select exactly one continuous range') - mask, y_bounds = draw_mask_1d(tomosum, title='select x data range', - legend='recon stack sum xz') - y_bounds = list(y_bounds[0]) - quickPlot(tomosum, vlines=y_bounds, title='recon stack sum xz') - print(f'y_bounds = {y_bounds} (lower bound inclusive, upper bound '+ - 'exclusive)') - accept = input_yesno('Accept these bounds (y/n)?', 'y') - if self.save_plots_only: - clearPlot('recon stack sum xz') - logging.info(f'y_bounds = {y_bounds}') - - # Combine reconstructed tomography stacks - logging.info(f'Combining reconstructed stacks ...') - t0 = time() - num_tomo_stacks = len(stacks) - if num_tomo_stacks == 1: - low_bound = row_bounds[0] - else: - low_bound = 0 - tomo_recon_combined = self.tomo_recon_stacks[0][low_bound:row_bounds[1], - x_bounds[0]:x_bounds[1],y_bounds[0]:y_bounds[1]] - if num_tomo_stacks > 2: - tomo_recon_combined = np.concatenate([tomo_recon_combined]+ - [self.tomo_recon_stacks[i][row_bounds[0]:row_bounds[1], - x_bounds[0]:x_bounds[1],y_bounds[0]:y_bounds[1]] - for i in range(1, num_tomo_stacks-1)]) - if num_tomo_stacks > 1: - tomo_recon_combined = np.concatenate([tomo_recon_combined]+ - [self.tomo_recon_stacks[num_tomo_stacks-1][row_bounds[0]:, - x_bounds[0]:x_bounds[1],y_bounds[0]:y_bounds[1]]]) - logging.info(f'... done in {time()-t0:.2f} seconds!') - combined_stacks = [stack['index'] for stack in stacks] - - # Selecting z bounds (in xy-plane) - tomosum = np.sum(tomo_recon_combined, axis=(1,2)) - if self.galaxy_flag: - if z_bounds[0] == -1: - z_bounds[0] = 0 - if z_bounds[1] == -1: - z_bounds[1] = tomosum.size - if not is_index_range(z_bounds, 0, tomosum.size): - illegal_value(z_bounds, 'z_bounds', 'galaxy input') - tomosum_min = tomosum.min() - tomosum_max = tomosum.max() - quickPlot((range(tomosum.size), tomosum), - ([z_bounds[0], z_bounds[0]], [tomosum_min, tomosum_max], 'r-'), - ([z_bounds[1]-1, z_bounds[1]-1], [tomosum_min, tomosum_max], 'r-'), - title=f'recon stack sum xy', path='combine_pngs', save_fig=True, save_only=True) - else: - z_bounds = None - if combine_stacks and 'z_bounds' in combine_stacks: - z_bounds = combine_stacks['z_bounds'] - if is_index_range(z_bounds, 0, tomosum.size): - if not self.test_mode: - quickPlot(tomosum, vlines=z_bounds, title='recon stack sum xy') - print(f'z_bounds = {z_bounds} (lower bound inclusive, upper bound '+ - 'exclusive)') - else: - illegal_value(z_bounds, 'combine_stacks:z_bounds', 'config file') - z_bounds = None - if self.test_mode: - if z_bounds is None: - z_bounds = [0, tomosum.size] - else: - if not input_yesno('\nDo you want to change the image z-bounds (y/n)?', 'n'): - if z_bounds is None: - z_bounds = [0, tomosum.size] - else: - accept = False - if z_bounds is None: - index_ranges = None - else: - index_ranges = [z_bounds] - while not accept: - mask, z_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, - title='select x data range', legend='recon stack sum xy') - while len(z_bounds) != 1: - print('Please select exactly one continuous range') - mask, z_bounds = draw_mask_1d(tomosum, title='select x data range', - legend='recon stack sum xy') - z_bounds = list(z_bounds[0]) - quickPlot(tomosum, vlines=z_bounds, title='recon stack sum xy') - print(f'z_bounds = {z_bounds} (lower bound inclusive, upper bound '+ - 'exclusive)') - accept = input_yesno('Accept these bounds (y/n)?', 'y') - if self.save_plots_only: - clearPlot('recon stack sum xy') - logging.info(f'z_bounds = {z_bounds}') - - # Plot center slices - if self.galaxy_flag: - path = 'combine_pngs' - save_fig = True - save_only = True - else: - path = self.output_folder - save_fig = self.save_plots - save_only = self.save_plots_only - quickImshow(tomo_recon_combined[int(tomo_recon_combined.shape[0]/2),:,:], - title=f'recon combined xslice{int(tomo_recon_combined.shape[0]/2)}', - path=path, save_fig=save_fig, save_only=save_only) - quickImshow(tomo_recon_combined[:,int(tomo_recon_combined.shape[1]/2),:], - title=f'recon combined yslice{int(tomo_recon_combined.shape[1]/2)}', - path=path, save_fig=save_fig, save_only=save_only) - quickImshow(tomo_recon_combined[:,:,int(tomo_recon_combined.shape[2]/2)], - title=f'recon combined zslice{int(tomo_recon_combined.shape[2]/2)}', - path=path, save_fig=save_fig, save_only=save_only) - - # Save combined reconstructed tomography stack or test mode data to file - if self.galaxy_flag: - t0 = time() - output_name = galaxy_param['output_name'] - logging.info(f'Saving combined reconstructed tomography stack to {output_name} ...') - np.save(output_name, tomo_recon_combined) - logging.info(f'... done in {time()-t0:.2f} seconds!') - elif self.test_mode: - zoom_perc = self.config['preprocess'].get('zoom_perc', 100) - filename = 'recon combined sum xy' - if zoom_perc is None or zoom_perc == 100: - filename += ' fullres.dat' - else: - filename += f' {zoom_perc}p.dat' - quickPlot(tomosum, title='recon combined sum xy', - path=self.output_folder, save_fig=self.save_plots, - save_only=self.save_plots_only) - np.savetxt(f'{self.output_folder}/recon_combined.txt', - tomo_recon_combined[int(tomosum.size/2),:,:], fmt='%.6e') - else: - base_name = 'recon combined' - for stack in stacks: - base_name += f' {stack["index"]}' - self._saveTomo(base_name, tomo_recon_combined) - - # Update config and save to file - if combine_stacks: - combine_stacks['x_bounds'] = x_bounds - combine_stacks['y_bounds'] = y_bounds - combine_stacks['z_bounds'] = z_bounds - combine_stacks['stacks'] = combined_stacks - else: - self.config['combine_stacks'] = {'x_bounds' : x_bounds, 'y_bounds' : y_bounds, - 'z_bounds' : z_bounds, 'stacks' : combined_stacks} - self.cf.saveFile(self.config_out) - -def runTomo(config_file=None, config_dict=None, output_folder='.', log_level='INFO', - test_mode=False, num_core=-1): - """Run a tomography analysis. - """ - # Instantiate Tomo object - tomo = Tomo(config_file=config_file, output_folder=output_folder, log_level=log_level, - test_mode=test_mode, num_core=num_core) - if not tomo.is_valid: - raise ValueError('Invalid config and/or detector file provided.') - - # Preprocess the image files - assert(tomo.config['stack_info']) - num_tomo_stacks = tomo.config['stack_info']['num'] - assert(num_tomo_stacks == len(tomo.tomo_stacks)) - preprocessed_stacks = [] - if not tomo.test_mode: - preprocess = tomo.config.get('preprocess', None) - if preprocess: - preprocessed_stacks = [stack['index'] for stack in tomo.config['stack_info']['stacks'] - if stack.get('preprocessed', False)] - if len(preprocessed_stacks) != num_tomo_stacks: - tomo.genTomoStacks() - if not tomo.is_valid: - IOError('Unable to load all required image files.') - tomo.cf.saveFile(tomo.config_out) - - # Find centers - find_center = tomo.config.get('find_center') - if find_center is None or not find_center.get('completed', False): - tomo.findCenters() - - # Check centers - #if num_tomo_stacks > 1 and not tomo.config.get('check_centers', False): - # tomo.checkCenters() - - # Reconstruct tomography stacks - assert(tomo.config['stack_info']['stacks']) - reconstructed_stacks = [stack['index'] for stack in tomo.config['stack_info']['stacks'] - if stack.get('reconstructed', False)] - if len(reconstructed_stacks) != num_tomo_stacks: - tomo.reconstructTomoStacks() - - # Combine reconstructed tomography stacks - reconstructed_stacks = [stack['index'] for stack in tomo.config['stack_info']['stacks'] - if stack.get('reconstructed', False)] - combine_stacks = tomo.config.get('combine_stacks') - if len(reconstructed_stacks) and (combine_stacks is None or - combine_stacks.get('stacks') != reconstructed_stacks): - tomo.combineTomoStacks() - -#%%============================================================================ -if __name__ == '__main__': - - # Parse command line arguments - parser = argparse.ArgumentParser( - description='Tomography reconstruction') - parser.add_argument('-c', '--config', - default=None, - help='Input config') - parser.add_argument('-o', '--output_folder', - default='.', - help='Output folder') - parser.add_argument('-l', '--log_level', - default='INFO', - help='Log level') - parser.add_argument('-t', '--test_mode', - action='store_true', - default=False, - help='Test mode flag') - parser.add_argument('--num_core', - type=int, - default=-1, - help='Number of cores') - args = parser.parse_args() - - if args.config is None: - if os.path.isfile('config.yaml'): - args.config = 'config.yaml' - else: - args.config = 'config.txt' - - # Set basic log configuration - logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s' - if not args.test_mode: - level = getattr(logging, args.log_level.upper(), None) - if not isinstance(level, int): - raise ValueError(f'Invalid log_level: {args.log_level}') - logging.basicConfig(format=logging_format, level=level, force=True, - handlers=[logging.StreamHandler()]) - - logging.debug(f'config = {args.config}') - logging.debug(f'output_folder = {args.output_folder}') - logging.debug(f'log_level = {args.log_level}') - logging.debug(f'test_mode = {args.test_mode}') - logging.debug(f'num_core = {args.num_core}') - - # Run tomography analysis - runTomo(config_file=args.config, output_folder=args.output_folder, log_level=args.log_level, - test_mode=args.test_mode, num_core=args.num_core) - -#%%============================================================================ -# input('Press any key to continue') -#%%============================================================================
--- a/tomo_combine.py Fri Aug 19 20:16:56 2022 +0000 +++ b/tomo_combine.py Fri Mar 10 16:02:04 2023 +0000 @@ -2,81 +2,85 @@ import logging +import argparse +import pathlib import sys -import argparse -import tracemalloc +#import tracemalloc -from tomo import Tomo +from workflow.run_tomo import Tomo +#from memory_profiler import profile +#@profile def __main__(): - # Parse command line arguments parser = argparse.ArgumentParser( - description='Combine reconstructed tomography stacks') - parser.add_argument('-i', '--input_stacks', - help='Reconstructed image file stacks') - parser.add_argument('-c', '--config', - help='Input config') - parser.add_argument('--x_bounds', - required=True, nargs=2, type=int, help='Reconstructed range in x direction') - parser.add_argument('--y_bounds', - required=True, nargs=2, type=int, help='Reconstructed range in y direction') - parser.add_argument('--z_bounds', - required=True, nargs=2, type=int, help='Reconstructed range in z direction') - parser.add_argument('--output_config', - help='Output config') - parser.add_argument('--output_data', - help='Combined tomography stacks') - parser.add_argument('-l', '--log', - type=argparse.FileType('w'), + description='Reduce tomography data') + parser.add_argument('-i', '--input_file', + required=True, + type=pathlib.Path, + help='''Full or relative path to the input file (in Nexus format).''') + parser.add_argument('-o', '--output_file', + required=False, + type=pathlib.Path, + help='''Full or relative path to the output file (in yaml format).''') + parser.add_argument('-l', '--log', +# type=argparse.FileType('w'), default=sys.stdout, - help='Log file') + help='Logging stream or filename') + parser.add_argument('--log_level', + choices=logging._nameToLevel.keys(), + default='INFO', + help='''Specify a preferred logging level.''') args = parser.parse_args() + # Set log configuration + # When logging to file, the stdout log level defaults to WARNING + logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s' + level = logging.getLevelName(args.log_level) + if args.log is sys.stdout: + logging.basicConfig(format=logging_format, level=level, force=True, + handlers=[logging.StreamHandler()]) + else: + if isinstance(args.log, str): + logging.basicConfig(filename=f'{args.log}', filemode='w', + format=logging_format, level=level, force=True) + elif isinstance(args.log, io.TextIOWrapper): + logging.basicConfig(filemode='w', format=logging_format, level=level, + stream=args.log, force=True) + else: + raise(ValueError(f'Invalid argument --log: {args.log}')) + stream_handler = logging.StreamHandler() + logging.getLogger().addHandler(stream_handler) + stream_handler.setLevel(logging.WARNING) + stream_handler.setFormatter(logging.Formatter(logging_format)) + # Starting memory monitoring - tracemalloc.start() +# tracemalloc.start() - # Set basic log configuration - logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s' - log_level = 'INFO' - level = getattr(logging, log_level.upper(), None) - if not isinstance(level, int): - raise ValueError(f'Invalid log_level: {log_level}') - logging.basicConfig(format=logging_format, level=level, force=True, - handlers=[logging.StreamHandler()]) - - logging.debug(f'config = {args.config}') - logging.debug(f'input_stacks = {args.input_stacks}') - logging.debug(f'x_bounds = {args.x_bounds} {type(args.x_bounds)}') - logging.debug(f'y_bounds = {args.y_bounds} {type(args.y_bounds)}') - logging.debug(f'z_bounds = {args.z_bounds} {type(args.z_bounds)}') - logging.debug(f'output_config = {args.output_config}') - logging.debug(f'output_data = {args.output_data}') + # Log command line arguments + logging.info(f'input_file = {args.input_file}') + logging.info(f'output_file = {args.output_file}') logging.debug(f'log = {args.log}') logging.debug(f'is log stdout? {args.log is sys.stdout}') + logging.debug(f'log_level = {args.log_level}') # Instantiate Tomo object - tomo = Tomo(config_file=args.config, config_out=args.output_config, log_level=log_level, - log_stream=args.log, galaxy_flag=True) - if not tomo.is_valid: - raise ValueError('Invalid config file provided.') - logging.debug(f'config:\n{tomo.config}') + tomo = Tomo() + + # Read input file + data = tomo.read(args.input_file) - # Load reconstructed image files - tomo.loadTomoStacks(args.input_stacks, recon_flag=True) + # Combine the reconstructed tomography stacks + data = tomo.combine_data(data) - # Combined reconstructed tomography stacks - galaxy_param = {'x_bounds' : args.x_bounds, 'y_bounds' : args.y_bounds, - 'z_bounds' : args.z_bounds, 'output_name' : args.output_data} - logging.debug(f'galaxy_param = {galaxy_param}') - tomo.combineTomoStacks(galaxy_param) + # Write output file + data = tomo.write(data, args.output_file) # Displaying memory usage - logging.info(f'Memory usage: {tracemalloc.get_traced_memory()}') - +# logging.info(f'Memory usage: {tracemalloc.get_traced_memory()}') + # stopping memory monitoring - tracemalloc.stop() +# tracemalloc.stop() if __name__ == "__main__": __main__() -
--- a/tomo_find_center.py Fri Aug 19 20:16:56 2022 +0000 +++ b/tomo_find_center.py Fri Mar 10 16:02:04 2023 +0000 @@ -2,85 +2,95 @@ import logging +import argparse +import pathlib import sys -import argparse -import tracemalloc +#import tracemalloc -from tomo import Tomo +from workflow.run_tomo import Tomo +#from memory_profiler import profile +#@profile def __main__(): - # Parse command line arguments parser = argparse.ArgumentParser( - description='Find the center axis for a tomography reconstruction') - parser.add_argument('-i', '--input_stacks', - required=True, help='Preprocessed image file stacks') - parser.add_argument('-c', '--config', - required=True, help='Input config') - parser.add_argument('--row_bounds', - required=True, nargs=2, type=int, help='Reconstruction row bounds') + description='Reduce tomography data') + parser.add_argument('-i', '--input_file', + required=True, + type=pathlib.Path, + help='''Full or relative path to the input file (in Nexus format).''') + parser.add_argument('-o', '--output_file', + required=False, + type=pathlib.Path, + help='''Full or relative path to the output file (in yaml format).''') parser.add_argument('--center_rows', - required=True, nargs=2, type=int, help='Center finding rows') - parser.add_argument('--center_type_selector', - help='Reconstruct slices for a set of center positions?') - parser.add_argument('--set_center', - type=int, help='Set center ') - parser.add_argument('--set_range', - type=float, help='Set range') - parser.add_argument('--set_step', - type=float, help='Set step') - parser.add_argument('--output_config', - required=True, help='Output config') - parser.add_argument('-l', '--log', - type=argparse.FileType('w'), default=sys.stdout, help='Log file') + required=True, + nargs=2, + type=int, + help='''Center finding rows.''') + parser.add_argument('--galaxy_flag', + action='store_true', + help='''Use this flag to run the scripts as a galaxy tool.''') + parser.add_argument('-l', '--log', +# type=argparse.FileType('w'), + default=sys.stdout, + help='Logging stream or filename') + parser.add_argument('--log_level', + choices=logging._nameToLevel.keys(), + default='INFO', + help='''Specify a preferred logging level.''') args = parser.parse_args() + # Set log configuration + # When logging to file, the stdout log level defaults to WARNING + logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s' + level = logging.getLevelName(args.log_level) + if args.log is sys.stdout: + logging.basicConfig(format=logging_format, level=level, force=True, + handlers=[logging.StreamHandler()]) + else: + if isinstance(args.log, str): + logging.basicConfig(filename=f'{args.log}', filemode='w', + format=logging_format, level=level, force=True) + elif isinstance(args.log, io.TextIOWrapper): + logging.basicConfig(filemode='w', format=logging_format, level=level, + stream=args.log, force=True) + else: + raise(ValueError(f'Invalid argument --log: {args.log}')) + stream_handler = logging.StreamHandler() + logging.getLogger().addHandler(stream_handler) + stream_handler.setLevel(logging.WARNING) + stream_handler.setFormatter(logging.Formatter(logging_format)) + # Starting memory monitoring - tracemalloc.start() +# tracemalloc.start() - # Set basic log configuration - logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s' - log_level = 'INFO' - level = getattr(logging, log_level.upper(), None) - if not isinstance(level, int): - raise ValueError(f'Invalid log_level: {log_level}') - logging.basicConfig(format=logging_format, level=level, force=True, - handlers=[logging.StreamHandler()]) - - logging.debug(f'config = {args.config}') - logging.debug(f'input_stacks = {args.input_stacks}') - logging.debug(f'row_bounds = {args.row_bounds} {type(args.row_bounds)}') - logging.debug(f'center_rows = {args.center_rows} {type(args.center_rows)}') - logging.debug(f'center_type_selector = {args.center_type_selector}') - logging.debug(f'set_center = {args.set_center}') - logging.debug(f'set_range = {args.set_range}') - logging.debug(f'set_step = {args.set_step}') - logging.debug(f'output_config = {args.output_config}') + # Log command line arguments + logging.info(f'input_file = {args.input_file}') + logging.info(f'output_file = {args.output_file}') + logging.info(f'center_rows = {args.center_rows}') + logging.info(f'galaxy_flag = {args.galaxy_flag}') logging.debug(f'log = {args.log}') logging.debug(f'is log stdout? {args.log is sys.stdout}') + logging.debug(f'log_level = {args.log_level}') # Instantiate Tomo object - tomo = Tomo(config_file=args.config, config_out=args.output_config, log_level=log_level, - log_stream=args.log, galaxy_flag=True) - if not tomo.is_valid: - raise ValueError('Invalid config file provided.') - logging.debug(f'config:\n{tomo.config}') + tomo = Tomo(galaxy_flag=args.galaxy_flag) + + # Read input file + data = tomo.read(args.input_file) - # Load preprocessed image files - tomo.loadTomoStacks(args.input_stacks) + # Find the calibrated center axis info + data = tomo.find_centers(data, center_rows=tuple(args.center_rows)) - # Find centers - galaxy_param = {'row_bounds' : args.row_bounds, 'center_rows' : args.center_rows, - 'center_type_selector' : args.center_type_selector, 'set_center' : args.set_center, - 'set_range' : args.set_range, 'set_step' : args.set_step} - tomo.findCenters(galaxy_param) + # Write output file + data = tomo.write(data, args.output_file) # Displaying memory usage - logging.info(f'Memory usage: {tracemalloc.get_traced_memory()}') - +# logging.info(f'Memory usage: {tracemalloc.get_traced_memory()}') + # stopping memory monitoring - tracemalloc.stop() +# tracemalloc.stop() if __name__ == "__main__": __main__() -
--- a/tomo_macros.xml Fri Aug 19 20:16:56 2022 +0000 +++ b/tomo_macros.xml Fri Mar 10 16:02:04 2023 +0000 @@ -1,24 +1,27 @@ <macros> <xml name="requirements"> <requirements> + <requirement type="package" version="1.0.3">lmfit</requirement> + <requirement type="package" version="1.0.0">nexusformat</requirement> <requirement type="package" version="1.11.0">tomopy</requirement> - <requirement type="package" version="3.6.0">h5py</requirement> - <requirement type="package" version="1.0.3">lmfit</requirement> </requirements> </xml> <xml name="citations"> <citations> <citation type="bibtex"> -@misc{githubsum_files, +@misc{github_files, author = {Verberg, Rolf}, year = {2022}, title = {Tomo Reconstruction}, }</citation> </citations> </xml> + <!-- <xml name="common_inputs"> - <param name="config" type='data' format='tomo.config.yaml' optional='false' label="Input config"/> + <param name="config" type='data' format='yaml' optional='false' label="Input config"/> + <param name="config" type='data' format='tomo.config.yaml' optional='true' label="Input config"/> </xml> + --> <xml name="common_outputs"> <data name="log" format="txt" label="Log"/> </xml>
--- a/tomo_reconstruct.py Fri Aug 19 20:16:56 2022 +0000 +++ b/tomo_reconstruct.py Fri Mar 10 16:02:04 2023 +0000 @@ -2,84 +2,109 @@ import logging +import argparse +import pathlib import sys -import argparse -import tracemalloc +#import tracemalloc -from tomo import Tomo +from workflow.run_tomo import Tomo +#from memory_profiler import profile +#@profile def __main__(): - # Parse command line arguments parser = argparse.ArgumentParser( - description='Perfrom a tomography reconstruction') - parser.add_argument('-i', '--input_stacks', - help='Preprocessed image file stacks') - parser.add_argument('-c', '--config', - help='Input config') - parser.add_argument('--center_offsets', - nargs=2, type=float, help='Reconstruction center axis offsets') - parser.add_argument('--output_config', - help='Output config') - parser.add_argument('--output_data', - help='Reconstructed tomography data') - parser.add_argument('-l', '--log', - type=argparse.FileType('w'), + description='Reduce tomography data') + parser.add_argument('-i', '--input_file', + required=True, + type=pathlib.Path, + help='''Full or relative path to the input file (in Nexus format).''') + parser.add_argument('-c', '--center_file', + required=True, + type=pathlib.Path, + help='''Full or relative path to the center info file (in Nexus format).''') + parser.add_argument('-o', '--output_file', + required=False, + type=pathlib.Path, + help='''Full or relative path to the output file (in yaml format).''') + parser.add_argument('--galaxy_flag', + action='store_true', + help='''Use this flag to run the scripts as a galaxy tool.''') + parser.add_argument('-l', '--log', +# type=argparse.FileType('w'), default=sys.stdout, - help='Log file') + help='Logging stream or filename') + parser.add_argument('--log_level', + choices=logging._nameToLevel.keys(), + default='INFO', + help='''Specify a preferred logging level.''') + parser.add_argument('--x_bounds', + required=False, + nargs=2, + type=int, + help='''Boundaries of reconstructed images in x-direction.''') + parser.add_argument('--y_bounds', + required=False, + nargs=2, + type=int, + help='''Boundaries of reconstructed images in y-direction.''') args = parser.parse_args() + # Set log configuration + # When logging to file, the stdout log level defaults to WARNING + logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s' + level = logging.getLevelName(args.log_level) + if args.log is sys.stdout: + logging.basicConfig(format=logging_format, level=level, force=True, + handlers=[logging.StreamHandler()]) + else: + if isinstance(args.log, str): + logging.basicConfig(filename=f'{args.log}', filemode='w', + format=logging_format, level=level, force=True) + elif isinstance(args.log, io.TextIOWrapper): + logging.basicConfig(filemode='w', format=logging_format, level=level, + stream=args.log, force=True) + else: + raise(ValueError(f'Invalid argument --log: {args.log}')) + stream_handler = logging.StreamHandler() + logging.getLogger().addHandler(stream_handler) + stream_handler.setLevel(logging.WARNING) + stream_handler.setFormatter(logging.Formatter(logging_format)) + # Starting memory monitoring - tracemalloc.start() +# tracemalloc.start() - # Set basic log configuration - logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s' - log_level = 'INFO' - level = getattr(logging, log_level.upper(), None) - if not isinstance(level, int): - raise ValueError(f'Invalid log_level: {log_level}') - logging.basicConfig(format=logging_format, level=level, force=True, - handlers=[logging.StreamHandler()]) - - logging.debug(f'config = {args.config}') - logging.debug(f'input_stacks = {args.input_stacks}') - logging.debug(f'center_offsets = {args.center_offsets} {type(args.center_offsets)}') - logging.debug(f'output_config = {args.output_config}') - logging.debug(f'output_data = {args.output_data}') + # Log command line arguments + logging.info(f'input_file = {args.input_file}') + logging.info(f'center_file = {args.center_file}') + logging.info(f'output_file = {args.output_file}') + logging.info(f'galaxy_flag = {args.galaxy_flag}') logging.debug(f'log = {args.log}') logging.debug(f'is log stdout? {args.log is sys.stdout}') + logging.debug(f'log_level = {args.log_level}') + logging.info(f'x_bounds = {args.x_bounds}') + logging.info(f'y_bounds = {args.y_bounds}') # Instantiate Tomo object - tomo = Tomo(config_file=args.config, config_out=args.output_config, log_level=log_level, - log_stream=args.log, galaxy_flag=True) - if not tomo.is_valid: - raise ValueError('Invalid config file provided.') - logging.debug(f'config:\n{tomo.config}') + tomo = Tomo(galaxy_flag=args.galaxy_flag) + + # Read input file + data = tomo.read(args.input_file) - # Set reconstruction center axis offsets - if args.center_offsets is None: - find_center = tomo.config.get('find_center') - if find_center is None: - raise ValueError('Invalid config file provided (missing find_center).') - center_offsets = [float(find_center.get('lower_center_offset')), - float(find_center.get('upper_center_offset'))] - else: - center_offsets = args.center_offsets + # Read center data + center_data = tomo.read(args.center_file) - # Load preprocessed image files - tomo.loadTomoStacks(args.input_stacks) + # Find the calibrated center axis info + data = tomo.reconstruct_data(data, center_data, x_bounds=args.x_bounds, y_bounds=args.y_bounds) - # Reconstruct tomography stacks - galaxy_param = {'center_offsets' : center_offsets, 'output_name' : args.output_data} - logging.debug(f'galaxy_param = {galaxy_param}') - tomo.reconstructTomoStacks(galaxy_param) + # Write output file + data = tomo.write(data, args.output_file) # Displaying memory usage - logging.info(f'Memory usage: {tracemalloc.get_traced_memory()}') - +# logging.info(f'Memory usage: {tracemalloc.get_traced_memory()}') + # stopping memory monitoring - tracemalloc.stop() +# tracemalloc.stop() if __name__ == "__main__": __main__() -
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/tomo_reduce.py Fri Mar 10 16:02:04 2023 +0000 @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 + +import logging + +import argparse +import pathlib +import sys +#import tracemalloc + +from workflow.run_tomo import Tomo + +#from memory_profiler import profile +#@profile +def __main__(): + # Parse command line arguments + parser = argparse.ArgumentParser( + description='Reduce tomography data') + parser.add_argument('-i', '--input_file', + required=True, + type=pathlib.Path, + help='''Full or relative path to the input file (in yaml format).''') + parser.add_argument('-o', '--output_file', + required=False, + type=pathlib.Path, + help='''Full or relative path to the output file (in Nexus format).''') + parser.add_argument('--galaxy_flag', + action='store_true', + help='''Use this flag to run the scripts as a galaxy tool.''') + parser.add_argument('--img_x_bounds', + required=False, + nargs=2, + type=int, + help='Vertical data reduction image range') + parser.add_argument('-l', '--log', +# type=argparse.FileType('w'), + default=sys.stdout, + help='Logging stream or filename') + parser.add_argument('--log_level', + choices=logging._nameToLevel.keys(), + default='INFO', + help='''Specify a preferred logging level.''') + args = parser.parse_args() + + # Set log configuration + # When logging to file, the stdout log level defaults to WARNING + logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s' + level = logging.getLevelName(args.log_level) + if args.log is sys.stdout: + logging.basicConfig(format=logging_format, level=level, force=True, + handlers=[logging.StreamHandler()]) + else: + if isinstance(args.log, str): + logging.basicConfig(filename=f'{args.log}', filemode='w', + format=logging_format, level=level, force=True) + elif isinstance(args.log, io.TextIOWrapper): + logging.basicConfig(filemode='w', format=logging_format, level=level, + stream=args.log, force=True) + else: + raise(ValueError(f'Invalid argument --log: {args.log}')) + stream_handler = logging.StreamHandler() + logging.getLogger().addHandler(stream_handler) + stream_handler.setLevel(logging.WARNING) + stream_handler.setFormatter(logging.Formatter(logging_format)) + + # Start memory monitoring +# tracemalloc.start() + + # Log command line arguments + logging.info(f'input_file = {args.input_file}') + logging.info(f'output_file = {args.output_file}') + logging.info(f'galaxy_flag = {args.galaxy_flag}') + logging.info(f'img_x_bounds = {args.img_x_bounds}') + logging.debug(f'log = {args.log}') + logging.debug(f'is log stdout? {args.log is sys.stdout}') + logging.debug(f'log_level = {args.log_level}') + + # Instantiate Tomo object + tomo = Tomo(galaxy_flag=args.galaxy_flag) + + # Read input file + data = tomo.read(args.input_file) + + # Generate reduced tomography images + data = tomo.gen_reduced_data(data, img_x_bounds=args.img_x_bounds) + + # Write output file + data = tomo.write(data, args.output_file) + + # Displaying memory usage +# logging.info(f'Memory usage: {tracemalloc.get_traced_memory()}') + + # Stop memory monitoring +# tracemalloc.stop() + +if __name__ == "__main__": + __main__()
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/tomo_reduce.xml Fri Mar 10 16:02:04 2023 +0000 @@ -0,0 +1,34 @@ +<tool id="tomo_reduce" name="Tomo Reduce" version="0.1.0" python_template_version="3.9"> + <description>Reduce tomography images</description> + <macros> + <import>tomo_macros.xml</import> + </macros> + <expand macro="requirements" /> + <command detect_errors="exit_code"> + <![CDATA[ + mkdir tomo_reduce_plots; + $__tool_directory__/tomo_reduce.py + --input_file '$input_file' + --output_file 'output.nxs' + --galaxy_flag + --img_x_bounds $x_bounds.x_bound_low $x_bounds.x_bound_upp + -l '$log' + ]]> + </command> + <inputs> + <param name="input_file" type="data" optional="false" label="Input file"/> + </inputs> + <outputs> + <expand macro="common_outputs"/> + <collection name="tomo_reduce_plots" type="list" label="Tomo data reduction images"> + <discover_datasets pattern="__name_and_ext__" directory="tomo_reduce_plots"/> + </collection> + <data name="output_file" format="nxs" label="Reduced tomography data" from_work_dir="output.nxs"/> + </outputs> + <help> + <![CDATA[ + Reduce tomography images. + ]]> + </help> + <expand macro="citations"/> +</tool>
--- a/tomo_setup.py Fri Aug 19 20:16:56 2022 +0000 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,304 +0,0 @@ -#!/usr/bin/env python3 - -import logging - -import os -import sys -import re -import yaml -import argparse -import numpy as np -import tracemalloc - -from tomo import Tomo -from general import get_trailing_int - -#from memory_profiler import profile -#@profile -def __main__(): - - # Parse command line arguments - parser = argparse.ArgumentParser( - description='Setup tomography reconstruction') - parser.add_argument('--inputconfig', - default='inputconfig.txt', - help='Input config from tool form') - parser.add_argument('--inputfiles', - default='inputfiles.txt', - help='Input file collections') - parser.add_argument('-c', '--config', - help='Input config file') - parser.add_argument('--detector', - help='Detector info (number of rows and columns, and pixel size)') - parser.add_argument('--num_theta', - help='Number of theta angles') - parser.add_argument('--theta_range', - help='Theta range (lower bound, upper bound)') - parser.add_argument('--output_config', - help='Output config') - parser.add_argument('--output_data', - help='Preprocessed tomography data') - parser.add_argument('-l', '--log', - type=argparse.FileType('w'), - default=sys.stdout, - help='Log file') - args = parser.parse_args() - - # Starting memory monitoring - tracemalloc.start() - - # Set basic log configuration - logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s' - log_level = 'INFO' - level = getattr(logging, log_level.upper(), None) - if not isinstance(level, int): - raise ValueError(f'Invalid log_level: {log_level}') - logging.basicConfig(format=logging_format, level=level, force=True, - handlers=[logging.StreamHandler()]) - - # Check command line arguments - logging.debug(f'config = {args.config}') - if args.detector is None: - logging.debug(f'detector = {args.detector}') - else: - logging.debug(f'detector = {args.detector.split()}') - logging.debug(f'num_theta = {args.num_theta}') - if args.theta_range is None: - logging.debug(f'theta_range = {args.theta_range}') - else: - logging.debug(f'theta_range = {args.theta_range.split()}') - logging.debug(f'output_config = {args.output_config}') - logging.debug(f'output_data = {args.output_data}') - logging.debug(f'log = {args.log}') - logging.debug(f'is log stdout? {args.log is sys.stdout}') - if args.detector is not None and len(args.detector.split()) != 3: - raise ValueError(f'Invalid detector: {args.detector}') - if args.num_theta is None or int(args.num_theta) < 1: - raise ValueError(f'Invalid num_theta: {args.num_theta}') - if args.theta_range is not None and len(args.theta_range.split()) != 2: - raise ValueError(f'Invalid theta_range: {args.theta_range}') - num_theta = int(args.num_theta) - - # Read and check tool config input - inputconfig = [] - with open(args.inputconfig) as f: - inputconfig = [line.strip() for line in f if line.strip() and not line.startswith('#')] - assert(len(inputconfig) >= 6) - config_type = inputconfig[0] - input_type = inputconfig[1] - num_stack = int(inputconfig[2]) - stack_types = [x.strip() for x in inputconfig[3].split()] - num_imgs = [int(x.strip()) for x in inputconfig[4].split()] - img_offsets = [int(x.strip()) for x in inputconfig[5].split()] - if config_type == 'config_manual': - assert(len(inputconfig) == 7) - ref_heights = [float(x.strip()) for x in inputconfig[6].split()] - assert(args.detector is not None) - assert(args.theta_range is not None) - else: - ref_heights = None - logging.debug(f'config_type = {config_type} {type(config_type)}') - logging.debug(f'input_type = {input_type} {type(input_type)}') - logging.debug(f'num_stack = {num_stack} {type(num_stack)}') - logging.debug(f'stack_types = {stack_types} {type(stack_types)}') - logging.debug(f'num_imgs = {num_imgs} {type(num_imgs)}') - logging.debug(f'img_offsets = {img_offsets} {type(img_offsets)}') - logging.debug(f'ref_heights = {ref_heights} {type(ref_heights)}') - if config_type != 'config_file' and config_type != 'config_manual': - raise ValueError('Invalid input config provided.') - if input_type != 'collections' and input_type != 'files': - raise ValueError('Invalid input config provided.') - if len(stack_types) != num_stack: - raise ValueError('Invalid input config provided.') - if len(num_imgs) != num_stack: - raise ValueError('Invalid input config provided.') - if len(img_offsets) != num_stack: - raise ValueError('Invalid input config provided.') - if ref_heights is not None and len(ref_heights) != num_stack: - raise ValueError('Invalid input config provided.') - - # Read input files and collect data files info - datasets = [] - with open(args.inputfiles) as f: - for line in f: - if not line.strip() or line.startswith('#'): - continue - fields = [x.strip() for x in line.split('\t')] - filepath = fields[0] - element_identifier = fields[1] if len(fields) > 1 else fields[0].split('/')[-1] - datasets.append({'element_identifier' : element_identifier, 'filepath' : filepath}) - logging.debug(f'datasets:\n{datasets}') - if input_type == 'files' and len(datasets) != num_stack: - raise ValueError('Inconsistent number of input files provided.') - - # Read and sort data files - collections = [] - stack_index = 1 - for i, dataset in enumerate(datasets): - if input_type == 'collections': - element_identifier = [x.strip() for x in dataset['element_identifier'].split('_')] - if len(element_identifier) > 1: - name = element_identifier[0] - else: - name = 'other' - else: - if stack_types[i] == 'tdf' or stack_types[i] == 'tbf': - name = stack_types[i] - elif stack_types[i] == 'data': - name = f'set{stack_index}' - stack_index += 1 - else: - raise ValueError('Invalid input config provided.') - filepath = dataset['filepath'] - if not len(collections): - collections = [{'name' : name, 'filepaths' : [filepath]}] - else: - collection = [c for c in collections if c['name'] == name] - if len(collection): - collection[0]['filepaths'].append(filepath) - else: - collection = {'name' : name, 'filepaths' : [filepath]} - collections.append(collection) - logging.debug(f'collections:\n{collections}') - - # Instantiate Tomo object - tomo = Tomo(config_file=args.config, config_out=args.output_config, log_level=log_level, - log_stream=args.log, galaxy_flag=True) - if config_type == 'config_file': - if not tomo.is_valid: - raise ValueError('Invalid config file provided.') - else: - assert(tomo.config is None) - tomo.config = {} - logging.debug(f'config:\n{tomo.config}') - - # Set detector inputs - if config_type == 'config_manual': - detector = args.detector.split() - tomo.config['detector'] = {'rows' : int(detector[0]), - 'columns' : int(detector[1]), 'pixel_size' : float(detector[2])} - - # Set theta inputs - config_theta_range = tomo.config.get('theta_range') - if config_theta_range is None: - tomo.config['theta_range'] = {'num' : num_theta} - config_theta_range = tomo.config['theta_range'] - else: - config_theta_range['num'] = num_theta - if config_type == 'config_manual': - theta_range = args.theta_range.split() - config_theta_range['start'] = float(theta_range[0]) - config_theta_range['end'] = float(theta_range[1]) - - # Find dark field files - dark_field = tomo.config.get('dark_field') - tdf_files = [c['filepaths'] for c in collections if c['name'] == 'tdf'] - if len(tdf_files) != 1 or len(tdf_files[0]) < 1: - logging.warning('Unable to obtain dark field files') - if config_type == 'config_file': - assert(dark_field is not None) - assert(dark_field['data_path'] is None) - if dark_field.get('img_start') is None or dark_field['img_start'] != -1: - dark_field['img_start'] = -1 - if dark_field.get('num') is None or dark_field['num'] != 0: - dark_field['num'] = 0 - else: - tomo.config['dark_field'] = {'data_path' : None, 'img_start' : -1, 'num' : 0} - tdf_files = [None] - num_collections = 0 - else: - if config_type == 'config_file': - assert(dark_field is not None) - assert(dark_field['data_path'] is not None) - if dark_field.get('img_start') is None: - dark_field['img_start'] = 0 - else: - tomo.config['dark_field'] = {'data_path' : tdf_files[0], 'img_start' : 0} - dark_field = tomo.config['dark_field'] - tdf_index = [i for i,c in enumerate(collections) if c['name'] == 'tdf'] - tdf_index_check = [i for i,s in enumerate(stack_types) if s == 'tdf'] - if tdf_index != tdf_index_check: - raise ValueError(f'Inconsistent tdf_index ({tdf_index} vs. {tdf_index_check}).') - tdf_index = tdf_index[0] - dark_field['img_offset'] = img_offsets[tdf_index] - dark_field['num'] = num_imgs[tdf_index] - num_collections = 1 - - # Find bright field files - bright_field = tomo.config.get('bright_field') - tbf_files = [c['filepaths'] for c in collections if c['name'] == 'tbf'] - if len(tbf_files) != 1 or len(tbf_files[0]) < 1: - exit('Unable to obtain bright field files') - if config_type == 'config_file': - assert(bright_field is not None) - assert(bright_field['data_path'] is not None) - if bright_field.get('img_start') is None: - bright_field['img_start'] = 0 - else: - tomo.config['bright_field'] = {'data_path' : tbf_files[0], 'img_start' : 0} - bright_field = tomo.config['bright_field'] - tbf_index = [i for i,c in enumerate(collections) if c['name'] == 'tbf'] - tbf_index_check = [i for i,s in enumerate(stack_types) if s == 'tbf'] - if tbf_index != tbf_index_check: - raise ValueError(f'Inconsistent tbf_index ({tbf_index} vs. {tbf_index_check}).') - tbf_index = tbf_index[0] - bright_field['img_offset'] = img_offsets[tbf_index] - bright_field['num'] = num_imgs[tbf_index] - num_collections += 1 - - # Find tomography files - stack_info = tomo.config.get('stack_info') - if config_type == 'config_file': - assert(stack_info is not None) - if stack_info['num'] != len(collections) - num_collections: - raise ValueError('Inconsistent number of tomography data image sets') - assert(stack_info.get('stacks') is not None) - for stack in stack_info['stacks']: - assert(stack['data_path'] is not None) - if stack.get('img_start') is None: - stack['img_start'] = 0 - assert(stack.get('index') is not None) - assert(stack.get('ref_height') is not None) - else: - tomo.config['stack_info'] = {'num' : len(collections) - num_collections, 'stacks' : []} - stack_info = tomo.config['stack_info'] - for i in range(stack_info['num']): - stack_info['stacks'].append({'img_start' : 0, 'index' : i+1}) - tomo_stack_files = [] - for stack in stack_info['stacks']: - index = stack['index'] - tomo_files = [c['filepaths'] for c in collections if c['name'] == f'set{index}'] - if len(tomo_files) != 1 or len(tomo_files[0]) < 1: - exit(f'Unable to obtain tomography images for set {index}') - tomo_index = [i for i,c in enumerate(collections) if c['name'] == f'set{index}'] - if len(tomo_index) != 1: - raise ValueError(f'Illegal tomo_index ({tomo_index}).') - tomo_index = tomo_index[0] - stack['img_offset'] = img_offsets[tomo_index] - assert(num_imgs[tomo_index] == -1) - stack['num'] = num_theta - if config_type == 'config_manual': - if len(tomo_files) == 1: - stack['data_path'] = tomo_files[0] - stack['ref_height'] = ref_heights[tomo_index] - tomo_stack_files.append(tomo_files[0]) - num_collections += 1 - if num_collections != num_stack: - raise ValueError('Inconsistent number of data image sets') - - # Preprocess the image files - galaxy_param = {'tdf_files' : tdf_files[0], 'tbf_files' : tbf_files[0], - 'tomo_stack_files' : tomo_stack_files, 'output_name' : args.output_data} - tomo.genTomoStacks(galaxy_param) - if not tomo.is_valid: - IOError('Unable to load all required image files.') - - # Displaying memory usage - logging.info(f'Memory usage: {tracemalloc.get_traced_memory()}') - - # stopping memory monitoring - tracemalloc.stop() - -if __name__ == "__main__": - __main__() -
--- a/tomo_setup.xml Fri Aug 19 20:16:56 2022 +0000 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,228 +0,0 @@ -<tool id="tomo_setup" name="Tomo Setup" version="0.2.2" python_template_version="3.9"> - <description>Preprocess tomography images</description> - <macros> - <import>tomo_macros.xml</import> - </macros> - <expand macro="requirements" /> - <command detect_errors="exit_code"> - <![CDATA[ - mkdir setup_pngs; - cp '$inputconfig' inputconfig.txt && - cp '$inputfiles' inputfiles.txt && - $__tool_directory__/tomo_setup.py - --inputconfig inputconfig.txt - --inputfiles inputfiles.txt - #if str($config_type.config_selector) == "config_file" - -c '$config' - --num_theta '$config_type.num_theta' - #else - --detector '$config_type.detector.num_row $config_type.detector.num_column $config_type.detector.pixel_size' - --num_theta '$config_type.thetas.num_theta' - --theta_range '$config_type.thetas.theta_start $config_type.thetas.theta_end' - #end if - --output_config 'output_config.yaml' - --output_data 'output_data.npz' - -l '$log' - ]]> - </command> - <configfiles> - <configfile name="inputconfig"> - <![CDATA[#slurp - #set $count = 0 - #for $s in $config_type.input.tomo_sets - #set $count += 1 - #end for - #echo str($config_type.config_selector) # - #echo str($config_type.input.type_selector) # - #echo str($count) # - #for $s in $config_type.input.tomo_sets - #echo ' ' + str($s.set_type.set_selector) - #end for - #echo '\n' - #for $s in $config_type.input.tomo_sets - #if str($s.set_type.set_selector) == "data" - #echo ' ' + '-1' - #else - #echo ' ' + str($s.set_type.num) - #end if - #end for - #echo '\n' - #for $s in $config_type.input.tomo_sets - #echo ' ' + str($s.set_type.offset) - #end for - #echo '\n' - #if str($config_type.config_selector) == "config_manual" - #for $s in $config_type.input.tomo_sets - #if str($s.set_type.set_selector) == "data" - #echo ' ' + str($s.set_type.ref_height) - #else - #echo ' ' + '0.0' - #end if - #end for - #echo '\n' - #end if - ]]> - </configfile> - <configfile name="inputfiles"> - <![CDATA[#slurp - #if str($config_type.input.type_selector) == "collections" - #for $s in $config_type.input.tomo_sets - #for $input in $s.inputs - #echo str($input) + '\t' + $input.element_identifier # - #end for - #end for - #else - #for $s in $config_type.input.tomo_sets - #echo str($s.inputs) # - #end for - #end if - ]]> - </configfile> - </configfiles> - <inputs> - <conditional name="config_type"> - <param name="config_selector" type="select" label="Read config from file or enter manually"> - <option value="config_file" selected="true">Read config from file</option> - <option value="config_manual">Manually enter config parameters</option> - </param> - <when value="config_file"> - <expand macro="common_inputs"/> - <param name="num_theta" type="integer" min="1" value="0" optional="false" label="Number of angles"/> - <conditional name="input"> - <param name="type_selector" type="select" label="Choose the dataset format"> - <option value="collections">datasets as collections</option> - <option value="files">datasets as files</option> - </param> - <when value="collections"> - <repeat name='tomo_sets' title="Tomography image collections"> - <param name="inputs" type="data_collection" label="Image file collection"/> - <conditional name="set_type"> - <param name="set_selector" type="select" label="Choose the dataset type"> - <option value="tdf">dark field</option> - <option value="tbf">bright field</option> - <option value="data">tomography field</option> - </param> - <when value="tdf"> - <param name="offset" type="integer" min="0" value="0" label="Image index offset"/> - <param name="num" type="integer" min="1" value="1" label="Number of images"/> - </when> - <when value="tbf"> - <param name="offset" type="integer" min="0" value="0" label="Image index offset"/> - <param name="num" type="integer" min="1" value="1" label="Number of images"/> - </when> - <when value="data"> - <param name="offset" type="integer" min="0" value="0" label="Image index offset"/> - </when> - </conditional> - </repeat> - </when> - <when value="files"> - <repeat name='tomo_sets' title="Tomography image datasets"> - <param name="inputs" type="data" format='h5' optional='false' label="Image file"/> - <conditional name="set_type"> - <param name="set_selector" type="select" label="Choose the dataset type"> - <option value="tdf">dark field</option> - <option value="tbf">bright field</option> - <option value="data">tomography field</option> - </param> - <when value="tdf"> - <param name="offset" type="integer" min="0" value="0" label="Image index offset"/> - <param name="num" type="integer" min="1" value="1" label="Number of images"/> - </when> - <when value="tbf"> - <param name="offset" type="integer" min="0" value="0" label="Image index offset"/> - <param name="num" type="integer" min="1" value="1" label="Number of images"/> - </when> - <when value="data"> - <param name="offset" type="integer" min="0" value="0" label="Image index offset"/> - </when> - </conditional> - </repeat> - </when> - </conditional> - </when> - <when value="config_manual"> - <section name="thetas" title="Tomography angles"> - <param name="num_theta" type="integer" min="1" value="0" optional="false" label="Number of angles"/> - <param name="theta_start" type="float" min="0.0" max="360.0" value="0.0" optional="false" label="Start angle"/> - <param name="theta_end" type="float" min="0.0" max="360.0" value="180.0" optional="false" label="End angle"/> - </section> - <section name="detector" title="Detector parameters"> - <param name="num_row" type="integer" min="1" value="0" optional="false" label="Number of pixel rows"/> - <param name="num_column" type="integer" min="1" value="0" optional="false" label="Number of pixel columns"/> - <param name="pixel_size" type="float" min="0.0" value="0.0" optional="false" label="Pixel size (corrected for lens magnification)(RV: needs unit)"/> - </section> - <conditional name="input"> - <param name="type_selector" type="select" label="Choose the dataset format"> - <option value="collections">datasets as collections</option> - <option value="files">datasets as files</option> - </param> - <when value="collections"> - <repeat name='tomo_sets' title="Tomography image collections"> - <param name="inputs" type="data_collection" label="Image file collection"/> - <conditional name="set_type"> - <param name="set_selector" type="select" label="Choose the dataset type"> - <option value="tdf">dark field</option> - <option value="tbf">bright field</option> - <option value="data">tomography field</option> - </param> - <when value="tdf"> - <param name="offset" type="integer" min="0" value="0" label="Image index offset"/> - <param name="num" type="integer" min="1" value="1" label="Number of images"/> - </when> - <when value="tbf"> - <param name="offset" type="integer" min="0" value="0" label="Image index offset"/> - <param name="num" type="integer" min="1" value="1" label="Number of images"/> - </when> - <when value="data"> - <param name="offset" type="integer" min="0" value="0" label="Image index offset"/> - <param name="ref_height" type="float" value="0.0" label="Reference height"/> - </when> - </conditional> - </repeat> - </when> - <when value="files"> - <repeat name='tomo_sets' title="Tomography image datasets"> - <param name="inputs" type="data" format='h5' optional='false' label="Image file"/> - <conditional name="set_type"> - <param name="set_selector" type="select" label="Choose the dataset type"> - <option value="tdf">dark field</option> - <option value="tbf">bright field</option> - <option value="data">tomography field</option> - </param> - <when value="tdf"> - <param name="offset" type="integer" min="0" value="0" label="Image index offset"/> - <param name="num" type="integer" min="1" value="1" label="Number of images"/> - </when> - <when value="tbf"> - <param name="offset" type="integer" min="0" value="0" label="Image index offset"/> - <param name="num" type="integer" min="1" value="1" label="Number of images"/> - </when> - <when value="data"> - <param name="offset" type="integer" min="0" value="0" label="Image index offset"/> - <param name="ref_height" type="float" value="0.0" label="Reference height"/> - </when> - </conditional> - </repeat> - </when> - </conditional> - </when> - </conditional> - </inputs> - <outputs> - <expand macro="common_outputs"/> - <data name="inputconfig" format="txt" label="Input config" from_work_dir="inputconfig.txt" hidden="false"/> - <data name="inputfiles" format="txt" label="Input files" from_work_dir="inputfiles.txt" hidden="false"/> - <collection name="setup_pngs" type="list" label="Tomo setup images"> - <discover_datasets pattern="__name_and_ext__" directory="setup_pngs"/> - </collection> - <data name="output_config" format="tomo.config.yaml" label="Output config setup" from_work_dir="output_config.yaml"/> - <data name="output_data" format="npz" label="Preprocessed tomography data" from_work_dir="output_data.npz"/> - </outputs> - <help> - <![CDATA[ - Preprocess tomography images. - ]]> - </help> - <expand macro="citations"/> -</tool>
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/workflow/__main__.py Fri Mar 10 16:02:04 2023 +0000 @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 + +import logging +logging.getLogger(__name__) + +import argparse +import pathlib +import sys + +from .models import TomoWorkflow as Workflow +try: + from deepdiff import DeepDiff +except: + pass + +parser = argparse.ArgumentParser(description='''Operate on representations of + Tomo data workflows saved to files.''') +parser.add_argument('-l', '--log', +# type=argparse.FileType('w'), + default=sys.stdout, + help='Logging stream or filename') +parser.add_argument('--log_level', + choices=logging._nameToLevel.keys(), + default='INFO', + help='''Specify a preferred logging level.''') +subparsers = parser.add_subparsers(title='subcommands', required=True)#, dest='command') + + +# CONSTRUCT +def construct(args:list) -> None: + if args.template_file is not None: + wf = Workflow.construct_from_file(args.template_file) + wf.cli() + else: + wf = Workflow.construct_from_cli() + wf.write_to_file(args.output_file, force_overwrite=args.force_overwrite) + +construct_parser = subparsers.add_parser('construct', help='''Construct a valid Tomo + workflow representation on the command line and save it to a file. Optionally use + an existing file as a template and/or preform the reconstruction or transfer to Galaxy.''') +construct_parser.set_defaults(func=construct) +construct_parser.add_argument('-t', '--template_file', + type=pathlib.Path, + required=False, + help='''Full or relative template file path for the constructed workflow.''') +construct_parser.add_argument('-f', '--force_overwrite', + action='store_true', + help='''Use this flag to overwrite the output file if it already exists.''') +construct_parser.add_argument('-o', '--output_file', + type=pathlib.Path, + help='''Full or relative file path to which the constructed workflow will be written.''') + + +# VALIDATE +def validate(args:list) -> bool: + try: + wf = Workflow.construct_from_file(args.input_file) + logger.info(f'Success: {args.input_file} represents a valid Tomo workflow configuration.') + return(True) + except BaseException as e: + logger.error(f'{e.__class__.__name__}: {str(e)}') + logger.info(f'''Failure: {args.input_file} does not represent a valid Tomo workflow + configuration.''') + return(False) + +validate_parser = subparsers.add_parser('validate', + help='''Validate a file as a representation of a Tomo workflow (this is most useful + after a .yaml file has been manually edited).''') +validate_parser.set_defaults(func=validate) +validate_parser.add_argument('input_file', + type=pathlib.Path, + help='''Full or relative file path to validate as a Tomo workflow.''') + + +# CONVERT +def convert(args:list) -> None: + wf = Workflow.construct_from_file(args.input_file) + wf.write_to_file(args.output_file, force_overwrite=args.force_overwrite) + +convert_parser = subparsers.add_parser('convert', help='''Convert one Tomo workflow + representation to another. File format of both input and output files will be + automatically determined from the files' extensions.''') +convert_parser.set_defaults(func=convert) +convert_parser.add_argument('-f', '--force_overwrite', + action='store_true', + help='''Use this flag to overwrite the output file if it already exists.''') +convert_parser.add_argument('-i', '--input_file', + type=pathlib.Path, + required=True, + help='''Full or relative input file path to be converted.''') +convert_parser.add_argument('-o', '--output_file', + type=pathlib.Path, + required=True, + help='''Full or relative file path to which the converted input will be written.''') + + +# DIFF / COMPARE +def diff(args:list) -> bool: + raise ValueError('diff not tested') +# wf1 = Workflow.construct_from_file(args.file1).dict_for_yaml() +# wf2 = Workflow.construct_from_file(args.file2).dict_for_yaml() +# diff = DeepDiff(wf1,wf2, +# ignore_order_func=lambda level:'independent_dimensions' not in level.path(), +# report_repetition=True, +# ignore_string_type_changes=True, +# ignore_numeric_type_changes=True) + diff_report = diff.pretty() + if len(diff_report) > 0: + logger.info(f'The configurations in {args.file1} and {args.file2} are not identical.') + print(diff_report) + return(True) + else: + logger.info(f'The configurations in {args.file1} and {args.file2} are identical.') + return(False) + +diff_parser = subparsers.add_parser('diff', aliases=['compare'], help='''Print a comparison of + two Tomo workflow representations stored in files. The files may have different formats.''') +diff_parser.set_defaults(func=diff) +diff_parser.add_argument('file1', + type=pathlib.Path, + help='''Full or relative path to the first file for comparison.''') +diff_parser.add_argument('file2', + type=pathlib.Path, + help='''Full or relative path to the second file for comparison.''') + + +# LINK TO GALAXY +def link_to_galaxy(args:list) -> None: + from .link_to_galaxy import link_to_galaxy + link_to_galaxy(args.input_file, galaxy=args.galaxy, user=args.user, + password=args.password, api_key=args.api_key) + +link_parser = subparsers.add_parser('link_to_galaxy', help='''Construct a Galaxy history and link + to an existing Tomo workflow representations in a NeXus file.''') +link_parser.set_defaults(func=link_to_galaxy) +link_parser.add_argument('-i', '--input_file', + type=pathlib.Path, + required=True, + help='''Full or relative input file path to the existing Tomo workflow representations as + a NeXus file.''') +link_parser.add_argument('-g', '--galaxy', + required=True, + help='Target Galaxy instance URL/IP address') +link_parser.add_argument('-u', '--user', + default=None, + help='Galaxy user email address') +link_parser.add_argument('-p', '--password', + default=None, + help='Password for the Galaxy user') +link_parser.add_argument('-a', '--api_key', + default=None, + help='Galaxy admin user API key (required if not defined in the tools list file)') + + +# RUN THE RECONSTRUCTION +def run_tomo(args:list) -> None: + from .run_tomo import run_tomo + run_tomo(args.input_file, args.output_file, args.modes, center_file=args.center_file, + num_core=args.num_core, output_folder=args.output_folder, save_figs=args.save_figs) + +tomo_parser = subparsers.add_parser('run_tomo', help='''Construct and add reconstructed tomography + data to an existing Tomo workflow representations in a NeXus file.''') +tomo_parser.set_defaults(func=run_tomo) +tomo_parser.add_argument('-i', '--input_file', + required=True, + type=pathlib.Path, + help='''Full or relative input file path containing raw and/or reduced data.''') +tomo_parser.add_argument('-o', '--output_file', + required=True, + type=pathlib.Path, + help='''Full or relative input file path containing raw and/or reduced data.''') +tomo_parser.add_argument('-c', '--center_file', + type=pathlib.Path, + help='''Full or relative input file path containing the rotation axis centers info.''') +#tomo_parser.add_argument('-f', '--force_overwrite', +# action='store_true', +# help='''Use this flag to overwrite any existing reduced data.''') +tomo_parser.add_argument('-n', '--num_core', + type=int, + default=-1, + help='''Specify the number of processors to use.''') +tomo_parser.add_argument('--output_folder', + type=pathlib.Path, + default='.', + help='Full or relative path to an output folder') +tomo_parser.add_argument('-s', '--save_figs', + choices=['yes', 'no', 'only'], + default='no', + help='''Specify weather to display ('yes' or 'no'), save ('yes'), or only save ('only').''') +tomo_parser.add_argument('--reduce_data', + dest='modes', + const='reduce_data', + action='append_const', + help='''Use this flag to create and add reduced data to the input file.''') +tomo_parser.add_argument('--find_center', + dest='modes', + const='find_center', + action='append_const', + help='''Use this flag to find and add the calibrated center axis info to the input file.''') +tomo_parser.add_argument('--reconstruct_data', + dest='modes', + const='reconstruct_data', + action='append_const', + help='''Use this flag to create and add reconstructed data data to the input file.''') +tomo_parser.add_argument('--combine_datas', + dest='modes', + const='combine_datas', + action='append_const', + help='''Use this flag to combine reconstructed data data and add to the input file.''') + + +if __name__ == '__main__': + args = parser.parse_args(sys.argv[1:]) + + # Set log configuration + # When logging to file, the stdout log level defaults to WARNING + logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s' + level = logging.getLevelName(args.log_level) + if args.log is sys.stdout: + logging.basicConfig(format=logging_format, level=level, force=True, + handlers=[logging.StreamHandler()]) + else: + if isinstance(args.log, str): + logging.basicConfig(filename=f'{args.log}', filemode='w', + format=logging_format, level=level, force=True) + elif isinstance(args.log, io.TextIOWrapper): + logging.basicConfig(filemode='w', format=logging_format, level=level, + stream=args.log, force=True) + else: + raise ValueError(f'Invalid argument --log: {args.log}') + stream_handler = logging.StreamHandler() + logging.getLogger().addHandler(stream_handler) + stream_handler.setLevel(logging.WARNING) + stream_handler.setFormatter(logging.Formatter(logging_format)) + + args.func(args)
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/workflow/__version__.py Fri Mar 10 16:02:04 2023 +0000 @@ -0,0 +1,1 @@ +__version__='2022.3.0'
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/workflow/link_to_galaxy.py Fri Mar 10 16:02:04 2023 +0000 @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 + +import logging +logger = logging.getLogger(__name__) + +from bioblend.galaxy import GalaxyInstance +from nexusformat.nexus import * +from os import path +from yaml import safe_load + +from .models import import_scanparser, TomoWorkflow + +def get_folder_id(gi, path): + library_id = None + folder_id = None + folder_names = path[1:] if len(path) > 1 else [] + new_folders = folder_names + libs = gi.libraries.get_libraries(name=path[0]) + if libs: + for lib in libs: + library_id = lib['id'] + folders = gi.libraries.get_folders(library_id, folder_id=None, name=None) + for i, folder in enumerate(folders): + fid = folder['id'] + details = gi.libraries.show_folder(library_id, fid) + library_path = details['library_path'] + if library_path == folder_names: + return (library_id, fid, []) + elif len(library_path) < len(folder_names): + if library_path == folder_names[:len(library_path)]: + nf = folder_names[len(library_path):] + if len(nf) < len(new_folders): + folder_id = fid + new_folders = nf + return (library_id, folder_id, new_folders) + +def link_to_galaxy(filename:str, galaxy=None, user=None, password=None, api_key=None) -> None: + # Read input file + extension = path.splitext(filename)[1] + if extension == '.yml' or extension == '.yaml': + with open(filename, 'r') as f: + data = safe_load(f) + elif extension == '.nxs': + with NXFile(filename, mode='r') as nxfile: + data = nxfile.readfile() + else: + raise ValueError(f'Invalid filename extension ({extension})') + if isinstance(data, dict): + # Create Nexus format object from input dictionary + wf = TomoWorkflow(**data) + if len(wf.sample_maps) > 1: + raise ValueError(f'Multiple sample maps not yet implemented') + nxroot = NXroot() + for sample_map in wf.sample_maps: + import_scanparser(sample_map.station) + sample_map.construct_nxentry(nxroot, include_raw_data=False) + nxentry = nxroot[nxroot.attrs['default']] + elif isinstance(data, NXroot): + nxentry = data[data.attrs['default']] + else: + raise ValueError(f'Invalid input file data ({data})') + + # Get a Galaxy instance + if user is not None and password is not None : + gi = GalaxyInstance(url=galaxy, email=user, password=password) + elif api_key is not None: + gi = GalaxyInstance(url=galaxy, key=api_key) + else: + exit('Please specify either a valid Galaxy username/password or an API key.') + + cycle = nxentry.instrument.source.attrs['cycle'] + btr = nxentry.instrument.source.attrs['btr'] + sample = nxentry.sample.name + + # Create a Galaxy work library/folder + # Combine the cycle, BTR and sample name as the base library name + lib_path = [p.strip() for p in f'{cycle}/{btr}/{sample}'.split('/')] + (library_id, folder_id, folder_names) = get_folder_id(gi, lib_path) + if not library_id: + library = gi.libraries.create_library(lib_path[0], description=None, synopsis=None) + library_id = library['id'] +# if user: +# gi.libraries.set_library_permissions(library_id, access_ids=user, +# manage_ids=user, modify_ids=user) + logger.info(f'Created Library:\n{library}') + if len(folder_names): + folder = gi.libraries.create_folder(library_id, folder_names[0], description=None, + base_folder_id=folder_id)[0] + folder_id = folder['id'] + logger.info(f'Created Folder:\n{folder}') + folder_names.pop(0) + while len(folder_names): + folder = gi.folders.create_folder(folder['id'], folder_names[0], + description=None) + folder_id = folder['id'] + logger.info(f'Created Folder:\n{folder}') + folder_names.pop(0) + + # Create a sym link for the Nexus file + dataset = gi.libraries.upload_from_galaxy_filesystem(library_id, path.abspath(filename), + folder_id=folder_id, file_type='auto', dbkey='?', link_data_only='link_to_files', + roles='', preserve_dirs=False, tag_using_filenames=False, tags=None)[0] + + # Make a history for the data + history_name = f'tomo {btr} {sample}' + history = gi.histories.create_history(name=history_name) + logger.info(f'Created history:\n{history}') + history_id = history['id'] + gi.histories.copy_dataset(history_id, dataset['id'], source='library') + +# TODO add option to either +# get a URL to share the history +# or to share with specific users +# This might require using: +# https://bioblend.readthedocs.io/en/latest/api_docs/galaxy/docs.html#using-bioblend-for-raw-api-calls +
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/workflow/models.py Fri Mar 10 16:02:04 2023 +0000 @@ -0,0 +1,1077 @@ +#!/usr/bin/env python3 + +import logging +logger = logging.getLogger(__name__) + +import logging + +import numpy as np +import os +import yaml + +from functools import cache +from pathlib import PosixPath +from pydantic import validator, ValidationError, conint, confloat, constr, \ + conlist, FilePath, PrivateAttr +from pydantic import BaseModel as PydanticBaseModel +from nexusformat.nexus import * +from time import time +from typing import Optional, Literal +from typing_extensions import TypedDict +try: + from pyspec.file.spec import FileSpec +except: + pass + +from msnctools.general import is_int, is_num, input_int, input_int_list, input_num, input_yesno, \ + input_menu, index_nearest, string_to_list, file_exists_and_readable + + +def import_scanparser(station): + if station in ('id1a3', 'id3a'): + from msnctools.scanparsers import SMBRotationScanParser + globals()['ScanParser'] = SMBRotationScanParser + elif station in ('id3b'): + from msnctools.scanparsers import FMBRotationScanParser + globals()['ScanParser'] = FMBRotationScanParser + else: + raise RuntimeError(f'Invalid station: {station}') + +@cache +def get_available_scan_numbers(spec_file:str): + scans = FileSpec(spec_file).scans + scan_numbers = list(scans.keys()) + for scan_number in scan_numbers.copy(): + try: + parser = ScanParser(spec_file, scan_number) + try: + scan_type = parser.get_scan_type() + except: + scan_type = None + pass + except: + scan_numbers.remove(scan_number) + return(scan_numbers) + +@cache +def get_scanparser(spec_file:str, scan_number:int): + if scan_number not in get_available_scan_numbers(spec_file): + return(None) + else: + return(ScanParser(spec_file, scan_number)) + + +class BaseModel(PydanticBaseModel): + class Config: + validate_assignment = True + arbitrary_types_allowed = True + + @classmethod + def construct_from_cli(cls): + obj = cls.construct() + obj.cli() + return(obj) + + @classmethod + def construct_from_yaml(cls, filename): + try: + with open(filename, 'r') as infile: + indict = yaml.load(infile, Loader=yaml.CLoader) + except: + raise ValueError(f'Could not load a dictionary from {filename}') + else: + obj = cls(**indict) + return(obj) + + @classmethod + def construct_from_file(cls, filename): + file_exists_and_readable(filename) + filename = os.path.abspath(filename) + fileformat = os.path.splitext(filename)[1] + yaml_extensions = ('.yaml','.yml') + nexus_extensions = ('.nxs','.nx5','.h5','.hdf5') + t0 = time() + if fileformat.lower() in yaml_extensions: + obj = cls.construct_from_yaml(filename) + logger.info(f'Constructed a model from {filename} in {time()-t0:.2f} seconds.') + return(obj) + elif fileformat.lower() in nexus_extensions: + obj = cls.construct_from_nexus(filename) + logger.info(f'Constructed a model from {filename} in {time()-t0:.2f} seconds.') + return(obj) + else: + logger.error(f'Unsupported file extension for constructing a model: {fileformat}') + raise TypeError(f'Unrecognized file extension: {fileformat}') + + def dict_for_yaml(self, exclude_fields=[]): + yaml_dict = {} + for field_name in self.__fields__: + if field_name in exclude_fields: + continue + else: + field_value = getattr(self, field_name, None) + if field_value is not None: + if isinstance(field_value, BaseModel): + yaml_dict[field_name] = field_value.dict_for_yaml() + elif isinstance(field_value,list) and all(isinstance(item,BaseModel) + for item in field_value): + yaml_dict[field_name] = [item.dict_for_yaml() for item in field_value] + elif isinstance(field_value, PosixPath): + yaml_dict[field_name] = str(field_value) + else: + yaml_dict[field_name] = field_value + else: + continue + return(yaml_dict) + + def write_to_yaml(self, filename=None): + yaml_dict = self.dict_for_yaml() + if filename is None: + logger.info('Printing yaml representation here:\n'+ + f'{yaml.dump(yaml_dict, sort_keys=False)}') + else: + try: + with open(filename, 'w') as outfile: + yaml.dump(yaml_dict, outfile, sort_keys=False) + logger.info(f'Successfully wrote this model to {filename}') + except: + logger.error(f'Unknown error -- could not write to {filename} in yaml format.') + logger.info('Printing yaml representation here:\n'+ + f'{yaml.dump(yaml_dict, sort_keys=False)}') + + def write_to_file(self, filename, force_overwrite=False): + file_writeable, fileformat = self.output_file_valid(filename, + force_overwrite=force_overwrite) + if fileformat == 'yaml': + if file_writeable: + self.write_to_yaml(filename=filename) + else: + self.write_to_yaml() + elif fileformat == 'nexus': + if file_writeable: + self.write_to_nexus(filename=filename) + + def output_file_valid(self, filename, force_overwrite=False): + filename = os.path.abspath(filename) + fileformat = os.path.splitext(filename)[1] + yaml_extensions = ('.yaml','.yml') + nexus_extensions = ('.nxs','.nx5','.h5','.hdf5') + if fileformat.lower() not in (*yaml_extensions, *nexus_extensions): + return(False, None) # Only yaml and NeXus files allowed for output now. + elif fileformat.lower() in yaml_extensions: + fileformat = 'yaml' + elif fileformat.lower() in nexus_extensions: + fileformat = 'nexus' + if os.path.isfile(filename): + if os.access(filename, os.W_OK): + if not force_overwrite: + logger.error(f'{filename} will not be overwritten.') + return(False, fileformat) + else: + logger.error(f'Cannot access {filename} for writing.') + return(False, fileformat) + if os.path.isdir(os.path.dirname(filename)): + if os.access(os.path.dirname(filename), os.W_OK): + return(True, fileformat) + else: + logger.error(f'Cannot access {os.path.dirname(filename)} for writing.') + return(False, fileformat) + else: + try: + os.makedirs(os.path.dirname(filename)) + return(True, fileformat) + except: + logger.error(f'Cannot create {os.path.dirname(filename)} for output.') + return(False, fileformat) + + def set_single_attr_cli(self, attr_name, attr_desc='unknown attribute', list_flag=False, + **cli_kwargs): + if cli_kwargs.get('chain_attr_desc', False): + cli_kwargs['attr_desc'] = attr_desc + try: + attr = getattr(self, attr_name, None) + if attr is None: + attr = self.__fields__[attr_name].type_.construct() + if cli_kwargs.get('chain_attr_desc', False): + cli_kwargs['attr_desc'] = attr_desc + input_accepted = False + while not input_accepted: + try: + attr.cli(**cli_kwargs) + except ValidationError as e: + print(e) + print(f'Removing {attr_desc} configuration') + attr = self.__fields__[attr_name].type_.construct() + continue + except KeyboardInterrupt as e: + raise e + except BaseException as e: + print(f'{type(e).__name__}: {e}') + print(f'Removing {attr_desc} configuration') + attr = self.__fields__[attr_name].type_.construct() + continue + try: + setattr(self, attr_name, attr) + except ValidationError as e: + print(e) + except KeyboardInterrupt as e: + raise e + except BaseException as e: + print(f'{type(e).__name__}: {e}') + else: + input_accepted = True + except: + input_accepted = False + while not input_accepted: + attr = getattr(self, attr_name, None) + if attr is None: + input_value = input(f'Type and enter a value for {attr_desc}: ') + else: + input_value = input(f'Type and enter a new value for {attr_desc} or press '+ + f'enter to keep the current one ({attr}): ') + if list_flag: + input_value = string_to_list(input_value, remove_duplicates=False, sort=False) + if len(input_value) == 0: + input_value = getattr(self, attr_name, None) + try: + setattr(self, attr_name, input_value) + except ValidationError as e: + print(e) + except KeyboardInterrupt as e: + raise e + except BaseException as e: + print(f'Unexpected {type(e).__name__}: {e}') + else: + input_accepted = True + + def set_list_attr_cli(self, attr_name, attr_desc='unknown attribute', **cli_kwargs): + if cli_kwargs.get('chain_attr_desc', False): + cli_kwargs['attr_desc'] = attr_desc + attr = getattr(self, attr_name, None) + if attr is not None: + # Check existing items + for item in attr: + item_accepted = False + while not item_accepted: + item.cli(**cli_kwargs) + try: + setattr(self, attr_name, attr) + except ValidationError as e: + print(e) + except KeyboardInterrupt as e: + raise e + except BaseException as e: + print(f'{type(e).__name__}: {e}') + else: + item_accepted = True + else: + # Initialize list for new attribute & starting item + attr = [] + item = self.__fields__[attr_name].type_.construct() + # Append (optional) additional items + append = input_yesno(f'Add a {attr_desc} configuration? (y/n)', 'n') + while append: + attr.append(item.__class__.construct_from_cli()) + try: + setattr(self, attr_name, attr) + except ValidationError as e: + print(e) + print(f'Removing last {attr_desc} configuration from the list') + attr.pop() + except KeyboardInterrupt as e: + raise e + except BaseException as e: + print(f'{type(e).__name__}: {e}') + print(f'Removing last {attr_desc} configuration from the list') + attr.pop() + else: + append = input_yesno(f'Add another {attr_desc} configuration? (y/n)', 'n') + + +class Detector(BaseModel): + prefix: constr(strip_whitespace=True, min_length=1) + rows: conint(gt=0) + columns: conint(gt=0) + pixel_size: conlist(item_type=confloat(gt=0), min_items=1, max_items=2) + lens_magnification: confloat(gt=0) = 1.0 + + @property + def get_pixel_size(self): + return(list(np.asarray(self.pixel_size)/self.lens_magnification)) + + def construct_from_yaml(self, filename): + try: + with open(filename, 'r') as infile: + indict = yaml.load(infile, Loader=yaml.CLoader) + detector = indict['detector'] + self.prefix = detector['id'] + pixels = detector['pixels'] + self.rows = pixels['rows'] + self.columns = pixels['columns'] + self.pixel_size = pixels['size'] + self.lens_magnification = indict['lens_magnification'] + except: + logging.warning(f'Could not load a dictionary from {filename}') + return(False) + else: + return(True) + + def cli(self): + print('\n -- Configure the detector -- ') + self.set_single_attr_cli('prefix', 'detector ID') + self.set_single_attr_cli('rows', 'number of pixel rows') + self.set_single_attr_cli('columns', 'number of pixel columns') + self.set_single_attr_cli('pixel_size', 'pixel size in mm (enter either a single value for '+ + 'square pixels or a pair of values for the size in the respective row and column '+ + 'directions)', list_flag=True) + self.set_single_attr_cli('lens_magnification', 'lens magnification') + + def construct_nxdetector(self): + nxdetector = NXdetector() + nxdetector.local_name = self.prefix + pixel_size = self.get_pixel_size + if len(pixel_size) == 1: + nxdetector.x_pixel_size = pixel_size[0] + nxdetector.y_pixel_size = pixel_size[0] + else: + nxdetector.x_pixel_size = pixel_size[0] + nxdetector.y_pixel_size = pixel_size[1] + nxdetector.x_pixel_size.attrs['units'] = 'mm' + nxdetector.y_pixel_size.attrs['units'] = 'mm' + return(nxdetector) + + +class ScanInfo(TypedDict): + scan_number: int + starting_image_offset: conint(ge=0) + num_image: conint(gt=0) + ref_x: float + ref_z: float + +class SpecScans(BaseModel): + spec_file: FilePath + scan_numbers: conlist(item_type=conint(gt=0), min_items=1) + stack_info: conlist(item_type=ScanInfo, min_items=1) = [] + + @validator('spec_file') + def validate_spec_file(cls, spec_file): + try: + spec_file = os.path.abspath(spec_file) + sspec_file = FileSpec(spec_file) + except: + raise ValueError(f'Invalid SPEC file {spec_file}') + else: + return(spec_file) + + @validator('scan_numbers') + def validate_scan_numbers(cls, scan_numbers, values): + spec_file = values.get('spec_file') + if spec_file is not None: + spec_scans = FileSpec(spec_file) + for scan_number in scan_numbers: + scan = spec_scans.get_scan_by_number(scan_number) + if scan is None: + raise ValueError(f'There is no scan number {scan_number} in {spec_file}') + return(scan_numbers) + + @validator('stack_info') + def validate_stack_info(cls, stack_info, values): + scan_numbers = values.get('scan_numbers') + assert(len(scan_numbers) == len(stack_info)) + for scan_info in stack_info: + assert(scan_info['scan_number'] in scan_numbers) + is_int(scan_info['starting_image_offset'], ge=0, lt=scan_info['num_image'], + raise_error=True) + return(stack_info) + + @classmethod + def construct_from_nxcollection(cls, nxcollection:NXcollection): + config = {} + config['spec_file'] = nxcollection.attrs['spec_file'] + scan_numbers = [] + stack_info = [] + for nxsubentry_name, nxsubentry in nxcollection.items(): + scan_number = int(nxsubentry_name.split('_')[-1]) + scan_numbers.append(scan_number) + stack_info.append({'scan_number': scan_number, + 'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number), + 'num_image': len(nxsubentry.sample.rotation_angle), + 'ref_x': float(nxsubentry.sample.x_translation), + 'ref_z': float(nxsubentry.sample.z_translation)}) + config['scan_numbers'] = sorted(scan_numbers) + config['stack_info'] = stack_info + return(cls(**config)) + + @property + def available_scan_numbers(self): + return(get_available_scan_numbers(self.spec_file)) + + def set_from_nxcollection(self, nxcollection:NXcollection): + self.spec_file = nxcollection.attrs['spec_file'] + scan_numbers = [] + stack_info = [] + for nxsubentry_name, nxsubentry in nxcollection.items(): + scan_number = int(nxsubentry_name.split('_')[-1]) + scan_numbers.append(scan_number) + stack_info.append({'scan_number': scan_number, + 'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number), + 'num_image': len(nxsubentry.sample.rotation_angle), + 'ref_x': float(nxsubentry.sample.x_translation), + 'ref_z': float(nxsubentry.sample.z_translation)}) + self.scan_numbers = sorted(scan_numbers) + self.stack_info = stack_info + + def get_scan_index(self, scan_number): + scan_index = [scan_index for scan_index, scan_info in enumerate(self.stack_info) + if scan_info['scan_number'] == scan_number] + if len(scan_index) > 1: + raise ValueError('Duplicate scan_numbers in image stack') + elif len(scan_index) == 1: + return(scan_index[0]) + else: + return(None) + + def get_scanparser(self, scan_number): + return(get_scanparser(self.spec_file, scan_number)) + +# def get_detector_data(self, detector_prefix, scan_number=None, scan_step_index=None): +# if scan_number is None: +# scan_number = self.scan_numbers[0] +# if scan_step_index is None: +# scan_info = self.stack_info[self.get_scan_index(scan_number)] +# scan_step_index = scan_info['starting_image_offset'] +# parser = self.get_scanparser(scan_number) +# return(parser.get_detector_data(detector_prefix, scan_step_index)) + + def get_detector_data(self, detector_prefix, scan_number=None, scan_step_index=None): + image_stacks = [] + if scan_number is None: + scan_numbers = self.scan_numbers + else: + scan_numbers = [scan_number] + for scan_number in scan_numbers: + parser = self.get_scanparser(scan_number) + scan_info = self.stack_info[self.get_scan_index(scan_number)] + image_offset = scan_info['starting_image_offset'] + if scan_step_index is None: + num_image = scan_info['num_image'] + image_stacks.append(parser.get_detector_data(detector_prefix, + (image_offset, image_offset+num_image))) + else: + image_stacks.append(parser.get_detector_data(detector_prefix, + image_offset+scan_step_index)) + if len(image_stacks) == 1: + return(image_stacks[0]) + else: + return(image_stacks) + + def scan_numbers_cli(self, attr_desc, **kwargs): + available_scan_numbers = self.available_scan_numbers + station = kwargs.get('station') + if (station is not None and station in ('id1a3', 'id3a') and + 'scan_type' in kwargs): + scan_type = kwargs['scan_type'] + if scan_type == 'ts1': + available_scan_numbers = [] + for scan_number in self.available_scan_numbers: + parser = self.get_scanparser(scan_number) + if parser.scan_type == scan_type: + available_scan_numbers.append(scan_number) + elif scan_type == 'df1': + tomo_scan_numbers = kwargs['tomo_scan_numbers'] + available_scan_numbers = [] + for scan_number in tomo_scan_numbers: + parser = self.get_scanparser(scan_number-2) + assert(parser.scan_type == scan_type) + available_scan_numbers.append(scan_number-2) + elif scan_type == 'bf1': + tomo_scan_numbers = kwargs['tomo_scan_numbers'] + available_scan_numbers = [] + for scan_number in tomo_scan_numbers: + parser = self.get_scanparser(scan_number-1) + assert(parser.scan_type == scan_type) + available_scan_numbers.append(scan_number-1) + if len(available_scan_numbers) == 1: + input_mode = 1 + else: + if hasattr(self, 'scan_numbers'): + print(f'Currently selected {attr_desc}scan numbers are: {self.scan_numbers}') + menu_options = [f'Select a subset of the available {attr_desc}scan numbers', + f'Use all available {attr_desc}scan numbers in {self.spec_file}', + f'Keep the currently selected {attr_desc}scan numbers'] + else: + menu_options = [f'Select a subset of the available {attr_desc}scan numbers', + f'Use all available {attr_desc}scan numbers in {self.spec_file}'] + print(f'Available scan numbers in {self.spec_file} are: '+ + f'{available_scan_numbers}') + input_mode = input_menu(menu_options, header='Choose one of the following options '+ + 'for selecting scan numbers') + if input_mode == 0: + accept_scan_numbers = False + while not accept_scan_numbers: + try: + self.scan_numbers = \ + input_int_list(f'Enter a series of {attr_desc}scan numbers') + except ValidationError as e: + print(e) + except KeyboardInterrupt as e: + raise e + except BaseException as e: + print(f'Unexpected {type(e).__name__}: {e}') + else: + accept_scan_numbers = True + elif input_mode == 1: + self.scan_numbers = available_scan_numbers + elif input_mode == 2: + pass + + def cli(self, **cli_kwargs): + if cli_kwargs.get('attr_desc') is not None: + attr_desc = f'{cli_kwargs["attr_desc"]} ' + else: + attr_desc = '' + print(f'\n -- Configure which scans to use from a single {attr_desc}SPEC file') + self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path') + self.scan_numbers_cli(attr_desc) + + def construct_nxcollection(self, image_key, thetas, detector): + nxcollection = NXcollection() + nxcollection.attrs['spec_file'] = str(self.spec_file) + parser = self.get_scanparser(self.scan_numbers[0]) + nxcollection.attrs['date'] = parser.spec_scan.file_date + for scan_number in self.scan_numbers: + # Get scan info + scan_info = self.stack_info[self.get_scan_index(scan_number)] + # Add an NXsubentry to the NXcollection for each scan + entry_name = f'scan_{scan_number}' + nxsubentry = NXsubentry() + nxcollection[entry_name] = nxsubentry + parser = self.get_scanparser(scan_number) + nxsubentry.start_time = parser.spec_scan.date + nxsubentry.spec_command = parser.spec_command + # Add an NXdata for independent dimensions to the scan's NXsubentry + num_image = scan_info['num_image'] + if thetas is None: + thetas = num_image*[0.0] + else: + assert(num_image == len(thetas)) +# nxsubentry.independent_dimensions = NXdata() +# nxsubentry.independent_dimensions.rotation_angle = thetas +# nxsubentry.independent_dimensions.rotation_angle.units = 'degrees' + # Add an NXinstrument to the scan's NXsubentry + nxsubentry.instrument = NXinstrument() + # Add an NXdetector to the NXinstrument to the scan's NXsubentry + nxsubentry.instrument.detector = detector.construct_nxdetector() + nxsubentry.instrument.detector.frame_start_number = scan_info['starting_image_offset'] + nxsubentry.instrument.detector.image_key = image_key + # Add an NXsample to the scan's NXsubentry + nxsubentry.sample = NXsample() + nxsubentry.sample.rotation_angle = thetas + nxsubentry.sample.rotation_angle.units = 'degrees' + nxsubentry.sample.x_translation = scan_info['ref_x'] + nxsubentry.sample.x_translation.units = 'mm' + nxsubentry.sample.z_translation = scan_info['ref_z'] + nxsubentry.sample.z_translation.units = 'mm' + return(nxcollection) + + +class FlatField(SpecScans): + + def image_range_cli(self, attr_desc, detector_prefix): + stack_info = self.stack_info + for scan_number in self.scan_numbers: + # Parse the available image range + parser = self.get_scanparser(scan_number) + image_offset = parser.starting_image_offset + num_image = parser.get_num_image(detector_prefix.upper()) + scan_index = self.get_scan_index(scan_number) + + # Select the image set + last_image_index = image_offset+num_image-1 + print(f'Available good image set index range: [{image_offset}, {last_image_index}]') + image_set_approved = False + if scan_index is not None: + scan_info = stack_info[scan_index] + print(f'Current starting image offset and number of images: '+ + f'{scan_info["starting_image_offset"]} and {scan_info["num_image"]}') + image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y') + if not image_set_approved: + print(f'Default starting image offset and number of images: '+ + f'{image_offset} and {last_image_index-image_offset}') + image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y') + if image_set_approved: + offset = image_offset + num = last_image_index-offset + while not image_set_approved: + offset = input_int(f'Enter the starting image offset', ge=image_offset, + le=last_image_index-1)#, default=image_offset) + num = input_int(f'Enter the number of images', ge=1, + le=last_image_index-offset+1)#, default=last_image_index-offset+1) + print(f'Current starting image offset and number of images: {offset} and {num}') + image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y') + if scan_index is not None: + scan_info['starting_image_offset'] = offset + scan_info['num_image'] = num + scan_info['ref_x'] = parser.horizontal_shift + scan_info['ref_z'] = parser.vertical_shift + else: + stack_info.append({'scan_number': scan_number, 'starting_image_offset': offset, + 'num_image': num, 'ref_x': parser.horizontal_shift, + 'ref_z': parser.vertical_shift}) + self.stack_info = stack_info + + def cli(self, **cli_kwargs): + if cli_kwargs.get('attr_desc') is not None: + attr_desc = f'{cli_kwargs["attr_desc"]} ' + else: + attr_desc = '' + station = cli_kwargs.get('station') + detector = cli_kwargs.get('detector') + print(f'\n -- Configure the location of the {attr_desc}scan data -- ') + if station in ('id1a3', 'id3a'): + self.spec_file = cli_kwargs['spec_file'] + tomo_scan_numbers = cli_kwargs['tomo_scan_numbers'] + scan_type = cli_kwargs['scan_type'] + self.scan_numbers_cli(attr_desc, station=station, tomo_scan_numbers=tomo_scan_numbers, + scan_type=scan_type) + else: + self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path') + self.scan_numbers_cli(attr_desc) + self.image_range_cli(attr_desc, detector.prefix) + + +class TomoField(SpecScans): + theta_range: dict = {} + + @validator('theta_range') + def validate_theta_range(cls, theta_range): + if len(theta_range) != 3 and len(theta_range) != 4: + raise ValueError(f'Invalid theta range {theta_range}') + is_num(theta_range['start'], raise_error=True) + is_num(theta_range['end'], raise_error=True) + is_int(theta_range['num'], gt=1, raise_error=True) + if theta_range['end'] <= theta_range['start']: + raise ValueError(f'Invalid theta range {theta_range}') + if 'start_index' in theta_range: + is_int(theta_range['start_index'], ge=0, raise_error=True) + return(theta_range) + + @classmethod + def construct_from_nxcollection(cls, nxcollection:NXcollection): + #RV Can I derive this from the same classfunction for SpecScans by adding theta_range + config = {} + config['spec_file'] = nxcollection.attrs['spec_file'] + scan_numbers = [] + stack_info = [] + for nxsubentry_name, nxsubentry in nxcollection.items(): + scan_number = int(nxsubentry_name.split('_')[-1]) + scan_numbers.append(scan_number) + stack_info.append({'scan_number': scan_number, + 'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number), + 'num_image': len(nxsubentry.sample.rotation_angle), + 'ref_x': float(nxsubentry.sample.x_translation), + 'ref_z': float(nxsubentry.sample.z_translation)}) + config['scan_numbers'] = sorted(scan_numbers) + config['stack_info'] = stack_info + for name in nxcollection.entries: + if 'scan_' in name: + thetas = np.asarray(nxcollection[name].sample.rotation_angle) + config['theta_range'] = {'start': thetas[0], 'end': thetas[-1], 'num': thetas.size} + break + return(cls(**config)) + + def get_horizontal_shifts(self, scan_number=None): + horizontal_shifts = [] + if scan_number is None: + scan_numbers = self.scan_numbers + else: + scan_numbers = [scan_number] + for scan_number in scan_numbers: + parser = self.get_scanparser(scan_number) + horizontal_shifts.append(parser.get_horizontal_shift()) + if len(horizontal_shifts) == 1: + return(horizontal_shifts[0]) + else: + return(horizontal_shifts) + + def get_vertical_shifts(self, scan_number=None): + vertical_shifts = [] + if scan_number is None: + scan_numbers = self.scan_numbers + else: + scan_numbers = [scan_number] + for scan_number in scan_numbers: + parser = self.get_scanparser(scan_number) + vertical_shifts.append(parser.get_vertical_shift()) + if len(vertical_shifts) == 1: + return(vertical_shifts[0]) + else: + return(vertical_shifts) + + def theta_range_cli(self, scan_number, attr_desc, station): + # Parse the available theta range + parser = self.get_scanparser(scan_number) + theta_vals = parser.theta_vals + spec_theta_start = theta_vals.get('start') + spec_theta_end = theta_vals.get('end') + spec_num_theta = theta_vals.get('num') + + # Check for consistency of theta ranges between scans + if scan_number != self.scan_numbers[0]: + parser = self.get_scanparser(self.scan_numbers[0]) + if (parser.theta_vals.get('start') != spec_theta_start or + parser.theta_vals.get('end') != spec_theta_end or + parser.theta_vals.get('num') != spec_num_theta): + raise ValueError(f'Incompatible theta ranges between {attr_desc}scans:'+ + f'\n\tScan {scan_number}: {theta_vals}'+ + f'\n\tScan {self.scan_numbers[0]}: {parser.theta_vals}') + return + + # Select the theta range for the tomo reconstruction from the first scan + thetas = np.linspace(spec_theta_start, spec_theta_end, spec_num_theta) + delta_theta = thetas[1]-thetas[0] + theta_range_approved = False + print(f'Theta range obtained from SPEC data: [{spec_theta_start}, {spec_theta_end})') + print(f'Theta step size = {delta_theta}') + print(f'Number of theta values: {spec_num_theta}') + exit('Done') + default_start = None + default_end = None + if station in ('id1a3', 'id3a'): + theta_range_approved = input_yesno(f'Accept this theta range (y/n)?', 'y') + if theta_range_approved: + theta_start = spec_theta_start + theta_end = spec_theta_end + num_theta = spec_num_theta + theta_index_start = 0 + elif station in ('id3b'): + if spec_theta_start <= 0.0 and spec_theta_end >= 180.0: + default_start = 0 + default_end = 180 + elif spec_theta_end-spec_theta_start == 180: + default_start = spec_theta_start + default_end = spec_theta_end + while not theta_range_approved: + theta_start = input_num(f'Enter the first theta (included)', ge=spec_theta_start, + lt=spec_theta_end, default=default_start) + theta_index_start = index_nearest(thetas, theta_start) + theta_start = thetas[theta_index_start] + theta_end = input_num(f'Enter the last theta (excluded)', + ge=theta_start+delta_theta, le=spec_theta_end, default=default_end) + theta_index_end = index_nearest(thetas, theta_end) + theta_end = thetas[theta_index_end] + num_theta = theta_index_end-theta_index_start + print(f'Selected theta range: [{theta_start}, {theta_start+delta_theta}, ..., '+ + f'{theta_end})') + print(f'Number of theta values: {num_theta}') + theta_range_approved = input_yesno(f'Accept this theta range (y/n)?', 'y') + self.thetas = np.linspace(theta_start, theta_end, num_theta) + + def image_range_cli(self, attr_desc, detector_prefix): + stack_info = self.stack_info + for scan_number in self.scan_numbers: + # Parse the available image range + parser = self.get_scanparser(scan_number) + image_offset = parser.starting_image_offset + num_image = parser.get_num_image(detector_prefix.upper()) + scan_index = self.get_scan_index(scan_number) + + # Select the image set matching the theta range + num_theta = self.theta_range['num'] + theta_index_start = self.theta_range['start_index'] + if num_theta > num_image-theta_index_start: + raise ValueError(f'Available {attr_desc}image indices incompatible with thetas:'+ + f'\n\tNumber of thetas and offset = {num_theta} and {theta_index_start}'+ + f'\n\tNumber of available images {num_image}') + if scan_index is not None: + scan_info = stack_info[scan_index] + scan_info['starting_image_offset'] = image_offset+theta_index_start + scan_info['num_image'] = num_theta + scan_info['ref_x'] = parser.horizontal_shift + scan_info['ref_z'] = parser.vertical_shift + else: + stack_info.append({'scan_number': scan_number, + 'starting_image_offset': image_offset+theta_index_start, + 'num_image': num_theta, 'ref_x': parser.horizontal_shift, + 'ref_z': parser.vertical_shift}) + self.stack_info = stack_info + + def cli(self, **cli_kwargs): + if cli_kwargs.get('attr_desc') is not None: + attr_desc = f'{cli_kwargs["attr_desc"]} ' + else: + attr_desc = '' + cycle = cli_kwargs.get('cycle') + btr = cli_kwargs.get('btr') + station = cli_kwargs.get('station') + detector = cli_kwargs.get('detector') + print(f'\n -- Configure the location of the {attr_desc}scan data -- ') + if station in ('id1a3', 'id3a'): + basedir = f'/nfs/chess/{station}/{cycle}/{btr}' + runs = [d for d in os.listdir(basedir) if os.path.isdir(os.path.join(basedir, d))] +#RV index = 15-1 +#RV index = 7-1 + index = input_menu(runs, header='Choose a sample directory') + self.spec_file = f'{basedir}/{runs[index]}/spec.log' + self.scan_numbers_cli(attr_desc, station=station, scan_type='ts1') + else: + self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path') + self.scan_numbers_cli(attr_desc) + for scan_number in self.scan_numbers: + self.theta_range_cli(scan_number, attr_desc, station) + self.image_range_cli(attr_desc, detector.prefix) + + +class Sample(BaseModel): + name: constr(min_length=1) + description: Optional[str] + rotation_angles: Optional[list] + x_translations: Optional[list] + z_translations: Optional[list] + + @classmethod + def construct_from_nxsample(cls, nxsample:NXsample): + config = {} + config['name'] = nxsample.name.nxdata + if 'description' in nxsample: + config['description'] = nxsample.description.nxdata + if 'rotation_angle' in nxsample: + config['rotation_angle'] = nxsample.rotation_angle.nxdata + if 'x_translation' in nxsample: + config['x_translation'] = nxsample.x_translation.nxdata + if 'z_translation' in nxsample: + config['z_translation'] = nxsample.z_translation.nxdata + return(cls(**config)) + + def cli(self): + print('\n -- Configure the sample metadata -- ') +#RV self.name = 'test' +#RV self.name = 'sobhani-3249-A' + self.set_single_attr_cli('name', 'the sample name') +#RV self.description = 'test sample' + self.set_single_attr_cli('description', 'a description of the sample (optional)') + + +class MapConfig(BaseModel): + cycle: constr(strip_whitespace=True, min_length=1) + btr: constr(strip_whitespace=True, min_length=1) + title: constr(strip_whitespace=True, min_length=1) + station: Literal['id1a3', 'id3a', 'id3b'] = None + sample: Sample + detector: Detector = Detector.construct() + tomo_fields: TomoField + dark_field: Optional[FlatField] + bright_field: FlatField + _thetas: list[float] = PrivateAttr() + _field_types = ({'name': 'dark_field', 'image_key': 2}, {'name': 'bright_field', + 'image_key': 1}, {'name': 'tomo_fields', 'image_key': 0}) + + @classmethod + def construct_from_nxentry(cls, nxentry:NXentry): + config = {} + config['cycle'] = nxentry.instrument.source.attrs['cycle'] + config['btr'] = nxentry.instrument.source.attrs['btr'] + config['title'] = nxentry.nxname + config['station'] = nxentry.instrument.source.attrs['station'] + config['sample'] = Sample.construct_from_nxsample(nxentry['sample']) + for nxobject_name, nxobject in nxentry.spec_scans.items(): + if isinstance(nxobject, NXcollection): + config[nxobject_name] = SpecScans.construct_from_nxcollection(nxobject) + return(cls(**config)) + +#FIX cache? + @property + def thetas(self): + try: + return(self._thetas) + except: + theta_range = self.tomo_fields.theta_range + self._thetas = list(np.linspace(theta_range['start'], theta_range['end'], + theta_range['num'])) + return(self._thetas) + + def cli(self): + print('\n -- Configure a map from a set of SPEC scans (dark, bright, and tomo), '+ + 'and / or detector data -- ') +#RV self.cycle = '2021-3' +#RV self.cycle = '2022-2' +#RV self.cycle = '2023-1' + self.set_single_attr_cli('cycle', 'beam cycle') +#RV self.btr = 'z-3234-A' +#RV self.btr = 'sobhani-3249-A' +#RV self.btr = 'przybyla-3606-a' + self.set_single_attr_cli('btr', 'BTR') +#RV self.title = 'z-3234-A' +#RV self.title = 'tomo7C' +#RV self.title = 'cmc-test-dwell-1' + self.set_single_attr_cli('title', 'title for the map entry') +#RV self.station = 'id3a' +#RV self.station = 'id3b' +#RV self.station = 'id1a3' + self.set_single_attr_cli('station', 'name of the station at which scans were collected '+ + '(currently choose from: id1a3, id3a, id3b)') + import_scanparser(self.station) + self.set_single_attr_cli('sample') + use_detector_config = False + if hasattr(self.detector, 'prefix') and len(self.detector.prefix): + use_detector_config = input_yesno(f'Current detector settings:\n{self.detector}\n'+ + f'Keep these settings? (y/n)') + if not use_detector_config: +#RV have_detector_config = True + have_detector_config = input_yesno(f'Is a detector configuration file available? (y/n)') + if have_detector_config: +#RV detector_config_file = 'retiga.yaml' +#RV detector_config_file = 'andor2.yaml' + detector_config_file = input(f'Enter detector configuration file name: ') + have_detector_config = self.detector.construct_from_yaml(detector_config_file) + if not have_detector_config: + self.set_single_attr_cli('detector', 'detector') + self.set_single_attr_cli('tomo_fields', 'Tomo field', chain_attr_desc=True, + cycle=self.cycle, btr=self.btr, station=self.station, detector=self.detector) + if self.station in ('id1a3', 'id3a'): + have_dark_field = True + tomo_spec_file = self.tomo_fields.spec_file + else: + have_dark_field = input_yesno(f'Are Dark field images available? (y/n)') + tomo_spec_file = None + if have_dark_field: + self.set_single_attr_cli('dark_field', 'Dark field', chain_attr_desc=True, + station=self.station, detector=self.detector, spec_file=tomo_spec_file, + tomo_scan_numbers=self.tomo_fields.scan_numbers, scan_type='df1') + self.set_single_attr_cli('bright_field', 'Bright field', chain_attr_desc=True, + station=self.station, detector=self.detector, spec_file=tomo_spec_file, + tomo_scan_numbers=self.tomo_fields.scan_numbers, scan_type='bf1') + + def construct_nxentry(self, nxroot, include_raw_data=True): + # Construct base NXentry + nxentry = NXentry() + + # Add an NXentry to the NXroot + nxroot[self.title] = nxentry + nxroot.attrs['default'] = self.title + nxentry.definition = 'NXtomo' +# nxentry.attrs['default'] = 'data' + + # Add an NXinstrument to the NXentry + nxinstrument = NXinstrument() + nxentry.instrument = nxinstrument + + # Add an NXsource to the NXinstrument + nxsource = NXsource() + nxinstrument.source = nxsource + nxsource.type = 'Synchrotron X-ray Source' + nxsource.name = 'CHESS' + nxsource.probe = 'x-ray' + + # Tag the NXsource with the runinfo (as an attribute) + nxsource.attrs['cycle'] = self.cycle + nxsource.attrs['btr'] = self.btr + nxsource.attrs['station'] = self.station + + # Add an NXdetector to the NXinstrument (don't fill in data fields yet) + nxinstrument.detector = self.detector.construct_nxdetector() + + # Add an NXsample to NXentry (don't fill in data fields yet) + nxsample = NXsample() + nxentry.sample = nxsample + nxsample.name = self.sample.name + nxsample.description = self.sample.description + + # Add an NXcollection to the base NXentry to hold metadata about the spec scans in the map + # Also obtain the data fields in NXsample and NXdetector + nxspec_scans = NXcollection() + nxentry.spec_scans = nxspec_scans + image_keys = [] + sequence_numbers = [] + image_stacks = [] + rotation_angles = [] + x_translations = [] + z_translations = [] + for field_type in self._field_types: + field_name = field_type['name'] + field = getattr(self, field_name) + if field is None: + continue + image_key = field_type['image_key'] + if field_type['name'] == 'tomo_fields': + thetas = self.thetas + else: + thetas = None + # Add the scans in a single spec file + nxspec_scans[field_name] = field.construct_nxcollection(image_key, thetas, + self.detector) + if include_raw_data: + image_stacks.append(field.get_detector_data(self.detector.prefix)) + for scan_number in field.scan_numbers: + parser = field.get_scanparser(scan_number) + scan_info = field.stack_info[field.get_scan_index(scan_number)] + num_image = scan_info['num_image'] + image_keys += num_image*[image_key] + sequence_numbers += [i for i in range(num_image)] + if thetas is None: + rotation_angles += scan_info['num_image']*[0.0] + else: + assert(num_image == len(thetas)) + rotation_angles += thetas + x_translations += scan_info['num_image']*[scan_info['ref_x']] + z_translations += scan_info['num_image']*[scan_info['ref_z']] + + if include_raw_data: + # Add image data to NXdetector + nxinstrument.detector.image_key = image_keys + nxinstrument.detector.sequence_number = sequence_numbers + nxinstrument.detector.data = np.concatenate([image for image in image_stacks]) + + # Add image data to NXsample + nxsample.rotation_angle = rotation_angles + nxsample.rotation_angle.attrs['units'] = 'degrees' + nxsample.x_translation = x_translations + nxsample.x_translation.attrs['units'] = 'mm' + nxsample.z_translation = z_translations + nxsample.z_translation.attrs['units'] = 'mm' + + # Add an NXdata to NXentry + nxdata = NXdata() + nxentry.data = nxdata + nxdata.makelink(nxentry.instrument.detector.data, name='data') + nxdata.makelink(nxentry.instrument.detector.image_key) + nxdata.makelink(nxentry.sample.rotation_angle) + nxdata.makelink(nxentry.sample.x_translation) + nxdata.makelink(nxentry.sample.z_translation) +# nxdata.attrs['axes'] = ['field', 'row', 'column'] +# nxdata.attrs['field_indices'] = 0 +# nxdata.attrs['row_indices'] = 1 +# nxdata.attrs['column_indices'] = 2 + + +class TomoWorkflow(BaseModel): + sample_maps: conlist(item_type=MapConfig, min_items=1) = [MapConfig.construct()] + + @classmethod + def construct_from_nexus(cls, filename): + nxroot = nxload(filename) + sample_maps = [] + config = {'sample_maps': sample_maps} + for nxentry_name, nxentry in nxroot.items(): + sample_maps.append(MapConfig.construct_from_nxentry(nxentry)) + return(cls(**config)) + + def cli(self): + print('\n -- Configure a map -- ') + self.set_list_attr_cli('sample_maps', 'sample map') + + def construct_nxfile(self, filename, mode='w-'): + nxroot = NXroot() + t0 = time() + for sample_map in self.sample_maps: + logger.info(f'Start constructing the {sample_map.title} map.') + import_scanparser(sample_map.station) + sample_map.construct_nxentry(nxroot) + logger.info(f'Constructed all sample maps in {time()-t0:.2f} seconds.') + logger.info(f'Start saving all sample maps to {filename}.') + nxroot.save(filename, mode=mode) + + def write_to_nexus(self, filename): + t0 = time() + self.construct_nxfile(filename, mode='w') + logger.info(f'Saved all sample maps to {filename} in {time()-t0:.2f} seconds.')
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/workflow/run_tomo.py Fri Mar 10 16:02:04 2023 +0000 @@ -0,0 +1,1589 @@ +#!/usr/bin/env python3 + +import logging +logger = logging.getLogger(__name__) + +import numpy as np +try: + import numexpr as ne +except: + pass +try: + import scipy.ndimage as spi +except: + pass + +from multiprocessing import cpu_count +from nexusformat.nexus import * +from os import mkdir +from os import path as os_path +try: + from skimage.transform import iradon +except: + pass +try: + from skimage.restoration import denoise_tv_chambolle +except: + pass +from time import time +try: + import tomopy +except: + pass +from yaml import safe_load, safe_dump + +from msnctools.fit import Fit +from msnctools.general import illegal_value, is_int, is_int_pair, is_num, is_index_range, \ + input_int, input_num, input_yesno, input_menu, draw_mask_1d, select_image_bounds, \ + select_one_image_bound, clear_imshow, quick_imshow, clear_plot, quick_plot + +from workflow.models import import_scanparser, FlatField, TomoField, TomoWorkflow +from workflow.__version__ import __version__ + +num_core_tomopy_limit = 24 + +def nxcopy(nxobject:NXobject, exclude_nxpaths:list[str]=[], nxpath_prefix:str='') -> NXobject: + '''Function that returns a copy of a nexus object, optionally exluding certain child items. + + :param nxobject: the original nexus object to return a "copy" of + :type nxobject: nexusformat.nexus.NXobject + :param exlude_nxpaths: a list of paths to child nexus objects that + should be exluded from the returned "copy", defaults to `[]` + :type exclude_nxpaths: list[str], optional + :param nxpath_prefix: For use in recursive calls from inside this + function only! + :type nxpath_prefix: str + :return: a copy of `nxobject` with some children optionally exluded. + :rtype: NXobject + ''' + + nxobject_copy = nxobject.__class__() + if not len(nxpath_prefix): + if 'default' in nxobject.attrs: + nxobject_copy.attrs['default'] = nxobject.attrs['default'] + else: + for k, v in nxobject.attrs.items(): + nxobject_copy.attrs[k] = v + + for k, v in nxobject.items(): + nxpath = os_path.join(nxpath_prefix, k) + + if nxpath in exclude_nxpaths: + continue + + if isinstance(v, NXgroup): + nxobject_copy[k] = nxcopy(v, exclude_nxpaths=exclude_nxpaths, + nxpath_prefix=os_path.join(nxpath_prefix, k)) + else: + nxobject_copy[k] = v + + return(nxobject_copy) + +class set_numexpr_threads: + + def __init__(self, num_core): + if num_core is None or num_core < 1 or num_core > cpu_count(): + self.num_core = cpu_count() + else: + self.num_core = num_core + + def __enter__(self): + self.num_core_org = ne.set_num_threads(self.num_core) + + def __exit__(self, exc_type, exc_value, traceback): + ne.set_num_threads(self.num_core_org) + +class Tomo: + """Processing tomography data with misalignment. + """ + def __init__(self, galaxy_flag=False, num_core=-1, output_folder='.', save_figs=None, + test_mode=False): + """Initialize with optional config input file or dictionary + """ + if not isinstance(galaxy_flag, bool): + raise ValueError(f'Invalid parameter galaxy_flag ({galaxy_flag})') + self.galaxy_flag = galaxy_flag + self.num_core = num_core + if self.galaxy_flag: + if output_folder != '.': + logger.warning('Ignoring output_folder in galaxy mode') + self.output_folder = '.' + if test_mode != False: + logger.warning('Ignoring test_mode in galaxy mode') + self.test_mode = False + if save_figs is not None: + logger.warning('Ignoring save_figs in galaxy mode') + save_figs = 'only' + else: + self.output_folder = os_path.abspath(output_folder) + if not os_path.isdir(output_folder): + mkdir(os_path.abspath(output_folder)) + if not isinstance(test_mode, bool): + raise ValueError(f'Invalid parameter test_mode ({test_mode})') + self.test_mode = test_mode + if save_figs is None: + save_figs = 'no' + self.test_config = {} + if self.test_mode: + if save_figs != 'only': + logger.warning('Ignoring save_figs in test mode') + save_figs = 'only' + if save_figs == 'only': + self.save_only = True + self.save_figs = True + elif save_figs == 'yes': + self.save_only = False + self.save_figs = True + elif save_figs == 'no': + self.save_only = False + self.save_figs = False + else: + raise ValueError(f'Invalid parameter save_figs ({save_figs})') + if self.save_only: + self.block = False + else: + self.block = True + if self.num_core == -1: + self.num_core = cpu_count() + if not is_int(self.num_core, gt=0, log=False): + raise ValueError(f'Invalid parameter num_core ({num_core})') + if self.num_core > cpu_count(): + logger.warning(f'num_core = {self.num_core} is larger than the number of available ' + f'processors and reduced to {cpu_count()}') + self.num_core= cpu_count() + + def read(self, filename): + extension = os_path.splitext(filename)[1] + if extension == '.yml' or extension == '.yaml': + with open(filename, 'r') as f: + config = safe_load(f) +# if len(config) > 1: +# raise ValueError(f'Multiple root entries in {filename} not yet implemented') +# if len(list(config.values())[0]) > 1: +# raise ValueError(f'Multiple sample maps in {filename} not yet implemented') + return(config) + elif extension == '.nxs': + with NXFile(filename, mode='r') as nxfile: + nxroot = nxfile.readfile() + return(nxroot) + else: + raise ValueError(f'Invalid filename extension ({extension})') + + def write(self, data, filename): + extension = os_path.splitext(filename)[1] + if extension == '.yml' or extension == '.yaml': + with open(filename, 'w') as f: + safe_dump(data, f) + elif extension == '.nxs': + data.save(filename, mode='w') + elif extension == '.nc': + data.to_netcdf(os_path=filename) + else: + raise ValueError(f'Invalid filename extension ({extension})') + + def gen_reduced_data(self, data, img_x_bounds=None): + """Generate the reduced tomography images. + """ + logger.info('Generate the reduced tomography images') + + # Create plot galaxy path directory if needed + if self.galaxy_flag and not os_path.exists('tomo_reduce_plots'): + mkdir('tomo_reduce_plots') + + if isinstance(data, dict): + # Create Nexus format object from input dictionary + wf = TomoWorkflow(**data) + if len(wf.sample_maps) > 1: + raise ValueError(f'Multiple sample maps not yet implemented') +# print(f'\nwf:\n{wf}\n') + nxroot = NXroot() + t0 = time() + for sample_map in wf.sample_maps: + logger.info(f'Start constructing the {sample_map.title} map.') + import_scanparser(sample_map.station) + sample_map.construct_nxentry(nxroot, include_raw_data=False) + logger.info(f'Constructed all sample maps in {time()-t0:.2f} seconds.') + nxentry = nxroot[nxroot.attrs['default']] + # Get test mode configuration info + if self.test_mode: + self.test_config = data['sample_maps'][0]['test_mode'] + elif isinstance(data, NXroot): + nxentry = data[data.attrs['default']] + else: + raise ValueError(f'Invalid parameter data ({data})') + + # Create an NXprocess to store data reduction (meta)data + reduced_data = NXprocess() + + # Generate dark field + if 'dark_field' in nxentry['spec_scans']: + reduced_data = self._gen_dark(nxentry, reduced_data) + + # Generate bright field + reduced_data = self._gen_bright(nxentry, reduced_data) + + # Set vertical detector bounds for image stack + img_x_bounds = self._set_detector_bounds(nxentry, reduced_data, img_x_bounds=img_x_bounds) + logger.info(f'img_x_bounds = {img_x_bounds}') + reduced_data['img_x_bounds'] = img_x_bounds + + # Set zoom and/or theta skip to reduce memory the requirement + zoom_perc, num_theta_skip = self._set_zoom_or_skip() + if zoom_perc is not None: + reduced_data.attrs['zoom_perc'] = zoom_perc + if num_theta_skip is not None: + reduced_data.attrs['num_theta_skip'] = num_theta_skip + + # Generate reduced tomography fields + reduced_data = self._gen_tomo(nxentry, reduced_data) + + # Create a copy of the input Nexus object and remove raw and any existing reduced data + if isinstance(data, NXroot): + exclude_items = [f'{nxentry._name}/reduced_data/data', + f'{nxentry._name}/instrument/detector/data', + f'{nxentry._name}/instrument/detector/image_key', + f'{nxentry._name}/instrument/detector/sequence_number', + f'{nxentry._name}/sample/rotation_angle', + f'{nxentry._name}/sample/x_translation', + f'{nxentry._name}/sample/z_translation', + f'{nxentry._name}/data/data', + f'{nxentry._name}/data/image_key', + f'{nxentry._name}/data/rotation_angle', + f'{nxentry._name}/data/x_translation', + f'{nxentry._name}/data/z_translation'] + nxroot = nxcopy(data, exclude_nxpaths=exclude_items) + nxentry = nxroot[nxroot.attrs['default']] + + # Add the reduced data NXprocess + nxentry.reduced_data = reduced_data + + if 'data' not in nxentry: + nxentry.data = NXdata() + nxentry.attrs['default'] = 'data' + nxentry.data.makelink(nxentry.reduced_data.data.tomo_fields, name='reduced_data') + nxentry.data.makelink(nxentry.reduced_data.rotation_angle, name='rotation_angle') + nxentry.data.attrs['signal'] = 'reduced_data' + + return(nxroot) + + def find_centers(self, nxroot, center_rows=None): + """Find the calibrated center axis info + """ + logger.info('Find the calibrated center axis info') + + if not isinstance(nxroot, NXroot): + raise ValueError(f'Invalid parameter nxroot ({nxroot})') + nxentry = nxroot[nxroot.attrs['default']] + if not isinstance(nxentry, NXentry): + raise ValueError(f'Invalid nxentry ({nxentry})') + if self.galaxy_flag: + if center_rows is None: + raise ValueError(f'Missing parameter center_rows ({center_rows})') + if not is_int_pair(center_rows): + raise ValueError(f'Invalid parameter center_rows ({center_rows})') + elif center_rows is not None: + logging.warning(f'Ignoring parameter center_rows ({center_rows})') + center_rows = None + + # Create plot galaxy path directory and path if needed + if self.galaxy_flag: + if not os_path.exists('tomo_find_centers_plots'): + mkdir('tomo_find_centers_plots') + path = 'tomo_find_centers_plots' + else: + path = self.output_folder + + # Check if reduced data is available + if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data): + raise KeyError(f'Unable to find valid reduced data in {nxentry}.') + + # Select the image stack to calibrate the center axis + # reduced data axes order: stack,row,theta,column + # Note: Nexus cannot follow a link if the data it points to is too big, + # so get the data from the actual place, not from nxentry.data + num_tomo_stacks = nxentry.reduced_data.data.tomo_fields.shape[0] + if num_tomo_stacks == 1: + center_stack_index = 0 + center_stack = np.asarray(nxentry.reduced_data.data.tomo_fields[0]) + if not center_stack.size: + raise KeyError('Unable to load the required reduced tomography stack') + default = 'n' + else: + if self.test_mode: + center_stack_index = self.test_config['center_stack_index']-1 # make offset 0 + else: + center_stack_index = input_int('\nEnter tomography stack index to calibrate the ' + 'center axis', ge=0, le=num_tomo_stacks-1, default=int(num_tomo_stacks/2)) + center_stack = \ + np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index]) + if not center_stack.size: + raise KeyError('Unable to load the required reduced tomography stack') + default = 'y' + + # Get thetas (in degrees) + thetas = np.asarray(nxentry.reduced_data.rotation_angle) + + # Get effective pixel_size + if 'zoom_perc' in nxentry.reduced_data: + eff_pixel_size = 100.*(nxentry.instrument.detector.x_pixel_size/ + nxentry.reduced_data.attrs['zoom_perc']) + else: + eff_pixel_size = nxentry.instrument.detector.x_pixel_size + + # Get cross sectional diameter + cross_sectional_dim = center_stack.shape[2]*eff_pixel_size + logger.debug(f'cross_sectional_dim = {cross_sectional_dim}') + + # Determine center offset at sample row boundaries + logger.info('Determine center offset at sample row boundaries') + + # Lower row center + # center_stack order: row,theta,column + if self.test_mode: + lower_row = self.test_config['lower_row'] + elif self.galaxy_flag: + lower_row = min(center_rows) + if not 0 <= lower_row < center_stack.shape[0]-1: + raise ValueError(f'Invalid parameter center_rows ({center_rows})') + else: + lower_row = select_one_image_bound(center_stack[:,0,:], 0, bound=0, + title=f'theta={round(thetas[0], 2)+0}', + bound_name='row index to find lower center', default=default) + lower_center_offset = self._find_center_one_plane(center_stack[lower_row,:,:], lower_row, + thetas, eff_pixel_size, cross_sectional_dim, path=path, num_core=self.num_core) + logger.debug(f'lower_row = {lower_row:.2f}') + logger.debug(f'lower_center_offset = {lower_center_offset:.2f}') + + # Upper row center + if self.test_mode: + upper_row = self.test_config['upper_row'] + elif self.galaxy_flag: + upper_row = max(center_rows) + if not lower_row < upper_row < center_stack.shape[0]: + raise ValueError(f'Invalid parameter center_rows ({center_rows})') + else: + upper_row = select_one_image_bound(center_stack[:,0,:], 0, + bound=center_stack.shape[0]-1, title=f'theta={round(thetas[0], 2)+0}', + bound_name='row index to find upper center', default=default) + upper_center_offset = self._find_center_one_plane(center_stack[upper_row,:,:], upper_row, + thetas, eff_pixel_size, cross_sectional_dim, path=path, num_core=self.num_core) + logger.debug(f'upper_row = {upper_row:.2f}') + logger.debug(f'upper_center_offset = {upper_center_offset:.2f}') + del center_stack + + center_config = {'lower_row': lower_row, 'lower_center_offset': lower_center_offset, + 'upper_row': upper_row, 'upper_center_offset': upper_center_offset} + if num_tomo_stacks > 1: + center_config['center_stack_index'] = center_stack_index+1 # save as offset 1 + + # Save test data to file + if self.test_mode: + with open(f'{self.output_folder}/center_config.yaml', 'w') as f: + safe_dump(center_config, f) + + return(center_config) + + def reconstruct_data(self, nxroot, center_info, x_bounds=None, y_bounds=None): + """Reconstruct the tomography data. + """ + logger.info('Reconstruct the tomography data') + + if not isinstance(nxroot, NXroot): + raise ValueError(f'Invalid parameter nxroot ({nxroot})') + nxentry = nxroot[nxroot.attrs['default']] + if not isinstance(nxentry, NXentry): + raise ValueError(f'Invalid nxentry ({nxentry})') + if not isinstance(center_info, dict): + raise ValueError(f'Invalid parameter center_info ({center_info})') + + # Create plot galaxy path directory and path if needed + if self.galaxy_flag: + if not os_path.exists('tomo_reconstruct_plots'): + mkdir('tomo_reconstruct_plots') + path = 'tomo_reconstruct_plots' + else: + path = self.output_folder + + # Check if reduced data is available + if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data): + raise KeyError(f'Unable to find valid reduced data in {nxentry}.') + + # Create an NXprocess to store image reconstruction (meta)data +# if 'reconstructed_data' in nxentry: +# logger.warning(f'Existing reconstructed data in {nxentry} will be overwritten.') +# del nxentry['reconstructed_data'] +# if 'data' in nxentry and 'reconstructed_data' in nxentry.data: +# del nxentry.data['reconstructed_data'] + nxprocess = NXprocess() + + # Get rotation axis rows and centers + lower_row = center_info.get('lower_row') + lower_center_offset = center_info.get('lower_center_offset') + upper_row = center_info.get('upper_row') + upper_center_offset = center_info.get('upper_center_offset') + if (lower_row is None or lower_center_offset is None or upper_row is None or + upper_center_offset is None): + raise KeyError(f'Unable to find valid calibrated center axis info in {center_info}.') + center_slope = (upper_center_offset-lower_center_offset)/(upper_row-lower_row) + + # Get thetas (in degrees) + thetas = np.asarray(nxentry.reduced_data.rotation_angle) + + # Reconstruct tomography data + # reduced data axes order: stack,row,theta,column + # reconstructed data order in each stack: row/z,x,y + # Note: Nexus cannot follow a link if the data it points to is too big, + # so get the data from the actual place, not from nxentry.data + if 'zoom_perc' in nxentry.reduced_data: + res_title = f'{nxentry.reduced_data.attrs["zoom_perc"]}p' + else: + res_title = 'fullres' + load_error = False + num_tomo_stacks = nxentry.reduced_data.data.tomo_fields.shape[0] + tomo_recon_stacks = num_tomo_stacks*[np.array([])] + for i in range(num_tomo_stacks): + tomo_stack = np.asarray(nxentry.reduced_data.data.tomo_fields[i]) + if not tomo_stack.size: + raise KeyError(f'Unable to load tomography stack {i} for reconstruction') + assert(0 <= lower_row < upper_row < tomo_stack.shape[0]) + center_offsets = [lower_center_offset-lower_row*center_slope, + upper_center_offset+(tomo_stack.shape[0]-1-upper_row)*center_slope] + t0 = time() + logger.debug(f'Running _reconstruct_one_tomo_stack on {self.num_core} cores ...') + tomo_recon_stack = self._reconstruct_one_tomo_stack(tomo_stack, thetas, + center_offsets=center_offsets, num_core=self.num_core, algorithm='gridrec') + logger.debug(f'... done in {time()-t0:.2f} seconds') + logger.info(f'Reconstruction of stack {i} took {time()-t0:.2f} seconds') + + # Combine stacks + tomo_recon_stacks[i] = tomo_recon_stack + + # Resize the reconstructed tomography data + # reconstructed data order in each stack: row/z,x,y + if self.test_mode: + x_bounds = self.test_config.get('x_bounds') + y_bounds = self.test_config.get('y_bounds') + z_bounds = None + elif self.galaxy_flag: + if x_bounds is not None and not is_int_pair(x_bounds, ge=0, + lt=tomo_recon_stacks[0].shape[1]): + raise ValueError(f'Invalid parameter x_bounds ({x_bounds})') + if y_bounds is not None and not is_int_pair(y_bounds, ge=0, + lt=tomo_recon_stacks[0].shape[1]): + raise ValueError(f'Invalid parameter y_bounds ({y_bounds})') + z_bounds = None + else: + x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_stacks) + if x_bounds is None: + x_range = (0, tomo_recon_stacks[0].shape[1]) + x_slice = int(x_range[1]/2) + else: + x_range = (min(x_bounds), max(x_bounds)) + x_slice = int((x_bounds[0]+x_bounds[1])/2) + if y_bounds is None: + y_range = (0, tomo_recon_stacks[0].shape[2]) + y_slice = int(y_range[1]/2) + else: + y_range = (min(y_bounds), max(y_bounds)) + y_slice = int((y_bounds[0]+y_bounds[1])/2) + if z_bounds is None: + z_range = (0, tomo_recon_stacks[0].shape[0]) + z_slice = int(z_range[1]/2) + else: + z_range = (min(z_bounds), max(z_bounds)) + z_slice = int((z_bounds[0]+z_bounds[1])/2) + + # Plot a few reconstructed image slices + if num_tomo_stacks == 1: + basetitle = 'recon' + else: + basetitle = f'recon stack {i}' + for i, stack in enumerate(tomo_recon_stacks): + title = f'{basetitle} {res_title} xslice{x_slice}' + quick_imshow(stack[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]], + title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, + block=self.block) + title = f'{basetitle} {res_title} yslice{y_slice}' + quick_imshow(stack[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice], + title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, + block=self.block) + title = f'{basetitle} {res_title} zslice{z_slice}' + quick_imshow(stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], + title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, + block=self.block) + + # Save test data to file + # reconstructed data order in each stack: row/z,x,y + if self.test_mode: + for i, stack in enumerate(tomo_recon_stacks): + np.savetxt(f'{self.output_folder}/recon_stack_{i+1}.txt', + stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e') + + # Add image reconstruction to reconstructed data NXprocess + # reconstructed data order in each stack: row/z,x,y + nxprocess.data = NXdata() + nxprocess.attrs['default'] = 'data' + for k, v in center_info.items(): + nxprocess[k] = v + if x_bounds is not None: + nxprocess.x_bounds = x_bounds + if y_bounds is not None: + nxprocess.y_bounds = y_bounds + if z_bounds is not None: + nxprocess.z_bounds = z_bounds + nxprocess.data['reconstructed_data'] = np.asarray([stack[z_range[0]:z_range[1], + x_range[0]:x_range[1],y_range[0]:y_range[1]] for stack in tomo_recon_stacks]) + nxprocess.data.attrs['signal'] = 'reconstructed_data' + + # Create a copy of the input Nexus object and remove reduced data + exclude_items = [f'{nxentry._name}/reduced_data/data', f'{nxentry._name}/data/reduced_data'] + nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items) + + # Add the reconstructed data NXprocess to the new Nexus object + nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']] + nxentry_copy.reconstructed_data = nxprocess + if 'data' not in nxentry_copy: + nxentry_copy.data = NXdata() + nxentry_copy.attrs['default'] = 'data' + nxentry_copy.data.makelink(nxprocess.data.reconstructed_data, name='reconstructed_data') + nxentry_copy.data.attrs['signal'] = 'reconstructed_data' + + return(nxroot_copy) + + def combine_data(self, nxroot): + """Combine the reconstructed tomography stacks. + """ + logger.info('Combine the reconstructed tomography stacks') + + if not isinstance(nxroot, NXroot): + raise ValueError(f'Invalid parameter nxroot ({nxroot})') + nxentry = nxroot[nxroot.attrs['default']] + if not isinstance(nxentry, NXentry): + raise ValueError(f'Invalid nxentry ({nxentry})') + + # Create plot galaxy path directory and path if needed + if self.galaxy_flag: + if not os_path.exists('tomo_combine_plots'): + mkdir('tomo_combine_plots') + path = 'tomo_combine_plots' + else: + path = self.output_folder + + # Check if reconstructed image data is available + if ('reconstructed_data' not in nxentry or 'reconstructed_data' not in nxentry.data): + raise KeyError(f'Unable to find valid reconstructed image data in {nxentry}.') + + # Create an NXprocess to store combined image reconstruction (meta)data +# if 'combined_data' in nxentry: +# logger.warning(f'Existing combined data in {nxentry} will be overwritten.') +# del nxentry['combined_data'] +# if 'data' in nxentry 'combined_data' in nxentry.data: +# del nxentry.data['combined_data'] + nxprocess = NXprocess() + + # Get the reconstructed data + # reconstructed data order: stack,row(z),x,y + # Note: Nexus cannot follow a link if the data it points to is too big, + # so get the data from the actual place, not from nxentry.data + tomo_recon_stacks = np.asarray(nxentry.reconstructed_data.data.reconstructed_data) + num_tomo_stacks = tomo_recon_stacks.shape[0] + if num_tomo_stacks == 1: + return(nxroot) + t0 = time() + logger.debug(f'Combining the reconstructed stacks ...') + tomo_recon_combined = tomo_recon_stacks[0,:,:,:] + if num_tomo_stacks > 2: + tomo_recon_combined = np.concatenate([tomo_recon_combined]+ + [tomo_recon_stacks[i,:,:,:] for i in range(1, num_tomo_stacks-1)]) + if num_tomo_stacks > 1: + tomo_recon_combined = np.concatenate([tomo_recon_combined]+ + [tomo_recon_stacks[num_tomo_stacks-1,:,:,:]]) + logger.debug(f'... done in {time()-t0:.2f} seconds') + logger.info(f'Combining the reconstructed stacks took {time()-t0:.2f} seconds') + + # Resize the combined tomography data set + # combined data order: row/z,x,y + if self.test_mode: + x_bounds = None + y_bounds = None + z_bounds = self.test_config.get('z_bounds') + elif self.galaxy_flag: + exit('TODO') + if x_bounds is not None and not is_int_pair(x_bounds, ge=0, + lt=tomo_recon_stacks[0].shape[1]): + raise ValueError(f'Invalid parameter x_bounds ({x_bounds})') + if y_bounds is not None and not is_int_pair(y_bounds, ge=0, + lt=tomo_recon_stacks[0].shape[1]): + raise ValueError(f'Invalid parameter y_bounds ({y_bounds})') + z_bounds = None + else: + x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_combined, + z_only=True) + if x_bounds is None: + x_range = (0, tomo_recon_combined.shape[1]) + x_slice = int(x_range[1]/2) + else: + x_range = x_bounds + x_slice = int((x_bounds[0]+x_bounds[1])/2) + if y_bounds is None: + y_range = (0, tomo_recon_combined.shape[2]) + y_slice = int(y_range[1]/2) + else: + y_range = y_bounds + y_slice = int((y_bounds[0]+y_bounds[1])/2) + if z_bounds is None: + z_range = (0, tomo_recon_combined.shape[0]) + z_slice = int(z_range[1]/2) + else: + z_range = z_bounds + z_slice = int((z_bounds[0]+z_bounds[1])/2) + + # Plot a few combined image slices + quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]], + title=f'recon combined xslice{x_slice}', path=path, + save_fig=self.save_figs, save_only=self.save_only, block=self.block) + quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice], + title=f'recon combined yslice{y_slice}', path=path, + save_fig=self.save_figs, save_only=self.save_only, block=self.block) + quick_imshow(tomo_recon_combined[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], + title=f'recon combined zslice{z_slice}', path=path, + save_fig=self.save_figs, save_only=self.save_only, block=self.block) + + # Save test data to file + # combined data order: row/z,x,y + if self.test_mode: + np.savetxt(f'{self.output_folder}/recon_combined.txt', tomo_recon_combined[ + z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e') + + # Add image reconstruction to reconstructed data NXprocess + # combined data order: row/z,x,y + nxprocess.data = NXdata() + nxprocess.attrs['default'] = 'data' + if x_bounds is not None: + nxprocess.x_bounds = x_bounds + if y_bounds is not None: + nxprocess.y_bounds = y_bounds + if z_bounds is not None: + nxprocess.z_bounds = z_bounds + nxprocess.data['combined_data'] = tomo_recon_combined + nxprocess.data.attrs['signal'] = 'combined_data' + + # Create a copy of the input Nexus object and remove reconstructed data + exclude_items = [f'{nxentry._name}/reconstructed_data/data', + f'{nxentry._name}/data/reconstructed_data'] + nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items) + + # Add the combined data NXprocess to the new Nexus object + nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']] + nxentry_copy.combined_data = nxprocess + if 'data' not in nxentry_copy: + nxentry_copy.data = NXdata() + nxentry_copy.attrs['default'] = 'data' + nxentry_copy.data.makelink(nxprocess.data.combined_data, name='combined_data') + nxentry_copy.data.attrs['signal'] = 'combined_data' + + return(nxroot_copy) + + def _gen_dark(self, nxentry, reduced_data): + """Generate dark field. + """ + # Get the dark field images + image_key = nxentry.instrument.detector.get('image_key', None) + if image_key and 'data' in nxentry.instrument.detector: + field_indices = [index for index, key in enumerate(image_key) if key == 2] + tdf_stack = nxentry.instrument.detector.data[field_indices,:,:] + # RV the default NXtomo form does not accomodate bright or dark field stacks + else: + dark_field_scans = nxentry.spec_scans.dark_field + dark_field = FlatField.construct_from_nxcollection(dark_field_scans) + prefix = str(nxentry.instrument.detector.local_name) + tdf_stack = dark_field.get_detector_data(prefix) + if isinstance(tdf_stack, list): + exit('TODO') + + # Take median + if tdf_stack.ndim == 2: + tdf = tdf_stack + elif tdf_stack.ndim == 3: + tdf = np.median(tdf_stack, axis=0) + del tdf_stack + else: + raise ValueError(f'Invalid tdf_stack shape ({tdf_stack.shape})') + + # Remove dark field intensities above the cutoff +#RV tdf_cutoff = None + tdf_cutoff = tdf.min()+2*(np.median(tdf)-tdf.min()) + logger.debug(f'tdf_cutoff = {tdf_cutoff}') + if tdf_cutoff is not None: + if not is_num(tdf_cutoff, ge=0): + logger.warning(f'Ignoring illegal value of tdf_cutoff {tdf_cutoff}') + else: + tdf[tdf > tdf_cutoff] = np.nan + logger.debug(f'tdf_cutoff = {tdf_cutoff}') + + # Remove nans + tdf_mean = np.nanmean(tdf) + logger.debug(f'tdf_mean = {tdf_mean}') + np.nan_to_num(tdf, copy=False, nan=tdf_mean, posinf=tdf_mean, neginf=0.) + + # Plot dark field + if self.galaxy_flag: + quick_imshow(tdf, title='dark field', path='tomo_reduce_plots', save_fig=self.save_figs, + save_only=self.save_only) + elif not self.test_mode: + quick_imshow(tdf, title='dark field', path=self.output_folder, save_fig=self.save_figs, + save_only=self.save_only) + clear_imshow('dark field') +# quick_imshow(tdf, title='dark field', block=True) + + # Add dark field to reduced data NXprocess + reduced_data.data = NXdata() + reduced_data.data['dark_field'] = tdf + + return(reduced_data) + + def _gen_bright(self, nxentry, reduced_data): + """Generate bright field. + """ + # Get the bright field images + image_key = nxentry.instrument.detector.get('image_key', None) + if image_key and 'data' in nxentry.instrument.detector: + field_indices = [index for index, key in enumerate(image_key) if key == 1] + tbf_stack = nxentry.instrument.detector.data[field_indices,:,:] + # RV the default NXtomo form does not accomodate bright or dark field stacks + else: + bright_field_scans = nxentry.spec_scans.bright_field + bright_field = FlatField.construct_from_nxcollection(bright_field_scans) + prefix = str(nxentry.instrument.detector.local_name) + tbf_stack = bright_field.get_detector_data(prefix) + if isinstance(tbf_stack, list): + exit('TODO') + + # Take median if more than one image + """Median or mean: It may be best to try the median because of some image + artifacts that arise due to crinkles in the upstream kapton tape windows + causing some phase contrast images to appear on the detector. + One thing that also may be useful in a future implementation is to do a + brightfield adjustment on EACH frame of the tomo based on a ROI in the + corner of the frame where there is no sample but there is the direct X-ray + beam because there is frame to frame fluctuations from the incoming beam. + We don’t typically account for them but potentially could. + """ + if tbf_stack.ndim == 2: + tbf = tbf_stack + elif tbf_stack.ndim == 3: + tbf = np.median(tbf_stack, axis=0) + del tbf_stack + else: + raise ValueError(f'Invalid tbf_stack shape ({tbf_stacks.shape})') + + # Subtract dark field + if 'data' in reduced_data and 'dark_field' in reduced_data.data: + tbf -= reduced_data.data.dark_field + else: + logger.warning('Dark field unavailable') + + # Set any non-positive values to one + # (avoid negative bright field values for spikes in dark field) + tbf[tbf < 1] = 1 + + # Plot bright field + if self.galaxy_flag: + quick_imshow(tbf, title='bright field', path='tomo_reduce_plots', + save_fig=self.save_figs, save_only=self.save_only) + elif not self.test_mode: + quick_imshow(tbf, title='bright field', path=self.output_folder, + save_fig=self.save_figs, save_only=self.save_only) + clear_imshow('bright field') +# quick_imshow(tbf, title='bright field', block=True) + + # Add bright field to reduced data NXprocess + if 'data' not in reduced_data: + reduced_data.data = NXdata() + reduced_data.data['bright_field'] = tbf + + return(reduced_data) + + def _set_detector_bounds(self, nxentry, reduced_data, img_x_bounds=None): + """Set vertical detector bounds for each image stack. + Right now the range is the same for each set in the image stack. + """ + if self.test_mode: + return(tuple(self.test_config['img_x_bounds'])) + + # Get the first tomography image and the reference heights + image_key = nxentry.instrument.detector.get('image_key', None) + if image_key and 'data' in nxentry.instrument.detector: + field_indices = [index for index, key in enumerate(image_key) if key == 0] + first_image = np.asarray(nxentry.instrument.detector.data[field_indices[0],:,:]) + theta = float(nxentry.sample.rotation_angle[field_indices[0]]) + z_translation_all = nxentry.sample.z_translation[field_indices] + z_translation_levels = sorted(list(set(z_translation_all))) + num_tomo_stacks = len(z_translation_levels) + else: + tomo_field_scans = nxentry.spec_scans.tomo_fields + tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans) + vertical_shifts = tomo_fields.get_vertical_shifts() + if not isinstance(vertical_shifts, list): + vertical_shifts = [vertical_shifts] + prefix = str(nxentry.instrument.detector.local_name) + t0 = time() + first_image = tomo_fields.get_detector_data(prefix, tomo_fields.scan_numbers[0], 0) + logger.debug(f'Getting first image took {time()-t0:.2f} seconds') + num_tomo_stacks = len(tomo_fields.scan_numbers) + theta = tomo_fields.theta_range['start'] + + # Select image bounds + title = f'tomography image at theta={round(theta, 2)+0}' + if (img_x_bounds is not None and not is_index_range(img_x_bounds, ge=0, + le=first_image.shape[0])): + raise ValueError(f'Invalid parameter img_x_bounds ({img_x_bounds})') + if nxentry.instrument.source.attrs['station'] in ('id1a3', 'id3a'): + pixel_size = nxentry.instrument.detector.x_pixel_size + # Try to get a fit from the bright field + tbf = np.asarray(reduced_data.data.bright_field) + tbf_shape = tbf.shape + x_sum = np.sum(tbf, 1) + x_sum_min = x_sum.min() + x_sum_max = x_sum.max() + fit = Fit.fit_data(x_sum, 'rectangle', x=np.array(range(len(x_sum))), form='atan', + guess=True) + parameters = fit.best_values + x_low_fit = parameters.get('center1', None) + x_upp_fit = parameters.get('center2', None) + sig_low = parameters.get('sigma1', None) + sig_upp = parameters.get('sigma2', None) + have_fit = fit.success and x_low_fit is not None and x_upp_fit is not None and \ + sig_low is not None and sig_upp is not None and \ + 0 <= x_low_fit < x_upp_fit <= x_sum.size and \ + (sig_low+sig_upp)/(x_upp_fit-x_low_fit) < 0.1 + if have_fit: + # Set a 5% margin on each side + margin = 0.05*(x_upp_fit-x_low_fit) + x_low_fit = max(0, x_low_fit-margin) + x_upp_fit = min(tbf_shape[0], x_upp_fit+margin) + if num_tomo_stacks == 1: + if have_fit: + # Set the default range to enclose the full fitted window + x_low = int(x_low_fit) + x_upp = int(x_upp_fit) + else: + # Center a default range of 1 mm (RV: can we get this from the slits?) + num_x_min = int((1.0-0.5*pixel_size)/pixel_size) + x_low = int(0.5*(tbf_shape[0]-num_x_min)) + x_upp = x_low+num_x_min + else: + # Get the default range from the reference heights + delta_z = vertical_shifts[1]-vertical_shifts[0] + for i in range(2, num_tomo_stacks): + delta_z = min(delta_z, vertical_shifts[i]-vertical_shifts[i-1]) + logger.debug(f'delta_z = {delta_z}') + num_x_min = int((delta_z-0.5*pixel_size)/pixel_size) + logger.debug(f'num_x_min = {num_x_min}') + if num_x_min > tbf_shape[0]: + logger.warning('Image bounds and pixel size prevent seamless stacking') + if have_fit: + # Center the default range relative to the fitted window + x_low = int(0.5*(x_low_fit+x_upp_fit-num_x_min)) + x_upp = x_low+num_x_min + else: + # Center the default range + x_low = int(0.5*(tbf_shape[0]-num_x_min)) + x_upp = x_low+num_x_min + if self.galaxy_flag: + img_x_bounds = (x_low, x_upp) + else: + tmp = np.copy(tbf) + tmp_max = tmp.max() + tmp[x_low,:] = tmp_max + tmp[x_upp-1,:] = tmp_max + quick_imshow(tmp, title='bright field') + tmp = np.copy(first_image) + tmp_max = tmp.max() + tmp[x_low,:] = tmp_max + tmp[x_upp-1,:] = tmp_max + quick_imshow(tmp, title=title) + del tmp + quick_plot((range(x_sum.size), x_sum), + ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'), + ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'), + title='sum over theta and y') + print(f'lower bound = {x_low} (inclusive)') + print(f'upper bound = {x_upp} (exclusive)]') + accept = input_yesno('Accept these bounds (y/n)?', 'y') + clear_imshow('bright field') + clear_imshow(title) + clear_plot('sum over theta and y') + if accept: + img_x_bounds = (x_low, x_upp) + else: + while True: + mask, img_x_bounds = draw_mask_1d(x_sum, title='select x data range', + legend='sum over theta and y') + if len(img_x_bounds) == 1: + break + else: + print(f'Choose a single connected data range') + img_x_bounds = tuple(img_x_bounds[0]) + if (num_tomo_stacks > 1 and img_x_bounds[1]-img_x_bounds[0]+1 < + int((delta_z-0.5*pixel_size)/pixel_size)): + logger.warning('Image bounds and pixel size prevent seamless stacking') + else: + if num_tomo_stacks > 1: + raise NotImplementedError('Selecting image bounds for multiple stacks on FMB') + # For FMB: use the first tomography image to select range + # RV: revisit if they do tomography with multiple stacks + x_sum = np.sum(first_image, 1) + x_sum_min = x_sum.min() + x_sum_max = x_sum.max() + if self.galaxy_flag: + if img_x_bounds is None: + img_x_bounds = (0, first_image.shape[0]) + else: + quick_imshow(first_image, title=title) + print('Select vertical data reduction range from first tomography image') + img_x_bounds = select_image_bounds(first_image, 0, title=title) + clear_imshow(title) + if img_x_bounds is None: + raise ValueError('Unable to select image bounds') + + # Plot results + if self.galaxy_flag: + path = 'tomo_reduce_plots' + else: + path = self.output_folder + x_low = img_x_bounds[0] + x_upp = img_x_bounds[1] + tmp = np.copy(first_image) + tmp_max = tmp.max() + tmp[x_low,:] = tmp_max + tmp[x_upp-1,:] = tmp_max + quick_imshow(tmp, title=title, path=path, save_fig=self.save_figs, save_only=self.save_only, + block=self.block) + del tmp + quick_plot((range(x_sum.size), x_sum), + ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'), + ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'), + title='sum over theta and y', path=path, save_fig=self.save_figs, + save_only=self.save_only, block=self.block) + + return(img_x_bounds) + + def _set_zoom_or_skip(self): + """Set zoom and/or theta skip to reduce memory the requirement for the analysis. + """ +# if input_yesno('\nDo you want to zoom in to reduce memory requirement (y/n)?', 'n'): +# zoom_perc = input_int(' Enter zoom percentage', ge=1, le=100) +# else: +# zoom_perc = None + zoom_perc = None +# if input_yesno('Do you want to skip thetas to reduce memory requirement (y/n)?', 'n'): +# num_theta_skip = input_int(' Enter the number skip theta interval', ge=0, +# lt=num_theta) +# else: +# num_theta_skip = None + num_theta_skip = None + logger.debug(f'zoom_perc = {zoom_perc}') + logger.debug(f'num_theta_skip = {num_theta_skip}') + + return(zoom_perc, num_theta_skip) + + def _gen_tomo(self, nxentry, reduced_data): + """Generate tomography fields. + """ + # Get full bright field + tbf = np.asarray(reduced_data.data.bright_field) + tbf_shape = tbf.shape + + # Get image bounds + img_x_bounds = tuple(reduced_data.get('img_x_bounds', (0, tbf_shape[0]))) + img_y_bounds = tuple(reduced_data.get('img_y_bounds', (0, tbf_shape[1]))) + + # Get resized dark field +# if 'dark_field' in data: +# tbf = np.asarray(reduced_data.data.dark_field[ +# img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]]) +# else: +# logger.warning('Dark field unavailable') +# tdf = None + tdf = None + + # Resize bright field + if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]): + tbf = tbf[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]] + + # Get the tomography images + image_key = nxentry.instrument.detector.get('image_key', None) + if image_key and 'data' in nxentry.instrument.detector: + field_indices_all = [index for index, key in enumerate(image_key) if key == 0] + z_translation_all = nxentry.sample.z_translation[field_indices_all] + z_translation_levels = sorted(list(set(z_translation_all))) + num_tomo_stacks = len(z_translation_levels) + tomo_stacks = num_tomo_stacks*[np.array([])] + horizontal_shifts = [] + vertical_shifts = [] + thetas = None + tomo_stacks = [] + for i, z_translation in enumerate(z_translation_levels): + field_indices = [field_indices_all[index] + for index, z in enumerate(z_translation_all) if z == z_translation] + horizontal_shift = list(set(nxentry.sample.x_translation[field_indices])) + assert(len(horizontal_shift) == 1) + horizontal_shifts += horizontal_shift + vertical_shift = list(set(nxentry.sample.z_translation[field_indices])) + assert(len(vertical_shift) == 1) + vertical_shifts += vertical_shift + sequence_numbers = nxentry.instrument.detector.sequence_number[field_indices] + if thetas is None: + thetas = np.asarray(nxentry.sample.rotation_angle[field_indices]) \ + [sequence_numbers] + else: + assert(all(thetas[i] == nxentry.sample.rotation_angle[field_indices[index]] + for i, index in enumerate(sequence_numbers))) + assert(list(set(sequence_numbers)) == [i for i in range(len(sequence_numbers))]) + if list(sequence_numbers) == [i for i in range(len(sequence_numbers))]: + tomo_stack = np.asarray(nxentry.instrument.detector.data[field_indices]) + else: + raise ValueError('Unable to load the tomography images') + tomo_stacks.append(tomo_stack) + else: + tomo_field_scans = nxentry.spec_scans.tomo_fields + tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans) + horizontal_shifts = tomo_fields.get_horizontal_shifts() + vertical_shifts = tomo_fields.get_vertical_shifts() + prefix = str(nxentry.instrument.detector.local_name) + t0 = time() + tomo_stacks = tomo_fields.get_detector_data(prefix) + logger.debug(f'Getting tomography images took {time()-t0:.2f} seconds') + logger.debug(f'Getting all images took {time()-t0:.2f} seconds') + thetas = np.linspace(tomo_fields.theta_range['start'], tomo_fields.theta_range['end'], + tomo_fields.theta_range['num']) + if not isinstance(tomo_stacks, list): + horizontal_shifts = [horizontal_shifts] + vertical_shifts = [vertical_shifts] + tomo_stacks = [tomo_stacks] + + reduced_tomo_stacks = [] + if self.galaxy_flag: + path = 'tomo_reduce_plots' + else: + path = self.output_folder + for i, tomo_stack in enumerate(tomo_stacks): + # Resize the tomography images + # Right now the range is the same for each set in the image stack. + if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]): + t0 = time() + tomo_stack = tomo_stack[:,img_x_bounds[0]:img_x_bounds[1], + img_y_bounds[0]:img_y_bounds[1]].astype('float64') + logger.debug(f'Resizing tomography images took {time()-t0:.2f} seconds') + + # Subtract dark field + if tdf is not None: + t0 = time() + with set_numexpr_threads(self.num_core): + ne.evaluate('tomo_stack-tdf', out=tomo_stack) + logger.debug(f'Subtracting dark field took {time()-t0:.2f} seconds') + + # Normalize + t0 = time() + with set_numexpr_threads(self.num_core): + ne.evaluate('tomo_stack/tbf', out=tomo_stack, truediv=True) + logger.debug(f'Normalizing took {time()-t0:.2f} seconds') + + # Remove non-positive values and linearize data + t0 = time() + cutoff = 1.e-6 + with set_numexpr_threads(self.num_core): + ne.evaluate('where(tomo_stack<cutoff, cutoff, tomo_stack)', out=tomo_stack) + with set_numexpr_threads(self.num_core): + ne.evaluate('-log(tomo_stack)', out=tomo_stack) + logger.debug('Removing non-positive values and linearizing data took '+ + f'{time()-t0:.2f} seconds') + + # Get rid of nans/infs that may be introduced by normalization + t0 = time() + np.where(np.isfinite(tomo_stack), tomo_stack, 0.) + logger.debug(f'Remove nans/infs took {time()-t0:.2f} seconds') + + # Downsize tomography stack to smaller size + # TODO use theta_skip as well + tomo_stack = tomo_stack.astype('float32') + if not self.test_mode: + if len(tomo_stacks) == 1: + title = f'red fullres theta {round(thetas[0], 2)+0}' + else: + title = f'red stack {i} fullres theta {round(thetas[0], 2)+0}' + quick_imshow(tomo_stack[0,:,:], title=title, path=path, save_fig=self.save_figs, + save_only=self.save_only, block=self.block) +# if not self.block: +# clear_imshow(title) + if False and zoom_perc != 100: + t0 = time() + logger.debug(f'Zooming in ...') + tomo_zoom_list = [] + for j in range(tomo_stack.shape[0]): + tomo_zoom = spi.zoom(tomo_stack[j,:,:], 0.01*zoom_perc) + tomo_zoom_list.append(tomo_zoom) + tomo_stack = np.stack([tomo_zoom for tomo_zoom in tomo_zoom_list]) + logger.debug(f'... done in {time()-t0:.2f} seconds') + logger.info(f'Zooming in took {time()-t0:.2f} seconds') + del tomo_zoom_list + if not self.test_mode: + title = f'red stack {zoom_perc}p theta {round(thetas[0], 2)+0}' + quick_imshow(tomo_stack[0,:,:], title=title, path=path, save_fig=self.save_figs, + save_only=self.save_only, block=self.block) +# if not self.block: +# clear_imshow(title) + + # Convert tomography stack from theta,row,column to row,theta,column + t0 = time() + tomo_stack = np.swapaxes(tomo_stack, 0, 1) + logger.debug(f'Converting coordinate order took {time()-t0:.2f} seconds') + + # Save test data to file + if self.test_mode: + row_index = int(tomo_stack.shape[0]/2) + np.savetxt(f'{self.output_folder}/red_stack_{i+1}.txt', tomo_stack[row_index,:,:], + fmt='%.6e') + + # Combine resized stacks + reduced_tomo_stacks.append(tomo_stack) + + # Add tomo field info to reduced data NXprocess + reduced_data['rotation_angle'] = thetas + reduced_data['x_translation'] = np.asarray(horizontal_shifts) + reduced_data['z_translation'] = np.asarray(vertical_shifts) + reduced_data.data['tomo_fields'] = np.asarray(reduced_tomo_stacks) + + if tdf is not None: + del tdf + del tbf + + return(reduced_data) + + def _find_center_one_plane(self, sinogram, row, thetas, eff_pixel_size, cross_sectional_dim, + path=None, tol=0.1, num_core=1): + """Find center for a single tomography plane. + """ + # Try automatic center finding routines for initial value + # sinogram index order: theta,column + # need column,theta for iradon, so take transpose + sinogram_T = sinogram.T + center = sinogram.shape[1]/2 + + # Try using Nghia Vo’s method + t0 = time() + if num_core > num_core_tomopy_limit: + logger.debug(f'Running find_center_vo on {num_core_tomopy_limit} cores ...') + tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core_tomopy_limit) + else: + logger.debug(f'Running find_center_vo on {num_core} cores ...') + tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core) + logger.debug(f'... done in {time()-t0:.2f} seconds') + logger.info(f'Finding the center using Nghia Vo’s method took {time()-t0:.2f} seconds') + center_offset_vo = tomo_center-center + logger.info(f'Center at row {row} using Nghia Vo’s method = {center_offset_vo:.2f}') + t0 = time() + logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...') + recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas, + eff_pixel_size, cross_sectional_dim, False, num_core) + logger.debug(f'... done in {time()-t0:.2f} seconds') + logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds') + + title = f'edges row{row} center offset{center_offset_vo:.2f} Vo' + self._plot_edges_one_plane(recon_plane, title, path=path) + + # Try using phase correlation method +# if input_yesno('Try finding center using phase correlation (y/n)?', 'n'): +# t0 = time() +# logger.debug(f'Running find_center_pc ...') +# tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=0.1, rotc_guess=tomo_center) +# error = 1. +# while error > tol: +# prev = tomo_center +# tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=tol, +# rotc_guess=tomo_center) +# error = np.abs(tomo_center-prev) +# logger.debug(f'... done in {time()-t0:.2f} seconds') +# logger.info('Finding the center using the phase correlation method took '+ +# f'{time()-t0:.2f} seconds') +# center_offset = tomo_center-center +# print(f'Center at row {row} using phase correlation = {center_offset:.2f}') +# t0 = time() +# logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...') +# recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas, +# eff_pixel_size, cross_sectional_dim, False, num_core) +# logger.debug(f'... done in {time()-t0:.2f} seconds') +# logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds') +# +# title = f'edges row{row} center_offset{center_offset:.2f} PC' +# self._plot_edges_one_plane(recon_plane, title, path=path) + + # Select center location +# if input_yesno('Accept a center location (y) or continue search (n)?', 'y'): + if True: +# center_offset = input_num(' Enter chosen center offset', ge=-center, le=center, +# default=center_offset_vo) + center_offset = center_offset_vo + del sinogram_T + del recon_plane + return float(center_offset) + + # perform center finding search + while True: + center_offset_low = input_int('\nEnter lower bound for center offset', ge=-center, + le=center) + center_offset_upp = input_int('Enter upper bound for center offset', + ge=center_offset_low, le=center) + if center_offset_upp == center_offset_low: + center_offset_step = 1 + else: + center_offset_step = input_int('Enter step size for center offset search', ge=1, + le=center_offset_upp-center_offset_low) + num_center_offset = 1+int((center_offset_upp-center_offset_low)/center_offset_step) + center_offsets = np.linspace(center_offset_low, center_offset_upp, num_center_offset) + for center_offset in center_offsets: + if center_offset == center_offset_vo: + continue + t0 = time() + logger.debug(f'Running _reconstruct_one_plane on {num_core} cores ...') + recon_plane = self._reconstruct_one_plane(sinogram_T, center_offset+center, thetas, + eff_pixel_size, cross_sectional_dim, False, num_core) + logger.debug(f'... done in {time()-t0:.2f} seconds') + logger.info(f'Reconstructing center_offset {center_offset} took '+ + f'{time()-t0:.2f} seconds') + title = f'edges row{row} center_offset{center_offset:.2f}' + self._plot_edges_one_plane(recon_plane, title, path=path) + if input_int('\nContinue (0) or end the search (1)', ge=0, le=1): + break + + del sinogram_T + del recon_plane + center_offset = input_num(' Enter chosen center offset', ge=-center, le=center) + return float(center_offset) + + def _reconstruct_one_plane(self, tomo_plane_T, center, thetas, eff_pixel_size, + cross_sectional_dim, plot_sinogram=True, num_core=1): + """Invert the sinogram for a single tomography plane. + """ + # tomo_plane_T index order: column,theta + assert(0 <= center < tomo_plane_T.shape[0]) + center_offset = center-tomo_plane_T.shape[0]/2 + two_offset = 2*int(np.round(center_offset)) + two_offset_abs = np.abs(two_offset) + max_rad = int(0.55*(cross_sectional_dim/eff_pixel_size)) # 10% slack to avoid edge effects + if max_rad > 0.5*tomo_plane_T.shape[0]: + max_rad = 0.5*tomo_plane_T.shape[0] + dist_from_edge = max(1, int(np.floor((tomo_plane_T.shape[0]-two_offset_abs)/2.)-max_rad)) + if two_offset >= 0: + logger.debug(f'sinogram range = [{two_offset+dist_from_edge}, {-dist_from_edge}]') + sinogram = tomo_plane_T[two_offset+dist_from_edge:-dist_from_edge,:] + else: + logger.debug(f'sinogram range = [{dist_from_edge}, {two_offset-dist_from_edge}]') + sinogram = tomo_plane_T[dist_from_edge:two_offset-dist_from_edge,:] + if not self.galaxy_flag and plot_sinogram: + quick_imshow(sinogram.T, f'sinogram center offset{center_offset:.2f}', aspect='auto', + path=self.output_folder, save_fig=self.save_figs, save_only=self.save_only, + block=self.block) + + # Inverting sinogram + t0 = time() + recon_sinogram = iradon(sinogram, theta=thetas, circle=True) + logger.debug(f'Inverting sinogram took {time()-t0:.2f} seconds') + del sinogram + + # Performing Gaussian filtering and removing ring artifacts + recon_parameters = None#self.config.get('recon_parameters') + if recon_parameters is None: + sigma = 1.0 + ring_width = 15 + else: + sigma = recon_parameters.get('gaussian_sigma', 1.0) + if not is_num(sigma, ge=0.0): + logger.warning(f'Invalid gaussian_sigma ({sigma}) in _reconstruct_one_plane, '+ + 'set to a default value of 1.0') + sigma = 1.0 + ring_width = recon_parameters.get('ring_width', 15) + if not is_int(ring_width, ge=0): + logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+ + 'set to a default value of 15') + ring_width = 15 + t0 = time() + recon_sinogram = spi.gaussian_filter(recon_sinogram, sigma, mode='nearest') + recon_clean = np.expand_dims(recon_sinogram, axis=0) + del recon_sinogram + recon_clean = tomopy.misc.corr.remove_ring(recon_clean, rwidth=ring_width, ncore=num_core) + logger.debug(f'Filtering and removing ring artifacts took {time()-t0:.2f} seconds') + + return recon_clean + + def _plot_edges_one_plane(self, recon_plane, title, path=None): + vis_parameters = None#self.config.get('vis_parameters') + if vis_parameters is None: + weight = 0.1 + else: + weight = vis_parameters.get('denoise_weight', 0.1) + if not is_num(weight, ge=0.0): + logger.warning(f'Invalid weight ({weight}) in _plot_edges_one_plane, '+ + 'set to a default value of 0.1') + weight = 0.1 + edges = denoise_tv_chambolle(recon_plane, weight=weight) + vmax = np.max(edges[0,:,:]) + vmin = -vmax + if path is None: + path = self.output_folder + quick_imshow(edges[0,:,:], f'{title} coolwarm', path=path, cmap='coolwarm', + save_fig=self.save_figs, save_only=self.save_only, block=self.block) + quick_imshow(edges[0,:,:], f'{title} gray', path=path, cmap='gray', vmin=vmin, vmax=vmax, + save_fig=self.save_figs, save_only=self.save_only, block=self.block) + del edges + + def _reconstruct_one_tomo_stack(self, tomo_stack, thetas, center_offsets=[], num_core=1, + algorithm='gridrec'): + """Reconstruct a single tomography stack. + """ + # tomo_stack order: row,theta,column + # input thetas must be in degrees + # centers_offset: tomography axis shift in pixels relative to column center + # RV should we remove stripes? + # https://tomopy.readthedocs.io/en/latest/api/tomopy.prep.stripe.html + # RV should we remove rings? + # https://tomopy.readthedocs.io/en/latest/api/tomopy.misc.corr.html + # RV: Add an option to do (extra) secondary iterations later or to do some sort of convergence test? + if not len(center_offsets): + centers = np.zeros((tomo_stack.shape[0])) + elif len(center_offsets) == 2: + centers = np.linspace(center_offsets[0], center_offsets[1], tomo_stack.shape[0]) + else: + if center_offsets.size != tomo_stack.shape[0]: + raise ValueError('center_offsets dimension mismatch in reconstruct_one_tomo_stack') + centers = center_offsets + centers += tomo_stack.shape[2]/2 + + # Get reconstruction parameters + recon_parameters = None#self.config.get('recon_parameters') + if recon_parameters is None: + sigma = 2.0 + secondary_iters = 0 + ring_width = 15 + else: + sigma = recon_parameters.get('stripe_fw_sigma', 2.0) + if not is_num(sigma, ge=0): + logger.warning(f'Invalid stripe_fw_sigma ({sigma}) in '+ + '_reconstruct_one_tomo_stack, set to a default value of 2.0') + ring_width = 15 + secondary_iters = recon_parameters.get('secondary_iters', 0) + if not is_int(secondary_iters, ge=0): + logger.warning(f'Invalid secondary_iters ({secondary_iters}) in '+ + '_reconstruct_one_tomo_stack, set to a default value of 0 (skip them)') + ring_width = 0 + ring_width = recon_parameters.get('ring_width', 15) + if not is_int(ring_width, ge=0): + logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+ + 'set to a default value of 15') + ring_width = 15 + + # Remove horizontal stripe + t0 = time() + if num_core > num_core_tomopy_limit: + logger.debug('Running remove_stripe_fw on {num_core_tomopy_limit} cores ...') + tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma, + ncore=num_core_tomopy_limit) + else: + logger.debug(f'Running remove_stripe_fw on {num_core} cores ...') + tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma, + ncore=num_core) + logger.debug(f'... tomopy.prep.stripe.remove_stripe_fw took {time()-t0:.2f} seconds') + + # Perform initial image reconstruction + logger.debug('Performing initial image reconstruction') + t0 = time() + logger.debug(f'Running recon on {num_core} cores ...') + tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers, + sinogram_order=True, algorithm=algorithm, ncore=num_core) + logger.debug(f'... done in {time()-t0:.2f} seconds') + logger.info(f'Performing initial image reconstruction took {time()-t0:.2f} seconds') + + # Run optional secondary iterations + if secondary_iters > 0: + logger.debug(f'Running {secondary_iters} secondary iterations') + #options = {'method':'SIRT_CUDA', 'proj_type':'cuda', 'num_iter':secondary_iters} + #RV: doesn't work for me: + #"Error: CUDA error 803: system has unsupported display driver/cuda driver combination." + #options = {'method':'SIRT', 'proj_type':'linear', 'MinConstraint': 0, 'num_iter':secondary_iters} + #SIRT did not finish while running overnight + #options = {'method':'SART', 'proj_type':'linear', 'num_iter':secondary_iters} + options = {'method':'SART', 'proj_type':'linear', 'MinConstraint': 0, + 'num_iter':secondary_iters} + t0 = time() + logger.debug(f'Running recon on {num_core} cores ...') + tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers, + init_recon=tomo_recon_stack, options=options, sinogram_order=True, + algorithm=tomopy.astra, ncore=num_core) + logger.debug(f'... done in {time()-t0:.2f} seconds') + logger.info(f'Performing secondary iterations took {time()-t0:.2f} seconds') + + # Remove ring artifacts + t0 = time() + tomopy.misc.corr.remove_ring(tomo_recon_stack, rwidth=ring_width, out=tomo_recon_stack, + ncore=num_core) + logger.debug(f'Removing ring artifacts took {time()-t0:.2f} seconds') + + return tomo_recon_stack + + def _resize_reconstructed_data(self, data, z_only=False): + """Resize the reconstructed tomography data. + """ + # Data order: row(z),x,y or stack,row(z),x,y + if isinstance(data, list): + for stack in data: + assert(stack.ndim == 3) + num_tomo_stacks = len(data) + tomo_recon_stacks = data + else: + assert(data.ndim == 3) + num_tomo_stacks = 1 + tomo_recon_stacks = [data] + + if z_only: + x_bounds = None + y_bounds = None + else: + # Selecting x bounds (in yz-plane) + tomosum = 0 + [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,2)) + for i in range(num_tomo_stacks)] + select_x_bounds = input_yesno('\nDo you want to change the image x-bounds (y/n)?', 'y') + if not select_x_bounds: + x_bounds = None + else: + accept = False + index_ranges = None + while not accept: + mask, x_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, + title='select x data range', legend='recon stack sum yz') + while len(x_bounds) != 1: + print('Please select exactly one continuous range') + mask, x_bounds = draw_mask_1d(tomosum, title='select x data range', + legend='recon stack sum yz') + x_bounds = x_bounds[0] +# quick_plot(tomosum, vlines=x_bounds, title='recon stack sum yz') +# print(f'x_bounds = {x_bounds} (lower bound inclusive, upper bound '+ +# 'exclusive)') +# accept = input_yesno('Accept these bounds (y/n)?', 'y') + accept = True + logger.debug(f'x_bounds = {x_bounds}') + + # Selecting y bounds (in xz-plane) + tomosum = 0 + [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,1)) + for i in range(num_tomo_stacks)] + select_y_bounds = input_yesno('\nDo you want to change the image y-bounds (y/n)?', 'y') + if not select_y_bounds: + y_bounds = None + else: + accept = False + index_ranges = None + while not accept: + mask, y_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, + title='select x data range', legend='recon stack sum xz') + while len(y_bounds) != 1: + print('Please select exactly one continuous range') + mask, y_bounds = draw_mask_1d(tomosum, title='select x data range', + legend='recon stack sum xz') + y_bounds = y_bounds[0] +# quick_plot(tomosum, vlines=y_bounds, title='recon stack sum xz') +# print(f'y_bounds = {y_bounds} (lower bound inclusive, upper bound '+ +# 'exclusive)') +# accept = input_yesno('Accept these bounds (y/n)?', 'y') + accept = True + logger.debug(f'y_bounds = {y_bounds}') + + # Selecting z bounds (in xy-plane) (only valid for a single image set) + if num_tomo_stacks != 1: + z_bounds = None + else: + tomosum = 0 + [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(1,2)) + for i in range(num_tomo_stacks)] + select_z_bounds = input_yesno('Do you want to change the image z-bounds (y/n)?', 'n') + if not select_z_bounds: + z_bounds = None + else: + accept = False + index_ranges = None + while not accept: + mask, z_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges, + title='select x data range', legend='recon stack sum xy') + while len(z_bounds) != 1: + print('Please select exactly one continuous range') + mask, z_bounds = draw_mask_1d(tomosum, title='select x data range', + legend='recon stack sum xy') + z_bounds = z_bounds[0] +# quick_plot(tomosum, vlines=z_bounds, title='recon stack sum xy') +# print(f'z_bounds = {z_bounds} (lower bound inclusive, upper bound '+ +# 'exclusive)') +# accept = input_yesno('Accept these bounds (y/n)?', 'y') + accept = True + logger.debug(f'z_bounds = {z_bounds}') + + return(x_bounds, y_bounds, z_bounds) + + +def run_tomo(input_file:str, output_file:str, modes:list[str], center_file=None, num_core=-1, + output_folder='.', save_figs='no', test_mode=False) -> None: + + if test_mode: + logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s' + level = logging.getLevelName('INFO') + logging.basicConfig(filename=f'{output_folder}/tomo.log', filemode='w', + format=logging_format, level=level, force=True) + logger.info(f'input_file = {input_file}') + logger.info(f'center_file = {center_file}') + logger.info(f'output_file = {output_file}') + logger.debug(f'modes= {modes}') + logger.debug(f'num_core= {num_core}') + logger.info(f'output_folder = {output_folder}') + logger.info(f'save_figs = {save_figs}') + logger.info(f'test_mode = {test_mode}') + + # Check for correction modes + if modes is None: + modes = ['all'] + logger.debug(f'modes {type(modes)} = {modes}') + + # Instantiate Tomo object + tomo = Tomo(num_core=num_core, output_folder=output_folder, save_figs=save_figs, + test_mode=test_mode) + + # Read input file + data = tomo.read(input_file) + + # Generate reduced tomography images + if 'reduce_data' in modes or 'all' in modes: + data = tomo.gen_reduced_data(data) + + # Find rotation axis centers for the tomography stacks. + center_data = None + if 'find_center' in modes or 'all' in modes: + center_data = tomo.find_centers(data) + + # Reconstruct tomography stacks + if 'reconstruct_data' in modes or 'all' in modes: + if center_data is None: + # Read input file + center_data = tomo.read(center_file) + data = tomo.reconstruct_data(data, center_data) + center_data = None + + # Combine reconstructed tomography stacks + if 'combine_data' in modes or 'all' in modes: + data = tomo.combine_data(data) + + # Write output file + if not test_mode: + if center_data is None: + data = tomo.write(data, output_file) + else: + data = tomo.write(center_data, output_file) +