Mercurial > repos > rv43 > test_tomo_reconstruct
comparison general.py @ 0:98e23dff1de2 draft default tip
planemo upload for repository https://github.com/rolfverberg/galaxytools commit f8c4bdb31c20c468045ad5e6eb255a293244bc6c-dirty
author | rv43 |
---|---|
date | Tue, 21 Mar 2023 16:22:42 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:98e23dff1de2 |
---|---|
1 #!/usr/bin/env python3 | |
2 | |
3 #FIX write a function that returns a list of peak indices for a given plot | |
4 #FIX use raise_error concept on more functions to optionally raise an error | |
5 | |
6 # -*- coding: utf-8 -*- | |
7 """ | |
8 Created on Mon Dec 6 15:36:22 2021 | |
9 | |
10 @author: rv43 | |
11 """ | |
12 | |
13 import logging | |
14 logger=logging.getLogger(__name__) | |
15 | |
16 import os | |
17 import sys | |
18 import re | |
19 try: | |
20 from yaml import safe_load, safe_dump | |
21 except: | |
22 pass | |
23 try: | |
24 import h5py | |
25 except: | |
26 pass | |
27 import numpy as np | |
28 try: | |
29 import matplotlib.pyplot as plt | |
30 import matplotlib.lines as mlines | |
31 from matplotlib import transforms | |
32 from matplotlib.widgets import Button | |
33 except: | |
34 pass | |
35 | |
36 from ast import literal_eval | |
37 try: | |
38 from asteval import Interpreter, get_ast_names | |
39 except: | |
40 pass | |
41 from copy import deepcopy | |
42 try: | |
43 from sympy import diff, simplify | |
44 except: | |
45 pass | |
46 from time import time | |
47 | |
48 | |
49 def depth_list(L): return(isinstance(L, list) and max(map(depth_list, L))+1) | |
50 def depth_tuple(T): return(isinstance(T, tuple) and max(map(depth_tuple, T))+1) | |
51 def unwrap_tuple(T): | |
52 if depth_tuple(T) > 1 and len(T) == 1: | |
53 T = unwrap_tuple(*T) | |
54 return(T) | |
55 | |
56 def illegal_value(value, name, location=None, raise_error=False, log=True): | |
57 if not isinstance(location, str): | |
58 location = '' | |
59 else: | |
60 location = f'in {location} ' | |
61 if isinstance(name, str): | |
62 error_msg = f'Illegal value for {name} {location}({value}, {type(value)})' | |
63 else: | |
64 error_msg = f'Illegal value {location}({value}, {type(value)})' | |
65 if log: | |
66 logger.error(error_msg) | |
67 if raise_error: | |
68 raise ValueError(error_msg) | |
69 | |
70 def illegal_combination(value1, name1, value2, name2, location=None, raise_error=False, | |
71 log=True): | |
72 if not isinstance(location, str): | |
73 location = '' | |
74 else: | |
75 location = f'in {location} ' | |
76 if isinstance(name1, str): | |
77 error_msg = f'Illegal combination for {name1} and {name2} {location}'+ \ | |
78 f'({value1}, {type(value1)} and {value2}, {type(value2)})' | |
79 else: | |
80 error_msg = f'Illegal combination {location}'+ \ | |
81 f'({value1}, {type(value1)} and {value2}, {type(value2)})' | |
82 if log: | |
83 logger.error(error_msg) | |
84 if raise_error: | |
85 raise ValueError(error_msg) | |
86 | |
87 def test_ge_gt_le_lt(ge, gt, le, lt, func, location=None, raise_error=False, log=True): | |
88 """Check individual and mutual validity of ge, gt, le, lt qualifiers | |
89 func: is_int or is_num to test for int or numbers | |
90 Return: True upon success or False when mutually exlusive | |
91 """ | |
92 if ge is None and gt is None and le is None and lt is None: | |
93 return(True) | |
94 if ge is not None: | |
95 if not func(ge): | |
96 illegal_value(ge, 'ge', location, raise_error, log) | |
97 return(False) | |
98 if gt is not None: | |
99 illegal_combination(ge, 'ge', gt, 'gt', location, raise_error, log) | |
100 return(False) | |
101 elif gt is not None and not func(gt): | |
102 illegal_value(gt, 'gt', location, raise_error, log) | |
103 return(False) | |
104 if le is not None: | |
105 if not func(le): | |
106 illegal_value(le, 'le', location, raise_error, log) | |
107 return(False) | |
108 if lt is not None: | |
109 illegal_combination(le, 'le', lt, 'lt', location, raise_error, log) | |
110 return(False) | |
111 elif lt is not None and not func(lt): | |
112 illegal_value(lt, 'lt', location, raise_error, log) | |
113 return(False) | |
114 if ge is not None: | |
115 if le is not None and ge > le: | |
116 illegal_combination(ge, 'ge', le, 'le', location, raise_error, log) | |
117 return(False) | |
118 elif lt is not None and ge >= lt: | |
119 illegal_combination(ge, 'ge', lt, 'lt', location, raise_error, log) | |
120 return(False) | |
121 elif gt is not None: | |
122 if le is not None and gt >= le: | |
123 illegal_combination(gt, 'gt', le, 'le', location, raise_error, log) | |
124 return(False) | |
125 elif lt is not None and gt >= lt: | |
126 illegal_combination(gt, 'gt', lt, 'lt', location, raise_error, log) | |
127 return(False) | |
128 return(True) | |
129 | |
130 def range_string_ge_gt_le_lt(ge=None, gt=None, le=None, lt=None): | |
131 """Return a range string representation matching the ge, gt, le, lt qualifiers | |
132 Does not validate the inputs, do that as needed before calling | |
133 """ | |
134 range_string = '' | |
135 if ge is not None: | |
136 if le is None and lt is None: | |
137 range_string += f'>= {ge}' | |
138 else: | |
139 range_string += f'[{ge}, ' | |
140 elif gt is not None: | |
141 if le is None and lt is None: | |
142 range_string += f'> {gt}' | |
143 else: | |
144 range_string += f'({gt}, ' | |
145 if le is not None: | |
146 if ge is None and gt is None: | |
147 range_string += f'<= {le}' | |
148 else: | |
149 range_string += f'{le}]' | |
150 elif lt is not None: | |
151 if ge is None and gt is None: | |
152 range_string += f'< {lt}' | |
153 else: | |
154 range_string += f'{lt})' | |
155 return(range_string) | |
156 | |
157 def is_int(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): | |
158 """Value is an integer in range ge <= v <= le or gt < v < lt or some combination. | |
159 Return: True if yes or False is no | |
160 """ | |
161 return(_is_int_or_num(v, 'int', ge, gt, le, lt, raise_error, log)) | |
162 | |
163 def is_num(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): | |
164 """Value is a number in range ge <= v <= le or gt < v < lt or some combination. | |
165 Return: True if yes or False is no | |
166 """ | |
167 return(_is_int_or_num(v, 'num', ge, gt, le, lt, raise_error, log)) | |
168 | |
169 def _is_int_or_num(v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False, | |
170 log=True): | |
171 if type_str == 'int': | |
172 if not isinstance(v, int): | |
173 illegal_value(v, 'v', '_is_int_or_num', raise_error, log) | |
174 return(False) | |
175 if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, '_is_int_or_num', raise_error, log): | |
176 return(False) | |
177 elif type_str == 'num': | |
178 if not isinstance(v, (int, float)): | |
179 illegal_value(v, 'v', '_is_int_or_num', raise_error, log) | |
180 return(False) | |
181 if not test_ge_gt_le_lt(ge, gt, le, lt, is_num, '_is_int_or_num', raise_error, log): | |
182 return(False) | |
183 else: | |
184 illegal_value(type_str, 'type_str', '_is_int_or_num', raise_error, log) | |
185 return(False) | |
186 if ge is None and gt is None and le is None and lt is None: | |
187 return(True) | |
188 error = False | |
189 if ge is not None and v < ge: | |
190 error = True | |
191 error_msg = f'Value {v} out of range: {v} !>= {ge}' | |
192 if not error and gt is not None and v <= gt: | |
193 error = True | |
194 error_msg = f'Value {v} out of range: {v} !> {gt}' | |
195 if not error and le is not None and v > le: | |
196 error = True | |
197 error_msg = f'Value {v} out of range: {v} !<= {le}' | |
198 if not error and lt is not None and v >= lt: | |
199 error = True | |
200 error_msg = f'Value {v} out of range: {v} !< {lt}' | |
201 if error: | |
202 if log: | |
203 logger.error(error_msg) | |
204 if raise_error: | |
205 raise ValueError(error_msg) | |
206 return(False) | |
207 return(True) | |
208 | |
209 def is_int_pair(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): | |
210 """Value is an integer pair, each in range ge <= v[i] <= le or gt < v[i] < lt or | |
211 ge[i] <= v[i] <= le[i] or gt[i] < v[i] < lt[i] or some combination. | |
212 Return: True if yes or False is no | |
213 """ | |
214 return(_is_int_or_num_pair(v, 'int', ge, gt, le, lt, raise_error, log)) | |
215 | |
216 def is_num_pair(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): | |
217 """Value is a number pair, each in range ge <= v[i] <= le or gt < v[i] < lt or | |
218 ge[i] <= v[i] <= le[i] or gt[i] < v[i] < lt[i] or some combination. | |
219 Return: True if yes or False is no | |
220 """ | |
221 return(_is_int_or_num_pair(v, 'num', ge, gt, le, lt, raise_error, log)) | |
222 | |
223 def _is_int_or_num_pair(v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False, | |
224 log=True): | |
225 if type_str == 'int': | |
226 if not (isinstance(v, (tuple, list)) and len(v) == 2 and isinstance(v[0], int) and | |
227 isinstance(v[1], int)): | |
228 illegal_value(v, 'v', '_is_int_or_num_pair', raise_error, log) | |
229 return(False) | |
230 func = is_int | |
231 elif type_str == 'num': | |
232 if not (isinstance(v, (tuple, list)) and len(v) == 2 and isinstance(v[0], (int, float)) and | |
233 isinstance(v[1], (int, float))): | |
234 illegal_value(v, 'v', '_is_int_or_num_pair', raise_error, log) | |
235 return(False) | |
236 func = is_num | |
237 else: | |
238 illegal_value(type_str, 'type_str', '_is_int_or_num_pair', raise_error, log) | |
239 return(False) | |
240 if ge is None and gt is None and le is None and lt is None: | |
241 return(True) | |
242 if ge is None or func(ge, log=True): | |
243 ge = 2*[ge] | |
244 elif not _is_int_or_num_pair(ge, type_str, raise_error=raise_error, log=log): | |
245 return(False) | |
246 if gt is None or func(gt, log=True): | |
247 gt = 2*[gt] | |
248 elif not _is_int_or_num_pair(gt, type_str, raise_error=raise_error, log=log): | |
249 return(False) | |
250 if le is None or func(le, log=True): | |
251 le = 2*[le] | |
252 elif not _is_int_or_num_pair(le, type_str, raise_error=raise_error, log=log): | |
253 return(False) | |
254 if lt is None or func(lt, log=True): | |
255 lt = 2*[lt] | |
256 elif not _is_int_or_num_pair(lt, type_str, raise_error=raise_error, log=log): | |
257 return(False) | |
258 if (not func(v[0], ge[0], gt[0], le[0], lt[0], raise_error, log) or | |
259 not func(v[1], ge[1], gt[1], le[1], lt[1], raise_error, log)): | |
260 return(False) | |
261 return(True) | |
262 | |
263 def is_int_series(l, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): | |
264 """Value is a tuple or list of integers, each in range ge <= l[i] <= le or | |
265 gt < l[i] < lt or some combination. | |
266 """ | |
267 if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, 'is_int_series', raise_error, log): | |
268 return(False) | |
269 if not isinstance(l, (tuple, list)): | |
270 illegal_value(l, 'l', 'is_int_series', raise_error, log) | |
271 return(False) | |
272 if any(True if not is_int(v, ge, gt, le, lt, raise_error, log) else False for v in l): | |
273 return(False) | |
274 return(True) | |
275 | |
276 def is_num_series(l, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): | |
277 """Value is a tuple or list of numbers, each in range ge <= l[i] <= le or | |
278 gt < l[i] < lt or some combination. | |
279 """ | |
280 if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, 'is_int_series', raise_error, log): | |
281 return(False) | |
282 if not isinstance(l, (tuple, list)): | |
283 illegal_value(l, 'l', 'is_num_series', raise_error, log) | |
284 return(False) | |
285 if any(True if not is_num(v, ge, gt, le, lt, raise_error, log) else False for v in l): | |
286 return(False) | |
287 return(True) | |
288 | |
289 def is_str_series(l, raise_error=False, log=True): | |
290 """Value is a tuple or list of strings. | |
291 """ | |
292 if (not isinstance(l, (tuple, list)) or | |
293 any(True if not isinstance(s, str) else False for s in l)): | |
294 illegal_value(l, 'l', 'is_str_series', raise_error, log) | |
295 return(False) | |
296 return(True) | |
297 | |
298 def is_dict_series(l, raise_error=False, log=True): | |
299 """Value is a tuple or list of dictionaries. | |
300 """ | |
301 if (not isinstance(l, (tuple, list)) or | |
302 any(True if not isinstance(d, dict) else False for d in l)): | |
303 illegal_value(l, 'l', 'is_dict_series', raise_error, log) | |
304 return(False) | |
305 return(True) | |
306 | |
307 def is_dict_nums(l, raise_error=False, log=True): | |
308 """Value is a dictionary with single number values | |
309 """ | |
310 if (not isinstance(l, dict) or | |
311 any(True if not is_num(v, log=False) else False for v in l.values())): | |
312 illegal_value(l, 'l', 'is_dict_nums', raise_error, log) | |
313 return(False) | |
314 return(True) | |
315 | |
316 def is_dict_strings(l, raise_error=False, log=True): | |
317 """Value is a dictionary with single string values | |
318 """ | |
319 if (not isinstance(l, dict) or | |
320 any(True if not isinstance(v, str) else False for v in l.values())): | |
321 illegal_value(l, 'l', 'is_dict_strings', raise_error, log) | |
322 return(False) | |
323 return(True) | |
324 | |
325 def is_index(v, ge=0, lt=None, raise_error=False, log=True): | |
326 """Value is an array index in range ge <= v < lt. | |
327 NOTE lt IS NOT included! | |
328 """ | |
329 if isinstance(lt, int): | |
330 if lt <= ge: | |
331 illegal_combination(ge, 'ge', lt, 'lt', 'is_index', raise_error, log) | |
332 return(False) | |
333 return(is_int(v, ge=ge, lt=lt, raise_error=raise_error, log=log)) | |
334 | |
335 def is_index_range(v, ge=0, le=None, lt=None, raise_error=False, log=True): | |
336 """Value is an array index range in range ge <= v[0] <= v[1] <= le or ge <= v[0] <= v[1] < lt. | |
337 NOTE le IS included! | |
338 """ | |
339 if not is_int_pair(v, raise_error=raise_error, log=log): | |
340 return(False) | |
341 if not test_ge_gt_le_lt(ge, None, le, lt, is_int, 'is_index_range', raise_error, log): | |
342 return(False) | |
343 if not ge <= v[0] <= v[1] or (le is not None and v[1] > le) or (lt is not None and v[1] >= lt): | |
344 if le is not None: | |
345 error_msg = f'Value {v} out of range: !({ge} <= {v[0]} <= {v[1]} <= {le})' | |
346 else: | |
347 error_msg = f'Value {v} out of range: !({ge} <= {v[0]} <= {v[1]} < {lt})' | |
348 if log: | |
349 logger.error(error_msg) | |
350 if raise_error: | |
351 raise ValueError(error_msg) | |
352 return(False) | |
353 return(True) | |
354 | |
355 def index_nearest(a, value): | |
356 a = np.asarray(a) | |
357 if a.ndim > 1: | |
358 raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})') | |
359 # Round up for .5 | |
360 value *= 1.0+sys.float_info.epsilon | |
361 return((int)(np.argmin(np.abs(a-value)))) | |
362 | |
363 def index_nearest_low(a, value): | |
364 a = np.asarray(a) | |
365 if a.ndim > 1: | |
366 raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})') | |
367 index = int(np.argmin(np.abs(a-value))) | |
368 if value < a[index] and index > 0: | |
369 index -= 1 | |
370 return(index) | |
371 | |
372 def index_nearest_upp(a, value): | |
373 a = np.asarray(a) | |
374 if a.ndim > 1: | |
375 raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})') | |
376 index = int(np.argmin(np.abs(a-value))) | |
377 if value > a[index] and index < a.size-1: | |
378 index += 1 | |
379 return(index) | |
380 | |
381 def round_to_n(x, n=1): | |
382 if x == 0.0: | |
383 return(0) | |
384 else: | |
385 return(type(x)(round(x, n-1-int(np.floor(np.log10(abs(x))))))) | |
386 | |
387 def round_up_to_n(x, n=1): | |
388 xr = round_to_n(x, n) | |
389 if abs(x/xr) > 1.0: | |
390 xr += np.sign(x)*10**(np.floor(np.log10(abs(x)))+1-n) | |
391 return(type(x)(xr)) | |
392 | |
393 def trunc_to_n(x, n=1): | |
394 xr = round_to_n(x, n) | |
395 if abs(xr/x) > 1.0: | |
396 xr -= np.sign(x)*10**(np.floor(np.log10(abs(x)))+1-n) | |
397 return(type(x)(xr)) | |
398 | |
399 def almost_equal(a, b, sig_figs): | |
400 if is_num(a) and is_num(b): | |
401 return(abs(round_to_n(a-b, sig_figs)) < pow(10, -sig_figs+1)) | |
402 else: | |
403 raise ValueError(f'Invalid value for a or b in almost_equal (a: {a}, {type(a)}, '+ | |
404 f'b: {b}, {type(b)})') | |
405 return(False) | |
406 | |
407 def string_to_list(s, split_on_dash=True, remove_duplicates=True, sort=True): | |
408 """Return a list of numbers by splitting/expanding a string on any combination of | |
409 commas, whitespaces, or dashes (when split_on_dash=True) | |
410 e.g: '1, 3, 5-8, 12 ' -> [1, 3, 5, 6, 7, 8, 12] | |
411 """ | |
412 if not isinstance(s, str): | |
413 illegal_value(s, location='string_to_list') | |
414 return(None) | |
415 if not len(s): | |
416 return([]) | |
417 try: | |
418 ll = [x for x in re.split('\s+,\s+|\s+,|,\s+|\s+|,', s.strip())] | |
419 except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): | |
420 return(None) | |
421 if split_on_dash: | |
422 try: | |
423 l = [] | |
424 for l1 in ll: | |
425 l2 = [literal_eval(x) for x in re.split('\s+-\s+|\s+-|-\s+|\s+|-', l1)] | |
426 if len(l2) == 1: | |
427 l += l2 | |
428 elif len(l2) == 2 and l2[1] > l2[0]: | |
429 l += [i for i in range(l2[0], l2[1]+1)] | |
430 else: | |
431 raise ValueError | |
432 except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): | |
433 return(None) | |
434 else: | |
435 l = [literal_eval(x) for x in ll] | |
436 if remove_duplicates: | |
437 l = list(dict.fromkeys(l)) | |
438 if sort: | |
439 l = sorted(l) | |
440 return(l) | |
441 | |
442 def get_trailing_int(string): | |
443 indexRegex = re.compile(r'\d+$') | |
444 mo = indexRegex.search(string) | |
445 if mo is None: | |
446 return(None) | |
447 else: | |
448 return(int(mo.group())) | |
449 | |
450 def input_int(s=None, ge=None, gt=None, le=None, lt=None, default=None, inset=None, | |
451 raise_error=False, log=True): | |
452 return(_input_int_or_num('int', s, ge, gt, le, lt, default, inset, raise_error, log)) | |
453 | |
454 def input_num(s=None, ge=None, gt=None, le=None, lt=None, default=None, raise_error=False, | |
455 log=True): | |
456 return(_input_int_or_num('num', s, ge, gt, le, lt, default, None, raise_error,log)) | |
457 | |
458 def _input_int_or_num(type_str, s=None, ge=None, gt=None, le=None, lt=None, default=None, | |
459 inset=None, raise_error=False, log=True): | |
460 if type_str == 'int': | |
461 if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, '_input_int_or_num', raise_error, log): | |
462 return(None) | |
463 elif type_str == 'num': | |
464 if not test_ge_gt_le_lt(ge, gt, le, lt, is_num, '_input_int_or_num', raise_error, log): | |
465 return(None) | |
466 else: | |
467 illegal_value(type_str, 'type_str', '_input_int_or_num', raise_error, log) | |
468 return(None) | |
469 if default is not None: | |
470 if not _is_int_or_num(default, type_str, raise_error=raise_error, log=log): | |
471 return(None) | |
472 if ge is not None and default < ge: | |
473 illegal_combination(ge, 'ge', default, 'default', '_input_int_or_num', raise_error, | |
474 log) | |
475 return(None) | |
476 if gt is not None and default <= gt: | |
477 illegal_combination(gt, 'gt', default, 'default', '_input_int_or_num', raise_error, | |
478 log) | |
479 return(None) | |
480 if le is not None and default > le: | |
481 illegal_combination(le, 'le', default, 'default', '_input_int_or_num', raise_error, | |
482 log) | |
483 return(None) | |
484 if lt is not None and default >= lt: | |
485 illegal_combination(lt, 'lt', default, 'default', '_input_int_or_num', raise_error, | |
486 log) | |
487 return(None) | |
488 default_string = f' [{default}]' | |
489 else: | |
490 default_string = '' | |
491 if inset is not None: | |
492 if (not isinstance(inset, (tuple, list)) or any(True if not isinstance(i, int) else | |
493 False for i in inset)): | |
494 illegal_value(inset, 'inset', '_input_int_or_num', raise_error, log) | |
495 return(None) | |
496 v_range = f'{range_string_ge_gt_le_lt(ge, gt, le, lt)}' | |
497 if len(v_range): | |
498 v_range = f' {v_range}' | |
499 if s is None: | |
500 if type_str == 'int': | |
501 print(f'Enter an integer{v_range}{default_string}: ') | |
502 else: | |
503 print(f'Enter a number{v_range}{default_string}: ') | |
504 else: | |
505 print(f'{s}{v_range}{default_string}: ') | |
506 try: | |
507 i = input() | |
508 if isinstance(i, str) and not len(i): | |
509 v = default | |
510 print(f'{v}') | |
511 else: | |
512 v = literal_eval(i) | |
513 if inset and v not in inset: | |
514 raise ValueError(f'{v} not part of the set {inset}') | |
515 except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): | |
516 v = None | |
517 except: | |
518 if log: | |
519 logger.error('Unexpected error') | |
520 if raise_error: | |
521 raise ValueError('Unexpected error') | |
522 if not _is_int_or_num(v, type_str, ge, gt, le, lt): | |
523 v = _input_int_or_num(type_str, s, ge, gt, le, lt, default, inset, raise_error, log) | |
524 return(v) | |
525 | |
526 def input_int_list(s=None, ge=None, le=None, split_on_dash=True, remove_duplicates=True, | |
527 sort=True, raise_error=False, log=True): | |
528 """Prompt the user to input a list of interger and split the entered string on any combination | |
529 of commas, whitespaces, or dashes (when split_on_dash is True) | |
530 e.g: '1 3,5-8 , 12 ' -> [1, 3, 5, 6, 7, 8, 12] | |
531 remove_duplicates: removes duplicates if True (may also change the order) | |
532 sort: sort in ascending order if True | |
533 return None upon an illegal input | |
534 """ | |
535 return(_input_int_or_num_list('int', s, ge, le, split_on_dash, remove_duplicates, sort, | |
536 raise_error, log)) | |
537 | |
538 def input_num_list(s=None, ge=None, le=None, remove_duplicates=True, sort=True, raise_error=False, | |
539 log=True): | |
540 """Prompt the user to input a list of numbers and split the entered string on any combination | |
541 of commas or whitespaces | |
542 e.g: '1.0, 3, 5.8, 12 ' -> [1.0, 3.0, 5.8, 12.0] | |
543 remove_duplicates: removes duplicates if True (may also change the order) | |
544 sort: sort in ascending order if True | |
545 return None upon an illegal input | |
546 """ | |
547 return(_input_int_or_num_list('num', s, ge, le, False, remove_duplicates, sort, raise_error, | |
548 log)) | |
549 | |
550 def _input_int_or_num_list(type_str, s=None, ge=None, le=None, split_on_dash=True, | |
551 remove_duplicates=True, sort=True, raise_error=False, log=True): | |
552 #FIX do we want a limit on max dimension? | |
553 if type_str == 'int': | |
554 if not test_ge_gt_le_lt(ge, None, le, None, is_int, 'input_int_or_num_list', raise_error, | |
555 log): | |
556 return(None) | |
557 elif type_str == 'num': | |
558 if not test_ge_gt_le_lt(ge, None, le, None, is_num, 'input_int_or_num_list', raise_error, | |
559 log): | |
560 return(None) | |
561 else: | |
562 illegal_value(type_str, 'type_str', '_input_int_or_num_list') | |
563 return(None) | |
564 v_range = f'{range_string_ge_gt_le_lt(ge=ge, le=le)}' | |
565 if len(v_range): | |
566 v_range = f' (each value in {v_range})' | |
567 if s is None: | |
568 print(f'Enter a series of integers{v_range}: ') | |
569 else: | |
570 print(f'{s}{v_range}: ') | |
571 try: | |
572 l = string_to_list(input(), split_on_dash, remove_duplicates, sort) | |
573 except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): | |
574 l = None | |
575 except: | |
576 print('Unexpected error') | |
577 raise | |
578 if (not isinstance(l, list) or | |
579 any(True if not _is_int_or_num(v, type_str, ge=ge, le=le) else False for v in l)): | |
580 if split_on_dash: | |
581 print('Invalid input: enter a valid set of dash/comma/whitespace separated integers '+ | |
582 'e.g. 1 3,5-8 , 12') | |
583 else: | |
584 print('Invalid input: enter a valid set of comma/whitespace separated integers '+ | |
585 'e.g. 1 3,5 8 , 12') | |
586 l = _input_int_or_num_list(type_str, s, ge, le, split_on_dash, remove_duplicates, sort, | |
587 raise_error, log) | |
588 return(l) | |
589 | |
590 def input_yesno(s=None, default=None): | |
591 if default is not None: | |
592 if not isinstance(default, str): | |
593 illegal_value(default, 'default', 'input_yesno') | |
594 return(None) | |
595 if default.lower() in 'yes': | |
596 default = 'y' | |
597 elif default.lower() in 'no': | |
598 default = 'n' | |
599 else: | |
600 illegal_value(default, 'default', 'input_yesno') | |
601 return(None) | |
602 default_string = f' [{default}]' | |
603 else: | |
604 default_string = '' | |
605 if s is None: | |
606 print(f'Enter yes or no{default_string}: ') | |
607 else: | |
608 print(f'{s}{default_string}: ') | |
609 i = input() | |
610 if isinstance(i, str) and not len(i): | |
611 i = default | |
612 print(f'{i}') | |
613 if i is not None and i.lower() in 'yes': | |
614 v = True | |
615 elif i is not None and i.lower() in 'no': | |
616 v = False | |
617 else: | |
618 print('Invalid input, enter yes or no') | |
619 v = input_yesno(s, default) | |
620 return(v) | |
621 | |
622 def input_menu(items, default=None, header=None): | |
623 if not isinstance(items, (tuple, list)) or any(True if not isinstance(i, str) else False | |
624 for i in items): | |
625 illegal_value(items, 'items', 'input_menu') | |
626 return(None) | |
627 if default is not None: | |
628 if not (isinstance(default, str) and default in items): | |
629 logger.error(f'Invalid value for default ({default}), must be in {items}') | |
630 return(None) | |
631 default_string = f' [{items.index(default)+1}]' | |
632 else: | |
633 default_string = '' | |
634 if header is None: | |
635 print(f'Choose one of the following items (1, {len(items)}){default_string}:') | |
636 else: | |
637 print(f'{header} (1, {len(items)}){default_string}:') | |
638 for i, choice in enumerate(items): | |
639 print(f' {i+1}: {choice}') | |
640 try: | |
641 choice = input() | |
642 if isinstance(choice, str) and not len(choice): | |
643 choice = items.index(default) | |
644 print(f'{choice+1}') | |
645 else: | |
646 choice = literal_eval(choice) | |
647 if isinstance(choice, int) and 1 <= choice <= len(items): | |
648 choice -= 1 | |
649 else: | |
650 raise ValueError | |
651 except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): | |
652 choice = None | |
653 except: | |
654 print('Unexpected error') | |
655 raise | |
656 if choice is None: | |
657 print(f'Invalid choice, enter a number between 1 and {len(items)}') | |
658 choice = input_menu(items, default) | |
659 return(choice) | |
660 | |
661 def assert_no_duplicates_in_list_of_dicts(l: list, raise_error=False) -> list: | |
662 if not isinstance(l, list): | |
663 illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error) | |
664 return(None) | |
665 if any(True if not isinstance(d, dict) else False for d in l): | |
666 illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error) | |
667 return(None) | |
668 if len(l) != len([dict(t) for t in {tuple(sorted(d.items())) for d in l}]): | |
669 if raise_error: | |
670 raise ValueError(f'Duplicate items found in {l}') | |
671 else: | |
672 logger.error(f'Duplicate items found in {l}') | |
673 return(None) | |
674 else: | |
675 return(l) | |
676 | |
677 def assert_no_duplicate_key_in_list_of_dicts(l: list, key: str, raise_error=False) -> list: | |
678 if not isinstance(key, str): | |
679 illegal_value(key, 'key', 'assert_no_duplicate_key_in_list_of_dicts', raise_error) | |
680 return(None) | |
681 if not isinstance(l, list): | |
682 illegal_value(l, 'l', 'assert_no_duplicate_key_in_list_of_dicts', raise_error) | |
683 return(None) | |
684 if any(True if not isinstance(d, dict) else False for d in l): | |
685 illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error) | |
686 return(None) | |
687 keys = [d.get(key, None) for d in l] | |
688 if None in keys or len(set(keys)) != len(l): | |
689 if raise_error: | |
690 raise ValueError(f'Duplicate or missing key ({key}) found in {l}') | |
691 else: | |
692 logger.error(f'Duplicate or missing key ({key}) found in {l}') | |
693 return(None) | |
694 else: | |
695 return(l) | |
696 | |
697 def assert_no_duplicate_attr_in_list_of_objs(l: list, attr: str, raise_error=False) -> list: | |
698 if not isinstance(attr, str): | |
699 illegal_value(attr, 'attr', 'assert_no_duplicate_attr_in_list_of_objs', raise_error) | |
700 return(None) | |
701 if not isinstance(l, list): | |
702 illegal_value(l, 'l', 'assert_no_duplicate_key_in_list_of_objs', raise_error) | |
703 return(None) | |
704 attrs = [getattr(obj, attr, None) for obj in l] | |
705 if None in attrs or len(set(attrs)) != len(l): | |
706 if raise_error: | |
707 raise ValueError(f'Duplicate or missing attr ({attr}) found in {l}') | |
708 else: | |
709 logger.error(f'Duplicate or missing attr ({attr}) found in {l}') | |
710 return(None) | |
711 else: | |
712 return(l) | |
713 | |
714 def file_exists_and_readable(path): | |
715 if not os.path.isfile(path): | |
716 raise ValueError(f'{path} is not a valid file') | |
717 elif not os.access(path, os.R_OK): | |
718 raise ValueError(f'{path} is not accessible for reading') | |
719 else: | |
720 return(path) | |
721 | |
722 def create_mask(x, bounds=None, exclude_bounds=False, current_mask=None): | |
723 # bounds is a pair of number in the same units a x | |
724 if not isinstance(x, (tuple, list, np.ndarray)) or not len(x): | |
725 logger.warning(f'Invalid input array ({x}, {type(x)})') | |
726 return(None) | |
727 if bounds is not None and not is_num_pair(bounds): | |
728 logger.warning(f'Invalid bounds parameter ({bounds} {type(bounds)}, input ignored') | |
729 bounds = None | |
730 if bounds is not None: | |
731 if exclude_bounds: | |
732 mask = np.logical_or(x < min(bounds), x > max(bounds)) | |
733 else: | |
734 mask = np.logical_and(x > min(bounds), x < max(bounds)) | |
735 else: | |
736 mask = np.ones(len(x), dtype=bool) | |
737 if current_mask is not None: | |
738 if not isinstance(current_mask, (tuple, list, np.ndarray)) or len(current_mask) != len(x): | |
739 logger.warning(f'Invalid current_mask ({current_mask}, {type(current_mask)}), '+ | |
740 'input ignored') | |
741 else: | |
742 mask = np.logical_or(mask, current_mask) | |
743 if not True in mask: | |
744 logger.warning('Entire data array is masked') | |
745 return(mask) | |
746 | |
747 def eval_expr(name, expr, expr_variables, user_variables=None, max_depth=10, raise_error=False, | |
748 log=True, **kwargs): | |
749 """Evaluate an expression of expressions | |
750 """ | |
751 if not isinstance(name, str): | |
752 illegal_value(name, 'name', 'eval_expr', raise_error, log) | |
753 return(None) | |
754 if not isinstance(expr, str): | |
755 illegal_value(expr, 'expr', 'eval_expr', raise_error, log) | |
756 return(None) | |
757 if not is_dict_strings(expr_variables, log=False): | |
758 illegal_value(expr_variables, 'expr_variables', 'eval_expr', raise_error, log) | |
759 return(None) | |
760 if user_variables is not None and not is_dict_nums(user_variables, log=False): | |
761 illegal_value(user_variables, 'user_variables', 'eval_expr', raise_error, log) | |
762 return(None) | |
763 if not is_int(max_depth, gt=1, log=False): | |
764 illegal_value(max_depth, 'max_depth', 'eval_expr', raise_error, log) | |
765 return(None) | |
766 if not isinstance(raise_error, bool): | |
767 illegal_value(raise_error, 'raise_error', 'eval_expr', raise_error, log) | |
768 return(None) | |
769 if not isinstance(log, bool): | |
770 illegal_value(log, 'log', 'eval_expr', raise_error, log) | |
771 return(None) | |
772 # print(f'\nEvaluate the full expression for {expr}') | |
773 if 'chain' in kwargs: | |
774 chain = kwargs.pop('chain') | |
775 if not is_str_series(chain): | |
776 illegal_value(chain, 'chain', 'eval_expr', raise_error, log) | |
777 return(None) | |
778 else: | |
779 chain = [] | |
780 if len(chain) > max_depth: | |
781 error_msg = 'Exceeded maximum depth ({max_depth}) in eval_expr' | |
782 if log: | |
783 logger.error(error_msg) | |
784 if raise_error: | |
785 raise ValueError(error_msg) | |
786 return(None) | |
787 if name not in chain: | |
788 chain.append(name) | |
789 # print(f'start: chain = {chain}') | |
790 if 'ast' in kwargs: | |
791 ast = kwargs.pop('ast') | |
792 else: | |
793 ast = Interpreter() | |
794 if user_variables is not None: | |
795 ast.symtable.update(user_variables) | |
796 chain_vars = [var for var in get_ast_names(ast.parse(expr)) | |
797 if var in expr_variables and var not in ast.symtable] | |
798 # print(f'chain_vars: {chain_vars}') | |
799 save_chain = chain.copy() | |
800 for var in chain_vars: | |
801 # print(f'\n\tname = {name}, var = {var}:\n\t\t{expr_variables[var]}') | |
802 # print(f'\tchain = {chain}') | |
803 if var in chain: | |
804 error_msg = f'Circular variable {var} in eval_expr' | |
805 if log: | |
806 logger.error(error_msg) | |
807 if raise_error: | |
808 raise ValueError(error_msg) | |
809 return(None) | |
810 # print(f'\tknown symbols:\n\t\t{ast.user_defined_symbols()}\n') | |
811 if var in ast.user_defined_symbols(): | |
812 val = ast.symtable[var] | |
813 else: | |
814 #val = eval_expr(var, expr_variables[var], expr_variables, user_variables=user_variables, | |
815 val = eval_expr(var, expr_variables[var], expr_variables, max_depth=max_depth, | |
816 raise_error=raise_error, log=log, chain=chain, ast=ast) | |
817 if val is None: | |
818 return(None) | |
819 ast.symtable[var] = val | |
820 # print(f'\tval = {val}') | |
821 # print(f'\t{var} = {ast.symtable[var]}') | |
822 chain = save_chain.copy() | |
823 # print(f'\treset loop for {var}: chain = {chain}') | |
824 val = ast.eval(expr) | |
825 # print(f'return val for {expr} = {val}\n') | |
826 return(val) | |
827 | |
828 def full_gradient(expr, x, expr_name=None, expr_variables=None, valid_variables=None, max_depth=10, | |
829 raise_error=False, log=True, **kwargs): | |
830 """Compute the full gradient dexpr/dx | |
831 """ | |
832 if not isinstance(x, str): | |
833 illegal_value(x, 'x', 'full_gradient', raise_error, log) | |
834 return(None) | |
835 if expr_name is not None and not isinstance(expr_name, str): | |
836 illegal_value(expr_name, 'expr_name', 'eval_expr', raise_error, log) | |
837 return(None) | |
838 if expr_variables is not None and not is_dict_strings(expr_variables, log=False): | |
839 illegal_value(expr_variables, 'expr_variables', 'full_gradient', raise_error, log) | |
840 return(None) | |
841 if valid_variables is not None and not is_str_series(valid_variables, log=False): | |
842 illegal_value(valid_variables, 'valid_variables', 'full_gradient', raise_error, log) | |
843 if not is_int(max_depth, gt=1, log=False): | |
844 illegal_value(max_depth, 'max_depth', 'eval_expr', raise_error, log) | |
845 return(None) | |
846 if not isinstance(raise_error, bool): | |
847 illegal_value(raise_error, 'raise_error', 'eval_expr', raise_error, log) | |
848 return(None) | |
849 if not isinstance(log, bool): | |
850 illegal_value(log, 'log', 'eval_expr', raise_error, log) | |
851 return(None) | |
852 # print(f'\nGet full gradient of {expr_name} = {expr} with respect to {x}') | |
853 if expr_name is not None and expr_name == x: | |
854 return(1.0) | |
855 if 'chain' in kwargs: | |
856 chain = kwargs.pop('chain') | |
857 if not is_str_series(chain): | |
858 illegal_value(chain, 'chain', 'eval_expr', raise_error, log) | |
859 return(None) | |
860 else: | |
861 chain = [] | |
862 if len(chain) > max_depth: | |
863 error_msg = 'Exceeded maximum depth ({max_depth}) in eval_expr' | |
864 if log: | |
865 logger.error(error_msg) | |
866 if raise_error: | |
867 raise ValueError(error_msg) | |
868 return(None) | |
869 if expr_name is not None and expr_name not in chain: | |
870 chain.append(expr_name) | |
871 # print(f'start ({x}): chain = {chain}') | |
872 ast = Interpreter() | |
873 if expr_variables is None: | |
874 chain_vars = [] | |
875 else: | |
876 chain_vars = [var for var in get_ast_names(ast.parse(f'{expr}')) | |
877 if var in expr_variables and var != x and var not in ast.symtable] | |
878 # print(f'chain_vars: {chain_vars}') | |
879 if valid_variables is not None: | |
880 unknown_vars = [var for var in chain_vars if var not in valid_variables] | |
881 if len(unknown_vars): | |
882 error_msg = f'Unknown variable {unknown_vars} in {expr}' | |
883 if log: | |
884 logger.error(error_msg) | |
885 if raise_error: | |
886 raise ValueError(error_msg) | |
887 return(None) | |
888 dexpr_dx = diff(expr, x) | |
889 # print(f'direct gradient: d({expr})/d({x}) = {dexpr_dx} ({type(dexpr_dx)})') | |
890 save_chain = chain.copy() | |
891 for var in chain_vars: | |
892 # print(f'\n\texpr_name = {expr_name}, var = {var}:\n\t\t{expr}') | |
893 # print(f'\tchain = {chain}') | |
894 if var in chain: | |
895 error_msg = f'Circular variable {var} in full_gradient' | |
896 if log: | |
897 logger.error(error_msg) | |
898 if raise_error: | |
899 raise ValueError(error_msg) | |
900 return(None) | |
901 dexpr_dvar = diff(expr, var) | |
902 # print(f'\td({expr})/d({var}) = {dexpr_dvar}') | |
903 if dexpr_dvar: | |
904 dvar_dx = full_gradient(expr_variables[var], x, expr_name=var, | |
905 expr_variables=expr_variables, valid_variables=valid_variables, | |
906 max_depth=max_depth, raise_error=raise_error, log=log, chain=chain) | |
907 # print(f'\t\td({var})/d({x}) = {dvar_dx}') | |
908 if dvar_dx: | |
909 dexpr_dx = f'{dexpr_dx}+({dexpr_dvar})*({dvar_dx})' | |
910 # print(f'\t\t2: chain = {chain}') | |
911 chain = save_chain.copy() | |
912 # print(f'\treset loop for {var}: chain = {chain}') | |
913 # print(f'full gradient: d({expr})/d({x}) = {dexpr_dx} ({type(dexpr_dx)})') | |
914 # print(f'reset end: chain = {chain}\n\n') | |
915 return(simplify(dexpr_dx)) | |
916 | |
917 def bounds_from_mask(mask, return_include_bounds:bool=True): | |
918 bounds = [] | |
919 for i, m in enumerate(mask): | |
920 if m == return_include_bounds: | |
921 if len(bounds) == 0 or type(bounds[-1]) == tuple: | |
922 bounds.append(i) | |
923 else: | |
924 if len(bounds) > 0 and isinstance(bounds[-1], int): | |
925 bounds[-1] = (bounds[-1], i-1) | |
926 if len(bounds) > 0 and isinstance(bounds[-1], int): | |
927 bounds[-1] = (bounds[-1], mask.size-1) | |
928 return(bounds) | |
929 | |
930 def draw_mask_1d(ydata, xdata=None, current_index_ranges=None, current_mask=None, | |
931 select_mask=True, num_index_ranges_max=None, title=None, legend=None, test_mode=False): | |
932 #FIX make color blind friendly | |
933 def draw_selections(ax, current_include, current_exclude, selected_index_ranges): | |
934 ax.clear() | |
935 ax.set_title(title) | |
936 ax.legend([legend]) | |
937 ax.plot(xdata, ydata, 'k') | |
938 for (low, upp) in current_include: | |
939 xlow = 0.5*(xdata[max(0, low-1)]+xdata[low]) | |
940 xupp = 0.5*(xdata[upp]+xdata[min(num_data-1, upp+1)]) | |
941 ax.axvspan(xlow, xupp, facecolor='green', alpha=0.5) | |
942 for (low, upp) in current_exclude: | |
943 xlow = 0.5*(xdata[max(0, low-1)]+xdata[low]) | |
944 xupp = 0.5*(xdata[upp]+xdata[min(num_data-1, upp+1)]) | |
945 ax.axvspan(xlow, xupp, facecolor='red', alpha=0.5) | |
946 for (low, upp) in selected_index_ranges: | |
947 xlow = 0.5*(xdata[max(0, low-1)]+xdata[low]) | |
948 xupp = 0.5*(xdata[upp]+xdata[min(num_data-1, upp+1)]) | |
949 ax.axvspan(xlow, xupp, facecolor=selection_color, alpha=0.5) | |
950 ax.get_figure().canvas.draw() | |
951 | |
952 def onclick(event): | |
953 if event.inaxes in [fig.axes[0]]: | |
954 selected_index_ranges.append(index_nearest_upp(xdata, event.xdata)) | |
955 | |
956 def onrelease(event): | |
957 if len(selected_index_ranges) > 0: | |
958 if isinstance(selected_index_ranges[-1], int): | |
959 if event.inaxes in [fig.axes[0]]: | |
960 event.xdata = index_nearest_low(xdata, event.xdata) | |
961 if selected_index_ranges[-1] <= event.xdata: | |
962 selected_index_ranges[-1] = (selected_index_ranges[-1], event.xdata) | |
963 else: | |
964 selected_index_ranges[-1] = (event.xdata, selected_index_ranges[-1]) | |
965 draw_selections(event.inaxes, current_include, current_exclude, selected_index_ranges) | |
966 else: | |
967 selected_index_ranges.pop(-1) | |
968 | |
969 def confirm_selection(event): | |
970 plt.close() | |
971 | |
972 def clear_last_selection(event): | |
973 if len(selected_index_ranges): | |
974 selected_index_ranges.pop(-1) | |
975 else: | |
976 while len(current_include): | |
977 current_include.pop() | |
978 while len(current_exclude): | |
979 current_exclude.pop() | |
980 selected_mask.fill(False) | |
981 draw_selections(ax, current_include, current_exclude, selected_index_ranges) | |
982 | |
983 def update_mask(mask, selected_index_ranges, unselected_index_ranges): | |
984 for (low, upp) in selected_index_ranges: | |
985 selected_mask = np.logical_and(xdata >= xdata[low], xdata <= xdata[upp]) | |
986 mask = np.logical_or(mask, selected_mask) | |
987 for (low, upp) in unselected_index_ranges: | |
988 unselected_mask = np.logical_and(xdata >= xdata[low], xdata <= xdata[upp]) | |
989 mask[unselected_mask] = False | |
990 return(mask) | |
991 | |
992 def update_index_ranges(mask): | |
993 # Update the currently included index ranges (where mask is True) | |
994 current_include = [] | |
995 for i, m in enumerate(mask): | |
996 if m == True: | |
997 if len(current_include) == 0 or type(current_include[-1]) == tuple: | |
998 current_include.append(i) | |
999 else: | |
1000 if len(current_include) > 0 and isinstance(current_include[-1], int): | |
1001 current_include[-1] = (current_include[-1], i-1) | |
1002 if len(current_include) > 0 and isinstance(current_include[-1], int): | |
1003 current_include[-1] = (current_include[-1], num_data-1) | |
1004 return(current_include) | |
1005 | |
1006 # Check inputs | |
1007 ydata = np.asarray(ydata) | |
1008 if ydata.ndim > 1: | |
1009 logger.warning(f'Invalid ydata dimension ({ydata.ndim})') | |
1010 return(None, None) | |
1011 num_data = ydata.size | |
1012 if xdata is None: | |
1013 xdata = np.arange(num_data) | |
1014 else: | |
1015 xdata = np.asarray(xdata, dtype=np.float64) | |
1016 if xdata.ndim > 1 or xdata.size != num_data: | |
1017 logger.warning(f'Invalid xdata shape ({xdata.shape})') | |
1018 return(None, None) | |
1019 if not np.all(xdata[:-1] < xdata[1:]): | |
1020 logger.warning('Invalid xdata: must be monotonically increasing') | |
1021 return(None, None) | |
1022 if current_index_ranges is not None: | |
1023 if not isinstance(current_index_ranges, (tuple, list)): | |
1024 logger.warning('Invalid current_index_ranges parameter ({current_index_ranges}, '+ | |
1025 f'{type(current_index_ranges)})') | |
1026 return(None, None) | |
1027 if not isinstance(select_mask, bool): | |
1028 logger.warning('Invalid select_mask parameter ({select_mask}, {type(select_mask)})') | |
1029 return(None, None) | |
1030 if num_index_ranges_max is not None: | |
1031 logger.warning('num_index_ranges_max input not yet implemented in draw_mask_1d') | |
1032 if title is None: | |
1033 title = 'select ranges of data' | |
1034 elif not isinstance(title, str): | |
1035 illegal(title, 'title') | |
1036 title = '' | |
1037 if legend is None and not isinstance(title, str): | |
1038 illegal(legend, 'legend') | |
1039 legend = None | |
1040 | |
1041 if select_mask: | |
1042 title = f'Click and drag to {title} you wish to include' | |
1043 selection_color = 'green' | |
1044 else: | |
1045 title = f'Click and drag to {title} you wish to exclude' | |
1046 selection_color = 'red' | |
1047 | |
1048 # Set initial selected mask and the selected/unselected index ranges as needed | |
1049 selected_index_ranges = [] | |
1050 unselected_index_ranges = [] | |
1051 selected_mask = np.full(xdata.shape, False, dtype=bool) | |
1052 if current_index_ranges is None: | |
1053 if current_mask is None: | |
1054 if not select_mask: | |
1055 selected_index_ranges = [(0, num_data-1)] | |
1056 selected_mask = np.full(xdata.shape, True, dtype=bool) | |
1057 else: | |
1058 selected_mask = np.copy(np.asarray(current_mask, dtype=bool)) | |
1059 if current_index_ranges is not None and len(current_index_ranges): | |
1060 current_index_ranges = sorted([(low, upp) for (low, upp) in current_index_ranges]) | |
1061 for (low, upp) in current_index_ranges: | |
1062 if low > upp or low >= num_data or upp < 0: | |
1063 continue | |
1064 if low < 0: | |
1065 low = 0 | |
1066 if upp >= num_data: | |
1067 upp = num_data-1 | |
1068 selected_index_ranges.append((low, upp)) | |
1069 selected_mask = update_mask(selected_mask, selected_index_ranges, unselected_index_ranges) | |
1070 if current_index_ranges is not None and current_mask is not None: | |
1071 selected_mask = np.logical_and(current_mask, selected_mask) | |
1072 if current_mask is not None: | |
1073 selected_index_ranges = update_index_ranges(selected_mask) | |
1074 | |
1075 # Set up range selections for display | |
1076 current_include = selected_index_ranges | |
1077 current_exclude = [] | |
1078 selected_index_ranges = [] | |
1079 if not len(current_include): | |
1080 if select_mask: | |
1081 current_exclude = [(0, num_data-1)] | |
1082 else: | |
1083 current_include = [(0, num_data-1)] | |
1084 else: | |
1085 if current_include[0][0] > 0: | |
1086 current_exclude.append((0, current_include[0][0]-1)) | |
1087 for i in range(1, len(current_include)): | |
1088 current_exclude.append((current_include[i-1][1]+1, current_include[i][0]-1)) | |
1089 if current_include[-1][1] < num_data-1: | |
1090 current_exclude.append((current_include[-1][1]+1, num_data-1)) | |
1091 | |
1092 if not test_mode: | |
1093 | |
1094 # Set up matplotlib figure | |
1095 plt.close('all') | |
1096 fig, ax = plt.subplots() | |
1097 plt.subplots_adjust(bottom=0.2) | |
1098 draw_selections(ax, current_include, current_exclude, selected_index_ranges) | |
1099 | |
1100 # Set up event handling for click-and-drag range selection | |
1101 cid_click = fig.canvas.mpl_connect('button_press_event', onclick) | |
1102 cid_release = fig.canvas.mpl_connect('button_release_event', onrelease) | |
1103 | |
1104 # Set up confirm / clear range selection buttons | |
1105 confirm_b = Button(plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm') | |
1106 clear_b = Button(plt.axes([0.59, 0.05, 0.15, 0.075]), 'Clear') | |
1107 cid_confirm = confirm_b.on_clicked(confirm_selection) | |
1108 cid_clear = clear_b.on_clicked(clear_last_selection) | |
1109 | |
1110 # Show figure | |
1111 plt.show(block=True) | |
1112 | |
1113 # Disconnect callbacks when figure is closed | |
1114 fig.canvas.mpl_disconnect(cid_click) | |
1115 fig.canvas.mpl_disconnect(cid_release) | |
1116 confirm_b.disconnect(cid_confirm) | |
1117 clear_b.disconnect(cid_clear) | |
1118 | |
1119 # Swap selection depending on select_mask | |
1120 if not select_mask: | |
1121 selected_index_ranges, unselected_index_ranges = unselected_index_ranges, \ | |
1122 selected_index_ranges | |
1123 | |
1124 # Update the mask with the currently selected/unselected x-ranges | |
1125 selected_mask = update_mask(selected_mask, selected_index_ranges, unselected_index_ranges) | |
1126 | |
1127 # Update the currently included index ranges (where mask is True) | |
1128 current_include = update_index_ranges(selected_mask) | |
1129 | |
1130 return(selected_mask, current_include) | |
1131 | |
1132 def select_peaks(ydata:np.ndarray, x_values:np.ndarray=None, x_mask:np.ndarray=None, | |
1133 peak_x_values:np.ndarray=np.array([]), peak_x_indices:np.ndarray=np.array([]), | |
1134 return_peak_x_values:bool=False, return_peak_x_indices:bool=False, | |
1135 return_peak_input_indices:bool=False, return_sorted:bool=False, | |
1136 title:str=None, xlabel:str=None, ylabel:str=None) -> list : | |
1137 | |
1138 # Check arguments | |
1139 if (len(peak_x_values) > 0 or return_peak_x_values) and not len(x_values) > 0: | |
1140 raise RuntimeError('Cannot use peak_x_values or return_peak_x_values without x_values') | |
1141 if not ((len(peak_x_values) > 0) ^ (len(peak_x_indices) > 0)): | |
1142 raise RuntimeError('Use exactly one of peak_x_values or peak_x_indices') | |
1143 return_format_iter = iter((return_peak_x_values, return_peak_x_indices, return_peak_input_indices)) | |
1144 if not (any(return_format_iter) and not any(return_format_iter)): | |
1145 raise RuntimeError('Exactly one of return_peak_x_values, return_peak_x_indices, or '+ | |
1146 'return_peak_input_indices must be True') | |
1147 | |
1148 EXCLUDE_PEAK_PROPERTIES = {'color': 'black', 'linestyle': '--','linewidth': 1, | |
1149 'marker': 10, 'markersize': 5, 'fillstyle': 'none'} | |
1150 INCLUDE_PEAK_PROPERTIES = {'color': 'green', 'linestyle': '-', 'linewidth': 2, | |
1151 'marker': 10, 'markersize': 10, 'fillstyle': 'full'} | |
1152 MASKED_PEAK_PROPERTIES = {'color': 'gray', 'linestyle': ':', 'linewidth': 1} | |
1153 | |
1154 # Setup reference data & plot | |
1155 x_indices = np.arange(len(ydata)) | |
1156 if x_values is None: | |
1157 x_values = x_indices | |
1158 if x_mask is None: | |
1159 x_mask = np.full(x_values.shape, True, dtype=bool) | |
1160 fig, ax = plt.subplots() | |
1161 handles = ax.plot(x_values, ydata, label='Reference data') | |
1162 handles.append(mlines.Line2D([], [], label='Excluded / unselected HKL', **EXCLUDE_PEAK_PROPERTIES)) | |
1163 handles.append(mlines.Line2D([], [], label='Included / selected HKL', **INCLUDE_PEAK_PROPERTIES)) | |
1164 handles.append(mlines.Line2D([], [], label='HKL in masked region (unselectable)', **MASKED_PEAK_PROPERTIES)) | |
1165 ax.legend(handles=handles, loc='upper right') | |
1166 ax.set(title=title, xlabel=xlabel, ylabel=ylabel) | |
1167 | |
1168 | |
1169 # Plot vertical line at each peak | |
1170 value_to_index = lambda x_value: int(np.argmin(abs(x_values - x_value))) | |
1171 if len(peak_x_indices) > 0: | |
1172 peak_x_values = x_values[peak_x_indices] | |
1173 else: | |
1174 peak_x_indices = np.array(list(map(value_to_index, peak_x_values))) | |
1175 peak_vlines = [] | |
1176 for loc in peak_x_values: | |
1177 nearest_index = value_to_index(loc) | |
1178 if nearest_index in x_indices[x_mask]: | |
1179 peak_vline = ax.axvline(loc, **EXCLUDE_PEAK_PROPERTIES) | |
1180 peak_vline.set_picker(5) | |
1181 else: | |
1182 peak_vline = ax.axvline(loc, **MASKED_PEAK_PROPERTIES) | |
1183 peak_vlines.append(peak_vline) | |
1184 | |
1185 # Indicate masked regions by gray-ing out the axes facecolor | |
1186 mask_exclude_bounds = bounds_from_mask(x_mask, return_include_bounds=False) | |
1187 for (low, upp) in mask_exclude_bounds: | |
1188 xlow = x_values[low] | |
1189 xupp = x_values[upp] | |
1190 ax.axvspan(xlow, xupp, facecolor='gray', alpha=0.5) | |
1191 | |
1192 # Setup peak picking | |
1193 selected_peak_input_indices = [] | |
1194 def onpick(event): | |
1195 try: | |
1196 peak_index = peak_vlines.index(event.artist) | |
1197 except: | |
1198 pass | |
1199 else: | |
1200 peak_vline = event.artist | |
1201 if peak_index in selected_peak_input_indices: | |
1202 peak_vline.set(**EXCLUDE_PEAK_PROPERTIES) | |
1203 selected_peak_input_indices.remove(peak_index) | |
1204 else: | |
1205 peak_vline.set(**INCLUDE_PEAK_PROPERTIES) | |
1206 selected_peak_input_indices.append(peak_index) | |
1207 plt.draw() | |
1208 cid_pick_peak = fig.canvas.mpl_connect('pick_event', onpick) | |
1209 | |
1210 # Setup "Confirm" button | |
1211 def confirm_selection(event): | |
1212 plt.close() | |
1213 plt.subplots_adjust(bottom=0.2) | |
1214 confirm_b = Button(plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm') | |
1215 cid_confirm = confirm_b.on_clicked(confirm_selection) | |
1216 | |
1217 # Show figure for user interaction | |
1218 plt.show() | |
1219 | |
1220 # Disconnect callbacks when figure is closed | |
1221 fig.canvas.mpl_disconnect(cid_pick_peak) | |
1222 confirm_b.disconnect(cid_confirm) | |
1223 | |
1224 if return_peak_input_indices: | |
1225 selected_peaks = np.array(selected_peak_input_indices) | |
1226 if return_peak_x_values: | |
1227 selected_peaks = peak_x_values[selected_peak_input_indices] | |
1228 if return_peak_x_indices: | |
1229 selected_peaks = peak_x_indices[selected_peak_input_indices] | |
1230 | |
1231 if return_sorted: | |
1232 selected_peaks.sort() | |
1233 | |
1234 return(selected_peaks) | |
1235 | |
1236 def find_image_files(path, filetype, name=None): | |
1237 if isinstance(name, str): | |
1238 name = f'{name.strip()} ' | |
1239 else: | |
1240 name = '' | |
1241 # Find available index range | |
1242 if filetype == 'tif': | |
1243 if not isinstance(path, str) or not os.path.isdir(path): | |
1244 illegal_value(path, 'path', 'find_image_files') | |
1245 return(-1, 0, []) | |
1246 indexRegex = re.compile(r'\d+') | |
1247 # At this point only tiffs | |
1248 files = sorted([f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) and | |
1249 f.endswith('.tif') and indexRegex.search(f)]) | |
1250 num_img = len(files) | |
1251 if num_img < 1: | |
1252 logger.warning(f'No available {name}files') | |
1253 return(-1, 0, []) | |
1254 first_index = indexRegex.search(files[0]).group() | |
1255 last_index = indexRegex.search(files[-1]).group() | |
1256 if first_index is None or last_index is None: | |
1257 logger.error(f'Unable to find correctly indexed {name}images') | |
1258 return(-1, 0, []) | |
1259 first_index = int(first_index) | |
1260 last_index = int(last_index) | |
1261 if num_img != last_index-first_index+1: | |
1262 logger.error(f'Non-consecutive set of indices for {name}images') | |
1263 return(-1, 0, []) | |
1264 paths = [os.path.join(path, f) for f in files] | |
1265 elif filetype == 'h5': | |
1266 if not isinstance(path, str) or not os.path.isfile(path): | |
1267 illegal_value(path, 'path', 'find_image_files') | |
1268 return(-1, 0, []) | |
1269 # At this point only h5 in alamo2 detector style | |
1270 first_index = 0 | |
1271 with h5py.File(path, 'r') as f: | |
1272 num_img = f['entry/instrument/detector/data'].shape[0] | |
1273 last_index = num_img-1 | |
1274 paths = [path] | |
1275 else: | |
1276 illegal_value(filetype, 'filetype', 'find_image_files') | |
1277 return(-1, 0, []) | |
1278 logger.info(f'Number of available {name}images: {num_img}') | |
1279 logger.info(f'Index range of available {name}images: [{first_index}, '+ | |
1280 f'{last_index}]') | |
1281 | |
1282 return(first_index, num_img, paths) | |
1283 | |
1284 def select_image_range(first_index, offset, num_available, num_img=None, name=None, | |
1285 num_required=None): | |
1286 if isinstance(name, str): | |
1287 name = f'{name.strip()} ' | |
1288 else: | |
1289 name = '' | |
1290 # Check existing values | |
1291 if not is_int(num_available, gt=0): | |
1292 logger.warning(f'No available {name}images') | |
1293 return(0, 0, 0) | |
1294 if num_img is not None and not is_int(num_img, ge=0): | |
1295 illegal_value(num_img, 'num_img', 'select_image_range') | |
1296 return(0, 0, 0) | |
1297 if is_int(first_index, ge=0) and is_int(offset, ge=0): | |
1298 if num_required is None: | |
1299 if input_yesno(f'\nCurrent {name}first image index/offset = {first_index}/{offset},'+ | |
1300 'use these values (y/n)?', 'y'): | |
1301 if num_img is not None: | |
1302 if input_yesno(f'Current number of {name}images = {num_img}, '+ | |
1303 'use this value (y/n)? ', 'y'): | |
1304 return(first_index, offset, num_img) | |
1305 else: | |
1306 if input_yesno(f'Number of available {name}images = {num_available}, '+ | |
1307 'use all (y/n)? ', 'y'): | |
1308 return(first_index, offset, num_available) | |
1309 else: | |
1310 if input_yesno(f'\nCurrent {name}first image offset = {offset}, '+ | |
1311 f'use this values (y/n)?', 'y'): | |
1312 return(first_index, offset, num_required) | |
1313 | |
1314 # Check range against requirements | |
1315 if num_required is None: | |
1316 if num_available == 1: | |
1317 return(first_index, 0, 1) | |
1318 else: | |
1319 if not is_int(num_required, ge=1): | |
1320 illegal_value(num_required, 'num_required', 'select_image_range') | |
1321 return(0, 0, 0) | |
1322 if num_available < num_required: | |
1323 logger.error(f'Unable to find the required {name}images ({num_available} out of '+ | |
1324 f'{num_required})') | |
1325 return(0, 0, 0) | |
1326 | |
1327 # Select index range | |
1328 print(f'\nThe number of available {name}images is {num_available}') | |
1329 if num_required is None: | |
1330 last_index = first_index+num_available | |
1331 use_all = f'Use all ([{first_index}, {last_index}])' | |
1332 pick_offset = 'Pick the first image index offset and the number of images' | |
1333 pick_bounds = 'Pick the first and last image index' | |
1334 choice = input_menu([use_all, pick_offset, pick_bounds], default=pick_offset) | |
1335 if not choice: | |
1336 offset = 0 | |
1337 num_img = num_available | |
1338 elif choice == 1: | |
1339 offset = input_int('Enter the first index offset', ge=0, le=last_index-first_index) | |
1340 if first_index+offset == last_index: | |
1341 num_img = 1 | |
1342 else: | |
1343 num_img = input_int('Enter the number of images', ge=1, le=num_available-offset) | |
1344 else: | |
1345 offset = input_int('Enter the first index', ge=first_index, le=last_index) | |
1346 num_img = 1-offset+input_int('Enter the last index', ge=offset, le=last_index) | |
1347 offset -= first_index | |
1348 else: | |
1349 use_all = f'Use ([{first_index}, {first_index+num_required-1}])' | |
1350 pick_offset = 'Pick the first index offset' | |
1351 choice = input_menu([use_all, pick_offset], pick_offset) | |
1352 offset = 0 | |
1353 if choice == 1: | |
1354 offset = input_int('Enter the first index offset', ge=0, le=num_available-num_required) | |
1355 num_img = num_required | |
1356 | |
1357 return(first_index, offset, num_img) | |
1358 | |
1359 def load_image(f, img_x_bounds=None, img_y_bounds=None): | |
1360 """Load a single image from file. | |
1361 """ | |
1362 if not os.path.isfile(f): | |
1363 logger.error(f'Unable to load {f}') | |
1364 return(None) | |
1365 img_read = plt.imread(f) | |
1366 if not img_x_bounds: | |
1367 img_x_bounds = (0, img_read.shape[0]) | |
1368 else: | |
1369 if (not isinstance(img_x_bounds, (tuple, list)) or len(img_x_bounds) != 2 or | |
1370 not (0 <= img_x_bounds[0] < img_x_bounds[1] <= img_read.shape[0])): | |
1371 logger.error(f'inconsistent row dimension in {f}') | |
1372 return(None) | |
1373 if not img_y_bounds: | |
1374 img_y_bounds = (0, img_read.shape[1]) | |
1375 else: | |
1376 if (not isinstance(img_y_bounds, list) or len(img_y_bounds) != 2 or | |
1377 not (0 <= img_y_bounds[0] < img_y_bounds[1] <= img_read.shape[1])): | |
1378 logger.error(f'inconsistent column dimension in {f}') | |
1379 return(None) | |
1380 return(img_read[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]]) | |
1381 | |
1382 def load_image_stack(files, filetype, img_offset, num_img, num_img_skip=0, | |
1383 img_x_bounds=None, img_y_bounds=None): | |
1384 """Load a set of images and return them as a stack. | |
1385 """ | |
1386 logger.debug(f'img_offset = {img_offset}') | |
1387 logger.debug(f'num_img = {num_img}') | |
1388 logger.debug(f'num_img_skip = {num_img_skip}') | |
1389 logger.debug(f'\nfiles:\n{files}\n') | |
1390 img_stack = np.array([]) | |
1391 if filetype == 'tif': | |
1392 img_read_stack = [] | |
1393 i = 1 | |
1394 t0 = time() | |
1395 for f in files[img_offset:img_offset+num_img:num_img_skip+1]: | |
1396 if not i%20: | |
1397 logger.info(f' loading {i}/{num_img}: {f}') | |
1398 else: | |
1399 logger.debug(f' loading {i}/{num_img}: {f}') | |
1400 img_read = load_image(f, img_x_bounds, img_y_bounds) | |
1401 img_read_stack.append(img_read) | |
1402 i += num_img_skip+1 | |
1403 img_stack = np.stack([img_read for img_read in img_read_stack]) | |
1404 logger.info(f'... done in {time()-t0:.2f} seconds!') | |
1405 logger.debug(f'img_stack shape = {np.shape(img_stack)}') | |
1406 del img_read_stack, img_read | |
1407 elif filetype == 'h5': | |
1408 if not isinstance(files[0], str) and not os.path.isfile(files[0]): | |
1409 illegal_value(files[0], 'files[0]', 'load_image_stack') | |
1410 return(img_stack) | |
1411 t0 = time() | |
1412 logger.info(f'Loading {files[0]}') | |
1413 with h5py.File(files[0], 'r') as f: | |
1414 shape = f['entry/instrument/detector/data'].shape | |
1415 if len(shape) != 3: | |
1416 logger.error(f'inconsistent dimensions in {files[0]}') | |
1417 if not img_x_bounds: | |
1418 img_x_bounds = (0, shape[1]) | |
1419 else: | |
1420 if (not isinstance(img_x_bounds, (tuple, list)) or len(img_x_bounds) != 2 or | |
1421 not (0 <= img_x_bounds[0] < img_x_bounds[1] <= shape[1])): | |
1422 logger.error(f'inconsistent row dimension in {files[0]} {img_x_bounds} '+ | |
1423 f'{shape[1]}') | |
1424 if not img_y_bounds: | |
1425 img_y_bounds = (0, shape[2]) | |
1426 else: | |
1427 if (not isinstance(img_y_bounds, list) or len(img_y_bounds) != 2 or | |
1428 not (0 <= img_y_bounds[0] < img_y_bounds[1] <= shape[2])): | |
1429 logger.error(f'inconsistent column dimension in {files[0]}') | |
1430 img_stack = f.get('entry/instrument/detector/data')[ | |
1431 img_offset:img_offset+num_img:num_img_skip+1, | |
1432 img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]] | |
1433 logger.info(f'... done in {time()-t0:.2f} seconds!') | |
1434 else: | |
1435 illegal_value(filetype, 'filetype', 'load_image_stack') | |
1436 return(img_stack) | |
1437 | |
1438 def combine_tiffs_in_h5(files, num_img, h5_filename): | |
1439 img_stack = load_image_stack(files, 'tif', 0, num_img) | |
1440 with h5py.File(h5_filename, 'w') as f: | |
1441 f.create_dataset('entry/instrument/detector/data', data=img_stack) | |
1442 del img_stack | |
1443 return([h5_filename]) | |
1444 | |
1445 def clear_imshow(title=None): | |
1446 plt.ioff() | |
1447 if title is None: | |
1448 title = 'quick imshow' | |
1449 elif not isinstance(title, str): | |
1450 illegal_value(title, 'title', 'clear_imshow') | |
1451 return | |
1452 plt.close(fig=title) | |
1453 | |
1454 def clear_plot(title=None): | |
1455 plt.ioff() | |
1456 if title is None: | |
1457 title = 'quick plot' | |
1458 elif not isinstance(title, str): | |
1459 illegal_value(title, 'title', 'clear_plot') | |
1460 return | |
1461 plt.close(fig=title) | |
1462 | |
1463 def quick_imshow(a, title=None, path=None, name=None, save_fig=False, save_only=False, | |
1464 clear=True, extent=None, show_grid=False, grid_color='w', grid_linewidth=1, | |
1465 block=False, **kwargs): | |
1466 if title is not None and not isinstance(title, str): | |
1467 illegal_value(title, 'title', 'quick_imshow') | |
1468 return | |
1469 if path is not None and not isinstance(path, str): | |
1470 illegal_value(path, 'path', 'quick_imshow') | |
1471 return | |
1472 if not isinstance(save_fig, bool): | |
1473 illegal_value(save_fig, 'save_fig', 'quick_imshow') | |
1474 return | |
1475 if not isinstance(save_only, bool): | |
1476 illegal_value(save_only, 'save_only', 'quick_imshow') | |
1477 return | |
1478 if not isinstance(clear, bool): | |
1479 illegal_value(clear, 'clear', 'quick_imshow') | |
1480 return | |
1481 if not isinstance(block, bool): | |
1482 illegal_value(block, 'block', 'quick_imshow') | |
1483 return | |
1484 if not title: | |
1485 title='quick imshow' | |
1486 # else: | |
1487 # title = re.sub(r"\s+", '_', title) | |
1488 if name is None: | |
1489 ttitle = re.sub(r"\s+", '_', title) | |
1490 if path is None: | |
1491 path = f'{ttitle}.png' | |
1492 else: | |
1493 path = f'{path}/{ttitle}.png' | |
1494 else: | |
1495 if path is None: | |
1496 path = name | |
1497 else: | |
1498 path = f'{path}/{name}' | |
1499 if 'cmap' in kwargs and a.ndim == 3 and (a.shape[2] == 3 or a.shape[2] == 4): | |
1500 use_cmap = True | |
1501 if a.shape[2] == 4 and a[:,:,-1].min() != a[:,:,-1].max(): | |
1502 use_cmap = False | |
1503 if any(True if a[i,j,0] != a[i,j,1] and a[i,j,0] != a[i,j,2] else False | |
1504 for i in range(a.shape[0]) for j in range(a.shape[1])): | |
1505 use_cmap = False | |
1506 if use_cmap: | |
1507 a = a[:,:,0] | |
1508 else: | |
1509 logger.warning('Image incompatible with cmap option, ignore cmap') | |
1510 kwargs.pop('cmap') | |
1511 if extent is None: | |
1512 extent = (0, a.shape[1], a.shape[0], 0) | |
1513 if clear: | |
1514 try: | |
1515 plt.close(fig=title) | |
1516 except: | |
1517 pass | |
1518 if not save_only: | |
1519 if block: | |
1520 plt.ioff() | |
1521 else: | |
1522 plt.ion() | |
1523 plt.figure(title) | |
1524 plt.imshow(a, extent=extent, **kwargs) | |
1525 if show_grid: | |
1526 ax = plt.gca() | |
1527 ax.grid(color=grid_color, linewidth=grid_linewidth) | |
1528 # if title != 'quick imshow': | |
1529 # plt.title = title | |
1530 if save_only: | |
1531 plt.savefig(path) | |
1532 plt.close(fig=title) | |
1533 else: | |
1534 if save_fig: | |
1535 plt.savefig(path) | |
1536 if block: | |
1537 plt.show(block=block) | |
1538 | |
1539 def quick_plot(*args, xerr=None, yerr=None, vlines=None, title=None, xlim=None, ylim=None, | |
1540 xlabel=None, ylabel=None, legend=None, path=None, name=None, show_grid=False, | |
1541 save_fig=False, save_only=False, clear=True, block=False, **kwargs): | |
1542 if title is not None and not isinstance(title, str): | |
1543 illegal_value(title, 'title', 'quick_plot') | |
1544 title = None | |
1545 if xlim is not None and not isinstance(xlim, (tuple, list)) and len(xlim) != 2: | |
1546 illegal_value(xlim, 'xlim', 'quick_plot') | |
1547 xlim = None | |
1548 if ylim is not None and not isinstance(ylim, (tuple, list)) and len(ylim) != 2: | |
1549 illegal_value(ylim, 'ylim', 'quick_plot') | |
1550 ylim = None | |
1551 if xlabel is not None and not isinstance(xlabel, str): | |
1552 illegal_value(xlabel, 'xlabel', 'quick_plot') | |
1553 xlabel = None | |
1554 if ylabel is not None and not isinstance(ylabel, str): | |
1555 illegal_value(ylabel, 'ylabel', 'quick_plot') | |
1556 ylabel = None | |
1557 if legend is not None and not isinstance(legend, (tuple, list)): | |
1558 illegal_value(legend, 'legend', 'quick_plot') | |
1559 legend = None | |
1560 if path is not None and not isinstance(path, str): | |
1561 illegal_value(path, 'path', 'quick_plot') | |
1562 return | |
1563 if not isinstance(show_grid, bool): | |
1564 illegal_value(show_grid, 'show_grid', 'quick_plot') | |
1565 return | |
1566 if not isinstance(save_fig, bool): | |
1567 illegal_value(save_fig, 'save_fig', 'quick_plot') | |
1568 return | |
1569 if not isinstance(save_only, bool): | |
1570 illegal_value(save_only, 'save_only', 'quick_plot') | |
1571 return | |
1572 if not isinstance(clear, bool): | |
1573 illegal_value(clear, 'clear', 'quick_plot') | |
1574 return | |
1575 if not isinstance(block, bool): | |
1576 illegal_value(block, 'block', 'quick_plot') | |
1577 return | |
1578 if title is None: | |
1579 title = 'quick plot' | |
1580 # else: | |
1581 # title = re.sub(r"\s+", '_', title) | |
1582 if name is None: | |
1583 ttitle = re.sub(r"\s+", '_', title) | |
1584 if path is None: | |
1585 path = f'{ttitle}.png' | |
1586 else: | |
1587 path = f'{path}/{ttitle}.png' | |
1588 else: | |
1589 if path is None: | |
1590 path = name | |
1591 else: | |
1592 path = f'{path}/{name}' | |
1593 if clear: | |
1594 try: | |
1595 plt.close(fig=title) | |
1596 except: | |
1597 pass | |
1598 args = unwrap_tuple(args) | |
1599 if depth_tuple(args) > 1 and (xerr is not None or yerr is not None): | |
1600 logger.warning('Error bars ignored form multiple curves') | |
1601 if not save_only: | |
1602 if block: | |
1603 plt.ioff() | |
1604 else: | |
1605 plt.ion() | |
1606 plt.figure(title) | |
1607 if depth_tuple(args) > 1: | |
1608 for y in args: | |
1609 plt.plot(*y, **kwargs) | |
1610 else: | |
1611 if xerr is None and yerr is None: | |
1612 plt.plot(*args, **kwargs) | |
1613 else: | |
1614 plt.errorbar(*args, xerr=xerr, yerr=yerr, **kwargs) | |
1615 if vlines is not None: | |
1616 if isinstance(vlines, (int, float)): | |
1617 vlines = [vlines] | |
1618 for v in vlines: | |
1619 plt.axvline(v, color='r', linestyle='--', **kwargs) | |
1620 # if vlines is not None: | |
1621 # for s in tuple(([x, x], list(plt.gca().get_ylim())) for x in vlines): | |
1622 # plt.plot(*s, color='red', **kwargs) | |
1623 if xlim is not None: | |
1624 plt.xlim(xlim) | |
1625 if ylim is not None: | |
1626 plt.ylim(ylim) | |
1627 if xlabel is not None: | |
1628 plt.xlabel(xlabel) | |
1629 if ylabel is not None: | |
1630 plt.ylabel(ylabel) | |
1631 if show_grid: | |
1632 ax = plt.gca() | |
1633 ax.grid(color='k')#, linewidth=1) | |
1634 if legend is not None: | |
1635 plt.legend(legend) | |
1636 if save_only: | |
1637 plt.savefig(path) | |
1638 plt.close(fig=title) | |
1639 else: | |
1640 if save_fig: | |
1641 plt.savefig(path) | |
1642 if block: | |
1643 plt.show(block=block) | |
1644 | |
1645 def select_array_bounds(a, x_low=None, x_upp=None, num_x_min=None, ask_bounds=False, | |
1646 title='select array bounds'): | |
1647 """Interactively select the lower and upper data bounds for a numpy array. | |
1648 """ | |
1649 if isinstance(a, (tuple, list)): | |
1650 a = np.array(a) | |
1651 if not isinstance(a, np.ndarray) or a.ndim != 1: | |
1652 illegal_value(a.ndim, 'array type or dimension', 'select_array_bounds') | |
1653 return(None) | |
1654 len_a = len(a) | |
1655 if num_x_min is None: | |
1656 num_x_min = 1 | |
1657 else: | |
1658 if num_x_min < 2 or num_x_min > len_a: | |
1659 logger.warning('Invalid value for num_x_min in select_array_bounds, input ignored') | |
1660 num_x_min = 1 | |
1661 | |
1662 # Ask to use current bounds | |
1663 if ask_bounds and (x_low is not None or x_upp is not None): | |
1664 if x_low is None: | |
1665 x_low = 0 | |
1666 if not is_int(x_low, ge=0, le=len_a-num_x_min): | |
1667 illegal_value(x_low, 'x_low', 'select_array_bounds') | |
1668 return(None) | |
1669 if x_upp is None: | |
1670 x_upp = len_a | |
1671 if not is_int(x_upp, ge=x_low+num_x_min, le=len_a): | |
1672 illegal_value(x_upp, 'x_upp', 'select_array_bounds') | |
1673 return(None) | |
1674 quick_plot((range(len_a), a), vlines=(x_low,x_upp), title=title) | |
1675 if not input_yesno(f'\nCurrent array bounds: [{x_low}, {x_upp}] '+ | |
1676 'use these values (y/n)?', 'y'): | |
1677 x_low = None | |
1678 x_upp = None | |
1679 else: | |
1680 clear_plot(title) | |
1681 return(x_low, x_upp) | |
1682 | |
1683 if x_low is None: | |
1684 x_min = 0 | |
1685 x_max = len_a | |
1686 x_low_max = len_a-num_x_min | |
1687 while True: | |
1688 quick_plot(range(x_min, x_max), a[x_min:x_max], title=title) | |
1689 zoom_flag = input_yesno('Set lower data bound (y) or zoom in (n)?', 'y') | |
1690 if zoom_flag: | |
1691 x_low = input_int(' Set lower data bound', ge=0, le=x_low_max) | |
1692 break | |
1693 else: | |
1694 x_min = input_int(' Set lower zoom index', ge=0, le=x_low_max) | |
1695 x_max = input_int(' Set upper zoom index', ge=x_min+1, le=x_low_max+1) | |
1696 else: | |
1697 if not is_int(x_low, ge=0, le=len_a-num_x_min): | |
1698 illegal_value(x_low, 'x_low', 'select_array_bounds') | |
1699 return(None) | |
1700 if x_upp is None: | |
1701 x_min = x_low+num_x_min | |
1702 x_max = len_a | |
1703 x_upp_min = x_min | |
1704 while True: | |
1705 quick_plot(range(x_min, x_max), a[x_min:x_max], title=title) | |
1706 zoom_flag = input_yesno('Set upper data bound (y) or zoom in (n)?', 'y') | |
1707 if zoom_flag: | |
1708 x_upp = input_int(' Set upper data bound', ge=x_upp_min, le=len_a) | |
1709 break | |
1710 else: | |
1711 x_min = input_int(' Set upper zoom index', ge=x_upp_min, le=len_a-1) | |
1712 x_max = input_int(' Set upper zoom index', ge=x_min+1, le=len_a) | |
1713 else: | |
1714 if not is_int(x_upp, ge=x_low+num_x_min, le=len_a): | |
1715 illegal_value(x_upp, 'x_upp', 'select_array_bounds') | |
1716 return(None) | |
1717 print(f'lower bound = {x_low} (inclusive)\nupper bound = {x_upp} (exclusive)]') | |
1718 quick_plot((range(len_a), a), vlines=(x_low,x_upp), title=title) | |
1719 if not input_yesno('Accept these bounds (y/n)?', 'y'): | |
1720 x_low, x_upp = select_array_bounds(a, None, None, num_x_min, title=title) | |
1721 clear_plot(title) | |
1722 return(x_low, x_upp) | |
1723 | |
1724 def select_image_bounds(a, axis, low=None, upp=None, num_min=None, title='select array bounds', | |
1725 raise_error=False): | |
1726 """Interactively select the lower and upper data bounds for a 2D numpy array. | |
1727 """ | |
1728 a = np.asarray(a) | |
1729 if a.ndim != 2: | |
1730 illegal_value(a.ndim, 'array dimension', location='select_image_bounds', | |
1731 raise_error=raise_error) | |
1732 return(None) | |
1733 if axis < 0 or axis >= a.ndim: | |
1734 illegal_value(axis, 'axis', location='select_image_bounds', raise_error=raise_error) | |
1735 return(None) | |
1736 low_save = low | |
1737 upp_save = upp | |
1738 num_min_save = num_min | |
1739 if num_min is None: | |
1740 num_min = 1 | |
1741 else: | |
1742 if num_min < 2 or num_min > a.shape[axis]: | |
1743 logger.warning('Invalid input for num_min in select_image_bounds, input ignored') | |
1744 num_min = 1 | |
1745 if low is None: | |
1746 min_ = 0 | |
1747 max_ = a.shape[axis] | |
1748 low_max = a.shape[axis]-num_min | |
1749 while True: | |
1750 if axis: | |
1751 quick_imshow(a[:,min_:max_], title=title, aspect='auto', | |
1752 extent=[min_,max_,a.shape[0],0]) | |
1753 else: | |
1754 quick_imshow(a[min_:max_,:], title=title, aspect='auto', | |
1755 extent=[0,a.shape[1], max_,min_]) | |
1756 zoom_flag = input_yesno('Set lower data bound (y) or zoom in (n)?', 'y') | |
1757 if zoom_flag: | |
1758 low = input_int(' Set lower data bound', ge=0, le=low_max) | |
1759 break | |
1760 else: | |
1761 min_ = input_int(' Set lower zoom index', ge=0, le=low_max) | |
1762 max_ = input_int(' Set upper zoom index', ge=min_+1, le=low_max+1) | |
1763 else: | |
1764 if not is_int(low, ge=0, le=a.shape[axis]-num_min): | |
1765 illegal_value(low, 'low', location='select_image_bounds', raise_error=raise_error) | |
1766 return(None) | |
1767 if upp is None: | |
1768 min_ = low+num_min | |
1769 max_ = a.shape[axis] | |
1770 upp_min = min_ | |
1771 while True: | |
1772 if axis: | |
1773 quick_imshow(a[:,min_:max_], title=title, aspect='auto', | |
1774 extent=[min_,max_,a.shape[0],0]) | |
1775 else: | |
1776 quick_imshow(a[min_:max_,:], title=title, aspect='auto', | |
1777 extent=[0,a.shape[1], max_,min_]) | |
1778 zoom_flag = input_yesno('Set upper data bound (y) or zoom in (n)?', 'y') | |
1779 if zoom_flag: | |
1780 upp = input_int(' Set upper data bound', ge=upp_min, le=a.shape[axis]) | |
1781 break | |
1782 else: | |
1783 min_ = input_int(' Set upper zoom index', ge=upp_min, le=a.shape[axis]-1) | |
1784 max_ = input_int(' Set upper zoom index', ge=min_+1, le=a.shape[axis]) | |
1785 else: | |
1786 if not is_int(upp, ge=low+num_min, le=a.shape[axis]): | |
1787 illegal_value(upp, 'upp', location='select_image_bounds', raise_error=raise_error) | |
1788 return(None) | |
1789 bounds = (low, upp) | |
1790 a_tmp = np.copy(a) | |
1791 a_tmp_max = a.max() | |
1792 if axis: | |
1793 a_tmp[:,bounds[0]] = a_tmp_max | |
1794 a_tmp[:,bounds[1]-1] = a_tmp_max | |
1795 else: | |
1796 a_tmp[bounds[0],:] = a_tmp_max | |
1797 a_tmp[bounds[1]-1,:] = a_tmp_max | |
1798 print(f'lower bound = {low} (inclusive)\nupper bound = {upp} (exclusive)') | |
1799 quick_imshow(a_tmp, title=title, aspect='auto') | |
1800 del a_tmp | |
1801 if not input_yesno('Accept these bounds (y/n)?', 'y'): | |
1802 bounds = select_image_bounds(a, axis, low=low_save, upp=upp_save, num_min=num_min_save, | |
1803 title=title) | |
1804 return(bounds) | |
1805 | |
1806 def select_one_image_bound(a, axis, bound=None, bound_name=None, title='select array bounds', | |
1807 default='y', raise_error=False): | |
1808 """Interactively select a data boundary for a 2D numpy array. | |
1809 """ | |
1810 a = np.asarray(a) | |
1811 if a.ndim != 2: | |
1812 illegal_value(a.ndim, 'array dimension', location='select_one_image_bound', | |
1813 raise_error=raise_error) | |
1814 return(None) | |
1815 if axis < 0 or axis >= a.ndim: | |
1816 illegal_value(axis, 'axis', location='select_one_image_bound', raise_error=raise_error) | |
1817 return(None) | |
1818 if bound_name is None: | |
1819 bound_name = 'data bound' | |
1820 if bound is None: | |
1821 min_ = 0 | |
1822 max_ = a.shape[axis] | |
1823 bound_max = a.shape[axis]-1 | |
1824 while True: | |
1825 if axis: | |
1826 quick_imshow(a[:,min_:max_], title=title, aspect='auto', | |
1827 extent=[min_,max_,a.shape[0],0]) | |
1828 else: | |
1829 quick_imshow(a[min_:max_,:], title=title, aspect='auto', | |
1830 extent=[0,a.shape[1], max_,min_]) | |
1831 zoom_flag = input_yesno(f'Set {bound_name} (y) or zoom in (n)?', 'y') | |
1832 if zoom_flag: | |
1833 bound = input_int(f' Set {bound_name}', ge=0, le=bound_max) | |
1834 clear_imshow(title) | |
1835 break | |
1836 else: | |
1837 min_ = input_int(' Set lower zoom index', ge=0, le=bound_max) | |
1838 max_ = input_int(' Set upper zoom index', ge=min_+1, le=bound_max+1) | |
1839 | |
1840 elif not is_int(bound, ge=0, le=a.shape[axis]-1): | |
1841 illegal_value(bound, 'bound', location='select_one_image_bound', raise_error=raise_error) | |
1842 return(None) | |
1843 else: | |
1844 print(f'Current {bound_name} = {bound}') | |
1845 a_tmp = np.copy(a) | |
1846 a_tmp_max = a.max() | |
1847 if axis: | |
1848 a_tmp[:,bound] = a_tmp_max | |
1849 else: | |
1850 a_tmp[bound,:] = a_tmp_max | |
1851 quick_imshow(a_tmp, title=title, aspect='auto') | |
1852 del a_tmp | |
1853 if not input_yesno(f'Accept this {bound_name} (y/n)?', default): | |
1854 bound = select_one_image_bound(a, axis, bound_name=bound_name, title=title) | |
1855 clear_imshow(title) | |
1856 return(bound) | |
1857 | |
1858 | |
1859 class Config: | |
1860 """Base class for processing a config file or dictionary. | |
1861 """ | |
1862 def __init__(self, config_file=None, config_dict=None): | |
1863 self.config = {} | |
1864 self.load_flag = False | |
1865 self.suffix = None | |
1866 | |
1867 # Load config file | |
1868 if config_file is not None and config_dict is not None: | |
1869 logger.warning('Ignoring config_dict (both config_file and config_dict are specified)') | |
1870 if config_file is not None: | |
1871 self.load_file(config_file) | |
1872 elif config_dict is not None: | |
1873 self.load_dict(config_dict) | |
1874 | |
1875 def load_file(self, config_file): | |
1876 """Load a config file. | |
1877 """ | |
1878 if self.load_flag: | |
1879 logger.warning('Overwriting any previously loaded config file') | |
1880 self.config = {} | |
1881 | |
1882 # Ensure config file exists | |
1883 if not os.path.isfile(config_file): | |
1884 logger.error(f'Unable to load {config_file}') | |
1885 return | |
1886 | |
1887 # Load config file (for now for Galaxy, allow .dat extension) | |
1888 self.suffix = os.path.splitext(config_file)[1] | |
1889 if self.suffix == '.yml' or self.suffix == '.yaml' or self.suffix == '.dat': | |
1890 with open(config_file, 'r') as f: | |
1891 self.config = safe_load(f) | |
1892 elif self.suffix == '.txt': | |
1893 with open(config_file, 'r') as f: | |
1894 lines = f.read().splitlines() | |
1895 self.config = {item[0].strip():literal_eval(item[1].strip()) for item in | |
1896 [line.split('#')[0].split('=') for line in lines if '=' in line.split('#')[0]]} | |
1897 else: | |
1898 illegal_value(self.suffix, 'config file extension', 'Config.load_file') | |
1899 | |
1900 # Make sure config file was correctly loaded | |
1901 if isinstance(self.config, dict): | |
1902 self.load_flag = True | |
1903 else: | |
1904 logger.error(f'Unable to load dictionary from config file: {config_file}') | |
1905 self.config = {} | |
1906 | |
1907 def load_dict(self, config_dict): | |
1908 """Takes a dictionary and places it into self.config. | |
1909 """ | |
1910 if self.load_flag: | |
1911 logger.warning('Overwriting the previously loaded config file') | |
1912 | |
1913 if isinstance(config_dict, dict): | |
1914 self.config = config_dict | |
1915 self.load_flag = True | |
1916 else: | |
1917 illegal_value(config_dict, 'dictionary config object', 'Config.load_dict') | |
1918 self.config = {} | |
1919 | |
1920 def save_file(self, config_file): | |
1921 """Save the config file (as a yaml file only right now). | |
1922 """ | |
1923 suffix = os.path.splitext(config_file)[1] | |
1924 if suffix != '.yml' and suffix != '.yaml': | |
1925 illegal_value(suffix, 'config file extension', 'Config.save_file') | |
1926 | |
1927 # Check if config file exists | |
1928 if os.path.isfile(config_file): | |
1929 logger.info(f'Updating {config_file}') | |
1930 else: | |
1931 logger.info(f'Saving {config_file}') | |
1932 | |
1933 # Save config file | |
1934 with open(config_file, 'w') as f: | |
1935 safe_dump(self.config, f) | |
1936 | |
1937 def validate(self, pars_required, pars_missing=None): | |
1938 """Returns False if any required keys are missing. | |
1939 """ | |
1940 if not self.load_flag: | |
1941 logger.error('Load a config file prior to calling Config.validate') | |
1942 | |
1943 def validate_nested_pars(config, par): | |
1944 par_levels = par.split(':') | |
1945 first_level_par = par_levels[0] | |
1946 try: | |
1947 first_level_par = int(first_level_par) | |
1948 except: | |
1949 pass | |
1950 try: | |
1951 next_level_config = config[first_level_par] | |
1952 if len(par_levels) > 1: | |
1953 next_level_par = ':'.join(par_levels[1:]) | |
1954 return(validate_nested_pars(next_level_config, next_level_par)) | |
1955 else: | |
1956 return(True) | |
1957 except: | |
1958 return(False) | |
1959 | |
1960 pars_missing = [p for p in pars_required if not validate_nested_pars(self.config, p)] | |
1961 if len(pars_missing) > 0: | |
1962 logger.error(f'Missing item(s) in configuration: {", ".join(pars_missing)}') | |
1963 return(False) | |
1964 else: | |
1965 return(True) |