comparison workflow/models.py @ 0:98e23dff1de2 draft default tip

planemo upload for repository https://github.com/rolfverberg/galaxytools commit f8c4bdb31c20c468045ad5e6eb255a293244bc6c-dirty
author rv43
date Tue, 21 Mar 2023 16:22:42 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:98e23dff1de2
1 #!/usr/bin/env python3
2
3 import logging
4 logger = logging.getLogger(__name__)
5
6 import logging
7
8 import numpy as np
9 import os
10 import yaml
11
12 from functools import cache
13 from pathlib import PosixPath
14 from pydantic import BaseModel as PydanticBaseModel
15 from pydantic import validator, ValidationError, conint, confloat, constr, conlist, FilePath, \
16 PrivateAttr
17 from nexusformat.nexus import *
18 from time import time
19 from typing import Optional, Literal
20 from typing_extensions import TypedDict
21 try:
22 from pyspec.file.spec import FileSpec
23 except:
24 pass
25
26 try:
27 from msnctools.general import is_int, is_num, input_int, input_int_list, input_num, \
28 input_yesno, input_menu, index_nearest, string_to_list, file_exists_and_readable
29 except:
30 from general import is_int, is_num, input_int, input_int_list, input_num, \
31 input_yesno, input_menu, index_nearest, string_to_list, file_exists_and_readable
32
33
34 def import_scanparser(station):
35 if station in ('id1a3', 'id3a'):
36 try:
37 from msnctools.scanparsers import SMBRotationScanParser
38 globals()['ScanParser'] = SMBRotationScanParser
39 except:
40 try:
41 from scanparsers import SMBRotationScanParser
42 globals()['ScanParser'] = SMBRotationScanParser
43 except:
44 pass
45 elif station in ('id3b'):
46 try:
47 from msnctools.scanparsers import FMBRotationScanParser
48 globals()['ScanParser'] = FMBRotationScanParser
49 except:
50 try:
51 from scanparsers import FMBRotationScanParser
52 globals()['ScanParser'] = FMBRotationScanParser
53 except:
54 pass
55 else:
56 raise RuntimeError(f'Invalid station: {station}')
57
58 @cache
59 def get_available_scan_numbers(spec_file:str):
60 scans = FileSpec(spec_file).scans
61 scan_numbers = list(scans.keys())
62 for scan_number in scan_numbers.copy():
63 try:
64 parser = ScanParser(spec_file, scan_number)
65 try:
66 scan_type = parser.scan_type
67 except:
68 scan_type = None
69 except:
70 scan_numbers.remove(scan_number)
71 return(scan_numbers)
72
73 @cache
74 def get_scanparser(spec_file:str, scan_number:int):
75 if scan_number not in get_available_scan_numbers(spec_file):
76 return(None)
77 else:
78 return(ScanParser(spec_file, scan_number))
79
80
81 class BaseModel(PydanticBaseModel):
82 class Config:
83 validate_assignment = True
84 arbitrary_types_allowed = True
85
86 @classmethod
87 def construct_from_cli(cls):
88 obj = cls.construct()
89 obj.cli()
90 return(obj)
91
92 @classmethod
93 def construct_from_yaml(cls, filename):
94 try:
95 with open(filename, 'r') as infile:
96 indict = yaml.load(infile, Loader=yaml.CLoader)
97 except:
98 raise ValueError(f'Could not load a dictionary from {filename}')
99 else:
100 obj = cls(**indict)
101 return(obj)
102
103 @classmethod
104 def construct_from_file(cls, filename):
105 file_exists_and_readable(filename)
106 filename = os.path.abspath(filename)
107 fileformat = os.path.splitext(filename)[1]
108 yaml_extensions = ('.yaml','.yml')
109 nexus_extensions = ('.nxs','.nx5','.h5','.hdf5')
110 t0 = time()
111 if fileformat.lower() in yaml_extensions:
112 obj = cls.construct_from_yaml(filename)
113 logger.info(f'Constructed a model from {filename} in {time()-t0:.2f} seconds.')
114 return(obj)
115 elif fileformat.lower() in nexus_extensions:
116 obj = cls.construct_from_nexus(filename)
117 logger.info(f'Constructed a model from {filename} in {time()-t0:.2f} seconds.')
118 return(obj)
119 else:
120 logger.error(f'Unsupported file extension for constructing a model: {fileformat}')
121 raise TypeError(f'Unrecognized file extension: {fileformat}')
122
123 def dict_for_yaml(self, exclude_fields=[]):
124 yaml_dict = {}
125 for field_name in self.__fields__:
126 if field_name in exclude_fields:
127 continue
128 else:
129 field_value = getattr(self, field_name, None)
130 if field_value is not None:
131 if isinstance(field_value, BaseModel):
132 yaml_dict[field_name] = field_value.dict_for_yaml()
133 elif isinstance(field_value,list) and all(isinstance(item,BaseModel)
134 for item in field_value):
135 yaml_dict[field_name] = [item.dict_for_yaml() for item in field_value]
136 elif isinstance(field_value, PosixPath):
137 yaml_dict[field_name] = str(field_value)
138 else:
139 yaml_dict[field_name] = field_value
140 else:
141 continue
142 return(yaml_dict)
143
144 def write_to_yaml(self, filename=None):
145 yaml_dict = self.dict_for_yaml()
146 if filename is None:
147 logger.info('Printing yaml representation here:\n'+
148 f'{yaml.dump(yaml_dict, sort_keys=False)}')
149 else:
150 try:
151 with open(filename, 'w') as outfile:
152 yaml.dump(yaml_dict, outfile, sort_keys=False)
153 logger.info(f'Successfully wrote this model to {filename}')
154 except:
155 logger.error(f'Unknown error -- could not write to {filename} in yaml format.')
156 logger.info('Printing yaml representation here:\n'+
157 f'{yaml.dump(yaml_dict, sort_keys=False)}')
158
159 def write_to_file(self, filename, force_overwrite=False):
160 file_writeable, fileformat = self.output_file_valid(filename,
161 force_overwrite=force_overwrite)
162 if fileformat == 'yaml':
163 if file_writeable:
164 self.write_to_yaml(filename=filename)
165 else:
166 self.write_to_yaml()
167 elif fileformat == 'nexus':
168 if file_writeable:
169 self.write_to_nexus(filename=filename)
170
171 def output_file_valid(self, filename, force_overwrite=False):
172 filename = os.path.abspath(filename)
173 fileformat = os.path.splitext(filename)[1]
174 yaml_extensions = ('.yaml','.yml')
175 nexus_extensions = ('.nxs','.nx5','.h5','.hdf5')
176 if fileformat.lower() not in (*yaml_extensions, *nexus_extensions):
177 return(False, None) # Only yaml and NeXus files allowed for output now.
178 elif fileformat.lower() in yaml_extensions:
179 fileformat = 'yaml'
180 elif fileformat.lower() in nexus_extensions:
181 fileformat = 'nexus'
182 if os.path.isfile(filename):
183 if os.access(filename, os.W_OK):
184 if not force_overwrite:
185 logger.error(f'{filename} will not be overwritten.')
186 return(False, fileformat)
187 else:
188 logger.error(f'Cannot access {filename} for writing.')
189 return(False, fileformat)
190 if os.path.isdir(os.path.dirname(filename)):
191 if os.access(os.path.dirname(filename), os.W_OK):
192 return(True, fileformat)
193 else:
194 logger.error(f'Cannot access {os.path.dirname(filename)} for writing.')
195 return(False, fileformat)
196 else:
197 try:
198 os.makedirs(os.path.dirname(filename))
199 return(True, fileformat)
200 except:
201 logger.error(f'Cannot create {os.path.dirname(filename)} for output.')
202 return(False, fileformat)
203
204 def set_single_attr_cli(self, attr_name, attr_desc='unknown attribute', list_flag=False,
205 **cli_kwargs):
206 if cli_kwargs.get('chain_attr_desc', False):
207 cli_kwargs['attr_desc'] = attr_desc
208 try:
209 attr = getattr(self, attr_name, None)
210 if attr is None:
211 attr = self.__fields__[attr_name].type_.construct()
212 if cli_kwargs.get('chain_attr_desc', False):
213 cli_kwargs['attr_desc'] = attr_desc
214 input_accepted = False
215 while not input_accepted:
216 try:
217 attr.cli(**cli_kwargs)
218 except ValidationError as e:
219 print(e)
220 print(f'Removing {attr_desc} configuration')
221 attr = self.__fields__[attr_name].type_.construct()
222 continue
223 except KeyboardInterrupt as e:
224 raise e
225 except BaseException as e:
226 print(f'{type(e).__name__}: {e}')
227 print(f'Removing {attr_desc} configuration')
228 attr = self.__fields__[attr_name].type_.construct()
229 continue
230 try:
231 setattr(self, attr_name, attr)
232 except ValidationError as e:
233 print(e)
234 except KeyboardInterrupt as e:
235 raise e
236 except BaseException as e:
237 print(f'{type(e).__name__}: {e}')
238 else:
239 input_accepted = True
240 except:
241 input_accepted = False
242 while not input_accepted:
243 attr = getattr(self, attr_name, None)
244 if attr is None:
245 input_value = input(f'Type and enter a value for {attr_desc}: ')
246 else:
247 input_value = input(f'Type and enter a new value for {attr_desc} or press '+
248 f'enter to keep the current one ({attr}): ')
249 if list_flag:
250 input_value = string_to_list(input_value, remove_duplicates=False, sort=False)
251 if len(input_value) == 0:
252 input_value = getattr(self, attr_name, None)
253 try:
254 setattr(self, attr_name, input_value)
255 except ValidationError as e:
256 print(e)
257 except KeyboardInterrupt as e:
258 raise e
259 except BaseException as e:
260 print(f'Unexpected {type(e).__name__}: {e}')
261 else:
262 input_accepted = True
263
264 def set_list_attr_cli(self, attr_name, attr_desc='unknown attribute', **cli_kwargs):
265 if cli_kwargs.get('chain_attr_desc', False):
266 cli_kwargs['attr_desc'] = attr_desc
267 attr = getattr(self, attr_name, None)
268 if attr is not None:
269 # Check existing items
270 for item in attr:
271 item_accepted = False
272 while not item_accepted:
273 item.cli(**cli_kwargs)
274 try:
275 setattr(self, attr_name, attr)
276 except ValidationError as e:
277 print(e)
278 except KeyboardInterrupt as e:
279 raise e
280 except BaseException as e:
281 print(f'{type(e).__name__}: {e}')
282 else:
283 item_accepted = True
284 else:
285 # Initialize list for new attribute & starting item
286 attr = []
287 item = self.__fields__[attr_name].type_.construct()
288 # Append (optional) additional items
289 append = input_yesno(f'Add a {attr_desc} configuration? (y/n)', 'n')
290 while append:
291 attr.append(item.__class__.construct_from_cli())
292 try:
293 setattr(self, attr_name, attr)
294 except ValidationError as e:
295 print(e)
296 print(f'Removing last {attr_desc} configuration from the list')
297 attr.pop()
298 except KeyboardInterrupt as e:
299 raise e
300 except BaseException as e:
301 print(f'{type(e).__name__}: {e}')
302 print(f'Removing last {attr_desc} configuration from the list')
303 attr.pop()
304 else:
305 append = input_yesno(f'Add another {attr_desc} configuration? (y/n)', 'n')
306
307
308 class Detector(BaseModel):
309 prefix: constr(strip_whitespace=True, min_length=1)
310 rows: conint(gt=0)
311 columns: conint(gt=0)
312 pixel_size: conlist(item_type=confloat(gt=0), min_items=1, max_items=2)
313 lens_magnification: confloat(gt=0) = 1.0
314
315 @property
316 def get_pixel_size(self):
317 return(list(np.asarray(self.pixel_size)/self.lens_magnification))
318
319 def construct_from_yaml(self, filename):
320 try:
321 with open(filename, 'r') as infile:
322 indict = yaml.load(infile, Loader=yaml.CLoader)
323 detector = indict['detector']
324 self.prefix = detector['id']
325 pixels = detector['pixels']
326 self.rows = pixels['rows']
327 self.columns = pixels['columns']
328 self.pixel_size = pixels['size']
329 self.lens_magnification = indict['lens_magnification']
330 except:
331 logging.warning(f'Could not load a dictionary from {filename}')
332 return(False)
333 else:
334 return(True)
335
336 def cli(self):
337 print('\n -- Configure the detector -- ')
338 self.set_single_attr_cli('prefix', 'detector ID')
339 self.set_single_attr_cli('rows', 'number of pixel rows')
340 self.set_single_attr_cli('columns', 'number of pixel columns')
341 self.set_single_attr_cli('pixel_size', 'pixel size in mm (enter either a single value for '+
342 'square pixels or a pair of values for the size in the respective row and column '+
343 'directions)', list_flag=True)
344 self.set_single_attr_cli('lens_magnification', 'lens magnification')
345
346 def construct_nxdetector(self):
347 nxdetector = NXdetector()
348 nxdetector.local_name = self.prefix
349 pixel_size = self.get_pixel_size
350 if len(pixel_size) == 1:
351 nxdetector.x_pixel_size = pixel_size[0]
352 nxdetector.y_pixel_size = pixel_size[0]
353 else:
354 nxdetector.x_pixel_size = pixel_size[0]
355 nxdetector.y_pixel_size = pixel_size[1]
356 nxdetector.x_pixel_size.attrs['units'] = 'mm'
357 nxdetector.y_pixel_size.attrs['units'] = 'mm'
358 return(nxdetector)
359
360
361 class ScanInfo(TypedDict):
362 scan_number: int
363 starting_image_offset: conint(ge=0)
364 num_image: conint(gt=0)
365 ref_x: float
366 ref_z: float
367
368 class SpecScans(BaseModel):
369 spec_file: FilePath
370 scan_numbers: conlist(item_type=conint(gt=0), min_items=1)
371 stack_info: conlist(item_type=ScanInfo, min_items=1) = []
372
373 @validator('spec_file')
374 def validate_spec_file(cls, spec_file):
375 try:
376 spec_file = os.path.abspath(spec_file)
377 sspec_file = FileSpec(spec_file)
378 except:
379 raise ValueError(f'Invalid SPEC file {spec_file}')
380 else:
381 return(spec_file)
382
383 @validator('scan_numbers')
384 def validate_scan_numbers(cls, scan_numbers, values):
385 spec_file = values.get('spec_file')
386 if spec_file is not None:
387 spec_scans = FileSpec(spec_file)
388 for scan_number in scan_numbers:
389 scan = spec_scans.get_scan_by_number(scan_number)
390 if scan is None:
391 raise ValueError(f'There is no scan number {scan_number} in {spec_file}')
392 return(scan_numbers)
393
394 @validator('stack_info')
395 def validate_stack_info(cls, stack_info, values):
396 scan_numbers = values.get('scan_numbers')
397 assert(len(scan_numbers) == len(stack_info))
398 for scan_info in stack_info:
399 assert(scan_info['scan_number'] in scan_numbers)
400 is_int(scan_info['starting_image_offset'], ge=0, lt=scan_info['num_image'],
401 raise_error=True)
402 return(stack_info)
403
404 @classmethod
405 def construct_from_nxcollection(cls, nxcollection:NXcollection):
406 config = {}
407 config['spec_file'] = nxcollection.attrs['spec_file']
408 scan_numbers = []
409 stack_info = []
410 for nxsubentry_name, nxsubentry in nxcollection.items():
411 scan_number = int(nxsubentry_name.split('_')[-1])
412 scan_numbers.append(scan_number)
413 stack_info.append({'scan_number': scan_number,
414 'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number),
415 'num_image': len(nxsubentry.sample.rotation_angle),
416 'ref_x': float(nxsubentry.sample.x_translation),
417 'ref_z': float(nxsubentry.sample.z_translation)})
418 config['scan_numbers'] = sorted(scan_numbers)
419 config['stack_info'] = stack_info
420 return(cls(**config))
421
422 @property
423 def available_scan_numbers(self):
424 return(get_available_scan_numbers(self.spec_file))
425
426 def set_from_nxcollection(self, nxcollection:NXcollection):
427 self.spec_file = nxcollection.attrs['spec_file']
428 scan_numbers = []
429 stack_info = []
430 for nxsubentry_name, nxsubentry in nxcollection.items():
431 scan_number = int(nxsubentry_name.split('_')[-1])
432 scan_numbers.append(scan_number)
433 stack_info.append({'scan_number': scan_number,
434 'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number),
435 'num_image': len(nxsubentry.sample.rotation_angle),
436 'ref_x': float(nxsubentry.sample.x_translation),
437 'ref_z': float(nxsubentry.sample.z_translation)})
438 self.scan_numbers = sorted(scan_numbers)
439 self.stack_info = stack_info
440
441 def get_scan_index(self, scan_number):
442 scan_index = [scan_index for scan_index, scan_info in enumerate(self.stack_info)
443 if scan_info['scan_number'] == scan_number]
444 if len(scan_index) > 1:
445 raise ValueError('Duplicate scan_numbers in image stack')
446 elif len(scan_index) == 1:
447 return(scan_index[0])
448 else:
449 return(None)
450
451 def get_scanparser(self, scan_number):
452 return(get_scanparser(self.spec_file, scan_number))
453
454 def get_detector_data(self, detector_prefix, scan_number=None, scan_step_index=None):
455 image_stacks = []
456 if scan_number is None:
457 scan_numbers = self.scan_numbers
458 else:
459 scan_numbers = [scan_number]
460 for scan_number in scan_numbers:
461 parser = self.get_scanparser(scan_number)
462 scan_info = self.stack_info[self.get_scan_index(scan_number)]
463 image_offset = scan_info['starting_image_offset']
464 if scan_step_index is None:
465 num_image = scan_info['num_image']
466 image_stacks.append(parser.get_detector_data(detector_prefix,
467 (image_offset, image_offset+num_image)))
468 else:
469 image_stacks.append(parser.get_detector_data(detector_prefix,
470 image_offset+scan_step_index))
471 if scan_number is not None and scan_step_index is not None:
472 # Return a single image for a specific scan_number and scan_step_index request
473 return(image_stacks[0])
474 else:
475 # Return a list otherwise
476 return(image_stacks)
477 return(image_stacks)
478
479 def scan_numbers_cli(self, attr_desc, **kwargs):
480 available_scan_numbers = self.available_scan_numbers
481 station = kwargs.get('station')
482 if (station is not None and station in ('id1a3', 'id3a') and
483 'scan_type' in kwargs):
484 scan_type = kwargs['scan_type']
485 if scan_type == 'ts1':
486 available_scan_numbers = []
487 for scan_number in self.available_scan_numbers:
488 parser = self.get_scanparser(scan_number)
489 try:
490 if parser.scan_type == scan_type:
491 available_scan_numbers.append(scan_number)
492 except:
493 pass
494 elif scan_type == 'df1':
495 tomo_scan_numbers = kwargs['tomo_scan_numbers']
496 available_scan_numbers = []
497 for scan_number in tomo_scan_numbers:
498 parser = self.get_scanparser(scan_number-2)
499 assert(parser.scan_type == scan_type)
500 available_scan_numbers.append(scan_number-2)
501 elif scan_type == 'bf1':
502 tomo_scan_numbers = kwargs['tomo_scan_numbers']
503 available_scan_numbers = []
504 for scan_number in tomo_scan_numbers:
505 parser = self.get_scanparser(scan_number-1)
506 assert(parser.scan_type == scan_type)
507 available_scan_numbers.append(scan_number-1)
508 if len(available_scan_numbers) == 1:
509 input_mode = 1
510 else:
511 if hasattr(self, 'scan_numbers'):
512 print(f'Currently selected {attr_desc}scan numbers are: {self.scan_numbers}')
513 menu_options = [f'Select a subset of the available {attr_desc}scan numbers',
514 f'Use all available {attr_desc}scan numbers in {self.spec_file}',
515 f'Keep the currently selected {attr_desc}scan numbers']
516 else:
517 menu_options = [f'Select a subset of the available {attr_desc}scan numbers',
518 f'Use all available {attr_desc}scan numbers in {self.spec_file}']
519 print(f'Available scan numbers in {self.spec_file} are: '+
520 f'{available_scan_numbers}')
521 input_mode = input_menu(menu_options, header='Choose one of the following options '+
522 'for selecting scan numbers')
523 if input_mode == 0:
524 accept_scan_numbers = False
525 while not accept_scan_numbers:
526 try:
527 self.scan_numbers = \
528 input_int_list(f'Enter a series of {attr_desc}scan numbers')
529 except ValidationError as e:
530 print(e)
531 except KeyboardInterrupt as e:
532 raise e
533 except BaseException as e:
534 print(f'Unexpected {type(e).__name__}: {e}')
535 else:
536 accept_scan_numbers = True
537 elif input_mode == 1:
538 self.scan_numbers = available_scan_numbers
539 elif input_mode == 2:
540 pass
541
542 def cli(self, **cli_kwargs):
543 if cli_kwargs.get('attr_desc') is not None:
544 attr_desc = f'{cli_kwargs["attr_desc"]} '
545 else:
546 attr_desc = ''
547 print(f'\n -- Configure which scans to use from a single {attr_desc}SPEC file')
548 self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path')
549 self.scan_numbers_cli(attr_desc)
550
551 def construct_nxcollection(self, image_key, thetas, detector):
552 nxcollection = NXcollection()
553 nxcollection.attrs['spec_file'] = str(self.spec_file)
554 parser = self.get_scanparser(self.scan_numbers[0])
555 nxcollection.attrs['date'] = parser.spec_scan.file_date
556 for scan_number in self.scan_numbers:
557 # Get scan info
558 scan_info = self.stack_info[self.get_scan_index(scan_number)]
559 # Add an NXsubentry to the NXcollection for each scan
560 entry_name = f'scan_{scan_number}'
561 nxsubentry = NXsubentry()
562 nxcollection[entry_name] = nxsubentry
563 parser = self.get_scanparser(scan_number)
564 nxsubentry.start_time = parser.spec_scan.date
565 nxsubentry.spec_command = parser.spec_command
566 # Add an NXdata for independent dimensions to the scan's NXsubentry
567 num_image = scan_info['num_image']
568 if thetas is None:
569 thetas = num_image*[0.0]
570 else:
571 assert(num_image == len(thetas))
572 # nxsubentry.independent_dimensions = NXdata()
573 # nxsubentry.independent_dimensions.rotation_angle = thetas
574 # nxsubentry.independent_dimensions.rotation_angle.units = 'degrees'
575 # Add an NXinstrument to the scan's NXsubentry
576 nxsubentry.instrument = NXinstrument()
577 # Add an NXdetector to the NXinstrument to the scan's NXsubentry
578 nxsubentry.instrument.detector = detector.construct_nxdetector()
579 nxsubentry.instrument.detector.frame_start_number = scan_info['starting_image_offset']
580 nxsubentry.instrument.detector.image_key = image_key
581 # Add an NXsample to the scan's NXsubentry
582 nxsubentry.sample = NXsample()
583 nxsubentry.sample.rotation_angle = thetas
584 nxsubentry.sample.rotation_angle.units = 'degrees'
585 nxsubentry.sample.x_translation = scan_info['ref_x']
586 nxsubentry.sample.x_translation.units = 'mm'
587 nxsubentry.sample.z_translation = scan_info['ref_z']
588 nxsubentry.sample.z_translation.units = 'mm'
589 return(nxcollection)
590
591
592 class FlatField(SpecScans):
593
594 def image_range_cli(self, attr_desc, detector_prefix):
595 stack_info = self.stack_info
596 for scan_number in self.scan_numbers:
597 # Parse the available image range
598 parser = self.get_scanparser(scan_number)
599 image_offset = parser.starting_image_offset
600 num_image = parser.get_num_image(detector_prefix.upper())
601 scan_index = self.get_scan_index(scan_number)
602
603 # Select the image set
604 last_image_index = image_offset+num_image
605 print(f'Available good image set index range: [{image_offset}, {last_image_index})')
606 image_set_approved = False
607 if scan_index is not None:
608 scan_info = stack_info[scan_index]
609 print(f'Current starting image offset and number of images: '+
610 f'{scan_info["starting_image_offset"]} and {scan_info["num_image"]}')
611 image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y')
612 if not image_set_approved:
613 print(f'Default starting image offset and number of images: '+
614 f'{image_offset} and {num_image}')
615 image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y')
616 if image_set_approved:
617 offset = image_offset
618 num = last_image_index-offset
619 while not image_set_approved:
620 offset = input_int(f'Enter the starting image offset', ge=image_offset,
621 lt=last_image_index)#, default=image_offset)
622 num = input_int(f'Enter the number of images', ge=1,
623 le=last_image_index-offset)#, default=last_image_index-offset)
624 print(f'Current starting image offset and number of images: {offset} and {num}')
625 image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y')
626 if scan_index is not None:
627 scan_info['starting_image_offset'] = offset
628 scan_info['num_image'] = num
629 scan_info['ref_x'] = parser.horizontal_shift
630 scan_info['ref_z'] = parser.vertical_shift
631 else:
632 stack_info.append({'scan_number': scan_number, 'starting_image_offset': offset,
633 'num_image': num, 'ref_x': parser.horizontal_shift,
634 'ref_z': parser.vertical_shift})
635 self.stack_info = stack_info
636
637 def cli(self, **cli_kwargs):
638 if cli_kwargs.get('attr_desc') is not None:
639 attr_desc = f'{cli_kwargs["attr_desc"]} '
640 else:
641 attr_desc = ''
642 station = cli_kwargs.get('station')
643 detector = cli_kwargs.get('detector')
644 print(f'\n -- Configure the location of the {attr_desc}scan data -- ')
645 if station in ('id1a3', 'id3a'):
646 self.spec_file = cli_kwargs['spec_file']
647 tomo_scan_numbers = cli_kwargs['tomo_scan_numbers']
648 scan_type = cli_kwargs['scan_type']
649 self.scan_numbers_cli(attr_desc, station=station, tomo_scan_numbers=tomo_scan_numbers,
650 scan_type=scan_type)
651 else:
652 self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path')
653 self.scan_numbers_cli(attr_desc)
654 self.image_range_cli(attr_desc, detector.prefix)
655
656
657 class TomoField(SpecScans):
658 theta_range: dict = {}
659
660 @validator('theta_range')
661 def validate_theta_range(cls, theta_range):
662 if len(theta_range) != 3 and len(theta_range) != 4:
663 raise ValueError(f'Invalid theta range {theta_range}')
664 is_num(theta_range['start'], raise_error=True)
665 is_num(theta_range['end'], raise_error=True)
666 is_int(theta_range['num'], gt=1, raise_error=True)
667 if theta_range['end'] <= theta_range['start']:
668 raise ValueError(f'Invalid theta range {theta_range}')
669 if 'start_index' in theta_range:
670 is_int(theta_range['start_index'], ge=0, raise_error=True)
671 return(theta_range)
672
673 @classmethod
674 def construct_from_nxcollection(cls, nxcollection:NXcollection):
675 #RV Can I derive this from the same classfunction for SpecScans by adding theta_range
676 config = {}
677 config['spec_file'] = nxcollection.attrs['spec_file']
678 scan_numbers = []
679 stack_info = []
680 for nxsubentry_name, nxsubentry in nxcollection.items():
681 scan_number = int(nxsubentry_name.split('_')[-1])
682 scan_numbers.append(scan_number)
683 stack_info.append({'scan_number': scan_number,
684 'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number),
685 'num_image': len(nxsubentry.sample.rotation_angle),
686 'ref_x': float(nxsubentry.sample.x_translation),
687 'ref_z': float(nxsubentry.sample.z_translation)})
688 config['scan_numbers'] = sorted(scan_numbers)
689 config['stack_info'] = stack_info
690 for name in nxcollection.entries:
691 if 'scan_' in name:
692 thetas = np.asarray(nxcollection[name].sample.rotation_angle)
693 config['theta_range'] = {'start': thetas[0], 'end': thetas[-1], 'num': thetas.size}
694 break
695 return(cls(**config))
696
697 def get_horizontal_shifts(self, scan_number=None):
698 horizontal_shifts = []
699 if scan_number is None:
700 scan_numbers = self.scan_numbers
701 else:
702 scan_numbers = [scan_number]
703 for scan_number in scan_numbers:
704 parser = self.get_scanparser(scan_number)
705 horizontal_shifts.append(parser.horizontal_shift)
706 if len(horizontal_shifts) == 1:
707 return(horizontal_shifts[0])
708 else:
709 return(horizontal_shifts)
710
711 def get_vertical_shifts(self, scan_number=None):
712 vertical_shifts = []
713 if scan_number is None:
714 scan_numbers = self.scan_numbers
715 else:
716 scan_numbers = [scan_number]
717 for scan_number in scan_numbers:
718 parser = self.get_scanparser(scan_number)
719 vertical_shifts.append(parser.vertical_shift)
720 if len(vertical_shifts) == 1:
721 return(vertical_shifts[0])
722 else:
723 return(vertical_shifts)
724
725 def theta_range_cli(self, scan_number, attr_desc, station):
726 # Parse the available theta range
727 parser = self.get_scanparser(scan_number)
728 theta_vals = parser.theta_vals
729 spec_theta_start = theta_vals.get('start')
730 spec_theta_end = theta_vals.get('end')
731 spec_num_theta = theta_vals.get('num')
732
733 # Check for consistency of theta ranges between scans
734 if scan_number != self.scan_numbers[0]:
735 parser = self.get_scanparser(self.scan_numbers[0])
736 if (parser.theta_vals.get('start') != spec_theta_start or
737 parser.theta_vals.get('end') != spec_theta_end or
738 parser.theta_vals.get('num') != spec_num_theta):
739 raise ValueError(f'Incompatible theta ranges between {attr_desc}scans:'+
740 f'\n\tScan {scan_number}: {theta_vals}'+
741 f'\n\tScan {self.scan_numbers[0]}: {parser.theta_vals}')
742 return
743
744 # Select the theta range for the tomo reconstruction from the first scan
745 theta_range_approved = False
746 thetas = np.linspace(spec_theta_start, spec_theta_end, spec_num_theta)
747 delta_theta = thetas[1]-thetas[0]
748 print(f'Theta range obtained from SPEC data: [{spec_theta_start}, {spec_theta_end}]')
749 print(f'Theta step size = {delta_theta}')
750 print(f'Number of theta values: {spec_num_theta}')
751 default_start = None
752 default_end = None
753 if station in ('id1a3', 'id3a'):
754 theta_range_approved = input_yesno(f'Accept this theta range (y/n)?', 'y')
755 if theta_range_approved:
756 self.theta_range = {'start': float(spec_theta_start), 'end': float(spec_theta_end),
757 'num': int(spec_num_theta), 'start_index': 0}
758 return
759 elif station in ('id3b'):
760 if spec_theta_start <= 0.0 and spec_theta_end >= 180.0:
761 default_start = 0
762 default_end = 180
763 elif spec_theta_end-spec_theta_start == 180:
764 default_start = spec_theta_start
765 default_end = spec_theta_end
766 while not theta_range_approved:
767 theta_start = input_num(f'Enter the first theta (included)', ge=spec_theta_start,
768 lt=spec_theta_end, default=default_start)
769 theta_index_start = index_nearest(thetas, theta_start)
770 theta_start = thetas[theta_index_start]
771 theta_end = input_num(f'Enter the last theta (excluded)',
772 ge=theta_start+delta_theta, le=spec_theta_end, default=default_end)
773 theta_index_end = index_nearest(thetas, theta_end)
774 theta_end = thetas[theta_index_end]
775 num_theta = theta_index_end-theta_index_start
776 print(f'Selected theta range: [{theta_start}, {theta_start+delta_theta}, ..., '+
777 f'{theta_end})')
778 print(f'Number of theta values: {num_theta}')
779 theta_range_approved = input_yesno(f'Accept this theta range (y/n)?', 'y')
780 self.theta_range = {'start': float(theta_start), 'end': float(theta_end),
781 'num': int(num_theta), 'start_index': int(theta_index_start)}
782
783 def image_range_cli(self, attr_desc, detector_prefix):
784 stack_info = self.stack_info
785 for scan_number in self.scan_numbers:
786 # Parse the available image range
787 parser = self.get_scanparser(scan_number)
788 image_offset = parser.starting_image_offset
789 num_image = parser.get_num_image(detector_prefix.upper())
790 scan_index = self.get_scan_index(scan_number)
791
792 # Select the image set matching the theta range
793 num_theta = self.theta_range['num']
794 theta_index_start = self.theta_range['start_index']
795 if num_theta > num_image-theta_index_start:
796 raise ValueError(f'Available {attr_desc}image indices incompatible with thetas:'+
797 f'\n\tNumber of thetas and offset = {num_theta} and {theta_index_start}'+
798 f'\n\tNumber of available images {num_image}')
799 if scan_index is not None:
800 scan_info = stack_info[scan_index]
801 scan_info['starting_image_offset'] = image_offset+theta_index_start
802 scan_info['num_image'] = num_theta
803 scan_info['ref_x'] = parser.horizontal_shift
804 scan_info['ref_z'] = parser.vertical_shift
805 else:
806 stack_info.append({'scan_number': scan_number,
807 'starting_image_offset': image_offset+theta_index_start,
808 'num_image': num_theta, 'ref_x': parser.horizontal_shift,
809 'ref_z': parser.vertical_shift})
810 self.stack_info = stack_info
811
812 def cli(self, **cli_kwargs):
813 if cli_kwargs.get('attr_desc') is not None:
814 attr_desc = f'{cli_kwargs["attr_desc"]} '
815 else:
816 attr_desc = ''
817 cycle = cli_kwargs.get('cycle')
818 btr = cli_kwargs.get('btr')
819 station = cli_kwargs.get('station')
820 detector = cli_kwargs.get('detector')
821 sample_name = cli_kwargs.get('sample_name')
822 print(f'\n -- Configure the location of the {attr_desc}scan data -- ')
823 if station in ('id1a3', 'id3a'):
824 basedir = f'/nfs/chess/{station}/{cycle}/{btr}'
825 runs = [d for d in os.listdir(basedir) if os.path.isdir(os.path.join(basedir, d))]
826 #RV index = 15-1
827 #RV index = 7-1
828 if sample_name is not None and sample_name in runs:
829 index = runs.index(sample_name)
830 else:
831 index = input_menu(runs, header='Choose a sample directory')
832 self.spec_file = f'{basedir}/{runs[index]}/spec.log'
833 self.scan_numbers_cli(attr_desc, station=station, scan_type='ts1')
834 else:
835 self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path')
836 self.scan_numbers_cli(attr_desc)
837 for scan_number in self.scan_numbers:
838 self.theta_range_cli(scan_number, attr_desc, station)
839 self.image_range_cli(attr_desc, detector.prefix)
840
841
842 class Sample(BaseModel):
843 name: constr(min_length=1)
844 description: Optional[str]
845 rotation_angles: Optional[list]
846 x_translations: Optional[list]
847 z_translations: Optional[list]
848
849 @classmethod
850 def construct_from_nxsample(cls, nxsample:NXsample):
851 config = {}
852 config['name'] = nxsample.name.nxdata
853 if 'description' in nxsample:
854 config['description'] = nxsample.description.nxdata
855 if 'rotation_angle' in nxsample:
856 config['rotation_angle'] = nxsample.rotation_angle.nxdata
857 if 'x_translation' in nxsample:
858 config['x_translation'] = nxsample.x_translation.nxdata
859 if 'z_translation' in nxsample:
860 config['z_translation'] = nxsample.z_translation.nxdata
861 return(cls(**config))
862
863 def cli(self):
864 print('\n -- Configure the sample metadata -- ')
865 #RV self.name = 'sobhani-3249-A'
866 #RV self.name = 'tenstom_1304r-1'
867 self.set_single_attr_cli('name', 'the sample name')
868 #RV self.description = 'test sample'
869 self.set_single_attr_cli('description', 'a description of the sample (optional)')
870
871
872 class MapConfig(BaseModel):
873 cycle: constr(strip_whitespace=True, min_length=1)
874 btr: constr(strip_whitespace=True, min_length=1)
875 title: constr(strip_whitespace=True, min_length=1)
876 station: Literal['id1a3', 'id3a', 'id3b'] = None
877 sample: Sample
878 detector: Detector = Detector.construct()
879 tomo_fields: TomoField
880 dark_field: Optional[FlatField]
881 bright_field: FlatField
882 _thetas: list[float] = PrivateAttr()
883 _field_types = ({'name': 'dark_field', 'image_key': 2}, {'name': 'bright_field',
884 'image_key': 1}, {'name': 'tomo_fields', 'image_key': 0})
885
886 @classmethod
887 def construct_from_nxentry(cls, nxentry:NXentry):
888 config = {}
889 config['cycle'] = nxentry.instrument.source.attrs['cycle']
890 config['btr'] = nxentry.instrument.source.attrs['btr']
891 config['title'] = nxentry.nxname
892 config['station'] = nxentry.instrument.source.attrs['station']
893 config['sample'] = Sample.construct_from_nxsample(nxentry['sample'])
894 for nxobject_name, nxobject in nxentry.spec_scans.items():
895 if isinstance(nxobject, NXcollection):
896 config[nxobject_name] = SpecScans.construct_from_nxcollection(nxobject)
897 return(cls(**config))
898
899 #FIX cache?
900 @property
901 def thetas(self):
902 try:
903 return(self._thetas)
904 except:
905 theta_range = self.tomo_fields.theta_range
906 self._thetas = list(np.linspace(theta_range['start'], theta_range['end'],
907 theta_range['num']))
908 return(self._thetas)
909
910 def cli(self):
911 print('\n -- Configure a map from a set of SPEC scans (dark, bright, and tomo), '+
912 'and / or detector data -- ')
913 #RV self.cycle = '2021-3'
914 #RV self.cycle = '2022-2'
915 #RV self.cycle = '2023-1'
916 self.set_single_attr_cli('cycle', 'beam cycle')
917 #RV self.btr = 'z-3234-A'
918 #RV self.btr = 'sobhani-3249-A'
919 #RV self.btr = 'przybyla-3606-a'
920 self.set_single_attr_cli('btr', 'BTR')
921 #RV self.title = 'z-3234-A'
922 #RV self.title = 'tomo7C'
923 #RV self.title = 'cmc-test-dwell-1'
924 self.set_single_attr_cli('title', 'title for the map entry')
925 #RV self.station = 'id3a'
926 #RV self.station = 'id3b'
927 #RV self.station = 'id1a3'
928 self.set_single_attr_cli('station', 'name of the station at which scans were collected '+
929 '(currently choose from: id1a3, id3a, id3b)')
930 import_scanparser(self.station)
931 self.set_single_attr_cli('sample')
932 use_detector_config = False
933 if hasattr(self.detector, 'prefix') and len(self.detector.prefix):
934 use_detector_config = input_yesno(f'Current detector settings:\n{self.detector}\n'+
935 f'Keep these settings? (y/n)')
936 if not use_detector_config:
937 menu_options = ['not listed', 'andor2', 'manta', 'retiga']
938 input_mode = input_menu(menu_options, header='Choose one of the following detector '+
939 'configuration options')
940 if input_mode:
941 detector_config_file = f'{menu_options[input_mode]}.yaml'
942 have_detector_config = self.detector.construct_from_yaml(detector_config_file)
943 else:
944 have_detector_config = False
945 if not have_detector_config:
946 self.set_single_attr_cli('detector', 'detector')
947 self.set_single_attr_cli('tomo_fields', 'Tomo field', chain_attr_desc=True,
948 cycle=self.cycle, btr=self.btr, station=self.station, detector=self.detector,
949 sample_name=self.sample.name)
950 if self.station in ('id1a3', 'id3a'):
951 have_dark_field = True
952 tomo_spec_file = self.tomo_fields.spec_file
953 else:
954 have_dark_field = input_yesno(f'Are Dark field images available? (y/n)')
955 tomo_spec_file = None
956 if have_dark_field:
957 self.set_single_attr_cli('dark_field', 'Dark field', chain_attr_desc=True,
958 station=self.station, detector=self.detector, spec_file=tomo_spec_file,
959 tomo_scan_numbers=self.tomo_fields.scan_numbers, scan_type='df1')
960 self.set_single_attr_cli('bright_field', 'Bright field', chain_attr_desc=True,
961 station=self.station, detector=self.detector, spec_file=tomo_spec_file,
962 tomo_scan_numbers=self.tomo_fields.scan_numbers, scan_type='bf1')
963
964 def construct_nxentry(self, nxroot, include_raw_data=True):
965 # Construct base NXentry
966 nxentry = NXentry()
967
968 # Add an NXentry to the NXroot
969 nxroot[self.title] = nxentry
970 nxroot.attrs['default'] = self.title
971 nxentry.definition = 'NXtomo'
972 # nxentry.attrs['default'] = 'data'
973
974 # Add an NXinstrument to the NXentry
975 nxinstrument = NXinstrument()
976 nxentry.instrument = nxinstrument
977
978 # Add an NXsource to the NXinstrument
979 nxsource = NXsource()
980 nxinstrument.source = nxsource
981 nxsource.type = 'Synchrotron X-ray Source'
982 nxsource.name = 'CHESS'
983 nxsource.probe = 'x-ray'
984
985 # Tag the NXsource with the runinfo (as an attribute)
986 nxsource.attrs['cycle'] = self.cycle
987 nxsource.attrs['btr'] = self.btr
988 nxsource.attrs['station'] = self.station
989
990 # Add an NXdetector to the NXinstrument (don't fill in data fields yet)
991 nxinstrument.detector = self.detector.construct_nxdetector()
992
993 # Add an NXsample to NXentry (don't fill in data fields yet)
994 nxsample = NXsample()
995 nxentry.sample = nxsample
996 nxsample.name = self.sample.name
997 nxsample.description = self.sample.description
998
999 # Add an NXcollection to the base NXentry to hold metadata about the spec scans in the map
1000 # Also obtain the data fields in NXsample and NXdetector
1001 nxspec_scans = NXcollection()
1002 nxentry.spec_scans = nxspec_scans
1003 image_keys = []
1004 sequence_numbers = []
1005 image_stacks = []
1006 rotation_angles = []
1007 x_translations = []
1008 z_translations = []
1009 for field_type in self._field_types:
1010 field_name = field_type['name']
1011 field = getattr(self, field_name)
1012 if field is None:
1013 continue
1014 image_key = field_type['image_key']
1015 if field_type['name'] == 'tomo_fields':
1016 thetas = self.thetas
1017 else:
1018 thetas = None
1019 # Add the scans in a single spec file
1020 nxspec_scans[field_name] = field.construct_nxcollection(image_key, thetas,
1021 self.detector)
1022 if include_raw_data:
1023 image_stacks += field.get_detector_data(self.detector.prefix)
1024 for scan_number in field.scan_numbers:
1025 parser = field.get_scanparser(scan_number)
1026 scan_info = field.stack_info[field.get_scan_index(scan_number)]
1027 num_image = scan_info['num_image']
1028 image_keys += num_image*[image_key]
1029 sequence_numbers += [i for i in range(num_image)]
1030 if thetas is None:
1031 rotation_angles += scan_info['num_image']*[0.0]
1032 else:
1033 assert(num_image == len(thetas))
1034 rotation_angles += thetas
1035 x_translations += scan_info['num_image']*[scan_info['ref_x']]
1036 z_translations += scan_info['num_image']*[scan_info['ref_z']]
1037
1038 if include_raw_data:
1039 # Add image data to NXdetector
1040 nxinstrument.detector.image_key = image_keys
1041 nxinstrument.detector.sequence_number = sequence_numbers
1042 nxinstrument.detector.data = np.concatenate([image for image in image_stacks])
1043
1044 # Add image data to NXsample
1045 nxsample.rotation_angle = rotation_angles
1046 nxsample.rotation_angle.attrs['units'] = 'degrees'
1047 nxsample.x_translation = x_translations
1048 nxsample.x_translation.attrs['units'] = 'mm'
1049 nxsample.z_translation = z_translations
1050 nxsample.z_translation.attrs['units'] = 'mm'
1051
1052 # Add an NXdata to NXentry
1053 nxdata = NXdata()
1054 nxentry.data = nxdata
1055 nxdata.makelink(nxentry.instrument.detector.data, name='data')
1056 nxdata.makelink(nxentry.instrument.detector.image_key)
1057 nxdata.makelink(nxentry.sample.rotation_angle)
1058 nxdata.makelink(nxentry.sample.x_translation)
1059 nxdata.makelink(nxentry.sample.z_translation)
1060 # nxdata.attrs['axes'] = ['field', 'row', 'column']
1061 # nxdata.attrs['field_indices'] = 0
1062 # nxdata.attrs['row_indices'] = 1
1063 # nxdata.attrs['column_indices'] = 2
1064
1065
1066 class TomoWorkflow(BaseModel):
1067 sample_maps: conlist(item_type=MapConfig, min_items=1) = [MapConfig.construct()]
1068
1069 @classmethod
1070 def construct_from_nexus(cls, filename):
1071 nxroot = nxload(filename)
1072 sample_maps = []
1073 config = {'sample_maps': sample_maps}
1074 for nxentry_name, nxentry in nxroot.items():
1075 sample_maps.append(MapConfig.construct_from_nxentry(nxentry))
1076 return(cls(**config))
1077
1078 def cli(self):
1079 print('\n -- Configure a map -- ')
1080 self.set_list_attr_cli('sample_maps', 'sample map')
1081
1082 def construct_nxfile(self, filename, mode='w-'):
1083 nxroot = NXroot()
1084 t0 = time()
1085 for sample_map in self.sample_maps:
1086 logger.info(f'Start constructing the {sample_map.title} map.')
1087 import_scanparser(sample_map.station)
1088 sample_map.construct_nxentry(nxroot)
1089 logger.info(f'Constructed all sample maps in {time()-t0:.2f} seconds.')
1090 logger.info(f'Start saving all sample maps to {filename}.')
1091 nxroot.save(filename, mode=mode)
1092
1093 def write_to_nexus(self, filename):
1094 t0 = time()
1095 self.construct_nxfile(filename, mode='w')
1096 logger.info(f'Saved all sample maps to {filename} in {time()-t0:.2f} seconds.')