comparison build/lib/CHAP/processor.py @ 0:cbbe42422d56 draft

planemo upload for repository https://github.com/CHESSComputing/ChessAnalysisPipeline/tree/galaxy commit 1401a7e1ae007a6bda260d147f9b879e789b73e0-dirty
author kls286
date Tue, 28 Mar 2023 15:07:30 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:cbbe42422d56
1 #!/usr/bin/env python
2 #-*- coding: utf-8 -*-
3 #pylint: disable=
4 """
5 File : processor.py
6 Author : Valentin Kuznetsov <vkuznet AT gmail dot com>
7 Description: Processor module
8 """
9
10 # system modules
11 import argparse
12 import json
13 import logging
14 import sys
15 from time import time
16
17 # local modules
18 # from pipeline import PipelineObject
19
20 class Processor():
21 """
22 Processor represent generic processor
23 """
24 def __init__(self):
25 """
26 Processor constructor
27 """
28 self.__name__ = self.__class__.__name__
29 self.logger = logging.getLogger(self.__name__)
30 self.logger.propagate = False
31
32 def process(self, data):
33 """
34 process data API
35 """
36
37 t0 = time()
38 self.logger.info(f'Executing "process" with type(data)={type(data)}')
39
40 data = self._process(data)
41
42 self.logger.info(f'Finished "process" in {time()-t0:.3f} seconds\n')
43
44 return(data)
45
46 def _process(self, data):
47 # If needed, extract data from a returned value of Reader.read
48 if isinstance(data, list):
49 if all([isinstance(d,dict) for d in data]):
50 data = data[0]['data']
51 # process operation is a simple print function
52 data += "process part\n"
53 # and we return data back to pipeline
54 return data
55
56
57 class TFaaSImageProcessor(Processor):
58 '''
59 A Processor to get predictions from TFaaS inference server.
60 '''
61 def process(self, data, url, model, verbose=False):
62 """
63 process data API
64 """
65
66 t0 = time()
67 self.logger.info(f'Executing "process" with url {url} model {model}')
68
69 data = self._process(data, url, model, verbose)
70
71 self.logger.info(f'Finished "process" in {time()-t0:.3f} seconds\n')
72
73 return(data)
74
75 def _process(self, data, url, model, verbose):
76 '''Print and return the input data.
77
78 :param data: Input image data, either file name or actual image data
79 :type data: object
80 :return: `data`
81 :rtype: object
82 '''
83 from MLaaS.tfaas_client import predictImage
84 from pathlib import Path
85 self.logger.info(f"input data {type(data)}")
86 if isinstance(data, str) and Path(data).is_file():
87 imgFile = data
88 data = predictImage(url, imgFile, model, verbose)
89 else:
90 rdict = data[0]
91 import requests
92 img = rdict['data']
93 session = requests.Session()
94 rurl = url + '/predict/image'
95 payload = dict(model=model)
96 files = dict(image=img)
97 self.logger.info(f"HTTP request {rurl} with image file and {payload} payload")
98 req = session.post(rurl, files=files, data=payload )
99 data = req.content
100 data = data.decode("utf-8").replace('\n', '')
101 self.logger.info(f"HTTP response {data}")
102
103 return(data)
104
105 class URLResponseProcessor(Processor):
106 def _process(self, data):
107 '''Take data returned from URLReader.read and return a decoded version of
108 the content.
109
110 :param data: input data (output of URLReader.read)
111 :type data: list[dict]
112 :return: decoded data contents
113 :rtype: object
114 '''
115
116 data = data[0]
117
118 content = data['data']
119 encoding = data['encoding']
120
121 self.logger.debug(f'Decoding content of type {type(content)} with {encoding}')
122
123 try:
124 content = content.decode(encoding)
125 except:
126 self.logger.warning(f'Failed to decode content of type {type(content)} with {encoding}')
127
128 return(content)
129
130 class PrintProcessor(Processor):
131 '''A Processor to simply print the input data to stdout and return the
132 original input data, unchanged in any way.
133 '''
134
135 def _process(self, data):
136 '''Print and return the input data.
137
138 :param data: Input data
139 :type data: object
140 :return: `data`
141 :rtype: object
142 '''
143
144 print(f'{self.__name__} data :')
145
146 if callable(getattr(data, '_str_tree', None)):
147 # If data is likely an NXobject, print its tree representation
148 # (since NXobjects' str representations are just their nxname -- not
149 # very helpful).
150 print(data._str_tree(attrs=True, recursive=True))
151 else:
152 print(str(data))
153
154 return(data)
155
156 class NexusToNumpyProcessor(Processor):
157 '''A class to convert the default plottable data in an `NXobject` into an
158 `numpy.ndarray`.
159 '''
160
161 def _process(self, data):
162 '''Return the default plottable data signal in `data` as an
163 `numpy.ndarray`.
164
165 :param data: input NeXus structure
166 :type data: nexusformat.nexus.tree.NXobject
167 :raises ValueError: if `data` has no default plottable data signal
168 :return: default plottable data signal in `data`
169 :rtype: numpy.ndarray
170 '''
171
172 default_data = data.plottable_data
173
174 if default_data is None:
175 default_data_path = data.attrs['default']
176 default_data = data.get(default_data_path)
177 if default_data is None:
178 raise(ValueError(f'The structure of {data} contains no default data'))
179
180 default_signal = default_data.attrs.get('signal')
181 if default_signal is None:
182 raise(ValueError(f'The signal of {default_data} is unknown'))
183 default_signal = default_signal.nxdata
184
185 np_data = default_data[default_signal].nxdata
186
187 return(np_data)
188
189 class NexusToXarrayProcessor(Processor):
190 '''A class to convert the default plottable data in an `NXobject` into an
191 `xarray.DataArray`.'''
192
193 def _process(self, data):
194 '''Return the default plottable data signal in `data` as an
195 `xarray.DataArray`.
196
197 :param data: input NeXus structure
198 :type data: nexusformat.nexus.tree.NXobject
199 :raises ValueError: if metadata for `xarray` is absen from `data`
200 :return: default plottable data signal in `data`
201 :rtype: xarray.DataArray
202 '''
203
204 from xarray import DataArray
205
206 default_data = data.plottable_data
207
208 if default_data is None:
209 default_data_path = data.attrs['default']
210 default_data = data.get(default_data_path)
211 if default_data is None:
212 raise(ValueError(f'The structure of {data} contains no default data'))
213
214 default_signal = default_data.attrs.get('signal')
215 if default_signal is None:
216 raise(ValueError(f'The signal of {default_data} is unknown'))
217 default_signal = default_signal.nxdata
218
219 signal_data = default_data[default_signal].nxdata
220
221 axes = default_data.attrs['axes']
222 coords = {}
223 for axis_name in axes:
224 axis = default_data[axis_name]
225 coords[axis_name] = (axis_name,
226 axis.nxdata,
227 axis.attrs)
228
229 dims = tuple(axes)
230
231 name = default_signal
232
233 attrs = default_data[default_signal].attrs
234
235 return(DataArray(data=signal_data,
236 coords=coords,
237 dims=dims,
238 name=name,
239 attrs=attrs))
240
241 class XarrayToNexusProcessor(Processor):
242 '''A class to convert the data in an `xarray` structure to an
243 `nexusformat.nexus.NXdata`.
244 '''
245
246 def _process(self, data):
247 '''Return `data` represented as an `nexusformat.nexus.NXdata`.
248
249 :param data: The input `xarray` structure
250 :type data: typing.Union[xarray.DataArray, xarray.Dataset]
251 :return: The data and metadata in `data`
252 :rtype: nexusformat.nexus.NXdata
253 '''
254
255 from nexusformat.nexus import NXdata, NXfield
256
257 signal = NXfield(value=data.data, name=data.name, attrs=data.attrs)
258
259 axes = []
260 for name, coord in data.coords.items():
261 axes.append(NXfield(value=coord.data, name=name, attrs=coord.attrs))
262 axes = tuple(axes)
263
264 return(NXdata(signal=signal, axes=axes))
265
266 class XarrayToNumpyProcessor(Processor):
267 '''A class to convert the data in an `xarray.DataArray` structure to an
268 `numpy.ndarray`.
269 '''
270
271 def _process(self, data):
272 '''Return just the signal values contained in `data`.
273
274 :param data: The input `xarray.DataArray`
275 :type data: xarray.DataArray
276 :return: The data in `data`
277 :rtype: numpy.ndarray
278 '''
279
280 return(data.data)
281
282 class MapProcessor(Processor):
283 '''Class representing a process that takes a map configuration and returns a
284 `nexusformat.nexus.NXentry` representing that map's metadata and any
285 scalar-valued raw data requseted by the supplied map configuration.
286 '''
287
288 def _process(self, data):
289 '''Process the output of a `Reader` that contains a map configuration and
290 return a `nexusformat.nexus.NXentry` representing the map.
291
292 :param data: Result of `Reader.read` where at least one item has the
293 value `'MapConfig'` for the `'schema'` key.
294 :type data: list[dict[str,object]]
295 :return: Map data & metadata (SPEC only, no detector)
296 :rtype: nexusformat.nexus.NXentry
297 '''
298
299 map_config = self.get_map_config(data)
300 nxentry = self.__class__.get_nxentry(map_config)
301
302 return(nxentry)
303
304 def get_map_config(self, data):
305 '''Get an instance of `MapConfig` from a returned value of `Reader.read`
306
307 :param data: Result of `Reader.read` where at least one item has the
308 value `'MapConfig'` for the `'schema'` key.
309 :type data: list[dict[str,object]]
310 :raises Exception: If a valid `MapConfig` cannot be constructed from `data`.
311 :return: a valid instance of `MapConfig` with field values taken from `data`.
312 :rtype: MapConfig
313 '''
314
315 from CHAP.models.map import MapConfig
316
317 map_config = False
318 if isinstance(data, list):
319 for item in data:
320 if isinstance(item, dict):
321 if item.get('schema') == 'MapConfig':
322 map_config = item.get('data')
323 break
324
325 if not map_config:
326 raise(ValueError('No map configuration found'))
327
328 return(MapConfig(**map_config))
329
330 @staticmethod
331 def get_nxentry(map_config):
332 '''Use a `MapConfig` to construct a `nexusformat.nexus.NXentry`
333
334 :param map_config: a valid map configuration
335 :type map_config: MapConfig
336 :return: the map's data and metadata contained in a NeXus structure
337 :rtype: nexusformat.nexus.NXentry
338 '''
339
340 from nexusformat.nexus import (NXcollection,
341 NXdata,
342 NXentry,
343 NXfield,
344 NXsample)
345 import numpy as np
346
347 nxentry = NXentry(name=map_config.title)
348
349 nxentry.map_config = json.dumps(map_config.dict())
350
351 nxentry[map_config.sample.name] = NXsample(**map_config.sample.dict())
352
353 nxentry.attrs['station'] = map_config.station
354
355 nxentry.spec_scans = NXcollection()
356 for scans in map_config.spec_scans:
357 nxentry.spec_scans[scans.scanparsers[0].scan_name] = \
358 NXfield(value=scans.scan_numbers,
359 dtype='int8',
360 attrs={'spec_file':str(scans.spec_file)})
361
362 nxentry.data = NXdata()
363 nxentry.data.attrs['axes'] = map_config.dims
364 for i,dim in enumerate(map_config.independent_dimensions[::-1]):
365 nxentry.data[dim.label] = NXfield(value=map_config.coords[dim.label],
366 units=dim.units,
367 attrs={'long_name': f'{dim.label} ({dim.units})',
368 'data_type': dim.data_type,
369 'local_name': dim.name})
370 nxentry.data.attrs[f'{dim.label}_indices'] = i
371
372 signal = False
373 auxilliary_signals = []
374 for data in map_config.all_scalar_data:
375 nxentry.data[data.label] = NXfield(value=np.empty(map_config.shape),
376 units=data.units,
377 attrs={'long_name': f'{data.label} ({data.units})',
378 'data_type': data.data_type,
379 'local_name': data.name})
380 if not signal:
381 signal = data.label
382 else:
383 auxilliary_signals.append(data.label)
384
385 if signal:
386 nxentry.data.attrs['signal'] = signal
387 nxentry.data.attrs['auxilliary_signals'] = auxilliary_signals
388
389 for scans in map_config.spec_scans:
390 for scan_number in scans.scan_numbers:
391 scanparser = scans.get_scanparser(scan_number)
392 for scan_step_index in range(scanparser.spec_scan_npts):
393 map_index = scans.get_index(scan_number, scan_step_index, map_config)
394 for data in map_config.all_scalar_data:
395 nxentry.data[data.label][map_index] = data.get_value(scans, scan_number, scan_step_index)
396
397 return(nxentry)
398
399 class IntegrationProcessor(Processor):
400 '''Class for integrating 2D detector data
401 '''
402
403 def _process(self, data):
404 '''Integrate the input data with the integration method and keyword
405 arguments supplied and return the results.
406
407 :param data: input data, including raw data, integration method, and
408 keyword args for the integration method.
409 :type data: tuple[typing.Union[numpy.ndarray, list[numpy.ndarray]],
410 callable,
411 dict]
412 :param integration_method: the method of a
413 `pyFAI.azimuthalIntegrator.AzimuthalIntegrator` or
414 `pyFAI.multi_geometry.MultiGeometry` that returns the desired
415 integration results.
416 :return: integrated raw data
417 :rtype: pyFAI.containers.IntegrateResult
418 '''
419
420 detector_data, integration_method, integration_kwargs = data
421
422 return(integration_method(detector_data, **integration_kwargs))
423
424 class IntegrateMapProcessor(Processor):
425 '''Class representing a process that takes a map and integration
426 configuration and returns a `nexusformat.nexus.NXprocess` containing a map of
427 the integrated detector data requested.
428 '''
429
430 def _process(self, data):
431 '''Process the output of a `Reader` that contains a map and integration
432 configuration and return a `nexusformat.nexus.NXprocess` containing a map
433 of the integrated detector data requested
434
435 :param data: Result of `Reader.read` where at least one item has the
436 value `'MapConfig'` for the `'schema'` key, and at least one item has
437 the value `'IntegrationConfig'` for the `'schema'` key.
438 :type data: list[dict[str,object]]
439 :return: integrated data and process metadata
440 :rtype: nexusformat.nexus.NXprocess
441 '''
442
443 map_config, integration_config = self.get_configs(data)
444 nxprocess = self.get_nxprocess(map_config, integration_config)
445
446 return(nxprocess)
447
448 def get_configs(self, data):
449 '''Return valid instances of `MapConfig` and `IntegrationConfig` from the
450 input supplied by `MultipleReader`.
451
452 :param data: Result of `Reader.read` where at least one item has the
453 value `'MapConfig'` for the `'schema'` key, and at least one item has
454 the value `'IntegrationConfig'` for the `'schema'` key.
455 :type data: list[dict[str,object]]
456 :raises ValueError: if `data` cannot be parsed into map and integration configurations.
457 :return: valid map and integration configuration objects.
458 :rtype: tuple[MapConfig, IntegrationConfig]
459 '''
460
461 self.logger.debug('Getting configuration objects')
462 t0 = time()
463
464 from CHAP.models.map import MapConfig
465 from CHAP.models.integration import IntegrationConfig
466
467 map_config = False
468 integration_config = False
469 if isinstance(data, list):
470 for item in data:
471 if isinstance(item, dict):
472 schema = item.get('schema')
473 if schema == 'MapConfig':
474 map_config = item.get('data')
475 elif schema == 'IntegrationConfig':
476 integration_config = item.get('data')
477
478 if not map_config:
479 raise(ValueError('No map configuration found'))
480 if not integration_config:
481 raise(ValueError('No integration configuration found'))
482
483 map_config = MapConfig(**map_config)
484 integration_config = IntegrationConfig(**integration_config)
485
486 self.logger.debug(f'Got configuration objects in {time()-t0:.3f} seconds')
487
488 return(map_config, integration_config)
489
490 def get_nxprocess(self, map_config, integration_config):
491 '''Use a `MapConfig` and `IntegrationConfig` to construct a
492 `nexusformat.nexus.NXprocess`
493
494 :param map_config: a valid map configuration
495 :type map_config: MapConfig
496 :param integration_config: a valid integration configuration
497 :type integration_config" IntegrationConfig
498 :return: the integrated detector data and metadata contained in a NeXus
499 structure
500 :rtype: nexusformat.nexus.NXprocess
501 '''
502
503 self.logger.debug('Constructing NXprocess')
504 t0 = time()
505
506 from nexusformat.nexus import (NXdata,
507 NXdetector,
508 NXfield,
509 NXprocess)
510 import numpy as np
511 import pyFAI
512
513 nxprocess = NXprocess(name=integration_config.title)
514
515 nxprocess.map_config = json.dumps(map_config.dict())
516 nxprocess.integration_config = json.dumps(integration_config.dict())
517
518 nxprocess.program = 'pyFAI'
519 nxprocess.version = pyFAI.version
520
521 for k,v in integration_config.dict().items():
522 if k == 'detectors':
523 continue
524 nxprocess.attrs[k] = v
525
526 for detector in integration_config.detectors:
527 nxprocess[detector.prefix] = NXdetector()
528 nxprocess[detector.prefix].local_name = detector.prefix
529 nxprocess[detector.prefix].distance = detector.azimuthal_integrator.dist
530 nxprocess[detector.prefix].distance.attrs['units'] = 'm'
531 nxprocess[detector.prefix].calibration_wavelength = detector.azimuthal_integrator.wavelength
532 nxprocess[detector.prefix].calibration_wavelength.attrs['units'] = 'm'
533 nxprocess[detector.prefix].attrs['poni_file'] = str(detector.poni_file)
534 nxprocess[detector.prefix].attrs['mask_file'] = str(detector.mask_file)
535 nxprocess[detector.prefix].raw_data_files = np.full(map_config.shape, '', dtype='|S256')
536
537 nxprocess.data = NXdata()
538
539 nxprocess.data.attrs['axes'] = (*map_config.dims, *integration_config.integrated_data_dims)
540 for i,dim in enumerate(map_config.independent_dimensions[::-1]):
541 nxprocess.data[dim.label] = NXfield(value=map_config.coords[dim.label],
542 units=dim.units,
543 attrs={'long_name': f'{dim.label} ({dim.units})',
544 'data_type': dim.data_type,
545 'local_name': dim.name})
546 nxprocess.data.attrs[f'{dim.label}_indices'] = i
547
548 for i,(coord_name,coord_values) in enumerate(integration_config.integrated_data_coordinates.items()):
549 if coord_name == 'radial':
550 type_ = pyFAI.units.RADIAL_UNITS
551 elif coord_name == 'azimuthal':
552 type_ = pyFAI.units.AZIMUTHAL_UNITS
553 coord_units = pyFAI.units.to_unit(getattr(integration_config, f'{coord_name}_units'), type_=type_)
554 nxprocess.data[coord_units.name] = coord_values
555 nxprocess.data.attrs[f'{coord_units.name}_indices'] = i+len(map_config.coords)
556 nxprocess.data[coord_units.name].units = coord_units.unit_symbol
557 nxprocess.data[coord_units.name].attrs['long_name'] = coord_units.label
558
559 nxprocess.data.attrs['signal'] = 'I'
560 nxprocess.data.I = NXfield(value=np.empty((*tuple([len(coord_values) for coord_name,coord_values in map_config.coords.items()][::-1]), *integration_config.integrated_data_shape)),
561 units='a.u',
562 attrs={'long_name':'Intensity (a.u)'})
563
564 integrator = integration_config.get_multi_geometry_integrator()
565 if integration_config.integration_type == 'azimuthal':
566 integration_method = integrator.integrate1d
567 integration_kwargs = {
568 'lst_mask': [detector.mask_array for detector in integration_config.detectors],
569 'npt': integration_config.radial_npt
570 }
571 elif integration_config.integration_type == 'cake':
572 integration_method = integrator.integrate2d
573 integration_kwargs = {
574 'lst_mask': [detector.mask_array for detector in integration_config.detectors],
575 'npt_rad': integration_config.radial_npt,
576 'npt_azim': integration_config.azimuthal_npt,
577 'method': 'bbox'
578 }
579
580 integration_processor = IntegrationProcessor()
581 integration_processor.logger.setLevel(self.logger.getEffectiveLevel())
582 integration_processor.logger.addHandler(self.logger.handlers[0])
583 lst_args = []
584 for scans in map_config.spec_scans:
585 for scan_number in scans.scan_numbers:
586 scanparser = scans.get_scanparser(scan_number)
587 for scan_step_index in range(scanparser.spec_scan_npts):
588 map_index = scans.get_index(scan_number, scan_step_index, map_config)
589 detector_data = scans.get_detector_data(integration_config.detectors, scan_number, scan_step_index)
590 result = integration_processor.process((detector_data, integration_method, integration_kwargs))
591 nxprocess.data.I[map_index] = result.intensity
592 for detector in integration_config.detectors:
593 nxprocess[detector.prefix].raw_data_files[map_index] = scanparser.get_detector_data_file(detector.prefix, scan_step_index)
594
595 self.logger.debug(f'Constructed NXprocess in {time()-t0:.3f} seconds')
596
597 return(nxprocess)
598
599 class MCACeriaCalibrationProcessor(Processor):
600 '''Class representing the procedure to use a CeO2 scan to obtain tuned values
601 for the bragg diffraction angle and linear correction parameters for MCA
602 channel energies for an EDD experimental setup.
603 '''
604
605 def _process(self, data):
606 '''Return tuned values for 2&theta and linear correction parameters for
607 the MCA channel energies.
608
609 :param data: input configuration for the raw data & tuning procedure
610 :type data: list[dict[str,object]]
611 :return: original configuration dictionary with tuned values added
612 :rtype: dict[str,float]
613 '''
614
615 calibration_config = self.get_config(data)
616
617 tth, slope, intercept = self.calibrate(calibration_config)
618
619 calibration_config.tth_calibrated = tth
620 calibration_config.slope_calibrated = slope
621 calibration_config.intercept_calibrated = intercept
622
623 return(calibration_config.dict())
624
625 def get_config(self, data):
626 '''Get an instance of the configuration object needed by this
627 `Processor` from a returned value of `Reader.read`
628
629 :param data: Result of `Reader.read` where at least one item has the
630 value `'MCACeriaCalibrationConfig'` for the `'schema'` key.
631 :type data: list[dict[str,object]]
632 :raises Exception: If a valid config object cannot be constructed from `data`.
633 :return: a valid instance of a configuration object with field values
634 taken from `data`.
635 :rtype: MCACeriaCalibrationConfig
636 '''
637
638 from CHAP.models.edd import MCACeriaCalibrationConfig
639
640 calibration_config = False
641 if isinstance(data, list):
642 for item in data:
643 if isinstance(item, dict):
644 if item.get('schema') == 'MCACeriaCalibrationConfig':
645 calibration_config = item.get('data')
646 break
647
648 if not calibration_config:
649 raise(ValueError('No MCA ceria calibration configuration found in input data'))
650
651 return(MCACeriaCalibrationConfig(**calibration_config))
652
653 def calibrate(self, calibration_config):
654 '''Iteratively calibrate 2&theta by fitting selected peaks of an MCA
655 spectrum until the computed strain is sufficiently small. Use the fitted
656 peak locations to determine linear correction parameters for the MCA's
657 channel energies.
658
659 :param calibration_config: object configuring the CeO2 calibration procedure
660 :type calibration_config: MCACeriaCalibrationConfig
661 :return: calibrated values of 2&theta and linear correction parameters
662 for MCA channel energies : tth, slope, intercept
663 :rtype: float, float, float
664 '''
665
666 from msnctools.fit import Fit, FitMultipeak
667 import numpy as np
668 from scipy.constants import physical_constants
669
670 hc = physical_constants['Planck constant in eV/Hz'][0] * \
671 physical_constants['speed of light in vacuum'][0] * \
672 1e7 # We'll work in keV and A, not eV and m.
673
674 # Collect raw MCA data of interest
675 mca_data = calibration_config.mca_data()
676 mca_bin_energies = np.arange(0, calibration_config.num_bins) * \
677 (calibration_config.max_energy_kev / calibration_config.num_bins)
678
679 # Mask out the corrected MCA data for fitting
680 mca_mask = calibration_config.mca_mask()
681 fit_mca_energies = mca_bin_energies[mca_mask]
682 fit_mca_intensities = mca_data[mca_mask]
683
684 # Correct raw MCA data for variable flux at different energies
685 flux_correct = calibration_config.flux_correction_interpolation_function()
686 mca_intensity_weights = flux_correct(fit_mca_energies)
687 fit_mca_intensities = fit_mca_intensities / mca_intensity_weights
688
689 # Get the HKLs and lattice spacings that will be used for fitting
690 tth = calibration_config.tth_initial_guess
691 fit_hkls, fit_ds = calibration_config.fit_ds()
692 c_1 = fit_hkls[:,0]**2 + fit_hkls[:,1]**2 + fit_hkls[:,2]**2
693
694 for iter_i in range(calibration_config.max_iter):
695
696 ### Perform the uniform fit first ###
697
698 # Get expected peak energy locations for this iteration's starting
699 # value of tth
700 fit_lambda = 2.0 * fit_ds * np.sin(0.5*np.radians(tth))
701 fit_E0 = hc / fit_lambda
702
703 # Run the uniform fit
704 best_fit, residual, best_values, best_errors, redchi, success = \
705 FitMultipeak.fit_multipeak(fit_mca_intensities,
706 fit_E0,
707 x=fit_mca_energies,
708 fit_type='uniform')
709
710 # Extract values of interest from the best values for the uniform fit
711 # parameters
712 uniform_fit_centers = [best_values[f'peak{i+1}_center'] for i in range(len(calibration_config.fit_hkls))]
713 # uniform_a = best_values['scale_factor']
714 # uniform_strain = np.log(uniform_a / calibration_config.lattice_parameter_angstrom)
715 # uniform_tth = tth * (1.0 + uniform_strain)
716 # uniform_rel_rms_error = np.linalg.norm(residual) / np.linalg.norm(fit_mca_intensities)
717
718 ### Next, perform the unconstrained fit ###
719
720 # Use the peak locations found in the uniform fit as the initial
721 # guesses for peak locations in the unconstrained fit
722 best_fit, residual, best_values, best_errors, redchi, success = \
723 FitMultipeak.fit_multipeak(fit_mca_intensities,
724 uniform_fit_centers,
725 x=fit_mca_energies,
726 fit_type='unconstrained')
727
728 # Extract values of interest from the best values for the
729 # unconstrained fit parameters
730 unconstrained_fit_centers = np.array([best_values[f'peak{i+1}_center'] for i in range(len(calibration_config.fit_hkls))])
731 unconstrained_a = 0.5 * hc * np.sqrt(c_1) / (unconstrained_fit_centers * abs(np.sin(0.5*np.radians(tth))))
732 unconstrained_strains = np.log(unconstrained_a / calibration_config.lattice_parameter_angstrom)
733 unconstrained_strain = np.mean(unconstrained_strains)
734 unconstrained_tth = tth * (1.0 + unconstrained_strain)
735 # unconstrained_rel_rms_error = np.linalg.norm(residual) / np.linalg.norm(fit_mca_intensities)
736
737
738 # Update tth for the next iteration of tuning
739 prev_tth = tth
740 tth = unconstrained_tth
741
742 # Stop tuning tth at this iteration if differences are small enough
743 if abs(tth - prev_tth) < calibration_config.tune_tth_tol:
744 break
745
746 # Fit line to expected / computed peak locations from the last
747 # unconstrained fit.
748 fit = Fit.fit_data(fit_E0,'linear', x=unconstrained_fit_centers, nan_policy='omit')
749 slope = fit.best_values['slope']
750 intercept = fit.best_values['intercept']
751
752 return(float(tth), float(slope), float(intercept))
753
754 class MCADataProcessor(Processor):
755 '''Class representing a process to return data from a MCA, restuctured to
756 incorporate the shape & metadata associated with a map configuration to
757 which the MCA data belongs, and linearly transformed according to the
758 results of a ceria calibration.
759 '''
760
761 def _process(self, data):
762 '''Process configurations for a map and MCA detector(s), and return the
763 raw MCA data collected over the map.
764
765 :param data: input map configuration and results of ceria calibration
766 :type data: list[dict[str,object]]
767 :return: calibrated and flux-corrected MCA data
768 :rtype: nexusformat.nexus.NXentry
769 '''
770
771 map_config, calibration_config = self.get_configs(data)
772 nxroot = self.get_nxroot(map_config, calibration_config)
773
774 return(nxroot)
775
776 def get_configs(self, data):
777 '''Get instances of the configuration objects needed by this
778 `Processor` from a returned value of `Reader.read`
779
780 :param data: Result of `Reader.read` where at least one item has the
781 value `'MapConfig'` for the `'schema'` key, and at least one item has
782 the value `'MCACeriaCalibrationConfig'` for the `'schema'` key.
783 :type data: list[dict[str,object]]
784 :raises Exception: If valid config objects cannot be constructed from `data`.
785 :return: valid instances of the configuration objects with field values
786 taken from `data`.
787 :rtype: tuple[MapConfig, MCACeriaCalibrationConfig]
788 '''
789
790 from CHAP.models.map import MapConfig
791 from CHAP.models.edd import MCACeriaCalibrationConfig
792
793 map_config = False
794 calibration_config = False
795 if isinstance(data, list):
796 for item in data:
797 if isinstance(item, dict):
798 schema = item.get('schema')
799 if schema == 'MapConfig':
800 map_config = item.get('data')
801 elif schema == 'MCACeriaCalibrationConfig':
802 calibration_config = item.get('data')
803
804 if not map_config:
805 raise(ValueError('No map configuration found in input data'))
806 if not calibration_config:
807 raise(ValueError('No MCA ceria calibration configuration found in input data'))
808
809 return(MapConfig(**map_config), MCACeriaCalibrationConfig(**calibration_config))
810
811 def get_nxroot(self, map_config, calibration_config):
812 '''Get a map of the MCA data collected by the scans in `map_config`. The
813 MCA data will be calibrated and flux-corrected according to the
814 parameters included in `calibration_config`. The data will be returned
815 along with relevant metadata in the form of a NeXus structure.
816
817 :param map_config: the map configuration
818 :type map_config: MapConfig
819 :param calibration_config: the calibration configuration
820 :type calibration_config: MCACeriaCalibrationConfig
821 :return: a map of the calibrated and flux-corrected MCA data
822 :rtype: nexusformat.nexus.NXroot
823 '''
824
825 from nexusformat.nexus import (NXdata,
826 NXdetector,
827 NXentry,
828 NXinstrument,
829 NXroot)
830 import numpy as np
831
832 nxroot = NXroot()
833
834 nxroot[map_config.title] = MapProcessor.get_nxentry(map_config)
835 nxentry = nxroot[map_config.title]
836
837 nxentry.instrument = NXinstrument()
838 nxentry.instrument.detector = NXdetector()
839 nxentry.instrument.detector.calibration_configuration = json.dumps(calibration_config.dict())
840
841 nxentry.instrument.detector.data = NXdata()
842 nxdata = nxentry.instrument.detector.data
843 nxdata.raw = np.empty((*map_config.shape, calibration_config.num_bins))
844 nxdata.raw.attrs['units'] = 'counts'
845 nxdata.channel_energy = calibration_config.slope_calibrated * \
846 np.arange(0, calibration_config.num_bins) * \
847 (calibration_config.max_energy_kev / calibration_config.num_bins) + \
848 calibration_config.intercept_calibrated
849 nxdata.channel_energy.attrs['units'] = 'keV'
850
851 for scans in map_config.spec_scans:
852 for scan_number in scans.scan_numbers:
853 scanparser = scans.get_scanparser(scan_number)
854 for scan_step_index in range(scanparser.spec_scan_npts):
855 map_index = scans.get_index(scan_number, scan_step_index, map_config)
856 nxdata.raw[map_index] = scanparser.get_detector_data(calibration_config.detector_name, scan_step_index)
857
858 nxentry.data.makelink(nxdata.raw, name=calibration_config.detector_name)
859 nxentry.data.makelink(nxdata.channel_energy, name=f'{calibration_config.detector_name}_channel_energy')
860 if isinstance(nxentry.data.attrs['axes'], str):
861 nxentry.data.attrs['axes'] = [nxentry.data.attrs['axes'], f'{calibration_config.detector_name}_channel_energy']
862 else:
863 nxentry.data.attrs['axes'] += [f'{calibration_config.detector_name}_channel_energy']
864 nxentry.data.attrs['signal'] = calibration_config.detector_name
865
866 return(nxroot)
867
868 class StrainAnalysisProcessor(Processor):
869 '''Class representing a process to compute a map of sample strains by fitting
870 bragg peaks in 1D detector data and analyzing the difference between measured
871 peak locations and expected peak locations for the sample measured.
872 '''
873
874 def _process(self, data):
875 '''Process the input map detector data & configuration for the strain
876 analysis procedure, and return a map of sample strains.
877
878 :param data: results of `MutlipleReader.read` containing input map
879 detector data and strain analysis configuration
880 :type data: dict[list[str,object]]
881 :return: map of sample strains
882 :rtype: xarray.Dataset
883 '''
884
885 strain_analysis_config = self.get_config(data)
886
887 return(data)
888
889 def get_config(self, data):
890 '''Get instances of the configuration objects needed by this
891 `Processor` from a returned value of `Reader.read`
892
893 :param data: Result of `Reader.read` where at least one item has the
894 value `'StrainAnalysisConfig'` for the `'schema'` key.
895 :type data: list[dict[str,object]]
896 :raises Exception: If valid config objects cannot be constructed from `data`.
897 :return: valid instances of the configuration objects with field values
898 taken from `data`.
899 :rtype: StrainAnalysisConfig
900 '''
901
902 strain_analysis_config = False
903 if isinstance(data, list):
904 for item in data:
905 if isinstance(item, dict):
906 schema = item.get('schema')
907 if item.get('schema') == 'StrainAnalysisConfig':
908 strain_analysis_config = item.get('data')
909
910 if not strain_analysis_config:
911 raise(ValueError('No strain analysis configuration found in input data'))
912
913 return(strain_analysis_config)
914
915
916 class OptionParser():
917 '''User based option parser'''
918 def __init__(self):
919 self.parser = argparse.ArgumentParser(prog='PROG')
920 self.parser.add_argument("--data", action="store",
921 dest="data", default="", help="Input data")
922 self.parser.add_argument("--processor", action="store",
923 dest="processor", default="Processor", help="Processor class name")
924 self.parser.add_argument('--log-level', choices=logging._nameToLevel.keys(),
925 dest='log_level', default='INFO', help='logging level')
926
927 def main():
928 '''Main function'''
929 optmgr = OptionParser()
930 opts = optmgr.parser.parse_args()
931 clsName = opts.processor
932 try:
933 processorCls = getattr(sys.modules[__name__],clsName)
934 except:
935 print(f'Unsupported processor {clsName}')
936 sys.exit(1)
937
938 processor = processorCls()
939 processor.logger.setLevel(getattr(logging, opts.log_level))
940 log_handler = logging.StreamHandler()
941 log_handler.setFormatter(logging.Formatter('{name:20}: {message}', style='{'))
942 processor.logger.addHandler(log_handler)
943 data = processor.process(opts.data)
944
945 print(f"Processor {processor} operates on data {data}")
946
947 if __name__ == '__main__':
948 main()