Mercurial > repos > rv43 > chess_tomo
comparison general.py @ 3:fc38431f257f draft
planemo upload for repository https://github.com/rolfverberg/galaxytools commit f8c4bdb31c20c468045ad5e6eb255a293244bc6c
| author | rv43 |
|---|---|
| date | Mon, 20 Mar 2023 18:30:26 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| 2:3fd7398a3a7f | 3:fc38431f257f |
|---|---|
| 1 #!/usr/bin/env python3 | |
| 2 | |
| 3 #FIX write a function that returns a list of peak indices for a given plot | |
| 4 #FIX use raise_error concept on more functions to optionally raise an error | |
| 5 | |
| 6 # -*- coding: utf-8 -*- | |
| 7 """ | |
| 8 Created on Mon Dec 6 15:36:22 2021 | |
| 9 | |
| 10 @author: rv43 | |
| 11 """ | |
| 12 | |
| 13 import logging | |
| 14 logger=logging.getLogger(__name__) | |
| 15 | |
| 16 import os | |
| 17 import sys | |
| 18 import re | |
| 19 try: | |
| 20 from yaml import safe_load, safe_dump | |
| 21 except: | |
| 22 pass | |
| 23 try: | |
| 24 import h5py | |
| 25 except: | |
| 26 pass | |
| 27 import numpy as np | |
| 28 try: | |
| 29 import matplotlib.pyplot as plt | |
| 30 import matplotlib.lines as mlines | |
| 31 from matplotlib import transforms | |
| 32 from matplotlib.widgets import Button | |
| 33 except: | |
| 34 pass | |
| 35 | |
| 36 from ast import literal_eval | |
| 37 try: | |
| 38 from asteval import Interpreter, get_ast_names | |
| 39 except: | |
| 40 pass | |
| 41 from copy import deepcopy | |
| 42 try: | |
| 43 from sympy import diff, simplify | |
| 44 except: | |
| 45 pass | |
| 46 from time import time | |
| 47 | |
| 48 | |
| 49 def depth_list(L): return(isinstance(L, list) and max(map(depth_list, L))+1) | |
| 50 def depth_tuple(T): return(isinstance(T, tuple) and max(map(depth_tuple, T))+1) | |
| 51 def unwrap_tuple(T): | |
| 52 if depth_tuple(T) > 1 and len(T) == 1: | |
| 53 T = unwrap_tuple(*T) | |
| 54 return(T) | |
| 55 | |
| 56 def illegal_value(value, name, location=None, raise_error=False, log=True): | |
| 57 if not isinstance(location, str): | |
| 58 location = '' | |
| 59 else: | |
| 60 location = f'in {location} ' | |
| 61 if isinstance(name, str): | |
| 62 error_msg = f'Illegal value for {name} {location}({value}, {type(value)})' | |
| 63 else: | |
| 64 error_msg = f'Illegal value {location}({value}, {type(value)})' | |
| 65 if log: | |
| 66 logger.error(error_msg) | |
| 67 if raise_error: | |
| 68 raise ValueError(error_msg) | |
| 69 | |
| 70 def illegal_combination(value1, name1, value2, name2, location=None, raise_error=False, | |
| 71 log=True): | |
| 72 if not isinstance(location, str): | |
| 73 location = '' | |
| 74 else: | |
| 75 location = f'in {location} ' | |
| 76 if isinstance(name1, str): | |
| 77 error_msg = f'Illegal combination for {name1} and {name2} {location}'+ \ | |
| 78 f'({value1}, {type(value1)} and {value2}, {type(value2)})' | |
| 79 else: | |
| 80 error_msg = f'Illegal combination {location}'+ \ | |
| 81 f'({value1}, {type(value1)} and {value2}, {type(value2)})' | |
| 82 if log: | |
| 83 logger.error(error_msg) | |
| 84 if raise_error: | |
| 85 raise ValueError(error_msg) | |
| 86 | |
| 87 def test_ge_gt_le_lt(ge, gt, le, lt, func, location=None, raise_error=False, log=True): | |
| 88 """Check individual and mutual validity of ge, gt, le, lt qualifiers | |
| 89 func: is_int or is_num to test for int or numbers | |
| 90 Return: True upon success or False when mutually exlusive | |
| 91 """ | |
| 92 if ge is None and gt is None and le is None and lt is None: | |
| 93 return(True) | |
| 94 if ge is not None: | |
| 95 if not func(ge): | |
| 96 illegal_value(ge, 'ge', location, raise_error, log) | |
| 97 return(False) | |
| 98 if gt is not None: | |
| 99 illegal_combination(ge, 'ge', gt, 'gt', location, raise_error, log) | |
| 100 return(False) | |
| 101 elif gt is not None and not func(gt): | |
| 102 illegal_value(gt, 'gt', location, raise_error, log) | |
| 103 return(False) | |
| 104 if le is not None: | |
| 105 if not func(le): | |
| 106 illegal_value(le, 'le', location, raise_error, log) | |
| 107 return(False) | |
| 108 if lt is not None: | |
| 109 illegal_combination(le, 'le', lt, 'lt', location, raise_error, log) | |
| 110 return(False) | |
| 111 elif lt is not None and not func(lt): | |
| 112 illegal_value(lt, 'lt', location, raise_error, log) | |
| 113 return(False) | |
| 114 if ge is not None: | |
| 115 if le is not None and ge > le: | |
| 116 illegal_combination(ge, 'ge', le, 'le', location, raise_error, log) | |
| 117 return(False) | |
| 118 elif lt is not None and ge >= lt: | |
| 119 illegal_combination(ge, 'ge', lt, 'lt', location, raise_error, log) | |
| 120 return(False) | |
| 121 elif gt is not None: | |
| 122 if le is not None and gt >= le: | |
| 123 illegal_combination(gt, 'gt', le, 'le', location, raise_error, log) | |
| 124 return(False) | |
| 125 elif lt is not None and gt >= lt: | |
| 126 illegal_combination(gt, 'gt', lt, 'lt', location, raise_error, log) | |
| 127 return(False) | |
| 128 return(True) | |
| 129 | |
| 130 def range_string_ge_gt_le_lt(ge=None, gt=None, le=None, lt=None): | |
| 131 """Return a range string representation matching the ge, gt, le, lt qualifiers | |
| 132 Does not validate the inputs, do that as needed before calling | |
| 133 """ | |
| 134 range_string = '' | |
| 135 if ge is not None: | |
| 136 if le is None and lt is None: | |
| 137 range_string += f'>= {ge}' | |
| 138 else: | |
| 139 range_string += f'[{ge}, ' | |
| 140 elif gt is not None: | |
| 141 if le is None and lt is None: | |
| 142 range_string += f'> {gt}' | |
| 143 else: | |
| 144 range_string += f'({gt}, ' | |
| 145 if le is not None: | |
| 146 if ge is None and gt is None: | |
| 147 range_string += f'<= {le}' | |
| 148 else: | |
| 149 range_string += f'{le}]' | |
| 150 elif lt is not None: | |
| 151 if ge is None and gt is None: | |
| 152 range_string += f'< {lt}' | |
| 153 else: | |
| 154 range_string += f'{lt})' | |
| 155 return(range_string) | |
| 156 | |
| 157 def is_int(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): | |
| 158 """Value is an integer in range ge <= v <= le or gt < v < lt or some combination. | |
| 159 Return: True if yes or False is no | |
| 160 """ | |
| 161 return(_is_int_or_num(v, 'int', ge, gt, le, lt, raise_error, log)) | |
| 162 | |
| 163 def is_num(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): | |
| 164 """Value is a number in range ge <= v <= le or gt < v < lt or some combination. | |
| 165 Return: True if yes or False is no | |
| 166 """ | |
| 167 return(_is_int_or_num(v, 'num', ge, gt, le, lt, raise_error, log)) | |
| 168 | |
| 169 def _is_int_or_num(v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False, | |
| 170 log=True): | |
| 171 if type_str == 'int': | |
| 172 if not isinstance(v, int): | |
| 173 illegal_value(v, 'v', '_is_int_or_num', raise_error, log) | |
| 174 return(False) | |
| 175 if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, '_is_int_or_num', raise_error, log): | |
| 176 return(False) | |
| 177 elif type_str == 'num': | |
| 178 if not isinstance(v, (int, float)): | |
| 179 illegal_value(v, 'v', '_is_int_or_num', raise_error, log) | |
| 180 return(False) | |
| 181 if not test_ge_gt_le_lt(ge, gt, le, lt, is_num, '_is_int_or_num', raise_error, log): | |
| 182 return(False) | |
| 183 else: | |
| 184 illegal_value(type_str, 'type_str', '_is_int_or_num', raise_error, log) | |
| 185 return(False) | |
| 186 if ge is None and gt is None and le is None and lt is None: | |
| 187 return(True) | |
| 188 error = False | |
| 189 if ge is not None and v < ge: | |
| 190 error = True | |
| 191 error_msg = f'Value {v} out of range: {v} !>= {ge}' | |
| 192 if not error and gt is not None and v <= gt: | |
| 193 error = True | |
| 194 error_msg = f'Value {v} out of range: {v} !> {gt}' | |
| 195 if not error and le is not None and v > le: | |
| 196 error = True | |
| 197 error_msg = f'Value {v} out of range: {v} !<= {le}' | |
| 198 if not error and lt is not None and v >= lt: | |
| 199 error = True | |
| 200 error_msg = f'Value {v} out of range: {v} !< {lt}' | |
| 201 if error: | |
| 202 if log: | |
| 203 logger.error(error_msg) | |
| 204 if raise_error: | |
| 205 raise ValueError(error_msg) | |
| 206 return(False) | |
| 207 return(True) | |
| 208 | |
| 209 def is_int_pair(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): | |
| 210 """Value is an integer pair, each in range ge <= v[i] <= le or gt < v[i] < lt or | |
| 211 ge[i] <= v[i] <= le[i] or gt[i] < v[i] < lt[i] or some combination. | |
| 212 Return: True if yes or False is no | |
| 213 """ | |
| 214 return(_is_int_or_num_pair(v, 'int', ge, gt, le, lt, raise_error, log)) | |
| 215 | |
| 216 def is_num_pair(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): | |
| 217 """Value is a number pair, each in range ge <= v[i] <= le or gt < v[i] < lt or | |
| 218 ge[i] <= v[i] <= le[i] or gt[i] < v[i] < lt[i] or some combination. | |
| 219 Return: True if yes or False is no | |
| 220 """ | |
| 221 return(_is_int_or_num_pair(v, 'num', ge, gt, le, lt, raise_error, log)) | |
| 222 | |
| 223 def _is_int_or_num_pair(v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False, | |
| 224 log=True): | |
| 225 if type_str == 'int': | |
| 226 if not (isinstance(v, (tuple, list)) and len(v) == 2 and isinstance(v[0], int) and | |
| 227 isinstance(v[1], int)): | |
| 228 illegal_value(v, 'v', '_is_int_or_num_pair', raise_error, log) | |
| 229 return(False) | |
| 230 func = is_int | |
| 231 elif type_str == 'num': | |
| 232 if not (isinstance(v, (tuple, list)) and len(v) == 2 and isinstance(v[0], (int, float)) and | |
| 233 isinstance(v[1], (int, float))): | |
| 234 illegal_value(v, 'v', '_is_int_or_num_pair', raise_error, log) | |
| 235 return(False) | |
| 236 func = is_num | |
| 237 else: | |
| 238 illegal_value(type_str, 'type_str', '_is_int_or_num_pair', raise_error, log) | |
| 239 return(False) | |
| 240 if ge is None and gt is None and le is None and lt is None: | |
| 241 return(True) | |
| 242 if ge is None or func(ge, log=True): | |
| 243 ge = 2*[ge] | |
| 244 elif not _is_int_or_num_pair(ge, type_str, raise_error=raise_error, log=log): | |
| 245 return(False) | |
| 246 if gt is None or func(gt, log=True): | |
| 247 gt = 2*[gt] | |
| 248 elif not _is_int_or_num_pair(gt, type_str, raise_error=raise_error, log=log): | |
| 249 return(False) | |
| 250 if le is None or func(le, log=True): | |
| 251 le = 2*[le] | |
| 252 elif not _is_int_or_num_pair(le, type_str, raise_error=raise_error, log=log): | |
| 253 return(False) | |
| 254 if lt is None or func(lt, log=True): | |
| 255 lt = 2*[lt] | |
| 256 elif not _is_int_or_num_pair(lt, type_str, raise_error=raise_error, log=log): | |
| 257 return(False) | |
| 258 if (not func(v[0], ge[0], gt[0], le[0], lt[0], raise_error, log) or | |
| 259 not func(v[1], ge[1], gt[1], le[1], lt[1], raise_error, log)): | |
| 260 return(False) | |
| 261 return(True) | |
| 262 | |
| 263 def is_int_series(l, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): | |
| 264 """Value is a tuple or list of integers, each in range ge <= l[i] <= le or | |
| 265 gt < l[i] < lt or some combination. | |
| 266 """ | |
| 267 if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, 'is_int_series', raise_error, log): | |
| 268 return(False) | |
| 269 if not isinstance(l, (tuple, list)): | |
| 270 illegal_value(l, 'l', 'is_int_series', raise_error, log) | |
| 271 return(False) | |
| 272 if any(True if not is_int(v, ge, gt, le, lt, raise_error, log) else False for v in l): | |
| 273 return(False) | |
| 274 return(True) | |
| 275 | |
| 276 def is_num_series(l, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): | |
| 277 """Value is a tuple or list of numbers, each in range ge <= l[i] <= le or | |
| 278 gt < l[i] < lt or some combination. | |
| 279 """ | |
| 280 if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, 'is_int_series', raise_error, log): | |
| 281 return(False) | |
| 282 if not isinstance(l, (tuple, list)): | |
| 283 illegal_value(l, 'l', 'is_num_series', raise_error, log) | |
| 284 return(False) | |
| 285 if any(True if not is_num(v, ge, gt, le, lt, raise_error, log) else False for v in l): | |
| 286 return(False) | |
| 287 return(True) | |
| 288 | |
| 289 def is_str_series(l, raise_error=False, log=True): | |
| 290 """Value is a tuple or list of strings. | |
| 291 """ | |
| 292 if (not isinstance(l, (tuple, list)) or | |
| 293 any(True if not isinstance(s, str) else False for s in l)): | |
| 294 illegal_value(l, 'l', 'is_str_series', raise_error, log) | |
| 295 return(False) | |
| 296 return(True) | |
| 297 | |
| 298 def is_dict_series(l, raise_error=False, log=True): | |
| 299 """Value is a tuple or list of dictionaries. | |
| 300 """ | |
| 301 if (not isinstance(l, (tuple, list)) or | |
| 302 any(True if not isinstance(d, dict) else False for d in l)): | |
| 303 illegal_value(l, 'l', 'is_dict_series', raise_error, log) | |
| 304 return(False) | |
| 305 return(True) | |
| 306 | |
| 307 def is_dict_nums(l, raise_error=False, log=True): | |
| 308 """Value is a dictionary with single number values | |
| 309 """ | |
| 310 if (not isinstance(l, dict) or | |
| 311 any(True if not is_num(v, log=False) else False for v in l.values())): | |
| 312 illegal_value(l, 'l', 'is_dict_nums', raise_error, log) | |
| 313 return(False) | |
| 314 return(True) | |
| 315 | |
| 316 def is_dict_strings(l, raise_error=False, log=True): | |
| 317 """Value is a dictionary with single string values | |
| 318 """ | |
| 319 if (not isinstance(l, dict) or | |
| 320 any(True if not isinstance(v, str) else False for v in l.values())): | |
| 321 illegal_value(l, 'l', 'is_dict_strings', raise_error, log) | |
| 322 return(False) | |
| 323 return(True) | |
| 324 | |
| 325 def is_index(v, ge=0, lt=None, raise_error=False, log=True): | |
| 326 """Value is an array index in range ge <= v < lt. | |
| 327 NOTE lt IS NOT included! | |
| 328 """ | |
| 329 if isinstance(lt, int): | |
| 330 if lt <= ge: | |
| 331 illegal_combination(ge, 'ge', lt, 'lt', 'is_index', raise_error, log) | |
| 332 return(False) | |
| 333 return(is_int(v, ge=ge, lt=lt, raise_error=raise_error, log=log)) | |
| 334 | |
| 335 def is_index_range(v, ge=0, le=None, lt=None, raise_error=False, log=True): | |
| 336 """Value is an array index range in range ge <= v[0] <= v[1] <= le or ge <= v[0] <= v[1] < lt. | |
| 337 NOTE le IS included! | |
| 338 """ | |
| 339 if not is_int_pair(v, raise_error=raise_error, log=log): | |
| 340 return(False) | |
| 341 if not test_ge_gt_le_lt(ge, None, le, lt, is_int, 'is_index_range', raise_error, log): | |
| 342 return(False) | |
| 343 if not ge <= v[0] <= v[1] or (le is not None and v[1] > le) or (lt is not None and v[1] >= lt): | |
| 344 if le is not None: | |
| 345 error_msg = f'Value {v} out of range: !({ge} <= {v[0]} <= {v[1]} <= {le})' | |
| 346 else: | |
| 347 error_msg = f'Value {v} out of range: !({ge} <= {v[0]} <= {v[1]} < {lt})' | |
| 348 if log: | |
| 349 logger.error(error_msg) | |
| 350 if raise_error: | |
| 351 raise ValueError(error_msg) | |
| 352 return(False) | |
| 353 return(True) | |
| 354 | |
| 355 def index_nearest(a, value): | |
| 356 a = np.asarray(a) | |
| 357 if a.ndim > 1: | |
| 358 raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})') | |
| 359 # Round up for .5 | |
| 360 value *= 1.0+sys.float_info.epsilon | |
| 361 return((int)(np.argmin(np.abs(a-value)))) | |
| 362 | |
| 363 def index_nearest_low(a, value): | |
| 364 a = np.asarray(a) | |
| 365 if a.ndim > 1: | |
| 366 raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})') | |
| 367 index = int(np.argmin(np.abs(a-value))) | |
| 368 if value < a[index] and index > 0: | |
| 369 index -= 1 | |
| 370 return(index) | |
| 371 | |
| 372 def index_nearest_upp(a, value): | |
| 373 a = np.asarray(a) | |
| 374 if a.ndim > 1: | |
| 375 raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})') | |
| 376 index = int(np.argmin(np.abs(a-value))) | |
| 377 if value > a[index] and index < a.size-1: | |
| 378 index += 1 | |
| 379 return(index) | |
| 380 | |
| 381 def round_to_n(x, n=1): | |
| 382 if x == 0.0: | |
| 383 return(0) | |
| 384 else: | |
| 385 return(type(x)(round(x, n-1-int(np.floor(np.log10(abs(x))))))) | |
| 386 | |
| 387 def round_up_to_n(x, n=1): | |
| 388 xr = round_to_n(x, n) | |
| 389 if abs(x/xr) > 1.0: | |
| 390 xr += np.sign(x)*10**(np.floor(np.log10(abs(x)))+1-n) | |
| 391 return(type(x)(xr)) | |
| 392 | |
| 393 def trunc_to_n(x, n=1): | |
| 394 xr = round_to_n(x, n) | |
| 395 if abs(xr/x) > 1.0: | |
| 396 xr -= np.sign(x)*10**(np.floor(np.log10(abs(x)))+1-n) | |
| 397 return(type(x)(xr)) | |
| 398 | |
| 399 def almost_equal(a, b, sig_figs): | |
| 400 if is_num(a) and is_num(b): | |
| 401 return(abs(round_to_n(a-b, sig_figs)) < pow(10, -sig_figs+1)) | |
| 402 else: | |
| 403 raise ValueError(f'Invalid value for a or b in almost_equal (a: {a}, {type(a)}, '+ | |
| 404 f'b: {b}, {type(b)})') | |
| 405 return(False) | |
| 406 | |
| 407 def string_to_list(s, split_on_dash=True, remove_duplicates=True, sort=True): | |
| 408 """Return a list of numbers by splitting/expanding a string on any combination of | |
| 409 commas, whitespaces, or dashes (when split_on_dash=True) | |
| 410 e.g: '1, 3, 5-8, 12 ' -> [1, 3, 5, 6, 7, 8, 12] | |
| 411 """ | |
| 412 if not isinstance(s, str): | |
| 413 illegal_value(s, location='string_to_list') | |
| 414 return(None) | |
| 415 if not len(s): | |
| 416 return([]) | |
| 417 try: | |
| 418 ll = [x for x in re.split('\s+,\s+|\s+,|,\s+|\s+|,', s.strip())] | |
| 419 except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): | |
| 420 return(None) | |
| 421 if split_on_dash: | |
| 422 try: | |
| 423 l = [] | |
| 424 for l1 in ll: | |
| 425 l2 = [literal_eval(x) for x in re.split('\s+-\s+|\s+-|-\s+|\s+|-', l1)] | |
| 426 if len(l2) == 1: | |
| 427 l += l2 | |
| 428 elif len(l2) == 2 and l2[1] > l2[0]: | |
| 429 l += [i for i in range(l2[0], l2[1]+1)] | |
| 430 else: | |
| 431 raise ValueError | |
| 432 except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): | |
| 433 return(None) | |
| 434 else: | |
| 435 l = [literal_eval(x) for x in ll] | |
| 436 if remove_duplicates: | |
| 437 l = list(dict.fromkeys(l)) | |
| 438 if sort: | |
| 439 l = sorted(l) | |
| 440 return(l) | |
| 441 | |
| 442 def get_trailing_int(string): | |
| 443 indexRegex = re.compile(r'\d+$') | |
| 444 mo = indexRegex.search(string) | |
| 445 if mo is None: | |
| 446 return(None) | |
| 447 else: | |
| 448 return(int(mo.group())) | |
| 449 | |
| 450 def input_int(s=None, ge=None, gt=None, le=None, lt=None, default=None, inset=None, | |
| 451 raise_error=False, log=True): | |
| 452 return(_input_int_or_num('int', s, ge, gt, le, lt, default, inset, raise_error, log)) | |
| 453 | |
| 454 def input_num(s=None, ge=None, gt=None, le=None, lt=None, default=None, raise_error=False, | |
| 455 log=True): | |
| 456 return(_input_int_or_num('num', s, ge, gt, le, lt, default, None, raise_error,log)) | |
| 457 | |
| 458 def _input_int_or_num(type_str, s=None, ge=None, gt=None, le=None, lt=None, default=None, | |
| 459 inset=None, raise_error=False, log=True): | |
| 460 if type_str == 'int': | |
| 461 if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, '_input_int_or_num', raise_error, log): | |
| 462 return(None) | |
| 463 elif type_str == 'num': | |
| 464 if not test_ge_gt_le_lt(ge, gt, le, lt, is_num, '_input_int_or_num', raise_error, log): | |
| 465 return(None) | |
| 466 else: | |
| 467 illegal_value(type_str, 'type_str', '_input_int_or_num', raise_error, log) | |
| 468 return(None) | |
| 469 if default is not None: | |
| 470 if not _is_int_or_num(default, type_str, raise_error=raise_error, log=log): | |
| 471 return(None) | |
| 472 if ge is not None and default < ge: | |
| 473 illegal_combination(ge, 'ge', default, 'default', '_input_int_or_num', raise_error, | |
| 474 log) | |
| 475 return(None) | |
| 476 if gt is not None and default <= gt: | |
| 477 illegal_combination(gt, 'gt', default, 'default', '_input_int_or_num', raise_error, | |
| 478 log) | |
| 479 return(None) | |
| 480 if le is not None and default > le: | |
| 481 illegal_combination(le, 'le', default, 'default', '_input_int_or_num', raise_error, | |
| 482 log) | |
| 483 return(None) | |
| 484 if lt is not None and default >= lt: | |
| 485 illegal_combination(lt, 'lt', default, 'default', '_input_int_or_num', raise_error, | |
| 486 log) | |
| 487 return(None) | |
| 488 default_string = f' [{default}]' | |
| 489 else: | |
| 490 default_string = '' | |
| 491 if inset is not None: | |
| 492 if (not isinstance(inset, (tuple, list)) or any(True if not isinstance(i, int) else | |
| 493 False for i in inset)): | |
| 494 illegal_value(inset, 'inset', '_input_int_or_num', raise_error, log) | |
| 495 return(None) | |
| 496 v_range = f'{range_string_ge_gt_le_lt(ge, gt, le, lt)}' | |
| 497 if len(v_range): | |
| 498 v_range = f' {v_range}' | |
| 499 if s is None: | |
| 500 if type_str == 'int': | |
| 501 print(f'Enter an integer{v_range}{default_string}: ') | |
| 502 else: | |
| 503 print(f'Enter a number{v_range}{default_string}: ') | |
| 504 else: | |
| 505 print(f'{s}{v_range}{default_string}: ') | |
| 506 try: | |
| 507 i = input() | |
| 508 if isinstance(i, str) and not len(i): | |
| 509 v = default | |
| 510 print(f'{v}') | |
| 511 else: | |
| 512 v = literal_eval(i) | |
| 513 if inset and v not in inset: | |
| 514 raise ValueError(f'{v} not part of the set {inset}') | |
| 515 except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): | |
| 516 v = None | |
| 517 except: | |
| 518 if log: | |
| 519 logger.error('Unexpected error') | |
| 520 if raise_error: | |
| 521 raise ValueError('Unexpected error') | |
| 522 if not _is_int_or_num(v, type_str, ge, gt, le, lt): | |
| 523 v = _input_int_or_num(type_str, s, ge, gt, le, lt, default, inset, raise_error, log) | |
| 524 return(v) | |
| 525 | |
| 526 def input_int_list(s=None, ge=None, le=None, split_on_dash=True, remove_duplicates=True, | |
| 527 sort=True, raise_error=False, log=True): | |
| 528 """Prompt the user to input a list of interger and split the entered string on any combination | |
| 529 of commas, whitespaces, or dashes (when split_on_dash is True) | |
| 530 e.g: '1 3,5-8 , 12 ' -> [1, 3, 5, 6, 7, 8, 12] | |
| 531 remove_duplicates: removes duplicates if True (may also change the order) | |
| 532 sort: sort in ascending order if True | |
| 533 return None upon an illegal input | |
| 534 """ | |
| 535 return(_input_int_or_num_list('int', s, ge, le, split_on_dash, remove_duplicates, sort, | |
| 536 raise_error, log)) | |
| 537 | |
| 538 def input_num_list(s=None, ge=None, le=None, remove_duplicates=True, sort=True, raise_error=False, | |
| 539 log=True): | |
| 540 """Prompt the user to input a list of numbers and split the entered string on any combination | |
| 541 of commas or whitespaces | |
| 542 e.g: '1.0, 3, 5.8, 12 ' -> [1.0, 3.0, 5.8, 12.0] | |
| 543 remove_duplicates: removes duplicates if True (may also change the order) | |
| 544 sort: sort in ascending order if True | |
| 545 return None upon an illegal input | |
| 546 """ | |
| 547 return(_input_int_or_num_list('num', s, ge, le, False, remove_duplicates, sort, raise_error, | |
| 548 log)) | |
| 549 | |
| 550 def _input_int_or_num_list(type_str, s=None, ge=None, le=None, split_on_dash=True, | |
| 551 remove_duplicates=True, sort=True, raise_error=False, log=True): | |
| 552 #FIX do we want a limit on max dimension? | |
| 553 if type_str == 'int': | |
| 554 if not test_ge_gt_le_lt(ge, None, le, None, is_int, 'input_int_or_num_list', raise_error, | |
| 555 log): | |
| 556 return(None) | |
| 557 elif type_str == 'num': | |
| 558 if not test_ge_gt_le_lt(ge, None, le, None, is_num, 'input_int_or_num_list', raise_error, | |
| 559 log): | |
| 560 return(None) | |
| 561 else: | |
| 562 illegal_value(type_str, 'type_str', '_input_int_or_num_list') | |
| 563 return(None) | |
| 564 v_range = f'{range_string_ge_gt_le_lt(ge=ge, le=le)}' | |
| 565 if len(v_range): | |
| 566 v_range = f' (each value in {v_range})' | |
| 567 if s is None: | |
| 568 print(f'Enter a series of integers{v_range}: ') | |
| 569 else: | |
| 570 print(f'{s}{v_range}: ') | |
| 571 try: | |
| 572 l = string_to_list(input(), split_on_dash, remove_duplicates, sort) | |
| 573 except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): | |
| 574 l = None | |
| 575 except: | |
| 576 print('Unexpected error') | |
| 577 raise | |
| 578 if (not isinstance(l, list) or | |
| 579 any(True if not _is_int_or_num(v, type_str, ge=ge, le=le) else False for v in l)): | |
| 580 if split_on_dash: | |
| 581 print('Invalid input: enter a valid set of dash/comma/whitespace separated integers '+ | |
| 582 'e.g. 1 3,5-8 , 12') | |
| 583 else: | |
| 584 print('Invalid input: enter a valid set of comma/whitespace separated integers '+ | |
| 585 'e.g. 1 3,5 8 , 12') | |
| 586 l = _input_int_or_num_list(type_str, s, ge, le, split_on_dash, remove_duplicates, sort, | |
| 587 raise_error, log) | |
| 588 return(l) | |
| 589 | |
| 590 def input_yesno(s=None, default=None): | |
| 591 if default is not None: | |
| 592 if not isinstance(default, str): | |
| 593 illegal_value(default, 'default', 'input_yesno') | |
| 594 return(None) | |
| 595 if default.lower() in 'yes': | |
| 596 default = 'y' | |
| 597 elif default.lower() in 'no': | |
| 598 default = 'n' | |
| 599 else: | |
| 600 illegal_value(default, 'default', 'input_yesno') | |
| 601 return(None) | |
| 602 default_string = f' [{default}]' | |
| 603 else: | |
| 604 default_string = '' | |
| 605 if s is None: | |
| 606 print(f'Enter yes or no{default_string}: ') | |
| 607 else: | |
| 608 print(f'{s}{default_string}: ') | |
| 609 i = input() | |
| 610 if isinstance(i, str) and not len(i): | |
| 611 i = default | |
| 612 print(f'{i}') | |
| 613 if i is not None and i.lower() in 'yes': | |
| 614 v = True | |
| 615 elif i is not None and i.lower() in 'no': | |
| 616 v = False | |
| 617 else: | |
| 618 print('Invalid input, enter yes or no') | |
| 619 v = input_yesno(s, default) | |
| 620 return(v) | |
| 621 | |
| 622 def input_menu(items, default=None, header=None): | |
| 623 if not isinstance(items, (tuple, list)) or any(True if not isinstance(i, str) else False | |
| 624 for i in items): | |
| 625 illegal_value(items, 'items', 'input_menu') | |
| 626 return(None) | |
| 627 if default is not None: | |
| 628 if not (isinstance(default, str) and default in items): | |
| 629 logger.error(f'Invalid value for default ({default}), must be in {items}') | |
| 630 return(None) | |
| 631 default_string = f' [{items.index(default)+1}]' | |
| 632 else: | |
| 633 default_string = '' | |
| 634 if header is None: | |
| 635 print(f'Choose one of the following items (1, {len(items)}){default_string}:') | |
| 636 else: | |
| 637 print(f'{header} (1, {len(items)}){default_string}:') | |
| 638 for i, choice in enumerate(items): | |
| 639 print(f' {i+1}: {choice}') | |
| 640 try: | |
| 641 choice = input() | |
| 642 if isinstance(choice, str) and not len(choice): | |
| 643 choice = items.index(default) | |
| 644 print(f'{choice+1}') | |
| 645 else: | |
| 646 choice = literal_eval(choice) | |
| 647 if isinstance(choice, int) and 1 <= choice <= len(items): | |
| 648 choice -= 1 | |
| 649 else: | |
| 650 raise ValueError | |
| 651 except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): | |
| 652 choice = None | |
| 653 except: | |
| 654 print('Unexpected error') | |
| 655 raise | |
| 656 if choice is None: | |
| 657 print(f'Invalid choice, enter a number between 1 and {len(items)}') | |
| 658 choice = input_menu(items, default) | |
| 659 return(choice) | |
| 660 | |
| 661 def assert_no_duplicates_in_list_of_dicts(l: list, raise_error=False) -> list: | |
| 662 if not isinstance(l, list): | |
| 663 illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error) | |
| 664 return(None) | |
| 665 if any(True if not isinstance(d, dict) else False for d in l): | |
| 666 illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error) | |
| 667 return(None) | |
| 668 if len(l) != len([dict(t) for t in {tuple(sorted(d.items())) for d in l}]): | |
| 669 if raise_error: | |
| 670 raise ValueError(f'Duplicate items found in {l}') | |
| 671 else: | |
| 672 logger.error(f'Duplicate items found in {l}') | |
| 673 return(None) | |
| 674 else: | |
| 675 return(l) | |
| 676 | |
| 677 def assert_no_duplicate_key_in_list_of_dicts(l: list, key: str, raise_error=False) -> list: | |
| 678 if not isinstance(key, str): | |
| 679 illegal_value(key, 'key', 'assert_no_duplicate_key_in_list_of_dicts', raise_error) | |
| 680 return(None) | |
| 681 if not isinstance(l, list): | |
| 682 illegal_value(l, 'l', 'assert_no_duplicate_key_in_list_of_dicts', raise_error) | |
| 683 return(None) | |
| 684 if any(True if not isinstance(d, dict) else False for d in l): | |
| 685 illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error) | |
| 686 return(None) | |
| 687 keys = [d.get(key, None) for d in l] | |
| 688 if None in keys or len(set(keys)) != len(l): | |
| 689 if raise_error: | |
| 690 raise ValueError(f'Duplicate or missing key ({key}) found in {l}') | |
| 691 else: | |
| 692 logger.error(f'Duplicate or missing key ({key}) found in {l}') | |
| 693 return(None) | |
| 694 else: | |
| 695 return(l) | |
| 696 | |
| 697 def assert_no_duplicate_attr_in_list_of_objs(l: list, attr: str, raise_error=False) -> list: | |
| 698 if not isinstance(attr, str): | |
| 699 illegal_value(attr, 'attr', 'assert_no_duplicate_attr_in_list_of_objs', raise_error) | |
| 700 return(None) | |
| 701 if not isinstance(l, list): | |
| 702 illegal_value(l, 'l', 'assert_no_duplicate_key_in_list_of_objs', raise_error) | |
| 703 return(None) | |
| 704 attrs = [getattr(obj, attr, None) for obj in l] | |
| 705 if None in attrs or len(set(attrs)) != len(l): | |
| 706 if raise_error: | |
| 707 raise ValueError(f'Duplicate or missing attr ({attr}) found in {l}') | |
| 708 else: | |
| 709 logger.error(f'Duplicate or missing attr ({attr}) found in {l}') | |
| 710 return(None) | |
| 711 else: | |
| 712 return(l) | |
| 713 | |
| 714 def file_exists_and_readable(path): | |
| 715 if not os.path.isfile(path): | |
| 716 raise ValueError(f'{path} is not a valid file') | |
| 717 elif not os.access(path, os.R_OK): | |
| 718 raise ValueError(f'{path} is not accessible for reading') | |
| 719 else: | |
| 720 return(path) | |
| 721 | |
| 722 def create_mask(x, bounds=None, exclude_bounds=False, current_mask=None): | |
| 723 # bounds is a pair of number in the same units a x | |
| 724 if not isinstance(x, (tuple, list, np.ndarray)) or not len(x): | |
| 725 logger.warning(f'Invalid input array ({x}, {type(x)})') | |
| 726 return(None) | |
| 727 if bounds is not None and not is_num_pair(bounds): | |
| 728 logger.warning(f'Invalid bounds parameter ({bounds} {type(bounds)}, input ignored') | |
| 729 bounds = None | |
| 730 if bounds is not None: | |
| 731 if exclude_bounds: | |
| 732 mask = np.logical_or(x < min(bounds), x > max(bounds)) | |
| 733 else: | |
| 734 mask = np.logical_and(x > min(bounds), x < max(bounds)) | |
| 735 else: | |
| 736 mask = np.ones(len(x), dtype=bool) | |
| 737 if current_mask is not None: | |
| 738 if not isinstance(current_mask, (tuple, list, np.ndarray)) or len(current_mask) != len(x): | |
| 739 logger.warning(f'Invalid current_mask ({current_mask}, {type(current_mask)}), '+ | |
| 740 'input ignored') | |
| 741 else: | |
| 742 mask = np.logical_or(mask, current_mask) | |
| 743 if not True in mask: | |
| 744 logger.warning('Entire data array is masked') | |
| 745 return(mask) | |
| 746 | |
| 747 def eval_expr(name, expr, expr_variables, user_variables=None, max_depth=10, raise_error=False, | |
| 748 log=True, **kwargs): | |
| 749 """Evaluate an expression of expressions | |
| 750 """ | |
| 751 if not isinstance(name, str): | |
| 752 illegal_value(name, 'name', 'eval_expr', raise_error, log) | |
| 753 return(None) | |
| 754 if not isinstance(expr, str): | |
| 755 illegal_value(expr, 'expr', 'eval_expr', raise_error, log) | |
| 756 return(None) | |
| 757 if not is_dict_strings(expr_variables, log=False): | |
| 758 illegal_value(expr_variables, 'expr_variables', 'eval_expr', raise_error, log) | |
| 759 return(None) | |
| 760 if user_variables is not None and not is_dict_nums(user_variables, log=False): | |
| 761 illegal_value(user_variables, 'user_variables', 'eval_expr', raise_error, log) | |
| 762 return(None) | |
| 763 if not is_int(max_depth, gt=1, log=False): | |
| 764 illegal_value(max_depth, 'max_depth', 'eval_expr', raise_error, log) | |
| 765 return(None) | |
| 766 if not isinstance(raise_error, bool): | |
| 767 illegal_value(raise_error, 'raise_error', 'eval_expr', raise_error, log) | |
| 768 return(None) | |
| 769 if not isinstance(log, bool): | |
| 770 illegal_value(log, 'log', 'eval_expr', raise_error, log) | |
| 771 return(None) | |
| 772 # print(f'\nEvaluate the full expression for {expr}') | |
| 773 if 'chain' in kwargs: | |
| 774 chain = kwargs.pop('chain') | |
| 775 if not is_str_series(chain): | |
| 776 illegal_value(chain, 'chain', 'eval_expr', raise_error, log) | |
| 777 return(None) | |
| 778 else: | |
| 779 chain = [] | |
| 780 if len(chain) > max_depth: | |
| 781 error_msg = 'Exceeded maximum depth ({max_depth}) in eval_expr' | |
| 782 if log: | |
| 783 logger.error(error_msg) | |
| 784 if raise_error: | |
| 785 raise ValueError(error_msg) | |
| 786 return(None) | |
| 787 if name not in chain: | |
| 788 chain.append(name) | |
| 789 # print(f'start: chain = {chain}') | |
| 790 if 'ast' in kwargs: | |
| 791 ast = kwargs.pop('ast') | |
| 792 else: | |
| 793 ast = Interpreter() | |
| 794 if user_variables is not None: | |
| 795 ast.symtable.update(user_variables) | |
| 796 chain_vars = [var for var in get_ast_names(ast.parse(expr)) | |
| 797 if var in expr_variables and var not in ast.symtable] | |
| 798 # print(f'chain_vars: {chain_vars}') | |
| 799 save_chain = chain.copy() | |
| 800 for var in chain_vars: | |
| 801 # print(f'\n\tname = {name}, var = {var}:\n\t\t{expr_variables[var]}') | |
| 802 # print(f'\tchain = {chain}') | |
| 803 if var in chain: | |
| 804 error_msg = f'Circular variable {var} in eval_expr' | |
| 805 if log: | |
| 806 logger.error(error_msg) | |
| 807 if raise_error: | |
| 808 raise ValueError(error_msg) | |
| 809 return(None) | |
| 810 # print(f'\tknown symbols:\n\t\t{ast.user_defined_symbols()}\n') | |
| 811 if var in ast.user_defined_symbols(): | |
| 812 val = ast.symtable[var] | |
| 813 else: | |
| 814 #val = eval_expr(var, expr_variables[var], expr_variables, user_variables=user_variables, | |
| 815 val = eval_expr(var, expr_variables[var], expr_variables, max_depth=max_depth, | |
| 816 raise_error=raise_error, log=log, chain=chain, ast=ast) | |
| 817 if val is None: | |
| 818 return(None) | |
| 819 ast.symtable[var] = val | |
| 820 # print(f'\tval = {val}') | |
| 821 # print(f'\t{var} = {ast.symtable[var]}') | |
| 822 chain = save_chain.copy() | |
| 823 # print(f'\treset loop for {var}: chain = {chain}') | |
| 824 val = ast.eval(expr) | |
| 825 # print(f'return val for {expr} = {val}\n') | |
| 826 return(val) | |
| 827 | |
| 828 def full_gradient(expr, x, expr_name=None, expr_variables=None, valid_variables=None, max_depth=10, | |
| 829 raise_error=False, log=True, **kwargs): | |
| 830 """Compute the full gradient dexpr/dx | |
| 831 """ | |
| 832 if not isinstance(x, str): | |
| 833 illegal_value(x, 'x', 'full_gradient', raise_error, log) | |
| 834 return(None) | |
| 835 if expr_name is not None and not isinstance(expr_name, str): | |
| 836 illegal_value(expr_name, 'expr_name', 'eval_expr', raise_error, log) | |
| 837 return(None) | |
| 838 if expr_variables is not None and not is_dict_strings(expr_variables, log=False): | |
| 839 illegal_value(expr_variables, 'expr_variables', 'full_gradient', raise_error, log) | |
| 840 return(None) | |
| 841 if valid_variables is not None and not is_str_series(valid_variables, log=False): | |
| 842 illegal_value(valid_variables, 'valid_variables', 'full_gradient', raise_error, log) | |
| 843 if not is_int(max_depth, gt=1, log=False): | |
| 844 illegal_value(max_depth, 'max_depth', 'eval_expr', raise_error, log) | |
| 845 return(None) | |
| 846 if not isinstance(raise_error, bool): | |
| 847 illegal_value(raise_error, 'raise_error', 'eval_expr', raise_error, log) | |
| 848 return(None) | |
| 849 if not isinstance(log, bool): | |
| 850 illegal_value(log, 'log', 'eval_expr', raise_error, log) | |
| 851 return(None) | |
| 852 # print(f'\nGet full gradient of {expr_name} = {expr} with respect to {x}') | |
| 853 if expr_name is not None and expr_name == x: | |
| 854 return(1.0) | |
| 855 if 'chain' in kwargs: | |
| 856 chain = kwargs.pop('chain') | |
| 857 if not is_str_series(chain): | |
| 858 illegal_value(chain, 'chain', 'eval_expr', raise_error, log) | |
| 859 return(None) | |
| 860 else: | |
| 861 chain = [] | |
| 862 if len(chain) > max_depth: | |
| 863 error_msg = 'Exceeded maximum depth ({max_depth}) in eval_expr' | |
| 864 if log: | |
| 865 logger.error(error_msg) | |
| 866 if raise_error: | |
| 867 raise ValueError(error_msg) | |
| 868 return(None) | |
| 869 if expr_name is not None and expr_name not in chain: | |
| 870 chain.append(expr_name) | |
| 871 # print(f'start ({x}): chain = {chain}') | |
| 872 ast = Interpreter() | |
| 873 if expr_variables is None: | |
| 874 chain_vars = [] | |
| 875 else: | |
| 876 chain_vars = [var for var in get_ast_names(ast.parse(f'{expr}')) | |
| 877 if var in expr_variables and var != x and var not in ast.symtable] | |
| 878 # print(f'chain_vars: {chain_vars}') | |
| 879 if valid_variables is not None: | |
| 880 unknown_vars = [var for var in chain_vars if var not in valid_variables] | |
| 881 if len(unknown_vars): | |
| 882 error_msg = f'Unknown variable {unknown_vars} in {expr}' | |
| 883 if log: | |
| 884 logger.error(error_msg) | |
| 885 if raise_error: | |
| 886 raise ValueError(error_msg) | |
| 887 return(None) | |
| 888 dexpr_dx = diff(expr, x) | |
| 889 # print(f'direct gradient: d({expr})/d({x}) = {dexpr_dx} ({type(dexpr_dx)})') | |
| 890 save_chain = chain.copy() | |
| 891 for var in chain_vars: | |
| 892 # print(f'\n\texpr_name = {expr_name}, var = {var}:\n\t\t{expr}') | |
| 893 # print(f'\tchain = {chain}') | |
| 894 if var in chain: | |
| 895 error_msg = f'Circular variable {var} in full_gradient' | |
| 896 if log: | |
| 897 logger.error(error_msg) | |
| 898 if raise_error: | |
| 899 raise ValueError(error_msg) | |
| 900 return(None) | |
| 901 dexpr_dvar = diff(expr, var) | |
| 902 # print(f'\td({expr})/d({var}) = {dexpr_dvar}') | |
| 903 if dexpr_dvar: | |
| 904 dvar_dx = full_gradient(expr_variables[var], x, expr_name=var, | |
| 905 expr_variables=expr_variables, valid_variables=valid_variables, | |
| 906 max_depth=max_depth, raise_error=raise_error, log=log, chain=chain) | |
| 907 # print(f'\t\td({var})/d({x}) = {dvar_dx}') | |
| 908 if dvar_dx: | |
| 909 dexpr_dx = f'{dexpr_dx}+({dexpr_dvar})*({dvar_dx})' | |
| 910 # print(f'\t\t2: chain = {chain}') | |
| 911 chain = save_chain.copy() | |
| 912 # print(f'\treset loop for {var}: chain = {chain}') | |
| 913 # print(f'full gradient: d({expr})/d({x}) = {dexpr_dx} ({type(dexpr_dx)})') | |
| 914 # print(f'reset end: chain = {chain}\n\n') | |
| 915 return(simplify(dexpr_dx)) | |
| 916 | |
| 917 def bounds_from_mask(mask, return_include_bounds:bool=True): | |
| 918 bounds = [] | |
| 919 for i, m in enumerate(mask): | |
| 920 if m == return_include_bounds: | |
| 921 if len(bounds) == 0 or type(bounds[-1]) == tuple: | |
| 922 bounds.append(i) | |
| 923 else: | |
| 924 if len(bounds) > 0 and isinstance(bounds[-1], int): | |
| 925 bounds[-1] = (bounds[-1], i-1) | |
| 926 if len(bounds) > 0 and isinstance(bounds[-1], int): | |
| 927 bounds[-1] = (bounds[-1], mask.size-1) | |
| 928 return(bounds) | |
| 929 | |
| 930 def draw_mask_1d(ydata, xdata=None, current_index_ranges=None, current_mask=None, | |
| 931 select_mask=True, num_index_ranges_max=None, title=None, legend=None, test_mode=False): | |
| 932 #FIX make color blind friendly | |
| 933 def draw_selections(ax, current_include, current_exclude, selected_index_ranges): | |
| 934 ax.clear() | |
| 935 ax.set_title(title) | |
| 936 ax.legend([legend]) | |
| 937 ax.plot(xdata, ydata, 'k') | |
| 938 for (low, upp) in current_include: | |
| 939 xlow = 0.5*(xdata[max(0, low-1)]+xdata[low]) | |
| 940 xupp = 0.5*(xdata[upp]+xdata[min(num_data-1, upp+1)]) | |
| 941 ax.axvspan(xlow, xupp, facecolor='green', alpha=0.5) | |
| 942 for (low, upp) in current_exclude: | |
| 943 xlow = 0.5*(xdata[max(0, low-1)]+xdata[low]) | |
| 944 xupp = 0.5*(xdata[upp]+xdata[min(num_data-1, upp+1)]) | |
| 945 ax.axvspan(xlow, xupp, facecolor='red', alpha=0.5) | |
| 946 for (low, upp) in selected_index_ranges: | |
| 947 xlow = 0.5*(xdata[max(0, low-1)]+xdata[low]) | |
| 948 xupp = 0.5*(xdata[upp]+xdata[min(num_data-1, upp+1)]) | |
| 949 ax.axvspan(xlow, xupp, facecolor=selection_color, alpha=0.5) | |
| 950 ax.get_figure().canvas.draw() | |
| 951 | |
| 952 def onclick(event): | |
| 953 if event.inaxes in [fig.axes[0]]: | |
| 954 selected_index_ranges.append(index_nearest_upp(xdata, event.xdata)) | |
| 955 | |
| 956 def onrelease(event): | |
| 957 if len(selected_index_ranges) > 0: | |
| 958 if isinstance(selected_index_ranges[-1], int): | |
| 959 if event.inaxes in [fig.axes[0]]: | |
| 960 event.xdata = index_nearest_low(xdata, event.xdata) | |
| 961 if selected_index_ranges[-1] <= event.xdata: | |
| 962 selected_index_ranges[-1] = (selected_index_ranges[-1], event.xdata) | |
| 963 else: | |
| 964 selected_index_ranges[-1] = (event.xdata, selected_index_ranges[-1]) | |
| 965 draw_selections(event.inaxes, current_include, current_exclude, selected_index_ranges) | |
| 966 else: | |
| 967 selected_index_ranges.pop(-1) | |
| 968 | |
| 969 def confirm_selection(event): | |
| 970 plt.close() | |
| 971 | |
| 972 def clear_last_selection(event): | |
| 973 if len(selected_index_ranges): | |
| 974 selected_index_ranges.pop(-1) | |
| 975 else: | |
| 976 while len(current_include): | |
| 977 current_include.pop() | |
| 978 while len(current_exclude): | |
| 979 current_exclude.pop() | |
| 980 selected_mask.fill(False) | |
| 981 draw_selections(ax, current_include, current_exclude, selected_index_ranges) | |
| 982 | |
| 983 def update_mask(mask, selected_index_ranges, unselected_index_ranges): | |
| 984 for (low, upp) in selected_index_ranges: | |
| 985 selected_mask = np.logical_and(xdata >= xdata[low], xdata <= xdata[upp]) | |
| 986 mask = np.logical_or(mask, selected_mask) | |
| 987 for (low, upp) in unselected_index_ranges: | |
| 988 unselected_mask = np.logical_and(xdata >= xdata[low], xdata <= xdata[upp]) | |
| 989 mask[unselected_mask] = False | |
| 990 return(mask) | |
| 991 | |
| 992 def update_index_ranges(mask): | |
| 993 # Update the currently included index ranges (where mask is True) | |
| 994 current_include = [] | |
| 995 for i, m in enumerate(mask): | |
| 996 if m == True: | |
| 997 if len(current_include) == 0 or type(current_include[-1]) == tuple: | |
| 998 current_include.append(i) | |
| 999 else: | |
| 1000 if len(current_include) > 0 and isinstance(current_include[-1], int): | |
| 1001 current_include[-1] = (current_include[-1], i-1) | |
| 1002 if len(current_include) > 0 and isinstance(current_include[-1], int): | |
| 1003 current_include[-1] = (current_include[-1], num_data-1) | |
| 1004 return(current_include) | |
| 1005 | |
| 1006 # Check inputs | |
| 1007 ydata = np.asarray(ydata) | |
| 1008 if ydata.ndim > 1: | |
| 1009 logger.warning(f'Invalid ydata dimension ({ydata.ndim})') | |
| 1010 return(None, None) | |
| 1011 num_data = ydata.size | |
| 1012 if xdata is None: | |
| 1013 xdata = np.arange(num_data) | |
| 1014 else: | |
| 1015 xdata = np.asarray(xdata, dtype=np.float64) | |
| 1016 if xdata.ndim > 1 or xdata.size != num_data: | |
| 1017 logger.warning(f'Invalid xdata shape ({xdata.shape})') | |
| 1018 return(None, None) | |
| 1019 if not np.all(xdata[:-1] < xdata[1:]): | |
| 1020 logger.warning('Invalid xdata: must be monotonically increasing') | |
| 1021 return(None, None) | |
| 1022 if current_index_ranges is not None: | |
| 1023 if not isinstance(current_index_ranges, (tuple, list)): | |
| 1024 logger.warning('Invalid current_index_ranges parameter ({current_index_ranges}, '+ | |
| 1025 f'{type(current_index_ranges)})') | |
| 1026 return(None, None) | |
| 1027 if not isinstance(select_mask, bool): | |
| 1028 logger.warning('Invalid select_mask parameter ({select_mask}, {type(select_mask)})') | |
| 1029 return(None, None) | |
| 1030 if num_index_ranges_max is not None: | |
| 1031 logger.warning('num_index_ranges_max input not yet implemented in draw_mask_1d') | |
| 1032 if title is None: | |
| 1033 title = 'select ranges of data' | |
| 1034 elif not isinstance(title, str): | |
| 1035 illegal(title, 'title') | |
| 1036 title = '' | |
| 1037 if legend is None and not isinstance(title, str): | |
| 1038 illegal(legend, 'legend') | |
| 1039 legend = None | |
| 1040 | |
| 1041 if select_mask: | |
| 1042 title = f'Click and drag to {title} you wish to include' | |
| 1043 selection_color = 'green' | |
| 1044 else: | |
| 1045 title = f'Click and drag to {title} you wish to exclude' | |
| 1046 selection_color = 'red' | |
| 1047 | |
| 1048 # Set initial selected mask and the selected/unselected index ranges as needed | |
| 1049 selected_index_ranges = [] | |
| 1050 unselected_index_ranges = [] | |
| 1051 selected_mask = np.full(xdata.shape, False, dtype=bool) | |
| 1052 if current_index_ranges is None: | |
| 1053 if current_mask is None: | |
| 1054 if not select_mask: | |
| 1055 selected_index_ranges = [(0, num_data-1)] | |
| 1056 selected_mask = np.full(xdata.shape, True, dtype=bool) | |
| 1057 else: | |
| 1058 selected_mask = np.copy(np.asarray(current_mask, dtype=bool)) | |
| 1059 if current_index_ranges is not None and len(current_index_ranges): | |
| 1060 current_index_ranges = sorted([(low, upp) for (low, upp) in current_index_ranges]) | |
| 1061 for (low, upp) in current_index_ranges: | |
| 1062 if low > upp or low >= num_data or upp < 0: | |
| 1063 continue | |
| 1064 if low < 0: | |
| 1065 low = 0 | |
| 1066 if upp >= num_data: | |
| 1067 upp = num_data-1 | |
| 1068 selected_index_ranges.append((low, upp)) | |
| 1069 selected_mask = update_mask(selected_mask, selected_index_ranges, unselected_index_ranges) | |
| 1070 if current_index_ranges is not None and current_mask is not None: | |
| 1071 selected_mask = np.logical_and(current_mask, selected_mask) | |
| 1072 if current_mask is not None: | |
| 1073 selected_index_ranges = update_index_ranges(selected_mask) | |
| 1074 | |
| 1075 # Set up range selections for display | |
| 1076 current_include = selected_index_ranges | |
| 1077 current_exclude = [] | |
| 1078 selected_index_ranges = [] | |
| 1079 if not len(current_include): | |
| 1080 if select_mask: | |
| 1081 current_exclude = [(0, num_data-1)] | |
| 1082 else: | |
| 1083 current_include = [(0, num_data-1)] | |
| 1084 else: | |
| 1085 if current_include[0][0] > 0: | |
| 1086 current_exclude.append((0, current_include[0][0]-1)) | |
| 1087 for i in range(1, len(current_include)): | |
| 1088 current_exclude.append((current_include[i-1][1]+1, current_include[i][0]-1)) | |
| 1089 if current_include[-1][1] < num_data-1: | |
| 1090 current_exclude.append((current_include[-1][1]+1, num_data-1)) | |
| 1091 | |
| 1092 if not test_mode: | |
| 1093 | |
| 1094 # Set up matplotlib figure | |
| 1095 plt.close('all') | |
| 1096 fig, ax = plt.subplots() | |
| 1097 plt.subplots_adjust(bottom=0.2) | |
| 1098 draw_selections(ax, current_include, current_exclude, selected_index_ranges) | |
| 1099 | |
| 1100 # Set up event handling for click-and-drag range selection | |
| 1101 cid_click = fig.canvas.mpl_connect('button_press_event', onclick) | |
| 1102 cid_release = fig.canvas.mpl_connect('button_release_event', onrelease) | |
| 1103 | |
| 1104 # Set up confirm / clear range selection buttons | |
| 1105 confirm_b = Button(plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm') | |
| 1106 clear_b = Button(plt.axes([0.59, 0.05, 0.15, 0.075]), 'Clear') | |
| 1107 cid_confirm = confirm_b.on_clicked(confirm_selection) | |
| 1108 cid_clear = clear_b.on_clicked(clear_last_selection) | |
| 1109 | |
| 1110 # Show figure | |
| 1111 plt.show(block=True) | |
| 1112 | |
| 1113 # Disconnect callbacks when figure is closed | |
| 1114 fig.canvas.mpl_disconnect(cid_click) | |
| 1115 fig.canvas.mpl_disconnect(cid_release) | |
| 1116 confirm_b.disconnect(cid_confirm) | |
| 1117 clear_b.disconnect(cid_clear) | |
| 1118 | |
| 1119 # Swap selection depending on select_mask | |
| 1120 if not select_mask: | |
| 1121 selected_index_ranges, unselected_index_ranges = unselected_index_ranges, \ | |
| 1122 selected_index_ranges | |
| 1123 | |
| 1124 # Update the mask with the currently selected/unselected x-ranges | |
| 1125 selected_mask = update_mask(selected_mask, selected_index_ranges, unselected_index_ranges) | |
| 1126 | |
| 1127 # Update the currently included index ranges (where mask is True) | |
| 1128 current_include = update_index_ranges(selected_mask) | |
| 1129 | |
| 1130 return(selected_mask, current_include) | |
| 1131 | |
| 1132 def select_peaks(ydata:np.ndarray, x_values:np.ndarray=None, x_mask:np.ndarray=None, | |
| 1133 peak_x_values:np.ndarray=np.array([]), peak_x_indices:np.ndarray=np.array([]), | |
| 1134 return_peak_x_values:bool=False, return_peak_x_indices:bool=False, | |
| 1135 return_peak_input_indices:bool=False, return_sorted:bool=False, | |
| 1136 title:str=None, xlabel:str=None, ylabel:str=None) -> list : | |
| 1137 | |
| 1138 # Check arguments | |
| 1139 if (len(peak_x_values) > 0 or return_peak_x_values) and not len(x_values) > 0: | |
| 1140 raise RuntimeError('Cannot use peak_x_values or return_peak_x_values without x_values') | |
| 1141 if not ((len(peak_x_values) > 0) ^ (len(peak_x_indices) > 0)): | |
| 1142 raise RuntimeError('Use exactly one of peak_x_values or peak_x_indices') | |
| 1143 return_format_iter = iter((return_peak_x_values, return_peak_x_indices, return_peak_input_indices)) | |
| 1144 if not (any(return_format_iter) and not any(return_format_iter)): | |
| 1145 raise RuntimeError('Exactly one of return_peak_x_values, return_peak_x_indices, or '+ | |
| 1146 'return_peak_input_indices must be True') | |
| 1147 | |
| 1148 EXCLUDE_PEAK_PROPERTIES = {'color': 'black', 'linestyle': '--','linewidth': 1, | |
| 1149 'marker': 10, 'markersize': 5, 'fillstyle': 'none'} | |
| 1150 INCLUDE_PEAK_PROPERTIES = {'color': 'green', 'linestyle': '-', 'linewidth': 2, | |
| 1151 'marker': 10, 'markersize': 10, 'fillstyle': 'full'} | |
| 1152 MASKED_PEAK_PROPERTIES = {'color': 'gray', 'linestyle': ':', 'linewidth': 1} | |
| 1153 | |
| 1154 # Setup reference data & plot | |
| 1155 x_indices = np.arange(len(ydata)) | |
| 1156 if x_values is None: | |
| 1157 x_values = x_indices | |
| 1158 if x_mask is None: | |
| 1159 x_mask = np.full(x_values.shape, True, dtype=bool) | |
| 1160 fig, ax = plt.subplots() | |
| 1161 handles = ax.plot(x_values, ydata, label='Reference data') | |
| 1162 handles.append(mlines.Line2D([], [], label='Excluded / unselected HKL', **EXCLUDE_PEAK_PROPERTIES)) | |
| 1163 handles.append(mlines.Line2D([], [], label='Included / selected HKL', **INCLUDE_PEAK_PROPERTIES)) | |
| 1164 handles.append(mlines.Line2D([], [], label='HKL in masked region (unselectable)', **MASKED_PEAK_PROPERTIES)) | |
| 1165 ax.legend(handles=handles, loc='upper right') | |
| 1166 ax.set(title=title, xlabel=xlabel, ylabel=ylabel) | |
| 1167 | |
| 1168 | |
| 1169 # Plot vertical line at each peak | |
| 1170 value_to_index = lambda x_value: int(np.argmin(abs(x_values - x_value))) | |
| 1171 if len(peak_x_indices) > 0: | |
| 1172 peak_x_values = x_values[peak_x_indices] | |
| 1173 else: | |
| 1174 peak_x_indices = np.array(list(map(value_to_index, peak_x_values))) | |
| 1175 peak_vlines = [] | |
| 1176 for loc in peak_x_values: | |
| 1177 nearest_index = value_to_index(loc) | |
| 1178 if nearest_index in x_indices[x_mask]: | |
| 1179 peak_vline = ax.axvline(loc, **EXCLUDE_PEAK_PROPERTIES) | |
| 1180 peak_vline.set_picker(5) | |
| 1181 else: | |
| 1182 peak_vline = ax.axvline(loc, **MASKED_PEAK_PROPERTIES) | |
| 1183 peak_vlines.append(peak_vline) | |
| 1184 | |
| 1185 # Indicate masked regions by gray-ing out the axes facecolor | |
| 1186 mask_exclude_bounds = bounds_from_mask(x_mask, return_include_bounds=False) | |
| 1187 for (low, upp) in mask_exclude_bounds: | |
| 1188 xlow = x_values[low] | |
| 1189 xupp = x_values[upp] | |
| 1190 ax.axvspan(xlow, xupp, facecolor='gray', alpha=0.5) | |
| 1191 | |
| 1192 # Setup peak picking | |
| 1193 selected_peak_input_indices = [] | |
| 1194 def onpick(event): | |
| 1195 try: | |
| 1196 peak_index = peak_vlines.index(event.artist) | |
| 1197 except: | |
| 1198 pass | |
| 1199 else: | |
| 1200 peak_vline = event.artist | |
| 1201 if peak_index in selected_peak_input_indices: | |
| 1202 peak_vline.set(**EXCLUDE_PEAK_PROPERTIES) | |
| 1203 selected_peak_input_indices.remove(peak_index) | |
| 1204 else: | |
| 1205 peak_vline.set(**INCLUDE_PEAK_PROPERTIES) | |
| 1206 selected_peak_input_indices.append(peak_index) | |
| 1207 plt.draw() | |
| 1208 cid_pick_peak = fig.canvas.mpl_connect('pick_event', onpick) | |
| 1209 | |
| 1210 # Setup "Confirm" button | |
| 1211 def confirm_selection(event): | |
| 1212 plt.close() | |
| 1213 plt.subplots_adjust(bottom=0.2) | |
| 1214 confirm_b = Button(plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm') | |
| 1215 cid_confirm = confirm_b.on_clicked(confirm_selection) | |
| 1216 | |
| 1217 # Show figure for user interaction | |
| 1218 plt.show() | |
| 1219 | |
| 1220 # Disconnect callbacks when figure is closed | |
| 1221 fig.canvas.mpl_disconnect(cid_pick_peak) | |
| 1222 confirm_b.disconnect(cid_confirm) | |
| 1223 | |
| 1224 if return_peak_input_indices: | |
| 1225 selected_peaks = np.array(selected_peak_input_indices) | |
| 1226 if return_peak_x_values: | |
| 1227 selected_peaks = peak_x_values[selected_peak_input_indices] | |
| 1228 if return_peak_x_indices: | |
| 1229 selected_peaks = peak_x_indices[selected_peak_input_indices] | |
| 1230 | |
| 1231 if return_sorted: | |
| 1232 selected_peaks.sort() | |
| 1233 | |
| 1234 return(selected_peaks) | |
| 1235 | |
| 1236 def find_image_files(path, filetype, name=None): | |
| 1237 if isinstance(name, str): | |
| 1238 name = f'{name.strip()} ' | |
| 1239 else: | |
| 1240 name = '' | |
| 1241 # Find available index range | |
| 1242 if filetype == 'tif': | |
| 1243 if not isinstance(path, str) or not os.path.isdir(path): | |
| 1244 illegal_value(path, 'path', 'find_image_files') | |
| 1245 return(-1, 0, []) | |
| 1246 indexRegex = re.compile(r'\d+') | |
| 1247 # At this point only tiffs | |
| 1248 files = sorted([f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) and | |
| 1249 f.endswith('.tif') and indexRegex.search(f)]) | |
| 1250 num_img = len(files) | |
| 1251 if num_img < 1: | |
| 1252 logger.warning(f'No available {name}files') | |
| 1253 return(-1, 0, []) | |
| 1254 first_index = indexRegex.search(files[0]).group() | |
| 1255 last_index = indexRegex.search(files[-1]).group() | |
| 1256 if first_index is None or last_index is None: | |
| 1257 logger.error(f'Unable to find correctly indexed {name}images') | |
| 1258 return(-1, 0, []) | |
| 1259 first_index = int(first_index) | |
| 1260 last_index = int(last_index) | |
| 1261 if num_img != last_index-first_index+1: | |
| 1262 logger.error(f'Non-consecutive set of indices for {name}images') | |
| 1263 return(-1, 0, []) | |
| 1264 paths = [os.path.join(path, f) for f in files] | |
| 1265 elif filetype == 'h5': | |
| 1266 if not isinstance(path, str) or not os.path.isfile(path): | |
| 1267 illegal_value(path, 'path', 'find_image_files') | |
| 1268 return(-1, 0, []) | |
| 1269 # At this point only h5 in alamo2 detector style | |
| 1270 first_index = 0 | |
| 1271 with h5py.File(path, 'r') as f: | |
| 1272 num_img = f['entry/instrument/detector/data'].shape[0] | |
| 1273 last_index = num_img-1 | |
| 1274 paths = [path] | |
| 1275 else: | |
| 1276 illegal_value(filetype, 'filetype', 'find_image_files') | |
| 1277 return(-1, 0, []) | |
| 1278 logger.info(f'Number of available {name}images: {num_img}') | |
| 1279 logger.info(f'Index range of available {name}images: [{first_index}, '+ | |
| 1280 f'{last_index}]') | |
| 1281 | |
| 1282 return(first_index, num_img, paths) | |
| 1283 | |
| 1284 def select_image_range(first_index, offset, num_available, num_img=None, name=None, | |
| 1285 num_required=None): | |
| 1286 if isinstance(name, str): | |
| 1287 name = f'{name.strip()} ' | |
| 1288 else: | |
| 1289 name = '' | |
| 1290 # Check existing values | |
| 1291 if not is_int(num_available, gt=0): | |
| 1292 logger.warning(f'No available {name}images') | |
| 1293 return(0, 0, 0) | |
| 1294 if num_img is not None and not is_int(num_img, ge=0): | |
| 1295 illegal_value(num_img, 'num_img', 'select_image_range') | |
| 1296 return(0, 0, 0) | |
| 1297 if is_int(first_index, ge=0) and is_int(offset, ge=0): | |
| 1298 if num_required is None: | |
| 1299 if input_yesno(f'\nCurrent {name}first image index/offset = {first_index}/{offset},'+ | |
| 1300 'use these values (y/n)?', 'y'): | |
| 1301 if num_img is not None: | |
| 1302 if input_yesno(f'Current number of {name}images = {num_img}, '+ | |
| 1303 'use this value (y/n)? ', 'y'): | |
| 1304 return(first_index, offset, num_img) | |
| 1305 else: | |
| 1306 if input_yesno(f'Number of available {name}images = {num_available}, '+ | |
| 1307 'use all (y/n)? ', 'y'): | |
| 1308 return(first_index, offset, num_available) | |
| 1309 else: | |
| 1310 if input_yesno(f'\nCurrent {name}first image offset = {offset}, '+ | |
| 1311 f'use this values (y/n)?', 'y'): | |
| 1312 return(first_index, offset, num_required) | |
| 1313 | |
| 1314 # Check range against requirements | |
| 1315 if num_required is None: | |
| 1316 if num_available == 1: | |
| 1317 return(first_index, 0, 1) | |
| 1318 else: | |
| 1319 if not is_int(num_required, ge=1): | |
| 1320 illegal_value(num_required, 'num_required', 'select_image_range') | |
| 1321 return(0, 0, 0) | |
| 1322 if num_available < num_required: | |
| 1323 logger.error(f'Unable to find the required {name}images ({num_available} out of '+ | |
| 1324 f'{num_required})') | |
| 1325 return(0, 0, 0) | |
| 1326 | |
| 1327 # Select index range | |
| 1328 print(f'\nThe number of available {name}images is {num_available}') | |
| 1329 if num_required is None: | |
| 1330 last_index = first_index+num_available | |
| 1331 use_all = f'Use all ([{first_index}, {last_index}])' | |
| 1332 pick_offset = 'Pick the first image index offset and the number of images' | |
| 1333 pick_bounds = 'Pick the first and last image index' | |
| 1334 choice = input_menu([use_all, pick_offset, pick_bounds], default=pick_offset) | |
| 1335 if not choice: | |
| 1336 offset = 0 | |
| 1337 num_img = num_available | |
| 1338 elif choice == 1: | |
| 1339 offset = input_int('Enter the first index offset', ge=0, le=last_index-first_index) | |
| 1340 if first_index+offset == last_index: | |
| 1341 num_img = 1 | |
| 1342 else: | |
| 1343 num_img = input_int('Enter the number of images', ge=1, le=num_available-offset) | |
| 1344 else: | |
| 1345 offset = input_int('Enter the first index', ge=first_index, le=last_index) | |
| 1346 num_img = 1-offset+input_int('Enter the last index', ge=offset, le=last_index) | |
| 1347 offset -= first_index | |
| 1348 else: | |
| 1349 use_all = f'Use ([{first_index}, {first_index+num_required-1}])' | |
| 1350 pick_offset = 'Pick the first index offset' | |
| 1351 choice = input_menu([use_all, pick_offset], pick_offset) | |
| 1352 offset = 0 | |
| 1353 if choice == 1: | |
| 1354 offset = input_int('Enter the first index offset', ge=0, le=num_available-num_required) | |
| 1355 num_img = num_required | |
| 1356 | |
| 1357 return(first_index, offset, num_img) | |
| 1358 | |
| 1359 def load_image(f, img_x_bounds=None, img_y_bounds=None): | |
| 1360 """Load a single image from file. | |
| 1361 """ | |
| 1362 if not os.path.isfile(f): | |
| 1363 logger.error(f'Unable to load {f}') | |
| 1364 return(None) | |
| 1365 img_read = plt.imread(f) | |
| 1366 if not img_x_bounds: | |
| 1367 img_x_bounds = (0, img_read.shape[0]) | |
| 1368 else: | |
| 1369 if (not isinstance(img_x_bounds, (tuple, list)) or len(img_x_bounds) != 2 or | |
| 1370 not (0 <= img_x_bounds[0] < img_x_bounds[1] <= img_read.shape[0])): | |
| 1371 logger.error(f'inconsistent row dimension in {f}') | |
| 1372 return(None) | |
| 1373 if not img_y_bounds: | |
| 1374 img_y_bounds = (0, img_read.shape[1]) | |
| 1375 else: | |
| 1376 if (not isinstance(img_y_bounds, list) or len(img_y_bounds) != 2 or | |
| 1377 not (0 <= img_y_bounds[0] < img_y_bounds[1] <= img_read.shape[1])): | |
| 1378 logger.error(f'inconsistent column dimension in {f}') | |
| 1379 return(None) | |
| 1380 return(img_read[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]]) | |
| 1381 | |
| 1382 def load_image_stack(files, filetype, img_offset, num_img, num_img_skip=0, | |
| 1383 img_x_bounds=None, img_y_bounds=None): | |
| 1384 """Load a set of images and return them as a stack. | |
| 1385 """ | |
| 1386 logger.debug(f'img_offset = {img_offset}') | |
| 1387 logger.debug(f'num_img = {num_img}') | |
| 1388 logger.debug(f'num_img_skip = {num_img_skip}') | |
| 1389 logger.debug(f'\nfiles:\n{files}\n') | |
| 1390 img_stack = np.array([]) | |
| 1391 if filetype == 'tif': | |
| 1392 img_read_stack = [] | |
| 1393 i = 1 | |
| 1394 t0 = time() | |
| 1395 for f in files[img_offset:img_offset+num_img:num_img_skip+1]: | |
| 1396 if not i%20: | |
| 1397 logger.info(f' loading {i}/{num_img}: {f}') | |
| 1398 else: | |
| 1399 logger.debug(f' loading {i}/{num_img}: {f}') | |
| 1400 img_read = load_image(f, img_x_bounds, img_y_bounds) | |
| 1401 img_read_stack.append(img_read) | |
| 1402 i += num_img_skip+1 | |
| 1403 img_stack = np.stack([img_read for img_read in img_read_stack]) | |
| 1404 logger.info(f'... done in {time()-t0:.2f} seconds!') | |
| 1405 logger.debug(f'img_stack shape = {np.shape(img_stack)}') | |
| 1406 del img_read_stack, img_read | |
| 1407 elif filetype == 'h5': | |
| 1408 if not isinstance(files[0], str) and not os.path.isfile(files[0]): | |
| 1409 illegal_value(files[0], 'files[0]', 'load_image_stack') | |
| 1410 return(img_stack) | |
| 1411 t0 = time() | |
| 1412 logger.info(f'Loading {files[0]}') | |
| 1413 with h5py.File(files[0], 'r') as f: | |
| 1414 shape = f['entry/instrument/detector/data'].shape | |
| 1415 if len(shape) != 3: | |
| 1416 logger.error(f'inconsistent dimensions in {files[0]}') | |
| 1417 if not img_x_bounds: | |
| 1418 img_x_bounds = (0, shape[1]) | |
| 1419 else: | |
| 1420 if (not isinstance(img_x_bounds, (tuple, list)) or len(img_x_bounds) != 2 or | |
| 1421 not (0 <= img_x_bounds[0] < img_x_bounds[1] <= shape[1])): | |
| 1422 logger.error(f'inconsistent row dimension in {files[0]} {img_x_bounds} '+ | |
| 1423 f'{shape[1]}') | |
| 1424 if not img_y_bounds: | |
| 1425 img_y_bounds = (0, shape[2]) | |
| 1426 else: | |
| 1427 if (not isinstance(img_y_bounds, list) or len(img_y_bounds) != 2 or | |
| 1428 not (0 <= img_y_bounds[0] < img_y_bounds[1] <= shape[2])): | |
| 1429 logger.error(f'inconsistent column dimension in {files[0]}') | |
| 1430 img_stack = f.get('entry/instrument/detector/data')[ | |
| 1431 img_offset:img_offset+num_img:num_img_skip+1, | |
| 1432 img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]] | |
| 1433 logger.info(f'... done in {time()-t0:.2f} seconds!') | |
| 1434 else: | |
| 1435 illegal_value(filetype, 'filetype', 'load_image_stack') | |
| 1436 return(img_stack) | |
| 1437 | |
| 1438 def combine_tiffs_in_h5(files, num_img, h5_filename): | |
| 1439 img_stack = load_image_stack(files, 'tif', 0, num_img) | |
| 1440 with h5py.File(h5_filename, 'w') as f: | |
| 1441 f.create_dataset('entry/instrument/detector/data', data=img_stack) | |
| 1442 del img_stack | |
| 1443 return([h5_filename]) | |
| 1444 | |
| 1445 def clear_imshow(title=None): | |
| 1446 plt.ioff() | |
| 1447 if title is None: | |
| 1448 title = 'quick imshow' | |
| 1449 elif not isinstance(title, str): | |
| 1450 illegal_value(title, 'title', 'clear_imshow') | |
| 1451 return | |
| 1452 plt.close(fig=title) | |
| 1453 | |
| 1454 def clear_plot(title=None): | |
| 1455 plt.ioff() | |
| 1456 if title is None: | |
| 1457 title = 'quick plot' | |
| 1458 elif not isinstance(title, str): | |
| 1459 illegal_value(title, 'title', 'clear_plot') | |
| 1460 return | |
| 1461 plt.close(fig=title) | |
| 1462 | |
| 1463 def quick_imshow(a, title=None, path=None, name=None, save_fig=False, save_only=False, | |
| 1464 clear=True, extent=None, show_grid=False, grid_color='w', grid_linewidth=1, | |
| 1465 block=False, **kwargs): | |
| 1466 if title is not None and not isinstance(title, str): | |
| 1467 illegal_value(title, 'title', 'quick_imshow') | |
| 1468 return | |
| 1469 if path is not None and not isinstance(path, str): | |
| 1470 illegal_value(path, 'path', 'quick_imshow') | |
| 1471 return | |
| 1472 if not isinstance(save_fig, bool): | |
| 1473 illegal_value(save_fig, 'save_fig', 'quick_imshow') | |
| 1474 return | |
| 1475 if not isinstance(save_only, bool): | |
| 1476 illegal_value(save_only, 'save_only', 'quick_imshow') | |
| 1477 return | |
| 1478 if not isinstance(clear, bool): | |
| 1479 illegal_value(clear, 'clear', 'quick_imshow') | |
| 1480 return | |
| 1481 if not isinstance(block, bool): | |
| 1482 illegal_value(block, 'block', 'quick_imshow') | |
| 1483 return | |
| 1484 if not title: | |
| 1485 title='quick imshow' | |
| 1486 # else: | |
| 1487 # title = re.sub(r"\s+", '_', title) | |
| 1488 if name is None: | |
| 1489 ttitle = re.sub(r"\s+", '_', title) | |
| 1490 if path is None: | |
| 1491 path = f'{ttitle}.png' | |
| 1492 else: | |
| 1493 path = f'{path}/{ttitle}.png' | |
| 1494 else: | |
| 1495 if path is None: | |
| 1496 path = name | |
| 1497 else: | |
| 1498 path = f'{path}/{name}' | |
| 1499 if 'cmap' in kwargs and a.ndim == 3 and (a.shape[2] == 3 or a.shape[2] == 4): | |
| 1500 use_cmap = True | |
| 1501 if a.shape[2] == 4 and a[:,:,-1].min() != a[:,:,-1].max(): | |
| 1502 use_cmap = False | |
| 1503 if any(True if a[i,j,0] != a[i,j,1] and a[i,j,0] != a[i,j,2] else False | |
| 1504 for i in range(a.shape[0]) for j in range(a.shape[1])): | |
| 1505 use_cmap = False | |
| 1506 if use_cmap: | |
| 1507 a = a[:,:,0] | |
| 1508 else: | |
| 1509 logger.warning('Image incompatible with cmap option, ignore cmap') | |
| 1510 kwargs.pop('cmap') | |
| 1511 if extent is None: | |
| 1512 extent = (0, a.shape[1], a.shape[0], 0) | |
| 1513 if clear: | |
| 1514 try: | |
| 1515 plt.close(fig=title) | |
| 1516 except: | |
| 1517 pass | |
| 1518 if not save_only: | |
| 1519 if block: | |
| 1520 plt.ioff() | |
| 1521 else: | |
| 1522 plt.ion() | |
| 1523 plt.figure(title) | |
| 1524 plt.imshow(a, extent=extent, **kwargs) | |
| 1525 if show_grid: | |
| 1526 ax = plt.gca() | |
| 1527 ax.grid(color=grid_color, linewidth=grid_linewidth) | |
| 1528 # if title != 'quick imshow': | |
| 1529 # plt.title = title | |
| 1530 if save_only: | |
| 1531 plt.savefig(path) | |
| 1532 plt.close(fig=title) | |
| 1533 else: | |
| 1534 if save_fig: | |
| 1535 plt.savefig(path) | |
| 1536 if block: | |
| 1537 plt.show(block=block) | |
| 1538 | |
| 1539 def quick_plot(*args, xerr=None, yerr=None, vlines=None, title=None, xlim=None, ylim=None, | |
| 1540 xlabel=None, ylabel=None, legend=None, path=None, name=None, show_grid=False, | |
| 1541 save_fig=False, save_only=False, clear=True, block=False, **kwargs): | |
| 1542 if title is not None and not isinstance(title, str): | |
| 1543 illegal_value(title, 'title', 'quick_plot') | |
| 1544 title = None | |
| 1545 if xlim is not None and not isinstance(xlim, (tuple, list)) and len(xlim) != 2: | |
| 1546 illegal_value(xlim, 'xlim', 'quick_plot') | |
| 1547 xlim = None | |
| 1548 if ylim is not None and not isinstance(ylim, (tuple, list)) and len(ylim) != 2: | |
| 1549 illegal_value(ylim, 'ylim', 'quick_plot') | |
| 1550 ylim = None | |
| 1551 if xlabel is not None and not isinstance(xlabel, str): | |
| 1552 illegal_value(xlabel, 'xlabel', 'quick_plot') | |
| 1553 xlabel = None | |
| 1554 if ylabel is not None and not isinstance(ylabel, str): | |
| 1555 illegal_value(ylabel, 'ylabel', 'quick_plot') | |
| 1556 ylabel = None | |
| 1557 if legend is not None and not isinstance(legend, (tuple, list)): | |
| 1558 illegal_value(legend, 'legend', 'quick_plot') | |
| 1559 legend = None | |
| 1560 if path is not None and not isinstance(path, str): | |
| 1561 illegal_value(path, 'path', 'quick_plot') | |
| 1562 return | |
| 1563 if not isinstance(show_grid, bool): | |
| 1564 illegal_value(show_grid, 'show_grid', 'quick_plot') | |
| 1565 return | |
| 1566 if not isinstance(save_fig, bool): | |
| 1567 illegal_value(save_fig, 'save_fig', 'quick_plot') | |
| 1568 return | |
| 1569 if not isinstance(save_only, bool): | |
| 1570 illegal_value(save_only, 'save_only', 'quick_plot') | |
| 1571 return | |
| 1572 if not isinstance(clear, bool): | |
| 1573 illegal_value(clear, 'clear', 'quick_plot') | |
| 1574 return | |
| 1575 if not isinstance(block, bool): | |
| 1576 illegal_value(block, 'block', 'quick_plot') | |
| 1577 return | |
| 1578 if title is None: | |
| 1579 title = 'quick plot' | |
| 1580 # else: | |
| 1581 # title = re.sub(r"\s+", '_', title) | |
| 1582 if name is None: | |
| 1583 ttitle = re.sub(r"\s+", '_', title) | |
| 1584 if path is None: | |
| 1585 path = f'{ttitle}.png' | |
| 1586 else: | |
| 1587 path = f'{path}/{ttitle}.png' | |
| 1588 else: | |
| 1589 if path is None: | |
| 1590 path = name | |
| 1591 else: | |
| 1592 path = f'{path}/{name}' | |
| 1593 if clear: | |
| 1594 try: | |
| 1595 plt.close(fig=title) | |
| 1596 except: | |
| 1597 pass | |
| 1598 args = unwrap_tuple(args) | |
| 1599 if depth_tuple(args) > 1 and (xerr is not None or yerr is not None): | |
| 1600 logger.warning('Error bars ignored form multiple curves') | |
| 1601 if not save_only: | |
| 1602 if block: | |
| 1603 plt.ioff() | |
| 1604 else: | |
| 1605 plt.ion() | |
| 1606 plt.figure(title) | |
| 1607 if depth_tuple(args) > 1: | |
| 1608 for y in args: | |
| 1609 plt.plot(*y, **kwargs) | |
| 1610 else: | |
| 1611 if xerr is None and yerr is None: | |
| 1612 plt.plot(*args, **kwargs) | |
| 1613 else: | |
| 1614 plt.errorbar(*args, xerr=xerr, yerr=yerr, **kwargs) | |
| 1615 if vlines is not None: | |
| 1616 if isinstance(vlines, (int, float)): | |
| 1617 vlines = [vlines] | |
| 1618 for v in vlines: | |
| 1619 plt.axvline(v, color='r', linestyle='--', **kwargs) | |
| 1620 # if vlines is not None: | |
| 1621 # for s in tuple(([x, x], list(plt.gca().get_ylim())) for x in vlines): | |
| 1622 # plt.plot(*s, color='red', **kwargs) | |
| 1623 if xlim is not None: | |
| 1624 plt.xlim(xlim) | |
| 1625 if ylim is not None: | |
| 1626 plt.ylim(ylim) | |
| 1627 if xlabel is not None: | |
| 1628 plt.xlabel(xlabel) | |
| 1629 if ylabel is not None: | |
| 1630 plt.ylabel(ylabel) | |
| 1631 if show_grid: | |
| 1632 ax = plt.gca() | |
| 1633 ax.grid(color='k')#, linewidth=1) | |
| 1634 if legend is not None: | |
| 1635 plt.legend(legend) | |
| 1636 if save_only: | |
| 1637 plt.savefig(path) | |
| 1638 plt.close(fig=title) | |
| 1639 else: | |
| 1640 if save_fig: | |
| 1641 plt.savefig(path) | |
| 1642 if block: | |
| 1643 plt.show(block=block) | |
| 1644 | |
| 1645 def select_array_bounds(a, x_low=None, x_upp=None, num_x_min=None, ask_bounds=False, | |
| 1646 title='select array bounds'): | |
| 1647 """Interactively select the lower and upper data bounds for a numpy array. | |
| 1648 """ | |
| 1649 if isinstance(a, (tuple, list)): | |
| 1650 a = np.array(a) | |
| 1651 if not isinstance(a, np.ndarray) or a.ndim != 1: | |
| 1652 illegal_value(a.ndim, 'array type or dimension', 'select_array_bounds') | |
| 1653 return(None) | |
| 1654 len_a = len(a) | |
| 1655 if num_x_min is None: | |
| 1656 num_x_min = 1 | |
| 1657 else: | |
| 1658 if num_x_min < 2 or num_x_min > len_a: | |
| 1659 logger.warning('Invalid value for num_x_min in select_array_bounds, input ignored') | |
| 1660 num_x_min = 1 | |
| 1661 | |
| 1662 # Ask to use current bounds | |
| 1663 if ask_bounds and (x_low is not None or x_upp is not None): | |
| 1664 if x_low is None: | |
| 1665 x_low = 0 | |
| 1666 if not is_int(x_low, ge=0, le=len_a-num_x_min): | |
| 1667 illegal_value(x_low, 'x_low', 'select_array_bounds') | |
| 1668 return(None) | |
| 1669 if x_upp is None: | |
| 1670 x_upp = len_a | |
| 1671 if not is_int(x_upp, ge=x_low+num_x_min, le=len_a): | |
| 1672 illegal_value(x_upp, 'x_upp', 'select_array_bounds') | |
| 1673 return(None) | |
| 1674 quick_plot((range(len_a), a), vlines=(x_low,x_upp), title=title) | |
| 1675 if not input_yesno(f'\nCurrent array bounds: [{x_low}, {x_upp}] '+ | |
| 1676 'use these values (y/n)?', 'y'): | |
| 1677 x_low = None | |
| 1678 x_upp = None | |
| 1679 else: | |
| 1680 clear_plot(title) | |
| 1681 return(x_low, x_upp) | |
| 1682 | |
| 1683 if x_low is None: | |
| 1684 x_min = 0 | |
| 1685 x_max = len_a | |
| 1686 x_low_max = len_a-num_x_min | |
| 1687 while True: | |
| 1688 quick_plot(range(x_min, x_max), a[x_min:x_max], title=title) | |
| 1689 zoom_flag = input_yesno('Set lower data bound (y) or zoom in (n)?', 'y') | |
| 1690 if zoom_flag: | |
| 1691 x_low = input_int(' Set lower data bound', ge=0, le=x_low_max) | |
| 1692 break | |
| 1693 else: | |
| 1694 x_min = input_int(' Set lower zoom index', ge=0, le=x_low_max) | |
| 1695 x_max = input_int(' Set upper zoom index', ge=x_min+1, le=x_low_max+1) | |
| 1696 else: | |
| 1697 if not is_int(x_low, ge=0, le=len_a-num_x_min): | |
| 1698 illegal_value(x_low, 'x_low', 'select_array_bounds') | |
| 1699 return(None) | |
| 1700 if x_upp is None: | |
| 1701 x_min = x_low+num_x_min | |
| 1702 x_max = len_a | |
| 1703 x_upp_min = x_min | |
| 1704 while True: | |
| 1705 quick_plot(range(x_min, x_max), a[x_min:x_max], title=title) | |
| 1706 zoom_flag = input_yesno('Set upper data bound (y) or zoom in (n)?', 'y') | |
| 1707 if zoom_flag: | |
| 1708 x_upp = input_int(' Set upper data bound', ge=x_upp_min, le=len_a) | |
| 1709 break | |
| 1710 else: | |
| 1711 x_min = input_int(' Set upper zoom index', ge=x_upp_min, le=len_a-1) | |
| 1712 x_max = input_int(' Set upper zoom index', ge=x_min+1, le=len_a) | |
| 1713 else: | |
| 1714 if not is_int(x_upp, ge=x_low+num_x_min, le=len_a): | |
| 1715 illegal_value(x_upp, 'x_upp', 'select_array_bounds') | |
| 1716 return(None) | |
| 1717 print(f'lower bound = {x_low} (inclusive)\nupper bound = {x_upp} (exclusive)]') | |
| 1718 quick_plot((range(len_a), a), vlines=(x_low,x_upp), title=title) | |
| 1719 if not input_yesno('Accept these bounds (y/n)?', 'y'): | |
| 1720 x_low, x_upp = select_array_bounds(a, None, None, num_x_min, title=title) | |
| 1721 clear_plot(title) | |
| 1722 return(x_low, x_upp) | |
| 1723 | |
| 1724 def select_image_bounds(a, axis, low=None, upp=None, num_min=None, title='select array bounds', | |
| 1725 raise_error=False): | |
| 1726 """Interactively select the lower and upper data bounds for a 2D numpy array. | |
| 1727 """ | |
| 1728 a = np.asarray(a) | |
| 1729 if a.ndim != 2: | |
| 1730 illegal_value(a.ndim, 'array dimension', location='select_image_bounds', | |
| 1731 raise_error=raise_error) | |
| 1732 return(None) | |
| 1733 if axis < 0 or axis >= a.ndim: | |
| 1734 illegal_value(axis, 'axis', location='select_image_bounds', raise_error=raise_error) | |
| 1735 return(None) | |
| 1736 low_save = low | |
| 1737 upp_save = upp | |
| 1738 num_min_save = num_min | |
| 1739 if num_min is None: | |
| 1740 num_min = 1 | |
| 1741 else: | |
| 1742 if num_min < 2 or num_min > a.shape[axis]: | |
| 1743 logger.warning('Invalid input for num_min in select_image_bounds, input ignored') | |
| 1744 num_min = 1 | |
| 1745 if low is None: | |
| 1746 min_ = 0 | |
| 1747 max_ = a.shape[axis] | |
| 1748 low_max = a.shape[axis]-num_min | |
| 1749 while True: | |
| 1750 if axis: | |
| 1751 quick_imshow(a[:,min_:max_], title=title, aspect='auto', | |
| 1752 extent=[min_,max_,a.shape[0],0]) | |
| 1753 else: | |
| 1754 quick_imshow(a[min_:max_,:], title=title, aspect='auto', | |
| 1755 extent=[0,a.shape[1], max_,min_]) | |
| 1756 zoom_flag = input_yesno('Set lower data bound (y) or zoom in (n)?', 'y') | |
| 1757 if zoom_flag: | |
| 1758 low = input_int(' Set lower data bound', ge=0, le=low_max) | |
| 1759 break | |
| 1760 else: | |
| 1761 min_ = input_int(' Set lower zoom index', ge=0, le=low_max) | |
| 1762 max_ = input_int(' Set upper zoom index', ge=min_+1, le=low_max+1) | |
| 1763 else: | |
| 1764 if not is_int(low, ge=0, le=a.shape[axis]-num_min): | |
| 1765 illegal_value(low, 'low', location='select_image_bounds', raise_error=raise_error) | |
| 1766 return(None) | |
| 1767 if upp is None: | |
| 1768 min_ = low+num_min | |
| 1769 max_ = a.shape[axis] | |
| 1770 upp_min = min_ | |
| 1771 while True: | |
| 1772 if axis: | |
| 1773 quick_imshow(a[:,min_:max_], title=title, aspect='auto', | |
| 1774 extent=[min_,max_,a.shape[0],0]) | |
| 1775 else: | |
| 1776 quick_imshow(a[min_:max_,:], title=title, aspect='auto', | |
| 1777 extent=[0,a.shape[1], max_,min_]) | |
| 1778 zoom_flag = input_yesno('Set upper data bound (y) or zoom in (n)?', 'y') | |
| 1779 if zoom_flag: | |
| 1780 upp = input_int(' Set upper data bound', ge=upp_min, le=a.shape[axis]) | |
| 1781 break | |
| 1782 else: | |
| 1783 min_ = input_int(' Set upper zoom index', ge=upp_min, le=a.shape[axis]-1) | |
| 1784 max_ = input_int(' Set upper zoom index', ge=min_+1, le=a.shape[axis]) | |
| 1785 else: | |
| 1786 if not is_int(upp, ge=low+num_min, le=a.shape[axis]): | |
| 1787 illegal_value(upp, 'upp', location='select_image_bounds', raise_error=raise_error) | |
| 1788 return(None) | |
| 1789 bounds = (low, upp) | |
| 1790 a_tmp = np.copy(a) | |
| 1791 a_tmp_max = a.max() | |
| 1792 if axis: | |
| 1793 a_tmp[:,bounds[0]] = a_tmp_max | |
| 1794 a_tmp[:,bounds[1]-1] = a_tmp_max | |
| 1795 else: | |
| 1796 a_tmp[bounds[0],:] = a_tmp_max | |
| 1797 a_tmp[bounds[1]-1,:] = a_tmp_max | |
| 1798 print(f'lower bound = {low} (inclusive)\nupper bound = {upp} (exclusive)') | |
| 1799 quick_imshow(a_tmp, title=title, aspect='auto') | |
| 1800 del a_tmp | |
| 1801 if not input_yesno('Accept these bounds (y/n)?', 'y'): | |
| 1802 bounds = select_image_bounds(a, axis, low=low_save, upp=upp_save, num_min=num_min_save, | |
| 1803 title=title) | |
| 1804 return(bounds) | |
| 1805 | |
| 1806 def select_one_image_bound(a, axis, bound=None, bound_name=None, title='select array bounds', | |
| 1807 default='y', raise_error=False): | |
| 1808 """Interactively select a data boundary for a 2D numpy array. | |
| 1809 """ | |
| 1810 a = np.asarray(a) | |
| 1811 if a.ndim != 2: | |
| 1812 illegal_value(a.ndim, 'array dimension', location='select_one_image_bound', | |
| 1813 raise_error=raise_error) | |
| 1814 return(None) | |
| 1815 if axis < 0 or axis >= a.ndim: | |
| 1816 illegal_value(axis, 'axis', location='select_one_image_bound', raise_error=raise_error) | |
| 1817 return(None) | |
| 1818 if bound_name is None: | |
| 1819 bound_name = 'data bound' | |
| 1820 if bound is None: | |
| 1821 min_ = 0 | |
| 1822 max_ = a.shape[axis] | |
| 1823 bound_max = a.shape[axis]-1 | |
| 1824 while True: | |
| 1825 if axis: | |
| 1826 quick_imshow(a[:,min_:max_], title=title, aspect='auto', | |
| 1827 extent=[min_,max_,a.shape[0],0]) | |
| 1828 else: | |
| 1829 quick_imshow(a[min_:max_,:], title=title, aspect='auto', | |
| 1830 extent=[0,a.shape[1], max_,min_]) | |
| 1831 zoom_flag = input_yesno(f'Set {bound_name} (y) or zoom in (n)?', 'y') | |
| 1832 if zoom_flag: | |
| 1833 bound = input_int(f' Set {bound_name}', ge=0, le=bound_max) | |
| 1834 clear_imshow(title) | |
| 1835 break | |
| 1836 else: | |
| 1837 min_ = input_int(' Set lower zoom index', ge=0, le=bound_max) | |
| 1838 max_ = input_int(' Set upper zoom index', ge=min_+1, le=bound_max+1) | |
| 1839 | |
| 1840 elif not is_int(bound, ge=0, le=a.shape[axis]-1): | |
| 1841 illegal_value(bound, 'bound', location='select_one_image_bound', raise_error=raise_error) | |
| 1842 return(None) | |
| 1843 else: | |
| 1844 print(f'Current {bound_name} = {bound}') | |
| 1845 a_tmp = np.copy(a) | |
| 1846 a_tmp_max = a.max() | |
| 1847 if axis: | |
| 1848 a_tmp[:,bound] = a_tmp_max | |
| 1849 else: | |
| 1850 a_tmp[bound,:] = a_tmp_max | |
| 1851 quick_imshow(a_tmp, title=title, aspect='auto') | |
| 1852 del a_tmp | |
| 1853 if not input_yesno(f'Accept this {bound_name} (y/n)?', default): | |
| 1854 bound = select_one_image_bound(a, axis, bound_name=bound_name, title=title) | |
| 1855 clear_imshow(title) | |
| 1856 return(bound) | |
| 1857 | |
| 1858 | |
| 1859 class Config: | |
| 1860 """Base class for processing a config file or dictionary. | |
| 1861 """ | |
| 1862 def __init__(self, config_file=None, config_dict=None): | |
| 1863 self.config = {} | |
| 1864 self.load_flag = False | |
| 1865 self.suffix = None | |
| 1866 | |
| 1867 # Load config file | |
| 1868 if config_file is not None and config_dict is not None: | |
| 1869 logger.warning('Ignoring config_dict (both config_file and config_dict are specified)') | |
| 1870 if config_file is not None: | |
| 1871 self.load_file(config_file) | |
| 1872 elif config_dict is not None: | |
| 1873 self.load_dict(config_dict) | |
| 1874 | |
| 1875 def load_file(self, config_file): | |
| 1876 """Load a config file. | |
| 1877 """ | |
| 1878 if self.load_flag: | |
| 1879 logger.warning('Overwriting any previously loaded config file') | |
| 1880 self.config = {} | |
| 1881 | |
| 1882 # Ensure config file exists | |
| 1883 if not os.path.isfile(config_file): | |
| 1884 logger.error(f'Unable to load {config_file}') | |
| 1885 return | |
| 1886 | |
| 1887 # Load config file (for now for Galaxy, allow .dat extension) | |
| 1888 self.suffix = os.path.splitext(config_file)[1] | |
| 1889 if self.suffix == '.yml' or self.suffix == '.yaml' or self.suffix == '.dat': | |
| 1890 with open(config_file, 'r') as f: | |
| 1891 self.config = safe_load(f) | |
| 1892 elif self.suffix == '.txt': | |
| 1893 with open(config_file, 'r') as f: | |
| 1894 lines = f.read().splitlines() | |
| 1895 self.config = {item[0].strip():literal_eval(item[1].strip()) for item in | |
| 1896 [line.split('#')[0].split('=') for line in lines if '=' in line.split('#')[0]]} | |
| 1897 else: | |
| 1898 illegal_value(self.suffix, 'config file extension', 'Config.load_file') | |
| 1899 | |
| 1900 # Make sure config file was correctly loaded | |
| 1901 if isinstance(self.config, dict): | |
| 1902 self.load_flag = True | |
| 1903 else: | |
| 1904 logger.error(f'Unable to load dictionary from config file: {config_file}') | |
| 1905 self.config = {} | |
| 1906 | |
| 1907 def load_dict(self, config_dict): | |
| 1908 """Takes a dictionary and places it into self.config. | |
| 1909 """ | |
| 1910 if self.load_flag: | |
| 1911 logger.warning('Overwriting the previously loaded config file') | |
| 1912 | |
| 1913 if isinstance(config_dict, dict): | |
| 1914 self.config = config_dict | |
| 1915 self.load_flag = True | |
| 1916 else: | |
| 1917 illegal_value(config_dict, 'dictionary config object', 'Config.load_dict') | |
| 1918 self.config = {} | |
| 1919 | |
| 1920 def save_file(self, config_file): | |
| 1921 """Save the config file (as a yaml file only right now). | |
| 1922 """ | |
| 1923 suffix = os.path.splitext(config_file)[1] | |
| 1924 if suffix != '.yml' and suffix != '.yaml': | |
| 1925 illegal_value(suffix, 'config file extension', 'Config.save_file') | |
| 1926 | |
| 1927 # Check if config file exists | |
| 1928 if os.path.isfile(config_file): | |
| 1929 logger.info(f'Updating {config_file}') | |
| 1930 else: | |
| 1931 logger.info(f'Saving {config_file}') | |
| 1932 | |
| 1933 # Save config file | |
| 1934 with open(config_file, 'w') as f: | |
| 1935 safe_dump(self.config, f) | |
| 1936 | |
| 1937 def validate(self, pars_required, pars_missing=None): | |
| 1938 """Returns False if any required keys are missing. | |
| 1939 """ | |
| 1940 if not self.load_flag: | |
| 1941 logger.error('Load a config file prior to calling Config.validate') | |
| 1942 | |
| 1943 def validate_nested_pars(config, par): | |
| 1944 par_levels = par.split(':') | |
| 1945 first_level_par = par_levels[0] | |
| 1946 try: | |
| 1947 first_level_par = int(first_level_par) | |
| 1948 except: | |
| 1949 pass | |
| 1950 try: | |
| 1951 next_level_config = config[first_level_par] | |
| 1952 if len(par_levels) > 1: | |
| 1953 next_level_par = ':'.join(par_levels[1:]) | |
| 1954 return(validate_nested_pars(next_level_config, next_level_par)) | |
| 1955 else: | |
| 1956 return(True) | |
| 1957 except: | |
| 1958 return(False) | |
| 1959 | |
| 1960 pars_missing = [p for p in pars_required if not validate_nested_pars(self.config, p)] | |
| 1961 if len(pars_missing) > 0: | |
| 1962 logger.error(f'Missing item(s) in configuration: {", ".join(pars_missing)}') | |
| 1963 return(False) | |
| 1964 else: | |
| 1965 return(True) |
