Mercurial > repos > rv43 > test_tomo_reconstruct
comparison workflow/models.py @ 0:98e23dff1de2 draft default tip
planemo upload for repository https://github.com/rolfverberg/galaxytools commit f8c4bdb31c20c468045ad5e6eb255a293244bc6c-dirty
| author | rv43 |
|---|---|
| date | Tue, 21 Mar 2023 16:22:42 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:98e23dff1de2 |
|---|---|
| 1 #!/usr/bin/env python3 | |
| 2 | |
| 3 import logging | |
| 4 logger = logging.getLogger(__name__) | |
| 5 | |
| 6 import logging | |
| 7 | |
| 8 import numpy as np | |
| 9 import os | |
| 10 import yaml | |
| 11 | |
| 12 from functools import cache | |
| 13 from pathlib import PosixPath | |
| 14 from pydantic import BaseModel as PydanticBaseModel | |
| 15 from pydantic import validator, ValidationError, conint, confloat, constr, conlist, FilePath, \ | |
| 16 PrivateAttr | |
| 17 from nexusformat.nexus import * | |
| 18 from time import time | |
| 19 from typing import Optional, Literal | |
| 20 from typing_extensions import TypedDict | |
| 21 try: | |
| 22 from pyspec.file.spec import FileSpec | |
| 23 except: | |
| 24 pass | |
| 25 | |
| 26 try: | |
| 27 from msnctools.general import is_int, is_num, input_int, input_int_list, input_num, \ | |
| 28 input_yesno, input_menu, index_nearest, string_to_list, file_exists_and_readable | |
| 29 except: | |
| 30 from general import is_int, is_num, input_int, input_int_list, input_num, \ | |
| 31 input_yesno, input_menu, index_nearest, string_to_list, file_exists_and_readable | |
| 32 | |
| 33 | |
| 34 def import_scanparser(station): | |
| 35 if station in ('id1a3', 'id3a'): | |
| 36 try: | |
| 37 from msnctools.scanparsers import SMBRotationScanParser | |
| 38 globals()['ScanParser'] = SMBRotationScanParser | |
| 39 except: | |
| 40 try: | |
| 41 from scanparsers import SMBRotationScanParser | |
| 42 globals()['ScanParser'] = SMBRotationScanParser | |
| 43 except: | |
| 44 pass | |
| 45 elif station in ('id3b'): | |
| 46 try: | |
| 47 from msnctools.scanparsers import FMBRotationScanParser | |
| 48 globals()['ScanParser'] = FMBRotationScanParser | |
| 49 except: | |
| 50 try: | |
| 51 from scanparsers import FMBRotationScanParser | |
| 52 globals()['ScanParser'] = FMBRotationScanParser | |
| 53 except: | |
| 54 pass | |
| 55 else: | |
| 56 raise RuntimeError(f'Invalid station: {station}') | |
| 57 | |
| 58 @cache | |
| 59 def get_available_scan_numbers(spec_file:str): | |
| 60 scans = FileSpec(spec_file).scans | |
| 61 scan_numbers = list(scans.keys()) | |
| 62 for scan_number in scan_numbers.copy(): | |
| 63 try: | |
| 64 parser = ScanParser(spec_file, scan_number) | |
| 65 try: | |
| 66 scan_type = parser.scan_type | |
| 67 except: | |
| 68 scan_type = None | |
| 69 except: | |
| 70 scan_numbers.remove(scan_number) | |
| 71 return(scan_numbers) | |
| 72 | |
| 73 @cache | |
| 74 def get_scanparser(spec_file:str, scan_number:int): | |
| 75 if scan_number not in get_available_scan_numbers(spec_file): | |
| 76 return(None) | |
| 77 else: | |
| 78 return(ScanParser(spec_file, scan_number)) | |
| 79 | |
| 80 | |
| 81 class BaseModel(PydanticBaseModel): | |
| 82 class Config: | |
| 83 validate_assignment = True | |
| 84 arbitrary_types_allowed = True | |
| 85 | |
| 86 @classmethod | |
| 87 def construct_from_cli(cls): | |
| 88 obj = cls.construct() | |
| 89 obj.cli() | |
| 90 return(obj) | |
| 91 | |
| 92 @classmethod | |
| 93 def construct_from_yaml(cls, filename): | |
| 94 try: | |
| 95 with open(filename, 'r') as infile: | |
| 96 indict = yaml.load(infile, Loader=yaml.CLoader) | |
| 97 except: | |
| 98 raise ValueError(f'Could not load a dictionary from {filename}') | |
| 99 else: | |
| 100 obj = cls(**indict) | |
| 101 return(obj) | |
| 102 | |
| 103 @classmethod | |
| 104 def construct_from_file(cls, filename): | |
| 105 file_exists_and_readable(filename) | |
| 106 filename = os.path.abspath(filename) | |
| 107 fileformat = os.path.splitext(filename)[1] | |
| 108 yaml_extensions = ('.yaml','.yml') | |
| 109 nexus_extensions = ('.nxs','.nx5','.h5','.hdf5') | |
| 110 t0 = time() | |
| 111 if fileformat.lower() in yaml_extensions: | |
| 112 obj = cls.construct_from_yaml(filename) | |
| 113 logger.info(f'Constructed a model from {filename} in {time()-t0:.2f} seconds.') | |
| 114 return(obj) | |
| 115 elif fileformat.lower() in nexus_extensions: | |
| 116 obj = cls.construct_from_nexus(filename) | |
| 117 logger.info(f'Constructed a model from {filename} in {time()-t0:.2f} seconds.') | |
| 118 return(obj) | |
| 119 else: | |
| 120 logger.error(f'Unsupported file extension for constructing a model: {fileformat}') | |
| 121 raise TypeError(f'Unrecognized file extension: {fileformat}') | |
| 122 | |
| 123 def dict_for_yaml(self, exclude_fields=[]): | |
| 124 yaml_dict = {} | |
| 125 for field_name in self.__fields__: | |
| 126 if field_name in exclude_fields: | |
| 127 continue | |
| 128 else: | |
| 129 field_value = getattr(self, field_name, None) | |
| 130 if field_value is not None: | |
| 131 if isinstance(field_value, BaseModel): | |
| 132 yaml_dict[field_name] = field_value.dict_for_yaml() | |
| 133 elif isinstance(field_value,list) and all(isinstance(item,BaseModel) | |
| 134 for item in field_value): | |
| 135 yaml_dict[field_name] = [item.dict_for_yaml() for item in field_value] | |
| 136 elif isinstance(field_value, PosixPath): | |
| 137 yaml_dict[field_name] = str(field_value) | |
| 138 else: | |
| 139 yaml_dict[field_name] = field_value | |
| 140 else: | |
| 141 continue | |
| 142 return(yaml_dict) | |
| 143 | |
| 144 def write_to_yaml(self, filename=None): | |
| 145 yaml_dict = self.dict_for_yaml() | |
| 146 if filename is None: | |
| 147 logger.info('Printing yaml representation here:\n'+ | |
| 148 f'{yaml.dump(yaml_dict, sort_keys=False)}') | |
| 149 else: | |
| 150 try: | |
| 151 with open(filename, 'w') as outfile: | |
| 152 yaml.dump(yaml_dict, outfile, sort_keys=False) | |
| 153 logger.info(f'Successfully wrote this model to {filename}') | |
| 154 except: | |
| 155 logger.error(f'Unknown error -- could not write to {filename} in yaml format.') | |
| 156 logger.info('Printing yaml representation here:\n'+ | |
| 157 f'{yaml.dump(yaml_dict, sort_keys=False)}') | |
| 158 | |
| 159 def write_to_file(self, filename, force_overwrite=False): | |
| 160 file_writeable, fileformat = self.output_file_valid(filename, | |
| 161 force_overwrite=force_overwrite) | |
| 162 if fileformat == 'yaml': | |
| 163 if file_writeable: | |
| 164 self.write_to_yaml(filename=filename) | |
| 165 else: | |
| 166 self.write_to_yaml() | |
| 167 elif fileformat == 'nexus': | |
| 168 if file_writeable: | |
| 169 self.write_to_nexus(filename=filename) | |
| 170 | |
| 171 def output_file_valid(self, filename, force_overwrite=False): | |
| 172 filename = os.path.abspath(filename) | |
| 173 fileformat = os.path.splitext(filename)[1] | |
| 174 yaml_extensions = ('.yaml','.yml') | |
| 175 nexus_extensions = ('.nxs','.nx5','.h5','.hdf5') | |
| 176 if fileformat.lower() not in (*yaml_extensions, *nexus_extensions): | |
| 177 return(False, None) # Only yaml and NeXus files allowed for output now. | |
| 178 elif fileformat.lower() in yaml_extensions: | |
| 179 fileformat = 'yaml' | |
| 180 elif fileformat.lower() in nexus_extensions: | |
| 181 fileformat = 'nexus' | |
| 182 if os.path.isfile(filename): | |
| 183 if os.access(filename, os.W_OK): | |
| 184 if not force_overwrite: | |
| 185 logger.error(f'{filename} will not be overwritten.') | |
| 186 return(False, fileformat) | |
| 187 else: | |
| 188 logger.error(f'Cannot access {filename} for writing.') | |
| 189 return(False, fileformat) | |
| 190 if os.path.isdir(os.path.dirname(filename)): | |
| 191 if os.access(os.path.dirname(filename), os.W_OK): | |
| 192 return(True, fileformat) | |
| 193 else: | |
| 194 logger.error(f'Cannot access {os.path.dirname(filename)} for writing.') | |
| 195 return(False, fileformat) | |
| 196 else: | |
| 197 try: | |
| 198 os.makedirs(os.path.dirname(filename)) | |
| 199 return(True, fileformat) | |
| 200 except: | |
| 201 logger.error(f'Cannot create {os.path.dirname(filename)} for output.') | |
| 202 return(False, fileformat) | |
| 203 | |
| 204 def set_single_attr_cli(self, attr_name, attr_desc='unknown attribute', list_flag=False, | |
| 205 **cli_kwargs): | |
| 206 if cli_kwargs.get('chain_attr_desc', False): | |
| 207 cli_kwargs['attr_desc'] = attr_desc | |
| 208 try: | |
| 209 attr = getattr(self, attr_name, None) | |
| 210 if attr is None: | |
| 211 attr = self.__fields__[attr_name].type_.construct() | |
| 212 if cli_kwargs.get('chain_attr_desc', False): | |
| 213 cli_kwargs['attr_desc'] = attr_desc | |
| 214 input_accepted = False | |
| 215 while not input_accepted: | |
| 216 try: | |
| 217 attr.cli(**cli_kwargs) | |
| 218 except ValidationError as e: | |
| 219 print(e) | |
| 220 print(f'Removing {attr_desc} configuration') | |
| 221 attr = self.__fields__[attr_name].type_.construct() | |
| 222 continue | |
| 223 except KeyboardInterrupt as e: | |
| 224 raise e | |
| 225 except BaseException as e: | |
| 226 print(f'{type(e).__name__}: {e}') | |
| 227 print(f'Removing {attr_desc} configuration') | |
| 228 attr = self.__fields__[attr_name].type_.construct() | |
| 229 continue | |
| 230 try: | |
| 231 setattr(self, attr_name, attr) | |
| 232 except ValidationError as e: | |
| 233 print(e) | |
| 234 except KeyboardInterrupt as e: | |
| 235 raise e | |
| 236 except BaseException as e: | |
| 237 print(f'{type(e).__name__}: {e}') | |
| 238 else: | |
| 239 input_accepted = True | |
| 240 except: | |
| 241 input_accepted = False | |
| 242 while not input_accepted: | |
| 243 attr = getattr(self, attr_name, None) | |
| 244 if attr is None: | |
| 245 input_value = input(f'Type and enter a value for {attr_desc}: ') | |
| 246 else: | |
| 247 input_value = input(f'Type and enter a new value for {attr_desc} or press '+ | |
| 248 f'enter to keep the current one ({attr}): ') | |
| 249 if list_flag: | |
| 250 input_value = string_to_list(input_value, remove_duplicates=False, sort=False) | |
| 251 if len(input_value) == 0: | |
| 252 input_value = getattr(self, attr_name, None) | |
| 253 try: | |
| 254 setattr(self, attr_name, input_value) | |
| 255 except ValidationError as e: | |
| 256 print(e) | |
| 257 except KeyboardInterrupt as e: | |
| 258 raise e | |
| 259 except BaseException as e: | |
| 260 print(f'Unexpected {type(e).__name__}: {e}') | |
| 261 else: | |
| 262 input_accepted = True | |
| 263 | |
| 264 def set_list_attr_cli(self, attr_name, attr_desc='unknown attribute', **cli_kwargs): | |
| 265 if cli_kwargs.get('chain_attr_desc', False): | |
| 266 cli_kwargs['attr_desc'] = attr_desc | |
| 267 attr = getattr(self, attr_name, None) | |
| 268 if attr is not None: | |
| 269 # Check existing items | |
| 270 for item in attr: | |
| 271 item_accepted = False | |
| 272 while not item_accepted: | |
| 273 item.cli(**cli_kwargs) | |
| 274 try: | |
| 275 setattr(self, attr_name, attr) | |
| 276 except ValidationError as e: | |
| 277 print(e) | |
| 278 except KeyboardInterrupt as e: | |
| 279 raise e | |
| 280 except BaseException as e: | |
| 281 print(f'{type(e).__name__}: {e}') | |
| 282 else: | |
| 283 item_accepted = True | |
| 284 else: | |
| 285 # Initialize list for new attribute & starting item | |
| 286 attr = [] | |
| 287 item = self.__fields__[attr_name].type_.construct() | |
| 288 # Append (optional) additional items | |
| 289 append = input_yesno(f'Add a {attr_desc} configuration? (y/n)', 'n') | |
| 290 while append: | |
| 291 attr.append(item.__class__.construct_from_cli()) | |
| 292 try: | |
| 293 setattr(self, attr_name, attr) | |
| 294 except ValidationError as e: | |
| 295 print(e) | |
| 296 print(f'Removing last {attr_desc} configuration from the list') | |
| 297 attr.pop() | |
| 298 except KeyboardInterrupt as e: | |
| 299 raise e | |
| 300 except BaseException as e: | |
| 301 print(f'{type(e).__name__}: {e}') | |
| 302 print(f'Removing last {attr_desc} configuration from the list') | |
| 303 attr.pop() | |
| 304 else: | |
| 305 append = input_yesno(f'Add another {attr_desc} configuration? (y/n)', 'n') | |
| 306 | |
| 307 | |
| 308 class Detector(BaseModel): | |
| 309 prefix: constr(strip_whitespace=True, min_length=1) | |
| 310 rows: conint(gt=0) | |
| 311 columns: conint(gt=0) | |
| 312 pixel_size: conlist(item_type=confloat(gt=0), min_items=1, max_items=2) | |
| 313 lens_magnification: confloat(gt=0) = 1.0 | |
| 314 | |
| 315 @property | |
| 316 def get_pixel_size(self): | |
| 317 return(list(np.asarray(self.pixel_size)/self.lens_magnification)) | |
| 318 | |
| 319 def construct_from_yaml(self, filename): | |
| 320 try: | |
| 321 with open(filename, 'r') as infile: | |
| 322 indict = yaml.load(infile, Loader=yaml.CLoader) | |
| 323 detector = indict['detector'] | |
| 324 self.prefix = detector['id'] | |
| 325 pixels = detector['pixels'] | |
| 326 self.rows = pixels['rows'] | |
| 327 self.columns = pixels['columns'] | |
| 328 self.pixel_size = pixels['size'] | |
| 329 self.lens_magnification = indict['lens_magnification'] | |
| 330 except: | |
| 331 logging.warning(f'Could not load a dictionary from {filename}') | |
| 332 return(False) | |
| 333 else: | |
| 334 return(True) | |
| 335 | |
| 336 def cli(self): | |
| 337 print('\n -- Configure the detector -- ') | |
| 338 self.set_single_attr_cli('prefix', 'detector ID') | |
| 339 self.set_single_attr_cli('rows', 'number of pixel rows') | |
| 340 self.set_single_attr_cli('columns', 'number of pixel columns') | |
| 341 self.set_single_attr_cli('pixel_size', 'pixel size in mm (enter either a single value for '+ | |
| 342 'square pixels or a pair of values for the size in the respective row and column '+ | |
| 343 'directions)', list_flag=True) | |
| 344 self.set_single_attr_cli('lens_magnification', 'lens magnification') | |
| 345 | |
| 346 def construct_nxdetector(self): | |
| 347 nxdetector = NXdetector() | |
| 348 nxdetector.local_name = self.prefix | |
| 349 pixel_size = self.get_pixel_size | |
| 350 if len(pixel_size) == 1: | |
| 351 nxdetector.x_pixel_size = pixel_size[0] | |
| 352 nxdetector.y_pixel_size = pixel_size[0] | |
| 353 else: | |
| 354 nxdetector.x_pixel_size = pixel_size[0] | |
| 355 nxdetector.y_pixel_size = pixel_size[1] | |
| 356 nxdetector.x_pixel_size.attrs['units'] = 'mm' | |
| 357 nxdetector.y_pixel_size.attrs['units'] = 'mm' | |
| 358 return(nxdetector) | |
| 359 | |
| 360 | |
| 361 class ScanInfo(TypedDict): | |
| 362 scan_number: int | |
| 363 starting_image_offset: conint(ge=0) | |
| 364 num_image: conint(gt=0) | |
| 365 ref_x: float | |
| 366 ref_z: float | |
| 367 | |
| 368 class SpecScans(BaseModel): | |
| 369 spec_file: FilePath | |
| 370 scan_numbers: conlist(item_type=conint(gt=0), min_items=1) | |
| 371 stack_info: conlist(item_type=ScanInfo, min_items=1) = [] | |
| 372 | |
| 373 @validator('spec_file') | |
| 374 def validate_spec_file(cls, spec_file): | |
| 375 try: | |
| 376 spec_file = os.path.abspath(spec_file) | |
| 377 sspec_file = FileSpec(spec_file) | |
| 378 except: | |
| 379 raise ValueError(f'Invalid SPEC file {spec_file}') | |
| 380 else: | |
| 381 return(spec_file) | |
| 382 | |
| 383 @validator('scan_numbers') | |
| 384 def validate_scan_numbers(cls, scan_numbers, values): | |
| 385 spec_file = values.get('spec_file') | |
| 386 if spec_file is not None: | |
| 387 spec_scans = FileSpec(spec_file) | |
| 388 for scan_number in scan_numbers: | |
| 389 scan = spec_scans.get_scan_by_number(scan_number) | |
| 390 if scan is None: | |
| 391 raise ValueError(f'There is no scan number {scan_number} in {spec_file}') | |
| 392 return(scan_numbers) | |
| 393 | |
| 394 @validator('stack_info') | |
| 395 def validate_stack_info(cls, stack_info, values): | |
| 396 scan_numbers = values.get('scan_numbers') | |
| 397 assert(len(scan_numbers) == len(stack_info)) | |
| 398 for scan_info in stack_info: | |
| 399 assert(scan_info['scan_number'] in scan_numbers) | |
| 400 is_int(scan_info['starting_image_offset'], ge=0, lt=scan_info['num_image'], | |
| 401 raise_error=True) | |
| 402 return(stack_info) | |
| 403 | |
| 404 @classmethod | |
| 405 def construct_from_nxcollection(cls, nxcollection:NXcollection): | |
| 406 config = {} | |
| 407 config['spec_file'] = nxcollection.attrs['spec_file'] | |
| 408 scan_numbers = [] | |
| 409 stack_info = [] | |
| 410 for nxsubentry_name, nxsubentry in nxcollection.items(): | |
| 411 scan_number = int(nxsubentry_name.split('_')[-1]) | |
| 412 scan_numbers.append(scan_number) | |
| 413 stack_info.append({'scan_number': scan_number, | |
| 414 'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number), | |
| 415 'num_image': len(nxsubentry.sample.rotation_angle), | |
| 416 'ref_x': float(nxsubentry.sample.x_translation), | |
| 417 'ref_z': float(nxsubentry.sample.z_translation)}) | |
| 418 config['scan_numbers'] = sorted(scan_numbers) | |
| 419 config['stack_info'] = stack_info | |
| 420 return(cls(**config)) | |
| 421 | |
| 422 @property | |
| 423 def available_scan_numbers(self): | |
| 424 return(get_available_scan_numbers(self.spec_file)) | |
| 425 | |
| 426 def set_from_nxcollection(self, nxcollection:NXcollection): | |
| 427 self.spec_file = nxcollection.attrs['spec_file'] | |
| 428 scan_numbers = [] | |
| 429 stack_info = [] | |
| 430 for nxsubentry_name, nxsubentry in nxcollection.items(): | |
| 431 scan_number = int(nxsubentry_name.split('_')[-1]) | |
| 432 scan_numbers.append(scan_number) | |
| 433 stack_info.append({'scan_number': scan_number, | |
| 434 'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number), | |
| 435 'num_image': len(nxsubentry.sample.rotation_angle), | |
| 436 'ref_x': float(nxsubentry.sample.x_translation), | |
| 437 'ref_z': float(nxsubentry.sample.z_translation)}) | |
| 438 self.scan_numbers = sorted(scan_numbers) | |
| 439 self.stack_info = stack_info | |
| 440 | |
| 441 def get_scan_index(self, scan_number): | |
| 442 scan_index = [scan_index for scan_index, scan_info in enumerate(self.stack_info) | |
| 443 if scan_info['scan_number'] == scan_number] | |
| 444 if len(scan_index) > 1: | |
| 445 raise ValueError('Duplicate scan_numbers in image stack') | |
| 446 elif len(scan_index) == 1: | |
| 447 return(scan_index[0]) | |
| 448 else: | |
| 449 return(None) | |
| 450 | |
| 451 def get_scanparser(self, scan_number): | |
| 452 return(get_scanparser(self.spec_file, scan_number)) | |
| 453 | |
| 454 def get_detector_data(self, detector_prefix, scan_number=None, scan_step_index=None): | |
| 455 image_stacks = [] | |
| 456 if scan_number is None: | |
| 457 scan_numbers = self.scan_numbers | |
| 458 else: | |
| 459 scan_numbers = [scan_number] | |
| 460 for scan_number in scan_numbers: | |
| 461 parser = self.get_scanparser(scan_number) | |
| 462 scan_info = self.stack_info[self.get_scan_index(scan_number)] | |
| 463 image_offset = scan_info['starting_image_offset'] | |
| 464 if scan_step_index is None: | |
| 465 num_image = scan_info['num_image'] | |
| 466 image_stacks.append(parser.get_detector_data(detector_prefix, | |
| 467 (image_offset, image_offset+num_image))) | |
| 468 else: | |
| 469 image_stacks.append(parser.get_detector_data(detector_prefix, | |
| 470 image_offset+scan_step_index)) | |
| 471 if scan_number is not None and scan_step_index is not None: | |
| 472 # Return a single image for a specific scan_number and scan_step_index request | |
| 473 return(image_stacks[0]) | |
| 474 else: | |
| 475 # Return a list otherwise | |
| 476 return(image_stacks) | |
| 477 return(image_stacks) | |
| 478 | |
| 479 def scan_numbers_cli(self, attr_desc, **kwargs): | |
| 480 available_scan_numbers = self.available_scan_numbers | |
| 481 station = kwargs.get('station') | |
| 482 if (station is not None and station in ('id1a3', 'id3a') and | |
| 483 'scan_type' in kwargs): | |
| 484 scan_type = kwargs['scan_type'] | |
| 485 if scan_type == 'ts1': | |
| 486 available_scan_numbers = [] | |
| 487 for scan_number in self.available_scan_numbers: | |
| 488 parser = self.get_scanparser(scan_number) | |
| 489 try: | |
| 490 if parser.scan_type == scan_type: | |
| 491 available_scan_numbers.append(scan_number) | |
| 492 except: | |
| 493 pass | |
| 494 elif scan_type == 'df1': | |
| 495 tomo_scan_numbers = kwargs['tomo_scan_numbers'] | |
| 496 available_scan_numbers = [] | |
| 497 for scan_number in tomo_scan_numbers: | |
| 498 parser = self.get_scanparser(scan_number-2) | |
| 499 assert(parser.scan_type == scan_type) | |
| 500 available_scan_numbers.append(scan_number-2) | |
| 501 elif scan_type == 'bf1': | |
| 502 tomo_scan_numbers = kwargs['tomo_scan_numbers'] | |
| 503 available_scan_numbers = [] | |
| 504 for scan_number in tomo_scan_numbers: | |
| 505 parser = self.get_scanparser(scan_number-1) | |
| 506 assert(parser.scan_type == scan_type) | |
| 507 available_scan_numbers.append(scan_number-1) | |
| 508 if len(available_scan_numbers) == 1: | |
| 509 input_mode = 1 | |
| 510 else: | |
| 511 if hasattr(self, 'scan_numbers'): | |
| 512 print(f'Currently selected {attr_desc}scan numbers are: {self.scan_numbers}') | |
| 513 menu_options = [f'Select a subset of the available {attr_desc}scan numbers', | |
| 514 f'Use all available {attr_desc}scan numbers in {self.spec_file}', | |
| 515 f'Keep the currently selected {attr_desc}scan numbers'] | |
| 516 else: | |
| 517 menu_options = [f'Select a subset of the available {attr_desc}scan numbers', | |
| 518 f'Use all available {attr_desc}scan numbers in {self.spec_file}'] | |
| 519 print(f'Available scan numbers in {self.spec_file} are: '+ | |
| 520 f'{available_scan_numbers}') | |
| 521 input_mode = input_menu(menu_options, header='Choose one of the following options '+ | |
| 522 'for selecting scan numbers') | |
| 523 if input_mode == 0: | |
| 524 accept_scan_numbers = False | |
| 525 while not accept_scan_numbers: | |
| 526 try: | |
| 527 self.scan_numbers = \ | |
| 528 input_int_list(f'Enter a series of {attr_desc}scan numbers') | |
| 529 except ValidationError as e: | |
| 530 print(e) | |
| 531 except KeyboardInterrupt as e: | |
| 532 raise e | |
| 533 except BaseException as e: | |
| 534 print(f'Unexpected {type(e).__name__}: {e}') | |
| 535 else: | |
| 536 accept_scan_numbers = True | |
| 537 elif input_mode == 1: | |
| 538 self.scan_numbers = available_scan_numbers | |
| 539 elif input_mode == 2: | |
| 540 pass | |
| 541 | |
| 542 def cli(self, **cli_kwargs): | |
| 543 if cli_kwargs.get('attr_desc') is not None: | |
| 544 attr_desc = f'{cli_kwargs["attr_desc"]} ' | |
| 545 else: | |
| 546 attr_desc = '' | |
| 547 print(f'\n -- Configure which scans to use from a single {attr_desc}SPEC file') | |
| 548 self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path') | |
| 549 self.scan_numbers_cli(attr_desc) | |
| 550 | |
| 551 def construct_nxcollection(self, image_key, thetas, detector): | |
| 552 nxcollection = NXcollection() | |
| 553 nxcollection.attrs['spec_file'] = str(self.spec_file) | |
| 554 parser = self.get_scanparser(self.scan_numbers[0]) | |
| 555 nxcollection.attrs['date'] = parser.spec_scan.file_date | |
| 556 for scan_number in self.scan_numbers: | |
| 557 # Get scan info | |
| 558 scan_info = self.stack_info[self.get_scan_index(scan_number)] | |
| 559 # Add an NXsubentry to the NXcollection for each scan | |
| 560 entry_name = f'scan_{scan_number}' | |
| 561 nxsubentry = NXsubentry() | |
| 562 nxcollection[entry_name] = nxsubentry | |
| 563 parser = self.get_scanparser(scan_number) | |
| 564 nxsubentry.start_time = parser.spec_scan.date | |
| 565 nxsubentry.spec_command = parser.spec_command | |
| 566 # Add an NXdata for independent dimensions to the scan's NXsubentry | |
| 567 num_image = scan_info['num_image'] | |
| 568 if thetas is None: | |
| 569 thetas = num_image*[0.0] | |
| 570 else: | |
| 571 assert(num_image == len(thetas)) | |
| 572 # nxsubentry.independent_dimensions = NXdata() | |
| 573 # nxsubentry.independent_dimensions.rotation_angle = thetas | |
| 574 # nxsubentry.independent_dimensions.rotation_angle.units = 'degrees' | |
| 575 # Add an NXinstrument to the scan's NXsubentry | |
| 576 nxsubentry.instrument = NXinstrument() | |
| 577 # Add an NXdetector to the NXinstrument to the scan's NXsubentry | |
| 578 nxsubentry.instrument.detector = detector.construct_nxdetector() | |
| 579 nxsubentry.instrument.detector.frame_start_number = scan_info['starting_image_offset'] | |
| 580 nxsubentry.instrument.detector.image_key = image_key | |
| 581 # Add an NXsample to the scan's NXsubentry | |
| 582 nxsubentry.sample = NXsample() | |
| 583 nxsubentry.sample.rotation_angle = thetas | |
| 584 nxsubentry.sample.rotation_angle.units = 'degrees' | |
| 585 nxsubentry.sample.x_translation = scan_info['ref_x'] | |
| 586 nxsubentry.sample.x_translation.units = 'mm' | |
| 587 nxsubentry.sample.z_translation = scan_info['ref_z'] | |
| 588 nxsubentry.sample.z_translation.units = 'mm' | |
| 589 return(nxcollection) | |
| 590 | |
| 591 | |
| 592 class FlatField(SpecScans): | |
| 593 | |
| 594 def image_range_cli(self, attr_desc, detector_prefix): | |
| 595 stack_info = self.stack_info | |
| 596 for scan_number in self.scan_numbers: | |
| 597 # Parse the available image range | |
| 598 parser = self.get_scanparser(scan_number) | |
| 599 image_offset = parser.starting_image_offset | |
| 600 num_image = parser.get_num_image(detector_prefix.upper()) | |
| 601 scan_index = self.get_scan_index(scan_number) | |
| 602 | |
| 603 # Select the image set | |
| 604 last_image_index = image_offset+num_image | |
| 605 print(f'Available good image set index range: [{image_offset}, {last_image_index})') | |
| 606 image_set_approved = False | |
| 607 if scan_index is not None: | |
| 608 scan_info = stack_info[scan_index] | |
| 609 print(f'Current starting image offset and number of images: '+ | |
| 610 f'{scan_info["starting_image_offset"]} and {scan_info["num_image"]}') | |
| 611 image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y') | |
| 612 if not image_set_approved: | |
| 613 print(f'Default starting image offset and number of images: '+ | |
| 614 f'{image_offset} and {num_image}') | |
| 615 image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y') | |
| 616 if image_set_approved: | |
| 617 offset = image_offset | |
| 618 num = last_image_index-offset | |
| 619 while not image_set_approved: | |
| 620 offset = input_int(f'Enter the starting image offset', ge=image_offset, | |
| 621 lt=last_image_index)#, default=image_offset) | |
| 622 num = input_int(f'Enter the number of images', ge=1, | |
| 623 le=last_image_index-offset)#, default=last_image_index-offset) | |
| 624 print(f'Current starting image offset and number of images: {offset} and {num}') | |
| 625 image_set_approved = input_yesno(f'Accept these values (y/n)?', 'y') | |
| 626 if scan_index is not None: | |
| 627 scan_info['starting_image_offset'] = offset | |
| 628 scan_info['num_image'] = num | |
| 629 scan_info['ref_x'] = parser.horizontal_shift | |
| 630 scan_info['ref_z'] = parser.vertical_shift | |
| 631 else: | |
| 632 stack_info.append({'scan_number': scan_number, 'starting_image_offset': offset, | |
| 633 'num_image': num, 'ref_x': parser.horizontal_shift, | |
| 634 'ref_z': parser.vertical_shift}) | |
| 635 self.stack_info = stack_info | |
| 636 | |
| 637 def cli(self, **cli_kwargs): | |
| 638 if cli_kwargs.get('attr_desc') is not None: | |
| 639 attr_desc = f'{cli_kwargs["attr_desc"]} ' | |
| 640 else: | |
| 641 attr_desc = '' | |
| 642 station = cli_kwargs.get('station') | |
| 643 detector = cli_kwargs.get('detector') | |
| 644 print(f'\n -- Configure the location of the {attr_desc}scan data -- ') | |
| 645 if station in ('id1a3', 'id3a'): | |
| 646 self.spec_file = cli_kwargs['spec_file'] | |
| 647 tomo_scan_numbers = cli_kwargs['tomo_scan_numbers'] | |
| 648 scan_type = cli_kwargs['scan_type'] | |
| 649 self.scan_numbers_cli(attr_desc, station=station, tomo_scan_numbers=tomo_scan_numbers, | |
| 650 scan_type=scan_type) | |
| 651 else: | |
| 652 self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path') | |
| 653 self.scan_numbers_cli(attr_desc) | |
| 654 self.image_range_cli(attr_desc, detector.prefix) | |
| 655 | |
| 656 | |
| 657 class TomoField(SpecScans): | |
| 658 theta_range: dict = {} | |
| 659 | |
| 660 @validator('theta_range') | |
| 661 def validate_theta_range(cls, theta_range): | |
| 662 if len(theta_range) != 3 and len(theta_range) != 4: | |
| 663 raise ValueError(f'Invalid theta range {theta_range}') | |
| 664 is_num(theta_range['start'], raise_error=True) | |
| 665 is_num(theta_range['end'], raise_error=True) | |
| 666 is_int(theta_range['num'], gt=1, raise_error=True) | |
| 667 if theta_range['end'] <= theta_range['start']: | |
| 668 raise ValueError(f'Invalid theta range {theta_range}') | |
| 669 if 'start_index' in theta_range: | |
| 670 is_int(theta_range['start_index'], ge=0, raise_error=True) | |
| 671 return(theta_range) | |
| 672 | |
| 673 @classmethod | |
| 674 def construct_from_nxcollection(cls, nxcollection:NXcollection): | |
| 675 #RV Can I derive this from the same classfunction for SpecScans by adding theta_range | |
| 676 config = {} | |
| 677 config['spec_file'] = nxcollection.attrs['spec_file'] | |
| 678 scan_numbers = [] | |
| 679 stack_info = [] | |
| 680 for nxsubentry_name, nxsubentry in nxcollection.items(): | |
| 681 scan_number = int(nxsubentry_name.split('_')[-1]) | |
| 682 scan_numbers.append(scan_number) | |
| 683 stack_info.append({'scan_number': scan_number, | |
| 684 'starting_image_offset': int(nxsubentry.instrument.detector.frame_start_number), | |
| 685 'num_image': len(nxsubentry.sample.rotation_angle), | |
| 686 'ref_x': float(nxsubentry.sample.x_translation), | |
| 687 'ref_z': float(nxsubentry.sample.z_translation)}) | |
| 688 config['scan_numbers'] = sorted(scan_numbers) | |
| 689 config['stack_info'] = stack_info | |
| 690 for name in nxcollection.entries: | |
| 691 if 'scan_' in name: | |
| 692 thetas = np.asarray(nxcollection[name].sample.rotation_angle) | |
| 693 config['theta_range'] = {'start': thetas[0], 'end': thetas[-1], 'num': thetas.size} | |
| 694 break | |
| 695 return(cls(**config)) | |
| 696 | |
| 697 def get_horizontal_shifts(self, scan_number=None): | |
| 698 horizontal_shifts = [] | |
| 699 if scan_number is None: | |
| 700 scan_numbers = self.scan_numbers | |
| 701 else: | |
| 702 scan_numbers = [scan_number] | |
| 703 for scan_number in scan_numbers: | |
| 704 parser = self.get_scanparser(scan_number) | |
| 705 horizontal_shifts.append(parser.horizontal_shift) | |
| 706 if len(horizontal_shifts) == 1: | |
| 707 return(horizontal_shifts[0]) | |
| 708 else: | |
| 709 return(horizontal_shifts) | |
| 710 | |
| 711 def get_vertical_shifts(self, scan_number=None): | |
| 712 vertical_shifts = [] | |
| 713 if scan_number is None: | |
| 714 scan_numbers = self.scan_numbers | |
| 715 else: | |
| 716 scan_numbers = [scan_number] | |
| 717 for scan_number in scan_numbers: | |
| 718 parser = self.get_scanparser(scan_number) | |
| 719 vertical_shifts.append(parser.vertical_shift) | |
| 720 if len(vertical_shifts) == 1: | |
| 721 return(vertical_shifts[0]) | |
| 722 else: | |
| 723 return(vertical_shifts) | |
| 724 | |
| 725 def theta_range_cli(self, scan_number, attr_desc, station): | |
| 726 # Parse the available theta range | |
| 727 parser = self.get_scanparser(scan_number) | |
| 728 theta_vals = parser.theta_vals | |
| 729 spec_theta_start = theta_vals.get('start') | |
| 730 spec_theta_end = theta_vals.get('end') | |
| 731 spec_num_theta = theta_vals.get('num') | |
| 732 | |
| 733 # Check for consistency of theta ranges between scans | |
| 734 if scan_number != self.scan_numbers[0]: | |
| 735 parser = self.get_scanparser(self.scan_numbers[0]) | |
| 736 if (parser.theta_vals.get('start') != spec_theta_start or | |
| 737 parser.theta_vals.get('end') != spec_theta_end or | |
| 738 parser.theta_vals.get('num') != spec_num_theta): | |
| 739 raise ValueError(f'Incompatible theta ranges between {attr_desc}scans:'+ | |
| 740 f'\n\tScan {scan_number}: {theta_vals}'+ | |
| 741 f'\n\tScan {self.scan_numbers[0]}: {parser.theta_vals}') | |
| 742 return | |
| 743 | |
| 744 # Select the theta range for the tomo reconstruction from the first scan | |
| 745 theta_range_approved = False | |
| 746 thetas = np.linspace(spec_theta_start, spec_theta_end, spec_num_theta) | |
| 747 delta_theta = thetas[1]-thetas[0] | |
| 748 print(f'Theta range obtained from SPEC data: [{spec_theta_start}, {spec_theta_end}]') | |
| 749 print(f'Theta step size = {delta_theta}') | |
| 750 print(f'Number of theta values: {spec_num_theta}') | |
| 751 default_start = None | |
| 752 default_end = None | |
| 753 if station in ('id1a3', 'id3a'): | |
| 754 theta_range_approved = input_yesno(f'Accept this theta range (y/n)?', 'y') | |
| 755 if theta_range_approved: | |
| 756 self.theta_range = {'start': float(spec_theta_start), 'end': float(spec_theta_end), | |
| 757 'num': int(spec_num_theta), 'start_index': 0} | |
| 758 return | |
| 759 elif station in ('id3b'): | |
| 760 if spec_theta_start <= 0.0 and spec_theta_end >= 180.0: | |
| 761 default_start = 0 | |
| 762 default_end = 180 | |
| 763 elif spec_theta_end-spec_theta_start == 180: | |
| 764 default_start = spec_theta_start | |
| 765 default_end = spec_theta_end | |
| 766 while not theta_range_approved: | |
| 767 theta_start = input_num(f'Enter the first theta (included)', ge=spec_theta_start, | |
| 768 lt=spec_theta_end, default=default_start) | |
| 769 theta_index_start = index_nearest(thetas, theta_start) | |
| 770 theta_start = thetas[theta_index_start] | |
| 771 theta_end = input_num(f'Enter the last theta (excluded)', | |
| 772 ge=theta_start+delta_theta, le=spec_theta_end, default=default_end) | |
| 773 theta_index_end = index_nearest(thetas, theta_end) | |
| 774 theta_end = thetas[theta_index_end] | |
| 775 num_theta = theta_index_end-theta_index_start | |
| 776 print(f'Selected theta range: [{theta_start}, {theta_start+delta_theta}, ..., '+ | |
| 777 f'{theta_end})') | |
| 778 print(f'Number of theta values: {num_theta}') | |
| 779 theta_range_approved = input_yesno(f'Accept this theta range (y/n)?', 'y') | |
| 780 self.theta_range = {'start': float(theta_start), 'end': float(theta_end), | |
| 781 'num': int(num_theta), 'start_index': int(theta_index_start)} | |
| 782 | |
| 783 def image_range_cli(self, attr_desc, detector_prefix): | |
| 784 stack_info = self.stack_info | |
| 785 for scan_number in self.scan_numbers: | |
| 786 # Parse the available image range | |
| 787 parser = self.get_scanparser(scan_number) | |
| 788 image_offset = parser.starting_image_offset | |
| 789 num_image = parser.get_num_image(detector_prefix.upper()) | |
| 790 scan_index = self.get_scan_index(scan_number) | |
| 791 | |
| 792 # Select the image set matching the theta range | |
| 793 num_theta = self.theta_range['num'] | |
| 794 theta_index_start = self.theta_range['start_index'] | |
| 795 if num_theta > num_image-theta_index_start: | |
| 796 raise ValueError(f'Available {attr_desc}image indices incompatible with thetas:'+ | |
| 797 f'\n\tNumber of thetas and offset = {num_theta} and {theta_index_start}'+ | |
| 798 f'\n\tNumber of available images {num_image}') | |
| 799 if scan_index is not None: | |
| 800 scan_info = stack_info[scan_index] | |
| 801 scan_info['starting_image_offset'] = image_offset+theta_index_start | |
| 802 scan_info['num_image'] = num_theta | |
| 803 scan_info['ref_x'] = parser.horizontal_shift | |
| 804 scan_info['ref_z'] = parser.vertical_shift | |
| 805 else: | |
| 806 stack_info.append({'scan_number': scan_number, | |
| 807 'starting_image_offset': image_offset+theta_index_start, | |
| 808 'num_image': num_theta, 'ref_x': parser.horizontal_shift, | |
| 809 'ref_z': parser.vertical_shift}) | |
| 810 self.stack_info = stack_info | |
| 811 | |
| 812 def cli(self, **cli_kwargs): | |
| 813 if cli_kwargs.get('attr_desc') is not None: | |
| 814 attr_desc = f'{cli_kwargs["attr_desc"]} ' | |
| 815 else: | |
| 816 attr_desc = '' | |
| 817 cycle = cli_kwargs.get('cycle') | |
| 818 btr = cli_kwargs.get('btr') | |
| 819 station = cli_kwargs.get('station') | |
| 820 detector = cli_kwargs.get('detector') | |
| 821 sample_name = cli_kwargs.get('sample_name') | |
| 822 print(f'\n -- Configure the location of the {attr_desc}scan data -- ') | |
| 823 if station in ('id1a3', 'id3a'): | |
| 824 basedir = f'/nfs/chess/{station}/{cycle}/{btr}' | |
| 825 runs = [d for d in os.listdir(basedir) if os.path.isdir(os.path.join(basedir, d))] | |
| 826 #RV index = 15-1 | |
| 827 #RV index = 7-1 | |
| 828 if sample_name is not None and sample_name in runs: | |
| 829 index = runs.index(sample_name) | |
| 830 else: | |
| 831 index = input_menu(runs, header='Choose a sample directory') | |
| 832 self.spec_file = f'{basedir}/{runs[index]}/spec.log' | |
| 833 self.scan_numbers_cli(attr_desc, station=station, scan_type='ts1') | |
| 834 else: | |
| 835 self.set_single_attr_cli('spec_file', attr_desc+'SPEC file path') | |
| 836 self.scan_numbers_cli(attr_desc) | |
| 837 for scan_number in self.scan_numbers: | |
| 838 self.theta_range_cli(scan_number, attr_desc, station) | |
| 839 self.image_range_cli(attr_desc, detector.prefix) | |
| 840 | |
| 841 | |
| 842 class Sample(BaseModel): | |
| 843 name: constr(min_length=1) | |
| 844 description: Optional[str] | |
| 845 rotation_angles: Optional[list] | |
| 846 x_translations: Optional[list] | |
| 847 z_translations: Optional[list] | |
| 848 | |
| 849 @classmethod | |
| 850 def construct_from_nxsample(cls, nxsample:NXsample): | |
| 851 config = {} | |
| 852 config['name'] = nxsample.name.nxdata | |
| 853 if 'description' in nxsample: | |
| 854 config['description'] = nxsample.description.nxdata | |
| 855 if 'rotation_angle' in nxsample: | |
| 856 config['rotation_angle'] = nxsample.rotation_angle.nxdata | |
| 857 if 'x_translation' in nxsample: | |
| 858 config['x_translation'] = nxsample.x_translation.nxdata | |
| 859 if 'z_translation' in nxsample: | |
| 860 config['z_translation'] = nxsample.z_translation.nxdata | |
| 861 return(cls(**config)) | |
| 862 | |
| 863 def cli(self): | |
| 864 print('\n -- Configure the sample metadata -- ') | |
| 865 #RV self.name = 'sobhani-3249-A' | |
| 866 #RV self.name = 'tenstom_1304r-1' | |
| 867 self.set_single_attr_cli('name', 'the sample name') | |
| 868 #RV self.description = 'test sample' | |
| 869 self.set_single_attr_cli('description', 'a description of the sample (optional)') | |
| 870 | |
| 871 | |
| 872 class MapConfig(BaseModel): | |
| 873 cycle: constr(strip_whitespace=True, min_length=1) | |
| 874 btr: constr(strip_whitespace=True, min_length=1) | |
| 875 title: constr(strip_whitespace=True, min_length=1) | |
| 876 station: Literal['id1a3', 'id3a', 'id3b'] = None | |
| 877 sample: Sample | |
| 878 detector: Detector = Detector.construct() | |
| 879 tomo_fields: TomoField | |
| 880 dark_field: Optional[FlatField] | |
| 881 bright_field: FlatField | |
| 882 _thetas: list[float] = PrivateAttr() | |
| 883 _field_types = ({'name': 'dark_field', 'image_key': 2}, {'name': 'bright_field', | |
| 884 'image_key': 1}, {'name': 'tomo_fields', 'image_key': 0}) | |
| 885 | |
| 886 @classmethod | |
| 887 def construct_from_nxentry(cls, nxentry:NXentry): | |
| 888 config = {} | |
| 889 config['cycle'] = nxentry.instrument.source.attrs['cycle'] | |
| 890 config['btr'] = nxentry.instrument.source.attrs['btr'] | |
| 891 config['title'] = nxentry.nxname | |
| 892 config['station'] = nxentry.instrument.source.attrs['station'] | |
| 893 config['sample'] = Sample.construct_from_nxsample(nxentry['sample']) | |
| 894 for nxobject_name, nxobject in nxentry.spec_scans.items(): | |
| 895 if isinstance(nxobject, NXcollection): | |
| 896 config[nxobject_name] = SpecScans.construct_from_nxcollection(nxobject) | |
| 897 return(cls(**config)) | |
| 898 | |
| 899 #FIX cache? | |
| 900 @property | |
| 901 def thetas(self): | |
| 902 try: | |
| 903 return(self._thetas) | |
| 904 except: | |
| 905 theta_range = self.tomo_fields.theta_range | |
| 906 self._thetas = list(np.linspace(theta_range['start'], theta_range['end'], | |
| 907 theta_range['num'])) | |
| 908 return(self._thetas) | |
| 909 | |
| 910 def cli(self): | |
| 911 print('\n -- Configure a map from a set of SPEC scans (dark, bright, and tomo), '+ | |
| 912 'and / or detector data -- ') | |
| 913 #RV self.cycle = '2021-3' | |
| 914 #RV self.cycle = '2022-2' | |
| 915 #RV self.cycle = '2023-1' | |
| 916 self.set_single_attr_cli('cycle', 'beam cycle') | |
| 917 #RV self.btr = 'z-3234-A' | |
| 918 #RV self.btr = 'sobhani-3249-A' | |
| 919 #RV self.btr = 'przybyla-3606-a' | |
| 920 self.set_single_attr_cli('btr', 'BTR') | |
| 921 #RV self.title = 'z-3234-A' | |
| 922 #RV self.title = 'tomo7C' | |
| 923 #RV self.title = 'cmc-test-dwell-1' | |
| 924 self.set_single_attr_cli('title', 'title for the map entry') | |
| 925 #RV self.station = 'id3a' | |
| 926 #RV self.station = 'id3b' | |
| 927 #RV self.station = 'id1a3' | |
| 928 self.set_single_attr_cli('station', 'name of the station at which scans were collected '+ | |
| 929 '(currently choose from: id1a3, id3a, id3b)') | |
| 930 import_scanparser(self.station) | |
| 931 self.set_single_attr_cli('sample') | |
| 932 use_detector_config = False | |
| 933 if hasattr(self.detector, 'prefix') and len(self.detector.prefix): | |
| 934 use_detector_config = input_yesno(f'Current detector settings:\n{self.detector}\n'+ | |
| 935 f'Keep these settings? (y/n)') | |
| 936 if not use_detector_config: | |
| 937 menu_options = ['not listed', 'andor2', 'manta', 'retiga'] | |
| 938 input_mode = input_menu(menu_options, header='Choose one of the following detector '+ | |
| 939 'configuration options') | |
| 940 if input_mode: | |
| 941 detector_config_file = f'{menu_options[input_mode]}.yaml' | |
| 942 have_detector_config = self.detector.construct_from_yaml(detector_config_file) | |
| 943 else: | |
| 944 have_detector_config = False | |
| 945 if not have_detector_config: | |
| 946 self.set_single_attr_cli('detector', 'detector') | |
| 947 self.set_single_attr_cli('tomo_fields', 'Tomo field', chain_attr_desc=True, | |
| 948 cycle=self.cycle, btr=self.btr, station=self.station, detector=self.detector, | |
| 949 sample_name=self.sample.name) | |
| 950 if self.station in ('id1a3', 'id3a'): | |
| 951 have_dark_field = True | |
| 952 tomo_spec_file = self.tomo_fields.spec_file | |
| 953 else: | |
| 954 have_dark_field = input_yesno(f'Are Dark field images available? (y/n)') | |
| 955 tomo_spec_file = None | |
| 956 if have_dark_field: | |
| 957 self.set_single_attr_cli('dark_field', 'Dark field', chain_attr_desc=True, | |
| 958 station=self.station, detector=self.detector, spec_file=tomo_spec_file, | |
| 959 tomo_scan_numbers=self.tomo_fields.scan_numbers, scan_type='df1') | |
| 960 self.set_single_attr_cli('bright_field', 'Bright field', chain_attr_desc=True, | |
| 961 station=self.station, detector=self.detector, spec_file=tomo_spec_file, | |
| 962 tomo_scan_numbers=self.tomo_fields.scan_numbers, scan_type='bf1') | |
| 963 | |
| 964 def construct_nxentry(self, nxroot, include_raw_data=True): | |
| 965 # Construct base NXentry | |
| 966 nxentry = NXentry() | |
| 967 | |
| 968 # Add an NXentry to the NXroot | |
| 969 nxroot[self.title] = nxentry | |
| 970 nxroot.attrs['default'] = self.title | |
| 971 nxentry.definition = 'NXtomo' | |
| 972 # nxentry.attrs['default'] = 'data' | |
| 973 | |
| 974 # Add an NXinstrument to the NXentry | |
| 975 nxinstrument = NXinstrument() | |
| 976 nxentry.instrument = nxinstrument | |
| 977 | |
| 978 # Add an NXsource to the NXinstrument | |
| 979 nxsource = NXsource() | |
| 980 nxinstrument.source = nxsource | |
| 981 nxsource.type = 'Synchrotron X-ray Source' | |
| 982 nxsource.name = 'CHESS' | |
| 983 nxsource.probe = 'x-ray' | |
| 984 | |
| 985 # Tag the NXsource with the runinfo (as an attribute) | |
| 986 nxsource.attrs['cycle'] = self.cycle | |
| 987 nxsource.attrs['btr'] = self.btr | |
| 988 nxsource.attrs['station'] = self.station | |
| 989 | |
| 990 # Add an NXdetector to the NXinstrument (don't fill in data fields yet) | |
| 991 nxinstrument.detector = self.detector.construct_nxdetector() | |
| 992 | |
| 993 # Add an NXsample to NXentry (don't fill in data fields yet) | |
| 994 nxsample = NXsample() | |
| 995 nxentry.sample = nxsample | |
| 996 nxsample.name = self.sample.name | |
| 997 nxsample.description = self.sample.description | |
| 998 | |
| 999 # Add an NXcollection to the base NXentry to hold metadata about the spec scans in the map | |
| 1000 # Also obtain the data fields in NXsample and NXdetector | |
| 1001 nxspec_scans = NXcollection() | |
| 1002 nxentry.spec_scans = nxspec_scans | |
| 1003 image_keys = [] | |
| 1004 sequence_numbers = [] | |
| 1005 image_stacks = [] | |
| 1006 rotation_angles = [] | |
| 1007 x_translations = [] | |
| 1008 z_translations = [] | |
| 1009 for field_type in self._field_types: | |
| 1010 field_name = field_type['name'] | |
| 1011 field = getattr(self, field_name) | |
| 1012 if field is None: | |
| 1013 continue | |
| 1014 image_key = field_type['image_key'] | |
| 1015 if field_type['name'] == 'tomo_fields': | |
| 1016 thetas = self.thetas | |
| 1017 else: | |
| 1018 thetas = None | |
| 1019 # Add the scans in a single spec file | |
| 1020 nxspec_scans[field_name] = field.construct_nxcollection(image_key, thetas, | |
| 1021 self.detector) | |
| 1022 if include_raw_data: | |
| 1023 image_stacks += field.get_detector_data(self.detector.prefix) | |
| 1024 for scan_number in field.scan_numbers: | |
| 1025 parser = field.get_scanparser(scan_number) | |
| 1026 scan_info = field.stack_info[field.get_scan_index(scan_number)] | |
| 1027 num_image = scan_info['num_image'] | |
| 1028 image_keys += num_image*[image_key] | |
| 1029 sequence_numbers += [i for i in range(num_image)] | |
| 1030 if thetas is None: | |
| 1031 rotation_angles += scan_info['num_image']*[0.0] | |
| 1032 else: | |
| 1033 assert(num_image == len(thetas)) | |
| 1034 rotation_angles += thetas | |
| 1035 x_translations += scan_info['num_image']*[scan_info['ref_x']] | |
| 1036 z_translations += scan_info['num_image']*[scan_info['ref_z']] | |
| 1037 | |
| 1038 if include_raw_data: | |
| 1039 # Add image data to NXdetector | |
| 1040 nxinstrument.detector.image_key = image_keys | |
| 1041 nxinstrument.detector.sequence_number = sequence_numbers | |
| 1042 nxinstrument.detector.data = np.concatenate([image for image in image_stacks]) | |
| 1043 | |
| 1044 # Add image data to NXsample | |
| 1045 nxsample.rotation_angle = rotation_angles | |
| 1046 nxsample.rotation_angle.attrs['units'] = 'degrees' | |
| 1047 nxsample.x_translation = x_translations | |
| 1048 nxsample.x_translation.attrs['units'] = 'mm' | |
| 1049 nxsample.z_translation = z_translations | |
| 1050 nxsample.z_translation.attrs['units'] = 'mm' | |
| 1051 | |
| 1052 # Add an NXdata to NXentry | |
| 1053 nxdata = NXdata() | |
| 1054 nxentry.data = nxdata | |
| 1055 nxdata.makelink(nxentry.instrument.detector.data, name='data') | |
| 1056 nxdata.makelink(nxentry.instrument.detector.image_key) | |
| 1057 nxdata.makelink(nxentry.sample.rotation_angle) | |
| 1058 nxdata.makelink(nxentry.sample.x_translation) | |
| 1059 nxdata.makelink(nxentry.sample.z_translation) | |
| 1060 # nxdata.attrs['axes'] = ['field', 'row', 'column'] | |
| 1061 # nxdata.attrs['field_indices'] = 0 | |
| 1062 # nxdata.attrs['row_indices'] = 1 | |
| 1063 # nxdata.attrs['column_indices'] = 2 | |
| 1064 | |
| 1065 | |
| 1066 class TomoWorkflow(BaseModel): | |
| 1067 sample_maps: conlist(item_type=MapConfig, min_items=1) = [MapConfig.construct()] | |
| 1068 | |
| 1069 @classmethod | |
| 1070 def construct_from_nexus(cls, filename): | |
| 1071 nxroot = nxload(filename) | |
| 1072 sample_maps = [] | |
| 1073 config = {'sample_maps': sample_maps} | |
| 1074 for nxentry_name, nxentry in nxroot.items(): | |
| 1075 sample_maps.append(MapConfig.construct_from_nxentry(nxentry)) | |
| 1076 return(cls(**config)) | |
| 1077 | |
| 1078 def cli(self): | |
| 1079 print('\n -- Configure a map -- ') | |
| 1080 self.set_list_attr_cli('sample_maps', 'sample map') | |
| 1081 | |
| 1082 def construct_nxfile(self, filename, mode='w-'): | |
| 1083 nxroot = NXroot() | |
| 1084 t0 = time() | |
| 1085 for sample_map in self.sample_maps: | |
| 1086 logger.info(f'Start constructing the {sample_map.title} map.') | |
| 1087 import_scanparser(sample_map.station) | |
| 1088 sample_map.construct_nxentry(nxroot) | |
| 1089 logger.info(f'Constructed all sample maps in {time()-t0:.2f} seconds.') | |
| 1090 logger.info(f'Start saving all sample maps to {filename}.') | |
| 1091 nxroot.save(filename, mode=mode) | |
| 1092 | |
| 1093 def write_to_nexus(self, filename): | |
| 1094 t0 = time() | |
| 1095 self.construct_nxfile(filename, mode='w') | |
| 1096 logger.info(f'Saved all sample maps to {filename} in {time()-t0:.2f} seconds.') |
