# HG changeset patch
# User rv43
# Date 1679415762 0
# Node ID 98e23dff1de2041ab383d3a769b848948749136b
planemo upload for repository https://github.com/rolfverberg/galaxytools commit f8c4bdb31c20c468045ad5e6eb255a293244bc6c-dirty
diff -r 000000000000 -r 98e23dff1de2 fit.py
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/fit.py Tue Mar 21 16:22:42 2023 +0000
@@ -0,0 +1,2576 @@
+#!/usr/bin/env python3
+
+# -*- coding: utf-8 -*-
+"""
+Created on Mon Dec 6 15:36:22 2021
+
+@author: rv43
+"""
+
+import logging
+
+from asteval import Interpreter, get_ast_names
+from copy import deepcopy
+from lmfit import Model, Parameters
+from lmfit.model import ModelResult
+from lmfit.models import ConstantModel, LinearModel, QuadraticModel, PolynomialModel,\
+ ExponentialModel, StepModel, RectangleModel, ExpressionModel, GaussianModel,\
+ LorentzianModel
+import numpy as np
+from os import cpu_count, getpid, listdir, mkdir, path
+from re import compile, sub
+from shutil import rmtree
+try:
+ from sympy import diff, simplify
+except:
+ pass
+try:
+ from joblib import Parallel, delayed
+ have_joblib = True
+except:
+ have_joblib = False
+try:
+ import xarray as xr
+ have_xarray = True
+except:
+ have_xarray = False
+
+try:
+ from .general import illegal_value, is_int, is_dict_series, is_index, index_nearest, \
+ almost_equal, quick_plot #, eval_expr
+except:
+ try:
+ from sys import path as syspath
+ syspath.append(f'/nfs/chess/user/rv43/msnctools/msnctools')
+ from general import illegal_value, is_int, is_dict_series, is_index, index_nearest, \
+ almost_equal, quick_plot #, eval_expr
+ except:
+ from general import illegal_value, is_int, is_dict_series, is_index, index_nearest, \
+ almost_equal, quick_plot #, eval_expr
+
+from sys import float_info
+float_min = float_info.min
+float_max = float_info.max
+
+# sigma = fwhm_factor*fwhm
+fwhm_factor = {
+ 'gaussian': f'fwhm/(2*sqrt(2*log(2)))',
+ 'lorentzian': f'0.5*fwhm',
+ 'splitlorentzian': f'0.5*fwhm', # sigma = sigma_r
+ 'voight': f'0.2776*fwhm', # sigma = gamma
+ 'pseudovoight': f'0.5*fwhm'} # fraction = 0.5
+
+# amplitude = height_factor*height*fwhm
+height_factor = {
+ 'gaussian': f'height*fwhm*0.5*sqrt(pi/log(2))',
+ 'lorentzian': f'height*fwhm*0.5*pi',
+ 'splitlorentzian': f'height*fwhm*0.5*pi', # sigma = sigma_r
+ 'voight': f'3.334*height*fwhm', # sigma = gamma
+ 'pseudovoight': f'1.268*height*fwhm'} # fraction = 0.5
+
+class Fit:
+ """Wrapper class for lmfit
+ """
+ def __init__(self, y, x=None, models=None, normalize=True, **kwargs):
+ if not isinstance(normalize, bool):
+ raise ValueError(f'Invalid parameter normalize ({normalize})')
+ self._mask = None
+ self._model = None
+ self._norm = None
+ self._normalized = False
+ self._parameters = Parameters()
+ self._parameter_bounds = None
+ self._parameter_norms = {}
+ self._linear_parameters = []
+ self._nonlinear_parameters = []
+ self._result = None
+ self._try_linear_fit = True
+ self._y = None
+ self._y_norm = None
+ self._y_range = None
+ if 'try_linear_fit' in kwargs:
+ try_linear_fit = kwargs.pop('try_linear_fit')
+ if not isinstance(try_linear_fit, bool):
+ illegal_value(try_linear_fit, 'try_linear_fit', 'Fit.fit', raise_error=True)
+ self._try_linear_fit = try_linear_fit
+ if y is not None:
+ if isinstance(y, (tuple, list, np.ndarray)):
+ self._x = np.asarray(x)
+ elif have_xarray and isinstance(y, xr.DataArray):
+ if x is not None:
+ logging.warning('Ignoring superfluous input x ({x}) in Fit.__init__')
+ if y.ndim != 1:
+ illegal_value(y.ndim, 'DataArray dimensions', 'Fit:__init__', raise_error=True)
+ self._x = np.asarray(y[y.dims[0]])
+ else:
+ illegal_value(y, 'y', 'Fit:__init__', raise_error=True)
+ self._y = y
+ if self._x.ndim != 1:
+ raise ValueError(f'Invalid dimension for input x ({self._x.ndim})')
+ if self._x.size != self._y.size:
+ raise ValueError(f'Inconsistent x and y dimensions ({self._x.size} vs '+
+ f'{self._y.size})')
+ if 'mask' in kwargs:
+ self._mask = kwargs.pop('mask')
+ if self._mask is None:
+ y_min = float(self._y.min())
+ self._y_range = float(self._y.max())-y_min
+ if normalize and self._y_range > 0.0:
+ self._norm = (y_min, self._y_range)
+ else:
+ self._mask = np.asarray(self._mask).astype(bool)
+ if self._x.size != self._mask.size:
+ raise ValueError(f'Inconsistent x and mask dimensions ({self._x.size} vs '+
+ f'{self._mask.size})')
+ y_masked = np.asarray(self._y)[~self._mask]
+ y_min = float(y_masked.min())
+ self._y_range = float(y_masked.max())-y_min
+ if normalize and self._y_range > 0.0:
+ if normalize and self._y_range > 0.0:
+ self._norm = (y_min, self._y_range)
+ if models is not None:
+ if callable(models) or isinstance(models, str):
+ kwargs = self.add_model(models, **kwargs)
+ elif isinstance(models, (tuple, list)):
+ for model in models:
+ kwargs = self.add_model(model, **kwargs)
+ self.fit(**kwargs)
+
+ @classmethod
+ def fit_data(cls, y, models, x=None, normalize=True, **kwargs):
+ return(cls(y, x=x, models=models, normalize=normalize, **kwargs))
+
+ @property
+ def best_errors(self):
+ if self._result is None:
+ return(None)
+ return({name:self._result.params[name].stderr for name in sorted(self._result.params)
+ if name != 'tmp_normalization_offset_c'})
+
+ @property
+ def best_fit(self):
+ if self._result is None:
+ return(None)
+ return(self._result.best_fit)
+
+ @property
+ def best_parameters(self):
+ if self._result is None:
+ return(None)
+ parameters = {}
+ for name in sorted(self._result.params):
+ if name != 'tmp_normalization_offset_c':
+ par = self._result.params[name]
+ parameters[name] = {'value': par.value, 'error': par.stderr,
+ 'init_value': par.init_value, 'min': par.min, 'max': par.max,
+ 'vary': par.vary, 'expr': par.expr}
+ return(parameters)
+
+ @property
+ def best_results(self):
+ """Convert the input data array to a data set and add the fit results.
+ """
+ if self._result is None:
+ return(None)
+ if isinstance(self._y, xr.DataArray):
+ best_results = self._y.to_dataset()
+ dims = self._y.dims
+ fit_name = f'{self._y.name}_fit'
+ else:
+ coords = {'x': (['x'], self._x)}
+ dims = ('x')
+ best_results = xr.Dataset(coords=coords)
+ best_results['y'] = (dims, self._y)
+ fit_name = 'y_fit'
+ best_results[fit_name] = (dims, self.best_fit)
+ if self._mask is not None:
+ best_results['mask'] = self._mask
+ best_results.coords['par_names'] = ('peak', [name for name in self.best_values.keys()])
+ best_results['best_values'] = (['par_names'], [v for v in self.best_values.values()])
+ best_results['best_errors'] = (['par_names'], [v for v in self.best_errors.values()])
+ best_results.attrs['components'] = self.components
+ return(best_results)
+
+ @property
+ def best_values(self):
+ if self._result is None:
+ return(None)
+ return({name:self._result.params[name].value for name in sorted(self._result.params)
+ if name != 'tmp_normalization_offset_c'})
+
+ @property
+ def chisqr(self):
+ if self._result is None:
+ return(None)
+ return(self._result.chisqr)
+
+ @property
+ def components(self):
+ components = {}
+ if self._result is None:
+ logging.warning('Unable to collect components in Fit.components')
+ return(components)
+ for component in self._result.components:
+ if 'tmp_normalization_offset_c' in component.param_names:
+ continue
+ parameters = {}
+ for name in component.param_names:
+ par = self._parameters[name]
+ parameters[name] = {'free': par.vary, 'value': self._result.params[name].value}
+ if par.expr is not None:
+ parameters[name]['expr'] = par.expr
+ expr = None
+ if isinstance(component, ExpressionModel):
+ name = component._name
+ if name[-1] == '_':
+ name = name[:-1]
+ expr = component.expr
+ else:
+ prefix = component.prefix
+ if len(prefix):
+ if prefix[-1] == '_':
+ prefix = prefix[:-1]
+ name = f'{prefix} ({component._name})'
+ else:
+ name = f'{component._name}'
+ if expr is None:
+ components[name] = {'parameters': parameters}
+ else:
+ components[name] = {'expr': expr, 'parameters': parameters}
+ return(components)
+
+ @property
+ def covar(self):
+ if self._result is None:
+ return(None)
+ return(self._result.covar)
+
+ @property
+ def init_parameters(self):
+ if self._result is None or self._result.init_params is None:
+ return(None)
+ parameters = {}
+ for name in sorted(self._result.init_params):
+ if name != 'tmp_normalization_offset_c':
+ par = self._result.init_params[name]
+ parameters[name] = {'value': par.value, 'min': par.min, 'max': par.max,
+ 'vary': par.vary, 'expr': par.expr}
+ return(parameters)
+
+ @property
+ def init_values(self):
+ if self._result is None or self._result.init_params is None:
+ return(None)
+ return({name:self._result.init_params[name].value for name in
+ sorted(self._result.init_params) if name != 'tmp_normalization_offset_c'})
+
+ @property
+ def normalization_offset(self):
+ if self._result is None:
+ return(None)
+ if self._norm is None:
+ return(0.0)
+ else:
+ if self._result.init_params is not None:
+ normalization_offset = self._result.init_params['tmp_normalization_offset_c']
+ else:
+ normalization_offset = self._result.params['tmp_normalization_offset_c']
+ return(normalization_offset)
+
+ @property
+ def num_func_eval(self):
+ if self._result is None:
+ return(None)
+ return(self._result.nfev)
+
+ @property
+ def parameters(self):
+ return({name:{'min': par.min, 'max': par.max, 'vary': par.vary, 'expr': par.expr}
+ for name, par in self._parameters.items() if name != 'tmp_normalization_offset_c'})
+
+ @property
+ def redchi(self):
+ if self._result is None:
+ return(None)
+ return(self._result.redchi)
+
+ @property
+ def residual(self):
+ if self._result is None:
+ return(None)
+ return(self._result.residual)
+
+ @property
+ def success(self):
+ if self._result is None:
+ return(None)
+ if not self._result.success:
+# print(f'ier = {self._result.ier}')
+# print(f'lmdif_message = {self._result.lmdif_message}')
+# print(f'message = {self._result.message}')
+# print(f'nfev = {self._result.nfev}')
+# print(f'redchi = {self._result.redchi}')
+# print(f'success = {self._result.success}')
+ if self._result.ier == 0 or self._result.ier == 5:
+ logging.warning(f'ier = {self._result.ier}: {self._result.message}')
+ else:
+ logging.warning(f'ier = {self._result.ier}: {self._result.message}')
+ return(True)
+# self.print_fit_report()
+# self.plot()
+ return(self._result.success)
+
+ @property
+ def var_names(self):
+ """Intended to be used with covar
+ """
+ if self._result is None:
+ return(None)
+ return(getattr(self._result, 'var_names', None))
+
+ @property
+ def x(self):
+ return(self._x)
+
+ @property
+ def y(self):
+ return(self._y)
+
+ def print_fit_report(self, result=None, show_correl=False):
+ if result is None:
+ result = self._result
+ if result is not None:
+ print(result.fit_report(show_correl=show_correl))
+
+ def add_parameter(self, **parameter):
+ if not isinstance(parameter, dict):
+ raise ValueError(f'Invalid parameter ({parameter})')
+ if parameter.get('expr') is not None:
+ raise KeyError(f'Illegal "expr" key in parameter {parameter}')
+ name = parameter['name']
+ if not isinstance(name, str):
+ raise ValueError(f'Illegal "name" value ({name}) in parameter {parameter}')
+ if parameter.get('norm') is None:
+ self._parameter_norms[name] = False
+ else:
+ norm = parameter.pop('norm')
+ if self._norm is None:
+ logging.warning(f'Ignoring norm in parameter {name} in '+
+ f'Fit.add_parameter (normalization is turned off)')
+ self._parameter_norms[name] = False
+ else:
+ if not isinstance(norm, bool):
+ raise ValueError(f'Illegal "norm" value ({norm}) in parameter {parameter}')
+ self._parameter_norms[name] = norm
+ vary = parameter.get('vary')
+ if vary is not None:
+ if not isinstance(vary, bool):
+ raise ValueError(f'Illegal "vary" value ({vary}) in parameter {parameter}')
+ if not vary:
+ if 'min' in parameter:
+ logging.warning(f'Ignoring min in parameter {name} in '+
+ f'Fit.add_parameter (vary = {vary})')
+ parameter.pop('min')
+ if 'max' in parameter:
+ logging.warning(f'Ignoring max in parameter {name} in '+
+ f'Fit.add_parameter (vary = {vary})')
+ parameter.pop('max')
+ if self._norm is not None and name not in self._parameter_norms:
+ raise ValueError(f'Missing parameter normalization type for paremeter {name}')
+ self._parameters.add(**parameter)
+
+ def add_model(self, model, prefix=None, parameters=None, parameter_norms=None, **kwargs):
+ # Create the new model
+# print(f'at start add_model:\nself._parameters:\n{self._parameters}')
+# print(f'at start add_model: kwargs = {kwargs}')
+# print(f'parameters = {parameters}')
+# print(f'parameter_norms = {parameter_norms}')
+# if len(self._parameters.keys()):
+# print('\nAt start adding model:')
+# self._parameters.pretty_print()
+# print(f'parameter_norms:\n{self._parameter_norms}')
+ if prefix is not None and not isinstance(prefix, str):
+ logging.warning('Ignoring illegal prefix: {model} {type(model)}')
+ prefix = None
+ if prefix is None:
+ pprefix = ''
+ else:
+ pprefix = prefix
+ if parameters is not None:
+ if isinstance(parameters, dict):
+ parameters = (parameters, )
+ elif not is_dict_series(parameters):
+ illegal_value(parameters, 'parameters', 'Fit.add_model', raise_error=True)
+ parameters = deepcopy(parameters)
+ if parameter_norms is not None:
+ if isinstance(parameter_norms, dict):
+ parameter_norms = (parameter_norms, )
+ if not is_dict_series(parameter_norms):
+ illegal_value(parameter_norms, 'parameter_norms', 'Fit.add_model', raise_error=True)
+ new_parameter_norms = {}
+ if callable(model):
+ # Linear fit not yet implemented for callable models
+ self._try_linear_fit = False
+ if parameter_norms is None:
+ if parameters is None:
+ raise ValueError('Either "parameters" or "parameter_norms" is required in '+
+ f'{model}')
+ for par in parameters:
+ name = par['name']
+ if not isinstance(name, str):
+ raise ValueError(f'Illegal "name" value ({name}) in input parameters')
+ if par.get('norm') is not None:
+ norm = par.pop('norm')
+ if not isinstance(norm, bool):
+ raise ValueError(f'Illegal "norm" value ({norm}) in input parameters')
+ new_parameter_norms[f'{pprefix}{name}'] = norm
+ else:
+ for par in parameter_norms:
+ name = par['name']
+ if not isinstance(name, str):
+ raise ValueError(f'Illegal "name" value ({name}) in input parameters')
+ norm = par.get('norm')
+ if norm is None or not isinstance(norm, bool):
+ raise ValueError(f'Illegal "norm" value ({norm}) in input parameters')
+ new_parameter_norms[f'{pprefix}{name}'] = norm
+ if parameters is not None:
+ for par in parameters:
+ if par.get('expr') is not None:
+ raise KeyError(f'Illegal "expr" key ({par.get("expr")}) in parameter '+
+ f'{name} for a callable model {model}')
+ name = par['name']
+ if not isinstance(name, str):
+ raise ValueError(f'Illegal "name" value ({name}) in input parameters')
+# RV FIX callable model will need partial deriv functions for any linear pars to get the linearized matrix, so for now skip linear solution option
+ newmodel = Model(model, prefix=prefix)
+ elif isinstance(model, str):
+ if model == 'constant': # Par: c
+ newmodel = ConstantModel(prefix=prefix)
+ new_parameter_norms[f'{pprefix}c'] = True
+ self._linear_parameters.append(f'{pprefix}c')
+ elif model == 'linear': # Par: slope, intercept
+ newmodel = LinearModel(prefix=prefix)
+ new_parameter_norms[f'{pprefix}slope'] = True
+ new_parameter_norms[f'{pprefix}intercept'] = True
+ self._linear_parameters.append(f'{pprefix}slope')
+ self._linear_parameters.append(f'{pprefix}intercept')
+ elif model == 'quadratic': # Par: a, b, c
+ newmodel = QuadraticModel(prefix=prefix)
+ new_parameter_norms[f'{pprefix}a'] = True
+ new_parameter_norms[f'{pprefix}b'] = True
+ new_parameter_norms[f'{pprefix}c'] = True
+ self._linear_parameters.append(f'{pprefix}a')
+ self._linear_parameters.append(f'{pprefix}b')
+ self._linear_parameters.append(f'{pprefix}c')
+ elif model == 'gaussian': # Par: amplitude, center, sigma (fwhm, height)
+ newmodel = GaussianModel(prefix=prefix)
+ new_parameter_norms[f'{pprefix}amplitude'] = True
+ new_parameter_norms[f'{pprefix}center'] = False
+ new_parameter_norms[f'{pprefix}sigma'] = False
+ self._linear_parameters.append(f'{pprefix}amplitude')
+ self._nonlinear_parameters.append(f'{pprefix}center')
+ self._nonlinear_parameters.append(f'{pprefix}sigma')
+ # parameter norms for height and fwhm are needed to get correct errors
+ new_parameter_norms[f'{pprefix}height'] = True
+ new_parameter_norms[f'{pprefix}fwhm'] = False
+ elif model == 'lorentzian': # Par: amplitude, center, sigma (fwhm, height)
+ newmodel = LorentzianModel(prefix=prefix)
+ new_parameter_norms[f'{pprefix}amplitude'] = True
+ new_parameter_norms[f'{pprefix}center'] = False
+ new_parameter_norms[f'{pprefix}sigma'] = False
+ self._linear_parameters.append(f'{pprefix}amplitude')
+ self._nonlinear_parameters.append(f'{pprefix}center')
+ self._nonlinear_parameters.append(f'{pprefix}sigma')
+ # parameter norms for height and fwhm are needed to get correct errors
+ new_parameter_norms[f'{pprefix}height'] = True
+ new_parameter_norms[f'{pprefix}fwhm'] = False
+ elif model == 'exponential': # Par: amplitude, decay
+ newmodel = ExponentialModel(prefix=prefix)
+ new_parameter_norms[f'{pprefix}amplitude'] = True
+ new_parameter_norms[f'{pprefix}decay'] = False
+ self._linear_parameters.append(f'{pprefix}amplitude')
+ self._nonlinear_parameters.append(f'{pprefix}decay')
+ elif model == 'step': # Par: amplitude, center, sigma
+ form = kwargs.get('form')
+ if form is not None:
+ kwargs.pop('form')
+ if form is None or form not in ('linear', 'atan', 'arctan', 'erf', 'logistic'):
+ raise ValueError(f'Invalid parameter form for build-in step model ({form})')
+ newmodel = StepModel(prefix=prefix, form=form)
+ new_parameter_norms[f'{pprefix}amplitude'] = True
+ new_parameter_norms[f'{pprefix}center'] = False
+ new_parameter_norms[f'{pprefix}sigma'] = False
+ self._linear_parameters.append(f'{pprefix}amplitude')
+ self._nonlinear_parameters.append(f'{pprefix}center')
+ self._nonlinear_parameters.append(f'{pprefix}sigma')
+ elif model == 'rectangle': # Par: amplitude, center1, center2, sigma1, sigma2
+ form = kwargs.get('form')
+ if form is not None:
+ kwargs.pop('form')
+ if form is None or form not in ('linear', 'atan', 'arctan', 'erf', 'logistic'):
+ raise ValueError('Invalid parameter form for build-in rectangle model '+
+ f'({form})')
+ newmodel = RectangleModel(prefix=prefix, form=form)
+ new_parameter_norms[f'{pprefix}amplitude'] = True
+ new_parameter_norms[f'{pprefix}center1'] = False
+ new_parameter_norms[f'{pprefix}center2'] = False
+ new_parameter_norms[f'{pprefix}sigma1'] = False
+ new_parameter_norms[f'{pprefix}sigma2'] = False
+ self._linear_parameters.append(f'{pprefix}amplitude')
+ self._nonlinear_parameters.append(f'{pprefix}center1')
+ self._nonlinear_parameters.append(f'{pprefix}center2')
+ self._nonlinear_parameters.append(f'{pprefix}sigma1')
+ self._nonlinear_parameters.append(f'{pprefix}sigma2')
+ elif model == 'expression': # Par: by expression
+ expr = kwargs['expr']
+ if not isinstance(expr, str):
+ raise ValueError(f'Illegal "expr" value ({expr}) in {model}')
+ kwargs.pop('expr')
+ if parameter_norms is not None:
+ logging.warning('Ignoring parameter_norms (normalization determined from '+
+ 'linearity)}')
+ if parameters is not None:
+ for par in parameters:
+ if par.get('expr') is not None:
+ raise KeyError(f'Illegal "expr" key ({par.get("expr")}) in parameter '+
+ f'({par}) for an expression model')
+ if par.get('norm') is not None:
+ logging.warning(f'Ignoring "norm" key in parameter ({par}) '+
+ '(normalization determined from linearity)}')
+ par.pop('norm')
+ name = par['name']
+ if not isinstance(name, str):
+ raise ValueError(f'Illegal "name" value ({name}) in input parameters')
+ ast = Interpreter()
+ expr_parameters = [name for name in get_ast_names(ast.parse(expr))
+ if name != 'x' and name not in self._parameters
+ and name not in ast.symtable]
+# print(f'\nexpr_parameters: {expr_parameters}')
+# print(f'expr = {expr}')
+ if prefix is None:
+ newmodel = ExpressionModel(expr=expr)
+ else:
+ for name in expr_parameters:
+ expr = sub(rf'\b{name}\b', f'{prefix}{name}', expr)
+ expr_parameters = [f'{prefix}{name}' for name in expr_parameters]
+# print(f'\nexpr_parameters: {expr_parameters}')
+# print(f'expr = {expr}')
+ newmodel = ExpressionModel(expr=expr, name=name)
+# print(f'\nnewmodel = {newmodel.__dict__}')
+# print(f'params_names = {newmodel._param_names}')
+# print(f'params_names = {newmodel.param_names}')
+ # Remove already existing names
+ for name in newmodel.param_names.copy():
+ if name not in expr_parameters:
+ newmodel._func_allargs.remove(name)
+ newmodel._param_names.remove(name)
+# print(f'params_names = {newmodel._param_names}')
+# print(f'params_names = {newmodel.param_names}')
+ else:
+ raise ValueError(f'Unknown build-in fit model ({model})')
+ else:
+ illegal_value(model, 'model', 'Fit.add_model', raise_error=True)
+
+ # Add the new model to the current one
+# print('\nBefore adding model:')
+# print(f'\nnewmodel = {newmodel.__dict__}')
+# if len(self._parameters):
+# self._parameters.pretty_print()
+ if self._model is None:
+ self._model = newmodel
+ else:
+ self._model += newmodel
+ new_parameters = newmodel.make_params()
+ self._parameters += new_parameters
+# print('\nAfter adding model:')
+# print(f'\nnewmodel = {newmodel.__dict__}')
+# print(f'\nnew_parameters = {new_parameters}')
+# self._parameters.pretty_print()
+
+ # Check linearity of expression model paremeters
+ if isinstance(newmodel, ExpressionModel):
+ for name in newmodel.param_names:
+ if not diff(newmodel.expr, name, name):
+ if name not in self._linear_parameters:
+ self._linear_parameters.append(name)
+ new_parameter_norms[name] = True
+# print(f'\nADDING {name} TO LINEAR')
+ else:
+ if name not in self._nonlinear_parameters:
+ self._nonlinear_parameters.append(name)
+ new_parameter_norms[name] = False
+# print(f'\nADDING {name} TO NONLINEAR')
+# print(f'new_parameter_norms:\n{new_parameter_norms}')
+
+ # Scale the default initial model parameters
+ if self._norm is not None:
+ for name, norm in new_parameter_norms.copy().items():
+ par = self._parameters.get(name)
+ if par is None:
+ new_parameter_norms.pop(name)
+ continue
+ if par.expr is None and norm:
+ value = par.value*self._norm[1]
+ _min = par.min
+ _max = par.max
+ if not np.isinf(_min) and abs(_min) != float_min:
+ _min *= self._norm[1]
+ if not np.isinf(_max) and abs(_max) != float_min:
+ _max *= self._norm[1]
+ par.set(value=value, min=_min, max=_max)
+# print('\nAfter norm defaults:')
+# self._parameters.pretty_print()
+# print(f'parameters:\n{parameters}')
+# print(f'all_parameters:\n{list(self.parameters)}')
+# print(f'new_parameter_norms:\n{new_parameter_norms}')
+# print(f'parameter_norms:\n{self._parameter_norms}')
+
+ # Initialize the model parameters from parameters
+ if prefix is None:
+ prefix = ""
+ if parameters is not None:
+ for parameter in parameters:
+ name = parameter['name']
+ if not isinstance(name, str):
+ raise ValueError(f'Illegal "name" value ({name}) in input parameters')
+ if name not in new_parameters:
+ name = prefix+name
+ parameter['name'] = name
+ if name not in new_parameters:
+ logging.warning(f'Ignoring superfluous parameter info for {name}')
+ continue
+ if name in self._parameters:
+ parameter.pop('name')
+ if 'norm' in parameter:
+ if not isinstance(parameter['norm'], bool):
+ illegal_value(parameter['norm'], 'norm', 'Fit.add_model',
+ raise_error=True)
+ new_parameter_norms[name] = parameter['norm']
+ parameter.pop('norm')
+ if parameter.get('expr') is not None:
+ if 'value' in parameter:
+ logging.warning(f'Ignoring value in parameter {name} '+
+ f'(set by expression: {parameter["expr"]})')
+ parameter.pop('value')
+ if 'vary' in parameter:
+ logging.warning(f'Ignoring vary in parameter {name} '+
+ f'(set by expression: {parameter["expr"]})')
+ parameter.pop('vary')
+ if 'min' in parameter:
+ logging.warning(f'Ignoring min in parameter {name} '+
+ f'(set by expression: {parameter["expr"]})')
+ parameter.pop('min')
+ if 'max' in parameter:
+ logging.warning(f'Ignoring max in parameter {name} '+
+ f'(set by expression: {parameter["expr"]})')
+ parameter.pop('max')
+ if 'vary' in parameter:
+ if not isinstance(parameter['vary'], bool):
+ illegal_value(parameter['vary'], 'vary', 'Fit.add_model',
+ raise_error=True)
+ if not parameter['vary']:
+ if 'min' in parameter:
+ logging.warning(f'Ignoring min in parameter {name} in '+
+ f'Fit.add_model (vary = {parameter["vary"]})')
+ parameter.pop('min')
+ if 'max' in parameter:
+ logging.warning(f'Ignoring max in parameter {name} in '+
+ f'Fit.add_model (vary = {parameter["vary"]})')
+ parameter.pop('max')
+ self._parameters[name].set(**parameter)
+ parameter['name'] = name
+ else:
+ illegal_value(parameter, 'parameter name', 'Fit.model', raise_error=True)
+ self._parameter_norms = {**self._parameter_norms, **new_parameter_norms}
+# print('\nAfter parameter init:')
+# self._parameters.pretty_print()
+# print(f'parameters:\n{parameters}')
+# print(f'new_parameter_norms:\n{new_parameter_norms}')
+# print(f'parameter_norms:\n{self._parameter_norms}')
+# print(f'kwargs:\n{kwargs}')
+
+ # Initialize the model parameters from kwargs
+ for name, value in {**kwargs}.items():
+ full_name = f'{pprefix}{name}'
+ if full_name in new_parameter_norms and isinstance(value, (int, float)):
+ kwargs.pop(name)
+ if self._parameters[full_name].expr is None:
+ self._parameters[full_name].set(value=value)
+ else:
+ logging.warning(f'Ignoring parameter {name} in Fit.fit (set by expression: '+
+ f'{self._parameters[full_name].expr})')
+# print('\nAfter kwargs init:')
+# self._parameters.pretty_print()
+# print(f'parameter_norms:\n{self._parameter_norms}')
+# print(f'kwargs:\n{kwargs}')
+
+ # Check parameter norms (also need it for expressions to renormalize the errors)
+ if self._norm is not None and (callable(model) or model == 'expression'):
+ missing_norm = False
+ for name in new_parameters.valuesdict():
+ if name not in self._parameter_norms:
+ print(f'new_parameters:\n{new_parameters.valuesdict()}')
+ print(f'self._parameter_norms:\n{self._parameter_norms}')
+ logging.error(f'Missing parameter normalization type for {name} in {model}')
+ missing_norm = True
+ if missing_norm:
+ raise ValueError
+
+# print(f'at end add_model:\nself._parameters:\n{list(self.parameters)}')
+# print(f'at end add_model: kwargs = {kwargs}')
+# print(f'\nat end add_model: newmodel:\n{newmodel.__dict__}\n')
+ return(kwargs)
+
+ def fit(self, interactive=False, guess=False, **kwargs):
+ # Check inputs
+ if self._model is None:
+ logging.error('Undefined fit model')
+ return
+ if not isinstance(interactive, bool):
+ illegal_value(interactive, 'interactive', 'Fit.fit', raise_error=True)
+ if not isinstance(guess, bool):
+ illegal_value(guess, 'guess', 'Fit.fit', raise_error=True)
+ if 'try_linear_fit' in kwargs:
+ try_linear_fit = kwargs.pop('try_linear_fit')
+ if not isinstance(try_linear_fit, bool):
+ illegal_value(try_linear_fit, 'try_linear_fit', 'Fit.fit', raise_error=True)
+ if not self._try_linear_fit:
+ logging.warning('Ignore superfluous keyword argument "try_linear_fit" (not '+
+ 'yet supported for callable models)')
+ else:
+ self._try_linear_fit = try_linear_fit
+# if self._result is None:
+# if 'parameters' in kwargs:
+# raise ValueError('Invalid parameter parameters ({kwargs["parameters"]})')
+# else:
+ if self._result is not None:
+ if guess:
+ logging.warning('Ignoring input parameter guess in Fit.fit during refitting')
+ guess = False
+
+ # Check for circular expressions
+ # FIX TODO
+# for name1, par1 in self._parameters.items():
+# if par1.expr is not None:
+
+ # Apply mask if supplied:
+ if 'mask' in kwargs:
+ self._mask = kwargs.pop('mask')
+ if self._mask is not None:
+ self._mask = np.asarray(self._mask).astype(bool)
+ if self._x.size != self._mask.size:
+ raise ValueError(f'Inconsistent x and mask dimensions ({self._x.size} vs '+
+ f'{self._mask.size})')
+
+ # Estimate initial parameters with build-in lmfit guess method (only for a single model)
+# print(f'\nat start fit: kwargs = {kwargs}')
+#RV print('\nAt start of fit:')
+#RV self._parameters.pretty_print()
+# print(f'parameter_norms:\n{self._parameter_norms}')
+ if guess:
+ if self._mask is None:
+ self._parameters = self._model.guess(self._y, x=self._x)
+ else:
+ self._parameters = self._model.guess(np.asarray(self._y)[~self._mask],
+ x=self._x[~self._mask])
+# print('\nAfter guess:')
+# self._parameters.pretty_print()
+
+ # Add constant offset for a normalized model
+ if self._result is None and self._norm is not None and self._norm[0]:
+ self.add_model('constant', prefix='tmp_normalization_offset_', parameters={'name': 'c',
+ 'value': -self._norm[0], 'vary': False, 'norm': True})
+ #'value': -self._norm[0]/self._norm[1], 'vary': False, 'norm': False})
+
+ # Adjust existing parameters for refit:
+ if 'parameters' in kwargs:
+ parameters = kwargs.pop('parameters')
+ if isinstance(parameters, dict):
+ parameters = (parameters, )
+ elif not is_dict_series(parameters):
+ illegal_value(parameters, 'parameters', 'Fit.fit', raise_error=True)
+ for par in parameters:
+ name = par['name']
+ if name not in self._parameters:
+ raise ValueError(f'Unable to match {name} parameter {par} to an existing one')
+ if self._parameters[name].expr is not None:
+ raise ValueError(f'Unable to modify {name} parameter {par} (currently an '+
+ 'expression)')
+ if par.get('expr') is not None:
+ raise KeyError(f'Illegal "expr" key in {name} parameter {par}')
+ self._parameters[name].set(vary=par.get('vary'))
+ self._parameters[name].set(min=par.get('min'))
+ self._parameters[name].set(max=par.get('max'))
+ self._parameters[name].set(value=par.get('value'))
+#RV print('\nAfter adjust:')
+#RV self._parameters.pretty_print()
+
+ # Apply parameter updates through keyword arguments
+# print(f'kwargs = {kwargs}')
+# print(f'parameter_norms = {self._parameter_norms}')
+ for name in set(self._parameters) & set(kwargs):
+ value = kwargs.pop(name)
+ if self._parameters[name].expr is None:
+ self._parameters[name].set(value=value)
+ else:
+ logging.warning(f'Ignoring parameter {name} in Fit.fit (set by expression: '+
+ f'{self._parameters[name].expr})')
+
+ # Check for uninitialized parameters
+ for name, par in self._parameters.items():
+ if par.expr is None:
+ value = par.value
+ if value is None or np.isinf(value) or np.isnan(value):
+ if interactive:
+ value = input_num(f'Enter an initial value for {name}', default=1.0)
+ else:
+ value = 1.0
+ if self._norm is None or name not in self._parameter_norms:
+ self._parameters[name].set(value=value)
+ elif self._parameter_norms[name]:
+ self._parameters[name].set(value=value*self._norm[1])
+
+ # Check if model is linear
+ try:
+ linear_model = self._check_linearity_model()
+ except:
+ linear_model = False
+# print(f'\n\n--------> linear_model = {linear_model}\n')
+ if kwargs.get('check_only_linearity') is not None:
+ return(linear_model)
+
+ # Normalize the data and initial parameters
+#RV print('\nBefore normalization:')
+#RV self._parameters.pretty_print()
+# print(f'parameter_norms:\n{self._parameter_norms}')
+ self._normalize()
+# print(f'norm = {self._norm}')
+#RV print('\nAfter normalization:')
+#RV self._parameters.pretty_print()
+# self.print_fit_report()
+# print(f'parameter_norms:\n{self._parameter_norms}')
+
+ if linear_model:
+ # Perform a linear fit by direct matrix solution with numpy
+ try:
+ if self._mask is None:
+ self._fit_linear_model(self._x, self._y_norm)
+ else:
+ self._fit_linear_model(self._x[~self._mask],
+ np.asarray(self._y_norm)[~self._mask])
+ except:
+ linear_model = False
+ if not linear_model:
+ # Perform a non-linear fit with lmfit
+ # Prevent initial values from sitting at boundaries
+ self._parameter_bounds = {name:{'min': par.min, 'max': par.max} for name, par in
+ self._parameters.items() if par.vary}
+ for par in self._parameters.values():
+ if par.vary:
+ par.set(value=self._reset_par_at_boundary(par, par.value))
+# print('\nAfter checking boundaries:')
+# self._parameters.pretty_print()
+
+ # Perform the fit
+# fit_kws = None
+# if 'Dfun' in kwargs:
+# fit_kws = {'Dfun': kwargs.pop('Dfun')}
+# self._result = self._model.fit(self._y_norm, self._parameters, x=self._x,
+# fit_kws=fit_kws, **kwargs)
+ if self._mask is None:
+ self._result = self._model.fit(self._y_norm, self._parameters, x=self._x, **kwargs)
+ else:
+ self._result = self._model.fit(np.asarray(self._y_norm)[~self._mask],
+ self._parameters, x=self._x[~self._mask], **kwargs)
+#RV print('\nAfter fit:')
+# print(f'\nself._result ({self._result}):\n\t{self._result.__dict__}')
+#RV self._parameters.pretty_print()
+# self.print_fit_report()
+
+ # Set internal parameter values to fit results upon success
+ if self.success:
+ for name, par in self._parameters.items():
+ if par.expr is None and par.vary:
+ par.set(value=self._result.params[name].value)
+# print('\nAfter update parameter values:')
+# self._parameters.pretty_print()
+
+ # Renormalize the data and results
+ self._renormalize()
+#RV print('\nAfter renormalization:')
+#RV self._parameters.pretty_print()
+# self.print_fit_report()
+
+ def plot(self, y=None, y_title=None, result=None, skip_init=False, plot_comp_legends=False,
+ plot_residual=False, plot_masked_data=True, **kwargs):
+ if result is None:
+ result = self._result
+ if result is None:
+ return
+ plots = []
+ legend = []
+ if self._mask is None:
+ mask = np.zeros(self._x.size).astype(bool)
+ plot_masked_data = False
+ else:
+ mask = self._mask
+ if y is not None:
+ if not isinstance(y, (tuple, list, np.ndarray)):
+ illegal_value(y, 'y', 'Fit.plot')
+ if len(y) != len(self._x):
+ logging.warning('Ignoring parameter y in Fit.plot (wrong dimension)')
+ y = None
+ if y is not None:
+ if y_title is None or not isinstance(y_title, str):
+ y_title = 'data'
+ plots += [(self._x, y, '.')]
+ legend += [y_title]
+ if self._y is not None:
+ plots += [(self._x, np.asarray(self._y), 'b.')]
+ legend += ['data']
+ if plot_masked_data:
+ plots += [(self._x[mask], np.asarray(self._y)[mask], 'bx')]
+ legend += ['masked data']
+ if isinstance(plot_residual, bool) and plot_residual:
+ plots += [(self._x[~mask], result.residual, 'k-')]
+ legend += ['residual']
+ plots += [(self._x[~mask], result.best_fit, 'k-')]
+ legend += ['best fit']
+ if not skip_init and hasattr(result, 'init_fit'):
+ plots += [(self._x[~mask], result.init_fit, 'g-')]
+ legend += ['init']
+ components = result.eval_components(x=self._x[~mask])
+ num_components = len(components)
+ if 'tmp_normalization_offset_' in components:
+ num_components -= 1
+ if num_components > 1:
+ eval_index = 0
+ for modelname, y in components.items():
+ if modelname == 'tmp_normalization_offset_':
+ continue
+ if modelname == '_eval':
+ modelname = f'eval{eval_index}'
+ if len(modelname) > 20:
+ modelname = f'{modelname[0:16]} ...'
+ if isinstance(y, (int, float)):
+ y *= np.ones(self._x[~mask].size)
+ plots += [(self._x[~mask], y, '--')]
+ if plot_comp_legends:
+ if modelname[-1] == '_':
+ legend.append(modelname[:-1])
+ else:
+ legend.append(modelname)
+ title = kwargs.get('title')
+ if title is not None:
+ kwargs.pop('title')
+ quick_plot(tuple(plots), legend=legend, title=title, block=True, **kwargs)
+
+ @staticmethod
+ def guess_init_peak(x, y, *args, center_guess=None, use_max_for_center=True):
+ """ Return a guess for the initial height, center and fwhm for a peak
+ """
+# print(f'\n\nargs = {args}')
+# print(f'center_guess = {center_guess}')
+# quick_plot(x, y, vlines=center_guess, block=True)
+ center_guesses = None
+ x = np.asarray(x)
+ y = np.asarray(y)
+ if len(x) != len(y):
+ logging.error(f'Invalid x and y lengths ({len(x)}, {len(y)}), skip initial guess')
+ return(None, None, None)
+ if isinstance(center_guess, (int, float)):
+ if len(args):
+ logging.warning('Ignoring additional arguments for single center_guess value')
+ center_guesses = [center_guess]
+ elif isinstance(center_guess, (tuple, list, np.ndarray)):
+ if len(center_guess) == 1:
+ logging.warning('Ignoring additional arguments for single center_guess value')
+ if not isinstance(center_guess[0], (int, float)):
+ raise ValueError(f'Invalid parameter center_guess ({type(center_guess[0])})')
+ center_guess = center_guess[0]
+ else:
+ if len(args) != 1:
+ raise ValueError(f'Invalid number of arguments ({len(args)})')
+ n = args[0]
+ if not is_index(n, 0, len(center_guess)):
+ raise ValueError('Invalid argument')
+ center_guesses = center_guess
+ center_guess = center_guesses[n]
+ elif center_guess is not None:
+ raise ValueError(f'Invalid center_guess type ({type(center_guess)})')
+# print(f'x = {x}')
+# print(f'y = {y}')
+# print(f'center_guess = {center_guess}')
+
+ # Sort the inputs
+ index = np.argsort(x)
+ x = x[index]
+ y = y[index]
+ miny = y.min()
+# print(f'miny = {miny}')
+# print(f'x_range = {x[0]} {x[-1]} {len(x)}')
+# print(f'y_range = {y[0]} {y[-1]} {len(y)}')
+# quick_plot(x, y, vlines=center_guess, block=True)
+
+# xx = x
+# yy = y
+ # Set range for current peak
+# print(f'n = {n}')
+# print(f'center_guesses = {center_guesses}')
+ if center_guesses is not None:
+ if len(center_guesses) > 1:
+ index = np.argsort(center_guesses)
+ n = list(index).index(n)
+# print(f'n = {n}')
+# print(f'index = {index}')
+ center_guesses = np.asarray(center_guesses)[index]
+# print(f'center_guesses = {center_guesses}')
+ if n == 0:
+ low = 0
+ upp = index_nearest(x, (center_guesses[0]+center_guesses[1])/2)
+ elif n == len(center_guesses)-1:
+ low = index_nearest(x, (center_guesses[n-1]+center_guesses[n])/2)
+ upp = len(x)
+ else:
+ low = index_nearest(x, (center_guesses[n-1]+center_guesses[n])/2)
+ upp = index_nearest(x, (center_guesses[n]+center_guesses[n+1])/2)
+# print(f'low = {low}')
+# print(f'upp = {upp}')
+ x = x[low:upp]
+ y = y[low:upp]
+# quick_plot(x, y, vlines=(x[0], center_guess, x[-1]), block=True)
+
+ # Estimate FHHM
+ maxy = y.max()
+# print(f'x_range = {x[0]} {x[-1]} {len(x)}')
+# print(f'y_range = {y[0]} {y[-1]} {len(y)} {miny} {maxy}')
+# print(f'center_guess = {center_guess}')
+ if center_guess is None:
+ center_index = np.argmax(y)
+ center = x[center_index]
+ height = maxy-miny
+ else:
+ if use_max_for_center:
+ center_index = np.argmax(y)
+ center = x[center_index]
+ if center_index < 0.1*len(x) or center_index > 0.9*len(x):
+ center_index = index_nearest(x, center_guess)
+ center = center_guess
+ else:
+ center_index = index_nearest(x, center_guess)
+ center = center_guess
+ height = y[center_index]-miny
+# print(f'center_index = {center_index}')
+# print(f'center = {center}')
+# print(f'height = {height}')
+ half_height = miny+0.5*height
+# print(f'half_height = {half_height}')
+ fwhm_index1 = 0
+ for i in range(center_index, fwhm_index1, -1):
+ if y[i] < half_height:
+ fwhm_index1 = i
+ break
+# print(f'fwhm_index1 = {fwhm_index1} {x[fwhm_index1]}')
+ fwhm_index2 = len(x)-1
+ for i in range(center_index, fwhm_index2):
+ if y[i] < half_height:
+ fwhm_index2 = i
+ break
+# print(f'fwhm_index2 = {fwhm_index2} {x[fwhm_index2]}')
+# quick_plot((x,y,'o'), vlines=(x[fwhm_index1], center, x[fwhm_index2]), block=True)
+ if fwhm_index1 == 0 and fwhm_index2 < len(x)-1:
+ fwhm = 2*(x[fwhm_index2]-center)
+ elif fwhm_index1 > 0 and fwhm_index2 == len(x)-1:
+ fwhm = 2*(center-x[fwhm_index1])
+ else:
+ fwhm = x[fwhm_index2]-x[fwhm_index1]
+# print(f'fwhm_index1 = {fwhm_index1} {x[fwhm_index1]}')
+# print(f'fwhm_index2 = {fwhm_index2} {x[fwhm_index2]}')
+# print(f'fwhm = {fwhm}')
+
+ # Return height, center and FWHM
+# quick_plot((x,y,'o'), (xx,yy), vlines=(x[fwhm_index1], center, x[fwhm_index2]), block=True)
+ return(height, center, fwhm)
+
+ def _check_linearity_model(self):
+ """Identify the linearity of all model parameters and check if the model is linear or not
+ """
+ if not self._try_linear_fit:
+ logging.info('Skip linearity check (not yet supported for callable models)')
+ return(False)
+ free_parameters = [name for name, par in self._parameters.items() if par.vary]
+ for component in self._model.components:
+ if 'tmp_normalization_offset_c' in component.param_names:
+ continue
+ if isinstance(component, ExpressionModel):
+ for name in free_parameters:
+ if diff(component.expr, name, name):
+# print(f'\t\t{component.expr} is non-linear in {name}')
+ self._nonlinear_parameters.append(name)
+ if name in self._linear_parameters:
+ self._linear_parameters.remove(name)
+ else:
+ model_parameters = component.param_names.copy()
+ for basename, hint in component.param_hints.items():
+ name = f'{component.prefix}{basename}'
+ if hint.get('expr') is not None:
+ model_parameters.remove(name)
+ for name in model_parameters:
+ expr = self._parameters[name].expr
+ if expr is not None:
+ for nname in free_parameters:
+ if name in self._nonlinear_parameters:
+ if diff(expr, nname):
+# print(f'\t\t{component} is non-linear in {nname} (through {name} = "{expr}")')
+ self._nonlinear_parameters.append(nname)
+ if nname in self._linear_parameters:
+ self._linear_parameters.remove(nname)
+ else:
+ assert(name in self._linear_parameters)
+# print(f'\n\nexpr ({type(expr)}) = {expr}\nnname ({type(nname)}) = {nname}\n\n')
+ if diff(expr, nname, nname):
+# print(f'\t\t{component} is non-linear in {nname} (through {name} = "{expr}")')
+ self._nonlinear_parameters.append(nname)
+ if nname in self._linear_parameters:
+ self._linear_parameters.remove(nname)
+# print(f'\nfree parameters:\n\t{free_parameters}')
+# print(f'linear parameters:\n\t{self._linear_parameters}')
+# print(f'nonlinear parameters:\n\t{self._nonlinear_parameters}\n')
+ if any(True for name in self._nonlinear_parameters if self._parameters[name].vary):
+ return(False)
+ return(True)
+
+ def _fit_linear_model(self, x, y):
+ """Perform a linear fit by direct matrix solution with numpy
+ """
+ # Construct the matrix and the free parameter vector
+# print(f'\nparameters:')
+# self._parameters.pretty_print()
+# print(f'\nparameter_norms:\n\t{self._parameter_norms}')
+# print(f'\nlinear_parameters:\n\t{self._linear_parameters}')
+# print(f'nonlinear_parameters:\n\t{self._nonlinear_parameters}')
+ free_parameters = [name for name, par in self._parameters.items() if par.vary]
+# print(f'free parameters:\n\t{free_parameters}\n')
+ expr_parameters = {name:par.expr for name, par in self._parameters.items()
+ if par.expr is not None}
+ model_parameters = []
+ for component in self._model.components:
+ if 'tmp_normalization_offset_c' in component.param_names:
+ continue
+ model_parameters += component.param_names
+ for basename, hint in component.param_hints.items():
+ name = f'{component.prefix}{basename}'
+ if hint.get('expr') is not None:
+ expr_parameters.pop(name)
+ model_parameters.remove(name)
+# print(f'expr parameters:\n{expr_parameters}')
+# print(f'model parameters:\n\t{model_parameters}\n')
+ norm = 1.0
+ if self._normalized:
+ norm = self._norm[1]
+# print(f'\n\nself._normalized = {self._normalized}\nnorm = {norm}\nself._norm = {self._norm}\n')
+ # Add expression parameters to asteval
+ ast = Interpreter()
+# print(f'Adding to asteval sym table:')
+ for name, expr in expr_parameters.items():
+# print(f'\tadding {name} {expr}')
+ ast.symtable[name] = expr
+ # Add constant parameters to asteval
+ # (renormalize to use correctly in evaluation of expression models)
+ for name, par in self._parameters.items():
+ if par.expr is None and not par.vary:
+ if self._parameter_norms[name]:
+# print(f'\tadding {name} {par.value*norm}')
+ ast.symtable[name] = par.value*norm
+ else:
+# print(f'\tadding {name} {par.value}')
+ ast.symtable[name] = par.value
+ A = np.zeros((len(x), len(free_parameters)), dtype='float64')
+ y_const = np.zeros(len(x), dtype='float64')
+ have_expression_model = False
+ for component in self._model.components:
+ if isinstance(component, ConstantModel):
+ name = component.param_names[0]
+# print(f'\nConstant model: {name} {self._parameters[name]}\n')
+ if name in free_parameters:
+# print(f'\t\t{name} is a free constant set matrix column {free_parameters.index(name)} to 1.0')
+ A[:,free_parameters.index(name)] = 1.0
+ else:
+ if self._parameter_norms[name]:
+ delta_y_const = self._parameters[name]*np.ones(len(x))
+ else:
+ delta_y_const = (self._parameters[name]*norm)*np.ones(len(x))
+ y_const += delta_y_const
+# print(f'\ndelta_y_const ({type(delta_y_const)}):\n{delta_y_const}\n')
+ elif isinstance(component, ExpressionModel):
+ have_expression_model = True
+ const_expr = component.expr
+# print(f'\nExpression model:\nconst_expr: {const_expr}\n')
+ for name in free_parameters:
+ dexpr_dname = diff(component.expr, name)
+ if dexpr_dname:
+ const_expr = f'{const_expr}-({str(dexpr_dname)})*{name}'
+# print(f'\tconst_expr: {const_expr}')
+ if not self._parameter_norms[name]:
+ dexpr_dname = f'({dexpr_dname})/{norm}'
+# print(f'\t{component.expr} is linear in {name}\n\t\tadd "{str(dexpr_dname)}" to matrix as column {free_parameters.index(name)}')
+ fx = [(lambda _: ast.eval(str(dexpr_dname)))(ast(f'x={v}')) for v in x]
+# print(f'\tfx:\n{fx}')
+ if len(ast.error):
+ raise ValueError(f'Unable to evaluate {dexpr_dname}')
+ A[:,free_parameters.index(name)] += fx
+# if self._parameter_norms[name]:
+# print(f'\t\t{component.expr} is linear in {name} add "{str(dexpr_dname)}" to matrix as column {free_parameters.index(name)}')
+# A[:,free_parameters.index(name)] += fx
+# else:
+# print(f'\t\t{component.expr} is linear in {name} add "({str(dexpr_dname)})/{norm}" to matrix as column {free_parameters.index(name)}')
+# A[:,free_parameters.index(name)] += np.asarray(fx)/norm
+ # FIX: find another solution if expr not supported by simplify
+ const_expr = str(simplify(f'({const_expr})/{norm}'))
+# print(f'\nconst_expr: {const_expr}')
+ delta_y_const = [(lambda _: ast.eval(const_expr))(ast(f'x = {v}')) for v in x]
+ y_const += delta_y_const
+# print(f'\ndelta_y_const ({type(delta_y_const)}):\n{delta_y_const}\n')
+ if len(ast.error):
+ raise ValueError(f'Unable to evaluate {const_expr}')
+ else:
+ free_model_parameters = [name for name in component.param_names
+ if name in free_parameters or name in expr_parameters]
+# print(f'\nBuild-in model ({component}):\nfree_model_parameters: {free_model_parameters}\n')
+ if not len(free_model_parameters):
+ y_const += component.eval(params=self._parameters, x=x)
+ elif isinstance(component, LinearModel):
+ if f'{component.prefix}slope' in free_model_parameters:
+ A[:,free_parameters.index(f'{component.prefix}slope')] = x
+ else:
+ y_const += self._parameters[f'{component.prefix}slope'].value*x
+ if f'{component.prefix}intercept' in free_model_parameters:
+ A[:,free_parameters.index(f'{component.prefix}intercept')] = 1.0
+ else:
+ y_const += self._parameters[f'{component.prefix}intercept'].value* \
+ np.ones(len(x))
+ elif isinstance(component, QuadraticModel):
+ if f'{component.prefix}a' in free_model_parameters:
+ A[:,free_parameters.index(f'{component.prefix}a')] = x**2
+ else:
+ y_const += self._parameters[f'{component.prefix}a'].value*x**2
+ if f'{component.prefix}b' in free_model_parameters:
+ A[:,free_parameters.index(f'{component.prefix}b')] = x
+ else:
+ y_const += self._parameters[f'{component.prefix}b'].value*x
+ if f'{component.prefix}c' in free_model_parameters:
+ A[:,free_parameters.index(f'{component.prefix}c')] = 1.0
+ else:
+ y_const += self._parameters[f'{component.prefix}c'].value*np.ones(len(x))
+ else:
+ # At this point each build-in model must be strictly proportional to each linear
+ # model parameter. Without this assumption, the model equation is needed
+ # For the current build-in lmfit models, this can only ever be the amplitude
+ assert(len(free_model_parameters) == 1)
+ name = f'{component.prefix}amplitude'
+ assert(free_model_parameters[0] == name)
+ assert(self._parameter_norms[name])
+ expr = self._parameters[name].expr
+ if expr is None:
+# print(f'\t{component} is linear in {name} add to matrix as column {free_parameters.index(name)}')
+ parameters = deepcopy(self._parameters)
+ parameters[name].set(value=1.0)
+ index = free_parameters.index(name)
+ A[:,free_parameters.index(name)] += component.eval(params=parameters, x=x)
+ else:
+ const_expr = expr
+# print(f'\tconst_expr: {const_expr}')
+ parameters = deepcopy(self._parameters)
+ parameters[name].set(value=1.0)
+ dcomp_dname = component.eval(params=parameters, x=x)
+# print(f'\tdcomp_dname ({type(dcomp_dname)}):\n{dcomp_dname}')
+ for nname in free_parameters:
+ dexpr_dnname = diff(expr, nname)
+ if dexpr_dnname:
+ assert(self._parameter_norms[name])
+# print(f'\t\td({expr})/d{nname} = {dexpr_dnname}')
+# print(f'\t\t{component} is linear in {nname} (through {name} = "{expr}", add to matrix as column {free_parameters.index(nname)})')
+ fx = np.asarray(dexpr_dnname*dcomp_dname, dtype='float64')
+# print(f'\t\tfx ({type(fx)}): {fx}')
+# print(f'free_parameters.index({nname}): {free_parameters.index(nname)}')
+ if self._parameter_norms[nname]:
+ A[:,free_parameters.index(nname)] += fx
+ else:
+ A[:,free_parameters.index(nname)] += fx/norm
+ const_expr = f'{const_expr}-({dexpr_dnname})*{nname}'
+# print(f'\t\tconst_expr: {const_expr}')
+ const_expr = str(simplify(f'({const_expr})/{norm}'))
+# print(f'\tconst_expr: {const_expr}')
+ fx = [(lambda _: ast.eval(const_expr))(ast(f'x = {v}')) for v in x]
+# print(f'\tfx: {fx}')
+ delta_y_const = np.multiply(fx, dcomp_dname)
+ y_const += delta_y_const
+# print(f'\ndelta_y_const ({type(delta_y_const)}):\n{delta_y_const}\n')
+# print(A)
+# print(y_const)
+ solution, residual, rank, s = np.linalg.lstsq(A, y-y_const, rcond=None)
+# print(f'\nsolution ({type(solution)} {solution.shape}):\n\t{solution}')
+# print(f'\nresidual ({type(residual)} {residual.shape}):\n\t{residual}')
+# print(f'\nrank ({type(rank)} {rank.shape}):\n\t{rank}')
+# print(f'\ns ({type(s)} {s.shape}):\n\t{s}\n')
+
+ # Assemble result (compensate for normalization in expression models)
+ for name, value in zip(free_parameters, solution):
+ self._parameters[name].set(value=value)
+ if self._normalized and (have_expression_model or len(expr_parameters)):
+ for name, norm in self._parameter_norms.items():
+ par = self._parameters[name]
+ if par.expr is None and norm:
+ self._parameters[name].set(value=par.value*self._norm[1])
+# self._parameters.pretty_print()
+# print(f'\nself._parameter_norms:\n\t{self._parameter_norms}')
+ self._result = ModelResult(self._model, deepcopy(self._parameters))
+ self._result.best_fit = self._model.eval(params=self._parameters, x=x)
+ if self._normalized and (have_expression_model or len(expr_parameters)):
+ if 'tmp_normalization_offset_c' in self._parameters:
+ offset = self._parameters['tmp_normalization_offset_c']
+ else:
+ offset = 0.0
+ self._result.best_fit = (self._result.best_fit-offset-self._norm[0])/self._norm[1]
+ if self._normalized:
+ for name, norm in self._parameter_norms.items():
+ par = self._parameters[name]
+ if par.expr is None and norm:
+ value = par.value/self._norm[1]
+ self._parameters[name].set(value=value)
+ self._result.params[name].set(value=value)
+# self._parameters.pretty_print()
+ self._result.residual = self._result.best_fit-y
+ self._result.components = self._model.components
+ self._result.init_params = None
+# quick_plot((x, y, '.'), (x, y_const, 'g'), (x, self._result.best_fit, 'k'), (x, self._result.residual, 'r'), block=True)
+
+ def _normalize(self):
+ """Normalize the data and initial parameters
+ """
+ if self._normalized:
+ return
+ if self._norm is None:
+ if self._y is not None and self._y_norm is None:
+ self._y_norm = np.asarray(self._y)
+ else:
+ if self._y is not None and self._y_norm is None:
+ self._y_norm = (np.asarray(self._y)-self._norm[0])/self._norm[1]
+ self._y_range = 1.0
+ for name, norm in self._parameter_norms.items():
+ par = self._parameters[name]
+ if par.expr is None and norm:
+ value = par.value/self._norm[1]
+ _min = par.min
+ _max = par.max
+ if not np.isinf(_min) and abs(_min) != float_min:
+ _min /= self._norm[1]
+ if not np.isinf(_max) and abs(_max) != float_min:
+ _max /= self._norm[1]
+ par.set(value=value, min=_min, max=_max)
+ self._normalized = True
+
+ def _renormalize(self):
+ """Renormalize the data and results
+ """
+ if self._norm is None or not self._normalized:
+ return
+ self._normalized = False
+ for name, norm in self._parameter_norms.items():
+ par = self._parameters[name]
+ if par.expr is None and norm:
+ value = par.value*self._norm[1]
+ _min = par.min
+ _max = par.max
+ if not np.isinf(_min) and abs(_min) != float_min:
+ _min *= self._norm[1]
+ if not np.isinf(_max) and abs(_max) != float_min:
+ _max *= self._norm[1]
+ par.set(value=value, min=_min, max=_max)
+ if self._result is None:
+ return
+ self._result.best_fit = self._result.best_fit*self._norm[1]+self._norm[0]
+ for name, par in self._result.params.items():
+ if self._parameter_norms.get(name, False):
+ if par.stderr is not None:
+ par.stderr *= self._norm[1]
+ if par.expr is None:
+ _min = par.min
+ _max = par.max
+ value = par.value*self._norm[1]
+ if par.init_value is not None:
+ par.init_value *= self._norm[1]
+ if not np.isinf(_min) and abs(_min) != float_min:
+ _min *= self._norm[1]
+ if not np.isinf(_max) and abs(_max) != float_min:
+ _max *= self._norm[1]
+ par.set(value=value, min=_min, max=_max)
+ if hasattr(self._result, 'init_fit'):
+ self._result.init_fit = self._result.init_fit*self._norm[1]+self._norm[0]
+ if hasattr(self._result, 'init_values'):
+ init_values = {}
+ for name, value in self._result.init_values.items():
+ if name not in self._parameter_norms or self._parameters[name].expr is not None:
+ init_values[name] = value
+ elif self._parameter_norms[name]:
+ init_values[name] = value*self._norm[1]
+ self._result.init_values = init_values
+ for name, par in self._result.init_params.items():
+ if par.expr is None and self._parameter_norms.get(name, False):
+ value = par.value
+ _min = par.min
+ _max = par.max
+ value *= self._norm[1]
+ if not np.isinf(_min) and abs(_min) != float_min:
+ _min *= self._norm[1]
+ if not np.isinf(_max) and abs(_max) != float_min:
+ _max *= self._norm[1]
+ par.set(value=value, min=_min, max=_max)
+ par.init_value = par.value
+ # Don't renormalize chisqr, it has no useful meaning in physical units
+ #self._result.chisqr *= self._norm[1]*self._norm[1]
+ if self._result.covar is not None:
+ for i, name in enumerate(self._result.var_names):
+ if self._parameter_norms.get(name, False):
+ for j in range(len(self._result.var_names)):
+ if self._result.covar[i,j] is not None:
+ self._result.covar[i,j] *= self._norm[1]
+ if self._result.covar[j,i] is not None:
+ self._result.covar[j,i] *= self._norm[1]
+ # Don't renormalize redchi, it has no useful meaning in physical units
+ #self._result.redchi *= self._norm[1]*self._norm[1]
+ if self._result.residual is not None:
+ self._result.residual *= self._norm[1]
+
+ def _reset_par_at_boundary(self, par, value):
+ assert(par.vary)
+ name = par.name
+ _min = self._parameter_bounds[name]['min']
+ _max = self._parameter_bounds[name]['max']
+ if np.isinf(_min):
+ if not np.isinf(_max):
+ if self._parameter_norms.get(name, False):
+ upp = _max-0.1*self._y_range
+ elif _max == 0.0:
+ upp = _max-0.1
+ else:
+ upp = _max-0.1*abs(_max)
+ if value >= upp:
+ return(upp)
+ else:
+ if np.isinf(_max):
+ if self._parameter_norms.get(name, False):
+ low = _min+0.1*self._y_range
+ elif _min == 0.0:
+ low = _min+0.1
+ else:
+ low = _min+0.1*abs(_min)
+ if value <= low:
+ return(low)
+ else:
+ low = 0.9*_min+0.1*_max
+ upp = 0.1*_min+0.9*_max
+ if value <= low:
+ return(low)
+ elif value >= upp:
+ return(upp)
+ return(value)
+
+
+class FitMultipeak(Fit):
+ """Fit data with multiple peaks
+ """
+ def __init__(self, y, x=None, normalize=True):
+ super().__init__(y, x=x, normalize=normalize)
+ self._fwhm_max = None
+ self._sigma_max = None
+
+ @classmethod
+ def fit_multipeak(cls, y, centers, x=None, normalize=True, peak_models='gaussian',
+ center_exprs=None, fit_type=None, background_order=None, background_exp=False,
+ fwhm_max=None, plot_components=False):
+ """Make sure that centers and fwhm_max are in the correct units and consistent with expr
+ for a uniform fit (fit_type == 'uniform')
+ """
+ fit = cls(y, x=x, normalize=normalize)
+ success = fit.fit(centers, fit_type=fit_type, peak_models=peak_models, fwhm_max=fwhm_max,
+ center_exprs=center_exprs, background_order=background_order,
+ background_exp=background_exp, plot_components=plot_components)
+ if success:
+ return(fit.best_fit, fit.residual, fit.best_values, fit.best_errors, fit.redchi, \
+ fit.success)
+ else:
+ return(np.array([]), np.array([]), {}, {}, float_max, False)
+
+ def fit(self, centers, fit_type=None, peak_models=None, center_exprs=None, fwhm_max=None,
+ background_order=None, background_exp=False, plot_components=False,
+ param_constraint=False):
+ self._fwhm_max = fwhm_max
+ # Create the multipeak model
+ self._create_model(centers, fit_type, peak_models, center_exprs, background_order,
+ background_exp, param_constraint)
+
+ # RV: Obsolete Normalize the data and results
+# print('\nBefore fit before normalization in FitMultipeak:')
+# self._parameters.pretty_print()
+# self._normalize()
+# print('\nBefore fit after normalization in FitMultipeak:')
+# self._parameters.pretty_print()
+
+ # Perform the fit
+ try:
+ if param_constraint:
+ super().fit(fit_kws={'xtol': 1.e-5, 'ftol': 1.e-5, 'gtol': 1.e-5})
+ else:
+ super().fit()
+ except:
+ return(False)
+
+ # Check for valid fit parameter results
+ fit_failure = self._check_validity()
+ success = True
+ if fit_failure:
+ if param_constraint:
+ logging.warning(' -> Should not happen with param_constraint set, fail the fit')
+ success = False
+ else:
+ logging.info(' -> Retry fitting with constraints')
+ self.fit(centers, fit_type, peak_models, center_exprs, fwhm_max=fwhm_max,
+ background_order=background_order, background_exp=background_exp,
+ plot_components=plot_components, param_constraint=True)
+ else:
+ # RV: Obsolete Renormalize the data and results
+# print('\nAfter fit before renormalization in FitMultipeak:')
+# self._parameters.pretty_print()
+# self.print_fit_report()
+# self._renormalize()
+# print('\nAfter fit after renormalization in FitMultipeak:')
+# self._parameters.pretty_print()
+# self.print_fit_report()
+
+ # Print report and plot components if requested
+ if plot_components:
+ self.print_fit_report()
+ self.plot()
+
+ return(success)
+
+ def _create_model(self, centers, fit_type=None, peak_models=None, center_exprs=None,
+ background_order=None, background_exp=False, param_constraint=False):
+ """Create the multipeak model
+ """
+ if isinstance(centers, (int, float)):
+ centers = [centers]
+ num_peaks = len(centers)
+ if peak_models is None:
+ peak_models = num_peaks*['gaussian']
+ elif isinstance(peak_models, str):
+ peak_models = num_peaks*[peak_models]
+ if len(peak_models) != num_peaks:
+ raise ValueError(f'Inconsistent number of peaks in peak_models ({len(peak_models)} vs '+
+ f'{num_peaks})')
+ if num_peaks == 1:
+ if fit_type is not None:
+ logging.debug('Ignoring fit_type input for fitting one peak')
+ fit_type = None
+ if center_exprs is not None:
+ logging.debug('Ignoring center_exprs input for fitting one peak')
+ center_exprs = None
+ else:
+ if fit_type == 'uniform':
+ if center_exprs is None:
+ center_exprs = [f'scale_factor*{cen}' for cen in centers]
+ if len(center_exprs) != num_peaks:
+ raise ValueError(f'Inconsistent number of peaks in center_exprs '+
+ f'({len(center_exprs)} vs {num_peaks})')
+ elif fit_type == 'unconstrained' or fit_type is None:
+ if center_exprs is not None:
+ logging.warning('Ignoring center_exprs input for unconstrained fit')
+ center_exprs = None
+ else:
+ raise ValueError(f'Invalid fit_type in fit_multigaussian {fit_type}')
+ self._sigma_max = None
+ if param_constraint:
+ min_value = float_min
+ if self._fwhm_max is not None:
+ self._sigma_max = np.zeros(num_peaks)
+ else:
+ min_value = None
+
+ # Reset the fit
+ self._model = None
+ self._parameters = Parameters()
+ self._result = None
+
+ # Add background model
+ if background_order is not None:
+ if background_order == 0:
+ self.add_model('constant', prefix='background', parameters=
+ {'name': 'c', 'value': float_min, 'min': min_value})
+ elif background_order == 1:
+ self.add_model('linear', prefix='background', slope=0.0, intercept=0.0)
+ elif background_order == 2:
+ self.add_model('quadratic', prefix='background', a=0.0, b=0.0, c=0.0)
+ else:
+ raise ValueError(f'Invalid parameter background_order ({background_order})')
+ if background_exp:
+ self.add_model('exponential', prefix='background', parameters=(
+ {'name': 'amplitude', 'value': float_min, 'min': min_value},
+ {'name': 'decay', 'value': float_min, 'min': min_value}))
+
+ # Add peaks and guess initial fit parameters
+ ast = Interpreter()
+ if num_peaks == 1:
+ height_init, cen_init, fwhm_init = self.guess_init_peak(self._x, self._y)
+ if self._fwhm_max is not None and fwhm_init > self._fwhm_max:
+ fwhm_init = self._fwhm_max
+ ast(f'fwhm = {fwhm_init}')
+ ast(f'height = {height_init}')
+ sig_init = ast(fwhm_factor[peak_models[0]])
+ amp_init = ast(height_factor[peak_models[0]])
+ sig_max = None
+ if self._sigma_max is not None:
+ ast(f'fwhm = {self._fwhm_max}')
+ sig_max = ast(fwhm_factor[peak_models[0]])
+ self._sigma_max[0] = sig_max
+ self.add_model(peak_models[0], parameters=(
+ {'name': 'amplitude', 'value': amp_init, 'min': min_value},
+ {'name': 'center', 'value': cen_init, 'min': min_value},
+ {'name': 'sigma', 'value': sig_init, 'min': min_value, 'max': sig_max}))
+ else:
+ if fit_type == 'uniform':
+ self.add_parameter(name='scale_factor', value=1.0)
+ for i in range(num_peaks):
+ height_init, cen_init, fwhm_init = self.guess_init_peak(self._x, self._y, i,
+ center_guess=centers)
+ if self._fwhm_max is not None and fwhm_init > self._fwhm_max:
+ fwhm_init = self._fwhm_max
+ ast(f'fwhm = {fwhm_init}')
+ ast(f'height = {height_init}')
+ sig_init = ast(fwhm_factor[peak_models[i]])
+ amp_init = ast(height_factor[peak_models[i]])
+ sig_max = None
+ if self._sigma_max is not None:
+ ast(f'fwhm = {self._fwhm_max}')
+ sig_max = ast(fwhm_factor[peak_models[i]])
+ self._sigma_max[i] = sig_max
+ if fit_type == 'uniform':
+ self.add_model(peak_models[i], prefix=f'peak{i+1}_', parameters=(
+ {'name': 'amplitude', 'value': amp_init, 'min': min_value},
+ {'name': 'center', 'expr': center_exprs[i]},
+ {'name': 'sigma', 'value': sig_init, 'min': min_value,
+ 'max': sig_max}))
+ else:
+ self.add_model('gaussian', prefix=f'peak{i+1}_', parameters=(
+ {'name': 'amplitude', 'value': amp_init, 'min': min_value},
+ {'name': 'center', 'value': cen_init, 'min': min_value},
+ {'name': 'sigma', 'value': sig_init, 'min': min_value,
+ 'max': sig_max}))
+
+ def _check_validity(self):
+ """Check for valid fit parameter results
+ """
+ fit_failure = False
+ index = compile(r'\d+')
+ for name, par in self.best_parameters.items():
+ if 'background' in name:
+# if ((name == 'backgroundc' and par['value'] <= 0.0) or
+# (name.endswith('amplitude') and par['value'] <= 0.0) or
+ if ((name.endswith('amplitude') and par['value'] <= 0.0) or
+ (name.endswith('decay') and par['value'] <= 0.0)):
+ logging.info(f'Invalid fit result for {name} ({par["value"]})')
+ fit_failure = True
+ elif (((name.endswith('amplitude') or name.endswith('height')) and
+ par['value'] <= 0.0) or
+ ((name.endswith('sigma') or name.endswith('fwhm')) and par['value'] <= 0.0) or
+ (name.endswith('center') and par['value'] <= 0.0) or
+ (name == 'scale_factor' and par['value'] <= 0.0)):
+ logging.info(f'Invalid fit result for {name} ({par["value"]})')
+ fit_failure = True
+ if name.endswith('sigma') and self._sigma_max is not None:
+ if name == 'sigma':
+ sigma_max = self._sigma_max[0]
+ else:
+ sigma_max = self._sigma_max[int(index.search(name).group())-1]
+ if par['value'] > sigma_max:
+ logging.info(f'Invalid fit result for {name} ({par["value"]})')
+ fit_failure = True
+ elif par['value'] == sigma_max:
+ logging.warning(f'Edge result on for {name} ({par["value"]})')
+ if name.endswith('fwhm') and self._fwhm_max is not None:
+ if par['value'] > self._fwhm_max:
+ logging.info(f'Invalid fit result for {name} ({par["value"]})')
+ fit_failure = True
+ elif par['value'] == self._fwhm_max:
+ logging.warning(f'Edge result on for {name} ({par["value"]})')
+ return(fit_failure)
+
+
+class FitMap(Fit):
+ """Fit a map of data
+ """
+ def __init__(self, ymap, x=None, models=None, normalize=True, transpose=None, **kwargs):
+ super().__init__(None)
+ self._best_errors = None
+ self._best_fit = None
+ self._best_parameters = None
+ self._best_values = None
+ self._inv_transpose = None
+ self._max_nfev = None
+ self._memfolder = None
+ self._new_parameters = None
+ self._out_of_bounds = None
+ self._plot = False
+ self._print_report = False
+ self._redchi = None
+ self._redchi_cutoff = 0.1
+ self._skip_init = True
+ self._success = None
+ self._transpose = None
+ self._try_no_bounds = True
+
+ # At this point the fastest index should always be the signal dimension so that the slowest
+ # ndim-1 dimensions are the map dimensions
+ if isinstance(ymap, (tuple, list, np.ndarray)):
+ self._x = np.asarray(x)
+ elif have_xarray and isinstance(ymap, xr.DataArray):
+ if x is not None:
+ logging.warning('Ignoring superfluous input x ({x}) in Fit.__init__')
+ self._x = np.asarray(ymap[ymap.dims[-1]])
+ else:
+ illegal_value(ymap, 'ymap', 'FitMap:__init__', raise_error=True)
+ self._ymap = ymap
+
+ # Verify the input parameters
+ if self._x.ndim != 1:
+ raise ValueError(f'Invalid dimension for input x {self._x.ndim}')
+ if self._ymap.ndim < 2:
+ raise ValueError('Invalid number of dimension of the input dataset '+
+ f'{self._ymap.ndim}')
+ if self._x.size != self._ymap.shape[-1]:
+ raise ValueError(f'Inconsistent x and y dimensions ({self._x.size} vs '+
+ f'{self._ymap.shape[-1]})')
+ if not isinstance(normalize, bool):
+ logging.warning(f'Invalid value for normalize ({normalize}) in Fit.__init__: '+
+ 'setting normalize to True')
+ normalize = True
+ if isinstance(transpose, bool) and not transpose:
+ transpose = None
+ if transpose is not None and self._ymap.ndim < 3:
+ logging.warning(f'Transpose meaningless for {self._ymap.ndim-1}D data maps: ignoring '+
+ 'transpose')
+ if transpose is not None:
+ if self._ymap.ndim == 3 and isinstance(transpose, bool) and transpose:
+ self._transpose = (1, 0)
+ elif not isinstance(transpose, (tuple, list)):
+ logging.warning(f'Invalid data type for transpose ({transpose}, '+
+ f'{type(transpose)}) in Fit.__init__: setting transpose to False')
+ elif len(transpose) != self._ymap.ndim-1:
+ logging.warning(f'Invalid dimension for transpose ({transpose}, must be equal to '+
+ f'{self._ymap.ndim-1}) in Fit.__init__: setting transpose to False')
+ elif any(i not in transpose for i in range(len(transpose))):
+ logging.warning(f'Invalid index in transpose ({transpose}) '+
+ f'in Fit.__init__: setting transpose to False')
+ elif not all(i==transpose[i] for i in range(self._ymap.ndim-1)):
+ self._transpose = transpose
+ if self._transpose is not None:
+ self._inv_transpose = tuple(self._transpose.index(i)
+ for i in range(len(self._transpose)))
+
+ # Flatten the map (transpose if requested)
+ # Store the flattened map in self._ymap_norm, whether normalized or not
+ if self._transpose is not None:
+ self._ymap_norm = np.transpose(np.asarray(self._ymap), list(self._transpose)+
+ [len(self._transpose)])
+ else:
+ self._ymap_norm = np.asarray(self._ymap)
+ self._map_dim = int(self._ymap_norm.size/self._x.size)
+ self._map_shape = self._ymap_norm.shape[:-1]
+ self._ymap_norm = np.reshape(self._ymap_norm, (self._map_dim, self._x.size))
+
+ # Check if a mask is provided
+ if 'mask' in kwargs:
+ self._mask = kwargs.pop('mask')
+ if self._mask is None:
+ ymap_min = float(self._ymap_norm.min())
+ ymap_max = float(self._ymap_norm.max())
+ else:
+ self._mask = np.asarray(self._mask).astype(bool)
+ if self._x.size != self._mask.size:
+ raise ValueError(f'Inconsistent mask dimension ({self._x.size} vs '+
+ f'{self._mask.size})')
+ ymap_masked = np.asarray(self._ymap_norm)[:,~self._mask]
+ ymap_min = float(ymap_masked.min())
+ ymap_max = float(ymap_masked.max())
+
+ # Normalize the data
+ self._y_range = ymap_max-ymap_min
+ if normalize and self._y_range > 0.0:
+ self._norm = (ymap_min, self._y_range)
+ self._ymap_norm = (self._ymap_norm-self._norm[0])/self._norm[1]
+ else:
+ self._redchi_cutoff *= self._y_range**2
+ if models is not None:
+ if callable(models) or isinstance(models, str):
+ kwargs = self.add_model(models, **kwargs)
+ elif isinstance(models, (tuple, list)):
+ for model in models:
+ kwargs = self.add_model(model, **kwargs)
+ self.fit(**kwargs)
+
+ @classmethod
+ def fit_map(cls, ymap, models, x=None, normalize=True, **kwargs):
+ return(cls(ymap, x=x, models=models, normalize=normalize, **kwargs))
+
+ @property
+ def best_errors(self):
+ return(self._best_errors)
+
+ @property
+ def best_fit(self):
+ return(self._best_fit)
+
+ @property
+ def best_results(self):
+ """Convert the input data array to a data set and add the fit results.
+ """
+ if self.best_values is None or self.best_errors is None or self.best_fit is None:
+ return(None)
+ if not have_xarray:
+ logging.warning('Unable to load xarray module')
+ return(None)
+ best_values = self.best_values
+ best_errors = self.best_errors
+ if isinstance(self._ymap, xr.DataArray):
+ best_results = self._ymap.to_dataset()
+ dims = self._ymap.dims
+ fit_name = f'{self._ymap.name}_fit'
+ else:
+ coords = {f'dim{n}_index':([f'dim{n}_index'], range(self._ymap.shape[n]))
+ for n in range(self._ymap.ndim-1)}
+ coords['x'] = (['x'], self._x)
+ dims = list(coords.keys())
+ best_results = xr.Dataset(coords=coords)
+ best_results['y'] = (dims, self._ymap)
+ fit_name = 'y_fit'
+ best_results[fit_name] = (dims, self.best_fit)
+ if self._mask is not None:
+ best_results['mask'] = self._mask
+ for n in range(best_values.shape[0]):
+ best_results[f'{self._best_parameters[n]}_values'] = (dims[:-1], best_values[n])
+ best_results[f'{self._best_parameters[n]}_errors'] = (dims[:-1], best_errors[n])
+ best_results.attrs['components'] = self.components
+ return(best_results)
+
+ @property
+ def best_values(self):
+ return(self._best_values)
+
+ @property
+ def chisqr(self):
+ logging.warning('property chisqr not defined for fit.FitMap')
+ return(None)
+
+ @property
+ def components(self):
+ components = {}
+ if self._result is None:
+ logging.warning('Unable to collect components in FitMap.components')
+ return(components)
+ for component in self._result.components:
+ if 'tmp_normalization_offset_c' in component.param_names:
+ continue
+ parameters = {}
+ for name in component.param_names:
+ if self._parameters[name].vary:
+ parameters[name] = {'free': True}
+ elif self._parameters[name].expr is not None:
+ parameters[name] = {'free': False, 'expr': self._parameters[name].expr}
+ else:
+ parameters[name] = {'free': False, 'value': self.init_parameters[name]['value']}
+ expr = None
+ if isinstance(component, ExpressionModel):
+ name = component._name
+ if name[-1] == '_':
+ name = name[:-1]
+ expr = component.expr
+ else:
+ prefix = component.prefix
+ if len(prefix):
+ if prefix[-1] == '_':
+ prefix = prefix[:-1]
+ name = f'{prefix} ({component._name})'
+ else:
+ name = f'{component._name}'
+ if expr is None:
+ components[name] = {'parameters': parameters}
+ else:
+ components[name] = {'expr': expr, 'parameters': parameters}
+ return(components)
+
+ @property
+ def covar(self):
+ logging.warning('property covar not defined for fit.FitMap')
+ return(None)
+
+ @property
+ def max_nfev(self):
+ return(self._max_nfev)
+
+ @property
+ def num_func_eval(self):
+ logging.warning('property num_func_eval not defined for fit.FitMap')
+ return(None)
+
+ @property
+ def out_of_bounds(self):
+ return(self._out_of_bounds)
+
+ @property
+ def redchi(self):
+ return(self._redchi)
+
+ @property
+ def residual(self):
+ if self.best_fit is None:
+ return(None)
+ if self._mask is None:
+ return(np.asarray(self._ymap)-self.best_fit)
+ else:
+ ymap_flat = np.reshape(np.asarray(self._ymap), (self._map_dim, self._x.size))
+ ymap_flat_masked = ymap_flat[:,~self._mask]
+ ymap_masked = np.reshape(ymap_flat_masked,
+ list(self._map_shape)+[ymap_flat_masked.shape[-1]])
+ return(ymap_masked-self.best_fit)
+
+ @property
+ def success(self):
+ return(self._success)
+
+ @property
+ def var_names(self):
+ logging.warning('property var_names not defined for fit.FitMap')
+ return(None)
+
+ @property
+ def y(self):
+ logging.warning('property y not defined for fit.FitMap')
+ return(None)
+
+ @property
+ def ymap(self):
+ return(self._ymap)
+
+ def best_parameters(self, dims=None):
+ if dims is None:
+ return(self._best_parameters)
+ if not isinstance(dims, (list, tuple)) or len(dims) != len(self._map_shape):
+ illegal_value(dims, 'dims', 'FitMap.best_parameters', raise_error=True)
+ if self.best_values is None or self.best_errors is None:
+ logging.warning(f'Unable to obtain best parameter values for dims = {dims} in '+
+ 'FitMap.best_parameters')
+ return({})
+ # Create current parameters
+ parameters = deepcopy(self._parameters)
+ for n, name in enumerate(self._best_parameters):
+ if self._parameters[name].vary:
+ parameters[name].set(value=self.best_values[n][dims])
+ parameters[name].stderr = self.best_errors[n][dims]
+ parameters_dict = {}
+ for name in sorted(parameters):
+ if name != 'tmp_normalization_offset_c':
+ par = parameters[name]
+ parameters_dict[name] = {'value': par.value, 'error': par.stderr,
+ 'init_value': self.init_parameters[name]['value'], 'min': par.min,
+ 'max': par.max, 'vary': par.vary, 'expr': par.expr}
+ return(parameters_dict)
+
+ def freemem(self):
+ if self._memfolder is None:
+ return
+ try:
+ rmtree(self._memfolder)
+ self._memfolder = None
+ except:
+ logging.warning('Could not clean-up automatically.')
+
+ def plot(self, dims, y_title=None, plot_residual=False, plot_comp_legends=False,
+ plot_masked_data=True):
+ if not isinstance(dims, (list, tuple)) or len(dims) != len(self._map_shape):
+ illegal_value(dims, 'dims', 'FitMap.plot', raise_error=True)
+ if self._result is None or self.best_fit is None or self.best_values is None:
+ logging.warning(f'Unable to plot fit for dims = {dims} in FitMap.plot')
+ return
+ if y_title is None or not isinstance(y_title, str):
+ y_title = 'data'
+ if self._mask is None:
+ mask = np.zeros(self._x.size).astype(bool)
+ plot_masked_data = False
+ else:
+ mask = self._mask
+ plots = [(self._x, np.asarray(self._ymap[dims]), 'b.')]
+ legend = [y_title]
+ if plot_masked_data:
+ plots += [(self._x[mask], np.asarray(self._ymap)[(*dims,mask)], 'bx')]
+ legend += ['masked data']
+ plots += [(self._x[~mask], self.best_fit[dims], 'k-')]
+ legend += ['best fit']
+ if plot_residual:
+ plots += [(self._x[~mask], self.residual[dims], 'k--')]
+ legend += ['residual']
+ # Create current parameters
+ parameters = deepcopy(self._parameters)
+ for name in self._best_parameters:
+ if self._parameters[name].vary:
+ parameters[name].set(value=
+ self.best_values[self._best_parameters.index(name)][dims])
+ for component in self._result.components:
+ if 'tmp_normalization_offset_c' in component.param_names:
+ continue
+ if isinstance(component, ExpressionModel):
+ prefix = component._name
+ if prefix[-1] == '_':
+ prefix = prefix[:-1]
+ modelname = f'{prefix}: {component.expr}'
+ else:
+ prefix = component.prefix
+ if len(prefix):
+ if prefix[-1] == '_':
+ prefix = prefix[:-1]
+ modelname = f'{prefix} ({component._name})'
+ else:
+ modelname = f'{component._name}'
+ if len(modelname) > 20:
+ modelname = f'{modelname[0:16]} ...'
+ y = component.eval(params=parameters, x=self._x[~mask])
+ if isinstance(y, (int, float)):
+ y *= np.ones(self._x[~mask].size)
+ plots += [(self._x[~mask], y, '--')]
+ if plot_comp_legends:
+ legend.append(modelname)
+ quick_plot(tuple(plots), legend=legend, title=str(dims), block=True)
+
+ def fit(self, **kwargs):
+# t0 = time()
+ # Check input parameters
+ if self._model is None:
+ logging.error('Undefined fit model')
+ if 'num_proc' in kwargs:
+ num_proc = kwargs.pop('num_proc')
+ if not is_int(num_proc, ge=1):
+ illegal_value(num_proc, 'num_proc', 'FitMap.fit', raise_error=True)
+ else:
+ num_proc = cpu_count()
+ if num_proc > 1 and not have_joblib:
+ logging.warning(f'Missing joblib in the conda environment, running FitMap serially')
+ num_proc = 1
+ if num_proc > cpu_count():
+ logging.warning(f'The requested number of processors ({num_proc}) exceeds the maximum '+
+ f'number of processors, num_proc reduced to ({cpu_count()})')
+ num_proc = cpu_count()
+ if 'try_no_bounds' in kwargs:
+ self._try_no_bounds = kwargs.pop('try_no_bounds')
+ if not isinstance(self._try_no_bounds, bool):
+ illegal_value(self._try_no_bounds, 'try_no_bounds', 'FitMap.fit', raise_error=True)
+ if 'redchi_cutoff' in kwargs:
+ self._redchi_cutoff = kwargs.pop('redchi_cutoff')
+ if not is_num(self._redchi_cutoff, gt=0):
+ illegal_value(self._redchi_cutoff, 'redchi_cutoff', 'FitMap.fit', raise_error=True)
+ if 'print_report' in kwargs:
+ self._print_report = kwargs.pop('print_report')
+ if not isinstance(self._print_report, bool):
+ illegal_value(self._print_report, 'print_report', 'FitMap.fit', raise_error=True)
+ if 'plot' in kwargs:
+ self._plot = kwargs.pop('plot')
+ if not isinstance(self._plot, bool):
+ illegal_value(self._plot, 'plot', 'FitMap.fit', raise_error=True)
+ if 'skip_init' in kwargs:
+ self._skip_init = kwargs.pop('skip_init')
+ if not isinstance(self._skip_init, bool):
+ illegal_value(self._skip_init, 'skip_init', 'FitMap.fit', raise_error=True)
+
+ # Apply mask if supplied:
+ if 'mask' in kwargs:
+ self._mask = kwargs.pop('mask')
+ if self._mask is not None:
+ self._mask = np.asarray(self._mask).astype(bool)
+ if self._x.size != self._mask.size:
+ raise ValueError(f'Inconsistent x and mask dimensions ({self._x.size} vs '+
+ f'{self._mask.size})')
+
+ # Add constant offset for a normalized single component model
+ if self._result is None and self._norm is not None and self._norm[0]:
+ self.add_model('constant', prefix='tmp_normalization_offset_', parameters={'name': 'c',
+ 'value': -self._norm[0], 'vary': False, 'norm': True})
+ #'value': -self._norm[0]/self._norm[1], 'vary': False, 'norm': False})
+
+ # Adjust existing parameters for refit:
+ if 'parameters' in kwargs:
+# print('\nIn FitMap before adjusting existing parameters for refit:')
+# self._parameters.pretty_print()
+# if self._result is None:
+# raise ValueError('Invalid parameter parameters ({parameters})')
+# if self._best_values is None:
+# raise ValueError('Valid self._best_values required for refitting in FitMap.fit')
+ parameters = kwargs.pop('parameters')
+# print(f'\nparameters:\n{parameters}')
+ if isinstance(parameters, dict):
+ parameters = (parameters, )
+ elif not is_dict_series(parameters):
+ illegal_value(parameters, 'parameters', 'Fit.fit', raise_error=True)
+ for par in parameters:
+ name = par['name']
+ if name not in self._parameters:
+ raise ValueError(f'Unable to match {name} parameter {par} to an existing one')
+ if self._parameters[name].expr is not None:
+ raise ValueError(f'Unable to modify {name} parameter {par} (currently an '+
+ 'expression)')
+ value = par.get('value')
+ vary = par.get('vary')
+ if par.get('expr') is not None:
+ raise KeyError(f'Illegal "expr" key in {name} parameter {par}')
+ self._parameters[name].set(value=value, vary=vary, min=par.get('min'),
+ max=par.get('max'))
+ # Overwrite existing best values for fixed parameters when a value is specified
+# print(f'best values befored resetting:\n{self._best_values}')
+ if isinstance(value, (int, float)) and vary is False:
+ for i, nname in enumerate(self._best_parameters):
+ if nname == name:
+ self._best_values[i] = value
+# print(f'best values after resetting (value={value}, vary={vary}):\n{self._best_values}')
+#RV print('\nIn FitMap after adjusting existing parameters for refit:')
+#RV self._parameters.pretty_print()
+
+ # Check for uninitialized parameters
+ for name, par in self._parameters.items():
+ if par.expr is None:
+ value = par.value
+ if value is None or np.isinf(value) or np.isnan(value):
+ value = 1.0
+ if self._norm is None or name not in self._parameter_norms:
+ self._parameters[name].set(value=value)
+ elif self._parameter_norms[name]:
+ self._parameters[name].set(value=value*self._norm[1])
+
+ # Create the best parameter list, consisting of all varying parameters plus the expression
+ # parameters in order to collect their errors
+ if self._result is None:
+ # Initial fit
+ assert(self._best_parameters is None)
+ self._best_parameters = [name for name, par in self._parameters.items()
+ if par.vary or par.expr is not None]
+ num_new_parameters = 0
+ else:
+ # Refit
+ assert(len(self._best_parameters))
+ self._new_parameters = [name for name, par in self._parameters.items()
+ if name != 'tmp_normalization_offset_c' and name not in self._best_parameters and
+ (par.vary or par.expr is not None)]
+ num_new_parameters = len(self._new_parameters)
+ num_best_parameters = len(self._best_parameters)
+
+ # Flatten and normalize the best values of the previous fit, remove the remaining results
+ # of the previous fit
+ if self._result is not None:
+# print('\nBefore flatten and normalize:')
+# print(f'self._best_values:\n{self._best_values}')
+ self._out_of_bounds = None
+ self._max_nfev = None
+ self._redchi = None
+ self._success = None
+ self._best_fit = None
+ self._best_errors = None
+ assert(self._best_values is not None)
+ assert(self._best_values.shape[0] == num_best_parameters)
+ assert(self._best_values.shape[1:] == self._map_shape)
+ if self._transpose is not None:
+ self._best_values = np.transpose(self._best_values,
+ [0]+[i+1 for i in self._transpose])
+ self._best_values = [np.reshape(self._best_values[i], self._map_dim)
+ for i in range(num_best_parameters)]
+ if self._norm is not None:
+ for i, name in enumerate(self._best_parameters):
+ if self._parameter_norms.get(name, False):
+ self._best_values[i] /= self._norm[1]
+#RV print('\nAfter flatten and normalize:')
+#RV print(f'self._best_values:\n{self._best_values}')
+
+ # Normalize the initial parameters (and best values for a refit)
+# print('\nIn FitMap before normalize:')
+# self._parameters.pretty_print()
+# print(f'\nparameter_norms:\n{self._parameter_norms}\n')
+ self._normalize()
+# print('\nIn FitMap after normalize:')
+# self._parameters.pretty_print()
+# print(f'\nparameter_norms:\n{self._parameter_norms}\n')
+
+ # Prevent initial values from sitting at boundaries
+ self._parameter_bounds = {name:{'min': par.min, 'max': par.max}
+ for name, par in self._parameters.items() if par.vary}
+ for name, par in self._parameters.items():
+ if par.vary:
+ par.set(value=self._reset_par_at_boundary(par, par.value))
+# print('\nAfter checking boundaries:')
+# self._parameters.pretty_print()
+
+ # Set parameter bounds to unbound (only use bounds when fit fails)
+ if self._try_no_bounds:
+ for name in self._parameter_bounds.keys():
+ self._parameters[name].set(min=-np.inf, max=np.inf)
+
+ # Allocate memory to store fit results
+ if self._mask is None:
+ x_size = self._x.size
+ else:
+ x_size = self._x[~self._mask].size
+ if num_proc == 1:
+ self._out_of_bounds_flat = np.zeros(self._map_dim, dtype=bool)
+ self._max_nfev_flat = np.zeros(self._map_dim, dtype=bool)
+ self._redchi_flat = np.zeros(self._map_dim, dtype=np.float64)
+ self._success_flat = np.zeros(self._map_dim, dtype=bool)
+ self._best_fit_flat = np.zeros((self._map_dim, x_size),
+ dtype=self._ymap_norm.dtype)
+ self._best_errors_flat = [np.zeros(self._map_dim, dtype=np.float64)
+ for _ in range(num_best_parameters+num_new_parameters)]
+ if self._result is None:
+ self._best_values_flat = [np.zeros(self._map_dim, dtype=np.float64)
+ for _ in range(num_best_parameters)]
+ else:
+ self._best_values_flat = self._best_values
+ self._best_values_flat += [np.zeros(self._map_dim, dtype=np.float64)
+ for _ in range(num_new_parameters)]
+ else:
+ self._memfolder = './joblib_memmap'
+ try:
+ mkdir(self._memfolder)
+ except FileExistsError:
+ pass
+ filename_memmap = path.join(self._memfolder, 'out_of_bounds_memmap')
+ self._out_of_bounds_flat = np.memmap(filename_memmap, dtype=bool,
+ shape=(self._map_dim), mode='w+')
+ filename_memmap = path.join(self._memfolder, 'max_nfev_memmap')
+ self._max_nfev_flat = np.memmap(filename_memmap, dtype=bool,
+ shape=(self._map_dim), mode='w+')
+ filename_memmap = path.join(self._memfolder, 'redchi_memmap')
+ self._redchi_flat = np.memmap(filename_memmap, dtype=np.float64,
+ shape=(self._map_dim), mode='w+')
+ filename_memmap = path.join(self._memfolder, 'success_memmap')
+ self._success_flat = np.memmap(filename_memmap, dtype=bool,
+ shape=(self._map_dim), mode='w+')
+ filename_memmap = path.join(self._memfolder, 'best_fit_memmap')
+ self._best_fit_flat = np.memmap(filename_memmap, dtype=self._ymap_norm.dtype,
+ shape=(self._map_dim, x_size), mode='w+')
+ self._best_errors_flat = []
+ for i in range(num_best_parameters+num_new_parameters):
+ filename_memmap = path.join(self._memfolder, f'best_errors_memmap_{i}')
+ self._best_errors_flat.append(np.memmap(filename_memmap, dtype=np.float64,
+ shape=self._map_dim, mode='w+'))
+ self._best_values_flat = []
+ for i in range(num_best_parameters):
+ filename_memmap = path.join(self._memfolder, f'best_values_memmap_{i}')
+ self._best_values_flat.append(np.memmap(filename_memmap, dtype=np.float64,
+ shape=self._map_dim, mode='w+'))
+ if self._result is not None:
+ self._best_values_flat[i][:] = self._best_values[i][:]
+ for i in range(num_new_parameters):
+ filename_memmap = path.join(self._memfolder,
+ f'best_values_memmap_{i+num_best_parameters}')
+ self._best_values_flat.append(np.memmap(filename_memmap, dtype=np.float64,
+ shape=self._map_dim, mode='w+'))
+
+ # Update the best parameter list
+ if num_new_parameters:
+ self._best_parameters += self._new_parameters
+
+ # Perform the first fit to get model component info and initial parameters
+ current_best_values = {}
+# print(f'0 before:\n{current_best_values}')
+# t1 = time()
+ self._result = self._fit(0, current_best_values, return_result=True, **kwargs)
+# t2 = time()
+# print(f'0 after:\n{current_best_values}')
+# print('\nAfter the first fit:')
+# self._parameters.pretty_print()
+# print(self._result.fit_report(show_correl=False))
+
+ # Remove all irrelevant content from self._result
+ for attr in ('_abort', 'aborted', 'aic', 'best_fit', 'best_values', 'bic', 'calc_covar',
+ 'call_kws', 'chisqr', 'ci_out', 'col_deriv', 'covar', 'data', 'errorbars',
+ 'flatchain', 'ier', 'init_vals', 'init_fit', 'iter_cb', 'jacfcn', 'kws',
+ 'last_internal_values', 'lmdif_message', 'message', 'method', 'nan_policy',
+ 'ndata', 'nfev', 'nfree', 'params', 'redchi', 'reduce_fcn', 'residual', 'result',
+ 'scale_covar', 'show_candidates', 'calc_covar', 'success', 'userargs', 'userfcn',
+ 'userkws', 'values', 'var_names', 'weights', 'user_options'):
+ try:
+ delattr(self._result, attr)
+ except AttributeError:
+# logging.warning(f'Unknown attribute {attr} in fit.FtMap._cleanup_result')
+ pass
+
+# t3 = time()
+ if num_proc == 1:
+ # Perform the remaining fits serially
+ for n in range(1, self._map_dim):
+# print(f'{n} before:\n{current_best_values}')
+ self._fit(n, current_best_values, **kwargs)
+# print(f'{n} after:\n{current_best_values}')
+ else:
+ # Perform the remaining fits in parallel
+ num_fit = self._map_dim-1
+# print(f'num_fit = {num_fit}')
+ if num_proc > num_fit:
+ logging.warning(f'The requested number of processors ({num_proc}) exceeds the '+
+ f'number of fits, num_proc reduced to ({num_fit})')
+ num_proc = num_fit
+ num_fit_per_proc = 1
+ else:
+ num_fit_per_proc = round((num_fit)/num_proc)
+ if num_proc*num_fit_per_proc < num_fit:
+ num_fit_per_proc +=1
+# print(f'num_fit_per_proc = {num_fit_per_proc}')
+ num_fit_batch = min(num_fit_per_proc, 40)
+# print(f'num_fit_batch = {num_fit_batch}')
+ with Parallel(n_jobs=num_proc) as parallel:
+ parallel(delayed(self._fit_parallel)(current_best_values, num_fit_batch,
+ n_start, **kwargs) for n_start in range(1, self._map_dim, num_fit_batch))
+# t4 = time()
+
+ # Renormalize the initial parameters for external use
+ if self._norm is not None and self._normalized:
+ init_values = {}
+ for name, value in self._result.init_values.items():
+ if name not in self._parameter_norms or self._parameters[name].expr is not None:
+ init_values[name] = value
+ elif self._parameter_norms[name]:
+ init_values[name] = value*self._norm[1]
+ self._result.init_values = init_values
+ for name, par in self._result.init_params.items():
+ if par.expr is None and self._parameter_norms.get(name, False):
+ _min = par.min
+ _max = par.max
+ value = par.value*self._norm[1]
+ if not np.isinf(_min) and abs(_min) != float_min:
+ _min *= self._norm[1]
+ if not np.isinf(_max) and abs(_max) != float_min:
+ _max *= self._norm[1]
+ par.set(value=value, min=_min, max=_max)
+ par.init_value = par.value
+
+ # Remap the best results
+# t5 = time()
+ self._out_of_bounds = np.copy(np.reshape(self._out_of_bounds_flat, self._map_shape))
+ self._max_nfev = np.copy(np.reshape(self._max_nfev_flat, self._map_shape))
+ self._redchi = np.copy(np.reshape(self._redchi_flat, self._map_shape))
+ self._success = np.copy(np.reshape(self._success_flat, self._map_shape))
+ self._best_fit = np.copy(np.reshape(self._best_fit_flat,
+ list(self._map_shape)+[x_size]))
+ self._best_values = np.asarray([np.reshape(par, list(self._map_shape))
+ for par in self._best_values_flat])
+ self._best_errors = np.asarray([np.reshape(par, list(self._map_shape))
+ for par in self._best_errors_flat])
+ if self._inv_transpose is not None:
+ self._out_of_bounds = np.transpose(self._out_of_bounds, self._inv_transpose)
+ self._max_nfev = np.transpose(self._max_nfev, self._inv_transpose)
+ self._redchi = np.transpose(self._redchi, self._inv_transpose)
+ self._success = np.transpose(self._success, self._inv_transpose)
+ self._best_fit = np.transpose(self._best_fit,
+ list(self._inv_transpose)+[len(self._inv_transpose)])
+ self._best_values = np.transpose(self._best_values,
+ [0]+[i+1 for i in self._inv_transpose])
+ self._best_errors = np.transpose(self._best_errors,
+ [0]+[i+1 for i in self._inv_transpose])
+ del self._out_of_bounds_flat
+ del self._max_nfev_flat
+ del self._redchi_flat
+ del self._success_flat
+ del self._best_fit_flat
+ del self._best_values_flat
+ del self._best_errors_flat
+# t6 = time()
+
+ # Restore parameter bounds and renormalize the parameters
+ for name, par in self._parameter_bounds.items():
+ self._parameters[name].set(min=par['min'], max=par['max'])
+ self._normalized = False
+ if self._norm is not None:
+ for name, norm in self._parameter_norms.items():
+ par = self._parameters[name]
+ if par.expr is None and norm:
+ value = par.value*self._norm[1]
+ _min = par.min
+ _max = par.max
+ if not np.isinf(_min) and abs(_min) != float_min:
+ _min *= self._norm[1]
+ if not np.isinf(_max) and abs(_max) != float_min:
+ _max *= self._norm[1]
+ par.set(value=value, min=_min, max=_max)
+# t7 = time()
+# print(f'total run time in fit: {t7-t0:.2f} seconds')
+# print(f'run time first fit: {t2-t1:.2f} seconds')
+# print(f'run time remaining fits: {t4-t3:.2f} seconds')
+# print(f'run time remapping results: {t6-t5:.2f} seconds')
+
+# print('\n\nAt end fit:')
+# self._parameters.pretty_print()
+# print(f'self._best_values:\n{self._best_values}\n\n')
+
+ # Free the shared memory
+ self.freemem()
+
+ def _fit_parallel(self, current_best_values, num, n_start, **kwargs):
+ num = min(num, self._map_dim-n_start)
+ for n in range(num):
+# print(f'{n_start+n} before:\n{current_best_values}')
+ self._fit(n_start+n, current_best_values, **kwargs)
+# print(f'{n_start+n} after:\n{current_best_values}')
+
+ def _fit(self, n, current_best_values, return_result=False, **kwargs):
+#RV print(f'\n\nstart FitMap._fit {n}\n')
+#RV print(f'current_best_values = {current_best_values}')
+#RV print(f'self._best_parameters = {self._best_parameters}')
+#RV print(f'self._new_parameters = {self._new_parameters}\n\n')
+# self._parameters.pretty_print()
+ # Set parameters to current best values, but prevent them from sitting at boundaries
+ if self._new_parameters is None:
+ # Initial fit
+ for name, value in current_best_values.items():
+ par = self._parameters[name]
+ par.set(value=self._reset_par_at_boundary(par, value))
+ else:
+ # Refit
+ for i, name in enumerate(self._best_parameters):
+ par = self._parameters[name]
+ if name in self._new_parameters:
+ if name in current_best_values:
+ par.set(value=self._reset_par_at_boundary(par, current_best_values[name]))
+ elif par.expr is None:
+ par.set(value=self._best_values[i][n])
+#RV print(f'\nbefore fit {n}')
+#RV self._parameters.pretty_print()
+ if self._mask is None:
+ result = self._model.fit(self._ymap_norm[n], self._parameters, x=self._x, **kwargs)
+ else:
+ result = self._model.fit(self._ymap_norm[n][~self._mask], self._parameters,
+ x=self._x[~self._mask], **kwargs)
+# print(f'\nafter fit {n}')
+# self._parameters.pretty_print()
+# print(result.fit_report(show_correl=False))
+ out_of_bounds = False
+ for name, par in self._parameter_bounds.items():
+ value = result.params[name].value
+ if not np.isinf(par['min']) and value < par['min']:
+ out_of_bounds = True
+ break
+ if not np.isinf(par['max']) and value > par['max']:
+ out_of_bounds = True
+ break
+ self._out_of_bounds_flat[n] = out_of_bounds
+ if self._try_no_bounds and out_of_bounds:
+ # Rerun fit with parameter bounds in place
+ for name, par in self._parameter_bounds.items():
+ self._parameters[name].set(min=par['min'], max=par['max'])
+ # Set parameters to current best values, but prevent them from sitting at boundaries
+ if self._new_parameters is None:
+ # Initial fit
+ for name, value in current_best_values.items():
+ par = self._parameters[name]
+ par.set(value=self._reset_par_at_boundary(par, value))
+ else:
+ # Refit
+ for i, name in enumerate(self._best_parameters):
+ par = self._parameters[name]
+ if name in self._new_parameters:
+ if name in current_best_values:
+ par.set(value=self._reset_par_at_boundary(par,
+ current_best_values[name]))
+ elif par.expr is None:
+ par.set(value=self._best_values[i][n])
+# print('\nbefore fit')
+# self._parameters.pretty_print()
+# print(result.fit_report(show_correl=False))
+ if self._mask is None:
+ result = self._model.fit(self._ymap_norm[n], self._parameters, x=self._x, **kwargs)
+ else:
+ result = self._model.fit(self._ymap_norm[n][~self._mask], self._parameters,
+ x=self._x[~self._mask], **kwargs)
+# print(f'\nafter fit {n}')
+# self._parameters.pretty_print()
+# print(result.fit_report(show_correl=False))
+ out_of_bounds = False
+ for name, par in self._parameter_bounds.items():
+ value = result.params[name].value
+ if not np.isinf(par['min']) and value < par['min']:
+ out_of_bounds = True
+ break
+ if not np.isinf(par['max']) and value > par['max']:
+ out_of_bounds = True
+ break
+# print(f'{n} redchi < redchi_cutoff = {result.redchi < self._redchi_cutoff} success = {result.success} out_of_bounds = {out_of_bounds}')
+ # Reset parameters back to unbound
+ for name in self._parameter_bounds.keys():
+ self._parameters[name].set(min=-np.inf, max=np.inf)
+ assert(not out_of_bounds)
+ if result.redchi >= self._redchi_cutoff:
+ result.success = False
+ if result.nfev == result.max_nfev:
+# print(f'Maximum number of function evaluations reached for n = {n}')
+# logging.warning(f'Maximum number of function evaluations reached for n = {n}')
+ if result.redchi < self._redchi_cutoff:
+ result.success = True
+ self._max_nfev_flat[n] = True
+ if result.success:
+ assert(all(True for par in current_best_values if par in result.params.values()))
+ for par in result.params.values():
+ if par.vary:
+ current_best_values[par.name] = par.value
+ else:
+ logging.warning(f'Fit for n = {n} failed: {result.lmdif_message}')
+ # Renormalize the data and results
+ self._renormalize(n, result)
+ if self._print_report:
+ print(result.fit_report(show_correl=False))
+ if self._plot:
+ dims = np.unravel_index(n, self._map_shape)
+ if self._inv_transpose is not None:
+ dims= tuple(dims[self._inv_transpose[i]] for i in range(len(dims)))
+ super().plot(result=result, y=np.asarray(self._ymap[dims]), plot_comp_legends=True,
+ skip_init=self._skip_init, title=str(dims))
+#RV print(f'\n\nend FitMap._fit {n}\n')
+#RV print(f'current_best_values = {current_best_values}')
+# self._parameters.pretty_print()
+# print(result.fit_report(show_correl=False))
+#RV print(f'\nself._best_values_flat:\n{self._best_values_flat}\n\n')
+ if return_result:
+ return(result)
+
+ def _renormalize(self, n, result):
+ self._redchi_flat[n] = np.float64(result.redchi)
+ self._success_flat[n] = result.success
+ if self._norm is None or not self._normalized:
+ self._best_fit_flat[n] = result.best_fit
+ for i, name in enumerate(self._best_parameters):
+ self._best_values_flat[i][n] = np.float64(result.params[name].value)
+ self._best_errors_flat[i][n] = np.float64(result.params[name].stderr)
+ else:
+ pars = set(self._parameter_norms) & set(self._best_parameters)
+ for name, par in result.params.items():
+ if name in pars and self._parameter_norms[name]:
+ if par.stderr is not None:
+ par.stderr *= self._norm[1]
+ if par.expr is None:
+ par.value *= self._norm[1]
+ if self._print_report:
+ if par.init_value is not None:
+ par.init_value *= self._norm[1]
+ if not np.isinf(par.min) and abs(par.min) != float_min:
+ par.min *= self._norm[1]
+ if not np.isinf(par.max) and abs(par.max) != float_min:
+ par.max *= self._norm[1]
+ self._best_fit_flat[n] = result.best_fit*self._norm[1]+self._norm[0]
+ for i, name in enumerate(self._best_parameters):
+ self._best_values_flat[i][n] = np.float64(result.params[name].value)
+ self._best_errors_flat[i][n] = np.float64(result.params[name].stderr)
+ if self._plot:
+ if not self._skip_init:
+ result.init_fit = result.init_fit*self._norm[1]+self._norm[0]
+ result.best_fit = np.copy(self._best_fit_flat[n])
diff -r 000000000000 -r 98e23dff1de2 general.py
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/general.py Tue Mar 21 16:22:42 2023 +0000
@@ -0,0 +1,1965 @@
+#!/usr/bin/env python3
+
+#FIX write a function that returns a list of peak indices for a given plot
+#FIX use raise_error concept on more functions to optionally raise an error
+
+# -*- coding: utf-8 -*-
+"""
+Created on Mon Dec 6 15:36:22 2021
+
+@author: rv43
+"""
+
+import logging
+logger=logging.getLogger(__name__)
+
+import os
+import sys
+import re
+try:
+ from yaml import safe_load, safe_dump
+except:
+ pass
+try:
+ import h5py
+except:
+ pass
+import numpy as np
+try:
+ import matplotlib.pyplot as plt
+ import matplotlib.lines as mlines
+ from matplotlib import transforms
+ from matplotlib.widgets import Button
+except:
+ pass
+
+from ast import literal_eval
+try:
+ from asteval import Interpreter, get_ast_names
+except:
+ pass
+from copy import deepcopy
+try:
+ from sympy import diff, simplify
+except:
+ pass
+from time import time
+
+
+def depth_list(L): return(isinstance(L, list) and max(map(depth_list, L))+1)
+def depth_tuple(T): return(isinstance(T, tuple) and max(map(depth_tuple, T))+1)
+def unwrap_tuple(T):
+ if depth_tuple(T) > 1 and len(T) == 1:
+ T = unwrap_tuple(*T)
+ return(T)
+
+def illegal_value(value, name, location=None, raise_error=False, log=True):
+ if not isinstance(location, str):
+ location = ''
+ else:
+ location = f'in {location} '
+ if isinstance(name, str):
+ error_msg = f'Illegal value for {name} {location}({value}, {type(value)})'
+ else:
+ error_msg = f'Illegal value {location}({value}, {type(value)})'
+ if log:
+ logger.error(error_msg)
+ if raise_error:
+ raise ValueError(error_msg)
+
+def illegal_combination(value1, name1, value2, name2, location=None, raise_error=False,
+ log=True):
+ if not isinstance(location, str):
+ location = ''
+ else:
+ location = f'in {location} '
+ if isinstance(name1, str):
+ error_msg = f'Illegal combination for {name1} and {name2} {location}'+ \
+ f'({value1}, {type(value1)} and {value2}, {type(value2)})'
+ else:
+ error_msg = f'Illegal combination {location}'+ \
+ f'({value1}, {type(value1)} and {value2}, {type(value2)})'
+ if log:
+ logger.error(error_msg)
+ if raise_error:
+ raise ValueError(error_msg)
+
+def test_ge_gt_le_lt(ge, gt, le, lt, func, location=None, raise_error=False, log=True):
+ """Check individual and mutual validity of ge, gt, le, lt qualifiers
+ func: is_int or is_num to test for int or numbers
+ Return: True upon success or False when mutually exlusive
+ """
+ if ge is None and gt is None and le is None and lt is None:
+ return(True)
+ if ge is not None:
+ if not func(ge):
+ illegal_value(ge, 'ge', location, raise_error, log)
+ return(False)
+ if gt is not None:
+ illegal_combination(ge, 'ge', gt, 'gt', location, raise_error, log)
+ return(False)
+ elif gt is not None and not func(gt):
+ illegal_value(gt, 'gt', location, raise_error, log)
+ return(False)
+ if le is not None:
+ if not func(le):
+ illegal_value(le, 'le', location, raise_error, log)
+ return(False)
+ if lt is not None:
+ illegal_combination(le, 'le', lt, 'lt', location, raise_error, log)
+ return(False)
+ elif lt is not None and not func(lt):
+ illegal_value(lt, 'lt', location, raise_error, log)
+ return(False)
+ if ge is not None:
+ if le is not None and ge > le:
+ illegal_combination(ge, 'ge', le, 'le', location, raise_error, log)
+ return(False)
+ elif lt is not None and ge >= lt:
+ illegal_combination(ge, 'ge', lt, 'lt', location, raise_error, log)
+ return(False)
+ elif gt is not None:
+ if le is not None and gt >= le:
+ illegal_combination(gt, 'gt', le, 'le', location, raise_error, log)
+ return(False)
+ elif lt is not None and gt >= lt:
+ illegal_combination(gt, 'gt', lt, 'lt', location, raise_error, log)
+ return(False)
+ return(True)
+
+def range_string_ge_gt_le_lt(ge=None, gt=None, le=None, lt=None):
+ """Return a range string representation matching the ge, gt, le, lt qualifiers
+ Does not validate the inputs, do that as needed before calling
+ """
+ range_string = ''
+ if ge is not None:
+ if le is None and lt is None:
+ range_string += f'>= {ge}'
+ else:
+ range_string += f'[{ge}, '
+ elif gt is not None:
+ if le is None and lt is None:
+ range_string += f'> {gt}'
+ else:
+ range_string += f'({gt}, '
+ if le is not None:
+ if ge is None and gt is None:
+ range_string += f'<= {le}'
+ else:
+ range_string += f'{le}]'
+ elif lt is not None:
+ if ge is None and gt is None:
+ range_string += f'< {lt}'
+ else:
+ range_string += f'{lt})'
+ return(range_string)
+
+def is_int(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
+ """Value is an integer in range ge <= v <= le or gt < v < lt or some combination.
+ Return: True if yes or False is no
+ """
+ return(_is_int_or_num(v, 'int', ge, gt, le, lt, raise_error, log))
+
+def is_num(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
+ """Value is a number in range ge <= v <= le or gt < v < lt or some combination.
+ Return: True if yes or False is no
+ """
+ return(_is_int_or_num(v, 'num', ge, gt, le, lt, raise_error, log))
+
+def _is_int_or_num(v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False,
+ log=True):
+ if type_str == 'int':
+ if not isinstance(v, int):
+ illegal_value(v, 'v', '_is_int_or_num', raise_error, log)
+ return(False)
+ if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, '_is_int_or_num', raise_error, log):
+ return(False)
+ elif type_str == 'num':
+ if not isinstance(v, (int, float)):
+ illegal_value(v, 'v', '_is_int_or_num', raise_error, log)
+ return(False)
+ if not test_ge_gt_le_lt(ge, gt, le, lt, is_num, '_is_int_or_num', raise_error, log):
+ return(False)
+ else:
+ illegal_value(type_str, 'type_str', '_is_int_or_num', raise_error, log)
+ return(False)
+ if ge is None and gt is None and le is None and lt is None:
+ return(True)
+ error = False
+ if ge is not None and v < ge:
+ error = True
+ error_msg = f'Value {v} out of range: {v} !>= {ge}'
+ if not error and gt is not None and v <= gt:
+ error = True
+ error_msg = f'Value {v} out of range: {v} !> {gt}'
+ if not error and le is not None and v > le:
+ error = True
+ error_msg = f'Value {v} out of range: {v} !<= {le}'
+ if not error and lt is not None and v >= lt:
+ error = True
+ error_msg = f'Value {v} out of range: {v} !< {lt}'
+ if error:
+ if log:
+ logger.error(error_msg)
+ if raise_error:
+ raise ValueError(error_msg)
+ return(False)
+ return(True)
+
+def is_int_pair(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
+ """Value is an integer pair, each in range ge <= v[i] <= le or gt < v[i] < lt or
+ ge[i] <= v[i] <= le[i] or gt[i] < v[i] < lt[i] or some combination.
+ Return: True if yes or False is no
+ """
+ return(_is_int_or_num_pair(v, 'int', ge, gt, le, lt, raise_error, log))
+
+def is_num_pair(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
+ """Value is a number pair, each in range ge <= v[i] <= le or gt < v[i] < lt or
+ ge[i] <= v[i] <= le[i] or gt[i] < v[i] < lt[i] or some combination.
+ Return: True if yes or False is no
+ """
+ return(_is_int_or_num_pair(v, 'num', ge, gt, le, lt, raise_error, log))
+
+def _is_int_or_num_pair(v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False,
+ log=True):
+ if type_str == 'int':
+ if not (isinstance(v, (tuple, list)) and len(v) == 2 and isinstance(v[0], int) and
+ isinstance(v[1], int)):
+ illegal_value(v, 'v', '_is_int_or_num_pair', raise_error, log)
+ return(False)
+ func = is_int
+ elif type_str == 'num':
+ if not (isinstance(v, (tuple, list)) and len(v) == 2 and isinstance(v[0], (int, float)) and
+ isinstance(v[1], (int, float))):
+ illegal_value(v, 'v', '_is_int_or_num_pair', raise_error, log)
+ return(False)
+ func = is_num
+ else:
+ illegal_value(type_str, 'type_str', '_is_int_or_num_pair', raise_error, log)
+ return(False)
+ if ge is None and gt is None and le is None and lt is None:
+ return(True)
+ if ge is None or func(ge, log=True):
+ ge = 2*[ge]
+ elif not _is_int_or_num_pair(ge, type_str, raise_error=raise_error, log=log):
+ return(False)
+ if gt is None or func(gt, log=True):
+ gt = 2*[gt]
+ elif not _is_int_or_num_pair(gt, type_str, raise_error=raise_error, log=log):
+ return(False)
+ if le is None or func(le, log=True):
+ le = 2*[le]
+ elif not _is_int_or_num_pair(le, type_str, raise_error=raise_error, log=log):
+ return(False)
+ if lt is None or func(lt, log=True):
+ lt = 2*[lt]
+ elif not _is_int_or_num_pair(lt, type_str, raise_error=raise_error, log=log):
+ return(False)
+ if (not func(v[0], ge[0], gt[0], le[0], lt[0], raise_error, log) or
+ not func(v[1], ge[1], gt[1], le[1], lt[1], raise_error, log)):
+ return(False)
+ return(True)
+
+def is_int_series(l, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
+ """Value is a tuple or list of integers, each in range ge <= l[i] <= le or
+ gt < l[i] < lt or some combination.
+ """
+ if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, 'is_int_series', raise_error, log):
+ return(False)
+ if not isinstance(l, (tuple, list)):
+ illegal_value(l, 'l', 'is_int_series', raise_error, log)
+ return(False)
+ if any(True if not is_int(v, ge, gt, le, lt, raise_error, log) else False for v in l):
+ return(False)
+ return(True)
+
+def is_num_series(l, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
+ """Value is a tuple or list of numbers, each in range ge <= l[i] <= le or
+ gt < l[i] < lt or some combination.
+ """
+ if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, 'is_int_series', raise_error, log):
+ return(False)
+ if not isinstance(l, (tuple, list)):
+ illegal_value(l, 'l', 'is_num_series', raise_error, log)
+ return(False)
+ if any(True if not is_num(v, ge, gt, le, lt, raise_error, log) else False for v in l):
+ return(False)
+ return(True)
+
+def is_str_series(l, raise_error=False, log=True):
+ """Value is a tuple or list of strings.
+ """
+ if (not isinstance(l, (tuple, list)) or
+ any(True if not isinstance(s, str) else False for s in l)):
+ illegal_value(l, 'l', 'is_str_series', raise_error, log)
+ return(False)
+ return(True)
+
+def is_dict_series(l, raise_error=False, log=True):
+ """Value is a tuple or list of dictionaries.
+ """
+ if (not isinstance(l, (tuple, list)) or
+ any(True if not isinstance(d, dict) else False for d in l)):
+ illegal_value(l, 'l', 'is_dict_series', raise_error, log)
+ return(False)
+ return(True)
+
+def is_dict_nums(l, raise_error=False, log=True):
+ """Value is a dictionary with single number values
+ """
+ if (not isinstance(l, dict) or
+ any(True if not is_num(v, log=False) else False for v in l.values())):
+ illegal_value(l, 'l', 'is_dict_nums', raise_error, log)
+ return(False)
+ return(True)
+
+def is_dict_strings(l, raise_error=False, log=True):
+ """Value is a dictionary with single string values
+ """
+ if (not isinstance(l, dict) or
+ any(True if not isinstance(v, str) else False for v in l.values())):
+ illegal_value(l, 'l', 'is_dict_strings', raise_error, log)
+ return(False)
+ return(True)
+
+def is_index(v, ge=0, lt=None, raise_error=False, log=True):
+ """Value is an array index in range ge <= v < lt.
+ NOTE lt IS NOT included!
+ """
+ if isinstance(lt, int):
+ if lt <= ge:
+ illegal_combination(ge, 'ge', lt, 'lt', 'is_index', raise_error, log)
+ return(False)
+ return(is_int(v, ge=ge, lt=lt, raise_error=raise_error, log=log))
+
+def is_index_range(v, ge=0, le=None, lt=None, raise_error=False, log=True):
+ """Value is an array index range in range ge <= v[0] <= v[1] <= le or ge <= v[0] <= v[1] < lt.
+ NOTE le IS included!
+ """
+ if not is_int_pair(v, raise_error=raise_error, log=log):
+ return(False)
+ if not test_ge_gt_le_lt(ge, None, le, lt, is_int, 'is_index_range', raise_error, log):
+ return(False)
+ if not ge <= v[0] <= v[1] or (le is not None and v[1] > le) or (lt is not None and v[1] >= lt):
+ if le is not None:
+ error_msg = f'Value {v} out of range: !({ge} <= {v[0]} <= {v[1]} <= {le})'
+ else:
+ error_msg = f'Value {v} out of range: !({ge} <= {v[0]} <= {v[1]} < {lt})'
+ if log:
+ logger.error(error_msg)
+ if raise_error:
+ raise ValueError(error_msg)
+ return(False)
+ return(True)
+
+def index_nearest(a, value):
+ a = np.asarray(a)
+ if a.ndim > 1:
+ raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})')
+ # Round up for .5
+ value *= 1.0+sys.float_info.epsilon
+ return((int)(np.argmin(np.abs(a-value))))
+
+def index_nearest_low(a, value):
+ a = np.asarray(a)
+ if a.ndim > 1:
+ raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})')
+ index = int(np.argmin(np.abs(a-value)))
+ if value < a[index] and index > 0:
+ index -= 1
+ return(index)
+
+def index_nearest_upp(a, value):
+ a = np.asarray(a)
+ if a.ndim > 1:
+ raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})')
+ index = int(np.argmin(np.abs(a-value)))
+ if value > a[index] and index < a.size-1:
+ index += 1
+ return(index)
+
+def round_to_n(x, n=1):
+ if x == 0.0:
+ return(0)
+ else:
+ return(type(x)(round(x, n-1-int(np.floor(np.log10(abs(x)))))))
+
+def round_up_to_n(x, n=1):
+ xr = round_to_n(x, n)
+ if abs(x/xr) > 1.0:
+ xr += np.sign(x)*10**(np.floor(np.log10(abs(x)))+1-n)
+ return(type(x)(xr))
+
+def trunc_to_n(x, n=1):
+ xr = round_to_n(x, n)
+ if abs(xr/x) > 1.0:
+ xr -= np.sign(x)*10**(np.floor(np.log10(abs(x)))+1-n)
+ return(type(x)(xr))
+
+def almost_equal(a, b, sig_figs):
+ if is_num(a) and is_num(b):
+ return(abs(round_to_n(a-b, sig_figs)) < pow(10, -sig_figs+1))
+ else:
+ raise ValueError(f'Invalid value for a or b in almost_equal (a: {a}, {type(a)}, '+
+ f'b: {b}, {type(b)})')
+ return(False)
+
+def string_to_list(s, split_on_dash=True, remove_duplicates=True, sort=True):
+ """Return a list of numbers by splitting/expanding a string on any combination of
+ commas, whitespaces, or dashes (when split_on_dash=True)
+ e.g: '1, 3, 5-8, 12 ' -> [1, 3, 5, 6, 7, 8, 12]
+ """
+ if not isinstance(s, str):
+ illegal_value(s, location='string_to_list')
+ return(None)
+ if not len(s):
+ return([])
+ try:
+ ll = [x for x in re.split('\s+,\s+|\s+,|,\s+|\s+|,', s.strip())]
+ except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
+ return(None)
+ if split_on_dash:
+ try:
+ l = []
+ for l1 in ll:
+ l2 = [literal_eval(x) for x in re.split('\s+-\s+|\s+-|-\s+|\s+|-', l1)]
+ if len(l2) == 1:
+ l += l2
+ elif len(l2) == 2 and l2[1] > l2[0]:
+ l += [i for i in range(l2[0], l2[1]+1)]
+ else:
+ raise ValueError
+ except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
+ return(None)
+ else:
+ l = [literal_eval(x) for x in ll]
+ if remove_duplicates:
+ l = list(dict.fromkeys(l))
+ if sort:
+ l = sorted(l)
+ return(l)
+
+def get_trailing_int(string):
+ indexRegex = re.compile(r'\d+$')
+ mo = indexRegex.search(string)
+ if mo is None:
+ return(None)
+ else:
+ return(int(mo.group()))
+
+def input_int(s=None, ge=None, gt=None, le=None, lt=None, default=None, inset=None,
+ raise_error=False, log=True):
+ return(_input_int_or_num('int', s, ge, gt, le, lt, default, inset, raise_error, log))
+
+def input_num(s=None, ge=None, gt=None, le=None, lt=None, default=None, raise_error=False,
+ log=True):
+ return(_input_int_or_num('num', s, ge, gt, le, lt, default, None, raise_error,log))
+
+def _input_int_or_num(type_str, s=None, ge=None, gt=None, le=None, lt=None, default=None,
+ inset=None, raise_error=False, log=True):
+ if type_str == 'int':
+ if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, '_input_int_or_num', raise_error, log):
+ return(None)
+ elif type_str == 'num':
+ if not test_ge_gt_le_lt(ge, gt, le, lt, is_num, '_input_int_or_num', raise_error, log):
+ return(None)
+ else:
+ illegal_value(type_str, 'type_str', '_input_int_or_num', raise_error, log)
+ return(None)
+ if default is not None:
+ if not _is_int_or_num(default, type_str, raise_error=raise_error, log=log):
+ return(None)
+ if ge is not None and default < ge:
+ illegal_combination(ge, 'ge', default, 'default', '_input_int_or_num', raise_error,
+ log)
+ return(None)
+ if gt is not None and default <= gt:
+ illegal_combination(gt, 'gt', default, 'default', '_input_int_or_num', raise_error,
+ log)
+ return(None)
+ if le is not None and default > le:
+ illegal_combination(le, 'le', default, 'default', '_input_int_or_num', raise_error,
+ log)
+ return(None)
+ if lt is not None and default >= lt:
+ illegal_combination(lt, 'lt', default, 'default', '_input_int_or_num', raise_error,
+ log)
+ return(None)
+ default_string = f' [{default}]'
+ else:
+ default_string = ''
+ if inset is not None:
+ if (not isinstance(inset, (tuple, list)) or any(True if not isinstance(i, int) else
+ False for i in inset)):
+ illegal_value(inset, 'inset', '_input_int_or_num', raise_error, log)
+ return(None)
+ v_range = f'{range_string_ge_gt_le_lt(ge, gt, le, lt)}'
+ if len(v_range):
+ v_range = f' {v_range}'
+ if s is None:
+ if type_str == 'int':
+ print(f'Enter an integer{v_range}{default_string}: ')
+ else:
+ print(f'Enter a number{v_range}{default_string}: ')
+ else:
+ print(f'{s}{v_range}{default_string}: ')
+ try:
+ i = input()
+ if isinstance(i, str) and not len(i):
+ v = default
+ print(f'{v}')
+ else:
+ v = literal_eval(i)
+ if inset and v not in inset:
+ raise ValueError(f'{v} not part of the set {inset}')
+ except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
+ v = None
+ except:
+ if log:
+ logger.error('Unexpected error')
+ if raise_error:
+ raise ValueError('Unexpected error')
+ if not _is_int_or_num(v, type_str, ge, gt, le, lt):
+ v = _input_int_or_num(type_str, s, ge, gt, le, lt, default, inset, raise_error, log)
+ return(v)
+
+def input_int_list(s=None, ge=None, le=None, split_on_dash=True, remove_duplicates=True,
+ sort=True, raise_error=False, log=True):
+ """Prompt the user to input a list of interger and split the entered string on any combination
+ of commas, whitespaces, or dashes (when split_on_dash is True)
+ e.g: '1 3,5-8 , 12 ' -> [1, 3, 5, 6, 7, 8, 12]
+ remove_duplicates: removes duplicates if True (may also change the order)
+ sort: sort in ascending order if True
+ return None upon an illegal input
+ """
+ return(_input_int_or_num_list('int', s, ge, le, split_on_dash, remove_duplicates, sort,
+ raise_error, log))
+
+def input_num_list(s=None, ge=None, le=None, remove_duplicates=True, sort=True, raise_error=False,
+ log=True):
+ """Prompt the user to input a list of numbers and split the entered string on any combination
+ of commas or whitespaces
+ e.g: '1.0, 3, 5.8, 12 ' -> [1.0, 3.0, 5.8, 12.0]
+ remove_duplicates: removes duplicates if True (may also change the order)
+ sort: sort in ascending order if True
+ return None upon an illegal input
+ """
+ return(_input_int_or_num_list('num', s, ge, le, False, remove_duplicates, sort, raise_error,
+ log))
+
+def _input_int_or_num_list(type_str, s=None, ge=None, le=None, split_on_dash=True,
+ remove_duplicates=True, sort=True, raise_error=False, log=True):
+ #FIX do we want a limit on max dimension?
+ if type_str == 'int':
+ if not test_ge_gt_le_lt(ge, None, le, None, is_int, 'input_int_or_num_list', raise_error,
+ log):
+ return(None)
+ elif type_str == 'num':
+ if not test_ge_gt_le_lt(ge, None, le, None, is_num, 'input_int_or_num_list', raise_error,
+ log):
+ return(None)
+ else:
+ illegal_value(type_str, 'type_str', '_input_int_or_num_list')
+ return(None)
+ v_range = f'{range_string_ge_gt_le_lt(ge=ge, le=le)}'
+ if len(v_range):
+ v_range = f' (each value in {v_range})'
+ if s is None:
+ print(f'Enter a series of integers{v_range}: ')
+ else:
+ print(f'{s}{v_range}: ')
+ try:
+ l = string_to_list(input(), split_on_dash, remove_duplicates, sort)
+ except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
+ l = None
+ except:
+ print('Unexpected error')
+ raise
+ if (not isinstance(l, list) or
+ any(True if not _is_int_or_num(v, type_str, ge=ge, le=le) else False for v in l)):
+ if split_on_dash:
+ print('Invalid input: enter a valid set of dash/comma/whitespace separated integers '+
+ 'e.g. 1 3,5-8 , 12')
+ else:
+ print('Invalid input: enter a valid set of comma/whitespace separated integers '+
+ 'e.g. 1 3,5 8 , 12')
+ l = _input_int_or_num_list(type_str, s, ge, le, split_on_dash, remove_duplicates, sort,
+ raise_error, log)
+ return(l)
+
+def input_yesno(s=None, default=None):
+ if default is not None:
+ if not isinstance(default, str):
+ illegal_value(default, 'default', 'input_yesno')
+ return(None)
+ if default.lower() in 'yes':
+ default = 'y'
+ elif default.lower() in 'no':
+ default = 'n'
+ else:
+ illegal_value(default, 'default', 'input_yesno')
+ return(None)
+ default_string = f' [{default}]'
+ else:
+ default_string = ''
+ if s is None:
+ print(f'Enter yes or no{default_string}: ')
+ else:
+ print(f'{s}{default_string}: ')
+ i = input()
+ if isinstance(i, str) and not len(i):
+ i = default
+ print(f'{i}')
+ if i is not None and i.lower() in 'yes':
+ v = True
+ elif i is not None and i.lower() in 'no':
+ v = False
+ else:
+ print('Invalid input, enter yes or no')
+ v = input_yesno(s, default)
+ return(v)
+
+def input_menu(items, default=None, header=None):
+ if not isinstance(items, (tuple, list)) or any(True if not isinstance(i, str) else False
+ for i in items):
+ illegal_value(items, 'items', 'input_menu')
+ return(None)
+ if default is not None:
+ if not (isinstance(default, str) and default in items):
+ logger.error(f'Invalid value for default ({default}), must be in {items}')
+ return(None)
+ default_string = f' [{items.index(default)+1}]'
+ else:
+ default_string = ''
+ if header is None:
+ print(f'Choose one of the following items (1, {len(items)}){default_string}:')
+ else:
+ print(f'{header} (1, {len(items)}){default_string}:')
+ for i, choice in enumerate(items):
+ print(f' {i+1}: {choice}')
+ try:
+ choice = input()
+ if isinstance(choice, str) and not len(choice):
+ choice = items.index(default)
+ print(f'{choice+1}')
+ else:
+ choice = literal_eval(choice)
+ if isinstance(choice, int) and 1 <= choice <= len(items):
+ choice -= 1
+ else:
+ raise ValueError
+ except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
+ choice = None
+ except:
+ print('Unexpected error')
+ raise
+ if choice is None:
+ print(f'Invalid choice, enter a number between 1 and {len(items)}')
+ choice = input_menu(items, default)
+ return(choice)
+
+def assert_no_duplicates_in_list_of_dicts(l: list, raise_error=False) -> list:
+ if not isinstance(l, list):
+ illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error)
+ return(None)
+ if any(True if not isinstance(d, dict) else False for d in l):
+ illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error)
+ return(None)
+ if len(l) != len([dict(t) for t in {tuple(sorted(d.items())) for d in l}]):
+ if raise_error:
+ raise ValueError(f'Duplicate items found in {l}')
+ else:
+ logger.error(f'Duplicate items found in {l}')
+ return(None)
+ else:
+ return(l)
+
+def assert_no_duplicate_key_in_list_of_dicts(l: list, key: str, raise_error=False) -> list:
+ if not isinstance(key, str):
+ illegal_value(key, 'key', 'assert_no_duplicate_key_in_list_of_dicts', raise_error)
+ return(None)
+ if not isinstance(l, list):
+ illegal_value(l, 'l', 'assert_no_duplicate_key_in_list_of_dicts', raise_error)
+ return(None)
+ if any(True if not isinstance(d, dict) else False for d in l):
+ illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error)
+ return(None)
+ keys = [d.get(key, None) for d in l]
+ if None in keys or len(set(keys)) != len(l):
+ if raise_error:
+ raise ValueError(f'Duplicate or missing key ({key}) found in {l}')
+ else:
+ logger.error(f'Duplicate or missing key ({key}) found in {l}')
+ return(None)
+ else:
+ return(l)
+
+def assert_no_duplicate_attr_in_list_of_objs(l: list, attr: str, raise_error=False) -> list:
+ if not isinstance(attr, str):
+ illegal_value(attr, 'attr', 'assert_no_duplicate_attr_in_list_of_objs', raise_error)
+ return(None)
+ if not isinstance(l, list):
+ illegal_value(l, 'l', 'assert_no_duplicate_key_in_list_of_objs', raise_error)
+ return(None)
+ attrs = [getattr(obj, attr, None) for obj in l]
+ if None in attrs or len(set(attrs)) != len(l):
+ if raise_error:
+ raise ValueError(f'Duplicate or missing attr ({attr}) found in {l}')
+ else:
+ logger.error(f'Duplicate or missing attr ({attr}) found in {l}')
+ return(None)
+ else:
+ return(l)
+
+def file_exists_and_readable(path):
+ if not os.path.isfile(path):
+ raise ValueError(f'{path} is not a valid file')
+ elif not os.access(path, os.R_OK):
+ raise ValueError(f'{path} is not accessible for reading')
+ else:
+ return(path)
+
+def create_mask(x, bounds=None, exclude_bounds=False, current_mask=None):
+ # bounds is a pair of number in the same units a x
+ if not isinstance(x, (tuple, list, np.ndarray)) or not len(x):
+ logger.warning(f'Invalid input array ({x}, {type(x)})')
+ return(None)
+ if bounds is not None and not is_num_pair(bounds):
+ logger.warning(f'Invalid bounds parameter ({bounds} {type(bounds)}, input ignored')
+ bounds = None
+ if bounds is not None:
+ if exclude_bounds:
+ mask = np.logical_or(x < min(bounds), x > max(bounds))
+ else:
+ mask = np.logical_and(x > min(bounds), x < max(bounds))
+ else:
+ mask = np.ones(len(x), dtype=bool)
+ if current_mask is not None:
+ if not isinstance(current_mask, (tuple, list, np.ndarray)) or len(current_mask) != len(x):
+ logger.warning(f'Invalid current_mask ({current_mask}, {type(current_mask)}), '+
+ 'input ignored')
+ else:
+ mask = np.logical_or(mask, current_mask)
+ if not True in mask:
+ logger.warning('Entire data array is masked')
+ return(mask)
+
+def eval_expr(name, expr, expr_variables, user_variables=None, max_depth=10, raise_error=False,
+ log=True, **kwargs):
+ """Evaluate an expression of expressions
+ """
+ if not isinstance(name, str):
+ illegal_value(name, 'name', 'eval_expr', raise_error, log)
+ return(None)
+ if not isinstance(expr, str):
+ illegal_value(expr, 'expr', 'eval_expr', raise_error, log)
+ return(None)
+ if not is_dict_strings(expr_variables, log=False):
+ illegal_value(expr_variables, 'expr_variables', 'eval_expr', raise_error, log)
+ return(None)
+ if user_variables is not None and not is_dict_nums(user_variables, log=False):
+ illegal_value(user_variables, 'user_variables', 'eval_expr', raise_error, log)
+ return(None)
+ if not is_int(max_depth, gt=1, log=False):
+ illegal_value(max_depth, 'max_depth', 'eval_expr', raise_error, log)
+ return(None)
+ if not isinstance(raise_error, bool):
+ illegal_value(raise_error, 'raise_error', 'eval_expr', raise_error, log)
+ return(None)
+ if not isinstance(log, bool):
+ illegal_value(log, 'log', 'eval_expr', raise_error, log)
+ return(None)
+# print(f'\nEvaluate the full expression for {expr}')
+ if 'chain' in kwargs:
+ chain = kwargs.pop('chain')
+ if not is_str_series(chain):
+ illegal_value(chain, 'chain', 'eval_expr', raise_error, log)
+ return(None)
+ else:
+ chain = []
+ if len(chain) > max_depth:
+ error_msg = 'Exceeded maximum depth ({max_depth}) in eval_expr'
+ if log:
+ logger.error(error_msg)
+ if raise_error:
+ raise ValueError(error_msg)
+ return(None)
+ if name not in chain:
+ chain.append(name)
+# print(f'start: chain = {chain}')
+ if 'ast' in kwargs:
+ ast = kwargs.pop('ast')
+ else:
+ ast = Interpreter()
+ if user_variables is not None:
+ ast.symtable.update(user_variables)
+ chain_vars = [var for var in get_ast_names(ast.parse(expr))
+ if var in expr_variables and var not in ast.symtable]
+# print(f'chain_vars: {chain_vars}')
+ save_chain = chain.copy()
+ for var in chain_vars:
+# print(f'\n\tname = {name}, var = {var}:\n\t\t{expr_variables[var]}')
+# print(f'\tchain = {chain}')
+ if var in chain:
+ error_msg = f'Circular variable {var} in eval_expr'
+ if log:
+ logger.error(error_msg)
+ if raise_error:
+ raise ValueError(error_msg)
+ return(None)
+# print(f'\tknown symbols:\n\t\t{ast.user_defined_symbols()}\n')
+ if var in ast.user_defined_symbols():
+ val = ast.symtable[var]
+ else:
+ #val = eval_expr(var, expr_variables[var], expr_variables, user_variables=user_variables,
+ val = eval_expr(var, expr_variables[var], expr_variables, max_depth=max_depth,
+ raise_error=raise_error, log=log, chain=chain, ast=ast)
+ if val is None:
+ return(None)
+ ast.symtable[var] = val
+# print(f'\tval = {val}')
+# print(f'\t{var} = {ast.symtable[var]}')
+ chain = save_chain.copy()
+# print(f'\treset loop for {var}: chain = {chain}')
+ val = ast.eval(expr)
+# print(f'return val for {expr} = {val}\n')
+ return(val)
+
+def full_gradient(expr, x, expr_name=None, expr_variables=None, valid_variables=None, max_depth=10,
+ raise_error=False, log=True, **kwargs):
+ """Compute the full gradient dexpr/dx
+ """
+ if not isinstance(x, str):
+ illegal_value(x, 'x', 'full_gradient', raise_error, log)
+ return(None)
+ if expr_name is not None and not isinstance(expr_name, str):
+ illegal_value(expr_name, 'expr_name', 'eval_expr', raise_error, log)
+ return(None)
+ if expr_variables is not None and not is_dict_strings(expr_variables, log=False):
+ illegal_value(expr_variables, 'expr_variables', 'full_gradient', raise_error, log)
+ return(None)
+ if valid_variables is not None and not is_str_series(valid_variables, log=False):
+ illegal_value(valid_variables, 'valid_variables', 'full_gradient', raise_error, log)
+ if not is_int(max_depth, gt=1, log=False):
+ illegal_value(max_depth, 'max_depth', 'eval_expr', raise_error, log)
+ return(None)
+ if not isinstance(raise_error, bool):
+ illegal_value(raise_error, 'raise_error', 'eval_expr', raise_error, log)
+ return(None)
+ if not isinstance(log, bool):
+ illegal_value(log, 'log', 'eval_expr', raise_error, log)
+ return(None)
+# print(f'\nGet full gradient of {expr_name} = {expr} with respect to {x}')
+ if expr_name is not None and expr_name == x:
+ return(1.0)
+ if 'chain' in kwargs:
+ chain = kwargs.pop('chain')
+ if not is_str_series(chain):
+ illegal_value(chain, 'chain', 'eval_expr', raise_error, log)
+ return(None)
+ else:
+ chain = []
+ if len(chain) > max_depth:
+ error_msg = 'Exceeded maximum depth ({max_depth}) in eval_expr'
+ if log:
+ logger.error(error_msg)
+ if raise_error:
+ raise ValueError(error_msg)
+ return(None)
+ if expr_name is not None and expr_name not in chain:
+ chain.append(expr_name)
+# print(f'start ({x}): chain = {chain}')
+ ast = Interpreter()
+ if expr_variables is None:
+ chain_vars = []
+ else:
+ chain_vars = [var for var in get_ast_names(ast.parse(f'{expr}'))
+ if var in expr_variables and var != x and var not in ast.symtable]
+# print(f'chain_vars: {chain_vars}')
+ if valid_variables is not None:
+ unknown_vars = [var for var in chain_vars if var not in valid_variables]
+ if len(unknown_vars):
+ error_msg = f'Unknown variable {unknown_vars} in {expr}'
+ if log:
+ logger.error(error_msg)
+ if raise_error:
+ raise ValueError(error_msg)
+ return(None)
+ dexpr_dx = diff(expr, x)
+# print(f'direct gradient: d({expr})/d({x}) = {dexpr_dx} ({type(dexpr_dx)})')
+ save_chain = chain.copy()
+ for var in chain_vars:
+# print(f'\n\texpr_name = {expr_name}, var = {var}:\n\t\t{expr}')
+# print(f'\tchain = {chain}')
+ if var in chain:
+ error_msg = f'Circular variable {var} in full_gradient'
+ if log:
+ logger.error(error_msg)
+ if raise_error:
+ raise ValueError(error_msg)
+ return(None)
+ dexpr_dvar = diff(expr, var)
+# print(f'\td({expr})/d({var}) = {dexpr_dvar}')
+ if dexpr_dvar:
+ dvar_dx = full_gradient(expr_variables[var], x, expr_name=var,
+ expr_variables=expr_variables, valid_variables=valid_variables,
+ max_depth=max_depth, raise_error=raise_error, log=log, chain=chain)
+# print(f'\t\td({var})/d({x}) = {dvar_dx}')
+ if dvar_dx:
+ dexpr_dx = f'{dexpr_dx}+({dexpr_dvar})*({dvar_dx})'
+# print(f'\t\t2: chain = {chain}')
+ chain = save_chain.copy()
+# print(f'\treset loop for {var}: chain = {chain}')
+# print(f'full gradient: d({expr})/d({x}) = {dexpr_dx} ({type(dexpr_dx)})')
+# print(f'reset end: chain = {chain}\n\n')
+ return(simplify(dexpr_dx))
+
+def bounds_from_mask(mask, return_include_bounds:bool=True):
+ bounds = []
+ for i, m in enumerate(mask):
+ if m == return_include_bounds:
+ if len(bounds) == 0 or type(bounds[-1]) == tuple:
+ bounds.append(i)
+ else:
+ if len(bounds) > 0 and isinstance(bounds[-1], int):
+ bounds[-1] = (bounds[-1], i-1)
+ if len(bounds) > 0 and isinstance(bounds[-1], int):
+ bounds[-1] = (bounds[-1], mask.size-1)
+ return(bounds)
+
+def draw_mask_1d(ydata, xdata=None, current_index_ranges=None, current_mask=None,
+ select_mask=True, num_index_ranges_max=None, title=None, legend=None, test_mode=False):
+ #FIX make color blind friendly
+ def draw_selections(ax, current_include, current_exclude, selected_index_ranges):
+ ax.clear()
+ ax.set_title(title)
+ ax.legend([legend])
+ ax.plot(xdata, ydata, 'k')
+ for (low, upp) in current_include:
+ xlow = 0.5*(xdata[max(0, low-1)]+xdata[low])
+ xupp = 0.5*(xdata[upp]+xdata[min(num_data-1, upp+1)])
+ ax.axvspan(xlow, xupp, facecolor='green', alpha=0.5)
+ for (low, upp) in current_exclude:
+ xlow = 0.5*(xdata[max(0, low-1)]+xdata[low])
+ xupp = 0.5*(xdata[upp]+xdata[min(num_data-1, upp+1)])
+ ax.axvspan(xlow, xupp, facecolor='red', alpha=0.5)
+ for (low, upp) in selected_index_ranges:
+ xlow = 0.5*(xdata[max(0, low-1)]+xdata[low])
+ xupp = 0.5*(xdata[upp]+xdata[min(num_data-1, upp+1)])
+ ax.axvspan(xlow, xupp, facecolor=selection_color, alpha=0.5)
+ ax.get_figure().canvas.draw()
+
+ def onclick(event):
+ if event.inaxes in [fig.axes[0]]:
+ selected_index_ranges.append(index_nearest_upp(xdata, event.xdata))
+
+ def onrelease(event):
+ if len(selected_index_ranges) > 0:
+ if isinstance(selected_index_ranges[-1], int):
+ if event.inaxes in [fig.axes[0]]:
+ event.xdata = index_nearest_low(xdata, event.xdata)
+ if selected_index_ranges[-1] <= event.xdata:
+ selected_index_ranges[-1] = (selected_index_ranges[-1], event.xdata)
+ else:
+ selected_index_ranges[-1] = (event.xdata, selected_index_ranges[-1])
+ draw_selections(event.inaxes, current_include, current_exclude, selected_index_ranges)
+ else:
+ selected_index_ranges.pop(-1)
+
+ def confirm_selection(event):
+ plt.close()
+
+ def clear_last_selection(event):
+ if len(selected_index_ranges):
+ selected_index_ranges.pop(-1)
+ else:
+ while len(current_include):
+ current_include.pop()
+ while len(current_exclude):
+ current_exclude.pop()
+ selected_mask.fill(False)
+ draw_selections(ax, current_include, current_exclude, selected_index_ranges)
+
+ def update_mask(mask, selected_index_ranges, unselected_index_ranges):
+ for (low, upp) in selected_index_ranges:
+ selected_mask = np.logical_and(xdata >= xdata[low], xdata <= xdata[upp])
+ mask = np.logical_or(mask, selected_mask)
+ for (low, upp) in unselected_index_ranges:
+ unselected_mask = np.logical_and(xdata >= xdata[low], xdata <= xdata[upp])
+ mask[unselected_mask] = False
+ return(mask)
+
+ def update_index_ranges(mask):
+ # Update the currently included index ranges (where mask is True)
+ current_include = []
+ for i, m in enumerate(mask):
+ if m == True:
+ if len(current_include) == 0 or type(current_include[-1]) == tuple:
+ current_include.append(i)
+ else:
+ if len(current_include) > 0 and isinstance(current_include[-1], int):
+ current_include[-1] = (current_include[-1], i-1)
+ if len(current_include) > 0 and isinstance(current_include[-1], int):
+ current_include[-1] = (current_include[-1], num_data-1)
+ return(current_include)
+
+ # Check inputs
+ ydata = np.asarray(ydata)
+ if ydata.ndim > 1:
+ logger.warning(f'Invalid ydata dimension ({ydata.ndim})')
+ return(None, None)
+ num_data = ydata.size
+ if xdata is None:
+ xdata = np.arange(num_data)
+ else:
+ xdata = np.asarray(xdata, dtype=np.float64)
+ if xdata.ndim > 1 or xdata.size != num_data:
+ logger.warning(f'Invalid xdata shape ({xdata.shape})')
+ return(None, None)
+ if not np.all(xdata[:-1] < xdata[1:]):
+ logger.warning('Invalid xdata: must be monotonically increasing')
+ return(None, None)
+ if current_index_ranges is not None:
+ if not isinstance(current_index_ranges, (tuple, list)):
+ logger.warning('Invalid current_index_ranges parameter ({current_index_ranges}, '+
+ f'{type(current_index_ranges)})')
+ return(None, None)
+ if not isinstance(select_mask, bool):
+ logger.warning('Invalid select_mask parameter ({select_mask}, {type(select_mask)})')
+ return(None, None)
+ if num_index_ranges_max is not None:
+ logger.warning('num_index_ranges_max input not yet implemented in draw_mask_1d')
+ if title is None:
+ title = 'select ranges of data'
+ elif not isinstance(title, str):
+ illegal(title, 'title')
+ title = ''
+ if legend is None and not isinstance(title, str):
+ illegal(legend, 'legend')
+ legend = None
+
+ if select_mask:
+ title = f'Click and drag to {title} you wish to include'
+ selection_color = 'green'
+ else:
+ title = f'Click and drag to {title} you wish to exclude'
+ selection_color = 'red'
+
+ # Set initial selected mask and the selected/unselected index ranges as needed
+ selected_index_ranges = []
+ unselected_index_ranges = []
+ selected_mask = np.full(xdata.shape, False, dtype=bool)
+ if current_index_ranges is None:
+ if current_mask is None:
+ if not select_mask:
+ selected_index_ranges = [(0, num_data-1)]
+ selected_mask = np.full(xdata.shape, True, dtype=bool)
+ else:
+ selected_mask = np.copy(np.asarray(current_mask, dtype=bool))
+ if current_index_ranges is not None and len(current_index_ranges):
+ current_index_ranges = sorted([(low, upp) for (low, upp) in current_index_ranges])
+ for (low, upp) in current_index_ranges:
+ if low > upp or low >= num_data or upp < 0:
+ continue
+ if low < 0:
+ low = 0
+ if upp >= num_data:
+ upp = num_data-1
+ selected_index_ranges.append((low, upp))
+ selected_mask = update_mask(selected_mask, selected_index_ranges, unselected_index_ranges)
+ if current_index_ranges is not None and current_mask is not None:
+ selected_mask = np.logical_and(current_mask, selected_mask)
+ if current_mask is not None:
+ selected_index_ranges = update_index_ranges(selected_mask)
+
+ # Set up range selections for display
+ current_include = selected_index_ranges
+ current_exclude = []
+ selected_index_ranges = []
+ if not len(current_include):
+ if select_mask:
+ current_exclude = [(0, num_data-1)]
+ else:
+ current_include = [(0, num_data-1)]
+ else:
+ if current_include[0][0] > 0:
+ current_exclude.append((0, current_include[0][0]-1))
+ for i in range(1, len(current_include)):
+ current_exclude.append((current_include[i-1][1]+1, current_include[i][0]-1))
+ if current_include[-1][1] < num_data-1:
+ current_exclude.append((current_include[-1][1]+1, num_data-1))
+
+ if not test_mode:
+
+ # Set up matplotlib figure
+ plt.close('all')
+ fig, ax = plt.subplots()
+ plt.subplots_adjust(bottom=0.2)
+ draw_selections(ax, current_include, current_exclude, selected_index_ranges)
+
+ # Set up event handling for click-and-drag range selection
+ cid_click = fig.canvas.mpl_connect('button_press_event', onclick)
+ cid_release = fig.canvas.mpl_connect('button_release_event', onrelease)
+
+ # Set up confirm / clear range selection buttons
+ confirm_b = Button(plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm')
+ clear_b = Button(plt.axes([0.59, 0.05, 0.15, 0.075]), 'Clear')
+ cid_confirm = confirm_b.on_clicked(confirm_selection)
+ cid_clear = clear_b.on_clicked(clear_last_selection)
+
+ # Show figure
+ plt.show(block=True)
+
+ # Disconnect callbacks when figure is closed
+ fig.canvas.mpl_disconnect(cid_click)
+ fig.canvas.mpl_disconnect(cid_release)
+ confirm_b.disconnect(cid_confirm)
+ clear_b.disconnect(cid_clear)
+
+ # Swap selection depending on select_mask
+ if not select_mask:
+ selected_index_ranges, unselected_index_ranges = unselected_index_ranges, \
+ selected_index_ranges
+
+ # Update the mask with the currently selected/unselected x-ranges
+ selected_mask = update_mask(selected_mask, selected_index_ranges, unselected_index_ranges)
+
+ # Update the currently included index ranges (where mask is True)
+ current_include = update_index_ranges(selected_mask)
+
+ return(selected_mask, current_include)
+
+def select_peaks(ydata:np.ndarray, x_values:np.ndarray=None, x_mask:np.ndarray=None,
+ peak_x_values:np.ndarray=np.array([]), peak_x_indices:np.ndarray=np.array([]),
+ return_peak_x_values:bool=False, return_peak_x_indices:bool=False,
+ return_peak_input_indices:bool=False, return_sorted:bool=False,
+ title:str=None, xlabel:str=None, ylabel:str=None) -> list :
+
+ # Check arguments
+ if (len(peak_x_values) > 0 or return_peak_x_values) and not len(x_values) > 0:
+ raise RuntimeError('Cannot use peak_x_values or return_peak_x_values without x_values')
+ if not ((len(peak_x_values) > 0) ^ (len(peak_x_indices) > 0)):
+ raise RuntimeError('Use exactly one of peak_x_values or peak_x_indices')
+ return_format_iter = iter((return_peak_x_values, return_peak_x_indices, return_peak_input_indices))
+ if not (any(return_format_iter) and not any(return_format_iter)):
+ raise RuntimeError('Exactly one of return_peak_x_values, return_peak_x_indices, or '+
+ 'return_peak_input_indices must be True')
+
+ EXCLUDE_PEAK_PROPERTIES = {'color': 'black', 'linestyle': '--','linewidth': 1,
+ 'marker': 10, 'markersize': 5, 'fillstyle': 'none'}
+ INCLUDE_PEAK_PROPERTIES = {'color': 'green', 'linestyle': '-', 'linewidth': 2,
+ 'marker': 10, 'markersize': 10, 'fillstyle': 'full'}
+ MASKED_PEAK_PROPERTIES = {'color': 'gray', 'linestyle': ':', 'linewidth': 1}
+
+ # Setup reference data & plot
+ x_indices = np.arange(len(ydata))
+ if x_values is None:
+ x_values = x_indices
+ if x_mask is None:
+ x_mask = np.full(x_values.shape, True, dtype=bool)
+ fig, ax = plt.subplots()
+ handles = ax.plot(x_values, ydata, label='Reference data')
+ handles.append(mlines.Line2D([], [], label='Excluded / unselected HKL', **EXCLUDE_PEAK_PROPERTIES))
+ handles.append(mlines.Line2D([], [], label='Included / selected HKL', **INCLUDE_PEAK_PROPERTIES))
+ handles.append(mlines.Line2D([], [], label='HKL in masked region (unselectable)', **MASKED_PEAK_PROPERTIES))
+ ax.legend(handles=handles, loc='upper right')
+ ax.set(title=title, xlabel=xlabel, ylabel=ylabel)
+
+
+ # Plot vertical line at each peak
+ value_to_index = lambda x_value: int(np.argmin(abs(x_values - x_value)))
+ if len(peak_x_indices) > 0:
+ peak_x_values = x_values[peak_x_indices]
+ else:
+ peak_x_indices = np.array(list(map(value_to_index, peak_x_values)))
+ peak_vlines = []
+ for loc in peak_x_values:
+ nearest_index = value_to_index(loc)
+ if nearest_index in x_indices[x_mask]:
+ peak_vline = ax.axvline(loc, **EXCLUDE_PEAK_PROPERTIES)
+ peak_vline.set_picker(5)
+ else:
+ peak_vline = ax.axvline(loc, **MASKED_PEAK_PROPERTIES)
+ peak_vlines.append(peak_vline)
+
+ # Indicate masked regions by gray-ing out the axes facecolor
+ mask_exclude_bounds = bounds_from_mask(x_mask, return_include_bounds=False)
+ for (low, upp) in mask_exclude_bounds:
+ xlow = x_values[low]
+ xupp = x_values[upp]
+ ax.axvspan(xlow, xupp, facecolor='gray', alpha=0.5)
+
+ # Setup peak picking
+ selected_peak_input_indices = []
+ def onpick(event):
+ try:
+ peak_index = peak_vlines.index(event.artist)
+ except:
+ pass
+ else:
+ peak_vline = event.artist
+ if peak_index in selected_peak_input_indices:
+ peak_vline.set(**EXCLUDE_PEAK_PROPERTIES)
+ selected_peak_input_indices.remove(peak_index)
+ else:
+ peak_vline.set(**INCLUDE_PEAK_PROPERTIES)
+ selected_peak_input_indices.append(peak_index)
+ plt.draw()
+ cid_pick_peak = fig.canvas.mpl_connect('pick_event', onpick)
+
+ # Setup "Confirm" button
+ def confirm_selection(event):
+ plt.close()
+ plt.subplots_adjust(bottom=0.2)
+ confirm_b = Button(plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm')
+ cid_confirm = confirm_b.on_clicked(confirm_selection)
+
+ # Show figure for user interaction
+ plt.show()
+
+ # Disconnect callbacks when figure is closed
+ fig.canvas.mpl_disconnect(cid_pick_peak)
+ confirm_b.disconnect(cid_confirm)
+
+ if return_peak_input_indices:
+ selected_peaks = np.array(selected_peak_input_indices)
+ if return_peak_x_values:
+ selected_peaks = peak_x_values[selected_peak_input_indices]
+ if return_peak_x_indices:
+ selected_peaks = peak_x_indices[selected_peak_input_indices]
+
+ if return_sorted:
+ selected_peaks.sort()
+
+ return(selected_peaks)
+
+def find_image_files(path, filetype, name=None):
+ if isinstance(name, str):
+ name = f'{name.strip()} '
+ else:
+ name = ''
+ # Find available index range
+ if filetype == 'tif':
+ if not isinstance(path, str) or not os.path.isdir(path):
+ illegal_value(path, 'path', 'find_image_files')
+ return(-1, 0, [])
+ indexRegex = re.compile(r'\d+')
+ # At this point only tiffs
+ files = sorted([f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) and
+ f.endswith('.tif') and indexRegex.search(f)])
+ num_img = len(files)
+ if num_img < 1:
+ logger.warning(f'No available {name}files')
+ return(-1, 0, [])
+ first_index = indexRegex.search(files[0]).group()
+ last_index = indexRegex.search(files[-1]).group()
+ if first_index is None or last_index is None:
+ logger.error(f'Unable to find correctly indexed {name}images')
+ return(-1, 0, [])
+ first_index = int(first_index)
+ last_index = int(last_index)
+ if num_img != last_index-first_index+1:
+ logger.error(f'Non-consecutive set of indices for {name}images')
+ return(-1, 0, [])
+ paths = [os.path.join(path, f) for f in files]
+ elif filetype == 'h5':
+ if not isinstance(path, str) or not os.path.isfile(path):
+ illegal_value(path, 'path', 'find_image_files')
+ return(-1, 0, [])
+ # At this point only h5 in alamo2 detector style
+ first_index = 0
+ with h5py.File(path, 'r') as f:
+ num_img = f['entry/instrument/detector/data'].shape[0]
+ last_index = num_img-1
+ paths = [path]
+ else:
+ illegal_value(filetype, 'filetype', 'find_image_files')
+ return(-1, 0, [])
+ logger.info(f'Number of available {name}images: {num_img}')
+ logger.info(f'Index range of available {name}images: [{first_index}, '+
+ f'{last_index}]')
+
+ return(first_index, num_img, paths)
+
+def select_image_range(first_index, offset, num_available, num_img=None, name=None,
+ num_required=None):
+ if isinstance(name, str):
+ name = f'{name.strip()} '
+ else:
+ name = ''
+ # Check existing values
+ if not is_int(num_available, gt=0):
+ logger.warning(f'No available {name}images')
+ return(0, 0, 0)
+ if num_img is not None and not is_int(num_img, ge=0):
+ illegal_value(num_img, 'num_img', 'select_image_range')
+ return(0, 0, 0)
+ if is_int(first_index, ge=0) and is_int(offset, ge=0):
+ if num_required is None:
+ if input_yesno(f'\nCurrent {name}first image index/offset = {first_index}/{offset},'+
+ 'use these values (y/n)?', 'y'):
+ if num_img is not None:
+ if input_yesno(f'Current number of {name}images = {num_img}, '+
+ 'use this value (y/n)? ', 'y'):
+ return(first_index, offset, num_img)
+ else:
+ if input_yesno(f'Number of available {name}images = {num_available}, '+
+ 'use all (y/n)? ', 'y'):
+ return(first_index, offset, num_available)
+ else:
+ if input_yesno(f'\nCurrent {name}first image offset = {offset}, '+
+ f'use this values (y/n)?', 'y'):
+ return(first_index, offset, num_required)
+
+ # Check range against requirements
+ if num_required is None:
+ if num_available == 1:
+ return(first_index, 0, 1)
+ else:
+ if not is_int(num_required, ge=1):
+ illegal_value(num_required, 'num_required', 'select_image_range')
+ return(0, 0, 0)
+ if num_available < num_required:
+ logger.error(f'Unable to find the required {name}images ({num_available} out of '+
+ f'{num_required})')
+ return(0, 0, 0)
+
+ # Select index range
+ print(f'\nThe number of available {name}images is {num_available}')
+ if num_required is None:
+ last_index = first_index+num_available
+ use_all = f'Use all ([{first_index}, {last_index}])'
+ pick_offset = 'Pick the first image index offset and the number of images'
+ pick_bounds = 'Pick the first and last image index'
+ choice = input_menu([use_all, pick_offset, pick_bounds], default=pick_offset)
+ if not choice:
+ offset = 0
+ num_img = num_available
+ elif choice == 1:
+ offset = input_int('Enter the first index offset', ge=0, le=last_index-first_index)
+ if first_index+offset == last_index:
+ num_img = 1
+ else:
+ num_img = input_int('Enter the number of images', ge=1, le=num_available-offset)
+ else:
+ offset = input_int('Enter the first index', ge=first_index, le=last_index)
+ num_img = 1-offset+input_int('Enter the last index', ge=offset, le=last_index)
+ offset -= first_index
+ else:
+ use_all = f'Use ([{first_index}, {first_index+num_required-1}])'
+ pick_offset = 'Pick the first index offset'
+ choice = input_menu([use_all, pick_offset], pick_offset)
+ offset = 0
+ if choice == 1:
+ offset = input_int('Enter the first index offset', ge=0, le=num_available-num_required)
+ num_img = num_required
+
+ return(first_index, offset, num_img)
+
+def load_image(f, img_x_bounds=None, img_y_bounds=None):
+ """Load a single image from file.
+ """
+ if not os.path.isfile(f):
+ logger.error(f'Unable to load {f}')
+ return(None)
+ img_read = plt.imread(f)
+ if not img_x_bounds:
+ img_x_bounds = (0, img_read.shape[0])
+ else:
+ if (not isinstance(img_x_bounds, (tuple, list)) or len(img_x_bounds) != 2 or
+ not (0 <= img_x_bounds[0] < img_x_bounds[1] <= img_read.shape[0])):
+ logger.error(f'inconsistent row dimension in {f}')
+ return(None)
+ if not img_y_bounds:
+ img_y_bounds = (0, img_read.shape[1])
+ else:
+ if (not isinstance(img_y_bounds, list) or len(img_y_bounds) != 2 or
+ not (0 <= img_y_bounds[0] < img_y_bounds[1] <= img_read.shape[1])):
+ logger.error(f'inconsistent column dimension in {f}')
+ return(None)
+ return(img_read[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]])
+
+def load_image_stack(files, filetype, img_offset, num_img, num_img_skip=0,
+ img_x_bounds=None, img_y_bounds=None):
+ """Load a set of images and return them as a stack.
+ """
+ logger.debug(f'img_offset = {img_offset}')
+ logger.debug(f'num_img = {num_img}')
+ logger.debug(f'num_img_skip = {num_img_skip}')
+ logger.debug(f'\nfiles:\n{files}\n')
+ img_stack = np.array([])
+ if filetype == 'tif':
+ img_read_stack = []
+ i = 1
+ t0 = time()
+ for f in files[img_offset:img_offset+num_img:num_img_skip+1]:
+ if not i%20:
+ logger.info(f' loading {i}/{num_img}: {f}')
+ else:
+ logger.debug(f' loading {i}/{num_img}: {f}')
+ img_read = load_image(f, img_x_bounds, img_y_bounds)
+ img_read_stack.append(img_read)
+ i += num_img_skip+1
+ img_stack = np.stack([img_read for img_read in img_read_stack])
+ logger.info(f'... done in {time()-t0:.2f} seconds!')
+ logger.debug(f'img_stack shape = {np.shape(img_stack)}')
+ del img_read_stack, img_read
+ elif filetype == 'h5':
+ if not isinstance(files[0], str) and not os.path.isfile(files[0]):
+ illegal_value(files[0], 'files[0]', 'load_image_stack')
+ return(img_stack)
+ t0 = time()
+ logger.info(f'Loading {files[0]}')
+ with h5py.File(files[0], 'r') as f:
+ shape = f['entry/instrument/detector/data'].shape
+ if len(shape) != 3:
+ logger.error(f'inconsistent dimensions in {files[0]}')
+ if not img_x_bounds:
+ img_x_bounds = (0, shape[1])
+ else:
+ if (not isinstance(img_x_bounds, (tuple, list)) or len(img_x_bounds) != 2 or
+ not (0 <= img_x_bounds[0] < img_x_bounds[1] <= shape[1])):
+ logger.error(f'inconsistent row dimension in {files[0]} {img_x_bounds} '+
+ f'{shape[1]}')
+ if not img_y_bounds:
+ img_y_bounds = (0, shape[2])
+ else:
+ if (not isinstance(img_y_bounds, list) or len(img_y_bounds) != 2 or
+ not (0 <= img_y_bounds[0] < img_y_bounds[1] <= shape[2])):
+ logger.error(f'inconsistent column dimension in {files[0]}')
+ img_stack = f.get('entry/instrument/detector/data')[
+ img_offset:img_offset+num_img:num_img_skip+1,
+ img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]]
+ logger.info(f'... done in {time()-t0:.2f} seconds!')
+ else:
+ illegal_value(filetype, 'filetype', 'load_image_stack')
+ return(img_stack)
+
+def combine_tiffs_in_h5(files, num_img, h5_filename):
+ img_stack = load_image_stack(files, 'tif', 0, num_img)
+ with h5py.File(h5_filename, 'w') as f:
+ f.create_dataset('entry/instrument/detector/data', data=img_stack)
+ del img_stack
+ return([h5_filename])
+
+def clear_imshow(title=None):
+ plt.ioff()
+ if title is None:
+ title = 'quick imshow'
+ elif not isinstance(title, str):
+ illegal_value(title, 'title', 'clear_imshow')
+ return
+ plt.close(fig=title)
+
+def clear_plot(title=None):
+ plt.ioff()
+ if title is None:
+ title = 'quick plot'
+ elif not isinstance(title, str):
+ illegal_value(title, 'title', 'clear_plot')
+ return
+ plt.close(fig=title)
+
+def quick_imshow(a, title=None, path=None, name=None, save_fig=False, save_only=False,
+ clear=True, extent=None, show_grid=False, grid_color='w', grid_linewidth=1,
+ block=False, **kwargs):
+ if title is not None and not isinstance(title, str):
+ illegal_value(title, 'title', 'quick_imshow')
+ return
+ if path is not None and not isinstance(path, str):
+ illegal_value(path, 'path', 'quick_imshow')
+ return
+ if not isinstance(save_fig, bool):
+ illegal_value(save_fig, 'save_fig', 'quick_imshow')
+ return
+ if not isinstance(save_only, bool):
+ illegal_value(save_only, 'save_only', 'quick_imshow')
+ return
+ if not isinstance(clear, bool):
+ illegal_value(clear, 'clear', 'quick_imshow')
+ return
+ if not isinstance(block, bool):
+ illegal_value(block, 'block', 'quick_imshow')
+ return
+ if not title:
+ title='quick imshow'
+# else:
+# title = re.sub(r"\s+", '_', title)
+ if name is None:
+ ttitle = re.sub(r"\s+", '_', title)
+ if path is None:
+ path = f'{ttitle}.png'
+ else:
+ path = f'{path}/{ttitle}.png'
+ else:
+ if path is None:
+ path = name
+ else:
+ path = f'{path}/{name}'
+ if 'cmap' in kwargs and a.ndim == 3 and (a.shape[2] == 3 or a.shape[2] == 4):
+ use_cmap = True
+ if a.shape[2] == 4 and a[:,:,-1].min() != a[:,:,-1].max():
+ use_cmap = False
+ if any(True if a[i,j,0] != a[i,j,1] and a[i,j,0] != a[i,j,2] else False
+ for i in range(a.shape[0]) for j in range(a.shape[1])):
+ use_cmap = False
+ if use_cmap:
+ a = a[:,:,0]
+ else:
+ logger.warning('Image incompatible with cmap option, ignore cmap')
+ kwargs.pop('cmap')
+ if extent is None:
+ extent = (0, a.shape[1], a.shape[0], 0)
+ if clear:
+ try:
+ plt.close(fig=title)
+ except:
+ pass
+ if not save_only:
+ if block:
+ plt.ioff()
+ else:
+ plt.ion()
+ plt.figure(title)
+ plt.imshow(a, extent=extent, **kwargs)
+ if show_grid:
+ ax = plt.gca()
+ ax.grid(color=grid_color, linewidth=grid_linewidth)
+# if title != 'quick imshow':
+# plt.title = title
+ if save_only:
+ plt.savefig(path)
+ plt.close(fig=title)
+ else:
+ if save_fig:
+ plt.savefig(path)
+ if block:
+ plt.show(block=block)
+
+def quick_plot(*args, xerr=None, yerr=None, vlines=None, title=None, xlim=None, ylim=None,
+ xlabel=None, ylabel=None, legend=None, path=None, name=None, show_grid=False,
+ save_fig=False, save_only=False, clear=True, block=False, **kwargs):
+ if title is not None and not isinstance(title, str):
+ illegal_value(title, 'title', 'quick_plot')
+ title = None
+ if xlim is not None and not isinstance(xlim, (tuple, list)) and len(xlim) != 2:
+ illegal_value(xlim, 'xlim', 'quick_plot')
+ xlim = None
+ if ylim is not None and not isinstance(ylim, (tuple, list)) and len(ylim) != 2:
+ illegal_value(ylim, 'ylim', 'quick_plot')
+ ylim = None
+ if xlabel is not None and not isinstance(xlabel, str):
+ illegal_value(xlabel, 'xlabel', 'quick_plot')
+ xlabel = None
+ if ylabel is not None and not isinstance(ylabel, str):
+ illegal_value(ylabel, 'ylabel', 'quick_plot')
+ ylabel = None
+ if legend is not None and not isinstance(legend, (tuple, list)):
+ illegal_value(legend, 'legend', 'quick_plot')
+ legend = None
+ if path is not None and not isinstance(path, str):
+ illegal_value(path, 'path', 'quick_plot')
+ return
+ if not isinstance(show_grid, bool):
+ illegal_value(show_grid, 'show_grid', 'quick_plot')
+ return
+ if not isinstance(save_fig, bool):
+ illegal_value(save_fig, 'save_fig', 'quick_plot')
+ return
+ if not isinstance(save_only, bool):
+ illegal_value(save_only, 'save_only', 'quick_plot')
+ return
+ if not isinstance(clear, bool):
+ illegal_value(clear, 'clear', 'quick_plot')
+ return
+ if not isinstance(block, bool):
+ illegal_value(block, 'block', 'quick_plot')
+ return
+ if title is None:
+ title = 'quick plot'
+# else:
+# title = re.sub(r"\s+", '_', title)
+ if name is None:
+ ttitle = re.sub(r"\s+", '_', title)
+ if path is None:
+ path = f'{ttitle}.png'
+ else:
+ path = f'{path}/{ttitle}.png'
+ else:
+ if path is None:
+ path = name
+ else:
+ path = f'{path}/{name}'
+ if clear:
+ try:
+ plt.close(fig=title)
+ except:
+ pass
+ args = unwrap_tuple(args)
+ if depth_tuple(args) > 1 and (xerr is not None or yerr is not None):
+ logger.warning('Error bars ignored form multiple curves')
+ if not save_only:
+ if block:
+ plt.ioff()
+ else:
+ plt.ion()
+ plt.figure(title)
+ if depth_tuple(args) > 1:
+ for y in args:
+ plt.plot(*y, **kwargs)
+ else:
+ if xerr is None and yerr is None:
+ plt.plot(*args, **kwargs)
+ else:
+ plt.errorbar(*args, xerr=xerr, yerr=yerr, **kwargs)
+ if vlines is not None:
+ if isinstance(vlines, (int, float)):
+ vlines = [vlines]
+ for v in vlines:
+ plt.axvline(v, color='r', linestyle='--', **kwargs)
+# if vlines is not None:
+# for s in tuple(([x, x], list(plt.gca().get_ylim())) for x in vlines):
+# plt.plot(*s, color='red', **kwargs)
+ if xlim is not None:
+ plt.xlim(xlim)
+ if ylim is not None:
+ plt.ylim(ylim)
+ if xlabel is not None:
+ plt.xlabel(xlabel)
+ if ylabel is not None:
+ plt.ylabel(ylabel)
+ if show_grid:
+ ax = plt.gca()
+ ax.grid(color='k')#, linewidth=1)
+ if legend is not None:
+ plt.legend(legend)
+ if save_only:
+ plt.savefig(path)
+ plt.close(fig=title)
+ else:
+ if save_fig:
+ plt.savefig(path)
+ if block:
+ plt.show(block=block)
+
+def select_array_bounds(a, x_low=None, x_upp=None, num_x_min=None, ask_bounds=False,
+ title='select array bounds'):
+ """Interactively select the lower and upper data bounds for a numpy array.
+ """
+ if isinstance(a, (tuple, list)):
+ a = np.array(a)
+ if not isinstance(a, np.ndarray) or a.ndim != 1:
+ illegal_value(a.ndim, 'array type or dimension', 'select_array_bounds')
+ return(None)
+ len_a = len(a)
+ if num_x_min is None:
+ num_x_min = 1
+ else:
+ if num_x_min < 2 or num_x_min > len_a:
+ logger.warning('Invalid value for num_x_min in select_array_bounds, input ignored')
+ num_x_min = 1
+
+ # Ask to use current bounds
+ if ask_bounds and (x_low is not None or x_upp is not None):
+ if x_low is None:
+ x_low = 0
+ if not is_int(x_low, ge=0, le=len_a-num_x_min):
+ illegal_value(x_low, 'x_low', 'select_array_bounds')
+ return(None)
+ if x_upp is None:
+ x_upp = len_a
+ if not is_int(x_upp, ge=x_low+num_x_min, le=len_a):
+ illegal_value(x_upp, 'x_upp', 'select_array_bounds')
+ return(None)
+ quick_plot((range(len_a), a), vlines=(x_low,x_upp), title=title)
+ if not input_yesno(f'\nCurrent array bounds: [{x_low}, {x_upp}] '+
+ 'use these values (y/n)?', 'y'):
+ x_low = None
+ x_upp = None
+ else:
+ clear_plot(title)
+ return(x_low, x_upp)
+
+ if x_low is None:
+ x_min = 0
+ x_max = len_a
+ x_low_max = len_a-num_x_min
+ while True:
+ quick_plot(range(x_min, x_max), a[x_min:x_max], title=title)
+ zoom_flag = input_yesno('Set lower data bound (y) or zoom in (n)?', 'y')
+ if zoom_flag:
+ x_low = input_int(' Set lower data bound', ge=0, le=x_low_max)
+ break
+ else:
+ x_min = input_int(' Set lower zoom index', ge=0, le=x_low_max)
+ x_max = input_int(' Set upper zoom index', ge=x_min+1, le=x_low_max+1)
+ else:
+ if not is_int(x_low, ge=0, le=len_a-num_x_min):
+ illegal_value(x_low, 'x_low', 'select_array_bounds')
+ return(None)
+ if x_upp is None:
+ x_min = x_low+num_x_min
+ x_max = len_a
+ x_upp_min = x_min
+ while True:
+ quick_plot(range(x_min, x_max), a[x_min:x_max], title=title)
+ zoom_flag = input_yesno('Set upper data bound (y) or zoom in (n)?', 'y')
+ if zoom_flag:
+ x_upp = input_int(' Set upper data bound', ge=x_upp_min, le=len_a)
+ break
+ else:
+ x_min = input_int(' Set upper zoom index', ge=x_upp_min, le=len_a-1)
+ x_max = input_int(' Set upper zoom index', ge=x_min+1, le=len_a)
+ else:
+ if not is_int(x_upp, ge=x_low+num_x_min, le=len_a):
+ illegal_value(x_upp, 'x_upp', 'select_array_bounds')
+ return(None)
+ print(f'lower bound = {x_low} (inclusive)\nupper bound = {x_upp} (exclusive)]')
+ quick_plot((range(len_a), a), vlines=(x_low,x_upp), title=title)
+ if not input_yesno('Accept these bounds (y/n)?', 'y'):
+ x_low, x_upp = select_array_bounds(a, None, None, num_x_min, title=title)
+ clear_plot(title)
+ return(x_low, x_upp)
+
+def select_image_bounds(a, axis, low=None, upp=None, num_min=None, title='select array bounds',
+ raise_error=False):
+ """Interactively select the lower and upper data bounds for a 2D numpy array.
+ """
+ a = np.asarray(a)
+ if a.ndim != 2:
+ illegal_value(a.ndim, 'array dimension', location='select_image_bounds',
+ raise_error=raise_error)
+ return(None)
+ if axis < 0 or axis >= a.ndim:
+ illegal_value(axis, 'axis', location='select_image_bounds', raise_error=raise_error)
+ return(None)
+ low_save = low
+ upp_save = upp
+ num_min_save = num_min
+ if num_min is None:
+ num_min = 1
+ else:
+ if num_min < 2 or num_min > a.shape[axis]:
+ logger.warning('Invalid input for num_min in select_image_bounds, input ignored')
+ num_min = 1
+ if low is None:
+ min_ = 0
+ max_ = a.shape[axis]
+ low_max = a.shape[axis]-num_min
+ while True:
+ if axis:
+ quick_imshow(a[:,min_:max_], title=title, aspect='auto',
+ extent=[min_,max_,a.shape[0],0])
+ else:
+ quick_imshow(a[min_:max_,:], title=title, aspect='auto',
+ extent=[0,a.shape[1], max_,min_])
+ zoom_flag = input_yesno('Set lower data bound (y) or zoom in (n)?', 'y')
+ if zoom_flag:
+ low = input_int(' Set lower data bound', ge=0, le=low_max)
+ break
+ else:
+ min_ = input_int(' Set lower zoom index', ge=0, le=low_max)
+ max_ = input_int(' Set upper zoom index', ge=min_+1, le=low_max+1)
+ else:
+ if not is_int(low, ge=0, le=a.shape[axis]-num_min):
+ illegal_value(low, 'low', location='select_image_bounds', raise_error=raise_error)
+ return(None)
+ if upp is None:
+ min_ = low+num_min
+ max_ = a.shape[axis]
+ upp_min = min_
+ while True:
+ if axis:
+ quick_imshow(a[:,min_:max_], title=title, aspect='auto',
+ extent=[min_,max_,a.shape[0],0])
+ else:
+ quick_imshow(a[min_:max_,:], title=title, aspect='auto',
+ extent=[0,a.shape[1], max_,min_])
+ zoom_flag = input_yesno('Set upper data bound (y) or zoom in (n)?', 'y')
+ if zoom_flag:
+ upp = input_int(' Set upper data bound', ge=upp_min, le=a.shape[axis])
+ break
+ else:
+ min_ = input_int(' Set upper zoom index', ge=upp_min, le=a.shape[axis]-1)
+ max_ = input_int(' Set upper zoom index', ge=min_+1, le=a.shape[axis])
+ else:
+ if not is_int(upp, ge=low+num_min, le=a.shape[axis]):
+ illegal_value(upp, 'upp', location='select_image_bounds', raise_error=raise_error)
+ return(None)
+ bounds = (low, upp)
+ a_tmp = np.copy(a)
+ a_tmp_max = a.max()
+ if axis:
+ a_tmp[:,bounds[0]] = a_tmp_max
+ a_tmp[:,bounds[1]-1] = a_tmp_max
+ else:
+ a_tmp[bounds[0],:] = a_tmp_max
+ a_tmp[bounds[1]-1,:] = a_tmp_max
+ print(f'lower bound = {low} (inclusive)\nupper bound = {upp} (exclusive)')
+ quick_imshow(a_tmp, title=title, aspect='auto')
+ del a_tmp
+ if not input_yesno('Accept these bounds (y/n)?', 'y'):
+ bounds = select_image_bounds(a, axis, low=low_save, upp=upp_save, num_min=num_min_save,
+ title=title)
+ return(bounds)
+
+def select_one_image_bound(a, axis, bound=None, bound_name=None, title='select array bounds',
+ default='y', raise_error=False):
+ """Interactively select a data boundary for a 2D numpy array.
+ """
+ a = np.asarray(a)
+ if a.ndim != 2:
+ illegal_value(a.ndim, 'array dimension', location='select_one_image_bound',
+ raise_error=raise_error)
+ return(None)
+ if axis < 0 or axis >= a.ndim:
+ illegal_value(axis, 'axis', location='select_one_image_bound', raise_error=raise_error)
+ return(None)
+ if bound_name is None:
+ bound_name = 'data bound'
+ if bound is None:
+ min_ = 0
+ max_ = a.shape[axis]
+ bound_max = a.shape[axis]-1
+ while True:
+ if axis:
+ quick_imshow(a[:,min_:max_], title=title, aspect='auto',
+ extent=[min_,max_,a.shape[0],0])
+ else:
+ quick_imshow(a[min_:max_,:], title=title, aspect='auto',
+ extent=[0,a.shape[1], max_,min_])
+ zoom_flag = input_yesno(f'Set {bound_name} (y) or zoom in (n)?', 'y')
+ if zoom_flag:
+ bound = input_int(f' Set {bound_name}', ge=0, le=bound_max)
+ clear_imshow(title)
+ break
+ else:
+ min_ = input_int(' Set lower zoom index', ge=0, le=bound_max)
+ max_ = input_int(' Set upper zoom index', ge=min_+1, le=bound_max+1)
+
+ elif not is_int(bound, ge=0, le=a.shape[axis]-1):
+ illegal_value(bound, 'bound', location='select_one_image_bound', raise_error=raise_error)
+ return(None)
+ else:
+ print(f'Current {bound_name} = {bound}')
+ a_tmp = np.copy(a)
+ a_tmp_max = a.max()
+ if axis:
+ a_tmp[:,bound] = a_tmp_max
+ else:
+ a_tmp[bound,:] = a_tmp_max
+ quick_imshow(a_tmp, title=title, aspect='auto')
+ del a_tmp
+ if not input_yesno(f'Accept this {bound_name} (y/n)?', default):
+ bound = select_one_image_bound(a, axis, bound_name=bound_name, title=title)
+ clear_imshow(title)
+ return(bound)
+
+
+class Config:
+ """Base class for processing a config file or dictionary.
+ """
+ def __init__(self, config_file=None, config_dict=None):
+ self.config = {}
+ self.load_flag = False
+ self.suffix = None
+
+ # Load config file
+ if config_file is not None and config_dict is not None:
+ logger.warning('Ignoring config_dict (both config_file and config_dict are specified)')
+ if config_file is not None:
+ self.load_file(config_file)
+ elif config_dict is not None:
+ self.load_dict(config_dict)
+
+ def load_file(self, config_file):
+ """Load a config file.
+ """
+ if self.load_flag:
+ logger.warning('Overwriting any previously loaded config file')
+ self.config = {}
+
+ # Ensure config file exists
+ if not os.path.isfile(config_file):
+ logger.error(f'Unable to load {config_file}')
+ return
+
+ # Load config file (for now for Galaxy, allow .dat extension)
+ self.suffix = os.path.splitext(config_file)[1]
+ if self.suffix == '.yml' or self.suffix == '.yaml' or self.suffix == '.dat':
+ with open(config_file, 'r') as f:
+ self.config = safe_load(f)
+ elif self.suffix == '.txt':
+ with open(config_file, 'r') as f:
+ lines = f.read().splitlines()
+ self.config = {item[0].strip():literal_eval(item[1].strip()) for item in
+ [line.split('#')[0].split('=') for line in lines if '=' in line.split('#')[0]]}
+ else:
+ illegal_value(self.suffix, 'config file extension', 'Config.load_file')
+
+ # Make sure config file was correctly loaded
+ if isinstance(self.config, dict):
+ self.load_flag = True
+ else:
+ logger.error(f'Unable to load dictionary from config file: {config_file}')
+ self.config = {}
+
+ def load_dict(self, config_dict):
+ """Takes a dictionary and places it into self.config.
+ """
+ if self.load_flag:
+ logger.warning('Overwriting the previously loaded config file')
+
+ if isinstance(config_dict, dict):
+ self.config = config_dict
+ self.load_flag = True
+ else:
+ illegal_value(config_dict, 'dictionary config object', 'Config.load_dict')
+ self.config = {}
+
+ def save_file(self, config_file):
+ """Save the config file (as a yaml file only right now).
+ """
+ suffix = os.path.splitext(config_file)[1]
+ if suffix != '.yml' and suffix != '.yaml':
+ illegal_value(suffix, 'config file extension', 'Config.save_file')
+
+ # Check if config file exists
+ if os.path.isfile(config_file):
+ logger.info(f'Updating {config_file}')
+ else:
+ logger.info(f'Saving {config_file}')
+
+ # Save config file
+ with open(config_file, 'w') as f:
+ safe_dump(self.config, f)
+
+ def validate(self, pars_required, pars_missing=None):
+ """Returns False if any required keys are missing.
+ """
+ if not self.load_flag:
+ logger.error('Load a config file prior to calling Config.validate')
+
+ def validate_nested_pars(config, par):
+ par_levels = par.split(':')
+ first_level_par = par_levels[0]
+ try:
+ first_level_par = int(first_level_par)
+ except:
+ pass
+ try:
+ next_level_config = config[first_level_par]
+ if len(par_levels) > 1:
+ next_level_par = ':'.join(par_levels[1:])
+ return(validate_nested_pars(next_level_config, next_level_par))
+ else:
+ return(True)
+ except:
+ return(False)
+
+ pars_missing = [p for p in pars_required if not validate_nested_pars(self.config, p)]
+ if len(pars_missing) > 0:
+ logger.error(f'Missing item(s) in configuration: {", ".join(pars_missing)}')
+ return(False)
+ else:
+ return(True)
diff -r 000000000000 -r 98e23dff1de2 tomo_macros.xml
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/tomo_macros.xml Tue Mar 21 16:22:42 2023 +0000
@@ -0,0 +1,30 @@
+
+
+
+ lmfit
+ matplotlib
+ nexusformat
+ tomopy
+
+
+
+
+
+@misc{github_files,
+ author = {Verberg, Rolf},
+ year = {2022},
+ title = {Tomo Reconstruction},
+}
+
+
+
+
+
+
+
+
diff -r 000000000000 -r 98e23dff1de2 tomo_reconstruct.py
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/tomo_reconstruct.py Tue Mar 21 16:22:42 2023 +0000
@@ -0,0 +1,113 @@
+#!/usr/bin/env python3
+
+import logging
+
+import argparse
+import pathlib
+import sys
+#import tracemalloc
+
+from workflow.run_tomo import Tomo
+
+#from memory_profiler import profile
+#@profile
+def __main__():
+ # Parse command line arguments
+ parser = argparse.ArgumentParser(
+ description='Perform a tomography reconstruction')
+ parser.add_argument('-i', '--input_file',
+ required=True,
+ type=pathlib.Path,
+ help='''Full or relative path to the input file (in Nexus format).''')
+ parser.add_argument('-c', '--center_file',
+ required=True,
+ type=pathlib.Path,
+ help='''Full or relative path to the center info file (in yaml format).''')
+ parser.add_argument('-o', '--output_file',
+ required=False,
+ type=pathlib.Path,
+ help='''Full or relative path to the output file (in Nexus format).''')
+ parser.add_argument('--galaxy_flag',
+ action='store_true',
+ help='''Use this flag to run the scripts as a galaxy tool.''')
+ parser.add_argument('-l', '--log',
+# type=argparse.FileType('w'),
+ default=sys.stdout,
+ help='Logging stream or filename')
+ parser.add_argument('--log_level',
+ choices=logging._nameToLevel.keys(),
+ default='INFO',
+ help='''Specify a preferred logging level.''')
+ parser.add_argument('--x_bounds',
+ required=False,
+ nargs=2,
+ type=int,
+ help='''Boundaries of reconstructed images in x-direction.''')
+ parser.add_argument('--y_bounds',
+ required=False,
+ nargs=2,
+ type=int,
+ help='''Boundaries of reconstructed images in y-direction.''')
+ args = parser.parse_args()
+
+ # Set log configuration
+ # When logging to file, the stdout log level defaults to WARNING
+ logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s'
+ level = logging.getLevelName(args.log_level)
+ if args.log is sys.stdout:
+ logging.basicConfig(format=logging_format, level=level, force=True,
+ handlers=[logging.StreamHandler()])
+ else:
+ if isinstance(args.log, str):
+ logging.basicConfig(filename=f'{args.log}', filemode='w',
+ format=logging_format, level=level, force=True)
+ elif isinstance(args.log, io.TextIOWrapper):
+ logging.basicConfig(filemode='w', format=logging_format, level=level,
+ stream=args.log, force=True)
+ else:
+ raise(ValueError(f'Invalid argument --log: {args.log}'))
+ stream_handler = logging.StreamHandler()
+ logging.getLogger().addHandler(stream_handler)
+ stream_handler.setLevel(logging.WARNING)
+ stream_handler.setFormatter(logging.Formatter(logging_format))
+
+ # Starting memory monitoring
+# tracemalloc.start()
+
+ # Log command line arguments
+ logging.info(f'input_file = {args.input_file}')
+ logging.info(f'center_file = {args.center_file}')
+ logging.info(f'output_file = {args.output_file}')
+ logging.info(f'galaxy_flag = {args.galaxy_flag}')
+ logging.debug(f'log = {args.log}')
+ logging.debug(f'is log stdout? {args.log is sys.stdout}')
+ logging.debug(f'log_level = {args.log_level}')
+ logging.info(f'x_bounds = {args.x_bounds}')
+ logging.info(f'y_bounds = {args.y_bounds}')
+
+ # Instantiate Tomo object
+ tomo = Tomo(galaxy_flag=args.galaxy_flag)
+
+ # Read input file
+ data = tomo.read(args.input_file)
+
+ # Read center data
+ center_data = tomo.read(args.center_file)
+
+ # Find the calibrated center axis info
+ data = tomo.reconstruct_data(data, center_data, x_bounds=args.x_bounds, y_bounds=args.y_bounds)
+
+ # Write output file
+ data = tomo.write(data, args.output_file)
+
+ # Displaying memory usage
+# logging.info(f'Memory usage: {tracemalloc.get_traced_memory()}')
+
+ # stopping memory monitoring
+# tracemalloc.stop()
+
+ logging.info('Completed tomography reconstruction')
+
+
+if __name__ == "__main__":
+ __main__()
diff -r 000000000000 -r 98e23dff1de2 tomo_reconstruct.xml
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/tomo_reconstruct.xml Tue Mar 21 16:22:42 2023 +0000
@@ -0,0 +1,34 @@
+
+ Perform a tomography reconstruction
+
+ tomo_macros.xml
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff -r 000000000000 -r 98e23dff1de2 workflow/__main__.py
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/workflow/__main__.py Tue Mar 21 16:22:42 2023 +0000
@@ -0,0 +1,236 @@
+#!/usr/bin/env python3
+
+import logging
+logging.getLogger(__name__)
+
+import argparse
+import pathlib
+import sys
+
+from .models import TomoWorkflow as Workflow
+try:
+ from deepdiff import DeepDiff
+except:
+ pass
+
+parser = argparse.ArgumentParser(description='''Operate on representations of
+ Tomo data workflows saved to files.''')
+parser.add_argument('-l', '--log',
+# type=argparse.FileType('w'),
+ default=sys.stdout,
+ help='Logging stream or filename')
+parser.add_argument('--log_level',
+ choices=logging._nameToLevel.keys(),
+ default='INFO',
+ help='''Specify a preferred logging level.''')
+subparsers = parser.add_subparsers(title='subcommands', required=True)#, dest='command')
+
+
+# CONSTRUCT
+def construct(args:list) -> None:
+ if args.template_file is not None:
+ wf = Workflow.construct_from_file(args.template_file)
+ wf.cli()
+ else:
+ wf = Workflow.construct_from_cli()
+ wf.write_to_file(args.output_file, force_overwrite=args.force_overwrite)
+
+construct_parser = subparsers.add_parser('construct', help='''Construct a valid Tomo
+ workflow representation on the command line and save it to a file. Optionally use
+ an existing file as a template and/or preform the reconstruction or transfer to Galaxy.''')
+construct_parser.set_defaults(func=construct)
+construct_parser.add_argument('-t', '--template_file',
+ type=pathlib.Path,
+ required=False,
+ help='''Full or relative template file path for the constructed workflow.''')
+construct_parser.add_argument('-f', '--force_overwrite',
+ action='store_true',
+ help='''Use this flag to overwrite the output file if it already exists.''')
+construct_parser.add_argument('-o', '--output_file',
+ type=pathlib.Path,
+ help='''Full or relative file path to which the constructed workflow will be written.''')
+
+
+# VALIDATE
+def validate(args:list) -> bool:
+ try:
+ wf = Workflow.construct_from_file(args.input_file)
+ logger.info(f'Success: {args.input_file} represents a valid Tomo workflow configuration.')
+ return(True)
+ except BaseException as e:
+ logger.error(f'{e.__class__.__name__}: {str(e)}')
+ logger.info(f'''Failure: {args.input_file} does not represent a valid Tomo workflow
+ configuration.''')
+ return(False)
+
+validate_parser = subparsers.add_parser('validate',
+ help='''Validate a file as a representation of a Tomo workflow (this is most useful
+ after a .yaml file has been manually edited).''')
+validate_parser.set_defaults(func=validate)
+validate_parser.add_argument('input_file',
+ type=pathlib.Path,
+ help='''Full or relative file path to validate as a Tomo workflow.''')
+
+
+# CONVERT
+def convert(args:list) -> None:
+ wf = Workflow.construct_from_file(args.input_file)
+ wf.write_to_file(args.output_file, force_overwrite=args.force_overwrite)
+
+convert_parser = subparsers.add_parser('convert', help='''Convert one Tomo workflow
+ representation to another. File format of both input and output files will be
+ automatically determined from the files' extensions.''')
+convert_parser.set_defaults(func=convert)
+convert_parser.add_argument('-f', '--force_overwrite',
+ action='store_true',
+ help='''Use this flag to overwrite the output file if it already exists.''')
+convert_parser.add_argument('-i', '--input_file',
+ type=pathlib.Path,
+ required=True,
+ help='''Full or relative input file path to be converted.''')
+convert_parser.add_argument('-o', '--output_file',
+ type=pathlib.Path,
+ required=True,
+ help='''Full or relative file path to which the converted input will be written.''')
+
+
+# DIFF / COMPARE
+def diff(args:list) -> bool:
+ raise ValueError('diff not tested')
+# wf1 = Workflow.construct_from_file(args.file1).dict_for_yaml()
+# wf2 = Workflow.construct_from_file(args.file2).dict_for_yaml()
+# diff = DeepDiff(wf1,wf2,
+# ignore_order_func=lambda level:'independent_dimensions' not in level.path(),
+# report_repetition=True,
+# ignore_string_type_changes=True,
+# ignore_numeric_type_changes=True)
+ diff_report = diff.pretty()
+ if len(diff_report) > 0:
+ logger.info(f'The configurations in {args.file1} and {args.file2} are not identical.')
+ print(diff_report)
+ return(True)
+ else:
+ logger.info(f'The configurations in {args.file1} and {args.file2} are identical.')
+ return(False)
+
+diff_parser = subparsers.add_parser('diff', aliases=['compare'], help='''Print a comparison of
+ two Tomo workflow representations stored in files. The files may have different formats.''')
+diff_parser.set_defaults(func=diff)
+diff_parser.add_argument('file1',
+ type=pathlib.Path,
+ help='''Full or relative path to the first file for comparison.''')
+diff_parser.add_argument('file2',
+ type=pathlib.Path,
+ help='''Full or relative path to the second file for comparison.''')
+
+
+# LINK TO GALAXY
+def link_to_galaxy(args:list) -> None:
+ from .link_to_galaxy import link_to_galaxy
+ link_to_galaxy(args.input_file, galaxy=args.galaxy, user=args.user,
+ password=args.password, api_key=args.api_key)
+
+link_parser = subparsers.add_parser('link_to_galaxy', help='''Construct a Galaxy history and link
+ to an existing Tomo workflow representations in a NeXus file.''')
+link_parser.set_defaults(func=link_to_galaxy)
+link_parser.add_argument('-i', '--input_file',
+ type=pathlib.Path,
+ required=True,
+ help='''Full or relative input file path to the existing Tomo workflow representations as
+ a NeXus file.''')
+link_parser.add_argument('-g', '--galaxy',
+ required=True,
+ help='Target Galaxy instance URL/IP address')
+link_parser.add_argument('-u', '--user',
+ default=None,
+ help='Galaxy user email address')
+link_parser.add_argument('-p', '--password',
+ default=None,
+ help='Password for the Galaxy user')
+link_parser.add_argument('-a', '--api_key',
+ default=None,
+ help='Galaxy admin user API key (required if not defined in the tools list file)')
+
+
+# RUN THE RECONSTRUCTION
+def run_tomo(args:list) -> None:
+ from .run_tomo import run_tomo
+ run_tomo(args.input_file, args.output_file, args.modes, center_file=args.center_file,
+ num_core=args.num_core, output_folder=args.output_folder, save_figs=args.save_figs)
+
+tomo_parser = subparsers.add_parser('run_tomo', help='''Construct and add reconstructed tomography
+ data to an existing Tomo workflow representations in a NeXus file.''')
+tomo_parser.set_defaults(func=run_tomo)
+tomo_parser.add_argument('-i', '--input_file',
+ required=True,
+ type=pathlib.Path,
+ help='''Full or relative input file path containing raw and/or reduced data.''')
+tomo_parser.add_argument('-o', '--output_file',
+ required=True,
+ type=pathlib.Path,
+ help='''Full or relative input file path containing raw and/or reduced data.''')
+tomo_parser.add_argument('-c', '--center_file',
+ type=pathlib.Path,
+ help='''Full or relative input file path containing the rotation axis centers info.''')
+#tomo_parser.add_argument('-f', '--force_overwrite',
+# action='store_true',
+# help='''Use this flag to overwrite any existing reduced data.''')
+tomo_parser.add_argument('-n', '--num_core',
+ type=int,
+ default=-1,
+ help='''Specify the number of processors to use.''')
+tomo_parser.add_argument('--output_folder',
+ type=pathlib.Path,
+ default='.',
+ help='Full or relative path to an output folder')
+tomo_parser.add_argument('-s', '--save_figs',
+ choices=['yes', 'no', 'only'],
+ default='no',
+ help='''Specify weather to display ('yes' or 'no'), save ('yes'), or only save ('only').''')
+tomo_parser.add_argument('--reduce_data',
+ dest='modes',
+ const='reduce_data',
+ action='append_const',
+ help='''Use this flag to create and add reduced data to the input file.''')
+tomo_parser.add_argument('--find_center',
+ dest='modes',
+ const='find_center',
+ action='append_const',
+ help='''Use this flag to find and add the calibrated center axis info to the input file.''')
+tomo_parser.add_argument('--reconstruct_data',
+ dest='modes',
+ const='reconstruct_data',
+ action='append_const',
+ help='''Use this flag to create and add reconstructed data data to the input file.''')
+tomo_parser.add_argument('--combine_data',
+ dest='modes',
+ const='combine_data',
+ action='append_const',
+ help='''Use this flag to combine reconstructed data data and add to the input file.''')
+
+
+if __name__ == '__main__':
+ args = parser.parse_args(sys.argv[1:])
+
+ # Set log configuration
+ # When logging to file, the stdout log level defaults to WARNING
+ logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s'
+ level = logging.getLevelName(args.log_level)
+ if args.log is sys.stdout:
+ logging.basicConfig(format=logging_format, level=level, force=True,
+ handlers=[logging.StreamHandler()])
+ else:
+ if isinstance(args.log, str):
+ logging.basicConfig(filename=f'{args.log}', filemode='w',
+ format=logging_format, level=level, force=True)
+ elif isinstance(args.log, io.TextIOWrapper):
+ logging.basicConfig(filemode='w', format=logging_format, level=level,
+ stream=args.log, force=True)
+ else:
+ raise ValueError(f'Invalid argument --log: {args.log}')
+ stream_handler = logging.StreamHandler()
+ logging.getLogger().addHandler(stream_handler)
+ stream_handler.setLevel(logging.WARNING)
+ stream_handler.setFormatter(logging.Formatter(logging_format))
+
+ args.func(args)
diff -r 000000000000 -r 98e23dff1de2 workflow/__version__.py
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/workflow/__version__.py Tue Mar 21 16:22:42 2023 +0000
@@ -0,0 +1,1 @@
+__version__='2022.3.0'
diff -r 000000000000 -r 98e23dff1de2 workflow/link_to_galaxy.py
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/workflow/link_to_galaxy.py Tue Mar 21 16:22:42 2023 +0000
@@ -0,0 +1,120 @@
+#!/usr/bin/env python3
+
+import logging
+logger = logging.getLogger(__name__)
+
+from bioblend.galaxy import GalaxyInstance
+from nexusformat.nexus import *
+from os import path
+from yaml import safe_load
+
+from .models import import_scanparser, TomoWorkflow
+
+def get_folder_id(gi, path):
+ library_id = None
+ folder_id = None
+ folder_names = path[1:] if len(path) > 1 else []
+ new_folders = folder_names
+ libs = gi.libraries.get_libraries(name=path[0])
+ if libs:
+ for lib in libs:
+ library_id = lib['id']
+ folders = gi.libraries.get_folders(library_id, folder_id=None, name=None)
+ for i, folder in enumerate(folders):
+ fid = folder['id']
+ details = gi.libraries.show_folder(library_id, fid)
+ library_path = details['library_path']
+ if library_path == folder_names:
+ return (library_id, fid, [])
+ elif len(library_path) < len(folder_names):
+ if library_path == folder_names[:len(library_path)]:
+ nf = folder_names[len(library_path):]
+ if len(nf) < len(new_folders):
+ folder_id = fid
+ new_folders = nf
+ return (library_id, folder_id, new_folders)
+
+def link_to_galaxy(filename:str, galaxy=None, user=None, password=None, api_key=None) -> None:
+ # Read input file
+ extension = path.splitext(filename)[1]
+# RV yaml input not incorporated yet, since Galaxy can't use pyspec right now
+# if extension == '.yml' or extension == '.yaml':
+# with open(filename, 'r') as f:
+# data = safe_load(f)
+# elif extension == '.nxs':
+ if extension == '.nxs':
+ with NXFile(filename, mode='r') as nxfile:
+ data = nxfile.readfile()
+ else:
+ raise ValueError(f'Invalid filename extension ({extension})')
+ if isinstance(data, dict):
+ # Create Nexus format object from input dictionary
+ wf = TomoWorkflow(**data)
+ if len(wf.sample_maps) > 1:
+ raise ValueError(f'Multiple sample maps not yet implemented')
+ nxroot = NXroot()
+ for sample_map in wf.sample_maps:
+ import_scanparser(sample_map.station)
+# RV raw data must be included, since Galaxy can't use pyspec right now
+# sample_map.construct_nxentry(nxroot, include_raw_data=False)
+ sample_map.construct_nxentry(nxroot, include_raw_data=True)
+ nxentry = nxroot[nxroot.attrs['default']]
+ elif isinstance(data, NXroot):
+ nxentry = data[data.attrs['default']]
+ else:
+ raise ValueError(f'Invalid input file data ({data})')
+
+ # Get a Galaxy instance
+ if user is not None and password is not None :
+ gi = GalaxyInstance(url=galaxy, email=user, password=password)
+ elif api_key is not None:
+ gi = GalaxyInstance(url=galaxy, key=api_key)
+ else:
+ exit('Please specify either a valid Galaxy username/password or an API key.')
+
+ cycle = nxentry.instrument.source.attrs['cycle']
+ btr = nxentry.instrument.source.attrs['btr']
+ sample = nxentry.sample.name
+
+ # Create a Galaxy work library/folder
+ # Combine the cycle, BTR and sample name as the base library name
+ lib_path = [p.strip() for p in f'{cycle}/{btr}/{sample}'.split('/')]
+ (library_id, folder_id, folder_names) = get_folder_id(gi, lib_path)
+ if not library_id:
+ library = gi.libraries.create_library(lib_path[0], description=None, synopsis=None)
+ library_id = library['id']
+# if user:
+# gi.libraries.set_library_permissions(library_id, access_ids=user,
+# manage_ids=user, modify_ids=user)
+ logger.info(f'Created Library:\n{library}')
+ if len(folder_names):
+ folder = gi.libraries.create_folder(library_id, folder_names[0], description=None,
+ base_folder_id=folder_id)[0]
+ folder_id = folder['id']
+ logger.info(f'Created Folder:\n{folder}')
+ folder_names.pop(0)
+ while len(folder_names):
+ folder = gi.folders.create_folder(folder['id'], folder_names[0],
+ description=None)
+ folder_id = folder['id']
+ logger.info(f'Created Folder:\n{folder}')
+ folder_names.pop(0)
+
+ # Create a sym link for the Nexus file
+ dataset = gi.libraries.upload_from_galaxy_filesystem(library_id, path.abspath(filename),
+ folder_id=folder_id, file_type='auto', dbkey='?', link_data_only='link_to_files',
+ roles='', preserve_dirs=False, tag_using_filenames=False, tags=None)[0]
+
+ # Make a history for the data
+ history_name = f'tomo {btr} {sample}'
+ history = gi.histories.create_history(name=history_name)
+ logger.info(f'Created history:\n{history}')
+ history_id = history['id']
+ gi.histories.copy_dataset(history_id, dataset['id'], source='library')
+
+# TODO add option to either
+# get a URL to share the history
+# or to share with specific users
+# This might require using:
+# https://bioblend.readthedocs.io/en/latest/api_docs/galaxy/docs.html#using-bioblend-for-raw-api-calls
+
diff -r 000000000000 -r 98e23dff1de2 workflow/models.py
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/workflow/models.py Tue Mar 21 16:22:42 2023 +0000
@@ -0,0 +1,1096 @@
+#!/usr/bin/env python3
+
+import logging
+logger = logging.getLogger(__name__)
+
+import logging
+
+import numpy as np
+import os
+import yaml
+
+from functools import cache
+from pathlib import PosixPath
+from pydantic import BaseModel as PydanticBaseModel
+from pydantic import validator, ValidationError, conint, confloat, constr, conlist, FilePath, \
+ PrivateAttr
+from nexusformat.nexus import *
+from time import time
+from typing import Optional, Literal
+from typing_extensions import TypedDict
+try:
+ from pyspec.file.spec import FileSpec
+except:
+ pass
+
+try:
+ from msnctools.general import is_int, is_num, input_int, input_int_list, input_num, \
+ input_yesno, input_menu, index_nearest, string_to_list, file_exists_and_readable
+except:
+ from general import is_int, is_num, input_int, input_int_list, input_num, \
+ input_yesno, input_menu, index_nearest, string_to_list, file_exists_and_readable
+
+
+def import_scanparser(station):
+ if station in ('id1a3', 'id3a'):
+ try:
+ from msnctools.scanparsers import SMBRotationScanParser
+ globals()['ScanParser'] = SMBRotationScanParser
+ except:
+ try:
+ from scanparsers import SMBRotationScanParser
+ globals()['ScanParser'] = SMBRotationScanParser
+ except:
+ pass
+ elif station in ('id3b'):
+ try:
+ from msnctools.scanparsers import FMBRotationScanParser
+ globals()['ScanParser'] = FMBRotationScanParser
+ except:
+ try:
+ from scanparsers import FMBRotationScanParser
+ globals()['ScanParser'] = FMBRotationScanParser
+ except:
+ pass
+ else:
+ raise RuntimeError(f'Invalid station: {station}')
+
+@cache
+def get_available_scan_numbers(spec_file:str):
+ scans = FileSpec(spec_file).scans
+ scan_numbers = list(scans.keys())
+ for scan_number in scan_numbers.copy():
+ try:
+ parser = ScanParser(spec_file, scan_number)
+ try:
+ scan_type = parser.scan_type
+ except:
+ scan_type = None
+ except:
+ scan_numbers.remove(scan_number)
+ return(scan_numbers)
+
+@cache
+def get_scanparser(spec_file:str, scan_number:int):
+ if scan_number not in get_available_scan_numbers(spec_file):
+ return(None)
+ else:
+ return(ScanParser(spec_file, scan_number))
+
+
+class BaseModel(PydanticBaseModel):
+ class Config:
+ validate_assignment = True
+ arbitrary_types_allowed = True
+
+ @classmethod
+ def construct_from_cli(cls):
+ obj = cls.construct()
+ obj.cli()
+ return(obj)
+
+ @classmethod
+ def construct_from_yaml(cls, filename):
+ try:
+ with open(filename, 'r') as infile:
+ indict = yaml.load(infile, Loader=yaml.CLoader)
+ except:
+ raise ValueError(f'Could not load a dictionary from {filename}')
+ else:
+ obj = cls(**indict)
+ return(obj)
+
+ @classmethod
+ def construct_from_file(cls, filename):
+ file_exists_and_readable(filename)
+ filename = os.path.abspath(filename)
+ fileformat = os.path.splitext(filename)[1]
+ yaml_extensions = ('.yaml','.yml')
+ nexus_extensions = ('.nxs','.nx5','.h5','.hdf5')
+ t0 = time()
+ if fileformat.lower() in yaml_extensions:
+ obj = cls.construct_from_yaml(filename)
+ logger.info(f'Constructed a model from {filename} in {time()-t0:.2f} seconds.')
+ return(obj)
+ elif fileformat.lower() in nexus_extensions:
+ obj = cls.construct_from_nexus(filename)
+ logger.info(f'Constructed a model from {filename} in {time()-t0:.2f} seconds.')
+ return(obj)
+ else:
+ logger.error(f'Unsupported file extension for constructing a model: {fileformat}')
+ raise TypeError(f'Unrecognized file extension: {fileformat}')
+
+ def dict_for_yaml(self, exclude_fields=[]):
+ yaml_dict = {}
+ for field_name in self.__fields__:
+ if field_name in exclude_fields:
+ continue
+ else:
+ field_value = getattr(self, field_name, None)
+ if field_value is not None:
+ if isinstance(field_value, BaseModel):
+ yaml_dict[field_name] = field_value.dict_for_yaml()
+ elif isinstance(field_value,list) and all(isinstance(item,BaseModel)
+ for item in field_value):
+ yaml_dict[field_name] = [item.dict_for_yaml() for item in field_value]
+ elif isinstance(field_value, PosixPath):
+ yaml_dict[field_name] = str(field_value)
+ else:
+ yaml_dict[field_name] = field_value
+ else:
+ continue
+ return(yaml_dict)
+
+ def write_to_yaml(self, filename=None):
+ yaml_dict = self.dict_for_yaml()
+ if filename is None:
+ logger.info('Printing yaml representation here:\n'+
+ f'{yaml.dump(yaml_dict, sort_keys=False)}')
+ else:
+ try:
+ with open(filename, 'w') as outfile:
+ yaml.dump(yaml_dict, outfile, sort_keys=False)
+ logger.info(f'Successfully wrote this model to {filename}')
+ except:
+ logger.error(f'Unknown error -- could not write to {filename} in yaml format.')
+ logger.info('Printing yaml representation here:\n'+
+ f'{yaml.dump(yaml_dict, sort_keys=False)}')
+
+ def write_to_file(self, filename, force_overwrite=False):
+ file_writeable, fileformat = self.output_file_valid(filename,
+ force_overwrite=force_overwrite)
+ if fileformat == 'yaml':
+ if file_writeable:
+ self.write_to_yaml(filename=filename)
+ else:
+ self.write_to_yaml()
+ elif fileformat == 'nexus':
+ if file_writeable:
+ self.write_to_nexus(filename=filename)
+
+ def output_file_valid(self, filename, force_overwrite=False):
+ filename = os.path.abspath(filename)
+ fileformat = os.path.splitext(filename)[1]
+ yaml_extensions = ('.yaml','.yml')
+ nexus_extensions = ('.nxs','.nx5','.h5','.hdf5')
+ if fileformat.lower() not in (*yaml_extensions, *nexus_extensions):
+ return(False, None) # Only yaml and NeXus files allowed for output now.
+ elif fileformat.lower() in yaml_extensions:
+ fileformat = 'yaml'
+ elif fileformat.lower() in nexus_extensions:
+ fileformat = 'nexus'
+ if os.path.isfile(filename):
+ if os.access(filename, os.W_OK):
+ if not force_overwrite:
+ logger.error(f'{filename} will not be overwritten.')
+ return(False, fileformat)
+ else:
+ logger.error(f'Cannot access {filename} for writing.')
+ return(False, fileformat)
+ if os.path.isdir(os.path.dirname(filename)):
+ if os.access(os.path.dirname(filename), os.W_OK):
+ return(True, fileformat)
+ else:
+ logger.error(f'Cannot access {os.path.dirname(filename)} for writing.')
+ return(False, fileformat)
+ else:
+ try:
+ os.makedirs(os.path.dirname(filename))
+ return(True, fileformat)
+ except:
+ logger.error(f'Cannot create {os.path.dirname(filename)} for output.')
+ return(False, fileformat)
+
+ def set_single_attr_cli(self, attr_name, attr_desc='unknown attribute', list_flag=False,
+ **cli_kwargs):
+ if cli_kwargs.get('chain_attr_desc', False):
+ cli_kwargs['attr_desc'] = attr_desc
+ try:
+ attr = getattr(self, attr_name, None)
+ if attr is None:
+ attr = self.__fields__[attr_name].type_.construct()
+ if cli_kwargs.get('chain_attr_desc', False):
+ cli_kwargs['attr_desc'] = attr_desc
+ input_accepted = False
+ while not input_accepted:
+ try:
+ attr.cli(**cli_kwargs)
+ except ValidationError as e:
+ print(e)
+ print(f'Removing {attr_desc} configuration')
+ attr = self.__fields__[attr_name].type_.construct()
+ continue
+ except KeyboardInterrupt as e:
+ raise e
+ except BaseException as e:
+ print(f'{type(e).__name__}: {e}')
+ print(f'Removing {attr_desc} configuration')
+ attr = self.__fields__[attr_name].type_.construct()
+ continue
+ try:
+ setattr(self, attr_name, attr)
+ except ValidationError as e:
+ print(e)
+ except KeyboardInterrupt as e:
+ raise e
+ except BaseException as e:
+ print(f'{type(e).__name__}: {e}')
+ else:
+ input_accepted = True
+ except:
+ input_accepted = False
+ while not input_accepted:
+ attr = getattr(self, attr_name, None)
+ if attr is None:
+ input_value = input(f'Type and enter a value for {attr_desc}: ')
+ else:
+ input_value = input(f'Type and enter a new value for {attr_desc} or press '+
+ f'enter to keep the current one ({attr}): ')
+ if list_flag:
+ input_value = string_to_list(input_value, remove_duplicates=False, sort=False)
+ if len(input_value) == 0:
+ input_value = getattr(self, attr_name, None)
+ try:
+ setattr(self, attr_name, input_value)
+ except ValidationError as e:
+ print(e)
+ except KeyboardInterrupt as e:
+ raise e
+ except BaseException as e:
+ print(f'Unexpected {type(e).__name__}: {e}')
+ else:
+ input_accepted = True
+
+ def set_list_attr_cli(self, attr_name, attr_desc='unknown attribute', **cli_kwargs):
+ if cli_kwargs.get('chain_attr_desc', False):
+ cli_kwargs['attr_desc'] = attr_desc
+ attr = getattr(self, attr_name, None)
+ if attr is not None:
+ # Check existing items
+ for item in attr:
+ item_accepted = False
+ while not item_accepted:
+ item.cli(**cli_kwargs)
+ try:
+ setattr(self, attr_name, attr)
+ except ValidationError as e:
+ print(e)
+ except KeyboardInterrupt as e:
+ raise e
+ except BaseException as e:
+ print(f'{type(e).__name__}: {e}')
+ else:
+ item_accepted = True
+ else:
+ # Initialize list for new attribute & starting item
+ attr = []
+ item = self.__fields__[attr_name].type_.construct()
+ # Append (optional) additional items
+ append = input_yesno(f'Add a {attr_desc} configuration? (y/n)', 'n')
+ while append:
+ attr.append(item.__class__.construct_from_cli())
+ try:
+ setattr(self, attr_name, attr)
+ except ValidationError as e:
+ print(e)
+ print(f'Removing last {attr_desc} configuration from the list')
+ attr.pop()
+ except KeyboardInterrupt as e:
+ raise e
+ except BaseException as e:
+ print(f'{type(e).__name__}: {e}')
+ print(f'Removing last {attr_desc} configuration from the list')
+ attr.pop()
+ else:
+ append = input_yesno(f'Add another {attr_desc} configuration? (y/n)', 'n')
+
+
+class Detector(BaseModel):
+ prefix: constr(strip_whitespace=True, min_length=1)
+ rows: conint(gt=0)
+ columns: conint(gt=0)
+ pixel_size: conlist(item_type=confloat(gt=0), min_items=1, max_items=2)
+ lens_magnification: confloat(gt=0) = 1.0
+
+ @property
+ def get_pixel_size(self):
+ return(list(np.asarray(self.pixel_size)/self.lens_magnification))
+
+ def construct_from_yaml(self, filename):
+ try:
+ with open(filename, 'r') as infile:
+ indict = yaml.load(infile, Loader=yaml.CLoader)
+ detector = indict['detector']
+ self.prefix = detector['id']
+ pixels = detector['pixels']
+ self.rows = pixels['rows']
+ self.columns = pixels['columns']
+ self.pixel_size = pixels['size']
+ self.lens_magnification = indict['lens_magnification']
+ except:
+ logging.warning(f'Could not load a dictionary from {filename}')
+ return(False)
+ else:
+ return(True)
+
+ def cli(self):
+ print('\n -- Configure the detector -- ')
+ self.set_single_attr_cli('prefix', 'detector ID')
+ self.set_single_attr_cli('rows', 'number of pixel rows')
+ self.set_single_attr_cli('columns', 'number of pixel columns')
+ self.set_single_attr_cli('pixel_size', 'pixel size in mm (enter either a single value for '+
+ 'square pixels or a pair of values for the size in the respective row and column '+
+ 'directions)', list_flag=True)
+ self.set_single_attr_cli('lens_magnification', 'lens magnification')
+
+ def construct_nxdetector(self):
+ nxdetector = NXdetector()
+ nxdetector.local_name = self.prefix
+ pixel_size = self.get_pixel_size
+ if len(pixel_size) == 1:
+ nxdetector.x_pixel_size = pixel_size[0]
+ nxdetector.y_pixel_size = pixel_size[0]
+ else:
+ nxdetector.x_pixel_size = pixel_size[0]
+ nxdetector.y_pixel_size = pixel_size[1]
+ nxdetector.x_pixel_size.attrs['units'] = 'mm'
+ nxdetector.y_pixel_size.attrs['units'] = 'mm'
+ return(nxdetector)
+
+
+class ScanInfo(TypedDict):
+ scan_number: int
+ starting_image_offset: conint(ge=0)
+ num_image: conint(gt=0)
+ ref_x: float
+ ref_z: float
+
+class SpecScans(BaseModel):
+ spec_file: FilePath
+ scan_numbers: conlist(item_type=conint(gt=0), min_items=1)
+ stack_info: conlist(item_type=ScanInfo, min_items=1) = []
+
+ @validator('spec_file')
+ def validate_spec_file(cls, spec_file):
+ try:
+ spec_file = os.path.abspath(spec_file)
+ sspec_file = FileSpec(spec_file)
+ except:
+ raise ValueError(f'Invalid SPEC file {spec_file}')
+ else:
+ return(spec_file)
+
+ @validator('scan_numbers')
+ def validate_scan_numbers(cls, scan_numbers, values):
+ spec_file = values.get('spec_file')
+ if spec_file is not None:
+ spec_scans = FileSpec(spec_file)
+ for scan_number in scan_numbers:
+ scan = spec_scans.get_scan_by_number(scan_number)
+ if scan is None:
+ raise ValueError(f'There is no scan number {scan_number} in {spec_file}')
+ return(scan_numbers)
+
+ @validator('stack_info')
+ def validate_stack_info(cls, stack_info, values):
+ scan_numbers = values.get('scan_numbers')
+ assert(len(scan_numbers) == len(stack_info))
+ for scan_info in stack_info:
+ assert(scan_info['scan_number'] in scan_numbers)
+ is_int(scan_info['starting_image_offset'], ge=0, lt=scan_info['num_image'],
+ raise_error=True)
+ return(stack_info)
+
+ @classmethod
+ def construct_from_nxcollection(cls, nxcollection:NXcollection):
+ config = {}
+ config['spec_file'] = nxcollection.attrs['spec_file']
+ scan_numbers = []
+ stack_info = []
+ for nxsubentry_name, nxsubentry in nxcollection.items():
+ scan_number = int(nxsubentry_name.split('_')[-1])
+ scan_numbers.append(scan_number)
+ stack_info.append({'scan_number': scan_number,
+ 'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number),
+ 'num_image': len(nxsubentry.sample.rotation_angle),
+ 'ref_x': float(nxsubentry.sample.x_translation),
+ 'ref_z': float(nxsubentry.sample.z_translation)})
+ config['scan_numbers'] = sorted(scan_numbers)
+ config['stack_info'] = stack_info
+ return(cls(**config))
+
+ @property
+ def available_scan_numbers(self):
+ return(get_available_scan_numbers(self.spec_file))
+
+ def set_from_nxcollection(self, nxcollection:NXcollection):
+ self.spec_file = nxcollection.attrs['spec_file']
+ scan_numbers = []
+ stack_info = []
+ for nxsubentry_name, nxsubentry in nxcollection.items():
+ scan_number = int(nxsubentry_name.split('_')[-1])
+ scan_numbers.append(scan_number)
+ stack_info.append({'scan_number': scan_number,
+ 'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number),
+ 'num_image': len(nxsubentry.sample.rotation_angle),
+ 'ref_x': float(nxsubentry.sample.x_translation),
+ 'ref_z': float(nxsubentry.sample.z_translation)})
+ self.scan_numbers = sorted(scan_numbers)
+ self.stack_info = stack_info
+
+ def get_scan_index(self, scan_number):
+ scan_index = [scan_index for scan_index, scan_info in enumerate(self.stack_info)
+ if scan_info['scan_number'] == scan_number]
+ if len(scan_index) > 1:
+ raise ValueError('Duplicate scan_numbers in image stack')
+ elif len(scan_index) == 1:
+ return(scan_index[0])
+ else:
+ return(None)
+
+ def get_scanparser(self, scan_number):
+ return(get_scanparser(self.spec_file, scan_number))
+
+ def get_detector_data(self, detector_prefix, scan_number=None, scan_step_index=None):
+ image_stacks = []
+ if scan_number is None:
+ scan_numbers = self.scan_numbers
+ else:
+ scan_numbers = [scan_number]
+ for scan_number in scan_numbers:
+ parser = self.get_scanparser(scan_number)
+ scan_info = self.stack_info[self.get_scan_index(scan_number)]
+ image_offset = scan_info['starting_image_offset']
+ if scan_step_index is None:
+ num_image = scan_info['num_image']
+ image_stacks.append(parser.get_detector_data(detector_prefix,
+ (image_offset, image_offset+num_image)))
+ else:
+ image_stacks.append(parser.get_detector_data(detector_prefix,
+ image_offset+scan_step_index))
+ if scan_number is not None and scan_step_index is not None:
+ # Return a single image for a specific scan_number and scan_step_index request
+ return(image_stacks[0])
+ else:
+ # Return a list otherwise
+ return(image_stacks)
+ return(image_stacks)
+
+ def scan_numbers_cli(self, attr_desc, **kwargs):
+ available_scan_numbers = self.available_scan_numbers
+ station = kwargs.get('station')
+ if (station is not None and station in ('id1a3', 'id3a') and
+ 'scan_type' in kwargs):
+ scan_type = kwargs['scan_type']
+ if scan_type == 'ts1':
+ available_scan_numbers = []
+ for scan_number in self.available_scan_numbers:
+ parser = self.get_scanparser(scan_number)
+ try:
+ if parser.scan_type == scan_type:
+ available_scan_numbers.append(scan_number)
+ except:
+ pass
+ elif scan_type == 'df1':
+ tomo_scan_numbers = kwargs['tomo_scan_numbers']
+ available_scan_numbers = []
+ for scan_number in tomo_scan_numbers:
+ parser = self.get_scanparser(scan_number-2)
+ assert(parser.scan_type == scan_type)
+ available_scan_numbers.append(scan_number-2)
+ elif scan_type == 'bf1':
+ tomo_scan_numbers = kwargs['tomo_scan_numbers']
+ available_scan_numbers = []
+ for scan_number in tomo_scan_numbers:
+ parser = self.get_scanparser(scan_number-1)
+ assert(parser.scan_type == scan_type)
+ available_scan_numbers.append(scan_number-1)
+ if len(available_scan_numbers) == 1:
+ input_mode = 1
+ else:
+ if hasattr(self, 'scan_numbers'):
+ print(f'Currently selected {attr_desc}scan numbers are: {self.scan_numbers}')
+ menu_options = [f'Select a subset of the available {attr_desc}scan numbers',
+ f'Use all available {attr_desc}scan numbers in {self.spec_file}',
+ f'Keep the currently selected {attr_desc}scan numbers']
+ else:
+ menu_options = [f'Select a subset of the available {attr_desc}scan numbers',
+ f'Use all available {attr_desc}scan numbers in {self.spec_file}']
+ print(f'Available scan numbers in {self.spec_file} are: '+
+ f'{available_scan_numbers}')
+ input_mode = input_menu(menu_options, header='Choose one of the following options '+
+ 'for selecting scan numbers')
+ if input_mode == 0:
+ accept_scan_numbers = False
+ while not accept_scan_numbers:
+ try:
+ self.scan_numbers = \
+ input_int_list(f'Enter a series of {attr_desc}scan numbers')
+ except ValidationError as e:
+ print(e)
+ except KeyboardInterrupt as e:
+ raise e
+ except BaseException as e:
+ print(f'Unexpected {type(e).__name__}: {e}')
+ else:
+ accept_scan_numbers = True
+ elif input_mode == 1:
+ self.scan_numbers = available_scan_numbers
+ elif input_mode == 2:
+ pass
+
+ def cli(self, **cli_kwargs):
+ if cli_kwargs.get('attr_desc') is not None:
+ attr_desc = f'{cli_kwargs["attr_desc"]} '
+ else:
+ attr_desc = ''
+ print(f'\n -- Configure which scans to use from a single {attr_desc}SPEC file')
+ self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path')
+ self.scan_numbers_cli(attr_desc)
+
+ def construct_nxcollection(self, image_key, thetas, detector):
+ nxcollection = NXcollection()
+ nxcollection.attrs['spec_file'] = str(self.spec_file)
+ parser = self.get_scanparser(self.scan_numbers[0])
+ nxcollection.attrs['date'] = parser.spec_scan.file_date
+ for scan_number in self.scan_numbers:
+ # Get scan info
+ scan_info = self.stack_info[self.get_scan_index(scan_number)]
+ # Add an NXsubentry to the NXcollection for each scan
+ entry_name = f'scan_{scan_number}'
+ nxsubentry = NXsubentry()
+ nxcollection[entry_name] = nxsubentry
+ parser = self.get_scanparser(scan_number)
+ nxsubentry.start_time = parser.spec_scan.date
+ nxsubentry.spec_command = parser.spec_command
+ # Add an NXdata for independent dimensions to the scan's NXsubentry
+ num_image = scan_info['num_image']
+ if thetas is None:
+ thetas = num_image*[0.0]
+ else:
+ assert(num_image == len(thetas))
+# nxsubentry.independent_dimensions = NXdata()
+# nxsubentry.independent_dimensions.rotation_angle = thetas
+# nxsubentry.independent_dimensions.rotation_angle.units = 'degrees'
+ # Add an NXinstrument to the scan's NXsubentry
+ nxsubentry.instrument = NXinstrument()
+ # Add an NXdetector to the NXinstrument to the scan's NXsubentry
+ nxsubentry.instrument.detector = detector.construct_nxdetector()
+ nxsubentry.instrument.detector.frame_start_number = scan_info['starting_image_offset']
+ nxsubentry.instrument.detector.image_key = image_key
+ # Add an NXsample to the scan's NXsubentry
+ nxsubentry.sample = NXsample()
+ nxsubentry.sample.rotation_angle = thetas
+ nxsubentry.sample.rotation_angle.units = 'degrees'
+ nxsubentry.sample.x_translation = scan_info['ref_x']
+ nxsubentry.sample.x_translation.units = 'mm'
+ nxsubentry.sample.z_translation = scan_info['ref_z']
+ nxsubentry.sample.z_translation.units = 'mm'
+ return(nxcollection)
+
+
+class FlatField(SpecScans):
+
+ def image_range_cli(self, attr_desc, detector_prefix):
+ stack_info = self.stack_info
+ for scan_number in self.scan_numbers:
+ # Parse the available image range
+ parser = self.get_scanparser(scan_number)
+ image_offset = parser.starting_image_offset
+ num_image = parser.get_num_image(detector_prefix.upper())
+ scan_index = self.get_scan_index(scan_number)
+
+ # Select the image set
+ last_image_index = image_offset+num_image
+ print(f'Available good image set index range: [{image_offset}, {last_image_index})')
+ image_set_approved = False
+ if scan_index is not None:
+ scan_info = stack_info[scan_index]
+ print(f'Current starting image offset and number of images: '+
+ f'{scan_info["starting_image_offset"]} and {scan_info["num_image"]}')
+ image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y')
+ if not image_set_approved:
+ print(f'Default starting image offset and number of images: '+
+ f'{image_offset} and {num_image}')
+ image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y')
+ if image_set_approved:
+ offset = image_offset
+ num = last_image_index-offset
+ while not image_set_approved:
+ offset = input_int(f'Enter the starting image offset', ge=image_offset,
+ lt=last_image_index)#, default=image_offset)
+ num = input_int(f'Enter the number of images', ge=1,
+ le=last_image_index-offset)#, default=last_image_index-offset)
+ print(f'Current starting image offset and number of images: {offset} and {num}')
+ image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y')
+ if scan_index is not None:
+ scan_info['starting_image_offset'] = offset
+ scan_info['num_image'] = num
+ scan_info['ref_x'] = parser.horizontal_shift
+ scan_info['ref_z'] = parser.vertical_shift
+ else:
+ stack_info.append({'scan_number': scan_number, 'starting_image_offset': offset,
+ 'num_image': num, 'ref_x': parser.horizontal_shift,
+ 'ref_z': parser.vertical_shift})
+ self.stack_info = stack_info
+
+ def cli(self, **cli_kwargs):
+ if cli_kwargs.get('attr_desc') is not None:
+ attr_desc = f'{cli_kwargs["attr_desc"]} '
+ else:
+ attr_desc = ''
+ station = cli_kwargs.get('station')
+ detector = cli_kwargs.get('detector')
+ print(f'\n -- Configure the location of the {attr_desc}scan data -- ')
+ if station in ('id1a3', 'id3a'):
+ self.spec_file = cli_kwargs['spec_file']
+ tomo_scan_numbers = cli_kwargs['tomo_scan_numbers']
+ scan_type = cli_kwargs['scan_type']
+ self.scan_numbers_cli(attr_desc, station=station, tomo_scan_numbers=tomo_scan_numbers,
+ scan_type=scan_type)
+ else:
+ self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path')
+ self.scan_numbers_cli(attr_desc)
+ self.image_range_cli(attr_desc, detector.prefix)
+
+
+class TomoField(SpecScans):
+ theta_range: dict = {}
+
+ @validator('theta_range')
+ def validate_theta_range(cls, theta_range):
+ if len(theta_range) != 3 and len(theta_range) != 4:
+ raise ValueError(f'Invalid theta range {theta_range}')
+ is_num(theta_range['start'], raise_error=True)
+ is_num(theta_range['end'], raise_error=True)
+ is_int(theta_range['num'], gt=1, raise_error=True)
+ if theta_range['end'] <= theta_range['start']:
+ raise ValueError(f'Invalid theta range {theta_range}')
+ if 'start_index' in theta_range:
+ is_int(theta_range['start_index'], ge=0, raise_error=True)
+ return(theta_range)
+
+ @classmethod
+ def construct_from_nxcollection(cls, nxcollection:NXcollection):
+ #RV Can I derive this from the same classfunction for SpecScans by adding theta_range
+ config = {}
+ config['spec_file'] = nxcollection.attrs['spec_file']
+ scan_numbers = []
+ stack_info = []
+ for nxsubentry_name, nxsubentry in nxcollection.items():
+ scan_number = int(nxsubentry_name.split('_')[-1])
+ scan_numbers.append(scan_number)
+ stack_info.append({'scan_number': scan_number,
+ 'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number),
+ 'num_image': len(nxsubentry.sample.rotation_angle),
+ 'ref_x': float(nxsubentry.sample.x_translation),
+ 'ref_z': float(nxsubentry.sample.z_translation)})
+ config['scan_numbers'] = sorted(scan_numbers)
+ config['stack_info'] = stack_info
+ for name in nxcollection.entries:
+ if 'scan_' in name:
+ thetas = np.asarray(nxcollection[name].sample.rotation_angle)
+ config['theta_range'] = {'start': thetas[0], 'end': thetas[-1], 'num': thetas.size}
+ break
+ return(cls(**config))
+
+ def get_horizontal_shifts(self, scan_number=None):
+ horizontal_shifts = []
+ if scan_number is None:
+ scan_numbers = self.scan_numbers
+ else:
+ scan_numbers = [scan_number]
+ for scan_number in scan_numbers:
+ parser = self.get_scanparser(scan_number)
+ horizontal_shifts.append(parser.horizontal_shift)
+ if len(horizontal_shifts) == 1:
+ return(horizontal_shifts[0])
+ else:
+ return(horizontal_shifts)
+
+ def get_vertical_shifts(self, scan_number=None):
+ vertical_shifts = []
+ if scan_number is None:
+ scan_numbers = self.scan_numbers
+ else:
+ scan_numbers = [scan_number]
+ for scan_number in scan_numbers:
+ parser = self.get_scanparser(scan_number)
+ vertical_shifts.append(parser.vertical_shift)
+ if len(vertical_shifts) == 1:
+ return(vertical_shifts[0])
+ else:
+ return(vertical_shifts)
+
+ def theta_range_cli(self, scan_number, attr_desc, station):
+ # Parse the available theta range
+ parser = self.get_scanparser(scan_number)
+ theta_vals = parser.theta_vals
+ spec_theta_start = theta_vals.get('start')
+ spec_theta_end = theta_vals.get('end')
+ spec_num_theta = theta_vals.get('num')
+
+ # Check for consistency of theta ranges between scans
+ if scan_number != self.scan_numbers[0]:
+ parser = self.get_scanparser(self.scan_numbers[0])
+ if (parser.theta_vals.get('start') != spec_theta_start or
+ parser.theta_vals.get('end') != spec_theta_end or
+ parser.theta_vals.get('num') != spec_num_theta):
+ raise ValueError(f'Incompatible theta ranges between {attr_desc}scans:'+
+ f'\n\tScan {scan_number}: {theta_vals}'+
+ f'\n\tScan {self.scan_numbers[0]}: {parser.theta_vals}')
+ return
+
+ # Select the theta range for the tomo reconstruction from the first scan
+ theta_range_approved = False
+ thetas = np.linspace(spec_theta_start, spec_theta_end, spec_num_theta)
+ delta_theta = thetas[1]-thetas[0]
+ print(f'Theta range obtained from SPEC data: [{spec_theta_start}, {spec_theta_end}]')
+ print(f'Theta step size = {delta_theta}')
+ print(f'Number of theta values: {spec_num_theta}')
+ default_start = None
+ default_end = None
+ if station in ('id1a3', 'id3a'):
+ theta_range_approved = input_yesno(f'Accept this theta range (y/n)?', 'y')
+ if theta_range_approved:
+ self.theta_range = {'start': float(spec_theta_start), 'end': float(spec_theta_end),
+ 'num': int(spec_num_theta), 'start_index': 0}
+ return
+ elif station in ('id3b'):
+ if spec_theta_start <= 0.0 and spec_theta_end >= 180.0:
+ default_start = 0
+ default_end = 180
+ elif spec_theta_end-spec_theta_start == 180:
+ default_start = spec_theta_start
+ default_end = spec_theta_end
+ while not theta_range_approved:
+ theta_start = input_num(f'Enter the first theta (included)', ge=spec_theta_start,
+ lt=spec_theta_end, default=default_start)
+ theta_index_start = index_nearest(thetas, theta_start)
+ theta_start = thetas[theta_index_start]
+ theta_end = input_num(f'Enter the last theta (excluded)',
+ ge=theta_start+delta_theta, le=spec_theta_end, default=default_end)
+ theta_index_end = index_nearest(thetas, theta_end)
+ theta_end = thetas[theta_index_end]
+ num_theta = theta_index_end-theta_index_start
+ print(f'Selected theta range: [{theta_start}, {theta_start+delta_theta}, ..., '+
+ f'{theta_end})')
+ print(f'Number of theta values: {num_theta}')
+ theta_range_approved = input_yesno(f'Accept this theta range (y/n)?', 'y')
+ self.theta_range = {'start': float(theta_start), 'end': float(theta_end),
+ 'num': int(num_theta), 'start_index': int(theta_index_start)}
+
+ def image_range_cli(self, attr_desc, detector_prefix):
+ stack_info = self.stack_info
+ for scan_number in self.scan_numbers:
+ # Parse the available image range
+ parser = self.get_scanparser(scan_number)
+ image_offset = parser.starting_image_offset
+ num_image = parser.get_num_image(detector_prefix.upper())
+ scan_index = self.get_scan_index(scan_number)
+
+ # Select the image set matching the theta range
+ num_theta = self.theta_range['num']
+ theta_index_start = self.theta_range['start_index']
+ if num_theta > num_image-theta_index_start:
+ raise ValueError(f'Available {attr_desc}image indices incompatible with thetas:'+
+ f'\n\tNumber of thetas and offset = {num_theta} and {theta_index_start}'+
+ f'\n\tNumber of available images {num_image}')
+ if scan_index is not None:
+ scan_info = stack_info[scan_index]
+ scan_info['starting_image_offset'] = image_offset+theta_index_start
+ scan_info['num_image'] = num_theta
+ scan_info['ref_x'] = parser.horizontal_shift
+ scan_info['ref_z'] = parser.vertical_shift
+ else:
+ stack_info.append({'scan_number': scan_number,
+ 'starting_image_offset': image_offset+theta_index_start,
+ 'num_image': num_theta, 'ref_x': parser.horizontal_shift,
+ 'ref_z': parser.vertical_shift})
+ self.stack_info = stack_info
+
+ def cli(self, **cli_kwargs):
+ if cli_kwargs.get('attr_desc') is not None:
+ attr_desc = f'{cli_kwargs["attr_desc"]} '
+ else:
+ attr_desc = ''
+ cycle = cli_kwargs.get('cycle')
+ btr = cli_kwargs.get('btr')
+ station = cli_kwargs.get('station')
+ detector = cli_kwargs.get('detector')
+ sample_name = cli_kwargs.get('sample_name')
+ print(f'\n -- Configure the location of the {attr_desc}scan data -- ')
+ if station in ('id1a3', 'id3a'):
+ basedir = f'/nfs/chess/{station}/{cycle}/{btr}'
+ runs = [d for d in os.listdir(basedir) if os.path.isdir(os.path.join(basedir, d))]
+#RV index = 15-1
+#RV index = 7-1
+ if sample_name is not None and sample_name in runs:
+ index = runs.index(sample_name)
+ else:
+ index = input_menu(runs, header='Choose a sample directory')
+ self.spec_file = f'{basedir}/{runs[index]}/spec.log'
+ self.scan_numbers_cli(attr_desc, station=station, scan_type='ts1')
+ else:
+ self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path')
+ self.scan_numbers_cli(attr_desc)
+ for scan_number in self.scan_numbers:
+ self.theta_range_cli(scan_number, attr_desc, station)
+ self.image_range_cli(attr_desc, detector.prefix)
+
+
+class Sample(BaseModel):
+ name: constr(min_length=1)
+ description: Optional[str]
+ rotation_angles: Optional[list]
+ x_translations: Optional[list]
+ z_translations: Optional[list]
+
+ @classmethod
+ def construct_from_nxsample(cls, nxsample:NXsample):
+ config = {}
+ config['name'] = nxsample.name.nxdata
+ if 'description' in nxsample:
+ config['description'] = nxsample.description.nxdata
+ if 'rotation_angle' in nxsample:
+ config['rotation_angle'] = nxsample.rotation_angle.nxdata
+ if 'x_translation' in nxsample:
+ config['x_translation'] = nxsample.x_translation.nxdata
+ if 'z_translation' in nxsample:
+ config['z_translation'] = nxsample.z_translation.nxdata
+ return(cls(**config))
+
+ def cli(self):
+ print('\n -- Configure the sample metadata -- ')
+#RV self.name = 'sobhani-3249-A'
+#RV self.name = 'tenstom_1304r-1'
+ self.set_single_attr_cli('name', 'the sample name')
+#RV self.description = 'test sample'
+ self.set_single_attr_cli('description', 'a description of the sample (optional)')
+
+
+class MapConfig(BaseModel):
+ cycle: constr(strip_whitespace=True, min_length=1)
+ btr: constr(strip_whitespace=True, min_length=1)
+ title: constr(strip_whitespace=True, min_length=1)
+ station: Literal['id1a3', 'id3a', 'id3b'] = None
+ sample: Sample
+ detector: Detector = Detector.construct()
+ tomo_fields: TomoField
+ dark_field: Optional[FlatField]
+ bright_field: FlatField
+ _thetas: list[float] = PrivateAttr()
+ _field_types = ({'name': 'dark_field', 'image_key': 2}, {'name': 'bright_field',
+ 'image_key': 1}, {'name': 'tomo_fields', 'image_key': 0})
+
+ @classmethod
+ def construct_from_nxentry(cls, nxentry:NXentry):
+ config = {}
+ config['cycle'] = nxentry.instrument.source.attrs['cycle']
+ config['btr'] = nxentry.instrument.source.attrs['btr']
+ config['title'] = nxentry.nxname
+ config['station'] = nxentry.instrument.source.attrs['station']
+ config['sample'] = Sample.construct_from_nxsample(nxentry['sample'])
+ for nxobject_name, nxobject in nxentry.spec_scans.items():
+ if isinstance(nxobject, NXcollection):
+ config[nxobject_name] = SpecScans.construct_from_nxcollection(nxobject)
+ return(cls(**config))
+
+#FIX cache?
+ @property
+ def thetas(self):
+ try:
+ return(self._thetas)
+ except:
+ theta_range = self.tomo_fields.theta_range
+ self._thetas = list(np.linspace(theta_range['start'], theta_range['end'],
+ theta_range['num']))
+ return(self._thetas)
+
+ def cli(self):
+ print('\n -- Configure a map from a set of SPEC scans (dark, bright, and tomo), '+
+ 'and / or detector data -- ')
+#RV self.cycle = '2021-3'
+#RV self.cycle = '2022-2'
+#RV self.cycle = '2023-1'
+ self.set_single_attr_cli('cycle', 'beam cycle')
+#RV self.btr = 'z-3234-A'
+#RV self.btr = 'sobhani-3249-A'
+#RV self.btr = 'przybyla-3606-a'
+ self.set_single_attr_cli('btr', 'BTR')
+#RV self.title = 'z-3234-A'
+#RV self.title = 'tomo7C'
+#RV self.title = 'cmc-test-dwell-1'
+ self.set_single_attr_cli('title', 'title for the map entry')
+#RV self.station = 'id3a'
+#RV self.station = 'id3b'
+#RV self.station = 'id1a3'
+ self.set_single_attr_cli('station', 'name of the station at which scans were collected '+
+ '(currently choose from: id1a3, id3a, id3b)')
+ import_scanparser(self.station)
+ self.set_single_attr_cli('sample')
+ use_detector_config = False
+ if hasattr(self.detector, 'prefix') and len(self.detector.prefix):
+ use_detector_config = input_yesno(f'Current detector settings:\n{self.detector}\n'+
+ f'Keep these settings? (y/n)')
+ if not use_detector_config:
+ menu_options = ['not listed', 'andor2', 'manta', 'retiga']
+ input_mode = input_menu(menu_options, header='Choose one of the following detector '+
+ 'configuration options')
+ if input_mode:
+ detector_config_file = f'{menu_options[input_mode]}.yaml'
+ have_detector_config = self.detector.construct_from_yaml(detector_config_file)
+ else:
+ have_detector_config = False
+ if not have_detector_config:
+ self.set_single_attr_cli('detector', 'detector')
+ self.set_single_attr_cli('tomo_fields', 'Tomo field', chain_attr_desc=True,
+ cycle=self.cycle, btr=self.btr, station=self.station, detector=self.detector,
+ sample_name=self.sample.name)
+ if self.station in ('id1a3', 'id3a'):
+ have_dark_field = True
+ tomo_spec_file = self.tomo_fields.spec_file
+ else:
+ have_dark_field = input_yesno(f'Are Dark field images available? (y/n)')
+ tomo_spec_file = None
+ if have_dark_field:
+ self.set_single_attr_cli('dark_field', 'Dark field', chain_attr_desc=True,
+ station=self.station, detector=self.detector, spec_file=tomo_spec_file,
+ tomo_scan_numbers=self.tomo_fields.scan_numbers, scan_type='df1')
+ self.set_single_attr_cli('bright_field', 'Bright field', chain_attr_desc=True,
+ station=self.station, detector=self.detector, spec_file=tomo_spec_file,
+ tomo_scan_numbers=self.tomo_fields.scan_numbers, scan_type='bf1')
+
+ def construct_nxentry(self, nxroot, include_raw_data=True):
+ # Construct base NXentry
+ nxentry = NXentry()
+
+ # Add an NXentry to the NXroot
+ nxroot[self.title] = nxentry
+ nxroot.attrs['default'] = self.title
+ nxentry.definition = 'NXtomo'
+# nxentry.attrs['default'] = 'data'
+
+ # Add an NXinstrument to the NXentry
+ nxinstrument = NXinstrument()
+ nxentry.instrument = nxinstrument
+
+ # Add an NXsource to the NXinstrument
+ nxsource = NXsource()
+ nxinstrument.source = nxsource
+ nxsource.type = 'Synchrotron X-ray Source'
+ nxsource.name = 'CHESS'
+ nxsource.probe = 'x-ray'
+
+ # Tag the NXsource with the runinfo (as an attribute)
+ nxsource.attrs['cycle'] = self.cycle
+ nxsource.attrs['btr'] = self.btr
+ nxsource.attrs['station'] = self.station
+
+ # Add an NXdetector to the NXinstrument (don't fill in data fields yet)
+ nxinstrument.detector = self.detector.construct_nxdetector()
+
+ # Add an NXsample to NXentry (don't fill in data fields yet)
+ nxsample = NXsample()
+ nxentry.sample = nxsample
+ nxsample.name = self.sample.name
+ nxsample.description = self.sample.description
+
+ # Add an NXcollection to the base NXentry to hold metadata about the spec scans in the map
+ # Also obtain the data fields in NXsample and NXdetector
+ nxspec_scans = NXcollection()
+ nxentry.spec_scans = nxspec_scans
+ image_keys = []
+ sequence_numbers = []
+ image_stacks = []
+ rotation_angles = []
+ x_translations = []
+ z_translations = []
+ for field_type in self._field_types:
+ field_name = field_type['name']
+ field = getattr(self, field_name)
+ if field is None:
+ continue
+ image_key = field_type['image_key']
+ if field_type['name'] == 'tomo_fields':
+ thetas = self.thetas
+ else:
+ thetas = None
+ # Add the scans in a single spec file
+ nxspec_scans[field_name] = field.construct_nxcollection(image_key, thetas,
+ self.detector)
+ if include_raw_data:
+ image_stacks += field.get_detector_data(self.detector.prefix)
+ for scan_number in field.scan_numbers:
+ parser = field.get_scanparser(scan_number)
+ scan_info = field.stack_info[field.get_scan_index(scan_number)]
+ num_image = scan_info['num_image']
+ image_keys += num_image*[image_key]
+ sequence_numbers += [i for i in range(num_image)]
+ if thetas is None:
+ rotation_angles += scan_info['num_image']*[0.0]
+ else:
+ assert(num_image == len(thetas))
+ rotation_angles += thetas
+ x_translations += scan_info['num_image']*[scan_info['ref_x']]
+ z_translations += scan_info['num_image']*[scan_info['ref_z']]
+
+ if include_raw_data:
+ # Add image data to NXdetector
+ nxinstrument.detector.image_key = image_keys
+ nxinstrument.detector.sequence_number = sequence_numbers
+ nxinstrument.detector.data = np.concatenate([image for image in image_stacks])
+
+ # Add image data to NXsample
+ nxsample.rotation_angle = rotation_angles
+ nxsample.rotation_angle.attrs['units'] = 'degrees'
+ nxsample.x_translation = x_translations
+ nxsample.x_translation.attrs['units'] = 'mm'
+ nxsample.z_translation = z_translations
+ nxsample.z_translation.attrs['units'] = 'mm'
+
+ # Add an NXdata to NXentry
+ nxdata = NXdata()
+ nxentry.data = nxdata
+ nxdata.makelink(nxentry.instrument.detector.data, name='data')
+ nxdata.makelink(nxentry.instrument.detector.image_key)
+ nxdata.makelink(nxentry.sample.rotation_angle)
+ nxdata.makelink(nxentry.sample.x_translation)
+ nxdata.makelink(nxentry.sample.z_translation)
+# nxdata.attrs['axes'] = ['field', 'row', 'column']
+# nxdata.attrs['field_indices'] = 0
+# nxdata.attrs['row_indices'] = 1
+# nxdata.attrs['column_indices'] = 2
+
+
+class TomoWorkflow(BaseModel):
+ sample_maps: conlist(item_type=MapConfig, min_items=1) = [MapConfig.construct()]
+
+ @classmethod
+ def construct_from_nexus(cls, filename):
+ nxroot = nxload(filename)
+ sample_maps = []
+ config = {'sample_maps': sample_maps}
+ for nxentry_name, nxentry in nxroot.items():
+ sample_maps.append(MapConfig.construct_from_nxentry(nxentry))
+ return(cls(**config))
+
+ def cli(self):
+ print('\n -- Configure a map -- ')
+ self.set_list_attr_cli('sample_maps', 'sample map')
+
+ def construct_nxfile(self, filename, mode='w-'):
+ nxroot = NXroot()
+ t0 = time()
+ for sample_map in self.sample_maps:
+ logger.info(f'Start constructing the {sample_map.title} map.')
+ import_scanparser(sample_map.station)
+ sample_map.construct_nxentry(nxroot)
+ logger.info(f'Constructed all sample maps in {time()-t0:.2f} seconds.')
+ logger.info(f'Start saving all sample maps to {filename}.')
+ nxroot.save(filename, mode=mode)
+
+ def write_to_nexus(self, filename):
+ t0 = time()
+ self.construct_nxfile(filename, mode='w')
+ logger.info(f'Saved all sample maps to {filename} in {time()-t0:.2f} seconds.')
diff -r 000000000000 -r 98e23dff1de2 workflow/run_tomo.py
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/workflow/run_tomo.py Tue Mar 21 16:22:42 2023 +0000
@@ -0,0 +1,1645 @@
+#!/usr/bin/env python3
+
+import logging
+logger = logging.getLogger(__name__)
+
+import numpy as np
+try:
+ import numexpr as ne
+except:
+ pass
+try:
+ import scipy.ndimage as spi
+except:
+ pass
+
+from multiprocessing import cpu_count
+from nexusformat.nexus import *
+from os import mkdir
+from os import path as os_path
+try:
+ from skimage.transform import iradon
+except:
+ pass
+try:
+ from skimage.restoration import denoise_tv_chambolle
+except:
+ pass
+from time import time
+try:
+ import tomopy
+except:
+ pass
+from yaml import safe_load, safe_dump
+
+try:
+ from msnctools.fit import Fit
+except:
+ from fit import Fit
+try:
+ from msnctools.general import illegal_value, is_int, is_int_pair, is_num, is_index_range, \
+ input_int, input_num, input_yesno, input_menu, draw_mask_1d, select_image_bounds, \
+ select_one_image_bound, clear_imshow, quick_imshow, clear_plot, quick_plot
+except:
+ from general import illegal_value, is_int, is_int_pair, is_num, is_index_range, \
+ input_int, input_num, input_yesno, input_menu, draw_mask_1d, select_image_bounds, \
+ select_one_image_bound, clear_imshow, quick_imshow, clear_plot, quick_plot
+
+try:
+ from workflow.models import import_scanparser, FlatField, TomoField, TomoWorkflow
+ from workflow.__version__ import __version__
+except:
+ pass
+
+num_core_tomopy_limit = 24
+
+def nxcopy(nxobject:NXobject, exclude_nxpaths:list[str]=[], nxpath_prefix:str='') -> NXobject:
+ '''Function that returns a copy of a nexus object, optionally exluding certain child items.
+
+ :param nxobject: the original nexus object to return a "copy" of
+ :type nxobject: nexusformat.nexus.NXobject
+ :param exlude_nxpaths: a list of paths to child nexus objects that
+ should be exluded from the returned "copy", defaults to `[]`
+ :type exclude_nxpaths: list[str], optional
+ :param nxpath_prefix: For use in recursive calls from inside this
+ function only!
+ :type nxpath_prefix: str
+ :return: a copy of `nxobject` with some children optionally exluded.
+ :rtype: NXobject
+ '''
+
+ nxobject_copy = nxobject.__class__()
+ if not len(nxpath_prefix):
+ if 'default' in nxobject.attrs:
+ nxobject_copy.attrs['default'] = nxobject.attrs['default']
+ else:
+ for k, v in nxobject.attrs.items():
+ nxobject_copy.attrs[k] = v
+
+ for k, v in nxobject.items():
+ nxpath = os_path.join(nxpath_prefix, k)
+
+ if nxpath in exclude_nxpaths:
+ continue
+
+ if isinstance(v, NXgroup):
+ nxobject_copy[k] = nxcopy(v, exclude_nxpaths=exclude_nxpaths,
+ nxpath_prefix=os_path.join(nxpath_prefix, k))
+ else:
+ nxobject_copy[k] = v
+
+ return(nxobject_copy)
+
+class set_numexpr_threads:
+
+ def __init__(self, num_core):
+ if num_core is None or num_core < 1 or num_core > cpu_count():
+ self.num_core = cpu_count()
+ else:
+ self.num_core = num_core
+
+ def __enter__(self):
+ self.num_core_org = ne.set_num_threads(self.num_core)
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ ne.set_num_threads(self.num_core_org)
+
+class Tomo:
+ """Processing tomography data with misalignment.
+ """
+ def __init__(self, galaxy_flag=False, num_core=-1, output_folder='.', save_figs=None,
+ test_mode=False):
+ """Initialize with optional config input file or dictionary
+ """
+ if not isinstance(galaxy_flag, bool):
+ raise ValueError(f'Invalid parameter galaxy_flag ({galaxy_flag})')
+ self.galaxy_flag = galaxy_flag
+ self.num_core = num_core
+ if self.galaxy_flag:
+ if output_folder != '.':
+ logger.warning('Ignoring output_folder in galaxy mode')
+ self.output_folder = '.'
+ if test_mode != False:
+ logger.warning('Ignoring test_mode in galaxy mode')
+ self.test_mode = False
+ if save_figs is not None:
+ logger.warning('Ignoring save_figs in galaxy mode')
+ save_figs = 'only'
+ else:
+ self.output_folder = os_path.abspath(output_folder)
+ if not os_path.isdir(output_folder):
+ mkdir(os_path.abspath(output_folder))
+ if not isinstance(test_mode, bool):
+ raise ValueError(f'Invalid parameter test_mode ({test_mode})')
+ self.test_mode = test_mode
+ if save_figs is None:
+ save_figs = 'no'
+ self.test_config = {}
+ if self.test_mode:
+ if save_figs != 'only':
+ logger.warning('Ignoring save_figs in test mode')
+ save_figs = 'only'
+ if save_figs == 'only':
+ self.save_only = True
+ self.save_figs = True
+ elif save_figs == 'yes':
+ self.save_only = False
+ self.save_figs = True
+ elif save_figs == 'no':
+ self.save_only = False
+ self.save_figs = False
+ else:
+ raise ValueError(f'Invalid parameter save_figs ({save_figs})')
+ if self.save_only:
+ self.block = False
+ else:
+ self.block = True
+ if self.num_core == -1:
+ self.num_core = cpu_count()
+ if not is_int(self.num_core, gt=0, log=False):
+ raise ValueError(f'Invalid parameter num_core ({num_core})')
+ if self.num_core > cpu_count():
+ logger.warning(f'num_core = {self.num_core} is larger than the number of available '
+ f'processors and reduced to {cpu_count()}')
+ self.num_core= cpu_count()
+
+ def read(self, filename):
+ logger.info(f'looking for {filename}')
+ if self.galaxy_flag:
+ try:
+ with open(filename, 'r') as f:
+ config = safe_load(f)
+ return(config)
+ except:
+ try:
+ with NXFile(filename, mode='r') as nxfile:
+ nxroot = nxfile.readfile()
+ return(nxroot)
+ except:
+ raise ValueError(f'Unable to open ({filename})')
+ else:
+ extension = os_path.splitext(filename)[1]
+ if extension == '.yml' or extension == '.yaml':
+ with open(filename, 'r') as f:
+ config = safe_load(f)
+# if len(config) > 1:
+# raise ValueError(f'Multiple root entries in {filename} not yet implemented')
+# if len(list(config.values())[0]) > 1:
+# raise ValueError(f'Multiple sample maps in {filename} not yet implemented')
+ return(config)
+ elif extension == '.nxs':
+ with NXFile(filename, mode='r') as nxfile:
+ nxroot = nxfile.readfile()
+ return(nxroot)
+ else:
+ raise ValueError(f'Invalid filename extension ({extension})')
+
+ def write(self, data, filename):
+ extension = os_path.splitext(filename)[1]
+ if extension == '.yml' or extension == '.yaml':
+ with open(filename, 'w') as f:
+ safe_dump(data, f)
+ elif extension == '.nxs' or extension == '.nex':
+ data.save(filename, mode='w')
+ elif extension == '.nc':
+ data.to_netcdf(os_path=filename)
+ else:
+ raise ValueError(f'Invalid filename extension ({extension})')
+
+ def gen_reduced_data(self, data, img_x_bounds=None):
+ """Generate the reduced tomography images.
+ """
+ logger.info('Generate the reduced tomography images')
+
+ # Create plot galaxy path directory if needed
+ if self.galaxy_flag and not os_path.exists('tomo_reduce_plots'):
+ mkdir('tomo_reduce_plots')
+
+ if isinstance(data, dict):
+ # Create Nexus format object from input dictionary
+ wf = TomoWorkflow(**data)
+ if len(wf.sample_maps) > 1:
+ raise ValueError(f'Multiple sample maps not yet implemented')
+# print(f'\nwf:\n{wf}\n')
+ nxroot = NXroot()
+ t0 = time()
+ for sample_map in wf.sample_maps:
+ logger.info(f'Start constructing the {sample_map.title} map.')
+ import_scanparser(sample_map.station)
+ sample_map.construct_nxentry(nxroot, include_raw_data=False)
+ logger.info(f'Constructed all sample maps in {time()-t0:.2f} seconds.')
+ nxentry = nxroot[nxroot.attrs['default']]
+ # Get test mode configuration info
+ if self.test_mode:
+ self.test_config = data['sample_maps'][0]['test_mode']
+ elif isinstance(data, NXroot):
+ nxentry = data[data.attrs['default']]
+ else:
+ raise ValueError(f'Invalid parameter data ({data})')
+
+ # Create an NXprocess to store data reduction (meta)data
+ reduced_data = NXprocess()
+
+ # Generate dark field
+ if 'dark_field' in nxentry['spec_scans']:
+ reduced_data = self._gen_dark(nxentry, reduced_data)
+
+ # Generate bright field
+ reduced_data = self._gen_bright(nxentry, reduced_data)
+
+ # Set vertical detector bounds for image stack
+ img_x_bounds = self._set_detector_bounds(nxentry, reduced_data, img_x_bounds=img_x_bounds)
+ logger.info(f'img_x_bounds = {img_x_bounds}')
+ reduced_data['img_x_bounds'] = img_x_bounds
+
+ # Set zoom and/or theta skip to reduce memory the requirement
+ zoom_perc, num_theta_skip = self._set_zoom_or_skip()
+ if zoom_perc is not None:
+ reduced_data.attrs['zoom_perc'] = zoom_perc
+ if num_theta_skip is not None:
+ reduced_data.attrs['num_theta_skip'] = num_theta_skip
+
+ # Generate reduced tomography fields
+ reduced_data = self._gen_tomo(nxentry, reduced_data)
+
+ # Create a copy of the input Nexus object and remove raw and any existing reduced data
+ if isinstance(data, NXroot):
+ exclude_items = [f'{nxentry._name}/reduced_data/data',
+ f'{nxentry._name}/instrument/detector/data',
+ f'{nxentry._name}/instrument/detector/image_key',
+ f'{nxentry._name}/instrument/detector/sequence_number',
+ f'{nxentry._name}/sample/rotation_angle',
+ f'{nxentry._name}/sample/x_translation',
+ f'{nxentry._name}/sample/z_translation',
+ f'{nxentry._name}/data/data',
+ f'{nxentry._name}/data/image_key',
+ f'{nxentry._name}/data/rotation_angle',
+ f'{nxentry._name}/data/x_translation',
+ f'{nxentry._name}/data/z_translation']
+ nxroot = nxcopy(data, exclude_nxpaths=exclude_items)
+ nxentry = nxroot[nxroot.attrs['default']]
+
+ # Add the reduced data NXprocess
+ nxentry.reduced_data = reduced_data
+
+ if 'data' not in nxentry:
+ nxentry.data = NXdata()
+ nxentry.attrs['default'] = 'data'
+ nxentry.data.makelink(nxentry.reduced_data.data.tomo_fields, name='reduced_data')
+ nxentry.data.makelink(nxentry.reduced_data.rotation_angle, name='rotation_angle')
+ nxentry.data.attrs['signal'] = 'reduced_data'
+
+ return(nxroot)
+
+ def find_centers(self, nxroot, center_rows=None, center_stack_index=None):
+ """Find the calibrated center axis info
+ """
+ logger.info('Find the calibrated center axis info')
+
+ if not isinstance(nxroot, NXroot):
+ raise ValueError(f'Invalid parameter nxroot ({nxroot})')
+ nxentry = nxroot[nxroot.attrs['default']]
+ if not isinstance(nxentry, NXentry):
+ raise ValueError(f'Invalid nxentry ({nxentry})')
+ if self.galaxy_flag:
+ if center_rows is not None:
+ center_rows = tuple(center_rows)
+ if not is_int_pair(center_rows):
+ raise ValueError(f'Invalid parameter center_rows ({center_rows})')
+ elif center_rows is not None:
+ logger.warning(f'Ignoring parameter center_rows ({center_rows})')
+ center_rows = None
+ if self.galaxy_flag:
+ if center_stack_index is not None and not is_int(center_stack_index, ge=0):
+ raise ValueError(f'Invalid parameter center_stack_index ({center_stack_index})')
+ elif center_stack_index is not None:
+ logger.warning(f'Ignoring parameter center_stack_index ({center_stack_index})')
+ center_stack_index = None
+
+ # Create plot galaxy path directory and path if needed
+ if self.galaxy_flag:
+ if not os_path.exists('tomo_find_centers_plots'):
+ mkdir('tomo_find_centers_plots')
+ path = 'tomo_find_centers_plots'
+ else:
+ path = self.output_folder
+
+ # Check if reduced data is available
+ if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data):
+ raise KeyError(f'Unable to find valid reduced data in {nxentry}.')
+
+ # Select the image stack to calibrate the center axis
+ # reduced data axes order: stack,theta,row,column
+ # Note: Nexus cannot follow a link if the data it points to is too big,
+ # so get the data from the actual place, not from nxentry.data
+ tomo_fields_shape = nxentry.reduced_data.data.tomo_fields.shape
+ if len(tomo_fields_shape) != 4 or any(True for dim in tomo_fields_shape if not dim):
+ raise KeyError('Unable to load the required reduced tomography stack')
+ num_tomo_stacks = tomo_fields_shape[0]
+ if num_tomo_stacks == 1:
+ center_stack_index = 0
+ default = 'n'
+ else:
+ if self.test_mode:
+ center_stack_index = self.test_config['center_stack_index']-1 # make offset 0
+ elif self.galaxy_flag:
+ if center_stack_index is None:
+ center_stack_index = int(num_tomo_stacks/2)
+ if center_stack_index >= num_tomo_stacks:
+ raise ValueError(f'Invalid parameter center_stack_index ({center_stack_index})')
+ else:
+ center_stack_index = input_int('\nEnter tomography stack index to calibrate the '
+ 'center axis', ge=1, le=num_tomo_stacks, default=int(1+num_tomo_stacks/2))
+ center_stack_index -= 1
+ default = 'y'
+
+ # Get thetas (in degrees)
+ thetas = np.asarray(nxentry.reduced_data.rotation_angle)
+
+ # Get effective pixel_size
+ if 'zoom_perc' in nxentry.reduced_data:
+ eff_pixel_size = 100.*(nxentry.instrument.detector.x_pixel_size/
+ nxentry.reduced_data.attrs['zoom_perc'])
+ else:
+ eff_pixel_size = nxentry.instrument.detector.x_pixel_size
+
+ # Get cross sectional diameter
+ cross_sectional_dim = tomo_fields_shape[3]*eff_pixel_size
+ logger.debug(f'cross_sectional_dim = {cross_sectional_dim}')
+
+ # Determine center offset at sample row boundaries
+ logger.info('Determine center offset at sample row boundaries')
+
+ # Lower row center
+ if self.test_mode:
+ lower_row = self.test_config['lower_row']
+ elif self.galaxy_flag:
+ if center_rows is None or center_rows[0] == -1:
+ lower_row = 0
+ else:
+ lower_row = min(center_rows)
+ if not 0 <= lower_row < tomo_fields_shape[2]-1:
+ raise ValueError(f'Invalid parameter center_rows ({center_rows})')
+ else:
+ lower_row = select_one_image_bound(
+ nxentry.reduced_data.data.tomo_fields[center_stack_index,0,:,:], 0, bound=0,
+ title=f'theta={round(thetas[0], 2)+0}',
+ bound_name='row index to find lower center', default=default, raise_error=True)
+ logger.debug('Finding center...')
+ t0 = time()
+ lower_center_offset = self._find_center_one_plane(
+ #np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index,:,lower_row,:]),
+ nxentry.reduced_data.data.tomo_fields[center_stack_index,:,lower_row,:],
+ lower_row, thetas, eff_pixel_size, cross_sectional_dim, path=path,
+ num_core=self.num_core)
+ logger.debug(f'... done in {time()-t0:.2f} seconds')
+ logger.debug(f'lower_row = {lower_row:.2f}')
+ logger.debug(f'lower_center_offset = {lower_center_offset:.2f}')
+
+ # Upper row center
+ if self.test_mode:
+ upper_row = self.test_config['upper_row']
+ elif self.galaxy_flag:
+ if center_rows is None or center_rows[1] == -1:
+ upper_row = tomo_fields_shape[2]-1
+ else:
+ upper_row = max(center_rows)
+ if not lower_row < upper_row < tomo_fields_shape[2]:
+ raise ValueError(f'Invalid parameter center_rows ({center_rows})')
+ else:
+ upper_row = select_one_image_bound(
+ nxentry.reduced_data.data.tomo_fields[center_stack_index,0,:,:], 0,
+ bound=tomo_fields_shape[2]-1, title=f'theta={round(thetas[0], 2)+0}',
+ bound_name='row index to find upper center', default=default, raise_error=True)
+ logger.debug('Finding center...')
+ t0 = time()
+ upper_center_offset = self._find_center_one_plane(
+ #np.asarray(nxentry.reduced_data.data.tomo_fields[center_stack_index,:,upper_row,:]),
+ nxentry.reduced_data.data.tomo_fields[center_stack_index,:,upper_row,:],
+ upper_row, thetas, eff_pixel_size, cross_sectional_dim, path=path,
+ num_core=self.num_core)
+ logger.debug(f'... done in {time()-t0:.2f} seconds')
+ logger.debug(f'upper_row = {upper_row:.2f}')
+ logger.debug(f'upper_center_offset = {upper_center_offset:.2f}')
+
+ center_config = {'lower_row': lower_row, 'lower_center_offset': lower_center_offset,
+ 'upper_row': upper_row, 'upper_center_offset': upper_center_offset}
+ if num_tomo_stacks > 1:
+ center_config['center_stack_index'] = center_stack_index+1 # save as offset 1
+
+ # Save test data to file
+ if self.test_mode:
+ with open(f'{self.output_folder}/center_config.yaml', 'w') as f:
+ safe_dump(center_config, f)
+
+ return(center_config)
+
+ def reconstruct_data(self, nxroot, center_info, x_bounds=None, y_bounds=None):
+ """Reconstruct the tomography data.
+ """
+ logger.info('Reconstruct the tomography data')
+
+ if not isinstance(nxroot, NXroot):
+ raise ValueError(f'Invalid parameter nxroot ({nxroot})')
+ nxentry = nxroot[nxroot.attrs['default']]
+ if not isinstance(nxentry, NXentry):
+ raise ValueError(f'Invalid nxentry ({nxentry})')
+ if not isinstance(center_info, dict):
+ raise ValueError(f'Invalid parameter center_info ({center_info})')
+
+ # Create plot galaxy path directory and path if needed
+ if self.galaxy_flag:
+ if not os_path.exists('tomo_reconstruct_plots'):
+ mkdir('tomo_reconstruct_plots')
+ path = 'tomo_reconstruct_plots'
+ else:
+ path = self.output_folder
+
+ # Check if reduced data is available
+ if ('reduced_data' not in nxentry or 'reduced_data' not in nxentry.data):
+ raise KeyError(f'Unable to find valid reduced data in {nxentry}.')
+
+ # Create an NXprocess to store image reconstruction (meta)data
+ nxprocess = NXprocess()
+
+ # Get rotation axis rows and centers
+ lower_row = center_info.get('lower_row')
+ lower_center_offset = center_info.get('lower_center_offset')
+ upper_row = center_info.get('upper_row')
+ upper_center_offset = center_info.get('upper_center_offset')
+ if (lower_row is None or lower_center_offset is None or upper_row is None or
+ upper_center_offset is None):
+ raise KeyError(f'Unable to find valid calibrated center axis info in {center_info}.')
+ center_slope = (upper_center_offset-lower_center_offset)/(upper_row-lower_row)
+
+ # Get thetas (in degrees)
+ thetas = np.asarray(nxentry.reduced_data.rotation_angle)
+
+ # Reconstruct tomography data
+ # reduced data axes order: stack,theta,row,column
+ # reconstructed data order in each stack: row/z,x,y
+ # Note: Nexus cannot follow a link if the data it points to is too big,
+ # so get the data from the actual place, not from nxentry.data
+ if 'zoom_perc' in nxentry.reduced_data:
+ res_title = f'{nxentry.reduced_data.attrs["zoom_perc"]}p'
+ else:
+ res_title = 'fullres'
+ load_error = False
+ num_tomo_stacks = nxentry.reduced_data.data.tomo_fields.shape[0]
+ tomo_recon_stacks = num_tomo_stacks*[np.array([])]
+ for i in range(num_tomo_stacks):
+ # Convert reduced data stack from theta,row,column to row,theta,column
+ logger.debug(f'Reading reduced data stack {i+1}...')
+ t0 = time()
+ tomo_stack = np.asarray(nxentry.reduced_data.data.tomo_fields[i])
+ logger.debug(f'... done in {time()-t0:.2f} seconds')
+ if len(tomo_stack.shape) != 3 or any(True for dim in tomo_stack.shape if not dim):
+ raise ValueError(f'Unable to load tomography stack {i+1} for reconstruction')
+ tomo_stack = np.swapaxes(tomo_stack, 0, 1)
+ assert(len(thetas) == tomo_stack.shape[1])
+ assert(0 <= lower_row < upper_row < tomo_stack.shape[0])
+ center_offsets = [lower_center_offset-lower_row*center_slope,
+ upper_center_offset+(tomo_stack.shape[0]-1-upper_row)*center_slope]
+ t0 = time()
+ logger.debug(f'Running _reconstruct_one_tomo_stack on {self.num_core} cores ...')
+ tomo_recon_stack = self._reconstruct_one_tomo_stack(tomo_stack, thetas,
+ center_offsets=center_offsets, num_core=self.num_core, algorithm='gridrec')
+ logger.debug(f'... done in {time()-t0:.2f} seconds')
+ logger.info(f'Reconstruction of stack {i+1} took {time()-t0:.2f} seconds')
+
+ # Combine stacks
+ tomo_recon_stacks[i] = tomo_recon_stack
+
+ # Resize the reconstructed tomography data
+ # reconstructed data order in each stack: row/z,x,y
+ if self.test_mode:
+ x_bounds = self.test_config.get('x_bounds')
+ y_bounds = self.test_config.get('y_bounds')
+ z_bounds = None
+ elif self.galaxy_flag:
+ if x_bounds is not None and not is_int_pair(x_bounds, ge=0,
+ lt=tomo_recon_stacks[0].shape[1]):
+ raise ValueError(f'Invalid parameter x_bounds ({x_bounds})')
+ if y_bounds is not None and not is_int_pair(y_bounds, ge=0,
+ lt=tomo_recon_stacks[0].shape[1]):
+ raise ValueError(f'Invalid parameter y_bounds ({y_bounds})')
+ z_bounds = None
+ else:
+ x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_stacks)
+ if x_bounds is None:
+ x_range = (0, tomo_recon_stacks[0].shape[1])
+ x_slice = int(x_range[1]/2)
+ else:
+ x_range = (min(x_bounds), max(x_bounds))
+ x_slice = int((x_bounds[0]+x_bounds[1])/2)
+ if y_bounds is None:
+ y_range = (0, tomo_recon_stacks[0].shape[2])
+ y_slice = int(y_range[1]/2)
+ else:
+ y_range = (min(y_bounds), max(y_bounds))
+ y_slice = int((y_bounds[0]+y_bounds[1])/2)
+ if z_bounds is None:
+ z_range = (0, tomo_recon_stacks[0].shape[0])
+ z_slice = int(z_range[1]/2)
+ else:
+ z_range = (min(z_bounds), max(z_bounds))
+ z_slice = int((z_bounds[0]+z_bounds[1])/2)
+
+ # Plot a few reconstructed image slices
+ if num_tomo_stacks == 1:
+ basetitle = 'recon'
+ else:
+ basetitle = f'recon stack {i+1}'
+ for i, stack in enumerate(tomo_recon_stacks):
+ title = f'{basetitle} {res_title} xslice{x_slice}'
+ quick_imshow(stack[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]],
+ title=title, path=path, save_fig=self.save_figs, save_only=self.save_only,
+ block=self.block)
+ title = f'{basetitle} {res_title} yslice{y_slice}'
+ quick_imshow(stack[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice],
+ title=title, path=path, save_fig=self.save_figs, save_only=self.save_only,
+ block=self.block)
+ title = f'{basetitle} {res_title} zslice{z_slice}'
+ quick_imshow(stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]],
+ title=title, path=path, save_fig=self.save_figs, save_only=self.save_only,
+ block=self.block)
+
+ # Save test data to file
+ # reconstructed data order in each stack: row/z,x,y
+ if self.test_mode:
+ for i, stack in enumerate(tomo_recon_stacks):
+ np.savetxt(f'{self.output_folder}/recon_stack_{i+1}.txt',
+ stack[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e')
+
+ # Add image reconstruction to reconstructed data NXprocess
+ # reconstructed data order in each stack: row/z,x,y
+ nxprocess.data = NXdata()
+ nxprocess.attrs['default'] = 'data'
+ for k, v in center_info.items():
+ nxprocess[k] = v
+ if x_bounds is not None:
+ nxprocess.x_bounds = x_bounds
+ if y_bounds is not None:
+ nxprocess.y_bounds = y_bounds
+ if z_bounds is not None:
+ nxprocess.z_bounds = z_bounds
+ nxprocess.data['reconstructed_data'] = np.asarray([stack[z_range[0]:z_range[1],
+ x_range[0]:x_range[1],y_range[0]:y_range[1]] for stack in tomo_recon_stacks])
+ nxprocess.data.attrs['signal'] = 'reconstructed_data'
+
+ # Create a copy of the input Nexus object and remove reduced data
+ exclude_items = [f'{nxentry._name}/reduced_data/data', f'{nxentry._name}/data/reduced_data']
+ nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items)
+
+ # Add the reconstructed data NXprocess to the new Nexus object
+ nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']]
+ nxentry_copy.reconstructed_data = nxprocess
+ if 'data' not in nxentry_copy:
+ nxentry_copy.data = NXdata()
+ nxentry_copy.attrs['default'] = 'data'
+ nxentry_copy.data.makelink(nxprocess.data.reconstructed_data, name='reconstructed_data')
+ nxentry_copy.data.attrs['signal'] = 'reconstructed_data'
+
+ return(nxroot_copy)
+
+ def combine_data(self, nxroot, x_bounds=None, y_bounds=None):
+ """Combine the reconstructed tomography stacks.
+ """
+ logger.info('Combine the reconstructed tomography stacks')
+
+ if not isinstance(nxroot, NXroot):
+ raise ValueError(f'Invalid parameter nxroot ({nxroot})')
+ nxentry = nxroot[nxroot.attrs['default']]
+ if not isinstance(nxentry, NXentry):
+ raise ValueError(f'Invalid nxentry ({nxentry})')
+
+ # Create plot galaxy path directory and path if needed
+ if self.galaxy_flag:
+ if not os_path.exists('tomo_combine_plots'):
+ mkdir('tomo_combine_plots')
+ path = 'tomo_combine_plots'
+ else:
+ path = self.output_folder
+
+ # Check if reconstructed image data is available
+ if ('reconstructed_data' not in nxentry or 'reconstructed_data' not in nxentry.data):
+ raise KeyError(f'Unable to find valid reconstructed image data in {nxentry}.')
+
+ # Create an NXprocess to store combined image reconstruction (meta)data
+ nxprocess = NXprocess()
+
+ # Get the reconstructed data
+ # reconstructed data order: stack,row(z),x,y
+ # Note: Nexus cannot follow a link if the data it points to is too big,
+ # so get the data from the actual place, not from nxentry.data
+ num_tomo_stacks = nxentry.reconstructed_data.data.reconstructed_data.shape[0]
+ if num_tomo_stacks == 1:
+ logger.info('Only one stack available: leaving combine_data')
+ return(None)
+
+ # Combine the reconstructed stacks
+ # (load one stack at a time to reduce risk of hitting Nexus data access limit)
+ t0 = time()
+ logger.debug(f'Combining the reconstructed stacks ...')
+ tomo_recon_combined = np.asarray(nxentry.reconstructed_data.data.reconstructed_data[0])
+ if num_tomo_stacks > 2:
+ tomo_recon_combined = np.concatenate([tomo_recon_combined]+
+ [nxentry.reconstructed_data.data.reconstructed_data[i]
+ for i in range(1, num_tomo_stacks-1)])
+ if num_tomo_stacks > 1:
+ tomo_recon_combined = np.concatenate([tomo_recon_combined]+
+ [nxentry.reconstructed_data.data.reconstructed_data[num_tomo_stacks-1]])
+ logger.debug(f'... done in {time()-t0:.2f} seconds')
+ logger.info(f'Combining the reconstructed stacks took {time()-t0:.2f} seconds')
+
+ # Resize the combined tomography data stacks
+ # combined data order: row/z,x,y
+ if self.test_mode:
+ x_bounds = None
+ y_bounds = None
+ z_bounds = self.test_config.get('z_bounds')
+ elif self.galaxy_flag:
+ if x_bounds is not None and not is_int_pair(x_bounds, ge=0,
+ lt=tomo_recon_stacks[0].shape[1]):
+ raise ValueError(f'Invalid parameter x_bounds ({x_bounds})')
+ if y_bounds is not None and not is_int_pair(y_bounds, ge=0,
+ lt=tomo_recon_stacks[0].shape[1]):
+ raise ValueError(f'Invalid parameter y_bounds ({y_bounds})')
+ z_bounds = None
+ else:
+ x_bounds, y_bounds, z_bounds = self._resize_reconstructed_data(tomo_recon_combined,
+ z_only=True)
+ if x_bounds is None:
+ x_range = (0, tomo_recon_combined.shape[1])
+ x_slice = int(x_range[1]/2)
+ else:
+ x_range = x_bounds
+ x_slice = int((x_bounds[0]+x_bounds[1])/2)
+ if y_bounds is None:
+ y_range = (0, tomo_recon_combined.shape[2])
+ y_slice = int(y_range[1]/2)
+ else:
+ y_range = y_bounds
+ y_slice = int((y_bounds[0]+y_bounds[1])/2)
+ if z_bounds is None:
+ z_range = (0, tomo_recon_combined.shape[0])
+ z_slice = int(z_range[1]/2)
+ else:
+ z_range = z_bounds
+ z_slice = int((z_bounds[0]+z_bounds[1])/2)
+
+ # Plot a few combined image slices
+ quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_slice,y_range[0]:y_range[1]],
+ title=f'recon combined xslice{x_slice}', path=path,
+ save_fig=self.save_figs, save_only=self.save_only, block=self.block)
+ quick_imshow(tomo_recon_combined[z_range[0]:z_range[1],x_range[0]:x_range[1],y_slice],
+ title=f'recon combined yslice{y_slice}', path=path,
+ save_fig=self.save_figs, save_only=self.save_only, block=self.block)
+ quick_imshow(tomo_recon_combined[z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]],
+ title=f'recon combined zslice{z_slice}', path=path,
+ save_fig=self.save_figs, save_only=self.save_only, block=self.block)
+
+ # Save test data to file
+ # combined data order: row/z,x,y
+ if self.test_mode:
+ np.savetxt(f'{self.output_folder}/recon_combined.txt', tomo_recon_combined[
+ z_slice,x_range[0]:x_range[1],y_range[0]:y_range[1]], fmt='%.6e')
+
+ # Add image reconstruction to reconstructed data NXprocess
+ # combined data order: row/z,x,y
+ nxprocess.data = NXdata()
+ nxprocess.attrs['default'] = 'data'
+ if x_bounds is not None:
+ nxprocess.x_bounds = x_bounds
+ if y_bounds is not None:
+ nxprocess.y_bounds = y_bounds
+ if z_bounds is not None:
+ nxprocess.z_bounds = z_bounds
+ nxprocess.data['combined_data'] = tomo_recon_combined
+ nxprocess.data.attrs['signal'] = 'combined_data'
+
+ # Create a copy of the input Nexus object and remove reconstructed data
+ exclude_items = [f'{nxentry._name}/reconstructed_data/data',
+ f'{nxentry._name}/data/reconstructed_data']
+ nxroot_copy = nxcopy(nxroot, exclude_nxpaths=exclude_items)
+
+ # Add the combined data NXprocess to the new Nexus object
+ nxentry_copy = nxroot_copy[nxroot_copy.attrs['default']]
+ nxentry_copy.combined_data = nxprocess
+ if 'data' not in nxentry_copy:
+ nxentry_copy.data = NXdata()
+ nxentry_copy.attrs['default'] = 'data'
+ nxentry_copy.data.makelink(nxprocess.data.combined_data, name='combined_data')
+ nxentry_copy.data.attrs['signal'] = 'combined_data'
+
+ return(nxroot_copy)
+
+ def _gen_dark(self, nxentry, reduced_data):
+ """Generate dark field.
+ """
+ # Get the dark field images
+ image_key = nxentry.instrument.detector.get('image_key', None)
+ if image_key and 'data' in nxentry.instrument.detector:
+ field_indices = [index for index, key in enumerate(image_key) if key == 2]
+ tdf_stack = nxentry.instrument.detector.data[field_indices,:,:]
+ # RV the default NXtomo form does not accomodate bright or dark field stacks
+ else:
+ dark_field_scans = nxentry.spec_scans.dark_field
+ dark_field = FlatField.construct_from_nxcollection(dark_field_scans)
+ prefix = str(nxentry.instrument.detector.local_name)
+ tdf_stack = dark_field.get_detector_data(prefix)
+ if isinstance(tdf_stack, list):
+ assert(len(tdf_stack) == 1) # TODO
+ tdf_stack = tdf_stack[0]
+
+ # Take median
+ if tdf_stack.ndim == 2:
+ tdf = tdf_stack
+ elif tdf_stack.ndim == 3:
+ tdf = np.median(tdf_stack, axis=0)
+ del tdf_stack
+ else:
+ raise ValueError(f'Invalid tdf_stack shape ({tdf_stack.shape})')
+
+ # Remove dark field intensities above the cutoff
+#RV tdf_cutoff = None
+ tdf_cutoff = tdf.min()+2*(np.median(tdf)-tdf.min())
+ logger.debug(f'tdf_cutoff = {tdf_cutoff}')
+ if tdf_cutoff is not None:
+ if not is_num(tdf_cutoff, ge=0):
+ logger.warning(f'Ignoring illegal value of tdf_cutoff {tdf_cutoff}')
+ else:
+ tdf[tdf > tdf_cutoff] = np.nan
+ logger.debug(f'tdf_cutoff = {tdf_cutoff}')
+
+ # Remove nans
+ tdf_mean = np.nanmean(tdf)
+ logger.debug(f'tdf_mean = {tdf_mean}')
+ np.nan_to_num(tdf, copy=False, nan=tdf_mean, posinf=tdf_mean, neginf=0.)
+
+ # Plot dark field
+ if self.galaxy_flag:
+ quick_imshow(tdf, title='dark field', path='tomo_reduce_plots', save_fig=self.save_figs,
+ save_only=self.save_only)
+ elif not self.test_mode:
+ quick_imshow(tdf, title='dark field', path=self.output_folder, save_fig=self.save_figs,
+ save_only=self.save_only)
+ clear_imshow('dark field')
+# quick_imshow(tdf, title='dark field', block=True)
+
+ # Add dark field to reduced data NXprocess
+ reduced_data.data = NXdata()
+ reduced_data.data['dark_field'] = tdf
+
+ return(reduced_data)
+
+ def _gen_bright(self, nxentry, reduced_data):
+ """Generate bright field.
+ """
+ # Get the bright field images
+ image_key = nxentry.instrument.detector.get('image_key', None)
+ if image_key and 'data' in nxentry.instrument.detector:
+ field_indices = [index for index, key in enumerate(image_key) if key == 1]
+ tbf_stack = nxentry.instrument.detector.data[field_indices,:,:]
+ # RV the default NXtomo form does not accomodate bright or dark field stacks
+ else:
+ bright_field_scans = nxentry.spec_scans.bright_field
+ bright_field = FlatField.construct_from_nxcollection(bright_field_scans)
+ prefix = str(nxentry.instrument.detector.local_name)
+ tbf_stack = bright_field.get_detector_data(prefix)
+ if isinstance(tbf_stack, list):
+ assert(len(tbf_stack) == 1) # TODO
+ tbf_stack = tbf_stack[0]
+
+ # Take median if more than one image
+ """Median or mean: It may be best to try the median because of some image
+ artifacts that arise due to crinkles in the upstream kapton tape windows
+ causing some phase contrast images to appear on the detector.
+ One thing that also may be useful in a future implementation is to do a
+ brightfield adjustment on EACH frame of the tomo based on a ROI in the
+ corner of the frame where there is no sample but there is the direct X-ray
+ beam because there is frame to frame fluctuations from the incoming beam.
+ We don’t typically account for them but potentially could.
+ """
+ if tbf_stack.ndim == 2:
+ tbf = tbf_stack
+ elif tbf_stack.ndim == 3:
+ tbf = np.median(tbf_stack, axis=0)
+ del tbf_stack
+ else:
+ raise ValueError(f'Invalid tbf_stack shape ({tbf_stacks.shape})')
+
+ # Subtract dark field
+ if 'data' in reduced_data and 'dark_field' in reduced_data.data:
+ tbf -= reduced_data.data.dark_field
+ else:
+ logger.warning('Dark field unavailable')
+
+ # Set any non-positive values to one
+ # (avoid negative bright field values for spikes in dark field)
+ tbf[tbf < 1] = 1
+
+ # Plot bright field
+ if self.galaxy_flag:
+ quick_imshow(tbf, title='bright field', path='tomo_reduce_plots',
+ save_fig=self.save_figs, save_only=self.save_only)
+ elif not self.test_mode:
+ quick_imshow(tbf, title='bright field', path=self.output_folder,
+ save_fig=self.save_figs, save_only=self.save_only)
+ clear_imshow('bright field')
+# quick_imshow(tbf, title='bright field', block=True)
+
+ # Add bright field to reduced data NXprocess
+ if 'data' not in reduced_data:
+ reduced_data.data = NXdata()
+ reduced_data.data['bright_field'] = tbf
+
+ return(reduced_data)
+
+ def _set_detector_bounds(self, nxentry, reduced_data, img_x_bounds=None):
+ """Set vertical detector bounds for each image stack.
+ Right now the range is the same for each set in the image stack.
+ """
+ if self.test_mode:
+ return(tuple(self.test_config['img_x_bounds']))
+
+ # Get the first tomography image and the reference heights
+ image_key = nxentry.instrument.detector.get('image_key', None)
+ if image_key and 'data' in nxentry.instrument.detector:
+ field_indices = [index for index, key in enumerate(image_key) if key == 0]
+ first_image = np.asarray(nxentry.instrument.detector.data[field_indices[0],:,:])
+ theta = float(nxentry.sample.rotation_angle[field_indices[0]])
+ z_translation_all = nxentry.sample.z_translation[field_indices]
+ vertical_shifts = sorted(list(set(z_translation_all)))
+ num_tomo_stacks = len(vertical_shifts)
+ else:
+ tomo_field_scans = nxentry.spec_scans.tomo_fields
+ tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans)
+ vertical_shifts = tomo_fields.get_vertical_shifts()
+ if not isinstance(vertical_shifts, list):
+ vertical_shifts = [vertical_shifts]
+ prefix = str(nxentry.instrument.detector.local_name)
+ t0 = time()
+ first_image = tomo_fields.get_detector_data(prefix, tomo_fields.scan_numbers[0], 0)
+ logger.debug(f'Getting first image took {time()-t0:.2f} seconds')
+ num_tomo_stacks = len(tomo_fields.scan_numbers)
+ theta = tomo_fields.theta_range['start']
+
+ # Select image bounds
+ title = f'tomography image at theta={round(theta, 2)+0}'
+ if img_x_bounds is not None:
+ if is_int_pair(img_x_bounds) and img_x_bounds[0] == -1 and img_x_bounds[1] == -1:
+ img_x_bounds = None
+ elif not is_index_range(img_x_bounds, ge=0, le=first_image.shape[0]):
+ raise ValueError(f'Invalid parameter img_x_bounds ({img_x_bounds})')
+ if nxentry.instrument.source.attrs['station'] in ('id1a3', 'id3a'):
+ pixel_size = nxentry.instrument.detector.x_pixel_size
+ # Try to get a fit from the bright field
+ tbf = np.asarray(reduced_data.data.bright_field)
+ tbf_shape = tbf.shape
+ x_sum = np.sum(tbf, 1)
+ x_sum_min = x_sum.min()
+ x_sum_max = x_sum.max()
+ fit = Fit.fit_data(x_sum, 'rectangle', x=np.array(range(len(x_sum))), form='atan',
+ guess=True)
+ parameters = fit.best_values
+ x_low_fit = parameters.get('center1', None)
+ x_upp_fit = parameters.get('center2', None)
+ sig_low = parameters.get('sigma1', None)
+ sig_upp = parameters.get('sigma2', None)
+ have_fit = fit.success and x_low_fit is not None and x_upp_fit is not None and \
+ sig_low is not None and sig_upp is not None and \
+ 0 <= x_low_fit < x_upp_fit <= x_sum.size and \
+ (sig_low+sig_upp)/(x_upp_fit-x_low_fit) < 0.1
+ if have_fit:
+ # Set a 5% margin on each side
+ margin = 0.05*(x_upp_fit-x_low_fit)
+ x_low_fit = max(0, x_low_fit-margin)
+ x_upp_fit = min(tbf_shape[0], x_upp_fit+margin)
+ if num_tomo_stacks == 1:
+ if have_fit:
+ # Set the default range to enclose the full fitted window
+ x_low = int(x_low_fit)
+ x_upp = int(x_upp_fit)
+ else:
+ # Center a default range of 1 mm (RV: can we get this from the slits?)
+ num_x_min = int((1.0-0.5*pixel_size)/pixel_size)
+ x_low = int(0.5*(tbf_shape[0]-num_x_min))
+ x_upp = x_low+num_x_min
+ else:
+ # Get the default range from the reference heights
+ delta_z = vertical_shifts[1]-vertical_shifts[0]
+ for i in range(2, num_tomo_stacks):
+ delta_z = min(delta_z, vertical_shifts[i]-vertical_shifts[i-1])
+ logger.debug(f'delta_z = {delta_z}')
+ num_x_min = int((delta_z-0.5*pixel_size)/pixel_size)
+ logger.debug(f'num_x_min = {num_x_min}')
+ if num_x_min > tbf_shape[0]:
+ logger.warning('Image bounds and pixel size prevent seamless stacking')
+ if have_fit:
+ # Center the default range relative to the fitted window
+ x_low = int(0.5*(x_low_fit+x_upp_fit-num_x_min))
+ x_upp = x_low+num_x_min
+ else:
+ # Center the default range
+ x_low = int(0.5*(tbf_shape[0]-num_x_min))
+ x_upp = x_low+num_x_min
+ if self.galaxy_flag:
+ img_x_bounds = (x_low, x_upp)
+ else:
+ tmp = np.copy(tbf)
+ tmp_max = tmp.max()
+ tmp[x_low,:] = tmp_max
+ tmp[x_upp-1,:] = tmp_max
+ quick_imshow(tmp, title='bright field')
+ tmp = np.copy(first_image)
+ tmp_max = tmp.max()
+ tmp[x_low,:] = tmp_max
+ tmp[x_upp-1,:] = tmp_max
+ quick_imshow(tmp, title=title)
+ del tmp
+ quick_plot((range(x_sum.size), x_sum),
+ ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'),
+ ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'),
+ title='sum over theta and y')
+ print(f'lower bound = {x_low} (inclusive)')
+ print(f'upper bound = {x_upp} (exclusive)]')
+ accept = input_yesno('Accept these bounds (y/n)?', 'y')
+ clear_imshow('bright field')
+ clear_imshow(title)
+ clear_plot('sum over theta and y')
+ if accept:
+ img_x_bounds = (x_low, x_upp)
+ else:
+ while True:
+ mask, img_x_bounds = draw_mask_1d(x_sum, title='select x data range',
+ legend='sum over theta and y')
+ if len(img_x_bounds) == 1:
+ break
+ else:
+ print(f'Choose a single connected data range')
+ img_x_bounds = tuple(img_x_bounds[0])
+ if (num_tomo_stacks > 1 and img_x_bounds[1]-img_x_bounds[0]+1 <
+ int((delta_z-0.5*pixel_size)/pixel_size)):
+ logger.warning('Image bounds and pixel size prevent seamless stacking')
+ else:
+ if num_tomo_stacks > 1:
+ raise NotImplementedError('Selecting image bounds for multiple stacks on FMB')
+ # For FMB: use the first tomography image to select range
+ # RV: revisit if they do tomography with multiple stacks
+ x_sum = np.sum(first_image, 1)
+ x_sum_min = x_sum.min()
+ x_sum_max = x_sum.max()
+ if self.galaxy_flag:
+ if img_x_bounds is None:
+ img_x_bounds = (0, first_image.shape[0])
+ else:
+ quick_imshow(first_image, title=title)
+ print('Select vertical data reduction range from first tomography image')
+ img_x_bounds = select_image_bounds(first_image, 0, title=title)
+ clear_imshow(title)
+ if img_x_bounds is None:
+ raise ValueError('Unable to select image bounds')
+
+ # Plot results
+ if self.galaxy_flag:
+ path = 'tomo_reduce_plots'
+ else:
+ path = self.output_folder
+ x_low = img_x_bounds[0]
+ x_upp = img_x_bounds[1]
+ tmp = np.copy(first_image)
+ tmp_max = tmp.max()
+ tmp[x_low,:] = tmp_max
+ tmp[x_upp-1,:] = tmp_max
+ quick_imshow(tmp, title=title, path=path, save_fig=self.save_figs, save_only=self.save_only,
+ block=self.block)
+ del tmp
+ quick_plot((range(x_sum.size), x_sum),
+ ([x_low, x_low], [x_sum_min, x_sum_max], 'r-'),
+ ([x_upp, x_upp], [x_sum_min, x_sum_max], 'r-'),
+ title='sum over theta and y', path=path, save_fig=self.save_figs,
+ save_only=self.save_only, block=self.block)
+
+ return(img_x_bounds)
+
+ def _set_zoom_or_skip(self):
+ """Set zoom and/or theta skip to reduce memory the requirement for the analysis.
+ """
+# if input_yesno('\nDo you want to zoom in to reduce memory requirement (y/n)?', 'n'):
+# zoom_perc = input_int(' Enter zoom percentage', ge=1, le=100)
+# else:
+# zoom_perc = None
+ zoom_perc = None
+# if input_yesno('Do you want to skip thetas to reduce memory requirement (y/n)?', 'n'):
+# num_theta_skip = input_int(' Enter the number skip theta interval', ge=0,
+# lt=num_theta)
+# else:
+# num_theta_skip = None
+ num_theta_skip = None
+ logger.debug(f'zoom_perc = {zoom_perc}')
+ logger.debug(f'num_theta_skip = {num_theta_skip}')
+
+ return(zoom_perc, num_theta_skip)
+
+ def _gen_tomo(self, nxentry, reduced_data):
+ """Generate tomography fields.
+ """
+ # Get full bright field
+ tbf = np.asarray(reduced_data.data.bright_field)
+ tbf_shape = tbf.shape
+
+ # Get image bounds
+ img_x_bounds = tuple(reduced_data.get('img_x_bounds', (0, tbf_shape[0])))
+ img_y_bounds = tuple(reduced_data.get('img_y_bounds', (0, tbf_shape[1])))
+
+ # Get resized dark field
+# if 'dark_field' in data:
+# tbf = np.asarray(reduced_data.data.dark_field[
+# img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]])
+# else:
+# logger.warning('Dark field unavailable')
+# tdf = None
+ tdf = None
+
+ # Resize bright field
+ if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]):
+ tbf = tbf[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]]
+
+ # Get the tomography images
+ image_key = nxentry.instrument.detector.get('image_key', None)
+ if image_key and 'data' in nxentry.instrument.detector:
+ field_indices_all = [index for index, key in enumerate(image_key) if key == 0]
+ z_translation_all = nxentry.sample.z_translation[field_indices_all]
+ z_translation_levels = sorted(list(set(z_translation_all)))
+ num_tomo_stacks = len(z_translation_levels)
+ tomo_stacks = num_tomo_stacks*[np.array([])]
+ horizontal_shifts = []
+ vertical_shifts = []
+ thetas = None
+ tomo_stacks = []
+ for i, z_translation in enumerate(z_translation_levels):
+ field_indices = [field_indices_all[index]
+ for index, z in enumerate(z_translation_all) if z == z_translation]
+ horizontal_shift = list(set(nxentry.sample.x_translation[field_indices]))
+ assert(len(horizontal_shift) == 1)
+ horizontal_shifts += horizontal_shift
+ vertical_shift = list(set(nxentry.sample.z_translation[field_indices]))
+ assert(len(vertical_shift) == 1)
+ vertical_shifts += vertical_shift
+ sequence_numbers = nxentry.instrument.detector.sequence_number[field_indices]
+ if thetas is None:
+ thetas = np.asarray(nxentry.sample.rotation_angle[field_indices]) \
+ [sequence_numbers]
+ else:
+ assert(all(thetas[i] == nxentry.sample.rotation_angle[field_indices[index]]
+ for i, index in enumerate(sequence_numbers)))
+ assert(list(set(sequence_numbers)) == [i for i in range(len(sequence_numbers))])
+ if list(sequence_numbers) == [i for i in range(len(sequence_numbers))]:
+ tomo_stack = np.asarray(nxentry.instrument.detector.data[field_indices])
+ else:
+ raise ValueError('Unable to load the tomography images')
+ tomo_stacks.append(tomo_stack)
+ else:
+ tomo_field_scans = nxentry.spec_scans.tomo_fields
+ tomo_fields = TomoField.construct_from_nxcollection(tomo_field_scans)
+ horizontal_shifts = tomo_fields.get_horizontal_shifts()
+ vertical_shifts = tomo_fields.get_vertical_shifts()
+ prefix = str(nxentry.instrument.detector.local_name)
+ t0 = time()
+ tomo_stacks = tomo_fields.get_detector_data(prefix)
+ logger.debug(f'Getting tomography images took {time()-t0:.2f} seconds')
+ logger.debug(f'Getting all images took {time()-t0:.2f} seconds')
+ thetas = np.linspace(tomo_fields.theta_range['start'], tomo_fields.theta_range['end'],
+ tomo_fields.theta_range['num'])
+ if not isinstance(tomo_stacks, list):
+ horizontal_shifts = [horizontal_shifts]
+ vertical_shifts = [vertical_shifts]
+ tomo_stacks = [tomo_stacks]
+
+ reduced_tomo_stacks = []
+ if self.galaxy_flag:
+ path = 'tomo_reduce_plots'
+ else:
+ path = self.output_folder
+ for i, tomo_stack in enumerate(tomo_stacks):
+ # Resize the tomography images
+ # Right now the range is the same for each set in the image stack.
+ if img_x_bounds != (0, tbf.shape[0]) or img_y_bounds != (0, tbf.shape[1]):
+ t0 = time()
+ tomo_stack = tomo_stack[:,img_x_bounds[0]:img_x_bounds[1],
+ img_y_bounds[0]:img_y_bounds[1]].astype('float64')
+ logger.debug(f'Resizing tomography images took {time()-t0:.2f} seconds')
+
+ # Subtract dark field
+ if tdf is not None:
+ t0 = time()
+ with set_numexpr_threads(self.num_core):
+ ne.evaluate('tomo_stack-tdf', out=tomo_stack)
+ logger.debug(f'Subtracting dark field took {time()-t0:.2f} seconds')
+
+ # Normalize
+ t0 = time()
+ with set_numexpr_threads(self.num_core):
+ ne.evaluate('tomo_stack/tbf', out=tomo_stack, truediv=True)
+ logger.debug(f'Normalizing took {time()-t0:.2f} seconds')
+
+ # Remove non-positive values and linearize data
+ t0 = time()
+ cutoff = 1.e-6
+ with set_numexpr_threads(self.num_core):
+ ne.evaluate('where(tomo_stack num_core_tomopy_limit:
+ logger.debug(f'Running find_center_vo on {num_core_tomopy_limit} cores ...')
+ tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core_tomopy_limit)
+ else:
+ logger.debug(f'Running find_center_vo on {num_core} cores ...')
+ tomo_center = tomopy.find_center_vo(sinogram, ncore=num_core)
+ logger.debug(f'... done in {time()-t0:.2f} seconds')
+ logger.info(f'Finding the center using Nghia Vo’s method took {time()-t0:.2f} seconds')
+ center_offset_vo = tomo_center-center
+ logger.info(f'Center at row {row} using Nghia Vo’s method = {center_offset_vo:.2f}')
+ t0 = time()
+ logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...')
+ recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas,
+ eff_pixel_size, cross_sectional_dim, False, num_core)
+ logger.debug(f'... done in {time()-t0:.2f} seconds')
+ logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds')
+
+ title = f'edges row{row} center offset{center_offset_vo:.2f} Vo'
+ self._plot_edges_one_plane(recon_plane, title, path=path)
+
+ # Try using phase correlation method
+# if input_yesno('Try finding center using phase correlation (y/n)?', 'n'):
+# t0 = time()
+# logger.debug(f'Running find_center_pc ...')
+# tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=0.1, rotc_guess=tomo_center)
+# error = 1.
+# while error > tol:
+# prev = tomo_center
+# tomo_center = tomopy.find_center_pc(sinogram, sinogram, tol=tol,
+# rotc_guess=tomo_center)
+# error = np.abs(tomo_center-prev)
+# logger.debug(f'... done in {time()-t0:.2f} seconds')
+# logger.info('Finding the center using the phase correlation method took '+
+# f'{time()-t0:.2f} seconds')
+# center_offset = tomo_center-center
+# print(f'Center at row {row} using phase correlation = {center_offset:.2f}')
+# t0 = time()
+# logger.debug(f'Running _reconstruct_one_plane on {self.num_core} cores ...')
+# recon_plane = self._reconstruct_one_plane(sinogram_T, tomo_center, thetas,
+# eff_pixel_size, cross_sectional_dim, False, num_core)
+# logger.debug(f'... done in {time()-t0:.2f} seconds')
+# logger.info(f'Reconstructing row {row} took {time()-t0:.2f} seconds')
+#
+# title = f'edges row{row} center_offset{center_offset:.2f} PC'
+# self._plot_edges_one_plane(recon_plane, title, path=path)
+
+ # Select center location
+# if input_yesno('Accept a center location (y) or continue search (n)?', 'y'):
+ if True:
+# center_offset = input_num(' Enter chosen center offset', ge=-center, le=center,
+# default=center_offset_vo)
+ center_offset = center_offset_vo
+ del sinogram_T
+ del recon_plane
+ return float(center_offset)
+
+ # perform center finding search
+ while True:
+ center_offset_low = input_int('\nEnter lower bound for center offset', ge=-center,
+ le=center)
+ center_offset_upp = input_int('Enter upper bound for center offset',
+ ge=center_offset_low, le=center)
+ if center_offset_upp == center_offset_low:
+ center_offset_step = 1
+ else:
+ center_offset_step = input_int('Enter step size for center offset search', ge=1,
+ le=center_offset_upp-center_offset_low)
+ num_center_offset = 1+int((center_offset_upp-center_offset_low)/center_offset_step)
+ center_offsets = np.linspace(center_offset_low, center_offset_upp, num_center_offset)
+ for center_offset in center_offsets:
+ if center_offset == center_offset_vo:
+ continue
+ t0 = time()
+ logger.debug(f'Running _reconstruct_one_plane on {num_core} cores ...')
+ recon_plane = self._reconstruct_one_plane(sinogram_T, center_offset+center, thetas,
+ eff_pixel_size, cross_sectional_dim, False, num_core)
+ logger.debug(f'... done in {time()-t0:.2f} seconds')
+ logger.info(f'Reconstructing center_offset {center_offset} took '+
+ f'{time()-t0:.2f} seconds')
+ title = f'edges row{row} center_offset{center_offset:.2f}'
+ self._plot_edges_one_plane(recon_plane, title, path=path)
+ if input_int('\nContinue (0) or end the search (1)', ge=0, le=1):
+ break
+
+ del sinogram_T
+ del recon_plane
+ center_offset = input_num(' Enter chosen center offset', ge=-center, le=center)
+ return float(center_offset)
+
+ def _reconstruct_one_plane(self, tomo_plane_T, center, thetas, eff_pixel_size,
+ cross_sectional_dim, plot_sinogram=True, num_core=1):
+ """Invert the sinogram for a single tomography plane.
+ """
+ # tomo_plane_T index order: column,theta
+ assert(0 <= center < tomo_plane_T.shape[0])
+ center_offset = center-tomo_plane_T.shape[0]/2
+ two_offset = 2*int(np.round(center_offset))
+ two_offset_abs = np.abs(two_offset)
+ max_rad = int(0.55*(cross_sectional_dim/eff_pixel_size)) # 10% slack to avoid edge effects
+ if max_rad > 0.5*tomo_plane_T.shape[0]:
+ max_rad = 0.5*tomo_plane_T.shape[0]
+ dist_from_edge = max(1, int(np.floor((tomo_plane_T.shape[0]-two_offset_abs)/2.)-max_rad))
+ if two_offset >= 0:
+ logger.debug(f'sinogram range = [{two_offset+dist_from_edge}, {-dist_from_edge}]')
+ sinogram = tomo_plane_T[two_offset+dist_from_edge:-dist_from_edge,:]
+ else:
+ logger.debug(f'sinogram range = [{dist_from_edge}, {two_offset-dist_from_edge}]')
+ sinogram = tomo_plane_T[dist_from_edge:two_offset-dist_from_edge,:]
+ if not self.galaxy_flag and plot_sinogram:
+ quick_imshow(sinogram.T, f'sinogram center offset{center_offset:.2f}', aspect='auto',
+ path=self.output_folder, save_fig=self.save_figs, save_only=self.save_only,
+ block=self.block)
+
+ # Inverting sinogram
+ t0 = time()
+ recon_sinogram = iradon(sinogram, theta=thetas, circle=True)
+ logger.debug(f'Inverting sinogram took {time()-t0:.2f} seconds')
+ del sinogram
+
+ # Performing Gaussian filtering and removing ring artifacts
+ recon_parameters = None#self.config.get('recon_parameters')
+ if recon_parameters is None:
+ sigma = 1.0
+ ring_width = 15
+ else:
+ sigma = recon_parameters.get('gaussian_sigma', 1.0)
+ if not is_num(sigma, ge=0.0):
+ logger.warning(f'Invalid gaussian_sigma ({sigma}) in _reconstruct_one_plane, '+
+ 'set to a default value of 1.0')
+ sigma = 1.0
+ ring_width = recon_parameters.get('ring_width', 15)
+ if not is_int(ring_width, ge=0):
+ logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+
+ 'set to a default value of 15')
+ ring_width = 15
+ t0 = time()
+ recon_sinogram = spi.gaussian_filter(recon_sinogram, sigma, mode='nearest')
+ recon_clean = np.expand_dims(recon_sinogram, axis=0)
+ del recon_sinogram
+ recon_clean = tomopy.misc.corr.remove_ring(recon_clean, rwidth=ring_width, ncore=num_core)
+ logger.debug(f'Filtering and removing ring artifacts took {time()-t0:.2f} seconds')
+
+ return recon_clean
+
+ def _plot_edges_one_plane(self, recon_plane, title, path=None):
+ vis_parameters = None#self.config.get('vis_parameters')
+ if vis_parameters is None:
+ weight = 0.1
+ else:
+ weight = vis_parameters.get('denoise_weight', 0.1)
+ if not is_num(weight, ge=0.0):
+ logger.warning(f'Invalid weight ({weight}) in _plot_edges_one_plane, '+
+ 'set to a default value of 0.1')
+ weight = 0.1
+ edges = denoise_tv_chambolle(recon_plane, weight=weight)
+ vmax = np.max(edges[0,:,:])
+ vmin = -vmax
+ if path is None:
+ path = self.output_folder
+ quick_imshow(edges[0,:,:], f'{title} coolwarm', path=path, cmap='coolwarm',
+ save_fig=self.save_figs, save_only=self.save_only, block=self.block)
+ quick_imshow(edges[0,:,:], f'{title} gray', path=path, cmap='gray', vmin=vmin, vmax=vmax,
+ save_fig=self.save_figs, save_only=self.save_only, block=self.block)
+ del edges
+
+ def _reconstruct_one_tomo_stack(self, tomo_stack, thetas, center_offsets=[], num_core=1,
+ algorithm='gridrec'):
+ """Reconstruct a single tomography stack.
+ """
+ # tomo_stack order: row,theta,column
+ # input thetas must be in degrees
+ # centers_offset: tomography axis shift in pixels relative to column center
+ # RV should we remove stripes?
+ # https://tomopy.readthedocs.io/en/latest/api/tomopy.prep.stripe.html
+ # RV should we remove rings?
+ # https://tomopy.readthedocs.io/en/latest/api/tomopy.misc.corr.html
+ # RV: Add an option to do (extra) secondary iterations later or to do some sort of convergence test?
+ if not len(center_offsets):
+ centers = np.zeros((tomo_stack.shape[0]))
+ elif len(center_offsets) == 2:
+ centers = np.linspace(center_offsets[0], center_offsets[1], tomo_stack.shape[0])
+ else:
+ if center_offsets.size != tomo_stack.shape[0]:
+ raise ValueError('center_offsets dimension mismatch in reconstruct_one_tomo_stack')
+ centers = center_offsets
+ centers += tomo_stack.shape[2]/2
+
+ # Get reconstruction parameters
+ recon_parameters = None#self.config.get('recon_parameters')
+ if recon_parameters is None:
+ sigma = 2.0
+ secondary_iters = 0
+ ring_width = 15
+ else:
+ sigma = recon_parameters.get('stripe_fw_sigma', 2.0)
+ if not is_num(sigma, ge=0):
+ logger.warning(f'Invalid stripe_fw_sigma ({sigma}) in '+
+ '_reconstruct_one_tomo_stack, set to a default value of 2.0')
+ ring_width = 15
+ secondary_iters = recon_parameters.get('secondary_iters', 0)
+ if not is_int(secondary_iters, ge=0):
+ logger.warning(f'Invalid secondary_iters ({secondary_iters}) in '+
+ '_reconstruct_one_tomo_stack, set to a default value of 0 (skip them)')
+ ring_width = 0
+ ring_width = recon_parameters.get('ring_width', 15)
+ if not is_int(ring_width, ge=0):
+ logger.warning(f'Invalid ring_width ({ring_width}) in _reconstruct_one_plane, '+
+ 'set to a default value of 15')
+ ring_width = 15
+
+ # Remove horizontal stripe
+ t0 = time()
+ if num_core > num_core_tomopy_limit:
+ logger.debug('Running remove_stripe_fw on {num_core_tomopy_limit} cores ...')
+ tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma,
+ ncore=num_core_tomopy_limit)
+ else:
+ logger.debug(f'Running remove_stripe_fw on {num_core} cores ...')
+ tomo_stack = tomopy.prep.stripe.remove_stripe_fw(tomo_stack, sigma=sigma,
+ ncore=num_core)
+ logger.debug(f'... tomopy.prep.stripe.remove_stripe_fw took {time()-t0:.2f} seconds')
+
+ # Perform initial image reconstruction
+ logger.debug('Performing initial image reconstruction')
+ t0 = time()
+ logger.debug(f'Running recon on {num_core} cores ...')
+ tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers,
+ sinogram_order=True, algorithm=algorithm, ncore=num_core)
+ logger.debug(f'... done in {time()-t0:.2f} seconds')
+ logger.info(f'Performing initial image reconstruction took {time()-t0:.2f} seconds')
+
+ # Run optional secondary iterations
+ if secondary_iters > 0:
+ logger.debug(f'Running {secondary_iters} secondary iterations')
+ #options = {'method':'SIRT_CUDA', 'proj_type':'cuda', 'num_iter':secondary_iters}
+ #RV: doesn't work for me:
+ #"Error: CUDA error 803: system has unsupported display driver/cuda driver combination."
+ #options = {'method':'SIRT', 'proj_type':'linear', 'MinConstraint': 0, 'num_iter':secondary_iters}
+ #SIRT did not finish while running overnight
+ #options = {'method':'SART', 'proj_type':'linear', 'num_iter':secondary_iters}
+ options = {'method':'SART', 'proj_type':'linear', 'MinConstraint': 0,
+ 'num_iter':secondary_iters}
+ t0 = time()
+ logger.debug(f'Running recon on {num_core} cores ...')
+ tomo_recon_stack = tomopy.recon(tomo_stack, np.radians(thetas), centers,
+ init_recon=tomo_recon_stack, options=options, sinogram_order=True,
+ algorithm=tomopy.astra, ncore=num_core)
+ logger.debug(f'... done in {time()-t0:.2f} seconds')
+ logger.info(f'Performing secondary iterations took {time()-t0:.2f} seconds')
+
+ # Remove ring artifacts
+ t0 = time()
+ tomopy.misc.corr.remove_ring(tomo_recon_stack, rwidth=ring_width, out=tomo_recon_stack,
+ ncore=num_core)
+ logger.debug(f'Removing ring artifacts took {time()-t0:.2f} seconds')
+
+ return tomo_recon_stack
+
+ def _resize_reconstructed_data(self, data, z_only=False):
+ """Resize the reconstructed tomography data.
+ """
+ # Data order: row(z),x,y or stack,row(z),x,y
+ if isinstance(data, list):
+ for stack in data:
+ assert(stack.ndim == 3)
+ num_tomo_stacks = len(data)
+ tomo_recon_stacks = data
+ else:
+ assert(data.ndim == 3)
+ num_tomo_stacks = 1
+ tomo_recon_stacks = [data]
+
+ if z_only:
+ x_bounds = None
+ y_bounds = None
+ else:
+ # Selecting x bounds (in yz-plane)
+ tomosum = 0
+ [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,2))
+ for i in range(num_tomo_stacks)]
+ select_x_bounds = input_yesno('\nDo you want to change the image x-bounds (y/n)?', 'y')
+ if not select_x_bounds:
+ x_bounds = None
+ else:
+ accept = False
+ index_ranges = None
+ while not accept:
+ mask, x_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges,
+ title='select x data range', legend='recon stack sum yz')
+ while len(x_bounds) != 1:
+ print('Please select exactly one continuous range')
+ mask, x_bounds = draw_mask_1d(tomosum, title='select x data range',
+ legend='recon stack sum yz')
+ x_bounds = x_bounds[0]
+# quick_plot(tomosum, vlines=x_bounds, title='recon stack sum yz')
+# print(f'x_bounds = {x_bounds} (lower bound inclusive, upper bound '+
+# 'exclusive)')
+# accept = input_yesno('Accept these bounds (y/n)?', 'y')
+ accept = True
+ logger.debug(f'x_bounds = {x_bounds}')
+
+ # Selecting y bounds (in xz-plane)
+ tomosum = 0
+ [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(0,1))
+ for i in range(num_tomo_stacks)]
+ select_y_bounds = input_yesno('\nDo you want to change the image y-bounds (y/n)?', 'y')
+ if not select_y_bounds:
+ y_bounds = None
+ else:
+ accept = False
+ index_ranges = None
+ while not accept:
+ mask, y_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges,
+ title='select x data range', legend='recon stack sum xz')
+ while len(y_bounds) != 1:
+ print('Please select exactly one continuous range')
+ mask, y_bounds = draw_mask_1d(tomosum, title='select x data range',
+ legend='recon stack sum xz')
+ y_bounds = y_bounds[0]
+# quick_plot(tomosum, vlines=y_bounds, title='recon stack sum xz')
+# print(f'y_bounds = {y_bounds} (lower bound inclusive, upper bound '+
+# 'exclusive)')
+# accept = input_yesno('Accept these bounds (y/n)?', 'y')
+ accept = True
+ logger.debug(f'y_bounds = {y_bounds}')
+
+ # Selecting z bounds (in xy-plane) (only valid for a single image stack)
+ if num_tomo_stacks != 1:
+ z_bounds = None
+ else:
+ tomosum = 0
+ [tomosum := tomosum+np.sum(tomo_recon_stacks[i], axis=(1,2))
+ for i in range(num_tomo_stacks)]
+ select_z_bounds = input_yesno('Do you want to change the image z-bounds (y/n)?', 'n')
+ if not select_z_bounds:
+ z_bounds = None
+ else:
+ accept = False
+ index_ranges = None
+ while not accept:
+ mask, z_bounds = draw_mask_1d(tomosum, current_index_ranges=index_ranges,
+ title='select x data range', legend='recon stack sum xy')
+ while len(z_bounds) != 1:
+ print('Please select exactly one continuous range')
+ mask, z_bounds = draw_mask_1d(tomosum, title='select x data range',
+ legend='recon stack sum xy')
+ z_bounds = z_bounds[0]
+# quick_plot(tomosum, vlines=z_bounds, title='recon stack sum xy')
+# print(f'z_bounds = {z_bounds} (lower bound inclusive, upper bound '+
+# 'exclusive)')
+# accept = input_yesno('Accept these bounds (y/n)?', 'y')
+ accept = True
+ logger.debug(f'z_bounds = {z_bounds}')
+
+ return(x_bounds, y_bounds, z_bounds)
+
+
+def run_tomo(input_file:str, output_file:str, modes:list[str], center_file=None, num_core=-1,
+ output_folder='.', save_figs='no', test_mode=False) -> None:
+
+ if test_mode:
+ logging_format = '%(asctime)s : %(levelname)s - %(module)s : %(funcName)s - %(message)s'
+ level = logging.getLevelName('INFO')
+ logging.basicConfig(filename=f'{output_folder}/tomo.log', filemode='w',
+ format=logging_format, level=level, force=True)
+ logger.info(f'input_file = {input_file}')
+ logger.info(f'center_file = {center_file}')
+ logger.info(f'output_file = {output_file}')
+ logger.debug(f'modes= {modes}')
+ logger.debug(f'num_core= {num_core}')
+ logger.info(f'output_folder = {output_folder}')
+ logger.info(f'save_figs = {save_figs}')
+ logger.info(f'test_mode = {test_mode}')
+
+ # Check for correction modes
+ legal_modes = ['reduce_data', 'find_center', 'reconstruct_data', 'combine_data', 'all']
+ if modes is None:
+ modes = ['all']
+ if not all(True if mode in legal_modes else False for mode in modes):
+ raise ValueError(f'Invalid parameter modes ({modes})')
+
+ # Instantiate Tomo object
+ tomo = Tomo(num_core=num_core, output_folder=output_folder, save_figs=save_figs,
+ test_mode=test_mode)
+
+ # Read input file
+ data = tomo.read(input_file)
+
+ # Generate reduced tomography images
+ if 'reduce_data' in modes or 'all' in modes:
+ data = tomo.gen_reduced_data(data)
+
+ # Find rotation axis centers for the tomography stacks.
+ center_data = None
+ if 'find_center' in modes or 'all' in modes:
+ center_data = tomo.find_centers(data)
+
+ # Reconstruct tomography stacks
+ if 'reconstruct_data' in modes or 'all' in modes:
+ if center_data is None:
+ # Read input file
+ center_data = tomo.read(center_file)
+ data = tomo.reconstruct_data(data, center_data)
+ center_data = None
+
+ # Combine reconstructed tomography stacks
+ if 'combine_data' in modes or 'all' in modes:
+ data = tomo.combine_data(data)
+
+ # Write output file
+ if data is not None and not test_mode:
+ if center_data is None:
+ data = tomo.write(data, output_file)
+ else:
+ data = tomo.write(center_data, output_file)
+
+ logger.info(f'Completed modes: {modes}')