comparison general.py @ 0:98e23dff1de2 draft default tip

planemo upload for repository https://github.com/rolfverberg/galaxytools commit f8c4bdb31c20c468045ad5e6eb255a293244bc6c-dirty
author rv43
date Tue, 21 Mar 2023 16:22:42 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:98e23dff1de2
1 #!/usr/bin/env python3
2
3 #FIX write a function that returns a list of peak indices for a given plot
4 #FIX use raise_error concept on more functions to optionally raise an error
5
6 # -*- coding: utf-8 -*-
7 """
8 Created on Mon Dec 6 15:36:22 2021
9
10 @author: rv43
11 """
12
13 import logging
14 logger=logging.getLogger(__name__)
15
16 import os
17 import sys
18 import re
19 try:
20 from yaml import safe_load, safe_dump
21 except:
22 pass
23 try:
24 import h5py
25 except:
26 pass
27 import numpy as np
28 try:
29 import matplotlib.pyplot as plt
30 import matplotlib.lines as mlines
31 from matplotlib import transforms
32 from matplotlib.widgets import Button
33 except:
34 pass
35
36 from ast import literal_eval
37 try:
38 from asteval import Interpreter, get_ast_names
39 except:
40 pass
41 from copy import deepcopy
42 try:
43 from sympy import diff, simplify
44 except:
45 pass
46 from time import time
47
48
49 def depth_list(L): return(isinstance(L, list) and max(map(depth_list, L))+1)
50 def depth_tuple(T): return(isinstance(T, tuple) and max(map(depth_tuple, T))+1)
51 def unwrap_tuple(T):
52 if depth_tuple(T) > 1 and len(T) == 1:
53 T = unwrap_tuple(*T)
54 return(T)
55
56 def illegal_value(value, name, location=None, raise_error=False, log=True):
57 if not isinstance(location, str):
58 location = ''
59 else:
60 location = f'in {location} '
61 if isinstance(name, str):
62 error_msg = f'Illegal value for {name} {location}({value}, {type(value)})'
63 else:
64 error_msg = f'Illegal value {location}({value}, {type(value)})'
65 if log:
66 logger.error(error_msg)
67 if raise_error:
68 raise ValueError(error_msg)
69
70 def illegal_combination(value1, name1, value2, name2, location=None, raise_error=False,
71 log=True):
72 if not isinstance(location, str):
73 location = ''
74 else:
75 location = f'in {location} '
76 if isinstance(name1, str):
77 error_msg = f'Illegal combination for {name1} and {name2} {location}'+ \
78 f'({value1}, {type(value1)} and {value2}, {type(value2)})'
79 else:
80 error_msg = f'Illegal combination {location}'+ \
81 f'({value1}, {type(value1)} and {value2}, {type(value2)})'
82 if log:
83 logger.error(error_msg)
84 if raise_error:
85 raise ValueError(error_msg)
86
87 def test_ge_gt_le_lt(ge, gt, le, lt, func, location=None, raise_error=False, log=True):
88 """Check individual and mutual validity of ge, gt, le, lt qualifiers
89 func: is_int or is_num to test for int or numbers
90 Return: True upon success or False when mutually exlusive
91 """
92 if ge is None and gt is None and le is None and lt is None:
93 return(True)
94 if ge is not None:
95 if not func(ge):
96 illegal_value(ge, 'ge', location, raise_error, log)
97 return(False)
98 if gt is not None:
99 illegal_combination(ge, 'ge', gt, 'gt', location, raise_error, log)
100 return(False)
101 elif gt is not None and not func(gt):
102 illegal_value(gt, 'gt', location, raise_error, log)
103 return(False)
104 if le is not None:
105 if not func(le):
106 illegal_value(le, 'le', location, raise_error, log)
107 return(False)
108 if lt is not None:
109 illegal_combination(le, 'le', lt, 'lt', location, raise_error, log)
110 return(False)
111 elif lt is not None and not func(lt):
112 illegal_value(lt, 'lt', location, raise_error, log)
113 return(False)
114 if ge is not None:
115 if le is not None and ge > le:
116 illegal_combination(ge, 'ge', le, 'le', location, raise_error, log)
117 return(False)
118 elif lt is not None and ge >= lt:
119 illegal_combination(ge, 'ge', lt, 'lt', location, raise_error, log)
120 return(False)
121 elif gt is not None:
122 if le is not None and gt >= le:
123 illegal_combination(gt, 'gt', le, 'le', location, raise_error, log)
124 return(False)
125 elif lt is not None and gt >= lt:
126 illegal_combination(gt, 'gt', lt, 'lt', location, raise_error, log)
127 return(False)
128 return(True)
129
130 def range_string_ge_gt_le_lt(ge=None, gt=None, le=None, lt=None):
131 """Return a range string representation matching the ge, gt, le, lt qualifiers
132 Does not validate the inputs, do that as needed before calling
133 """
134 range_string = ''
135 if ge is not None:
136 if le is None and lt is None:
137 range_string += f'>= {ge}'
138 else:
139 range_string += f'[{ge}, '
140 elif gt is not None:
141 if le is None and lt is None:
142 range_string += f'> {gt}'
143 else:
144 range_string += f'({gt}, '
145 if le is not None:
146 if ge is None and gt is None:
147 range_string += f'<= {le}'
148 else:
149 range_string += f'{le}]'
150 elif lt is not None:
151 if ge is None and gt is None:
152 range_string += f'< {lt}'
153 else:
154 range_string += f'{lt})'
155 return(range_string)
156
157 def is_int(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
158 """Value is an integer in range ge <= v <= le or gt < v < lt or some combination.
159 Return: True if yes or False is no
160 """
161 return(_is_int_or_num(v, 'int', ge, gt, le, lt, raise_error, log))
162
163 def is_num(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
164 """Value is a number in range ge <= v <= le or gt < v < lt or some combination.
165 Return: True if yes or False is no
166 """
167 return(_is_int_or_num(v, 'num', ge, gt, le, lt, raise_error, log))
168
169 def _is_int_or_num(v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False,
170 log=True):
171 if type_str == 'int':
172 if not isinstance(v, int):
173 illegal_value(v, 'v', '_is_int_or_num', raise_error, log)
174 return(False)
175 if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, '_is_int_or_num', raise_error, log):
176 return(False)
177 elif type_str == 'num':
178 if not isinstance(v, (int, float)):
179 illegal_value(v, 'v', '_is_int_or_num', raise_error, log)
180 return(False)
181 if not test_ge_gt_le_lt(ge, gt, le, lt, is_num, '_is_int_or_num', raise_error, log):
182 return(False)
183 else:
184 illegal_value(type_str, 'type_str', '_is_int_or_num', raise_error, log)
185 return(False)
186 if ge is None and gt is None and le is None and lt is None:
187 return(True)
188 error = False
189 if ge is not None and v < ge:
190 error = True
191 error_msg = f'Value {v} out of range: {v} !>= {ge}'
192 if not error and gt is not None and v <= gt:
193 error = True
194 error_msg = f'Value {v} out of range: {v} !> {gt}'
195 if not error and le is not None and v > le:
196 error = True
197 error_msg = f'Value {v} out of range: {v} !<= {le}'
198 if not error and lt is not None and v >= lt:
199 error = True
200 error_msg = f'Value {v} out of range: {v} !< {lt}'
201 if error:
202 if log:
203 logger.error(error_msg)
204 if raise_error:
205 raise ValueError(error_msg)
206 return(False)
207 return(True)
208
209 def is_int_pair(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
210 """Value is an integer pair, each in range ge <= v[i] <= le or gt < v[i] < lt or
211 ge[i] <= v[i] <= le[i] or gt[i] < v[i] < lt[i] or some combination.
212 Return: True if yes or False is no
213 """
214 return(_is_int_or_num_pair(v, 'int', ge, gt, le, lt, raise_error, log))
215
216 def is_num_pair(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
217 """Value is a number pair, each in range ge <= v[i] <= le or gt < v[i] < lt or
218 ge[i] <= v[i] <= le[i] or gt[i] < v[i] < lt[i] or some combination.
219 Return: True if yes or False is no
220 """
221 return(_is_int_or_num_pair(v, 'num', ge, gt, le, lt, raise_error, log))
222
223 def _is_int_or_num_pair(v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False,
224 log=True):
225 if type_str == 'int':
226 if not (isinstance(v, (tuple, list)) and len(v) == 2 and isinstance(v[0], int) and
227 isinstance(v[1], int)):
228 illegal_value(v, 'v', '_is_int_or_num_pair', raise_error, log)
229 return(False)
230 func = is_int
231 elif type_str == 'num':
232 if not (isinstance(v, (tuple, list)) and len(v) == 2 and isinstance(v[0], (int, float)) and
233 isinstance(v[1], (int, float))):
234 illegal_value(v, 'v', '_is_int_or_num_pair', raise_error, log)
235 return(False)
236 func = is_num
237 else:
238 illegal_value(type_str, 'type_str', '_is_int_or_num_pair', raise_error, log)
239 return(False)
240 if ge is None and gt is None and le is None and lt is None:
241 return(True)
242 if ge is None or func(ge, log=True):
243 ge = 2*[ge]
244 elif not _is_int_or_num_pair(ge, type_str, raise_error=raise_error, log=log):
245 return(False)
246 if gt is None or func(gt, log=True):
247 gt = 2*[gt]
248 elif not _is_int_or_num_pair(gt, type_str, raise_error=raise_error, log=log):
249 return(False)
250 if le is None or func(le, log=True):
251 le = 2*[le]
252 elif not _is_int_or_num_pair(le, type_str, raise_error=raise_error, log=log):
253 return(False)
254 if lt is None or func(lt, log=True):
255 lt = 2*[lt]
256 elif not _is_int_or_num_pair(lt, type_str, raise_error=raise_error, log=log):
257 return(False)
258 if (not func(v[0], ge[0], gt[0], le[0], lt[0], raise_error, log) or
259 not func(v[1], ge[1], gt[1], le[1], lt[1], raise_error, log)):
260 return(False)
261 return(True)
262
263 def is_int_series(l, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
264 """Value is a tuple or list of integers, each in range ge <= l[i] <= le or
265 gt < l[i] < lt or some combination.
266 """
267 if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, 'is_int_series', raise_error, log):
268 return(False)
269 if not isinstance(l, (tuple, list)):
270 illegal_value(l, 'l', 'is_int_series', raise_error, log)
271 return(False)
272 if any(True if not is_int(v, ge, gt, le, lt, raise_error, log) else False for v in l):
273 return(False)
274 return(True)
275
276 def is_num_series(l, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
277 """Value is a tuple or list of numbers, each in range ge <= l[i] <= le or
278 gt < l[i] < lt or some combination.
279 """
280 if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, 'is_int_series', raise_error, log):
281 return(False)
282 if not isinstance(l, (tuple, list)):
283 illegal_value(l, 'l', 'is_num_series', raise_error, log)
284 return(False)
285 if any(True if not is_num(v, ge, gt, le, lt, raise_error, log) else False for v in l):
286 return(False)
287 return(True)
288
289 def is_str_series(l, raise_error=False, log=True):
290 """Value is a tuple or list of strings.
291 """
292 if (not isinstance(l, (tuple, list)) or
293 any(True if not isinstance(s, str) else False for s in l)):
294 illegal_value(l, 'l', 'is_str_series', raise_error, log)
295 return(False)
296 return(True)
297
298 def is_dict_series(l, raise_error=False, log=True):
299 """Value is a tuple or list of dictionaries.
300 """
301 if (not isinstance(l, (tuple, list)) or
302 any(True if not isinstance(d, dict) else False for d in l)):
303 illegal_value(l, 'l', 'is_dict_series', raise_error, log)
304 return(False)
305 return(True)
306
307 def is_dict_nums(l, raise_error=False, log=True):
308 """Value is a dictionary with single number values
309 """
310 if (not isinstance(l, dict) or
311 any(True if not is_num(v, log=False) else False for v in l.values())):
312 illegal_value(l, 'l', 'is_dict_nums', raise_error, log)
313 return(False)
314 return(True)
315
316 def is_dict_strings(l, raise_error=False, log=True):
317 """Value is a dictionary with single string values
318 """
319 if (not isinstance(l, dict) or
320 any(True if not isinstance(v, str) else False for v in l.values())):
321 illegal_value(l, 'l', 'is_dict_strings', raise_error, log)
322 return(False)
323 return(True)
324
325 def is_index(v, ge=0, lt=None, raise_error=False, log=True):
326 """Value is an array index in range ge <= v < lt.
327 NOTE lt IS NOT included!
328 """
329 if isinstance(lt, int):
330 if lt <= ge:
331 illegal_combination(ge, 'ge', lt, 'lt', 'is_index', raise_error, log)
332 return(False)
333 return(is_int(v, ge=ge, lt=lt, raise_error=raise_error, log=log))
334
335 def is_index_range(v, ge=0, le=None, lt=None, raise_error=False, log=True):
336 """Value is an array index range in range ge <= v[0] <= v[1] <= le or ge <= v[0] <= v[1] < lt.
337 NOTE le IS included!
338 """
339 if not is_int_pair(v, raise_error=raise_error, log=log):
340 return(False)
341 if not test_ge_gt_le_lt(ge, None, le, lt, is_int, 'is_index_range', raise_error, log):
342 return(False)
343 if not ge <= v[0] <= v[1] or (le is not None and v[1] > le) or (lt is not None and v[1] >= lt):
344 if le is not None:
345 error_msg = f'Value {v} out of range: !({ge} <= {v[0]} <= {v[1]} <= {le})'
346 else:
347 error_msg = f'Value {v} out of range: !({ge} <= {v[0]} <= {v[1]} < {lt})'
348 if log:
349 logger.error(error_msg)
350 if raise_error:
351 raise ValueError(error_msg)
352 return(False)
353 return(True)
354
355 def index_nearest(a, value):
356 a = np.asarray(a)
357 if a.ndim > 1:
358 raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})')
359 # Round up for .5
360 value *= 1.0+sys.float_info.epsilon
361 return((int)(np.argmin(np.abs(a-value))))
362
363 def index_nearest_low(a, value):
364 a = np.asarray(a)
365 if a.ndim > 1:
366 raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})')
367 index = int(np.argmin(np.abs(a-value)))
368 if value < a[index] and index > 0:
369 index -= 1
370 return(index)
371
372 def index_nearest_upp(a, value):
373 a = np.asarray(a)
374 if a.ndim > 1:
375 raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})')
376 index = int(np.argmin(np.abs(a-value)))
377 if value > a[index] and index < a.size-1:
378 index += 1
379 return(index)
380
381 def round_to_n(x, n=1):
382 if x == 0.0:
383 return(0)
384 else:
385 return(type(x)(round(x, n-1-int(np.floor(np.log10(abs(x)))))))
386
387 def round_up_to_n(x, n=1):
388 xr = round_to_n(x, n)
389 if abs(x/xr) > 1.0:
390 xr += np.sign(x)*10**(np.floor(np.log10(abs(x)))+1-n)
391 return(type(x)(xr))
392
393 def trunc_to_n(x, n=1):
394 xr = round_to_n(x, n)
395 if abs(xr/x) > 1.0:
396 xr -= np.sign(x)*10**(np.floor(np.log10(abs(x)))+1-n)
397 return(type(x)(xr))
398
399 def almost_equal(a, b, sig_figs):
400 if is_num(a) and is_num(b):
401 return(abs(round_to_n(a-b, sig_figs)) < pow(10, -sig_figs+1))
402 else:
403 raise ValueError(f'Invalid value for a or b in almost_equal (a: {a}, {type(a)}, '+
404 f'b: {b}, {type(b)})')
405 return(False)
406
407 def string_to_list(s, split_on_dash=True, remove_duplicates=True, sort=True):
408 """Return a list of numbers by splitting/expanding a string on any combination of
409 commas, whitespaces, or dashes (when split_on_dash=True)
410 e.g: '1, 3, 5-8, 12 ' -> [1, 3, 5, 6, 7, 8, 12]
411 """
412 if not isinstance(s, str):
413 illegal_value(s, location='string_to_list')
414 return(None)
415 if not len(s):
416 return([])
417 try:
418 ll = [x for x in re.split('\s+,\s+|\s+,|,\s+|\s+|,', s.strip())]
419 except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
420 return(None)
421 if split_on_dash:
422 try:
423 l = []
424 for l1 in ll:
425 l2 = [literal_eval(x) for x in re.split('\s+-\s+|\s+-|-\s+|\s+|-', l1)]
426 if len(l2) == 1:
427 l += l2
428 elif len(l2) == 2 and l2[1] > l2[0]:
429 l += [i for i in range(l2[0], l2[1]+1)]
430 else:
431 raise ValueError
432 except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
433 return(None)
434 else:
435 l = [literal_eval(x) for x in ll]
436 if remove_duplicates:
437 l = list(dict.fromkeys(l))
438 if sort:
439 l = sorted(l)
440 return(l)
441
442 def get_trailing_int(string):
443 indexRegex = re.compile(r'\d+$')
444 mo = indexRegex.search(string)
445 if mo is None:
446 return(None)
447 else:
448 return(int(mo.group()))
449
450 def input_int(s=None, ge=None, gt=None, le=None, lt=None, default=None, inset=None,
451 raise_error=False, log=True):
452 return(_input_int_or_num('int', s, ge, gt, le, lt, default, inset, raise_error, log))
453
454 def input_num(s=None, ge=None, gt=None, le=None, lt=None, default=None, raise_error=False,
455 log=True):
456 return(_input_int_or_num('num', s, ge, gt, le, lt, default, None, raise_error,log))
457
458 def _input_int_or_num(type_str, s=None, ge=None, gt=None, le=None, lt=None, default=None,
459 inset=None, raise_error=False, log=True):
460 if type_str == 'int':
461 if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, '_input_int_or_num', raise_error, log):
462 return(None)
463 elif type_str == 'num':
464 if not test_ge_gt_le_lt(ge, gt, le, lt, is_num, '_input_int_or_num', raise_error, log):
465 return(None)
466 else:
467 illegal_value(type_str, 'type_str', '_input_int_or_num', raise_error, log)
468 return(None)
469 if default is not None:
470 if not _is_int_or_num(default, type_str, raise_error=raise_error, log=log):
471 return(None)
472 if ge is not None and default < ge:
473 illegal_combination(ge, 'ge', default, 'default', '_input_int_or_num', raise_error,
474 log)
475 return(None)
476 if gt is not None and default <= gt:
477 illegal_combination(gt, 'gt', default, 'default', '_input_int_or_num', raise_error,
478 log)
479 return(None)
480 if le is not None and default > le:
481 illegal_combination(le, 'le', default, 'default', '_input_int_or_num', raise_error,
482 log)
483 return(None)
484 if lt is not None and default >= lt:
485 illegal_combination(lt, 'lt', default, 'default', '_input_int_or_num', raise_error,
486 log)
487 return(None)
488 default_string = f' [{default}]'
489 else:
490 default_string = ''
491 if inset is not None:
492 if (not isinstance(inset, (tuple, list)) or any(True if not isinstance(i, int) else
493 False for i in inset)):
494 illegal_value(inset, 'inset', '_input_int_or_num', raise_error, log)
495 return(None)
496 v_range = f'{range_string_ge_gt_le_lt(ge, gt, le, lt)}'
497 if len(v_range):
498 v_range = f' {v_range}'
499 if s is None:
500 if type_str == 'int':
501 print(f'Enter an integer{v_range}{default_string}: ')
502 else:
503 print(f'Enter a number{v_range}{default_string}: ')
504 else:
505 print(f'{s}{v_range}{default_string}: ')
506 try:
507 i = input()
508 if isinstance(i, str) and not len(i):
509 v = default
510 print(f'{v}')
511 else:
512 v = literal_eval(i)
513 if inset and v not in inset:
514 raise ValueError(f'{v} not part of the set {inset}')
515 except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
516 v = None
517 except:
518 if log:
519 logger.error('Unexpected error')
520 if raise_error:
521 raise ValueError('Unexpected error')
522 if not _is_int_or_num(v, type_str, ge, gt, le, lt):
523 v = _input_int_or_num(type_str, s, ge, gt, le, lt, default, inset, raise_error, log)
524 return(v)
525
526 def input_int_list(s=None, ge=None, le=None, split_on_dash=True, remove_duplicates=True,
527 sort=True, raise_error=False, log=True):
528 """Prompt the user to input a list of interger and split the entered string on any combination
529 of commas, whitespaces, or dashes (when split_on_dash is True)
530 e.g: '1 3,5-8 , 12 ' -> [1, 3, 5, 6, 7, 8, 12]
531 remove_duplicates: removes duplicates if True (may also change the order)
532 sort: sort in ascending order if True
533 return None upon an illegal input
534 """
535 return(_input_int_or_num_list('int', s, ge, le, split_on_dash, remove_duplicates, sort,
536 raise_error, log))
537
538 def input_num_list(s=None, ge=None, le=None, remove_duplicates=True, sort=True, raise_error=False,
539 log=True):
540 """Prompt the user to input a list of numbers and split the entered string on any combination
541 of commas or whitespaces
542 e.g: '1.0, 3, 5.8, 12 ' -> [1.0, 3.0, 5.8, 12.0]
543 remove_duplicates: removes duplicates if True (may also change the order)
544 sort: sort in ascending order if True
545 return None upon an illegal input
546 """
547 return(_input_int_or_num_list('num', s, ge, le, False, remove_duplicates, sort, raise_error,
548 log))
549
550 def _input_int_or_num_list(type_str, s=None, ge=None, le=None, split_on_dash=True,
551 remove_duplicates=True, sort=True, raise_error=False, log=True):
552 #FIX do we want a limit on max dimension?
553 if type_str == 'int':
554 if not test_ge_gt_le_lt(ge, None, le, None, is_int, 'input_int_or_num_list', raise_error,
555 log):
556 return(None)
557 elif type_str == 'num':
558 if not test_ge_gt_le_lt(ge, None, le, None, is_num, 'input_int_or_num_list', raise_error,
559 log):
560 return(None)
561 else:
562 illegal_value(type_str, 'type_str', '_input_int_or_num_list')
563 return(None)
564 v_range = f'{range_string_ge_gt_le_lt(ge=ge, le=le)}'
565 if len(v_range):
566 v_range = f' (each value in {v_range})'
567 if s is None:
568 print(f'Enter a series of integers{v_range}: ')
569 else:
570 print(f'{s}{v_range}: ')
571 try:
572 l = string_to_list(input(), split_on_dash, remove_duplicates, sort)
573 except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
574 l = None
575 except:
576 print('Unexpected error')
577 raise
578 if (not isinstance(l, list) or
579 any(True if not _is_int_or_num(v, type_str, ge=ge, le=le) else False for v in l)):
580 if split_on_dash:
581 print('Invalid input: enter a valid set of dash/comma/whitespace separated integers '+
582 'e.g. 1 3,5-8 , 12')
583 else:
584 print('Invalid input: enter a valid set of comma/whitespace separated integers '+
585 'e.g. 1 3,5 8 , 12')
586 l = _input_int_or_num_list(type_str, s, ge, le, split_on_dash, remove_duplicates, sort,
587 raise_error, log)
588 return(l)
589
590 def input_yesno(s=None, default=None):
591 if default is not None:
592 if not isinstance(default, str):
593 illegal_value(default, 'default', 'input_yesno')
594 return(None)
595 if default.lower() in 'yes':
596 default = 'y'
597 elif default.lower() in 'no':
598 default = 'n'
599 else:
600 illegal_value(default, 'default', 'input_yesno')
601 return(None)
602 default_string = f' [{default}]'
603 else:
604 default_string = ''
605 if s is None:
606 print(f'Enter yes or no{default_string}: ')
607 else:
608 print(f'{s}{default_string}: ')
609 i = input()
610 if isinstance(i, str) and not len(i):
611 i = default
612 print(f'{i}')
613 if i is not None and i.lower() in 'yes':
614 v = True
615 elif i is not None and i.lower() in 'no':
616 v = False
617 else:
618 print('Invalid input, enter yes or no')
619 v = input_yesno(s, default)
620 return(v)
621
622 def input_menu(items, default=None, header=None):
623 if not isinstance(items, (tuple, list)) or any(True if not isinstance(i, str) else False
624 for i in items):
625 illegal_value(items, 'items', 'input_menu')
626 return(None)
627 if default is not None:
628 if not (isinstance(default, str) and default in items):
629 logger.error(f'Invalid value for default ({default}), must be in {items}')
630 return(None)
631 default_string = f' [{items.index(default)+1}]'
632 else:
633 default_string = ''
634 if header is None:
635 print(f'Choose one of the following items (1, {len(items)}){default_string}:')
636 else:
637 print(f'{header} (1, {len(items)}){default_string}:')
638 for i, choice in enumerate(items):
639 print(f' {i+1}: {choice}')
640 try:
641 choice = input()
642 if isinstance(choice, str) and not len(choice):
643 choice = items.index(default)
644 print(f'{choice+1}')
645 else:
646 choice = literal_eval(choice)
647 if isinstance(choice, int) and 1 <= choice <= len(items):
648 choice -= 1
649 else:
650 raise ValueError
651 except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
652 choice = None
653 except:
654 print('Unexpected error')
655 raise
656 if choice is None:
657 print(f'Invalid choice, enter a number between 1 and {len(items)}')
658 choice = input_menu(items, default)
659 return(choice)
660
661 def assert_no_duplicates_in_list_of_dicts(l: list, raise_error=False) -> list:
662 if not isinstance(l, list):
663 illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error)
664 return(None)
665 if any(True if not isinstance(d, dict) else False for d in l):
666 illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error)
667 return(None)
668 if len(l) != len([dict(t) for t in {tuple(sorted(d.items())) for d in l}]):
669 if raise_error:
670 raise ValueError(f'Duplicate items found in {l}')
671 else:
672 logger.error(f'Duplicate items found in {l}')
673 return(None)
674 else:
675 return(l)
676
677 def assert_no_duplicate_key_in_list_of_dicts(l: list, key: str, raise_error=False) -> list:
678 if not isinstance(key, str):
679 illegal_value(key, 'key', 'assert_no_duplicate_key_in_list_of_dicts', raise_error)
680 return(None)
681 if not isinstance(l, list):
682 illegal_value(l, 'l', 'assert_no_duplicate_key_in_list_of_dicts', raise_error)
683 return(None)
684 if any(True if not isinstance(d, dict) else False for d in l):
685 illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error)
686 return(None)
687 keys = [d.get(key, None) for d in l]
688 if None in keys or len(set(keys)) != len(l):
689 if raise_error:
690 raise ValueError(f'Duplicate or missing key ({key}) found in {l}')
691 else:
692 logger.error(f'Duplicate or missing key ({key}) found in {l}')
693 return(None)
694 else:
695 return(l)
696
697 def assert_no_duplicate_attr_in_list_of_objs(l: list, attr: str, raise_error=False) -> list:
698 if not isinstance(attr, str):
699 illegal_value(attr, 'attr', 'assert_no_duplicate_attr_in_list_of_objs', raise_error)
700 return(None)
701 if not isinstance(l, list):
702 illegal_value(l, 'l', 'assert_no_duplicate_key_in_list_of_objs', raise_error)
703 return(None)
704 attrs = [getattr(obj, attr, None) for obj in l]
705 if None in attrs or len(set(attrs)) != len(l):
706 if raise_error:
707 raise ValueError(f'Duplicate or missing attr ({attr}) found in {l}')
708 else:
709 logger.error(f'Duplicate or missing attr ({attr}) found in {l}')
710 return(None)
711 else:
712 return(l)
713
714 def file_exists_and_readable(path):
715 if not os.path.isfile(path):
716 raise ValueError(f'{path} is not a valid file')
717 elif not os.access(path, os.R_OK):
718 raise ValueError(f'{path} is not accessible for reading')
719 else:
720 return(path)
721
722 def create_mask(x, bounds=None, exclude_bounds=False, current_mask=None):
723 # bounds is a pair of number in the same units a x
724 if not isinstance(x, (tuple, list, np.ndarray)) or not len(x):
725 logger.warning(f'Invalid input array ({x}, {type(x)})')
726 return(None)
727 if bounds is not None and not is_num_pair(bounds):
728 logger.warning(f'Invalid bounds parameter ({bounds} {type(bounds)}, input ignored')
729 bounds = None
730 if bounds is not None:
731 if exclude_bounds:
732 mask = np.logical_or(x < min(bounds), x > max(bounds))
733 else:
734 mask = np.logical_and(x > min(bounds), x < max(bounds))
735 else:
736 mask = np.ones(len(x), dtype=bool)
737 if current_mask is not None:
738 if not isinstance(current_mask, (tuple, list, np.ndarray)) or len(current_mask) != len(x):
739 logger.warning(f'Invalid current_mask ({current_mask}, {type(current_mask)}), '+
740 'input ignored')
741 else:
742 mask = np.logical_or(mask, current_mask)
743 if not True in mask:
744 logger.warning('Entire data array is masked')
745 return(mask)
746
747 def eval_expr(name, expr, expr_variables, user_variables=None, max_depth=10, raise_error=False,
748 log=True, **kwargs):
749 """Evaluate an expression of expressions
750 """
751 if not isinstance(name, str):
752 illegal_value(name, 'name', 'eval_expr', raise_error, log)
753 return(None)
754 if not isinstance(expr, str):
755 illegal_value(expr, 'expr', 'eval_expr', raise_error, log)
756 return(None)
757 if not is_dict_strings(expr_variables, log=False):
758 illegal_value(expr_variables, 'expr_variables', 'eval_expr', raise_error, log)
759 return(None)
760 if user_variables is not None and not is_dict_nums(user_variables, log=False):
761 illegal_value(user_variables, 'user_variables', 'eval_expr', raise_error, log)
762 return(None)
763 if not is_int(max_depth, gt=1, log=False):
764 illegal_value(max_depth, 'max_depth', 'eval_expr', raise_error, log)
765 return(None)
766 if not isinstance(raise_error, bool):
767 illegal_value(raise_error, 'raise_error', 'eval_expr', raise_error, log)
768 return(None)
769 if not isinstance(log, bool):
770 illegal_value(log, 'log', 'eval_expr', raise_error, log)
771 return(None)
772 # print(f'\nEvaluate the full expression for {expr}')
773 if 'chain' in kwargs:
774 chain = kwargs.pop('chain')
775 if not is_str_series(chain):
776 illegal_value(chain, 'chain', 'eval_expr', raise_error, log)
777 return(None)
778 else:
779 chain = []
780 if len(chain) > max_depth:
781 error_msg = 'Exceeded maximum depth ({max_depth}) in eval_expr'
782 if log:
783 logger.error(error_msg)
784 if raise_error:
785 raise ValueError(error_msg)
786 return(None)
787 if name not in chain:
788 chain.append(name)
789 # print(f'start: chain = {chain}')
790 if 'ast' in kwargs:
791 ast = kwargs.pop('ast')
792 else:
793 ast = Interpreter()
794 if user_variables is not None:
795 ast.symtable.update(user_variables)
796 chain_vars = [var for var in get_ast_names(ast.parse(expr))
797 if var in expr_variables and var not in ast.symtable]
798 # print(f'chain_vars: {chain_vars}')
799 save_chain = chain.copy()
800 for var in chain_vars:
801 # print(f'\n\tname = {name}, var = {var}:\n\t\t{expr_variables[var]}')
802 # print(f'\tchain = {chain}')
803 if var in chain:
804 error_msg = f'Circular variable {var} in eval_expr'
805 if log:
806 logger.error(error_msg)
807 if raise_error:
808 raise ValueError(error_msg)
809 return(None)
810 # print(f'\tknown symbols:\n\t\t{ast.user_defined_symbols()}\n')
811 if var in ast.user_defined_symbols():
812 val = ast.symtable[var]
813 else:
814 #val = eval_expr(var, expr_variables[var], expr_variables, user_variables=user_variables,
815 val = eval_expr(var, expr_variables[var], expr_variables, max_depth=max_depth,
816 raise_error=raise_error, log=log, chain=chain, ast=ast)
817 if val is None:
818 return(None)
819 ast.symtable[var] = val
820 # print(f'\tval = {val}')
821 # print(f'\t{var} = {ast.symtable[var]}')
822 chain = save_chain.copy()
823 # print(f'\treset loop for {var}: chain = {chain}')
824 val = ast.eval(expr)
825 # print(f'return val for {expr} = {val}\n')
826 return(val)
827
828 def full_gradient(expr, x, expr_name=None, expr_variables=None, valid_variables=None, max_depth=10,
829 raise_error=False, log=True, **kwargs):
830 """Compute the full gradient dexpr/dx
831 """
832 if not isinstance(x, str):
833 illegal_value(x, 'x', 'full_gradient', raise_error, log)
834 return(None)
835 if expr_name is not None and not isinstance(expr_name, str):
836 illegal_value(expr_name, 'expr_name', 'eval_expr', raise_error, log)
837 return(None)
838 if expr_variables is not None and not is_dict_strings(expr_variables, log=False):
839 illegal_value(expr_variables, 'expr_variables', 'full_gradient', raise_error, log)
840 return(None)
841 if valid_variables is not None and not is_str_series(valid_variables, log=False):
842 illegal_value(valid_variables, 'valid_variables', 'full_gradient', raise_error, log)
843 if not is_int(max_depth, gt=1, log=False):
844 illegal_value(max_depth, 'max_depth', 'eval_expr', raise_error, log)
845 return(None)
846 if not isinstance(raise_error, bool):
847 illegal_value(raise_error, 'raise_error', 'eval_expr', raise_error, log)
848 return(None)
849 if not isinstance(log, bool):
850 illegal_value(log, 'log', 'eval_expr', raise_error, log)
851 return(None)
852 # print(f'\nGet full gradient of {expr_name} = {expr} with respect to {x}')
853 if expr_name is not None and expr_name == x:
854 return(1.0)
855 if 'chain' in kwargs:
856 chain = kwargs.pop('chain')
857 if not is_str_series(chain):
858 illegal_value(chain, 'chain', 'eval_expr', raise_error, log)
859 return(None)
860 else:
861 chain = []
862 if len(chain) > max_depth:
863 error_msg = 'Exceeded maximum depth ({max_depth}) in eval_expr'
864 if log:
865 logger.error(error_msg)
866 if raise_error:
867 raise ValueError(error_msg)
868 return(None)
869 if expr_name is not None and expr_name not in chain:
870 chain.append(expr_name)
871 # print(f'start ({x}): chain = {chain}')
872 ast = Interpreter()
873 if expr_variables is None:
874 chain_vars = []
875 else:
876 chain_vars = [var for var in get_ast_names(ast.parse(f'{expr}'))
877 if var in expr_variables and var != x and var not in ast.symtable]
878 # print(f'chain_vars: {chain_vars}')
879 if valid_variables is not None:
880 unknown_vars = [var for var in chain_vars if var not in valid_variables]
881 if len(unknown_vars):
882 error_msg = f'Unknown variable {unknown_vars} in {expr}'
883 if log:
884 logger.error(error_msg)
885 if raise_error:
886 raise ValueError(error_msg)
887 return(None)
888 dexpr_dx = diff(expr, x)
889 # print(f'direct gradient: d({expr})/d({x}) = {dexpr_dx} ({type(dexpr_dx)})')
890 save_chain = chain.copy()
891 for var in chain_vars:
892 # print(f'\n\texpr_name = {expr_name}, var = {var}:\n\t\t{expr}')
893 # print(f'\tchain = {chain}')
894 if var in chain:
895 error_msg = f'Circular variable {var} in full_gradient'
896 if log:
897 logger.error(error_msg)
898 if raise_error:
899 raise ValueError(error_msg)
900 return(None)
901 dexpr_dvar = diff(expr, var)
902 # print(f'\td({expr})/d({var}) = {dexpr_dvar}')
903 if dexpr_dvar:
904 dvar_dx = full_gradient(expr_variables[var], x, expr_name=var,
905 expr_variables=expr_variables, valid_variables=valid_variables,
906 max_depth=max_depth, raise_error=raise_error, log=log, chain=chain)
907 # print(f'\t\td({var})/d({x}) = {dvar_dx}')
908 if dvar_dx:
909 dexpr_dx = f'{dexpr_dx}+({dexpr_dvar})*({dvar_dx})'
910 # print(f'\t\t2: chain = {chain}')
911 chain = save_chain.copy()
912 # print(f'\treset loop for {var}: chain = {chain}')
913 # print(f'full gradient: d({expr})/d({x}) = {dexpr_dx} ({type(dexpr_dx)})')
914 # print(f'reset end: chain = {chain}\n\n')
915 return(simplify(dexpr_dx))
916
917 def bounds_from_mask(mask, return_include_bounds:bool=True):
918 bounds = []
919 for i, m in enumerate(mask):
920 if m == return_include_bounds:
921 if len(bounds) == 0 or type(bounds[-1]) == tuple:
922 bounds.append(i)
923 else:
924 if len(bounds) > 0 and isinstance(bounds[-1], int):
925 bounds[-1] = (bounds[-1], i-1)
926 if len(bounds) > 0 and isinstance(bounds[-1], int):
927 bounds[-1] = (bounds[-1], mask.size-1)
928 return(bounds)
929
930 def draw_mask_1d(ydata, xdata=None, current_index_ranges=None, current_mask=None,
931 select_mask=True, num_index_ranges_max=None, title=None, legend=None, test_mode=False):
932 #FIX make color blind friendly
933 def draw_selections(ax, current_include, current_exclude, selected_index_ranges):
934 ax.clear()
935 ax.set_title(title)
936 ax.legend([legend])
937 ax.plot(xdata, ydata, 'k')
938 for (low, upp) in current_include:
939 xlow = 0.5*(xdata[max(0, low-1)]+xdata[low])
940 xupp = 0.5*(xdata[upp]+xdata[min(num_data-1, upp+1)])
941 ax.axvspan(xlow, xupp, facecolor='green', alpha=0.5)
942 for (low, upp) in current_exclude:
943 xlow = 0.5*(xdata[max(0, low-1)]+xdata[low])
944 xupp = 0.5*(xdata[upp]+xdata[min(num_data-1, upp+1)])
945 ax.axvspan(xlow, xupp, facecolor='red', alpha=0.5)
946 for (low, upp) in selected_index_ranges:
947 xlow = 0.5*(xdata[max(0, low-1)]+xdata[low])
948 xupp = 0.5*(xdata[upp]+xdata[min(num_data-1, upp+1)])
949 ax.axvspan(xlow, xupp, facecolor=selection_color, alpha=0.5)
950 ax.get_figure().canvas.draw()
951
952 def onclick(event):
953 if event.inaxes in [fig.axes[0]]:
954 selected_index_ranges.append(index_nearest_upp(xdata, event.xdata))
955
956 def onrelease(event):
957 if len(selected_index_ranges) > 0:
958 if isinstance(selected_index_ranges[-1], int):
959 if event.inaxes in [fig.axes[0]]:
960 event.xdata = index_nearest_low(xdata, event.xdata)
961 if selected_index_ranges[-1] <= event.xdata:
962 selected_index_ranges[-1] = (selected_index_ranges[-1], event.xdata)
963 else:
964 selected_index_ranges[-1] = (event.xdata, selected_index_ranges[-1])
965 draw_selections(event.inaxes, current_include, current_exclude, selected_index_ranges)
966 else:
967 selected_index_ranges.pop(-1)
968
969 def confirm_selection(event):
970 plt.close()
971
972 def clear_last_selection(event):
973 if len(selected_index_ranges):
974 selected_index_ranges.pop(-1)
975 else:
976 while len(current_include):
977 current_include.pop()
978 while len(current_exclude):
979 current_exclude.pop()
980 selected_mask.fill(False)
981 draw_selections(ax, current_include, current_exclude, selected_index_ranges)
982
983 def update_mask(mask, selected_index_ranges, unselected_index_ranges):
984 for (low, upp) in selected_index_ranges:
985 selected_mask = np.logical_and(xdata >= xdata[low], xdata <= xdata[upp])
986 mask = np.logical_or(mask, selected_mask)
987 for (low, upp) in unselected_index_ranges:
988 unselected_mask = np.logical_and(xdata >= xdata[low], xdata <= xdata[upp])
989 mask[unselected_mask] = False
990 return(mask)
991
992 def update_index_ranges(mask):
993 # Update the currently included index ranges (where mask is True)
994 current_include = []
995 for i, m in enumerate(mask):
996 if m == True:
997 if len(current_include) == 0 or type(current_include[-1]) == tuple:
998 current_include.append(i)
999 else:
1000 if len(current_include) > 0 and isinstance(current_include[-1], int):
1001 current_include[-1] = (current_include[-1], i-1)
1002 if len(current_include) > 0 and isinstance(current_include[-1], int):
1003 current_include[-1] = (current_include[-1], num_data-1)
1004 return(current_include)
1005
1006 # Check inputs
1007 ydata = np.asarray(ydata)
1008 if ydata.ndim > 1:
1009 logger.warning(f'Invalid ydata dimension ({ydata.ndim})')
1010 return(None, None)
1011 num_data = ydata.size
1012 if xdata is None:
1013 xdata = np.arange(num_data)
1014 else:
1015 xdata = np.asarray(xdata, dtype=np.float64)
1016 if xdata.ndim > 1 or xdata.size != num_data:
1017 logger.warning(f'Invalid xdata shape ({xdata.shape})')
1018 return(None, None)
1019 if not np.all(xdata[:-1] < xdata[1:]):
1020 logger.warning('Invalid xdata: must be monotonically increasing')
1021 return(None, None)
1022 if current_index_ranges is not None:
1023 if not isinstance(current_index_ranges, (tuple, list)):
1024 logger.warning('Invalid current_index_ranges parameter ({current_index_ranges}, '+
1025 f'{type(current_index_ranges)})')
1026 return(None, None)
1027 if not isinstance(select_mask, bool):
1028 logger.warning('Invalid select_mask parameter ({select_mask}, {type(select_mask)})')
1029 return(None, None)
1030 if num_index_ranges_max is not None:
1031 logger.warning('num_index_ranges_max input not yet implemented in draw_mask_1d')
1032 if title is None:
1033 title = 'select ranges of data'
1034 elif not isinstance(title, str):
1035 illegal(title, 'title')
1036 title = ''
1037 if legend is None and not isinstance(title, str):
1038 illegal(legend, 'legend')
1039 legend = None
1040
1041 if select_mask:
1042 title = f'Click and drag to {title} you wish to include'
1043 selection_color = 'green'
1044 else:
1045 title = f'Click and drag to {title} you wish to exclude'
1046 selection_color = 'red'
1047
1048 # Set initial selected mask and the selected/unselected index ranges as needed
1049 selected_index_ranges = []
1050 unselected_index_ranges = []
1051 selected_mask = np.full(xdata.shape, False, dtype=bool)
1052 if current_index_ranges is None:
1053 if current_mask is None:
1054 if not select_mask:
1055 selected_index_ranges = [(0, num_data-1)]
1056 selected_mask = np.full(xdata.shape, True, dtype=bool)
1057 else:
1058 selected_mask = np.copy(np.asarray(current_mask, dtype=bool))
1059 if current_index_ranges is not None and len(current_index_ranges):
1060 current_index_ranges = sorted([(low, upp) for (low, upp) in current_index_ranges])
1061 for (low, upp) in current_index_ranges:
1062 if low > upp or low >= num_data or upp < 0:
1063 continue
1064 if low < 0:
1065 low = 0
1066 if upp >= num_data:
1067 upp = num_data-1
1068 selected_index_ranges.append((low, upp))
1069 selected_mask = update_mask(selected_mask, selected_index_ranges, unselected_index_ranges)
1070 if current_index_ranges is not None and current_mask is not None:
1071 selected_mask = np.logical_and(current_mask, selected_mask)
1072 if current_mask is not None:
1073 selected_index_ranges = update_index_ranges(selected_mask)
1074
1075 # Set up range selections for display
1076 current_include = selected_index_ranges
1077 current_exclude = []
1078 selected_index_ranges = []
1079 if not len(current_include):
1080 if select_mask:
1081 current_exclude = [(0, num_data-1)]
1082 else:
1083 current_include = [(0, num_data-1)]
1084 else:
1085 if current_include[0][0] > 0:
1086 current_exclude.append((0, current_include[0][0]-1))
1087 for i in range(1, len(current_include)):
1088 current_exclude.append((current_include[i-1][1]+1, current_include[i][0]-1))
1089 if current_include[-1][1] < num_data-1:
1090 current_exclude.append((current_include[-1][1]+1, num_data-1))
1091
1092 if not test_mode:
1093
1094 # Set up matplotlib figure
1095 plt.close('all')
1096 fig, ax = plt.subplots()
1097 plt.subplots_adjust(bottom=0.2)
1098 draw_selections(ax, current_include, current_exclude, selected_index_ranges)
1099
1100 # Set up event handling for click-and-drag range selection
1101 cid_click = fig.canvas.mpl_connect('button_press_event', onclick)
1102 cid_release = fig.canvas.mpl_connect('button_release_event', onrelease)
1103
1104 # Set up confirm / clear range selection buttons
1105 confirm_b = Button(plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm')
1106 clear_b = Button(plt.axes([0.59, 0.05, 0.15, 0.075]), 'Clear')
1107 cid_confirm = confirm_b.on_clicked(confirm_selection)
1108 cid_clear = clear_b.on_clicked(clear_last_selection)
1109
1110 # Show figure
1111 plt.show(block=True)
1112
1113 # Disconnect callbacks when figure is closed
1114 fig.canvas.mpl_disconnect(cid_click)
1115 fig.canvas.mpl_disconnect(cid_release)
1116 confirm_b.disconnect(cid_confirm)
1117 clear_b.disconnect(cid_clear)
1118
1119 # Swap selection depending on select_mask
1120 if not select_mask:
1121 selected_index_ranges, unselected_index_ranges = unselected_index_ranges, \
1122 selected_index_ranges
1123
1124 # Update the mask with the currently selected/unselected x-ranges
1125 selected_mask = update_mask(selected_mask, selected_index_ranges, unselected_index_ranges)
1126
1127 # Update the currently included index ranges (where mask is True)
1128 current_include = update_index_ranges(selected_mask)
1129
1130 return(selected_mask, current_include)
1131
1132 def select_peaks(ydata:np.ndarray, x_values:np.ndarray=None, x_mask:np.ndarray=None,
1133 peak_x_values:np.ndarray=np.array([]), peak_x_indices:np.ndarray=np.array([]),
1134 return_peak_x_values:bool=False, return_peak_x_indices:bool=False,
1135 return_peak_input_indices:bool=False, return_sorted:bool=False,
1136 title:str=None, xlabel:str=None, ylabel:str=None) -> list :
1137
1138 # Check arguments
1139 if (len(peak_x_values) > 0 or return_peak_x_values) and not len(x_values) > 0:
1140 raise RuntimeError('Cannot use peak_x_values or return_peak_x_values without x_values')
1141 if not ((len(peak_x_values) > 0) ^ (len(peak_x_indices) > 0)):
1142 raise RuntimeError('Use exactly one of peak_x_values or peak_x_indices')
1143 return_format_iter = iter((return_peak_x_values, return_peak_x_indices, return_peak_input_indices))
1144 if not (any(return_format_iter) and not any(return_format_iter)):
1145 raise RuntimeError('Exactly one of return_peak_x_values, return_peak_x_indices, or '+
1146 'return_peak_input_indices must be True')
1147
1148 EXCLUDE_PEAK_PROPERTIES = {'color': 'black', 'linestyle': '--','linewidth': 1,
1149 'marker': 10, 'markersize': 5, 'fillstyle': 'none'}
1150 INCLUDE_PEAK_PROPERTIES = {'color': 'green', 'linestyle': '-', 'linewidth': 2,
1151 'marker': 10, 'markersize': 10, 'fillstyle': 'full'}
1152 MASKED_PEAK_PROPERTIES = {'color': 'gray', 'linestyle': ':', 'linewidth': 1}
1153
1154 # Setup reference data & plot
1155 x_indices = np.arange(len(ydata))
1156 if x_values is None:
1157 x_values = x_indices
1158 if x_mask is None:
1159 x_mask = np.full(x_values.shape, True, dtype=bool)
1160 fig, ax = plt.subplots()
1161 handles = ax.plot(x_values, ydata, label='Reference data')
1162 handles.append(mlines.Line2D([], [], label='Excluded / unselected HKL', **EXCLUDE_PEAK_PROPERTIES))
1163 handles.append(mlines.Line2D([], [], label='Included / selected HKL', **INCLUDE_PEAK_PROPERTIES))
1164 handles.append(mlines.Line2D([], [], label='HKL in masked region (unselectable)', **MASKED_PEAK_PROPERTIES))
1165 ax.legend(handles=handles, loc='upper right')
1166 ax.set(title=title, xlabel=xlabel, ylabel=ylabel)
1167
1168
1169 # Plot vertical line at each peak
1170 value_to_index = lambda x_value: int(np.argmin(abs(x_values - x_value)))
1171 if len(peak_x_indices) > 0:
1172 peak_x_values = x_values[peak_x_indices]
1173 else:
1174 peak_x_indices = np.array(list(map(value_to_index, peak_x_values)))
1175 peak_vlines = []
1176 for loc in peak_x_values:
1177 nearest_index = value_to_index(loc)
1178 if nearest_index in x_indices[x_mask]:
1179 peak_vline = ax.axvline(loc, **EXCLUDE_PEAK_PROPERTIES)
1180 peak_vline.set_picker(5)
1181 else:
1182 peak_vline = ax.axvline(loc, **MASKED_PEAK_PROPERTIES)
1183 peak_vlines.append(peak_vline)
1184
1185 # Indicate masked regions by gray-ing out the axes facecolor
1186 mask_exclude_bounds = bounds_from_mask(x_mask, return_include_bounds=False)
1187 for (low, upp) in mask_exclude_bounds:
1188 xlow = x_values[low]
1189 xupp = x_values[upp]
1190 ax.axvspan(xlow, xupp, facecolor='gray', alpha=0.5)
1191
1192 # Setup peak picking
1193 selected_peak_input_indices = []
1194 def onpick(event):
1195 try:
1196 peak_index = peak_vlines.index(event.artist)
1197 except:
1198 pass
1199 else:
1200 peak_vline = event.artist
1201 if peak_index in selected_peak_input_indices:
1202 peak_vline.set(**EXCLUDE_PEAK_PROPERTIES)
1203 selected_peak_input_indices.remove(peak_index)
1204 else:
1205 peak_vline.set(**INCLUDE_PEAK_PROPERTIES)
1206 selected_peak_input_indices.append(peak_index)
1207 plt.draw()
1208 cid_pick_peak = fig.canvas.mpl_connect('pick_event', onpick)
1209
1210 # Setup "Confirm" button
1211 def confirm_selection(event):
1212 plt.close()
1213 plt.subplots_adjust(bottom=0.2)
1214 confirm_b = Button(plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm')
1215 cid_confirm = confirm_b.on_clicked(confirm_selection)
1216
1217 # Show figure for user interaction
1218 plt.show()
1219
1220 # Disconnect callbacks when figure is closed
1221 fig.canvas.mpl_disconnect(cid_pick_peak)
1222 confirm_b.disconnect(cid_confirm)
1223
1224 if return_peak_input_indices:
1225 selected_peaks = np.array(selected_peak_input_indices)
1226 if return_peak_x_values:
1227 selected_peaks = peak_x_values[selected_peak_input_indices]
1228 if return_peak_x_indices:
1229 selected_peaks = peak_x_indices[selected_peak_input_indices]
1230
1231 if return_sorted:
1232 selected_peaks.sort()
1233
1234 return(selected_peaks)
1235
1236 def find_image_files(path, filetype, name=None):
1237 if isinstance(name, str):
1238 name = f'{name.strip()} '
1239 else:
1240 name = ''
1241 # Find available index range
1242 if filetype == 'tif':
1243 if not isinstance(path, str) or not os.path.isdir(path):
1244 illegal_value(path, 'path', 'find_image_files')
1245 return(-1, 0, [])
1246 indexRegex = re.compile(r'\d+')
1247 # At this point only tiffs
1248 files = sorted([f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) and
1249 f.endswith('.tif') and indexRegex.search(f)])
1250 num_img = len(files)
1251 if num_img < 1:
1252 logger.warning(f'No available {name}files')
1253 return(-1, 0, [])
1254 first_index = indexRegex.search(files[0]).group()
1255 last_index = indexRegex.search(files[-1]).group()
1256 if first_index is None or last_index is None:
1257 logger.error(f'Unable to find correctly indexed {name}images')
1258 return(-1, 0, [])
1259 first_index = int(first_index)
1260 last_index = int(last_index)
1261 if num_img != last_index-first_index+1:
1262 logger.error(f'Non-consecutive set of indices for {name}images')
1263 return(-1, 0, [])
1264 paths = [os.path.join(path, f) for f in files]
1265 elif filetype == 'h5':
1266 if not isinstance(path, str) or not os.path.isfile(path):
1267 illegal_value(path, 'path', 'find_image_files')
1268 return(-1, 0, [])
1269 # At this point only h5 in alamo2 detector style
1270 first_index = 0
1271 with h5py.File(path, 'r') as f:
1272 num_img = f['entry/instrument/detector/data'].shape[0]
1273 last_index = num_img-1
1274 paths = [path]
1275 else:
1276 illegal_value(filetype, 'filetype', 'find_image_files')
1277 return(-1, 0, [])
1278 logger.info(f'Number of available {name}images: {num_img}')
1279 logger.info(f'Index range of available {name}images: [{first_index}, '+
1280 f'{last_index}]')
1281
1282 return(first_index, num_img, paths)
1283
1284 def select_image_range(first_index, offset, num_available, num_img=None, name=None,
1285 num_required=None):
1286 if isinstance(name, str):
1287 name = f'{name.strip()} '
1288 else:
1289 name = ''
1290 # Check existing values
1291 if not is_int(num_available, gt=0):
1292 logger.warning(f'No available {name}images')
1293 return(0, 0, 0)
1294 if num_img is not None and not is_int(num_img, ge=0):
1295 illegal_value(num_img, 'num_img', 'select_image_range')
1296 return(0, 0, 0)
1297 if is_int(first_index, ge=0) and is_int(offset, ge=0):
1298 if num_required is None:
1299 if input_yesno(f'\nCurrent {name}first image index/offset = {first_index}/{offset},'+
1300 'use these values (y/n)?', 'y'):
1301 if num_img is not None:
1302 if input_yesno(f'Current number of {name}images = {num_img}, '+
1303 'use this value (y/n)? ', 'y'):
1304 return(first_index, offset, num_img)
1305 else:
1306 if input_yesno(f'Number of available {name}images = {num_available}, '+
1307 'use all (y/n)? ', 'y'):
1308 return(first_index, offset, num_available)
1309 else:
1310 if input_yesno(f'\nCurrent {name}first image offset = {offset}, '+
1311 f'use this values (y/n)?', 'y'):
1312 return(first_index, offset, num_required)
1313
1314 # Check range against requirements
1315 if num_required is None:
1316 if num_available == 1:
1317 return(first_index, 0, 1)
1318 else:
1319 if not is_int(num_required, ge=1):
1320 illegal_value(num_required, 'num_required', 'select_image_range')
1321 return(0, 0, 0)
1322 if num_available < num_required:
1323 logger.error(f'Unable to find the required {name}images ({num_available} out of '+
1324 f'{num_required})')
1325 return(0, 0, 0)
1326
1327 # Select index range
1328 print(f'\nThe number of available {name}images is {num_available}')
1329 if num_required is None:
1330 last_index = first_index+num_available
1331 use_all = f'Use all ([{first_index}, {last_index}])'
1332 pick_offset = 'Pick the first image index offset and the number of images'
1333 pick_bounds = 'Pick the first and last image index'
1334 choice = input_menu([use_all, pick_offset, pick_bounds], default=pick_offset)
1335 if not choice:
1336 offset = 0
1337 num_img = num_available
1338 elif choice == 1:
1339 offset = input_int('Enter the first index offset', ge=0, le=last_index-first_index)
1340 if first_index+offset == last_index:
1341 num_img = 1
1342 else:
1343 num_img = input_int('Enter the number of images', ge=1, le=num_available-offset)
1344 else:
1345 offset = input_int('Enter the first index', ge=first_index, le=last_index)
1346 num_img = 1-offset+input_int('Enter the last index', ge=offset, le=last_index)
1347 offset -= first_index
1348 else:
1349 use_all = f'Use ([{first_index}, {first_index+num_required-1}])'
1350 pick_offset = 'Pick the first index offset'
1351 choice = input_menu([use_all, pick_offset], pick_offset)
1352 offset = 0
1353 if choice == 1:
1354 offset = input_int('Enter the first index offset', ge=0, le=num_available-num_required)
1355 num_img = num_required
1356
1357 return(first_index, offset, num_img)
1358
1359 def load_image(f, img_x_bounds=None, img_y_bounds=None):
1360 """Load a single image from file.
1361 """
1362 if not os.path.isfile(f):
1363 logger.error(f'Unable to load {f}')
1364 return(None)
1365 img_read = plt.imread(f)
1366 if not img_x_bounds:
1367 img_x_bounds = (0, img_read.shape[0])
1368 else:
1369 if (not isinstance(img_x_bounds, (tuple, list)) or len(img_x_bounds) != 2 or
1370 not (0 <= img_x_bounds[0] < img_x_bounds[1] <= img_read.shape[0])):
1371 logger.error(f'inconsistent row dimension in {f}')
1372 return(None)
1373 if not img_y_bounds:
1374 img_y_bounds = (0, img_read.shape[1])
1375 else:
1376 if (not isinstance(img_y_bounds, list) or len(img_y_bounds) != 2 or
1377 not (0 <= img_y_bounds[0] < img_y_bounds[1] <= img_read.shape[1])):
1378 logger.error(f'inconsistent column dimension in {f}')
1379 return(None)
1380 return(img_read[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]])
1381
1382 def load_image_stack(files, filetype, img_offset, num_img, num_img_skip=0,
1383 img_x_bounds=None, img_y_bounds=None):
1384 """Load a set of images and return them as a stack.
1385 """
1386 logger.debug(f'img_offset = {img_offset}')
1387 logger.debug(f'num_img = {num_img}')
1388 logger.debug(f'num_img_skip = {num_img_skip}')
1389 logger.debug(f'\nfiles:\n{files}\n')
1390 img_stack = np.array([])
1391 if filetype == 'tif':
1392 img_read_stack = []
1393 i = 1
1394 t0 = time()
1395 for f in files[img_offset:img_offset+num_img:num_img_skip+1]:
1396 if not i%20:
1397 logger.info(f' loading {i}/{num_img}: {f}')
1398 else:
1399 logger.debug(f' loading {i}/{num_img}: {f}')
1400 img_read = load_image(f, img_x_bounds, img_y_bounds)
1401 img_read_stack.append(img_read)
1402 i += num_img_skip+1
1403 img_stack = np.stack([img_read for img_read in img_read_stack])
1404 logger.info(f'... done in {time()-t0:.2f} seconds!')
1405 logger.debug(f'img_stack shape = {np.shape(img_stack)}')
1406 del img_read_stack, img_read
1407 elif filetype == 'h5':
1408 if not isinstance(files[0], str) and not os.path.isfile(files[0]):
1409 illegal_value(files[0], 'files[0]', 'load_image_stack')
1410 return(img_stack)
1411 t0 = time()
1412 logger.info(f'Loading {files[0]}')
1413 with h5py.File(files[0], 'r') as f:
1414 shape = f['entry/instrument/detector/data'].shape
1415 if len(shape) != 3:
1416 logger.error(f'inconsistent dimensions in {files[0]}')
1417 if not img_x_bounds:
1418 img_x_bounds = (0, shape[1])
1419 else:
1420 if (not isinstance(img_x_bounds, (tuple, list)) or len(img_x_bounds) != 2 or
1421 not (0 <= img_x_bounds[0] < img_x_bounds[1] <= shape[1])):
1422 logger.error(f'inconsistent row dimension in {files[0]} {img_x_bounds} '+
1423 f'{shape[1]}')
1424 if not img_y_bounds:
1425 img_y_bounds = (0, shape[2])
1426 else:
1427 if (not isinstance(img_y_bounds, list) or len(img_y_bounds) != 2 or
1428 not (0 <= img_y_bounds[0] < img_y_bounds[1] <= shape[2])):
1429 logger.error(f'inconsistent column dimension in {files[0]}')
1430 img_stack = f.get('entry/instrument/detector/data')[
1431 img_offset:img_offset+num_img:num_img_skip+1,
1432 img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]]
1433 logger.info(f'... done in {time()-t0:.2f} seconds!')
1434 else:
1435 illegal_value(filetype, 'filetype', 'load_image_stack')
1436 return(img_stack)
1437
1438 def combine_tiffs_in_h5(files, num_img, h5_filename):
1439 img_stack = load_image_stack(files, 'tif', 0, num_img)
1440 with h5py.File(h5_filename, 'w') as f:
1441 f.create_dataset('entry/instrument/detector/data', data=img_stack)
1442 del img_stack
1443 return([h5_filename])
1444
1445 def clear_imshow(title=None):
1446 plt.ioff()
1447 if title is None:
1448 title = 'quick imshow'
1449 elif not isinstance(title, str):
1450 illegal_value(title, 'title', 'clear_imshow')
1451 return
1452 plt.close(fig=title)
1453
1454 def clear_plot(title=None):
1455 plt.ioff()
1456 if title is None:
1457 title = 'quick plot'
1458 elif not isinstance(title, str):
1459 illegal_value(title, 'title', 'clear_plot')
1460 return
1461 plt.close(fig=title)
1462
1463 def quick_imshow(a, title=None, path=None, name=None, save_fig=False, save_only=False,
1464 clear=True, extent=None, show_grid=False, grid_color='w', grid_linewidth=1,
1465 block=False, **kwargs):
1466 if title is not None and not isinstance(title, str):
1467 illegal_value(title, 'title', 'quick_imshow')
1468 return
1469 if path is not None and not isinstance(path, str):
1470 illegal_value(path, 'path', 'quick_imshow')
1471 return
1472 if not isinstance(save_fig, bool):
1473 illegal_value(save_fig, 'save_fig', 'quick_imshow')
1474 return
1475 if not isinstance(save_only, bool):
1476 illegal_value(save_only, 'save_only', 'quick_imshow')
1477 return
1478 if not isinstance(clear, bool):
1479 illegal_value(clear, 'clear', 'quick_imshow')
1480 return
1481 if not isinstance(block, bool):
1482 illegal_value(block, 'block', 'quick_imshow')
1483 return
1484 if not title:
1485 title='quick imshow'
1486 # else:
1487 # title = re.sub(r"\s+", '_', title)
1488 if name is None:
1489 ttitle = re.sub(r"\s+", '_', title)
1490 if path is None:
1491 path = f'{ttitle}.png'
1492 else:
1493 path = f'{path}/{ttitle}.png'
1494 else:
1495 if path is None:
1496 path = name
1497 else:
1498 path = f'{path}/{name}'
1499 if 'cmap' in kwargs and a.ndim == 3 and (a.shape[2] == 3 or a.shape[2] == 4):
1500 use_cmap = True
1501 if a.shape[2] == 4 and a[:,:,-1].min() != a[:,:,-1].max():
1502 use_cmap = False
1503 if any(True if a[i,j,0] != a[i,j,1] and a[i,j,0] != a[i,j,2] else False
1504 for i in range(a.shape[0]) for j in range(a.shape[1])):
1505 use_cmap = False
1506 if use_cmap:
1507 a = a[:,:,0]
1508 else:
1509 logger.warning('Image incompatible with cmap option, ignore cmap')
1510 kwargs.pop('cmap')
1511 if extent is None:
1512 extent = (0, a.shape[1], a.shape[0], 0)
1513 if clear:
1514 try:
1515 plt.close(fig=title)
1516 except:
1517 pass
1518 if not save_only:
1519 if block:
1520 plt.ioff()
1521 else:
1522 plt.ion()
1523 plt.figure(title)
1524 plt.imshow(a, extent=extent, **kwargs)
1525 if show_grid:
1526 ax = plt.gca()
1527 ax.grid(color=grid_color, linewidth=grid_linewidth)
1528 # if title != 'quick imshow':
1529 # plt.title = title
1530 if save_only:
1531 plt.savefig(path)
1532 plt.close(fig=title)
1533 else:
1534 if save_fig:
1535 plt.savefig(path)
1536 if block:
1537 plt.show(block=block)
1538
1539 def quick_plot(*args, xerr=None, yerr=None, vlines=None, title=None, xlim=None, ylim=None,
1540 xlabel=None, ylabel=None, legend=None, path=None, name=None, show_grid=False,
1541 save_fig=False, save_only=False, clear=True, block=False, **kwargs):
1542 if title is not None and not isinstance(title, str):
1543 illegal_value(title, 'title', 'quick_plot')
1544 title = None
1545 if xlim is not None and not isinstance(xlim, (tuple, list)) and len(xlim) != 2:
1546 illegal_value(xlim, 'xlim', 'quick_plot')
1547 xlim = None
1548 if ylim is not None and not isinstance(ylim, (tuple, list)) and len(ylim) != 2:
1549 illegal_value(ylim, 'ylim', 'quick_plot')
1550 ylim = None
1551 if xlabel is not None and not isinstance(xlabel, str):
1552 illegal_value(xlabel, 'xlabel', 'quick_plot')
1553 xlabel = None
1554 if ylabel is not None and not isinstance(ylabel, str):
1555 illegal_value(ylabel, 'ylabel', 'quick_plot')
1556 ylabel = None
1557 if legend is not None and not isinstance(legend, (tuple, list)):
1558 illegal_value(legend, 'legend', 'quick_plot')
1559 legend = None
1560 if path is not None and not isinstance(path, str):
1561 illegal_value(path, 'path', 'quick_plot')
1562 return
1563 if not isinstance(show_grid, bool):
1564 illegal_value(show_grid, 'show_grid', 'quick_plot')
1565 return
1566 if not isinstance(save_fig, bool):
1567 illegal_value(save_fig, 'save_fig', 'quick_plot')
1568 return
1569 if not isinstance(save_only, bool):
1570 illegal_value(save_only, 'save_only', 'quick_plot')
1571 return
1572 if not isinstance(clear, bool):
1573 illegal_value(clear, 'clear', 'quick_plot')
1574 return
1575 if not isinstance(block, bool):
1576 illegal_value(block, 'block', 'quick_plot')
1577 return
1578 if title is None:
1579 title = 'quick plot'
1580 # else:
1581 # title = re.sub(r"\s+", '_', title)
1582 if name is None:
1583 ttitle = re.sub(r"\s+", '_', title)
1584 if path is None:
1585 path = f'{ttitle}.png'
1586 else:
1587 path = f'{path}/{ttitle}.png'
1588 else:
1589 if path is None:
1590 path = name
1591 else:
1592 path = f'{path}/{name}'
1593 if clear:
1594 try:
1595 plt.close(fig=title)
1596 except:
1597 pass
1598 args = unwrap_tuple(args)
1599 if depth_tuple(args) > 1 and (xerr is not None or yerr is not None):
1600 logger.warning('Error bars ignored form multiple curves')
1601 if not save_only:
1602 if block:
1603 plt.ioff()
1604 else:
1605 plt.ion()
1606 plt.figure(title)
1607 if depth_tuple(args) > 1:
1608 for y in args:
1609 plt.plot(*y, **kwargs)
1610 else:
1611 if xerr is None and yerr is None:
1612 plt.plot(*args, **kwargs)
1613 else:
1614 plt.errorbar(*args, xerr=xerr, yerr=yerr, **kwargs)
1615 if vlines is not None:
1616 if isinstance(vlines, (int, float)):
1617 vlines = [vlines]
1618 for v in vlines:
1619 plt.axvline(v, color='r', linestyle='--', **kwargs)
1620 # if vlines is not None:
1621 # for s in tuple(([x, x], list(plt.gca().get_ylim())) for x in vlines):
1622 # plt.plot(*s, color='red', **kwargs)
1623 if xlim is not None:
1624 plt.xlim(xlim)
1625 if ylim is not None:
1626 plt.ylim(ylim)
1627 if xlabel is not None:
1628 plt.xlabel(xlabel)
1629 if ylabel is not None:
1630 plt.ylabel(ylabel)
1631 if show_grid:
1632 ax = plt.gca()
1633 ax.grid(color='k')#, linewidth=1)
1634 if legend is not None:
1635 plt.legend(legend)
1636 if save_only:
1637 plt.savefig(path)
1638 plt.close(fig=title)
1639 else:
1640 if save_fig:
1641 plt.savefig(path)
1642 if block:
1643 plt.show(block=block)
1644
1645 def select_array_bounds(a, x_low=None, x_upp=None, num_x_min=None, ask_bounds=False,
1646 title='select array bounds'):
1647 """Interactively select the lower and upper data bounds for a numpy array.
1648 """
1649 if isinstance(a, (tuple, list)):
1650 a = np.array(a)
1651 if not isinstance(a, np.ndarray) or a.ndim != 1:
1652 illegal_value(a.ndim, 'array type or dimension', 'select_array_bounds')
1653 return(None)
1654 len_a = len(a)
1655 if num_x_min is None:
1656 num_x_min = 1
1657 else:
1658 if num_x_min < 2 or num_x_min > len_a:
1659 logger.warning('Invalid value for num_x_min in select_array_bounds, input ignored')
1660 num_x_min = 1
1661
1662 # Ask to use current bounds
1663 if ask_bounds and (x_low is not None or x_upp is not None):
1664 if x_low is None:
1665 x_low = 0
1666 if not is_int(x_low, ge=0, le=len_a-num_x_min):
1667 illegal_value(x_low, 'x_low', 'select_array_bounds')
1668 return(None)
1669 if x_upp is None:
1670 x_upp = len_a
1671 if not is_int(x_upp, ge=x_low+num_x_min, le=len_a):
1672 illegal_value(x_upp, 'x_upp', 'select_array_bounds')
1673 return(None)
1674 quick_plot((range(len_a), a), vlines=(x_low,x_upp), title=title)
1675 if not input_yesno(f'\nCurrent array bounds: [{x_low}, {x_upp}] '+
1676 'use these values (y/n)?', 'y'):
1677 x_low = None
1678 x_upp = None
1679 else:
1680 clear_plot(title)
1681 return(x_low, x_upp)
1682
1683 if x_low is None:
1684 x_min = 0
1685 x_max = len_a
1686 x_low_max = len_a-num_x_min
1687 while True:
1688 quick_plot(range(x_min, x_max), a[x_min:x_max], title=title)
1689 zoom_flag = input_yesno('Set lower data bound (y) or zoom in (n)?', 'y')
1690 if zoom_flag:
1691 x_low = input_int(' Set lower data bound', ge=0, le=x_low_max)
1692 break
1693 else:
1694 x_min = input_int(' Set lower zoom index', ge=0, le=x_low_max)
1695 x_max = input_int(' Set upper zoom index', ge=x_min+1, le=x_low_max+1)
1696 else:
1697 if not is_int(x_low, ge=0, le=len_a-num_x_min):
1698 illegal_value(x_low, 'x_low', 'select_array_bounds')
1699 return(None)
1700 if x_upp is None:
1701 x_min = x_low+num_x_min
1702 x_max = len_a
1703 x_upp_min = x_min
1704 while True:
1705 quick_plot(range(x_min, x_max), a[x_min:x_max], title=title)
1706 zoom_flag = input_yesno('Set upper data bound (y) or zoom in (n)?', 'y')
1707 if zoom_flag:
1708 x_upp = input_int(' Set upper data bound', ge=x_upp_min, le=len_a)
1709 break
1710 else:
1711 x_min = input_int(' Set upper zoom index', ge=x_upp_min, le=len_a-1)
1712 x_max = input_int(' Set upper zoom index', ge=x_min+1, le=len_a)
1713 else:
1714 if not is_int(x_upp, ge=x_low+num_x_min, le=len_a):
1715 illegal_value(x_upp, 'x_upp', 'select_array_bounds')
1716 return(None)
1717 print(f'lower bound = {x_low} (inclusive)\nupper bound = {x_upp} (exclusive)]')
1718 quick_plot((range(len_a), a), vlines=(x_low,x_upp), title=title)
1719 if not input_yesno('Accept these bounds (y/n)?', 'y'):
1720 x_low, x_upp = select_array_bounds(a, None, None, num_x_min, title=title)
1721 clear_plot(title)
1722 return(x_low, x_upp)
1723
1724 def select_image_bounds(a, axis, low=None, upp=None, num_min=None, title='select array bounds',
1725 raise_error=False):
1726 """Interactively select the lower and upper data bounds for a 2D numpy array.
1727 """
1728 a = np.asarray(a)
1729 if a.ndim != 2:
1730 illegal_value(a.ndim, 'array dimension', location='select_image_bounds',
1731 raise_error=raise_error)
1732 return(None)
1733 if axis < 0 or axis >= a.ndim:
1734 illegal_value(axis, 'axis', location='select_image_bounds', raise_error=raise_error)
1735 return(None)
1736 low_save = low
1737 upp_save = upp
1738 num_min_save = num_min
1739 if num_min is None:
1740 num_min = 1
1741 else:
1742 if num_min < 2 or num_min > a.shape[axis]:
1743 logger.warning('Invalid input for num_min in select_image_bounds, input ignored')
1744 num_min = 1
1745 if low is None:
1746 min_ = 0
1747 max_ = a.shape[axis]
1748 low_max = a.shape[axis]-num_min
1749 while True:
1750 if axis:
1751 quick_imshow(a[:,min_:max_], title=title, aspect='auto',
1752 extent=[min_,max_,a.shape[0],0])
1753 else:
1754 quick_imshow(a[min_:max_,:], title=title, aspect='auto',
1755 extent=[0,a.shape[1], max_,min_])
1756 zoom_flag = input_yesno('Set lower data bound (y) or zoom in (n)?', 'y')
1757 if zoom_flag:
1758 low = input_int(' Set lower data bound', ge=0, le=low_max)
1759 break
1760 else:
1761 min_ = input_int(' Set lower zoom index', ge=0, le=low_max)
1762 max_ = input_int(' Set upper zoom index', ge=min_+1, le=low_max+1)
1763 else:
1764 if not is_int(low, ge=0, le=a.shape[axis]-num_min):
1765 illegal_value(low, 'low', location='select_image_bounds', raise_error=raise_error)
1766 return(None)
1767 if upp is None:
1768 min_ = low+num_min
1769 max_ = a.shape[axis]
1770 upp_min = min_
1771 while True:
1772 if axis:
1773 quick_imshow(a[:,min_:max_], title=title, aspect='auto',
1774 extent=[min_,max_,a.shape[0],0])
1775 else:
1776 quick_imshow(a[min_:max_,:], title=title, aspect='auto',
1777 extent=[0,a.shape[1], max_,min_])
1778 zoom_flag = input_yesno('Set upper data bound (y) or zoom in (n)?', 'y')
1779 if zoom_flag:
1780 upp = input_int(' Set upper data bound', ge=upp_min, le=a.shape[axis])
1781 break
1782 else:
1783 min_ = input_int(' Set upper zoom index', ge=upp_min, le=a.shape[axis]-1)
1784 max_ = input_int(' Set upper zoom index', ge=min_+1, le=a.shape[axis])
1785 else:
1786 if not is_int(upp, ge=low+num_min, le=a.shape[axis]):
1787 illegal_value(upp, 'upp', location='select_image_bounds', raise_error=raise_error)
1788 return(None)
1789 bounds = (low, upp)
1790 a_tmp = np.copy(a)
1791 a_tmp_max = a.max()
1792 if axis:
1793 a_tmp[:,bounds[0]] = a_tmp_max
1794 a_tmp[:,bounds[1]-1] = a_tmp_max
1795 else:
1796 a_tmp[bounds[0],:] = a_tmp_max
1797 a_tmp[bounds[1]-1,:] = a_tmp_max
1798 print(f'lower bound = {low} (inclusive)\nupper bound = {upp} (exclusive)')
1799 quick_imshow(a_tmp, title=title, aspect='auto')
1800 del a_tmp
1801 if not input_yesno('Accept these bounds (y/n)?', 'y'):
1802 bounds = select_image_bounds(a, axis, low=low_save, upp=upp_save, num_min=num_min_save,
1803 title=title)
1804 return(bounds)
1805
1806 def select_one_image_bound(a, axis, bound=None, bound_name=None, title='select array bounds',
1807 default='y', raise_error=False):
1808 """Interactively select a data boundary for a 2D numpy array.
1809 """
1810 a = np.asarray(a)
1811 if a.ndim != 2:
1812 illegal_value(a.ndim, 'array dimension', location='select_one_image_bound',
1813 raise_error=raise_error)
1814 return(None)
1815 if axis < 0 or axis >= a.ndim:
1816 illegal_value(axis, 'axis', location='select_one_image_bound', raise_error=raise_error)
1817 return(None)
1818 if bound_name is None:
1819 bound_name = 'data bound'
1820 if bound is None:
1821 min_ = 0
1822 max_ = a.shape[axis]
1823 bound_max = a.shape[axis]-1
1824 while True:
1825 if axis:
1826 quick_imshow(a[:,min_:max_], title=title, aspect='auto',
1827 extent=[min_,max_,a.shape[0],0])
1828 else:
1829 quick_imshow(a[min_:max_,:], title=title, aspect='auto',
1830 extent=[0,a.shape[1], max_,min_])
1831 zoom_flag = input_yesno(f'Set {bound_name} (y) or zoom in (n)?', 'y')
1832 if zoom_flag:
1833 bound = input_int(f' Set {bound_name}', ge=0, le=bound_max)
1834 clear_imshow(title)
1835 break
1836 else:
1837 min_ = input_int(' Set lower zoom index', ge=0, le=bound_max)
1838 max_ = input_int(' Set upper zoom index', ge=min_+1, le=bound_max+1)
1839
1840 elif not is_int(bound, ge=0, le=a.shape[axis]-1):
1841 illegal_value(bound, 'bound', location='select_one_image_bound', raise_error=raise_error)
1842 return(None)
1843 else:
1844 print(f'Current {bound_name} = {bound}')
1845 a_tmp = np.copy(a)
1846 a_tmp_max = a.max()
1847 if axis:
1848 a_tmp[:,bound] = a_tmp_max
1849 else:
1850 a_tmp[bound,:] = a_tmp_max
1851 quick_imshow(a_tmp, title=title, aspect='auto')
1852 del a_tmp
1853 if not input_yesno(f'Accept this {bound_name} (y/n)?', default):
1854 bound = select_one_image_bound(a, axis, bound_name=bound_name, title=title)
1855 clear_imshow(title)
1856 return(bound)
1857
1858
1859 class Config:
1860 """Base class for processing a config file or dictionary.
1861 """
1862 def __init__(self, config_file=None, config_dict=None):
1863 self.config = {}
1864 self.load_flag = False
1865 self.suffix = None
1866
1867 # Load config file
1868 if config_file is not None and config_dict is not None:
1869 logger.warning('Ignoring config_dict (both config_file and config_dict are specified)')
1870 if config_file is not None:
1871 self.load_file(config_file)
1872 elif config_dict is not None:
1873 self.load_dict(config_dict)
1874
1875 def load_file(self, config_file):
1876 """Load a config file.
1877 """
1878 if self.load_flag:
1879 logger.warning('Overwriting any previously loaded config file')
1880 self.config = {}
1881
1882 # Ensure config file exists
1883 if not os.path.isfile(config_file):
1884 logger.error(f'Unable to load {config_file}')
1885 return
1886
1887 # Load config file (for now for Galaxy, allow .dat extension)
1888 self.suffix = os.path.splitext(config_file)[1]
1889 if self.suffix == '.yml' or self.suffix == '.yaml' or self.suffix == '.dat':
1890 with open(config_file, 'r') as f:
1891 self.config = safe_load(f)
1892 elif self.suffix == '.txt':
1893 with open(config_file, 'r') as f:
1894 lines = f.read().splitlines()
1895 self.config = {item[0].strip():literal_eval(item[1].strip()) for item in
1896 [line.split('#')[0].split('=') for line in lines if '=' in line.split('#')[0]]}
1897 else:
1898 illegal_value(self.suffix, 'config file extension', 'Config.load_file')
1899
1900 # Make sure config file was correctly loaded
1901 if isinstance(self.config, dict):
1902 self.load_flag = True
1903 else:
1904 logger.error(f'Unable to load dictionary from config file: {config_file}')
1905 self.config = {}
1906
1907 def load_dict(self, config_dict):
1908 """Takes a dictionary and places it into self.config.
1909 """
1910 if self.load_flag:
1911 logger.warning('Overwriting the previously loaded config file')
1912
1913 if isinstance(config_dict, dict):
1914 self.config = config_dict
1915 self.load_flag = True
1916 else:
1917 illegal_value(config_dict, 'dictionary config object', 'Config.load_dict')
1918 self.config = {}
1919
1920 def save_file(self, config_file):
1921 """Save the config file (as a yaml file only right now).
1922 """
1923 suffix = os.path.splitext(config_file)[1]
1924 if suffix != '.yml' and suffix != '.yaml':
1925 illegal_value(suffix, 'config file extension', 'Config.save_file')
1926
1927 # Check if config file exists
1928 if os.path.isfile(config_file):
1929 logger.info(f'Updating {config_file}')
1930 else:
1931 logger.info(f'Saving {config_file}')
1932
1933 # Save config file
1934 with open(config_file, 'w') as f:
1935 safe_dump(self.config, f)
1936
1937 def validate(self, pars_required, pars_missing=None):
1938 """Returns False if any required keys are missing.
1939 """
1940 if not self.load_flag:
1941 logger.error('Load a config file prior to calling Config.validate')
1942
1943 def validate_nested_pars(config, par):
1944 par_levels = par.split(':')
1945 first_level_par = par_levels[0]
1946 try:
1947 first_level_par = int(first_level_par)
1948 except:
1949 pass
1950 try:
1951 next_level_config = config[first_level_par]
1952 if len(par_levels) > 1:
1953 next_level_par = ':'.join(par_levels[1:])
1954 return(validate_nested_pars(next_level_config, next_level_par))
1955 else:
1956 return(True)
1957 except:
1958 return(False)
1959
1960 pars_missing = [p for p in pars_required if not validate_nested_pars(self.config, p)]
1961 if len(pars_missing) > 0:
1962 logger.error(f'Missing item(s) in configuration: {", ".join(pars_missing)}')
1963 return(False)
1964 else:
1965 return(True)