comparison workflow/models.py @ 69:fba792d5f83b draft

planemo upload for repository https://github.com/rolfverberg/galaxytools commit ab9f412c362a4ab986d00e21d5185cfcf82485c1
author rv43
date Fri, 10 Mar 2023 16:02:04 +0000
parents
children 1cf15b61cd83
comparison
equal deleted inserted replaced
68:ba5866d0251d 69:fba792d5f83b
1 #!/usr/bin/env python3
2
3 import logging
4 logger = logging.getLogger(__name__)
5
6 import logging
7
8 import numpy as np
9 import os
10 import yaml
11
12 from functools import cache
13 from pathlib import PosixPath
14 from pydantic import validator, ValidationError, conint, confloat, constr, \
15 conlist, FilePath, PrivateAttr
16 from pydantic import BaseModel as PydanticBaseModel
17 from nexusformat.nexus import *
18 from time import time
19 from typing import Optional, Literal
20 from typing_extensions import TypedDict
21 try:
22 from pyspec.file.spec import FileSpec
23 except:
24 pass
25
26 from msnctools.general import is_int, is_num, input_int, input_int_list, input_num, input_yesno, \
27 input_menu, index_nearest, string_to_list, file_exists_and_readable
28
29
30 def import_scanparser(station):
31 if station in ('id1a3', 'id3a'):
32 from msnctools.scanparsers import SMBRotationScanParser
33 globals()['ScanParser'] = SMBRotationScanParser
34 elif station in ('id3b'):
35 from msnctools.scanparsers import FMBRotationScanParser
36 globals()['ScanParser'] = FMBRotationScanParser
37 else:
38 raise RuntimeError(f'Invalid station: {station}')
39
40 @cache
41 def get_available_scan_numbers(spec_file:str):
42 scans = FileSpec(spec_file).scans
43 scan_numbers = list(scans.keys())
44 for scan_number in scan_numbers.copy():
45 try:
46 parser = ScanParser(spec_file, scan_number)
47 try:
48 scan_type = parser.get_scan_type()
49 except:
50 scan_type = None
51 pass
52 except:
53 scan_numbers.remove(scan_number)
54 return(scan_numbers)
55
56 @cache
57 def get_scanparser(spec_file:str, scan_number:int):
58 if scan_number not in get_available_scan_numbers(spec_file):
59 return(None)
60 else:
61 return(ScanParser(spec_file, scan_number))
62
63
64 class BaseModel(PydanticBaseModel):
65 class Config:
66 validate_assignment = True
67 arbitrary_types_allowed = True
68
69 @classmethod
70 def construct_from_cli(cls):
71 obj = cls.construct()
72 obj.cli()
73 return(obj)
74
75 @classmethod
76 def construct_from_yaml(cls, filename):
77 try:
78 with open(filename, 'r') as infile:
79 indict = yaml.load(infile, Loader=yaml.CLoader)
80 except:
81 raise ValueError(f'Could not load a dictionary from {filename}')
82 else:
83 obj = cls(**indict)
84 return(obj)
85
86 @classmethod
87 def construct_from_file(cls, filename):
88 file_exists_and_readable(filename)
89 filename = os.path.abspath(filename)
90 fileformat = os.path.splitext(filename)[1]
91 yaml_extensions = ('.yaml','.yml')
92 nexus_extensions = ('.nxs','.nx5','.h5','.hdf5')
93 t0 = time()
94 if fileformat.lower() in yaml_extensions:
95 obj = cls.construct_from_yaml(filename)
96 logger.info(f'Constructed a model from {filename} in {time()-t0:.2f} seconds.')
97 return(obj)
98 elif fileformat.lower() in nexus_extensions:
99 obj = cls.construct_from_nexus(filename)
100 logger.info(f'Constructed a model from {filename} in {time()-t0:.2f} seconds.')
101 return(obj)
102 else:
103 logger.error(f'Unsupported file extension for constructing a model: {fileformat}')
104 raise TypeError(f'Unrecognized file extension: {fileformat}')
105
106 def dict_for_yaml(self, exclude_fields=[]):
107 yaml_dict = {}
108 for field_name in self.__fields__:
109 if field_name in exclude_fields:
110 continue
111 else:
112 field_value = getattr(self, field_name, None)
113 if field_value is not None:
114 if isinstance(field_value, BaseModel):
115 yaml_dict[field_name] = field_value.dict_for_yaml()
116 elif isinstance(field_value,list) and all(isinstance(item,BaseModel)
117 for item in field_value):
118 yaml_dict[field_name] = [item.dict_for_yaml() for item in field_value]
119 elif isinstance(field_value, PosixPath):
120 yaml_dict[field_name] = str(field_value)
121 else:
122 yaml_dict[field_name] = field_value
123 else:
124 continue
125 return(yaml_dict)
126
127 def write_to_yaml(self, filename=None):
128 yaml_dict = self.dict_for_yaml()
129 if filename is None:
130 logger.info('Printing yaml representation here:\n'+
131 f'{yaml.dump(yaml_dict, sort_keys=False)}')
132 else:
133 try:
134 with open(filename, 'w') as outfile:
135 yaml.dump(yaml_dict, outfile, sort_keys=False)
136 logger.info(f'Successfully wrote this model to {filename}')
137 except:
138 logger.error(f'Unknown error -- could not write to {filename} in yaml format.')
139 logger.info('Printing yaml representation here:\n'+
140 f'{yaml.dump(yaml_dict, sort_keys=False)}')
141
142 def write_to_file(self, filename, force_overwrite=False):
143 file_writeable, fileformat = self.output_file_valid(filename,
144 force_overwrite=force_overwrite)
145 if fileformat == 'yaml':
146 if file_writeable:
147 self.write_to_yaml(filename=filename)
148 else:
149 self.write_to_yaml()
150 elif fileformat == 'nexus':
151 if file_writeable:
152 self.write_to_nexus(filename=filename)
153
154 def output_file_valid(self, filename, force_overwrite=False):
155 filename = os.path.abspath(filename)
156 fileformat = os.path.splitext(filename)[1]
157 yaml_extensions = ('.yaml','.yml')
158 nexus_extensions = ('.nxs','.nx5','.h5','.hdf5')
159 if fileformat.lower() not in (*yaml_extensions, *nexus_extensions):
160 return(False, None) # Only yaml and NeXus files allowed for output now.
161 elif fileformat.lower() in yaml_extensions:
162 fileformat = 'yaml'
163 elif fileformat.lower() in nexus_extensions:
164 fileformat = 'nexus'
165 if os.path.isfile(filename):
166 if os.access(filename, os.W_OK):
167 if not force_overwrite:
168 logger.error(f'{filename} will not be overwritten.')
169 return(False, fileformat)
170 else:
171 logger.error(f'Cannot access {filename} for writing.')
172 return(False, fileformat)
173 if os.path.isdir(os.path.dirname(filename)):
174 if os.access(os.path.dirname(filename), os.W_OK):
175 return(True, fileformat)
176 else:
177 logger.error(f'Cannot access {os.path.dirname(filename)} for writing.')
178 return(False, fileformat)
179 else:
180 try:
181 os.makedirs(os.path.dirname(filename))
182 return(True, fileformat)
183 except:
184 logger.error(f'Cannot create {os.path.dirname(filename)} for output.')
185 return(False, fileformat)
186
187 def set_single_attr_cli(self, attr_name, attr_desc='unknown attribute', list_flag=False,
188 **cli_kwargs):
189 if cli_kwargs.get('chain_attr_desc', False):
190 cli_kwargs['attr_desc'] = attr_desc
191 try:
192 attr = getattr(self, attr_name, None)
193 if attr is None:
194 attr = self.__fields__[attr_name].type_.construct()
195 if cli_kwargs.get('chain_attr_desc', False):
196 cli_kwargs['attr_desc'] = attr_desc
197 input_accepted = False
198 while not input_accepted:
199 try:
200 attr.cli(**cli_kwargs)
201 except ValidationError as e:
202 print(e)
203 print(f'Removing {attr_desc} configuration')
204 attr = self.__fields__[attr_name].type_.construct()
205 continue
206 except KeyboardInterrupt as e:
207 raise e
208 except BaseException as e:
209 print(f'{type(e).__name__}: {e}')
210 print(f'Removing {attr_desc} configuration')
211 attr = self.__fields__[attr_name].type_.construct()
212 continue
213 try:
214 setattr(self, attr_name, attr)
215 except ValidationError as e:
216 print(e)
217 except KeyboardInterrupt as e:
218 raise e
219 except BaseException as e:
220 print(f'{type(e).__name__}: {e}')
221 else:
222 input_accepted = True
223 except:
224 input_accepted = False
225 while not input_accepted:
226 attr = getattr(self, attr_name, None)
227 if attr is None:
228 input_value = input(f'Type and enter a value for {attr_desc}: ')
229 else:
230 input_value = input(f'Type and enter a new value for {attr_desc} or press '+
231 f'enter to keep the current one ({attr}): ')
232 if list_flag:
233 input_value = string_to_list(input_value, remove_duplicates=False, sort=False)
234 if len(input_value) == 0:
235 input_value = getattr(self, attr_name, None)
236 try:
237 setattr(self, attr_name, input_value)
238 except ValidationError as e:
239 print(e)
240 except KeyboardInterrupt as e:
241 raise e
242 except BaseException as e:
243 print(f'Unexpected {type(e).__name__}: {e}')
244 else:
245 input_accepted = True
246
247 def set_list_attr_cli(self, attr_name, attr_desc='unknown attribute', **cli_kwargs):
248 if cli_kwargs.get('chain_attr_desc', False):
249 cli_kwargs['attr_desc'] = attr_desc
250 attr = getattr(self, attr_name, None)
251 if attr is not None:
252 # Check existing items
253 for item in attr:
254 item_accepted = False
255 while not item_accepted:
256 item.cli(**cli_kwargs)
257 try:
258 setattr(self, attr_name, attr)
259 except ValidationError as e:
260 print(e)
261 except KeyboardInterrupt as e:
262 raise e
263 except BaseException as e:
264 print(f'{type(e).__name__}: {e}')
265 else:
266 item_accepted = True
267 else:
268 # Initialize list for new attribute & starting item
269 attr = []
270 item = self.__fields__[attr_name].type_.construct()
271 # Append (optional) additional items
272 append = input_yesno(f'Add a {attr_desc} configuration? (y/n)', 'n')
273 while append:
274 attr.append(item.__class__.construct_from_cli())
275 try:
276 setattr(self, attr_name, attr)
277 except ValidationError as e:
278 print(e)
279 print(f'Removing last {attr_desc} configuration from the list')
280 attr.pop()
281 except KeyboardInterrupt as e:
282 raise e
283 except BaseException as e:
284 print(f'{type(e).__name__}: {e}')
285 print(f'Removing last {attr_desc} configuration from the list')
286 attr.pop()
287 else:
288 append = input_yesno(f'Add another {attr_desc} configuration? (y/n)', 'n')
289
290
291 class Detector(BaseModel):
292 prefix: constr(strip_whitespace=True, min_length=1)
293 rows: conint(gt=0)
294 columns: conint(gt=0)
295 pixel_size: conlist(item_type=confloat(gt=0), min_items=1, max_items=2)
296 lens_magnification: confloat(gt=0) = 1.0
297
298 @property
299 def get_pixel_size(self):
300 return(list(np.asarray(self.pixel_size)/self.lens_magnification))
301
302 def construct_from_yaml(self, filename):
303 try:
304 with open(filename, 'r') as infile:
305 indict = yaml.load(infile, Loader=yaml.CLoader)
306 detector = indict['detector']
307 self.prefix = detector['id']
308 pixels = detector['pixels']
309 self.rows = pixels['rows']
310 self.columns = pixels['columns']
311 self.pixel_size = pixels['size']
312 self.lens_magnification = indict['lens_magnification']
313 except:
314 logging.warning(f'Could not load a dictionary from {filename}')
315 return(False)
316 else:
317 return(True)
318
319 def cli(self):
320 print('\n -- Configure the detector -- ')
321 self.set_single_attr_cli('prefix', 'detector ID')
322 self.set_single_attr_cli('rows', 'number of pixel rows')
323 self.set_single_attr_cli('columns', 'number of pixel columns')
324 self.set_single_attr_cli('pixel_size', 'pixel size in mm (enter either a single value for '+
325 'square pixels or a pair of values for the size in the respective row and column '+
326 'directions)', list_flag=True)
327 self.set_single_attr_cli('lens_magnification', 'lens magnification')
328
329 def construct_nxdetector(self):
330 nxdetector = NXdetector()
331 nxdetector.local_name = self.prefix
332 pixel_size = self.get_pixel_size
333 if len(pixel_size) == 1:
334 nxdetector.x_pixel_size = pixel_size[0]
335 nxdetector.y_pixel_size = pixel_size[0]
336 else:
337 nxdetector.x_pixel_size = pixel_size[0]
338 nxdetector.y_pixel_size = pixel_size[1]
339 nxdetector.x_pixel_size.attrs['units'] = 'mm'
340 nxdetector.y_pixel_size.attrs['units'] = 'mm'
341 return(nxdetector)
342
343
344 class ScanInfo(TypedDict):
345 scan_number: int
346 starting_image_offset: conint(ge=0)
347 num_image: conint(gt=0)
348 ref_x: float
349 ref_z: float
350
351 class SpecScans(BaseModel):
352 spec_file: FilePath
353 scan_numbers: conlist(item_type=conint(gt=0), min_items=1)
354 stack_info: conlist(item_type=ScanInfo, min_items=1) = []
355
356 @validator('spec_file')
357 def validate_spec_file(cls, spec_file):
358 try:
359 spec_file = os.path.abspath(spec_file)
360 sspec_file = FileSpec(spec_file)
361 except:
362 raise ValueError(f'Invalid SPEC file {spec_file}')
363 else:
364 return(spec_file)
365
366 @validator('scan_numbers')
367 def validate_scan_numbers(cls, scan_numbers, values):
368 spec_file = values.get('spec_file')
369 if spec_file is not None:
370 spec_scans = FileSpec(spec_file)
371 for scan_number in scan_numbers:
372 scan = spec_scans.get_scan_by_number(scan_number)
373 if scan is None:
374 raise ValueError(f'There is no scan number {scan_number} in {spec_file}')
375 return(scan_numbers)
376
377 @validator('stack_info')
378 def validate_stack_info(cls, stack_info, values):
379 scan_numbers = values.get('scan_numbers')
380 assert(len(scan_numbers) == len(stack_info))
381 for scan_info in stack_info:
382 assert(scan_info['scan_number'] in scan_numbers)
383 is_int(scan_info['starting_image_offset'], ge=0, lt=scan_info['num_image'],
384 raise_error=True)
385 return(stack_info)
386
387 @classmethod
388 def construct_from_nxcollection(cls, nxcollection:NXcollection):
389 config = {}
390 config['spec_file'] = nxcollection.attrs['spec_file']
391 scan_numbers = []
392 stack_info = []
393 for nxsubentry_name, nxsubentry in nxcollection.items():
394 scan_number = int(nxsubentry_name.split('_')[-1])
395 scan_numbers.append(scan_number)
396 stack_info.append({'scan_number': scan_number,
397 'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number),
398 'num_image': len(nxsubentry.sample.rotation_angle),
399 'ref_x': float(nxsubentry.sample.x_translation),
400 'ref_z': float(nxsubentry.sample.z_translation)})
401 config['scan_numbers'] = sorted(scan_numbers)
402 config['stack_info'] = stack_info
403 return(cls(**config))
404
405 @property
406 def available_scan_numbers(self):
407 return(get_available_scan_numbers(self.spec_file))
408
409 def set_from_nxcollection(self, nxcollection:NXcollection):
410 self.spec_file = nxcollection.attrs['spec_file']
411 scan_numbers = []
412 stack_info = []
413 for nxsubentry_name, nxsubentry in nxcollection.items():
414 scan_number = int(nxsubentry_name.split('_')[-1])
415 scan_numbers.append(scan_number)
416 stack_info.append({'scan_number': scan_number,
417 'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number),
418 'num_image': len(nxsubentry.sample.rotation_angle),
419 'ref_x': float(nxsubentry.sample.x_translation),
420 'ref_z': float(nxsubentry.sample.z_translation)})
421 self.scan_numbers = sorted(scan_numbers)
422 self.stack_info = stack_info
423
424 def get_scan_index(self, scan_number):
425 scan_index = [scan_index for scan_index, scan_info in enumerate(self.stack_info)
426 if scan_info['scan_number'] == scan_number]
427 if len(scan_index) > 1:
428 raise ValueError('Duplicate scan_numbers in image stack')
429 elif len(scan_index) == 1:
430 return(scan_index[0])
431 else:
432 return(None)
433
434 def get_scanparser(self, scan_number):
435 return(get_scanparser(self.spec_file, scan_number))
436
437 # def get_detector_data(self, detector_prefix, scan_number=None, scan_step_index=None):
438 # if scan_number is None:
439 # scan_number = self.scan_numbers[0]
440 # if scan_step_index is None:
441 # scan_info = self.stack_info[self.get_scan_index(scan_number)]
442 # scan_step_index = scan_info['starting_image_offset']
443 # parser = self.get_scanparser(scan_number)
444 # return(parser.get_detector_data(detector_prefix, scan_step_index))
445
446 def get_detector_data(self, detector_prefix, scan_number=None, scan_step_index=None):
447 image_stacks = []
448 if scan_number is None:
449 scan_numbers = self.scan_numbers
450 else:
451 scan_numbers = [scan_number]
452 for scan_number in scan_numbers:
453 parser = self.get_scanparser(scan_number)
454 scan_info = self.stack_info[self.get_scan_index(scan_number)]
455 image_offset = scan_info['starting_image_offset']
456 if scan_step_index is None:
457 num_image = scan_info['num_image']
458 image_stacks.append(parser.get_detector_data(detector_prefix,
459 (image_offset, image_offset+num_image)))
460 else:
461 image_stacks.append(parser.get_detector_data(detector_prefix,
462 image_offset+scan_step_index))
463 if len(image_stacks) == 1:
464 return(image_stacks[0])
465 else:
466 return(image_stacks)
467
468 def scan_numbers_cli(self, attr_desc, **kwargs):
469 available_scan_numbers = self.available_scan_numbers
470 station = kwargs.get('station')
471 if (station is not None and station in ('id1a3', 'id3a') and
472 'scan_type' in kwargs):
473 scan_type = kwargs['scan_type']
474 if scan_type == 'ts1':
475 available_scan_numbers = []
476 for scan_number in self.available_scan_numbers:
477 parser = self.get_scanparser(scan_number)
478 if parser.scan_type == scan_type:
479 available_scan_numbers.append(scan_number)
480 elif scan_type == 'df1':
481 tomo_scan_numbers = kwargs['tomo_scan_numbers']
482 available_scan_numbers = []
483 for scan_number in tomo_scan_numbers:
484 parser = self.get_scanparser(scan_number-2)
485 assert(parser.scan_type == scan_type)
486 available_scan_numbers.append(scan_number-2)
487 elif scan_type == 'bf1':
488 tomo_scan_numbers = kwargs['tomo_scan_numbers']
489 available_scan_numbers = []
490 for scan_number in tomo_scan_numbers:
491 parser = self.get_scanparser(scan_number-1)
492 assert(parser.scan_type == scan_type)
493 available_scan_numbers.append(scan_number-1)
494 if len(available_scan_numbers) == 1:
495 input_mode = 1
496 else:
497 if hasattr(self, 'scan_numbers'):
498 print(f'Currently selected {attr_desc}scan numbers are: {self.scan_numbers}')
499 menu_options = [f'Select a subset of the available {attr_desc}scan numbers',
500 f'Use all available {attr_desc}scan numbers in {self.spec_file}',
501 f'Keep the currently selected {attr_desc}scan numbers']
502 else:
503 menu_options = [f'Select a subset of the available {attr_desc}scan numbers',
504 f'Use all available {attr_desc}scan numbers in {self.spec_file}']
505 print(f'Available scan numbers in {self.spec_file} are: '+
506 f'{available_scan_numbers}')
507 input_mode = input_menu(menu_options, header='Choose one of the following options '+
508 'for selecting scan numbers')
509 if input_mode == 0:
510 accept_scan_numbers = False
511 while not accept_scan_numbers:
512 try:
513 self.scan_numbers = \
514 input_int_list(f'Enter a series of {attr_desc}scan numbers')
515 except ValidationError as e:
516 print(e)
517 except KeyboardInterrupt as e:
518 raise e
519 except BaseException as e:
520 print(f'Unexpected {type(e).__name__}: {e}')
521 else:
522 accept_scan_numbers = True
523 elif input_mode == 1:
524 self.scan_numbers = available_scan_numbers
525 elif input_mode == 2:
526 pass
527
528 def cli(self, **cli_kwargs):
529 if cli_kwargs.get('attr_desc') is not None:
530 attr_desc = f'{cli_kwargs["attr_desc"]} '
531 else:
532 attr_desc = ''
533 print(f'\n -- Configure which scans to use from a single {attr_desc}SPEC file')
534 self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path')
535 self.scan_numbers_cli(attr_desc)
536
537 def construct_nxcollection(self, image_key, thetas, detector):
538 nxcollection = NXcollection()
539 nxcollection.attrs['spec_file'] = str(self.spec_file)
540 parser = self.get_scanparser(self.scan_numbers[0])
541 nxcollection.attrs['date'] = parser.spec_scan.file_date
542 for scan_number in self.scan_numbers:
543 # Get scan info
544 scan_info = self.stack_info[self.get_scan_index(scan_number)]
545 # Add an NXsubentry to the NXcollection for each scan
546 entry_name = f'scan_{scan_number}'
547 nxsubentry = NXsubentry()
548 nxcollection[entry_name] = nxsubentry
549 parser = self.get_scanparser(scan_number)
550 nxsubentry.start_time = parser.spec_scan.date
551 nxsubentry.spec_command = parser.spec_command
552 # Add an NXdata for independent dimensions to the scan's NXsubentry
553 num_image = scan_info['num_image']
554 if thetas is None:
555 thetas = num_image*[0.0]
556 else:
557 assert(num_image == len(thetas))
558 # nxsubentry.independent_dimensions = NXdata()
559 # nxsubentry.independent_dimensions.rotation_angle = thetas
560 # nxsubentry.independent_dimensions.rotation_angle.units = 'degrees'
561 # Add an NXinstrument to the scan's NXsubentry
562 nxsubentry.instrument = NXinstrument()
563 # Add an NXdetector to the NXinstrument to the scan's NXsubentry
564 nxsubentry.instrument.detector = detector.construct_nxdetector()
565 nxsubentry.instrument.detector.frame_start_number = scan_info['starting_image_offset']
566 nxsubentry.instrument.detector.image_key = image_key
567 # Add an NXsample to the scan's NXsubentry
568 nxsubentry.sample = NXsample()
569 nxsubentry.sample.rotation_angle = thetas
570 nxsubentry.sample.rotation_angle.units = 'degrees'
571 nxsubentry.sample.x_translation = scan_info['ref_x']
572 nxsubentry.sample.x_translation.units = 'mm'
573 nxsubentry.sample.z_translation = scan_info['ref_z']
574 nxsubentry.sample.z_translation.units = 'mm'
575 return(nxcollection)
576
577
578 class FlatField(SpecScans):
579
580 def image_range_cli(self, attr_desc, detector_prefix):
581 stack_info = self.stack_info
582 for scan_number in self.scan_numbers:
583 # Parse the available image range
584 parser = self.get_scanparser(scan_number)
585 image_offset = parser.starting_image_offset
586 num_image = parser.get_num_image(detector_prefix.upper())
587 scan_index = self.get_scan_index(scan_number)
588
589 # Select the image set
590 last_image_index = image_offset+num_image-1
591 print(f'Available good image set index range: [{image_offset}, {last_image_index}]')
592 image_set_approved = False
593 if scan_index is not None:
594 scan_info = stack_info[scan_index]
595 print(f'Current starting image offset and number of images: '+
596 f'{scan_info["starting_image_offset"]} and {scan_info["num_image"]}')
597 image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y')
598 if not image_set_approved:
599 print(f'Default starting image offset and number of images: '+
600 f'{image_offset} and {last_image_index-image_offset}')
601 image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y')
602 if image_set_approved:
603 offset = image_offset
604 num = last_image_index-offset
605 while not image_set_approved:
606 offset = input_int(f'Enter the starting image offset', ge=image_offset,
607 le=last_image_index-1)#, default=image_offset)
608 num = input_int(f'Enter the number of images', ge=1,
609 le=last_image_index-offset+1)#, default=last_image_index-offset+1)
610 print(f'Current starting image offset and number of images: {offset} and {num}')
611 image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y')
612 if scan_index is not None:
613 scan_info['starting_image_offset'] = offset
614 scan_info['num_image'] = num
615 scan_info['ref_x'] = parser.horizontal_shift
616 scan_info['ref_z'] = parser.vertical_shift
617 else:
618 stack_info.append({'scan_number': scan_number, 'starting_image_offset': offset,
619 'num_image': num, 'ref_x': parser.horizontal_shift,
620 'ref_z': parser.vertical_shift})
621 self.stack_info = stack_info
622
623 def cli(self, **cli_kwargs):
624 if cli_kwargs.get('attr_desc') is not None:
625 attr_desc = f'{cli_kwargs["attr_desc"]} '
626 else:
627 attr_desc = ''
628 station = cli_kwargs.get('station')
629 detector = cli_kwargs.get('detector')
630 print(f'\n -- Configure the location of the {attr_desc}scan data -- ')
631 if station in ('id1a3', 'id3a'):
632 self.spec_file = cli_kwargs['spec_file']
633 tomo_scan_numbers = cli_kwargs['tomo_scan_numbers']
634 scan_type = cli_kwargs['scan_type']
635 self.scan_numbers_cli(attr_desc, station=station, tomo_scan_numbers=tomo_scan_numbers,
636 scan_type=scan_type)
637 else:
638 self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path')
639 self.scan_numbers_cli(attr_desc)
640 self.image_range_cli(attr_desc, detector.prefix)
641
642
643 class TomoField(SpecScans):
644 theta_range: dict = {}
645
646 @validator('theta_range')
647 def validate_theta_range(cls, theta_range):
648 if len(theta_range) != 3 and len(theta_range) != 4:
649 raise ValueError(f'Invalid theta range {theta_range}')
650 is_num(theta_range['start'], raise_error=True)
651 is_num(theta_range['end'], raise_error=True)
652 is_int(theta_range['num'], gt=1, raise_error=True)
653 if theta_range['end'] <= theta_range['start']:
654 raise ValueError(f'Invalid theta range {theta_range}')
655 if 'start_index' in theta_range:
656 is_int(theta_range['start_index'], ge=0, raise_error=True)
657 return(theta_range)
658
659 @classmethod
660 def construct_from_nxcollection(cls, nxcollection:NXcollection):
661 #RV Can I derive this from the same classfunction for SpecScans by adding theta_range
662 config = {}
663 config['spec_file'] = nxcollection.attrs['spec_file']
664 scan_numbers = []
665 stack_info = []
666 for nxsubentry_name, nxsubentry in nxcollection.items():
667 scan_number = int(nxsubentry_name.split('_')[-1])
668 scan_numbers.append(scan_number)
669 stack_info.append({'scan_number': scan_number,
670 'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number),
671 'num_image': len(nxsubentry.sample.rotation_angle),
672 'ref_x': float(nxsubentry.sample.x_translation),
673 'ref_z': float(nxsubentry.sample.z_translation)})
674 config['scan_numbers'] = sorted(scan_numbers)
675 config['stack_info'] = stack_info
676 for name in nxcollection.entries:
677 if 'scan_' in name:
678 thetas = np.asarray(nxcollection[name].sample.rotation_angle)
679 config['theta_range'] = {'start': thetas[0], 'end': thetas[-1], 'num': thetas.size}
680 break
681 return(cls(**config))
682
683 def get_horizontal_shifts(self, scan_number=None):
684 horizontal_shifts = []
685 if scan_number is None:
686 scan_numbers = self.scan_numbers
687 else:
688 scan_numbers = [scan_number]
689 for scan_number in scan_numbers:
690 parser = self.get_scanparser(scan_number)
691 horizontal_shifts.append(parser.get_horizontal_shift())
692 if len(horizontal_shifts) == 1:
693 return(horizontal_shifts[0])
694 else:
695 return(horizontal_shifts)
696
697 def get_vertical_shifts(self, scan_number=None):
698 vertical_shifts = []
699 if scan_number is None:
700 scan_numbers = self.scan_numbers
701 else:
702 scan_numbers = [scan_number]
703 for scan_number in scan_numbers:
704 parser = self.get_scanparser(scan_number)
705 vertical_shifts.append(parser.get_vertical_shift())
706 if len(vertical_shifts) == 1:
707 return(vertical_shifts[0])
708 else:
709 return(vertical_shifts)
710
711 def theta_range_cli(self, scan_number, attr_desc, station):
712 # Parse the available theta range
713 parser = self.get_scanparser(scan_number)
714 theta_vals = parser.theta_vals
715 spec_theta_start = theta_vals.get('start')
716 spec_theta_end = theta_vals.get('end')
717 spec_num_theta = theta_vals.get('num')
718
719 # Check for consistency of theta ranges between scans
720 if scan_number != self.scan_numbers[0]:
721 parser = self.get_scanparser(self.scan_numbers[0])
722 if (parser.theta_vals.get('start') != spec_theta_start or
723 parser.theta_vals.get('end') != spec_theta_end or
724 parser.theta_vals.get('num') != spec_num_theta):
725 raise ValueError(f'Incompatible theta ranges between {attr_desc}scans:'+
726 f'\n\tScan {scan_number}: {theta_vals}'+
727 f'\n\tScan {self.scan_numbers[0]}: {parser.theta_vals}')
728 return
729
730 # Select the theta range for the tomo reconstruction from the first scan
731 thetas = np.linspace(spec_theta_start, spec_theta_end, spec_num_theta)
732 delta_theta = thetas[1]-thetas[0]
733 theta_range_approved = False
734 print(f'Theta range obtained from SPEC data: [{spec_theta_start}, {spec_theta_end})')
735 print(f'Theta step size = {delta_theta}')
736 print(f'Number of theta values: {spec_num_theta}')
737 exit('Done')
738 default_start = None
739 default_end = None
740 if station in ('id1a3', 'id3a'):
741 theta_range_approved = input_yesno(f'Accept this theta range (y/n)?', 'y')
742 if theta_range_approved:
743 theta_start = spec_theta_start
744 theta_end = spec_theta_end
745 num_theta = spec_num_theta
746 theta_index_start = 0
747 elif station in ('id3b'):
748 if spec_theta_start <= 0.0 and spec_theta_end >= 180.0:
749 default_start = 0
750 default_end = 180
751 elif spec_theta_end-spec_theta_start == 180:
752 default_start = spec_theta_start
753 default_end = spec_theta_end
754 while not theta_range_approved:
755 theta_start = input_num(f'Enter the first theta (included)', ge=spec_theta_start,
756 lt=spec_theta_end, default=default_start)
757 theta_index_start = index_nearest(thetas, theta_start)
758 theta_start = thetas[theta_index_start]
759 theta_end = input_num(f'Enter the last theta (excluded)',
760 ge=theta_start+delta_theta, le=spec_theta_end, default=default_end)
761 theta_index_end = index_nearest(thetas, theta_end)
762 theta_end = thetas[theta_index_end]
763 num_theta = theta_index_end-theta_index_start
764 print(f'Selected theta range: [{theta_start}, {theta_start+delta_theta}, ..., '+
765 f'{theta_end})')
766 print(f'Number of theta values: {num_theta}')
767 theta_range_approved = input_yesno(f'Accept this theta range (y/n)?', 'y')
768 self.thetas = np.linspace(theta_start, theta_end, num_theta)
769
770 def image_range_cli(self, attr_desc, detector_prefix):
771 stack_info = self.stack_info
772 for scan_number in self.scan_numbers:
773 # Parse the available image range
774 parser = self.get_scanparser(scan_number)
775 image_offset = parser.starting_image_offset
776 num_image = parser.get_num_image(detector_prefix.upper())
777 scan_index = self.get_scan_index(scan_number)
778
779 # Select the image set matching the theta range
780 num_theta = self.theta_range['num']
781 theta_index_start = self.theta_range['start_index']
782 if num_theta > num_image-theta_index_start:
783 raise ValueError(f'Available {attr_desc}image indices incompatible with thetas:'+
784 f'\n\tNumber of thetas and offset = {num_theta} and {theta_index_start}'+
785 f'\n\tNumber of available images {num_image}')
786 if scan_index is not None:
787 scan_info = stack_info[scan_index]
788 scan_info['starting_image_offset'] = image_offset+theta_index_start
789 scan_info['num_image'] = num_theta
790 scan_info['ref_x'] = parser.horizontal_shift
791 scan_info['ref_z'] = parser.vertical_shift
792 else:
793 stack_info.append({'scan_number': scan_number,
794 'starting_image_offset': image_offset+theta_index_start,
795 'num_image': num_theta, 'ref_x': parser.horizontal_shift,
796 'ref_z': parser.vertical_shift})
797 self.stack_info = stack_info
798
799 def cli(self, **cli_kwargs):
800 if cli_kwargs.get('attr_desc') is not None:
801 attr_desc = f'{cli_kwargs["attr_desc"]} '
802 else:
803 attr_desc = ''
804 cycle = cli_kwargs.get('cycle')
805 btr = cli_kwargs.get('btr')
806 station = cli_kwargs.get('station')
807 detector = cli_kwargs.get('detector')
808 print(f'\n -- Configure the location of the {attr_desc}scan data -- ')
809 if station in ('id1a3', 'id3a'):
810 basedir = f'/nfs/chess/{station}/{cycle}/{btr}'
811 runs = [d for d in os.listdir(basedir) if os.path.isdir(os.path.join(basedir, d))]
812 #RV index = 15-1
813 #RV index = 7-1
814 index = input_menu(runs, header='Choose a sample directory')
815 self.spec_file = f'{basedir}/{runs[index]}/spec.log'
816 self.scan_numbers_cli(attr_desc, station=station, scan_type='ts1')
817 else:
818 self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path')
819 self.scan_numbers_cli(attr_desc)
820 for scan_number in self.scan_numbers:
821 self.theta_range_cli(scan_number, attr_desc, station)
822 self.image_range_cli(attr_desc, detector.prefix)
823
824
825 class Sample(BaseModel):
826 name: constr(min_length=1)
827 description: Optional[str]
828 rotation_angles: Optional[list]
829 x_translations: Optional[list]
830 z_translations: Optional[list]
831
832 @classmethod
833 def construct_from_nxsample(cls, nxsample:NXsample):
834 config = {}
835 config['name'] = nxsample.name.nxdata
836 if 'description' in nxsample:
837 config['description'] = nxsample.description.nxdata
838 if 'rotation_angle' in nxsample:
839 config['rotation_angle'] = nxsample.rotation_angle.nxdata
840 if 'x_translation' in nxsample:
841 config['x_translation'] = nxsample.x_translation.nxdata
842 if 'z_translation' in nxsample:
843 config['z_translation'] = nxsample.z_translation.nxdata
844 return(cls(**config))
845
846 def cli(self):
847 print('\n -- Configure the sample metadata -- ')
848 #RV self.name = 'test'
849 #RV self.name = 'sobhani-3249-A'
850 self.set_single_attr_cli('name', 'the sample name')
851 #RV self.description = 'test sample'
852 self.set_single_attr_cli('description', 'a description of the sample (optional)')
853
854
855 class MapConfig(BaseModel):
856 cycle: constr(strip_whitespace=True, min_length=1)
857 btr: constr(strip_whitespace=True, min_length=1)
858 title: constr(strip_whitespace=True, min_length=1)
859 station: Literal['id1a3', 'id3a', 'id3b'] = None
860 sample: Sample
861 detector: Detector = Detector.construct()
862 tomo_fields: TomoField
863 dark_field: Optional[FlatField]
864 bright_field: FlatField
865 _thetas: list[float] = PrivateAttr()
866 _field_types = ({'name': 'dark_field', 'image_key': 2}, {'name': 'bright_field',
867 'image_key': 1}, {'name': 'tomo_fields', 'image_key': 0})
868
869 @classmethod
870 def construct_from_nxentry(cls, nxentry:NXentry):
871 config = {}
872 config['cycle'] = nxentry.instrument.source.attrs['cycle']
873 config['btr'] = nxentry.instrument.source.attrs['btr']
874 config['title'] = nxentry.nxname
875 config['station'] = nxentry.instrument.source.attrs['station']
876 config['sample'] = Sample.construct_from_nxsample(nxentry['sample'])
877 for nxobject_name, nxobject in nxentry.spec_scans.items():
878 if isinstance(nxobject, NXcollection):
879 config[nxobject_name] = SpecScans.construct_from_nxcollection(nxobject)
880 return(cls(**config))
881
882 #FIX cache?
883 @property
884 def thetas(self):
885 try:
886 return(self._thetas)
887 except:
888 theta_range = self.tomo_fields.theta_range
889 self._thetas = list(np.linspace(theta_range['start'], theta_range['end'],
890 theta_range['num']))
891 return(self._thetas)
892
893 def cli(self):
894 print('\n -- Configure a map from a set of SPEC scans (dark, bright, and tomo), '+
895 'and / or detector data -- ')
896 #RV self.cycle = '2021-3'
897 #RV self.cycle = '2022-2'
898 #RV self.cycle = '2023-1'
899 self.set_single_attr_cli('cycle', 'beam cycle')
900 #RV self.btr = 'z-3234-A'
901 #RV self.btr = 'sobhani-3249-A'
902 #RV self.btr = 'przybyla-3606-a'
903 self.set_single_attr_cli('btr', 'BTR')
904 #RV self.title = 'z-3234-A'
905 #RV self.title = 'tomo7C'
906 #RV self.title = 'cmc-test-dwell-1'
907 self.set_single_attr_cli('title', 'title for the map entry')
908 #RV self.station = 'id3a'
909 #RV self.station = 'id3b'
910 #RV self.station = 'id1a3'
911 self.set_single_attr_cli('station', 'name of the station at which scans were collected '+
912 '(currently choose from: id1a3, id3a, id3b)')
913 import_scanparser(self.station)
914 self.set_single_attr_cli('sample')
915 use_detector_config = False
916 if hasattr(self.detector, 'prefix') and len(self.detector.prefix):
917 use_detector_config = input_yesno(f'Current detector settings:\n{self.detector}\n'+
918 f'Keep these settings? (y/n)')
919 if not use_detector_config:
920 #RV have_detector_config = True
921 have_detector_config = input_yesno(f'Is a detector configuration file available? (y/n)')
922 if have_detector_config:
923 #RV detector_config_file = 'retiga.yaml'
924 #RV detector_config_file = 'andor2.yaml'
925 detector_config_file = input(f'Enter detector configuration file name: ')
926 have_detector_config = self.detector.construct_from_yaml(detector_config_file)
927 if not have_detector_config:
928 self.set_single_attr_cli('detector', 'detector')
929 self.set_single_attr_cli('tomo_fields', 'Tomo field', chain_attr_desc=True,
930 cycle=self.cycle, btr=self.btr, station=self.station, detector=self.detector)
931 if self.station in ('id1a3', 'id3a'):
932 have_dark_field = True
933 tomo_spec_file = self.tomo_fields.spec_file
934 else:
935 have_dark_field = input_yesno(f'Are Dark field images available? (y/n)')
936 tomo_spec_file = None
937 if have_dark_field:
938 self.set_single_attr_cli('dark_field', 'Dark field', chain_attr_desc=True,
939 station=self.station, detector=self.detector, spec_file=tomo_spec_file,
940 tomo_scan_numbers=self.tomo_fields.scan_numbers, scan_type='df1')
941 self.set_single_attr_cli('bright_field', 'Bright field', chain_attr_desc=True,
942 station=self.station, detector=self.detector, spec_file=tomo_spec_file,
943 tomo_scan_numbers=self.tomo_fields.scan_numbers, scan_type='bf1')
944
945 def construct_nxentry(self, nxroot, include_raw_data=True):
946 # Construct base NXentry
947 nxentry = NXentry()
948
949 # Add an NXentry to the NXroot
950 nxroot[self.title] = nxentry
951 nxroot.attrs['default'] = self.title
952 nxentry.definition = 'NXtomo'
953 # nxentry.attrs['default'] = 'data'
954
955 # Add an NXinstrument to the NXentry
956 nxinstrument = NXinstrument()
957 nxentry.instrument = nxinstrument
958
959 # Add an NXsource to the NXinstrument
960 nxsource = NXsource()
961 nxinstrument.source = nxsource
962 nxsource.type = 'Synchrotron X-ray Source'
963 nxsource.name = 'CHESS'
964 nxsource.probe = 'x-ray'
965
966 # Tag the NXsource with the runinfo (as an attribute)
967 nxsource.attrs['cycle'] = self.cycle
968 nxsource.attrs['btr'] = self.btr
969 nxsource.attrs['station'] = self.station
970
971 # Add an NXdetector to the NXinstrument (don't fill in data fields yet)
972 nxinstrument.detector = self.detector.construct_nxdetector()
973
974 # Add an NXsample to NXentry (don't fill in data fields yet)
975 nxsample = NXsample()
976 nxentry.sample = nxsample
977 nxsample.name = self.sample.name
978 nxsample.description = self.sample.description
979
980 # Add an NXcollection to the base NXentry to hold metadata about the spec scans in the map
981 # Also obtain the data fields in NXsample and NXdetector
982 nxspec_scans = NXcollection()
983 nxentry.spec_scans = nxspec_scans
984 image_keys = []
985 sequence_numbers = []
986 image_stacks = []
987 rotation_angles = []
988 x_translations = []
989 z_translations = []
990 for field_type in self._field_types:
991 field_name = field_type['name']
992 field = getattr(self, field_name)
993 if field is None:
994 continue
995 image_key = field_type['image_key']
996 if field_type['name'] == 'tomo_fields':
997 thetas = self.thetas
998 else:
999 thetas = None
1000 # Add the scans in a single spec file
1001 nxspec_scans[field_name] = field.construct_nxcollection(image_key, thetas,
1002 self.detector)
1003 if include_raw_data:
1004 image_stacks.append(field.get_detector_data(self.detector.prefix))
1005 for scan_number in field.scan_numbers:
1006 parser = field.get_scanparser(scan_number)
1007 scan_info = field.stack_info[field.get_scan_index(scan_number)]
1008 num_image = scan_info['num_image']
1009 image_keys += num_image*[image_key]
1010 sequence_numbers += [i for i in range(num_image)]
1011 if thetas is None:
1012 rotation_angles += scan_info['num_image']*[0.0]
1013 else:
1014 assert(num_image == len(thetas))
1015 rotation_angles += thetas
1016 x_translations += scan_info['num_image']*[scan_info['ref_x']]
1017 z_translations += scan_info['num_image']*[scan_info['ref_z']]
1018
1019 if include_raw_data:
1020 # Add image data to NXdetector
1021 nxinstrument.detector.image_key = image_keys
1022 nxinstrument.detector.sequence_number = sequence_numbers
1023 nxinstrument.detector.data = np.concatenate([image for image in image_stacks])
1024
1025 # Add image data to NXsample
1026 nxsample.rotation_angle = rotation_angles
1027 nxsample.rotation_angle.attrs['units'] = 'degrees'
1028 nxsample.x_translation = x_translations
1029 nxsample.x_translation.attrs['units'] = 'mm'
1030 nxsample.z_translation = z_translations
1031 nxsample.z_translation.attrs['units'] = 'mm'
1032
1033 # Add an NXdata to NXentry
1034 nxdata = NXdata()
1035 nxentry.data = nxdata
1036 nxdata.makelink(nxentry.instrument.detector.data, name='data')
1037 nxdata.makelink(nxentry.instrument.detector.image_key)
1038 nxdata.makelink(nxentry.sample.rotation_angle)
1039 nxdata.makelink(nxentry.sample.x_translation)
1040 nxdata.makelink(nxentry.sample.z_translation)
1041 # nxdata.attrs['axes'] = ['field', 'row', 'column']
1042 # nxdata.attrs['field_indices'] = 0
1043 # nxdata.attrs['row_indices'] = 1
1044 # nxdata.attrs['column_indices'] = 2
1045
1046
1047 class TomoWorkflow(BaseModel):
1048 sample_maps: conlist(item_type=MapConfig, min_items=1) = [MapConfig.construct()]
1049
1050 @classmethod
1051 def construct_from_nexus(cls, filename):
1052 nxroot = nxload(filename)
1053 sample_maps = []
1054 config = {'sample_maps': sample_maps}
1055 for nxentry_name, nxentry in nxroot.items():
1056 sample_maps.append(MapConfig.construct_from_nxentry(nxentry))
1057 return(cls(**config))
1058
1059 def cli(self):
1060 print('\n -- Configure a map -- ')
1061 self.set_list_attr_cli('sample_maps', 'sample map')
1062
1063 def construct_nxfile(self, filename, mode='w-'):
1064 nxroot = NXroot()
1065 t0 = time()
1066 for sample_map in self.sample_maps:
1067 logger.info(f'Start constructing the {sample_map.title} map.')
1068 import_scanparser(sample_map.station)
1069 sample_map.construct_nxentry(nxroot)
1070 logger.info(f'Constructed all sample maps in {time()-t0:.2f} seconds.')
1071 logger.info(f'Start saving all sample maps to {filename}.')
1072 nxroot.save(filename, mode=mode)
1073
1074 def write_to_nexus(self, filename):
1075 t0 = time()
1076 self.construct_nxfile(filename, mode='w')
1077 logger.info(f'Saved all sample maps to {filename} in {time()-t0:.2f} seconds.')