diff detector.py @ 65:f31ef7bfb430 draft

"planemo upload for repository https://github.com/rolfverberg/galaxytools commit d55db09b45d0b542f966cef17892858bb55d94f7"
author rv43
date Thu, 18 Aug 2022 14:57:39 +0000
parents 15288e9746e0
children
line wrap: on
line diff
--- a/detector.py	Thu Aug 18 14:49:16 2022 +0000
+++ b/detector.py	Thu Aug 18 14:57:39 2022 +0000
@@ -4,13 +4,12 @@
 from functools import cache
 from copy import deepcopy
 
-from general import *
+from general import illegal_value, is_int, is_num, input_yesno
 
 #from hexrd.instrument import HEDMInstrument, PlanarDetector
 
 class DetectorConfig:
     def __init__(self, config_source):
-
         self._config_source = config_source
 
         if isinstance(self._config_source, ((str, bytes, os.PathLike, int))):
@@ -26,7 +25,8 @@
         self._valid = self._validate()
 
         if not self.valid:
-            logging.error(f'Cannot create a valid instance of {self.__class__.__name__} from {self._config_source}')
+            logging.error(f'Cannot create a valid instance of {self.__class__.__name__} '+
+                    f'from {self._config_source}')
 
     def __repr__(self):
         return(f'{self.__class__.__name__}({self._config_source.__repr__()})')
@@ -36,15 +36,18 @@
     @property
     def config_file(self):
         return(self._config_file)
+
     @property
     def config(self):
         return(deepcopy(self._config))
+
     @property
     def valid(self):
         return(self._valid)
 
     def load_config_file(self):
         raise(NotImplementedError)
+
     def validate(self):
         raise(NotImplementedError)
 
@@ -54,20 +57,24 @@
             return(False)
         else:
             return(self.load_config_file())
+
     def _validate(self):
         if not self.config:
             logging.error('A configuration must be loaded prior to calling Detector._validate')
             return(False)
         else:
             return(self.validate())
+
     def _write_to_file(self, out_file):
         out_file = os.path.abspath(out_file)
 
         current_config_valid = self.validate()
         if not current_config_valid:
-            write_invalid_config = input_yesno(s=f'This {self.__class__.__name__} is not currently valid. Write the configuration to {out_file} anyways?', default='no')
+            write_invalid_config = input_yesno(s=f'This {self.__class__.__name__} is currently '+
+                    f'invalid. Write the configuration to {out_file} anyways?', default='no')
             if not write_invalid_config:
-                logging.info(f'In accordance with user input, the invalid configuration will not be written to {out_file}')
+                logging.info('In accordance with user input, the invalid configuration will '+
+                        f'not be written to {out_file}')
                 return 
 
         if os.access(out_file, os.W_OK):
@@ -76,7 +83,8 @@
                 if overwrite:
                     self.write_to_file(out_file)
                 else:
-                    logging.info(f'In accordance with user input, {out_file} will not be overwritten')
+                    logging.info(f'In accordance with user input, {out_file} will not be '+
+                            'overwritten')
             else:
                 self.write_to_file(out_file)
         else:
@@ -91,21 +99,29 @@
         super().__init__(config_source)
 
     def load_config_file(self):
-        with open(self.config_file, 'r') as infile:
+        if not os.path.splitext(self._config_file)[1]:
+            if os.path.isfile(f'{self._config_file}.yml'):
+                self._config_file = f'{self._config_file}.yml'
+            if os.path.isfile(f'{self._config_file}.yaml'):
+                self._config_file = f'{self._config_file}.yaml'
+        if not os.path.isfile(self._config_file):
+            logging.error(f'Unable to load {self._config_file}')
+            return(False)
+        with open(self._config_file, 'r') as infile:
             config = yaml.safe_load(infile)
         if isinstance(config, dict):
             return(config)
         else:
-            logging.error(f'Unable to load {self.config_file} as a dictionary')
+            logging.error(f'Unable to load {self._config_file} as a dictionary')
             return(False)
 
     def validate(self):
         if not self._validate_yaml_pars:
-            logging.warning('There are no required parameters provided for this detector configuration.')
+            logging.warning('There are no required parameters provided for this detector '+
+                    'configuration')
             return(True)
 
         def validate_nested_pars(config, validate_yaml_par):
-            
             yaml_par_levels = validate_yaml_par.split(':')
             first_level_par = yaml_par_levels[0]
             try:
@@ -122,7 +138,8 @@
             except:
                 return(False)
 
-        pars_missing = [p for p in self._validate_yaml_pars if not validate_nested_pars(self.config, p)]
+        pars_missing = [p for p in self._validate_yaml_pars 
+                if not validate_nested_pars(self.config, p)]
         if len(pars_missing) > 0:
             logging.error(f'Missing item(s) in configuration: {", ".join(pars_missing)}')
             return(False)
@@ -149,7 +166,7 @@
         lens_magnification = self.config.get('lens_magnification')
         if not isinstance(lens_magnification, (int, float)) or lens_magnification <= 0.:
             illegal_value(lens_magnification, 'lens_magnification', 'detector file')
-            logging.warning('Using default lens_magnification value of 1.0.')
+            logging.warning('Using default lens_magnification value of 1.0')
             return(1.0)
         else:
             return(lens_magnification)
@@ -191,7 +208,6 @@
         if not is_int(num_columns, 1):
             illegal_value(num_columns, 'columns', 'detector file')
             return(None)
-
         return(num_rows, num_columns)
 
 
@@ -231,7 +247,8 @@
             
     @property
     def bin_energies(self):
-        return(self.slope * np.linspace(0, self.max_E, self.num_bins, endpoint=False) + self.intercept)
+        return(self.slope * np.linspace(0, self.max_E, self.num_bins, endpoint=False) + 
+                self.intercept)
 
     @property
     def tth_angle(self):