Mercurial > repos > rv43 > tomo
diff general.py @ 69:fba792d5f83b draft
planemo upload for repository https://github.com/rolfverberg/galaxytools commit ab9f412c362a4ab986d00e21d5185cfcf82485c1
author | rv43 |
---|---|
date | Fri, 10 Mar 2023 16:02:04 +0000 |
parents | ba5866d0251d |
children |
line wrap: on
line diff
--- a/general.py Fri Aug 19 20:16:56 2022 +0000 +++ b/general.py Fri Mar 10 16:02:04 2023 +0000 @@ -1,5 +1,8 @@ #!/usr/bin/env python3 +#FIX write a function that returns a list of peak indices for a given plot +#FIX use raise_error concept on more functions to optionally raise an error + # -*- coding: utf-8 -*- """ Created on Mon Dec 6 15:36:22 2021 @@ -8,11 +11,15 @@ """ import logging +logger=logging.getLogger(__name__) import os import sys import re -import yaml +try: + from yaml import safe_load, safe_dump +except: + pass try: import h5py except: @@ -20,314 +27,480 @@ import numpy as np try: import matplotlib.pyplot as plt + import matplotlib.lines as mlines + from matplotlib import transforms from matplotlib.widgets import Button except: pass from ast import literal_eval +try: + from asteval import Interpreter, get_ast_names +except: + pass from copy import deepcopy +try: + from sympy import diff, simplify +except: + pass from time import time -def depth_list(L): return isinstance(L, list) and max(map(depth_list, L))+1 -def depth_tuple(T): return isinstance(T, tuple) and max(map(depth_tuple, T))+1 +def depth_list(L): return(isinstance(L, list) and max(map(depth_list, L))+1) +def depth_tuple(T): return(isinstance(T, tuple) and max(map(depth_tuple, T))+1) def unwrap_tuple(T): if depth_tuple(T) > 1 and len(T) == 1: T = unwrap_tuple(*T) - return T - -def illegal_value(value, name, location=None, exit_flag=False): + return(T) + +def illegal_value(value, name, location=None, raise_error=False, log=True): if not isinstance(location, str): location = '' else: location = f'in {location} ' if isinstance(name, str): - logging.error(f'Illegal value for {name} {location}({value}, {type(value)})') + error_msg = f'Illegal value for {name} {location}({value}, {type(value)})' else: - logging.error(f'Illegal value {location}({value}, {type(value)})') - if exit_flag: - raise ValueError + error_msg = f'Illegal value {location}({value}, {type(value)})' + if log: + logger.error(error_msg) + if raise_error: + raise ValueError(error_msg) -def is_int(v, v_min=None, v_max=None): - """Value is an integer in range v_min <= v <= v_max. - """ - if not isinstance(v, int): - return False - if v_min is not None and not isinstance(v_min, int): - illegal_value(v_min, 'v_min', 'is_int') - return False - if v_max is not None and not isinstance(v_max, int): - illegal_value(v_max, 'v_max', 'is_int') - return False - if v_min is not None and v_max is not None and v_min > v_max: - logging.error(f'Illegal v_min, v_max combination ({v_min}, {v_max})') - return False - if (v_min is not None and v < v_min) or (v_max is not None and v > v_max): - return False - return True +def illegal_combination(value1, name1, value2, name2, location=None, raise_error=False, + log=True): + if not isinstance(location, str): + location = '' + else: + location = f'in {location} ' + if isinstance(name1, str): + error_msg = f'Illegal combination for {name1} and {name2} {location}'+ \ + f'({value1}, {type(value1)} and {value2}, {type(value2)})' + else: + error_msg = f'Illegal combination {location}'+ \ + f'({value1}, {type(value1)} and {value2}, {type(value2)})' + if log: + logger.error(error_msg) + if raise_error: + raise ValueError(error_msg) -def is_int_pair(v, v_min=None, v_max=None): - """Value is an integer pair, each in range v_min <= v[i] <= v_max or - v_min[i] <= v[i] <= v_max[i]. +def test_ge_gt_le_lt(ge, gt, le, lt, func, location=None, raise_error=False, log=True): + """Check individual and mutual validity of ge, gt, le, lt qualifiers + func: is_int or is_num to test for int or numbers + Return: True upon success or False when mutually exlusive """ - if not (isinstance(v, (tuple, list)) and len(v) == 2 and isinstance(v[0], int) and - isinstance(v[1], int)): - return False - if v_min is not None or v_max is not None: - if (v_min is None or isinstance(v_min, int)) and (v_max is None or isinstance(v_max, int)): - if True in [True if not is_int(vi, v_min=v_min, v_max=v_max) else False for vi in v]: - return False - elif is_int_pair(v_min) and is_int_pair(v_max): - if True in [True if v_min[i] > v_max[i] else False for i in range(2)]: - logging.error(f'Illegal v_min, v_max combination ({v_min}, {v_max})') - return False - if True in [True if not is_int(v[i], v_min[i], v_max[i]) else False for i in range(2)]: - return False - elif is_int_pair(v_min) and (v_max is None or isinstance(v_max, int)): - if True in [True if not is_int(v[i], v_min=v_min[i], v_max=v_max) else False - for i in range(2)]: - return False - elif (v_min is None or isinstance(v_min, int)) and is_int_pair(v_max): - if True in [True if not is_int(v[i], v_min=v_min, v_max=v_max[i]) else False - for i in range(2)]: - return False + if ge is None and gt is None and le is None and lt is None: + return(True) + if ge is not None: + if not func(ge): + illegal_value(ge, 'ge', location, raise_error, log) + return(False) + if gt is not None: + illegal_combination(ge, 'ge', gt, 'gt', location, raise_error, log) + return(False) + elif gt is not None and not func(gt): + illegal_value(gt, 'gt', location, raise_error, log) + return(False) + if le is not None: + if not func(le): + illegal_value(le, 'le', location, raise_error, log) + return(False) + if lt is not None: + illegal_combination(le, 'le', lt, 'lt', location, raise_error, log) + return(False) + elif lt is not None and not func(lt): + illegal_value(lt, 'lt', location, raise_error, log) + return(False) + if ge is not None: + if le is not None and ge > le: + illegal_combination(ge, 'ge', le, 'le', location, raise_error, log) + return(False) + elif lt is not None and ge >= lt: + illegal_combination(ge, 'ge', lt, 'lt', location, raise_error, log) + return(False) + elif gt is not None: + if le is not None and gt >= le: + illegal_combination(gt, 'gt', le, 'le', location, raise_error, log) + return(False) + elif lt is not None and gt >= lt: + illegal_combination(gt, 'gt', lt, 'lt', location, raise_error, log) + return(False) + return(True) + +def range_string_ge_gt_le_lt(ge=None, gt=None, le=None, lt=None): + """Return a range string representation matching the ge, gt, le, lt qualifiers + Does not validate the inputs, do that as needed before calling + """ + range_string = '' + if ge is not None: + if le is None and lt is None: + range_string += f'>= {ge}' else: - logging.error(f'Illegal v_min or v_max input ({v_min} {type(v_min)} and '+ - f'{v_max} {type(v_max)})') - return False - return True + range_string += f'[{ge}, ' + elif gt is not None: + if le is None and lt is None: + range_string += f'> {gt}' + else: + range_string += f'({gt}, ' + if le is not None: + if ge is None and gt is None: + range_string += f'<= {le}' + else: + range_string += f'{le}]' + elif lt is not None: + if ge is None and gt is None: + range_string += f'< {lt}' + else: + range_string += f'{lt})' + return(range_string) -def is_int_series(l, v_min=None, v_max=None): - """Value is a tuple or list of integers, each in range v_min <= l[i] <= v_max. +def is_int(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): + """Value is an integer in range ge <= v <= le or gt < v < lt or some combination. + Return: True if yes or False is no """ - if v_min is not None and not isinstance(v_min, int): - illegal_value(v_min, 'v_min', 'is_int_series') - return False - if v_max is not None and not isinstance(v_max, int): - illegal_value(v_max, 'v_max', 'is_int_series') - return False - if not isinstance(l, (tuple, list)): - return False - if True in [True if not is_int(v, v_min=v_min, v_max=v_max) else False for v in l]: - return False - return True + return(_is_int_or_num(v, 'int', ge, gt, le, lt, raise_error, log)) -def is_num(v, v_min=None, v_max=None): - """Value is a number in range v_min <= v <= v_max. +def is_num(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): + """Value is a number in range ge <= v <= le or gt < v < lt or some combination. + Return: True if yes or False is no """ - if not isinstance(v, (int, float)): - return False - if v_min is not None and not isinstance(v_min, (int, float)): - illegal_value(v_min, 'v_min', 'is_num') - return False - if v_max is not None and not isinstance(v_max, (int, float)): - illegal_value(v_max, 'v_max', 'is_num') - return False - if v_min is not None and v_max is not None and v_min > v_max: - logging.error(f'Illegal v_min, v_max combination ({v_min}, {v_max})') - return False - if (v_min is not None and v < v_min) or (v_max is not None and v > v_max): - return False - return True + return(_is_int_or_num(v, 'num', ge, gt, le, lt, raise_error, log)) -def is_num_pair(v, v_min=None, v_max=None): - """Value is a number pair, each in range v_min <= v[i] <= v_max or - v_min[i] <= v[i] <= v_max[i]. +def _is_int_or_num(v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False, + log=True): + if type_str == 'int': + if not isinstance(v, int): + illegal_value(v, 'v', '_is_int_or_num', raise_error, log) + return(False) + if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, '_is_int_or_num', raise_error, log): + return(False) + elif type_str == 'num': + if not isinstance(v, (int, float)): + illegal_value(v, 'v', '_is_int_or_num', raise_error, log) + return(False) + if not test_ge_gt_le_lt(ge, gt, le, lt, is_num, '_is_int_or_num', raise_error, log): + return(False) + else: + illegal_value(type_str, 'type_str', '_is_int_or_num', raise_error, log) + return(False) + if ge is None and gt is None and le is None and lt is None: + return(True) + error = False + if ge is not None and v < ge: + error = True + error_msg = f'Value {v} out of range: {v} !>= {ge}' + if not error and gt is not None and v <= gt: + error = True + error_msg = f'Value {v} out of range: {v} !> {gt}' + if not error and le is not None and v > le: + error = True + error_msg = f'Value {v} out of range: {v} !<= {le}' + if not error and lt is not None and v >= lt: + error = True + error_msg = f'Value {v} out of range: {v} !< {lt}' + if error: + if log: + logger.error(error_msg) + if raise_error: + raise ValueError(error_msg) + return(False) + return(True) + +def is_int_pair(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): + """Value is an integer pair, each in range ge <= v[i] <= le or gt < v[i] < lt or + ge[i] <= v[i] <= le[i] or gt[i] < v[i] < lt[i] or some combination. + Return: True if yes or False is no + """ + return(_is_int_or_num_pair(v, 'int', ge, gt, le, lt, raise_error, log)) + +def is_num_pair(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): + """Value is a number pair, each in range ge <= v[i] <= le or gt < v[i] < lt or + ge[i] <= v[i] <= le[i] or gt[i] < v[i] < lt[i] or some combination. + Return: True if yes or False is no """ - if not (isinstance(v, (tuple, list)) and len(v) == 2 and isinstance(v[0], (int, float)) and - isinstance(v[1], (int, float))): - return False - if v_min is not None or v_max is not None: - if ((v_min is None or isinstance(v_min, (int, float))) and - (v_max is None or isinstance(v_max, (int, float)))): - if True in [True if not is_num(vi, v_min=v_min, v_max=v_max) else False for vi in v]: - return False - elif is_num_pair(v_min) and is_num_pair(v_max): - if True in [True if v_min[i] > v_max[i] else False for i in range(2)]: - logging.error(f'Illegal v_min, v_max combination ({v_min}, {v_max})') - return False - if True in [True if not is_num(v[i], v_min[i], v_max[i]) else False for i in range(2)]: - return False - elif is_num_pair(v_min) and (v_max is None or isinstance(v_max, (int, float))): - if True in [True if not is_num(v[i], v_min=v_min[i], v_max=v_max) else False - for i in range(2)]: - return False - elif (v_min is None or isinstance(v_min, (int, float))) and is_num_pair(v_max): - if True in [True if not is_num(v[i], v_min=v_min, v_max=v_max[i]) else False - for i in range(2)]: - return False + return(_is_int_or_num_pair(v, 'num', ge, gt, le, lt, raise_error, log)) + +def _is_int_or_num_pair(v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False, + log=True): + if type_str == 'int': + if not (isinstance(v, (tuple, list)) and len(v) == 2 and isinstance(v[0], int) and + isinstance(v[1], int)): + illegal_value(v, 'v', '_is_int_or_num_pair', raise_error, log) + return(False) + func = is_int + elif type_str == 'num': + if not (isinstance(v, (tuple, list)) and len(v) == 2 and isinstance(v[0], (int, float)) and + isinstance(v[1], (int, float))): + illegal_value(v, 'v', '_is_int_or_num_pair', raise_error, log) + return(False) + func = is_num + else: + illegal_value(type_str, 'type_str', '_is_int_or_num_pair', raise_error, log) + return(False) + if ge is None and gt is None and le is None and lt is None: + return(True) + if ge is None or func(ge, log=True): + ge = 2*[ge] + elif not _is_int_or_num_pair(ge, type_str, raise_error=raise_error, log=log): + return(False) + if gt is None or func(gt, log=True): + gt = 2*[gt] + elif not _is_int_or_num_pair(gt, type_str, raise_error=raise_error, log=log): + return(False) + if le is None or func(le, log=True): + le = 2*[le] + elif not _is_int_or_num_pair(le, type_str, raise_error=raise_error, log=log): + return(False) + if lt is None or func(lt, log=True): + lt = 2*[lt] + elif not _is_int_or_num_pair(lt, type_str, raise_error=raise_error, log=log): + return(False) + if (not func(v[0], ge[0], gt[0], le[0], lt[0], raise_error, log) or + not func(v[1], ge[1], gt[1], le[1], lt[1], raise_error, log)): + return(False) + return(True) + +def is_int_series(l, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): + """Value is a tuple or list of integers, each in range ge <= l[i] <= le or + gt < l[i] < lt or some combination. + """ + if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, 'is_int_series', raise_error, log): + return(False) + if not isinstance(l, (tuple, list)): + illegal_value(l, 'l', 'is_int_series', raise_error, log) + return(False) + if any(True if not is_int(v, ge, gt, le, lt, raise_error, log) else False for v in l): + return(False) + return(True) + +def is_num_series(l, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True): + """Value is a tuple or list of numbers, each in range ge <= l[i] <= le or + gt < l[i] < lt or some combination. + """ + if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, 'is_int_series', raise_error, log): + return(False) + if not isinstance(l, (tuple, list)): + illegal_value(l, 'l', 'is_num_series', raise_error, log) + return(False) + if any(True if not is_num(v, ge, gt, le, lt, raise_error, log) else False for v in l): + return(False) + return(True) + +def is_str_series(l, raise_error=False, log=True): + """Value is a tuple or list of strings. + """ + if (not isinstance(l, (tuple, list)) or + any(True if not isinstance(s, str) else False for s in l)): + illegal_value(l, 'l', 'is_str_series', raise_error, log) + return(False) + return(True) + +def is_dict_series(l, raise_error=False, log=True): + """Value is a tuple or list of dictionaries. + """ + if (not isinstance(l, (tuple, list)) or + any(True if not isinstance(d, dict) else False for d in l)): + illegal_value(l, 'l', 'is_dict_series', raise_error, log) + return(False) + return(True) + +def is_dict_nums(l, raise_error=False, log=True): + """Value is a dictionary with single number values + """ + if (not isinstance(l, dict) or + any(True if not is_num(v, log=False) else False for v in l.values())): + illegal_value(l, 'l', 'is_dict_nums', raise_error, log) + return(False) + return(True) + +def is_dict_strings(l, raise_error=False, log=True): + """Value is a dictionary with single string values + """ + if (not isinstance(l, dict) or + any(True if not isinstance(v, str) else False for v in l.values())): + illegal_value(l, 'l', 'is_dict_strings', raise_error, log) + return(False) + return(True) + +def is_index(v, ge=0, lt=None, raise_error=False, log=True): + """Value is an array index in range ge <= v < lt. + NOTE lt IS NOT included! + """ + if isinstance(lt, int): + if lt <= ge: + illegal_combination(ge, 'ge', lt, 'lt', 'is_index', raise_error, log) + return(False) + return(is_int(v, ge=ge, lt=lt, raise_error=raise_error, log=log)) + +def is_index_range(v, ge=0, le=None, lt=None, raise_error=False, log=True): + """Value is an array index range in range ge <= v[0] <= v[1] <= le or ge <= v[0] <= v[1] < lt. + NOTE le IS included! + """ + if not is_int_pair(v, raise_error=raise_error, log=log): + return(False) + if not test_ge_gt_le_lt(ge, None, le, lt, is_int, 'is_index_range', raise_error, log): + return(False) + if not ge <= v[0] <= v[1] or (le is not None and v[1] > le) or (lt is not None and v[1] >= lt): + if le is not None: + error_msg = f'Value {v} out of range: !({ge} <= {v[0]} <= {v[1]} <= {le})' else: - logging.error(f'Illegal v_min or v_max input ({v_min} {type(v_min)} and '+ - f'{v_max} {type(v_max)})') - return False - return True - -def is_num_series(l, v_min=None, v_max=None): - """Value is a tuple or list of numbers, each in range v_min <= l[i] <= v_max. - """ - if v_min is not None and not isinstance(v_min, (int, float)): - illegal_value(v_min, 'v_min', 'is_num_series') - return False - if v_max is not None and not isinstance(v_max, (int, float)): - illegal_value(v_max, 'v_max', 'is_num_series') - return False - if not isinstance(l, (tuple, list)): - return False - if True in [True if not is_num(v, v_min=v_min, v_max=v_max) else False for v in l]: - return False - return True - -def is_index(v, v_min=0, v_max=None): - """Value is an array index in range v_min <= v < v_max. - NOTE v_max IS NOT included! - """ - if isinstance(v_max, int): - if v_max <= v_min: - logging.error(f'Illegal v_min, v_max combination ({v_min}, {v_max})') - return False - v_max -= 1 - return is_int(v, v_min, v_max) - -def is_index_range(v, v_min=0, v_max=None): - """Value is an array index range in range v_min <= v[0] <= v[1] <= v_max. - NOTE v_max IS included! - """ - if not is_int_pair(v): - return False - if not isinstance(v_min, int): - illegal_value(v_min, 'v_min', 'is_index_range') - return False - if v_max is not None: - if not isinstance(v_max, int): - illegal_value(v_max, 'v_max', 'is_index_range') - return False - if v_max < v_min: - logging.error(f'Illegal v_min, v_max combination ({v_min}, {v_max})') - return False - if not v_min <= v[0] <= v[1] or (v_max is not None and v[1] > v_max): - return False - return True + error_msg = f'Value {v} out of range: !({ge} <= {v[0]} <= {v[1]} < {lt})' + if log: + logger.error(error_msg) + if raise_error: + raise ValueError(error_msg) + return(False) + return(True) def index_nearest(a, value): a = np.asarray(a) if a.ndim > 1: - logging.warning(f'Illegal input array ({a}, {type(a)})') + raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})') # Round up for .5 value *= 1.0+sys.float_info.epsilon - return (int)(np.argmin(np.abs(a-value))) + return((int)(np.argmin(np.abs(a-value)))) def index_nearest_low(a, value): a = np.asarray(a) if a.ndim > 1: - logging.warning(f'Illegal input array ({a}, {type(a)})') + raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})') index = int(np.argmin(np.abs(a-value))) if value < a[index] and index > 0: index -= 1 - return index + return(index) def index_nearest_upp(a, value): a = np.asarray(a) if a.ndim > 1: - logging.warning(f'Illegal input array ({a}, {type(a)})') + raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})') index = int(np.argmin(np.abs(a-value))) if value > a[index] and index < a.size-1: index += 1 - return index + return(index) def round_to_n(x, n=1): if x == 0.0: - return 0 + return(0) else: - return round(x, n-1-int(np.floor(np.log10(abs(x))))) + return(type(x)(round(x, n-1-int(np.floor(np.log10(abs(x))))))) def round_up_to_n(x, n=1): xr = round_to_n(x, n) if abs(x/xr) > 1.0: xr += np.sign(x)*10**(np.floor(np.log10(abs(x)))+1-n) - return xr + return(type(x)(xr)) def trunc_to_n(x, n=1): xr = round_to_n(x, n) if abs(xr/x) > 1.0: xr -= np.sign(x)*10**(np.floor(np.log10(abs(x)))+1-n) - return xr + return(type(x)(xr)) -def string_to_list(s): +def almost_equal(a, b, sig_figs): + if is_num(a) and is_num(b): + return(abs(round_to_n(a-b, sig_figs)) < pow(10, -sig_figs+1)) + else: + raise ValueError(f'Invalid value for a or b in almost_equal (a: {a}, {type(a)}, '+ + f'b: {b}, {type(b)})') + return(False) + +def string_to_list(s, split_on_dash=True, remove_duplicates=True, sort=True): """Return a list of numbers by splitting/expanding a string on any combination of - dashes, commas, and/or whitespaces - e.g: '1, 3, 5-8,12 ' -> [1, 3, 5, 6, 7, 8, 12] + commas, whitespaces, or dashes (when split_on_dash=True) + e.g: '1, 3, 5-8, 12 ' -> [1, 3, 5, 6, 7, 8, 12] """ if not isinstance(s, str): illegal_value(s, location='string_to_list') - return None + return(None) if not len(s): - return [] - try: - list1 = [x for x in re.split('\s+,\s+|\s+,|,\s+|\s+|,', s.strip())] - except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): - return None + return([]) try: - l = [] - for l1 in list1: - l2 = [literal_eval(x) for x in re.split('\s+-\s+|\s+-|-\s+|\s+|-', l1)] - if len(l2) == 1: - l += l2 - elif len(l2) == 2 and l2[1] > l2[0]: - l += [i for i in range(l2[0], l2[1]+1)] - else: - raise ValueError + ll = [x for x in re.split('\s+,\s+|\s+,|,\s+|\s+|,', s.strip())] except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): - return None - return sorted(set(l)) + return(None) + if split_on_dash: + try: + l = [] + for l1 in ll: + l2 = [literal_eval(x) for x in re.split('\s+-\s+|\s+-|-\s+|\s+|-', l1)] + if len(l2) == 1: + l += l2 + elif len(l2) == 2 and l2[1] > l2[0]: + l += [i for i in range(l2[0], l2[1]+1)] + else: + raise ValueError + except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): + return(None) + else: + l = [literal_eval(x) for x in ll] + if remove_duplicates: + l = list(dict.fromkeys(l)) + if sort: + l = sorted(l) + return(l) def get_trailing_int(string): indexRegex = re.compile(r'\d+$') mo = indexRegex.search(string) if mo is None: - return None + return(None) else: - return int(mo.group()) + return(int(mo.group())) + +def input_int(s=None, ge=None, gt=None, le=None, lt=None, default=None, inset=None, + raise_error=False, log=True): + return(_input_int_or_num('int', s, ge, gt, le, lt, default, inset, raise_error, log)) + +def input_num(s=None, ge=None, gt=None, le=None, lt=None, default=None, raise_error=False, + log=True): + return(_input_int_or_num('num', s, ge, gt, le, lt, default, None, raise_error,log)) -def input_int(s=None, v_min=None, v_max=None, default=None, inset=None): +def _input_int_or_num(type_str, s=None, ge=None, gt=None, le=None, lt=None, default=None, + inset=None, raise_error=False, log=True): + if type_str == 'int': + if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, '_input_int_or_num', raise_error, log): + return(None) + elif type_str == 'num': + if not test_ge_gt_le_lt(ge, gt, le, lt, is_num, '_input_int_or_num', raise_error, log): + return(None) + else: + illegal_value(type_str, 'type_str', '_input_int_or_num', raise_error, log) + return(None) if default is not None: - if not isinstance(default, int): - illegal_value(default, 'default', 'input_int') - return None + if not _is_int_or_num(default, type_str, raise_error=raise_error, log=log): + return(None) + if ge is not None and default < ge: + illegal_combination(ge, 'ge', default, 'default', '_input_int_or_num', raise_error, + log) + return(None) + if gt is not None and default <= gt: + illegal_combination(gt, 'gt', default, 'default', '_input_int_or_num', raise_error, + log) + return(None) + if le is not None and default > le: + illegal_combination(le, 'le', default, 'default', '_input_int_or_num', raise_error, + log) + return(None) + if lt is not None and default >= lt: + illegal_combination(lt, 'lt', default, 'default', '_input_int_or_num', raise_error, + log) + return(None) default_string = f' [{default}]' else: default_string = '' - if v_min is not None: - if not isinstance(v_min, int): - illegal_value(v_min, 'v_min', 'input_int') - return None - if default is not None and default < v_min: - logging.error('Illegal v_min, default combination ({v_min}, {default})') - return None - if v_max is not None: - if not isinstance(v_max, int): - illegal_value(v_max, 'v_max', 'input_int') - return None - if v_min is not None and v_min > v_max: - logging.error(f'Illegal v_min, v_max combination ({v_min}, {v_max})') - return None - if default is not None and default > v_max: - logging.error('Illegal default, v_max combination ({default}, {v_max})') - return None if inset is not None: - if (not isinstance(inset, (tuple, list)) or False in [True if isinstance(i, int) else - False for i in inset]): - illegal_value(inset, 'inset', 'input_int') - return None - if v_min is not None and v_max is not None: - v_range = f' ({v_min}, {v_max})' - elif v_min is not None: - v_range = f' (>= {v_min})' - elif v_max is not None: - v_range = f' (<= {v_max})' - else: - v_range = '' + if (not isinstance(inset, (tuple, list)) or any(True if not isinstance(i, int) else + False for i in inset)): + illegal_value(inset, 'inset', '_input_int_or_num', raise_error, log) + return(None) + v_range = f'{range_string_ge_gt_le_lt(ge, gt, le, lt)}' + if len(v_range): + v_range = f' {v_range}' if s is None: - print(f'Enter an integer{v_range}{default_string}: ') + if type_str == 'int': + print(f'Enter an integer{v_range}{default_string}: ') + else: + print(f'Enter a number{v_range}{default_string}: ') else: print(f'{s}{v_range}{default_string}: ') try: @@ -342,116 +515,90 @@ except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): v = None except: - print('Unexpected error') - raise - if not is_int(v, v_min, v_max): - print('Illegal input, enter a valid integer') - v = input_int(s, v_min, v_max, default) - return v + if log: + logger.error('Unexpected error') + if raise_error: + raise ValueError('Unexpected error') + if not _is_int_or_num(v, type_str, ge, gt, le, lt): + v = _input_int_or_num(type_str, s, ge, gt, le, lt, default, inset, raise_error, log) + return(v) + +def input_int_list(s=None, ge=None, le=None, split_on_dash=True, remove_duplicates=True, + sort=True, raise_error=False, log=True): + """Prompt the user to input a list of interger and split the entered string on any combination + of commas, whitespaces, or dashes (when split_on_dash is True) + e.g: '1 3,5-8 , 12 ' -> [1, 3, 5, 6, 7, 8, 12] + remove_duplicates: removes duplicates if True (may also change the order) + sort: sort in ascending order if True + return None upon an illegal input + """ + return(_input_int_or_num_list('int', s, ge, le, split_on_dash, remove_duplicates, sort, + raise_error, log)) -def input_num(s=None, v_min=None, v_max=None, default=None): - if default is not None: - if not isinstance(default, (int, float)): - illegal_value(default, 'default', 'input_num') - return None - default_string = f' [{default}]' - else: - default_string = '' - if v_min is not None: - if not isinstance(v_min, (int, float)): - illegal_value(vmin, 'vmin', 'input_num') - return None - if default is not None and default < v_min: - logging.error('Illegal v_min, default combination ({v_min}, {default})') - return None - if v_max is not None: - if not isinstance(v_max, (int, float)): - illegal_value(vmax, 'vmax', 'input_num') - return None - if v_min is not None and v_max < v_min: - logging.error(f'Illegal v_min, v_max combination ({v_min}, {v_max})') - return None - if default is not None and default > v_max: - logging.error('Illegal default, v_max combination ({default}, {v_max})') - return None - if v_min is not None and v_max is not None: - v_range = f' ({v_min}, {v_max})' - elif v_min is not None: - v_range = f' (>= {v_min})' - elif v_max is not None: - v_range = f' (<= {v_max})' +def input_num_list(s=None, ge=None, le=None, remove_duplicates=True, sort=True, raise_error=False, + log=True): + """Prompt the user to input a list of numbers and split the entered string on any combination + of commas or whitespaces + e.g: '1.0, 3, 5.8, 12 ' -> [1.0, 3.0, 5.8, 12.0] + remove_duplicates: removes duplicates if True (may also change the order) + sort: sort in ascending order if True + return None upon an illegal input + """ + return(_input_int_or_num_list('num', s, ge, le, False, remove_duplicates, sort, raise_error, + log)) + +def _input_int_or_num_list(type_str, s=None, ge=None, le=None, split_on_dash=True, + remove_duplicates=True, sort=True, raise_error=False, log=True): + #FIX do we want a limit on max dimension? + if type_str == 'int': + if not test_ge_gt_le_lt(ge, None, le, None, is_int, 'input_int_or_num_list', raise_error, + log): + return(None) + elif type_str == 'num': + if not test_ge_gt_le_lt(ge, None, le, None, is_num, 'input_int_or_num_list', raise_error, + log): + return(None) else: - v_range = '' - if s is None: - print(f'Enter a number{v_range}{default_string}: ') - else: - print(f'{s}{v_range}{default_string}: ') - try: - i = input() - if isinstance(i, str) and not len(i): - v = default - print(f'{v}') - else: - v = literal_eval(i) - except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): - v = None - except: - print('Unexpected error') - raise - if not is_num(v, v_min, v_max): - print('Illegal input, enter a valid number') - v = input_num(s, v_min, v_max, default) - return v - -def input_int_list(s=None, v_min=None, v_max=None): - if v_min is not None and not isinstance(v_min, int): - illegal_value(vmin, 'vmin', 'input_int_list') - return None - if v_max is not None: - if not isinstance(v_max, int): - illegal_value(vmax, 'vmax', 'input_int_list') - return None - if v_max < v_min: - logging.error(f'Illegal v_min, v_max combination ({v_min}, {v_max})') - return None - if v_min is not None and v_max is not None: - v_range = f' (each value in ({v_min}, {v_max}))' - elif v_min is not None: - v_range = f' (each value >= {v_min})' - elif v_max is not None: - v_range = f' (each value <= {v_max})' - else: - v_range = '' + illegal_value(type_str, 'type_str', '_input_int_or_num_list') + return(None) + v_range = f'{range_string_ge_gt_le_lt(ge=ge, le=le)}' + if len(v_range): + v_range = f' (each value in {v_range})' if s is None: print(f'Enter a series of integers{v_range}: ') else: print(f'{s}{v_range}: ') try: - l = string_to_list(input()) + l = string_to_list(input(), split_on_dash, remove_duplicates, sort) except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError): l = None except: print('Unexpected error') raise if (not isinstance(l, list) or - True in [True if not is_int(v, v_min, v_max) else False for v in l]): - print('Illegal input: enter a valid set of dash/comma/whitespace separated integers '+ - 'e.g. 2,3,5-8,10') - l = input_int_list(s, v_min, v_max) - return l + any(True if not _is_int_or_num(v, type_str, ge=ge, le=le) else False for v in l)): + if split_on_dash: + print('Invalid input: enter a valid set of dash/comma/whitespace separated integers '+ + 'e.g. 1 3,5-8 , 12') + else: + print('Invalid input: enter a valid set of comma/whitespace separated integers '+ + 'e.g. 1 3,5 8 , 12') + l = _input_int_or_num_list(type_str, s, ge, le, split_on_dash, remove_duplicates, sort, + raise_error, log) + return(l) def input_yesno(s=None, default=None): if default is not None: if not isinstance(default, str): illegal_value(default, 'default', 'input_yesno') - return None + return(None) if default.lower() in 'yes': default = 'y' elif default.lower() in 'no': default = 'n' else: illegal_value(default, 'default', 'input_yesno') - return None + return(None) default_string = f' [{default}]' else: default_string = '' @@ -468,19 +615,19 @@ elif i is not None and i.lower() in 'no': v = False else: - print('Illegal input, enter yes or no') + print('Invalid input, enter yes or no') v = input_yesno(s, default) - return v + return(v) def input_menu(items, default=None, header=None): - if not isinstance(items, (tuple, list)) or False in [True if isinstance(i, str) else False - for i in items]: + if not isinstance(items, (tuple, list)) or any(True if not isinstance(i, str) else False + for i in items): illegal_value(items, 'items', 'input_menu') - return None + return(None) if default is not None: if not (isinstance(default, str) and default in items): - logging.error(f'Illegal value for default ({default}), must be in {items}') - return None + logger.error(f'Invalid value for default ({default}), must be in {items}') + return(None) default_string = f' [{items.index(default)+1}]' else: default_string = '' @@ -507,38 +654,283 @@ print('Unexpected error') raise if choice is None: - print(f'Illegal choice, enter a number between 1 and {len(items)}') + print(f'Invalid choice, enter a number between 1 and {len(items)}') choice = input_menu(items, default) - return choice + return(choice) + +def assert_no_duplicates_in_list_of_dicts(l: list, raise_error=False) -> list: + if not isinstance(l, list): + illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error) + return(None) + if any(True if not isinstance(d, dict) else False for d in l): + illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error) + return(None) + if len(l) != len([dict(t) for t in {tuple(sorted(d.items())) for d in l}]): + if raise_error: + raise ValueError(f'Duplicate items found in {l}') + else: + logger.error(f'Duplicate items found in {l}') + return(None) + else: + return(l) -def create_mask(x, bounds=None, reverse_mask=False, current_mask=None): +def assert_no_duplicate_key_in_list_of_dicts(l: list, key: str, raise_error=False) -> list: + if not isinstance(key, str): + illegal_value(key, 'key', 'assert_no_duplicate_key_in_list_of_dicts', raise_error) + return(None) + if not isinstance(l, list): + illegal_value(l, 'l', 'assert_no_duplicate_key_in_list_of_dicts', raise_error) + return(None) + if any(True if not isinstance(d, dict) else False for d in l): + illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error) + return(None) + keys = [d.get(key, None) for d in l] + if None in keys or len(set(keys)) != len(l): + if raise_error: + raise ValueError(f'Duplicate or missing key ({key}) found in {l}') + else: + logger.error(f'Duplicate or missing key ({key}) found in {l}') + return(None) + else: + return(l) + +def assert_no_duplicate_attr_in_list_of_objs(l: list, attr: str, raise_error=False) -> list: + if not isinstance(attr, str): + illegal_value(attr, 'attr', 'assert_no_duplicate_attr_in_list_of_objs', raise_error) + return(None) + if not isinstance(l, list): + illegal_value(l, 'l', 'assert_no_duplicate_key_in_list_of_objs', raise_error) + return(None) + attrs = [getattr(obj, attr, None) for obj in l] + if None in attrs or len(set(attrs)) != len(l): + if raise_error: + raise ValueError(f'Duplicate or missing attr ({attr}) found in {l}') + else: + logger.error(f'Duplicate or missing attr ({attr}) found in {l}') + return(None) + else: + return(l) + +def file_exists_and_readable(path): + if not os.path.isfile(path): + raise ValueError(f'{path} is not a valid file') + elif not os.access(path, os.R_OK): + raise ValueError(f'{path} is not accessible for reading') + else: + return(path) + +def create_mask(x, bounds=None, exclude_bounds=False, current_mask=None): # bounds is a pair of number in the same units a x if not isinstance(x, (tuple, list, np.ndarray)) or not len(x): - logging.warning(f'Illegal input array ({x}, {type(x)})') - return None + logger.warning(f'Invalid input array ({x}, {type(x)})') + return(None) if bounds is not None and not is_num_pair(bounds): - logging.warning(f'Illegal bounds parameter ({bounds} {type(bounds)}, input ignored') + logger.warning(f'Invalid bounds parameter ({bounds} {type(bounds)}, input ignored') bounds = None if bounds is not None: - if not reverse_mask: + if exclude_bounds: + mask = np.logical_or(x < min(bounds), x > max(bounds)) + else: mask = np.logical_and(x > min(bounds), x < max(bounds)) - else: - mask = np.logical_or(x < min(bounds), x > max(bounds)) else: mask = np.ones(len(x), dtype=bool) if current_mask is not None: if not isinstance(current_mask, (tuple, list, np.ndarray)) or len(current_mask) != len(x): - logging.warning(f'Illegal current_mask ({current_mask}, {type(current_mask)}), '+ + logger.warning(f'Invalid current_mask ({current_mask}, {type(current_mask)}), '+ 'input ignored') else: - mask = np.logical_and(mask, current_mask) + mask = np.logical_or(mask, current_mask) if not True in mask: - logging.warning('Entire data array is masked') - return mask + logger.warning('Entire data array is masked') + return(mask) + +def eval_expr(name, expr, expr_variables, user_variables=None, max_depth=10, raise_error=False, + log=True, **kwargs): + """Evaluate an expression of expressions + """ + if not isinstance(name, str): + illegal_value(name, 'name', 'eval_expr', raise_error, log) + return(None) + if not isinstance(expr, str): + illegal_value(expr, 'expr', 'eval_expr', raise_error, log) + return(None) + if not is_dict_strings(expr_variables, log=False): + illegal_value(expr_variables, 'expr_variables', 'eval_expr', raise_error, log) + return(None) + if user_variables is not None and not is_dict_nums(user_variables, log=False): + illegal_value(user_variables, 'user_variables', 'eval_expr', raise_error, log) + return(None) + if not is_int(max_depth, gt=1, log=False): + illegal_value(max_depth, 'max_depth', 'eval_expr', raise_error, log) + return(None) + if not isinstance(raise_error, bool): + illegal_value(raise_error, 'raise_error', 'eval_expr', raise_error, log) + return(None) + if not isinstance(log, bool): + illegal_value(log, 'log', 'eval_expr', raise_error, log) + return(None) +# print(f'\nEvaluate the full expression for {expr}') + if 'chain' in kwargs: + chain = kwargs.pop('chain') + if not is_str_series(chain): + illegal_value(chain, 'chain', 'eval_expr', raise_error, log) + return(None) + else: + chain = [] + if len(chain) > max_depth: + error_msg = 'Exceeded maximum depth ({max_depth}) in eval_expr' + if log: + logger.error(error_msg) + if raise_error: + raise ValueError(error_msg) + return(None) + if name not in chain: + chain.append(name) +# print(f'start: chain = {chain}') + if 'ast' in kwargs: + ast = kwargs.pop('ast') + else: + ast = Interpreter() + if user_variables is not None: + ast.symtable.update(user_variables) + chain_vars = [var for var in get_ast_names(ast.parse(expr)) + if var in expr_variables and var not in ast.symtable] +# print(f'chain_vars: {chain_vars}') + save_chain = chain.copy() + for var in chain_vars: +# print(f'\n\tname = {name}, var = {var}:\n\t\t{expr_variables[var]}') +# print(f'\tchain = {chain}') + if var in chain: + error_msg = f'Circular variable {var} in eval_expr' + if log: + logger.error(error_msg) + if raise_error: + raise ValueError(error_msg) + return(None) +# print(f'\tknown symbols:\n\t\t{ast.user_defined_symbols()}\n') + if var in ast.user_defined_symbols(): + val = ast.symtable[var] + else: + #val = eval_expr(var, expr_variables[var], expr_variables, user_variables=user_variables, + val = eval_expr(var, expr_variables[var], expr_variables, max_depth=max_depth, + raise_error=raise_error, log=log, chain=chain, ast=ast) + if val is None: + return(None) + ast.symtable[var] = val +# print(f'\tval = {val}') +# print(f'\t{var} = {ast.symtable[var]}') + chain = save_chain.copy() +# print(f'\treset loop for {var}: chain = {chain}') + val = ast.eval(expr) +# print(f'return val for {expr} = {val}\n') + return(val) + +def full_gradient(expr, x, expr_name=None, expr_variables=None, valid_variables=None, max_depth=10, + raise_error=False, log=True, **kwargs): + """Compute the full gradient dexpr/dx + """ + if not isinstance(x, str): + illegal_value(x, 'x', 'full_gradient', raise_error, log) + return(None) + if expr_name is not None and not isinstance(expr_name, str): + illegal_value(expr_name, 'expr_name', 'eval_expr', raise_error, log) + return(None) + if expr_variables is not None and not is_dict_strings(expr_variables, log=False): + illegal_value(expr_variables, 'expr_variables', 'full_gradient', raise_error, log) + return(None) + if valid_variables is not None and not is_str_series(valid_variables, log=False): + illegal_value(valid_variables, 'valid_variables', 'full_gradient', raise_error, log) + if not is_int(max_depth, gt=1, log=False): + illegal_value(max_depth, 'max_depth', 'eval_expr', raise_error, log) + return(None) + if not isinstance(raise_error, bool): + illegal_value(raise_error, 'raise_error', 'eval_expr', raise_error, log) + return(None) + if not isinstance(log, bool): + illegal_value(log, 'log', 'eval_expr', raise_error, log) + return(None) +# print(f'\nGet full gradient of {expr_name} = {expr} with respect to {x}') + if expr_name is not None and expr_name == x: + return(1.0) + if 'chain' in kwargs: + chain = kwargs.pop('chain') + if not is_str_series(chain): + illegal_value(chain, 'chain', 'eval_expr', raise_error, log) + return(None) + else: + chain = [] + if len(chain) > max_depth: + error_msg = 'Exceeded maximum depth ({max_depth}) in eval_expr' + if log: + logger.error(error_msg) + if raise_error: + raise ValueError(error_msg) + return(None) + if expr_name is not None and expr_name not in chain: + chain.append(expr_name) +# print(f'start ({x}): chain = {chain}') + ast = Interpreter() + if expr_variables is None: + chain_vars = [] + else: + chain_vars = [var for var in get_ast_names(ast.parse(f'{expr}')) + if var in expr_variables and var != x and var not in ast.symtable] +# print(f'chain_vars: {chain_vars}') + if valid_variables is not None: + unknown_vars = [var for var in chain_vars if var not in valid_variables] + if len(unknown_vars): + error_msg = f'Unknown variable {unknown_vars} in {expr}' + if log: + logger.error(error_msg) + if raise_error: + raise ValueError(error_msg) + return(None) + dexpr_dx = diff(expr, x) +# print(f'direct gradient: d({expr})/d({x}) = {dexpr_dx} ({type(dexpr_dx)})') + save_chain = chain.copy() + for var in chain_vars: +# print(f'\n\texpr_name = {expr_name}, var = {var}:\n\t\t{expr}') +# print(f'\tchain = {chain}') + if var in chain: + error_msg = f'Circular variable {var} in full_gradient' + if log: + logger.error(error_msg) + if raise_error: + raise ValueError(error_msg) + return(None) + dexpr_dvar = diff(expr, var) +# print(f'\td({expr})/d({var}) = {dexpr_dvar}') + if dexpr_dvar: + dvar_dx = full_gradient(expr_variables[var], x, expr_name=var, + expr_variables=expr_variables, valid_variables=valid_variables, + max_depth=max_depth, raise_error=raise_error, log=log, chain=chain) +# print(f'\t\td({var})/d({x}) = {dvar_dx}') + if dvar_dx: + dexpr_dx = f'{dexpr_dx}+({dexpr_dvar})*({dvar_dx})' +# print(f'\t\t2: chain = {chain}') + chain = save_chain.copy() +# print(f'\treset loop for {var}: chain = {chain}') +# print(f'full gradient: d({expr})/d({x}) = {dexpr_dx} ({type(dexpr_dx)})') +# print(f'reset end: chain = {chain}\n\n') + return(simplify(dexpr_dx)) + +def bounds_from_mask(mask, return_include_bounds:bool=True): + bounds = [] + for i, m in enumerate(mask): + if m == return_include_bounds: + if len(bounds) == 0 or type(bounds[-1]) == tuple: + bounds.append(i) + else: + if len(bounds) > 0 and isinstance(bounds[-1], int): + bounds[-1] = (bounds[-1], i-1) + if len(bounds) > 0 and isinstance(bounds[-1], int): + bounds[-1] = (bounds[-1], mask.size-1) + return(bounds) def draw_mask_1d(ydata, xdata=None, current_index_ranges=None, current_mask=None, select_mask=True, num_index_ranges_max=None, title=None, legend=None, test_mode=False): - def draw_selections(ax): + #FIX make color blind friendly + def draw_selections(ax, current_include, current_exclude, selected_index_ranges): ax.clear() ax.set_title(title) ax.legend([legend]) @@ -570,26 +962,32 @@ selected_index_ranges[-1] = (selected_index_ranges[-1], event.xdata) else: selected_index_ranges[-1] = (event.xdata, selected_index_ranges[-1]) - draw_selections(event.inaxes) + draw_selections(event.inaxes, current_include, current_exclude, selected_index_ranges) else: selected_index_ranges.pop(-1) def confirm_selection(event): plt.close() - + def clear_last_selection(event): if len(selected_index_ranges): selected_index_ranges.pop(-1) - draw_selections(ax) + else: + while len(current_include): + current_include.pop() + while len(current_exclude): + current_exclude.pop() + selected_mask.fill(False) + draw_selections(ax, current_include, current_exclude, selected_index_ranges) - def update_mask(mask): + def update_mask(mask, selected_index_ranges, unselected_index_ranges): for (low, upp) in selected_index_ranges: selected_mask = np.logical_and(xdata >= xdata[low], xdata <= xdata[upp]) mask = np.logical_or(mask, selected_mask) for (low, upp) in unselected_index_ranges: unselected_mask = np.logical_and(xdata >= xdata[low], xdata <= xdata[upp]) mask[unselected_mask] = False - return mask + return(mask) def update_index_ranges(mask): # Update the currently included index ranges (where mask is True) @@ -603,34 +1001,34 @@ current_include[-1] = (current_include[-1], i-1) if len(current_include) > 0 and isinstance(current_include[-1], int): current_include[-1] = (current_include[-1], num_data-1) - return current_include + return(current_include) - # Check for valid inputs + # Check inputs ydata = np.asarray(ydata) if ydata.ndim > 1: - logging.warning(f'Illegal ydata dimension ({ydata.ndim})') - return None, None + logger.warning(f'Invalid ydata dimension ({ydata.ndim})') + return(None, None) num_data = ydata.size if xdata is None: xdata = np.arange(num_data) else: xdata = np.asarray(xdata, dtype=np.float64) if xdata.ndim > 1 or xdata.size != num_data: - logging.warning(f'Illegal xdata shape ({xdata.shape})') - return None, None + logger.warning(f'Invalid xdata shape ({xdata.shape})') + return(None, None) if not np.all(xdata[:-1] < xdata[1:]): - logging.warning('Illegal xdata: must be monotonically increasing') - return None, None + logger.warning('Invalid xdata: must be monotonically increasing') + return(None, None) if current_index_ranges is not None: if not isinstance(current_index_ranges, (tuple, list)): - logging.warning('Illegal current_index_ranges parameter ({current_index_ranges}, '+ + logger.warning('Invalid current_index_ranges parameter ({current_index_ranges}, '+ f'{type(current_index_ranges)})') - return None, None + return(None, None) if not isinstance(select_mask, bool): - logging.warning('Illegal select_mask parameter ({select_mask}, {type(select_mask)})') - return None, None + logger.warning('Invalid select_mask parameter ({select_mask}, {type(select_mask)})') + return(None, None) if num_index_ranges_max is not None: - logging.warning('num_index_ranges_max input not yet implemented in draw_mask_1d') + logger.warning('num_index_ranges_max input not yet implemented in draw_mask_1d') if title is None: title = 'select ranges of data' elif not isinstance(title, str): @@ -668,7 +1066,7 @@ if upp >= num_data: upp = num_data-1 selected_index_ranges.append((low, upp)) - selected_mask = update_mask(selected_mask) + selected_mask = update_mask(selected_mask, selected_index_ranges, unselected_index_ranges) if current_index_ranges is not None and current_mask is not None: selected_mask = np.logical_and(current_mask, selected_mask) if current_mask is not None: @@ -697,7 +1095,7 @@ plt.close('all') fig, ax = plt.subplots() plt.subplots_adjust(bottom=0.2) - draw_selections(ax) + draw_selections(ax, current_include, current_exclude, selected_index_ranges) # Set up event handling for click-and-drag range selection cid_click = fig.canvas.mpl_connect('button_press_event', onclick) @@ -724,251 +1122,364 @@ selected_index_ranges # Update the mask with the currently selected/unselected x-ranges - selected_mask = update_mask(selected_mask) + selected_mask = update_mask(selected_mask, selected_index_ranges, unselected_index_ranges) # Update the currently included index ranges (where mask is True) current_include = update_index_ranges(selected_mask) + + return(selected_mask, current_include) - return selected_mask, current_include +def select_peaks(ydata:np.ndarray, x_values:np.ndarray=None, x_mask:np.ndarray=None, + peak_x_values:np.ndarray=np.array([]), peak_x_indices:np.ndarray=np.array([]), + return_peak_x_values:bool=False, return_peak_x_indices:bool=False, + return_peak_input_indices:bool=False, return_sorted:bool=False, + title:str=None, xlabel:str=None, ylabel:str=None) -> list : + + # Check arguments + if (len(peak_x_values) > 0 or return_peak_x_values) and not len(x_values) > 0: + raise RuntimeError('Cannot use peak_x_values or return_peak_x_values without x_values') + if not ((len(peak_x_values) > 0) ^ (len(peak_x_indices) > 0)): + raise RuntimeError('Use exactly one of peak_x_values or peak_x_indices') + return_format_iter = iter((return_peak_x_values, return_peak_x_indices, return_peak_input_indices)) + if not (any(return_format_iter) and not any(return_format_iter)): + raise RuntimeError('Exactly one of return_peak_x_values, return_peak_x_indices, or '+ + 'return_peak_input_indices must be True') + + EXCLUDE_PEAK_PROPERTIES = {'color': 'black', 'linestyle': '--','linewidth': 1, + 'marker': 10, 'markersize': 5, 'fillstyle': 'none'} + INCLUDE_PEAK_PROPERTIES = {'color': 'green', 'linestyle': '-', 'linewidth': 2, + 'marker': 10, 'markersize': 10, 'fillstyle': 'full'} + MASKED_PEAK_PROPERTIES = {'color': 'gray', 'linestyle': ':', 'linewidth': 1} + + # Setup reference data & plot + x_indices = np.arange(len(ydata)) + if x_values is None: + x_values = x_indices + if x_mask is None: + x_mask = np.full(x_values.shape, True, dtype=bool) + fig, ax = plt.subplots() + handles = ax.plot(x_values, ydata, label='Reference data') + handles.append(mlines.Line2D([], [], label='Excluded / unselected HKL', **EXCLUDE_PEAK_PROPERTIES)) + handles.append(mlines.Line2D([], [], label='Included / selected HKL', **INCLUDE_PEAK_PROPERTIES)) + handles.append(mlines.Line2D([], [], label='HKL in masked region (unselectable)', **MASKED_PEAK_PROPERTIES)) + ax.legend(handles=handles, loc='upper right') + ax.set(title=title, xlabel=xlabel, ylabel=ylabel) + + + # Plot vertical line at each peak + value_to_index = lambda x_value: int(np.argmin(abs(x_values - x_value))) + if len(peak_x_indices) > 0: + peak_x_values = x_values[peak_x_indices] + else: + peak_x_indices = np.array(list(map(value_to_index, peak_x_values))) + peak_vlines = [] + for loc in peak_x_values: + nearest_index = value_to_index(loc) + if nearest_index in x_indices[x_mask]: + peak_vline = ax.axvline(loc, **EXCLUDE_PEAK_PROPERTIES) + peak_vline.set_picker(5) + else: + peak_vline = ax.axvline(loc, **MASKED_PEAK_PROPERTIES) + peak_vlines.append(peak_vline) -def findImageFiles(path, filetype, name=None): + # Indicate masked regions by gray-ing out the axes facecolor + mask_exclude_bounds = bounds_from_mask(x_mask, return_include_bounds=False) + for (low, upp) in mask_exclude_bounds: + xlow = x_values[low] + xupp = x_values[upp] + ax.axvspan(xlow, xupp, facecolor='gray', alpha=0.5) + + # Setup peak picking + selected_peak_input_indices = [] + def onpick(event): + try: + peak_index = peak_vlines.index(event.artist) + except: + pass + else: + peak_vline = event.artist + if peak_index in selected_peak_input_indices: + peak_vline.set(**EXCLUDE_PEAK_PROPERTIES) + selected_peak_input_indices.remove(peak_index) + else: + peak_vline.set(**INCLUDE_PEAK_PROPERTIES) + selected_peak_input_indices.append(peak_index) + plt.draw() + cid_pick_peak = fig.canvas.mpl_connect('pick_event', onpick) + + # Setup "Confirm" button + def confirm_selection(event): + plt.close() + plt.subplots_adjust(bottom=0.2) + confirm_b = Button(plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm') + cid_confirm = confirm_b.on_clicked(confirm_selection) + + # Show figure for user interaction + plt.show() + + # Disconnect callbacks when figure is closed + fig.canvas.mpl_disconnect(cid_pick_peak) + confirm_b.disconnect(cid_confirm) + + if return_peak_input_indices: + selected_peaks = np.array(selected_peak_input_indices) + if return_peak_x_values: + selected_peaks = peak_x_values[selected_peak_input_indices] + if return_peak_x_indices: + selected_peaks = peak_x_indices[selected_peak_input_indices] + + if return_sorted: + selected_peaks.sort() + + return(selected_peaks) + +def find_image_files(path, filetype, name=None): if isinstance(name, str): - name = f' {name} ' + name = f'{name.strip()} ' else: - name = ' ' + name = '' # Find available index range if filetype == 'tif': if not isinstance(path, str) or not os.path.isdir(path): - illegal_value(path, 'path', 'findImageFiles') - return -1, 0, [] + illegal_value(path, 'path', 'find_image_files') + return(-1, 0, []) indexRegex = re.compile(r'\d+') # At this point only tiffs files = sorted([f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) and f.endswith('.tif') and indexRegex.search(f)]) - num_imgs = len(files) - if num_imgs < 1: - logging.warning('No available'+name+'files') - return -1, 0, [] + num_img = len(files) + if num_img < 1: + logger.warning(f'No available {name}files') + return(-1, 0, []) first_index = indexRegex.search(files[0]).group() last_index = indexRegex.search(files[-1]).group() if first_index is None or last_index is None: - logging.error('Unable to find correctly indexed'+name+'images') - return -1, 0, [] + logger.error(f'Unable to find correctly indexed {name}images') + return(-1, 0, []) first_index = int(first_index) last_index = int(last_index) - if num_imgs != last_index-first_index+1: - logging.error('Non-consecutive set of indices for'+name+'images') - return -1, 0, [] + if num_img != last_index-first_index+1: + logger.error(f'Non-consecutive set of indices for {name}images') + return(-1, 0, []) paths = [os.path.join(path, f) for f in files] elif filetype == 'h5': if not isinstance(path, str) or not os.path.isfile(path): - illegal_value(path, 'path', 'findImageFiles') - return -1, 0, [] + illegal_value(path, 'path', 'find_image_files') + return(-1, 0, []) # At this point only h5 in alamo2 detector style first_index = 0 with h5py.File(path, 'r') as f: - num_imgs = f['entry/instrument/detector/data'].shape[0] - last_index = num_imgs-1 + num_img = f['entry/instrument/detector/data'].shape[0] + last_index = num_img-1 paths = [path] else: - illegal_value(filetype, 'filetype', 'findImageFiles') - return -1, 0, [] - logging.debug('\nNumber of available'+name+f'images: {num_imgs}') - logging.debug('Index range of available'+name+f'images: [{first_index}, '+ + illegal_value(filetype, 'filetype', 'find_image_files') + return(-1, 0, []) + logger.info(f'Number of available {name}images: {num_img}') + logger.info(f'Index range of available {name}images: [{first_index}, '+ f'{last_index}]') - return first_index, num_imgs, paths + return(first_index, num_img, paths) -def selectImageRange(first_index, offset, num_imgs, name=None, num_required=None): +def select_image_range(first_index, offset, num_available, num_img=None, name=None, + num_required=None): if isinstance(name, str): - name = f' {name} ' + name = f'{name.strip()} ' else: - name = ' ' + name = '' # Check existing values - use_input = False - if (is_int(first_index, 0) and is_int(offset, 0) and is_int(num_imgs, 1)): - if offset < 0: - use_input = input_yesno(f'\nCurrent{name}first index = {first_index}, '+ - 'use this value (y/n)?', 'y') + if not is_int(num_available, gt=0): + logger.warning(f'No available {name}images') + return(0, 0, 0) + if num_img is not None and not is_int(num_img, ge=0): + illegal_value(num_img, 'num_img', 'select_image_range') + return(0, 0, 0) + if is_int(first_index, ge=0) and is_int(offset, ge=0): + if num_required is None: + if input_yesno(f'\nCurrent {name}first image index/offset = {first_index}/{offset},'+ + 'use these values (y/n)?', 'y'): + if num_img is not None: + if input_yesno(f'Current number of {name}images = {num_img}, '+ + 'use this value (y/n)? ', 'y'): + return(first_index, offset, num_img) + else: + if input_yesno(f'Number of available {name}images = {num_available}, '+ + 'use all (y/n)? ', 'y'): + return(first_index, offset, num_available) else: - use_input = input_yesno(f'\nCurrent{name}first index/offset = '+ - f'{first_index}/{offset}, use these values (y/n)?', 'y') - if num_required is None: - if use_input: - use_input = input_yesno(f'Current number of{name}images = '+ - f'{num_imgs}, use this value (y/n)? ', 'y') - if use_input: - return first_index, offset, num_imgs + if input_yesno(f'\nCurrent {name}first image offset = {offset}, '+ + f'use this values (y/n)?', 'y'): + return(first_index, offset, num_required) # Check range against requirements - if num_imgs < 1: - logging.warning('No available'+name+'images') - return -1, -1, 0 if num_required is None: - if num_imgs == 1: - return first_index, 0, 1 + if num_available == 1: + return(first_index, 0, 1) else: - if not is_int(num_required, 1): - illegal_value(num_required, 'num_required', 'selectImageRange') - return -1, -1, 0 - if num_imgs < num_required: - logging.error('Unable to find the required'+name+ - f'images ({num_imgs} out of {num_required})') - return -1, -1, 0 + if not is_int(num_required, ge=1): + illegal_value(num_required, 'num_required', 'select_image_range') + return(0, 0, 0) + if num_available < num_required: + logger.error(f'Unable to find the required {name}images ({num_available} out of '+ + f'{num_required})') + return(0, 0, 0) # Select index range - print('\nThe number of available'+name+f'images is {num_imgs}') + print(f'\nThe number of available {name}images is {num_available}') if num_required is None: - last_index = first_index+num_imgs + last_index = first_index+num_available use_all = f'Use all ([{first_index}, {last_index}])' - pick_offset = 'Pick a first index offset and a number of images' - pick_bounds = 'Pick the first and last index' + pick_offset = 'Pick the first image index offset and the number of images' + pick_bounds = 'Pick the first and last image index' choice = input_menu([use_all, pick_offset, pick_bounds], default=pick_offset) if not choice: offset = 0 + num_img = num_available elif choice == 1: - offset = input_int('Enter the first index offset', 0, last_index-first_index) - first_index += offset - if first_index == last_index: - num_imgs = 1 + offset = input_int('Enter the first index offset', ge=0, le=last_index-first_index) + if first_index+offset == last_index: + num_img = 1 else: - num_imgs = input_int('Enter the number of images', 1, num_imgs-offset) + num_img = input_int('Enter the number of images', ge=1, le=num_available-offset) else: - offset = input_int('Enter the first index', first_index, last_index) - first_index += offset - num_imgs = input_int('Enter the last index', first_index, last_index)-first_index+1 + offset = input_int('Enter the first index', ge=first_index, le=last_index) + num_img = 1-offset+input_int('Enter the last index', ge=offset, le=last_index) + offset -= first_index else: use_all = f'Use ([{first_index}, {first_index+num_required-1}])' pick_offset = 'Pick the first index offset' choice = input_menu([use_all, pick_offset], pick_offset) offset = 0 if choice == 1: - offset = input_int('Enter the first index offset', 0, num_imgs-num_required) - first_index += offset - num_imgs = num_required + offset = input_int('Enter the first index offset', ge=0, le=num_available-num_required) + num_img = num_required - return first_index, offset, num_imgs + return(first_index, offset, num_img) -def loadImage(f, img_x_bounds=None, img_y_bounds=None): +def load_image(f, img_x_bounds=None, img_y_bounds=None): """Load a single image from file. """ if not os.path.isfile(f): - logging.error(f'Unable to load {f}') - return None + logger.error(f'Unable to load {f}') + return(None) img_read = plt.imread(f) if not img_x_bounds: img_x_bounds = (0, img_read.shape[0]) else: if (not isinstance(img_x_bounds, (tuple, list)) or len(img_x_bounds) != 2 or not (0 <= img_x_bounds[0] < img_x_bounds[1] <= img_read.shape[0])): - logging.error(f'inconsistent row dimension in {f}') - return None + logger.error(f'inconsistent row dimension in {f}') + return(None) if not img_y_bounds: img_y_bounds = (0, img_read.shape[1]) else: if (not isinstance(img_y_bounds, list) or len(img_y_bounds) != 2 or not (0 <= img_y_bounds[0] < img_y_bounds[1] <= img_read.shape[1])): - logging.error(f'inconsistent column dimension in {f}') - return None - return img_read[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]] + logger.error(f'inconsistent column dimension in {f}') + return(None) + return(img_read[img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]]) -def loadImageStack(files, filetype, img_offset, num_imgs, num_img_skip=0, +def load_image_stack(files, filetype, img_offset, num_img, num_img_skip=0, img_x_bounds=None, img_y_bounds=None): """Load a set of images and return them as a stack. """ - logging.debug(f'img_offset = {img_offset}') - logging.debug(f'num_imgs = {num_imgs}') - logging.debug(f'num_img_skip = {num_img_skip}') - logging.debug(f'\nfiles:\n{files}\n') + logger.debug(f'img_offset = {img_offset}') + logger.debug(f'num_img = {num_img}') + logger.debug(f'num_img_skip = {num_img_skip}') + logger.debug(f'\nfiles:\n{files}\n') img_stack = np.array([]) if filetype == 'tif': img_read_stack = [] i = 1 t0 = time() - for f in files[img_offset:img_offset+num_imgs:num_img_skip+1]: + for f in files[img_offset:img_offset+num_img:num_img_skip+1]: if not i%20: - logging.info(f' loading {i}/{num_imgs}: {f}') + logger.info(f' loading {i}/{num_img}: {f}') else: - logging.debug(f' loading {i}/{num_imgs}: {f}') - img_read = loadImage(f, img_x_bounds, img_y_bounds) + logger.debug(f' loading {i}/{num_img}: {f}') + img_read = load_image(f, img_x_bounds, img_y_bounds) img_read_stack.append(img_read) i += num_img_skip+1 img_stack = np.stack([img_read for img_read in img_read_stack]) - logging.info(f'... done in {time()-t0:.2f} seconds!') - logging.debug(f'img_stack shape = {np.shape(img_stack)}') + logger.info(f'... done in {time()-t0:.2f} seconds!') + logger.debug(f'img_stack shape = {np.shape(img_stack)}') del img_read_stack, img_read elif filetype == 'h5': if not isinstance(files[0], str) and not os.path.isfile(files[0]): - illegal_value(files[0], 'files[0]', 'loadImageStack') - return img_stack + illegal_value(files[0], 'files[0]', 'load_image_stack') + return(img_stack) t0 = time() - logging.info(f'Loading {files[0]}') + logger.info(f'Loading {files[0]}') with h5py.File(files[0], 'r') as f: shape = f['entry/instrument/detector/data'].shape if len(shape) != 3: - logging.error(f'inconsistent dimensions in {files[0]}') + logger.error(f'inconsistent dimensions in {files[0]}') if not img_x_bounds: img_x_bounds = (0, shape[1]) else: if (not isinstance(img_x_bounds, (tuple, list)) or len(img_x_bounds) != 2 or not (0 <= img_x_bounds[0] < img_x_bounds[1] <= shape[1])): - logging.error(f'inconsistent row dimension in {files[0]} {img_x_bounds} '+ + logger.error(f'inconsistent row dimension in {files[0]} {img_x_bounds} '+ f'{shape[1]}') if not img_y_bounds: img_y_bounds = (0, shape[2]) else: if (not isinstance(img_y_bounds, list) or len(img_y_bounds) != 2 or not (0 <= img_y_bounds[0] < img_y_bounds[1] <= shape[2])): - logging.error(f'inconsistent column dimension in {files[0]}') + logger.error(f'inconsistent column dimension in {files[0]}') img_stack = f.get('entry/instrument/detector/data')[ - img_offset:img_offset+num_imgs:num_img_skip+1, + img_offset:img_offset+num_img:num_img_skip+1, img_x_bounds[0]:img_x_bounds[1],img_y_bounds[0]:img_y_bounds[1]] - logging.info(f'... done in {time()-t0:.2f} seconds!') + logger.info(f'... done in {time()-t0:.2f} seconds!') else: - illegal_value(filetype, 'filetype', 'loadImageStack') - return img_stack + illegal_value(filetype, 'filetype', 'load_image_stack') + return(img_stack) -def combine_tiffs_in_h5(files, num_imgs, h5_filename): - img_stack = loadImageStack(files, 'tif', 0, num_imgs) +def combine_tiffs_in_h5(files, num_img, h5_filename): + img_stack = load_image_stack(files, 'tif', 0, num_img) with h5py.File(h5_filename, 'w') as f: f.create_dataset('entry/instrument/detector/data', data=img_stack) del img_stack - return [h5_filename] + return([h5_filename]) -def clearImshow(title=None): +def clear_imshow(title=None): plt.ioff() if title is None: title = 'quick imshow' elif not isinstance(title, str): - illegal_value(title, 'title', 'clearImshow') + illegal_value(title, 'title', 'clear_imshow') return plt.close(fig=title) -def clearPlot(title=None): +def clear_plot(title=None): plt.ioff() if title is None: title = 'quick plot' elif not isinstance(title, str): - illegal_value(title, 'title', 'clearPlot') + illegal_value(title, 'title', 'clear_plot') return plt.close(fig=title) -def quickImshow(a, title=None, path=None, name=None, save_fig=False, save_only=False, - clear=True, extent=None, show_grid=False, grid_color='w', grid_linewidth=1, **kwargs): +def quick_imshow(a, title=None, path=None, name=None, save_fig=False, save_only=False, + clear=True, extent=None, show_grid=False, grid_color='w', grid_linewidth=1, + block=False, **kwargs): if title is not None and not isinstance(title, str): - illegal_value(title, 'title', 'quickImshow') + illegal_value(title, 'title', 'quick_imshow') return if path is not None and not isinstance(path, str): - illegal_value(path, 'path', 'quickImshow') + illegal_value(path, 'path', 'quick_imshow') return if not isinstance(save_fig, bool): - illegal_value(save_fig, 'save_fig', 'quickImshow') + illegal_value(save_fig, 'save_fig', 'quick_imshow') return if not isinstance(save_only, bool): - illegal_value(save_only, 'save_only', 'quickImshow') + illegal_value(save_only, 'save_only', 'quick_imshow') return if not isinstance(clear, bool): - illegal_value(clear, 'clear', 'quickImshow') + illegal_value(clear, 'clear', 'quick_imshow') + return + if not isinstance(block, bool): + illegal_value(block, 'block', 'quick_imshow') return if not title: title='quick imshow' @@ -985,12 +1496,30 @@ path = name else: path = f'{path}/{name}' + if 'cmap' in kwargs and a.ndim == 3 and (a.shape[2] == 3 or a.shape[2] == 4): + use_cmap = True + if a.shape[2] == 4 and a[:,:,-1].min() != a[:,:,-1].max(): + use_cmap = False + if any(True if a[i,j,0] != a[i,j,1] and a[i,j,0] != a[i,j,2] else False + for i in range(a.shape[0]) for j in range(a.shape[1])): + use_cmap = False + if use_cmap: + a = a[:,:,0] + else: + logger.warning('Image incompatible with cmap option, ignore cmap') + kwargs.pop('cmap') if extent is None: extent = (0, a.shape[1], a.shape[0], 0) if clear: - plt.close(fig=title) + try: + plt.close(fig=title) + except: + pass if not save_only: - plt.ion() + if block: + plt.ioff() + else: + plt.ion() plt.figure(title) plt.imshow(a, extent=extent, **kwargs) if show_grid: @@ -1004,45 +1533,47 @@ else: if save_fig: plt.savefig(path) + if block: + plt.show(block=block) -def quickPlot(*args, xerr=None, yerr=None, vlines=None, title=None, xlim=None, ylim=None, - xlabel=None, ylabel=None, legend=None, path=None, name=None, show_grid=False, +def quick_plot(*args, xerr=None, yerr=None, vlines=None, title=None, xlim=None, ylim=None, + xlabel=None, ylabel=None, legend=None, path=None, name=None, show_grid=False, save_fig=False, save_only=False, clear=True, block=False, **kwargs): if title is not None and not isinstance(title, str): - illegal_value(title, 'title', 'quickPlot') + illegal_value(title, 'title', 'quick_plot') title = None if xlim is not None and not isinstance(xlim, (tuple, list)) and len(xlim) != 2: - illegal_value(xlim, 'xlim', 'quickPlot') + illegal_value(xlim, 'xlim', 'quick_plot') xlim = None if ylim is not None and not isinstance(ylim, (tuple, list)) and len(ylim) != 2: - illegal_value(ylim, 'ylim', 'quickPlot') + illegal_value(ylim, 'ylim', 'quick_plot') ylim = None if xlabel is not None and not isinstance(xlabel, str): - illegal_value(xlabel, 'xlabel', 'quickPlot') + illegal_value(xlabel, 'xlabel', 'quick_plot') xlabel = None if ylabel is not None and not isinstance(ylabel, str): - illegal_value(ylabel, 'ylabel', 'quickPlot') + illegal_value(ylabel, 'ylabel', 'quick_plot') ylabel = None if legend is not None and not isinstance(legend, (tuple, list)): - illegal_value(legend, 'legend', 'quickPlot') + illegal_value(legend, 'legend', 'quick_plot') legend = None if path is not None and not isinstance(path, str): - illegal_value(path, 'path', 'quickPlot') + illegal_value(path, 'path', 'quick_plot') return if not isinstance(show_grid, bool): - illegal_value(show_grid, 'show_grid', 'quickPlot') + illegal_value(show_grid, 'show_grid', 'quick_plot') return if not isinstance(save_fig, bool): - illegal_value(save_fig, 'save_fig', 'quickPlot') + illegal_value(save_fig, 'save_fig', 'quick_plot') return if not isinstance(save_only, bool): - illegal_value(save_only, 'save_only', 'quickPlot') + illegal_value(save_only, 'save_only', 'quick_plot') return if not isinstance(clear, bool): - illegal_value(clear, 'clear', 'quickPlot') + illegal_value(clear, 'clear', 'quick_plot') return if not isinstance(block, bool): - illegal_value(block, 'block', 'quickPlot') + illegal_value(block, 'block', 'quick_plot') return if title is None: title = 'quick plot' @@ -1060,10 +1591,13 @@ else: path = f'{path}/{name}' if clear: - plt.close(fig=title) + try: + plt.close(fig=title) + except: + pass args = unwrap_tuple(args) if depth_tuple(args) > 1 and (xerr is not None or yerr is not None): - logging.warning('Error bars ignored form multiple curves') + logger.warning('Error bars ignored form multiple curves') if not save_only: if block: plt.ioff() @@ -1079,6 +1613,8 @@ else: plt.errorbar(*args, xerr=xerr, yerr=yerr, **kwargs) if vlines is not None: + if isinstance(vlines, (int, float)): + vlines = [vlines] for v in vlines: plt.axvline(v, color='r', linestyle='--', **kwargs) # if vlines is not None: @@ -1106,108 +1642,97 @@ if block: plt.show(block=block) -def selectArrayBounds(a, x_low=None, x_upp=None, num_x_min=None, ask_bounds=False, +def select_array_bounds(a, x_low=None, x_upp=None, num_x_min=None, ask_bounds=False, title='select array bounds'): """Interactively select the lower and upper data bounds for a numpy array. """ if isinstance(a, (tuple, list)): a = np.array(a) if not isinstance(a, np.ndarray) or a.ndim != 1: - illegal_value(a.ndim, 'array type or dimension', 'selectArrayBounds') - return None + illegal_value(a.ndim, 'array type or dimension', 'select_array_bounds') + return(None) len_a = len(a) if num_x_min is None: num_x_min = 1 else: if num_x_min < 2 or num_x_min > len_a: - logging.warning('Illegal value for num_x_min in selectArrayBounds, input ignored') + logger.warning('Invalid value for num_x_min in select_array_bounds, input ignored') num_x_min = 1 # Ask to use current bounds if ask_bounds and (x_low is not None or x_upp is not None): if x_low is None: x_low = 0 - if not is_int(x_low, 0, len_a-num_x_min): - illegal_value(x_low, 'x_low', 'selectArrayBounds') - return None + if not is_int(x_low, ge=0, le=len_a-num_x_min): + illegal_value(x_low, 'x_low', 'select_array_bounds') + return(None) if x_upp is None: x_upp = len_a - if not is_int(x_upp, x_low+num_x_min, len_a): - illegal_value(x_upp, 'x_upp', 'selectArrayBounds') - return None - quickPlot((range(len_a), a), vlines=(x_low,x_upp), title=title) + if not is_int(x_upp, ge=x_low+num_x_min, le=len_a): + illegal_value(x_upp, 'x_upp', 'select_array_bounds') + return(None) + quick_plot((range(len_a), a), vlines=(x_low,x_upp), title=title) if not input_yesno(f'\nCurrent array bounds: [{x_low}, {x_upp}] '+ 'use these values (y/n)?', 'y'): x_low = None x_upp = None else: - clearPlot(title) - return x_low, x_upp + clear_plot(title) + return(x_low, x_upp) if x_low is None: x_min = 0 x_max = len_a x_low_max = len_a-num_x_min while True: - quickPlot(range(x_min, x_max), a[x_min:x_max], title=title) + quick_plot(range(x_min, x_max), a[x_min:x_max], title=title) zoom_flag = input_yesno('Set lower data bound (y) or zoom in (n)?', 'y') if zoom_flag: - x_low = input_int(' Set lower data bound', 0, x_low_max) + x_low = input_int(' Set lower data bound', ge=0, le=x_low_max) break else: - x_min = input_int(' Set lower zoom index', 0, x_low_max) - x_max = input_int(' Set upper zoom index', x_min+1, x_low_max+1) + x_min = input_int(' Set lower zoom index', ge=0, le=x_low_max) + x_max = input_int(' Set upper zoom index', ge=x_min+1, le=x_low_max+1) else: - if not is_int(x_low, 0, len_a-num_x_min): - illegal_value(x_low, 'x_low', 'selectArrayBounds') - return None + if not is_int(x_low, ge=0, le=len_a-num_x_min): + illegal_value(x_low, 'x_low', 'select_array_bounds') + return(None) if x_upp is None: x_min = x_low+num_x_min x_max = len_a x_upp_min = x_min while True: - quickPlot(range(x_min, x_max), a[x_min:x_max], title=title) + quick_plot(range(x_min, x_max), a[x_min:x_max], title=title) zoom_flag = input_yesno('Set upper data bound (y) or zoom in (n)?', 'y') if zoom_flag: - x_upp = input_int(' Set upper data bound', x_upp_min, len_a) + x_upp = input_int(' Set upper data bound', ge=x_upp_min, le=len_a) break else: - x_min = input_int(' Set upper zoom index', x_upp_min, len_a-1) - x_max = input_int(' Set upper zoom index', x_min+1, len_a) + x_min = input_int(' Set upper zoom index', ge=x_upp_min, le=len_a-1) + x_max = input_int(' Set upper zoom index', ge=x_min+1, le=len_a) else: - if not is_int(x_upp, x_low+num_x_min, len_a): - illegal_value(x_upp, 'x_upp', 'selectArrayBounds') - return None + if not is_int(x_upp, ge=x_low+num_x_min, le=len_a): + illegal_value(x_upp, 'x_upp', 'select_array_bounds') + return(None) print(f'lower bound = {x_low} (inclusive)\nupper bound = {x_upp} (exclusive)]') - quickPlot((range(len_a), a), vlines=(x_low,x_upp), title=title) + quick_plot((range(len_a), a), vlines=(x_low,x_upp), title=title) if not input_yesno('Accept these bounds (y/n)?', 'y'): - x_low, x_upp = selectArrayBounds(a, None, None, num_x_min, title=title) - clearPlot(title) - return x_low, x_upp + x_low, x_upp = select_array_bounds(a, None, None, num_x_min, title=title) + clear_plot(title) + return(x_low, x_upp) -def selectImageBounds(a, axis, low=None, upp=None, num_min=None, - title='select array bounds'): +def select_image_bounds(a, axis, low=None, upp=None, num_min=None, title='select array bounds', + raise_error=False): """Interactively select the lower and upper data bounds for a 2D numpy array. """ - if isinstance(a, np.ndarray): - if a.ndim != 2: - illegal_value(a.ndim, 'array dimension', 'selectImageBounds') - return None - elif isinstance(a, (tuple, list)): - if len(a) != 2: - illegal_value(len(a), 'array dimension', 'selectImageBounds') - return None - if len(a[0]) != len(a[1]) or not (isinstance(a[0], (tuple, list, np.ndarray)) and - isinstance(a[1], (tuple, list, np.ndarray))): - logging.error(f'Illegal array type in selectImageBounds ({type(a[0])} {type(a[1])})') - return None - a = np.array(a) - else: - illegal_value(a, 'array type', 'selectImageBounds') - return None + a = np.asarray(a) + if a.ndim != 2: + illegal_value(a.ndim, 'array dimension', location='select_image_bounds', + raise_error=raise_error) + return(None) if axis < 0 or axis >= a.ndim: - illegal_value(axis, 'axis', 'selectImageBounds') - return None + illegal_value(axis, 'axis', location='select_image_bounds', raise_error=raise_error) + return(None) low_save = low upp_save = upp num_min_save = num_min @@ -1215,7 +1740,7 @@ num_min = 1 else: if num_min < 2 or num_min > a.shape[axis]: - logging.warning('Illegal input for num_min in selectImageBounds, input ignored') + logger.warning('Invalid input for num_min in select_image_bounds, input ignored') num_min = 1 if low is None: min_ = 0 @@ -1223,44 +1748,44 @@ low_max = a.shape[axis]-num_min while True: if axis: - quickImshow(a[:,min_:max_], title=title, aspect='auto', + quick_imshow(a[:,min_:max_], title=title, aspect='auto', extent=[min_,max_,a.shape[0],0]) else: - quickImshow(a[min_:max_,:], title=title, aspect='auto', + quick_imshow(a[min_:max_,:], title=title, aspect='auto', extent=[0,a.shape[1], max_,min_]) zoom_flag = input_yesno('Set lower data bound (y) or zoom in (n)?', 'y') if zoom_flag: - low = input_int(' Set lower data bound', 0, low_max) + low = input_int(' Set lower data bound', ge=0, le=low_max) break else: - min_ = input_int(' Set lower zoom index', 0, low_max) - max_ = input_int(' Set upper zoom index', min_+1, low_max+1) + min_ = input_int(' Set lower zoom index', ge=0, le=low_max) + max_ = input_int(' Set upper zoom index', ge=min_+1, le=low_max+1) else: - if not is_int(low, 0, a.shape[axis]-num_min): - illegal_value(low, 'low', 'selectImageBounds') - return None + if not is_int(low, ge=0, le=a.shape[axis]-num_min): + illegal_value(low, 'low', location='select_image_bounds', raise_error=raise_error) + return(None) if upp is None: min_ = low+num_min max_ = a.shape[axis] upp_min = min_ while True: if axis: - quickImshow(a[:,min_:max_], title=title, aspect='auto', + quick_imshow(a[:,min_:max_], title=title, aspect='auto', extent=[min_,max_,a.shape[0],0]) else: - quickImshow(a[min_:max_,:], title=title, aspect='auto', + quick_imshow(a[min_:max_,:], title=title, aspect='auto', extent=[0,a.shape[1], max_,min_]) zoom_flag = input_yesno('Set upper data bound (y) or zoom in (n)?', 'y') if zoom_flag: - upp = input_int(' Set upper data bound', upp_min, a.shape[axis]) + upp = input_int(' Set upper data bound', ge=upp_min, le=a.shape[axis]) break else: - min_ = input_int(' Set upper zoom index', upp_min, a.shape[axis]-1) - max_ = input_int(' Set upper zoom index', min_+1, a.shape[axis]) + min_ = input_int(' Set upper zoom index', ge=upp_min, le=a.shape[axis]-1) + max_ = input_int(' Set upper zoom index', ge=min_+1, le=a.shape[axis]) else: - if not is_int(upp, low+num_min, a.shape[axis]): - illegal_value(upp, 'upp', 'selectImageBounds') - return None + if not is_int(upp, ge=low+num_min, le=a.shape[axis]): + illegal_value(upp, 'upp', location='select_image_bounds', raise_error=raise_error) + return(None) bounds = (low, upp) a_tmp = np.copy(a) a_tmp_max = a.max() @@ -1271,12 +1796,64 @@ a_tmp[bounds[0],:] = a_tmp_max a_tmp[bounds[1]-1,:] = a_tmp_max print(f'lower bound = {low} (inclusive)\nupper bound = {upp} (exclusive)') - quickImshow(a_tmp, title=title) + quick_imshow(a_tmp, title=title, aspect='auto') del a_tmp if not input_yesno('Accept these bounds (y/n)?', 'y'): - bounds = selectImageBounds(a, axis, low=low_save, upp=upp_save, num_min=num_min_save, + bounds = select_image_bounds(a, axis, low=low_save, upp=upp_save, num_min=num_min_save, title=title) - return bounds + return(bounds) + +def select_one_image_bound(a, axis, bound=None, bound_name=None, title='select array bounds', + default='y', raise_error=False): + """Interactively select a data boundary for a 2D numpy array. + """ + a = np.asarray(a) + if a.ndim != 2: + illegal_value(a.ndim, 'array dimension', location='select_one_image_bound', + raise_error=raise_error) + return(None) + if axis < 0 or axis >= a.ndim: + illegal_value(axis, 'axis', location='select_one_image_bound', raise_error=raise_error) + return(None) + if bound_name is None: + bound_name = 'data bound' + if bound is None: + min_ = 0 + max_ = a.shape[axis] + bound_max = a.shape[axis]-1 + while True: + if axis: + quick_imshow(a[:,min_:max_], title=title, aspect='auto', + extent=[min_,max_,a.shape[0],0]) + else: + quick_imshow(a[min_:max_,:], title=title, aspect='auto', + extent=[0,a.shape[1], max_,min_]) + zoom_flag = input_yesno(f'Set {bound_name} (y) or zoom in (n)?', 'y') + if zoom_flag: + bound = input_int(f' Set {bound_name}', ge=0, le=bound_max) + clear_imshow(title) + break + else: + min_ = input_int(' Set lower zoom index', ge=0, le=bound_max) + max_ = input_int(' Set upper zoom index', ge=min_+1, le=bound_max+1) + + elif not is_int(bound, ge=0, le=a.shape[axis]-1): + illegal_value(bound, 'bound', location='select_one_image_bound', raise_error=raise_error) + return(None) + else: + print(f'Current {bound_name} = {bound}') + a_tmp = np.copy(a) + a_tmp_max = a.max() + if axis: + a_tmp[:,bound] = a_tmp_max + else: + a_tmp[bound,:] = a_tmp_max + quick_imshow(a_tmp, title=title, aspect='auto') + del a_tmp + if not input_yesno(f'Accept this {bound_name} (y/n)?', default): + bound = select_one_image_bound(a, axis, bound_name=bound_name, title=title) + clear_imshow(title) + return(bound) class Config: @@ -1289,79 +1866,79 @@ # Load config file if config_file is not None and config_dict is not None: - logging.warning('Ignoring config_dict (both config_file and config_dict are specified)') + logger.warning('Ignoring config_dict (both config_file and config_dict are specified)') if config_file is not None: - self.loadFile(config_file) + self.load_file(config_file) elif config_dict is not None: - self.loadDict(config_dict) + self.load_dict(config_dict) - def loadFile(self, config_file): + def load_file(self, config_file): """Load a config file. """ if self.load_flag: - logging.warning('Overwriting any previously loaded config file') + logger.warning('Overwriting any previously loaded config file') self.config = {} # Ensure config file exists if not os.path.isfile(config_file): - logging.error(f'Unable to load {config_file}') + logger.error(f'Unable to load {config_file}') return # Load config file (for now for Galaxy, allow .dat extension) self.suffix = os.path.splitext(config_file)[1] if self.suffix == '.yml' or self.suffix == '.yaml' or self.suffix == '.dat': with open(config_file, 'r') as f: - self.config = yaml.safe_load(f) + self.config = safe_load(f) elif self.suffix == '.txt': with open(config_file, 'r') as f: lines = f.read().splitlines() self.config = {item[0].strip():literal_eval(item[1].strip()) for item in [line.split('#')[0].split('=') for line in lines if '=' in line.split('#')[0]]} else: - illegal_value(self.suffix, 'config file extension', 'Config.loadFile') + illegal_value(self.suffix, 'config file extension', 'Config.load_file') # Make sure config file was correctly loaded if isinstance(self.config, dict): self.load_flag = True else: - logging.error(f'Unable to load dictionary from config file: {config_file}') + logger.error(f'Unable to load dictionary from config file: {config_file}') self.config = {} - def loadDict(self, config_dict): + def load_dict(self, config_dict): """Takes a dictionary and places it into self.config. """ if self.load_flag: - logging.warning('Overwriting the previously loaded config file') + logger.warning('Overwriting the previously loaded config file') if isinstance(config_dict, dict): self.config = config_dict self.load_flag = True else: - illegal_value(config_dict, 'dictionary config object', 'Config.loadDict') + illegal_value(config_dict, 'dictionary config object', 'Config.load_dict') self.config = {} - def saveFile(self, config_file): + def save_file(self, config_file): """Save the config file (as a yaml file only right now). """ suffix = os.path.splitext(config_file)[1] if suffix != '.yml' and suffix != '.yaml': - illegal_value(suffix, 'config file extension', 'Config.saveFile') + illegal_value(suffix, 'config file extension', 'Config.save_file') # Check if config file exists if os.path.isfile(config_file): - logging.info(f'Updating {config_file}') + logger.info(f'Updating {config_file}') else: - logging.info(f'Saving {config_file}') + logger.info(f'Saving {config_file}') # Save config file with open(config_file, 'w') as f: - yaml.safe_dump(self.config, f) + safe_dump(self.config, f) def validate(self, pars_required, pars_missing=None): """Returns False if any required keys are missing. """ if not self.load_flag: - logging.error('Load a config file prior to calling Config.validate') + logger.error('Load a config file prior to calling Config.validate') def validate_nested_pars(config, par): par_levels = par.split(':') @@ -1374,15 +1951,15 @@ next_level_config = config[first_level_par] if len(par_levels) > 1: next_level_par = ':'.join(par_levels[1:]) - return validate_nested_pars(next_level_config, next_level_par) + return(validate_nested_pars(next_level_config, next_level_par)) else: - return True + return(True) except: - return False + return(False) pars_missing = [p for p in pars_required if not validate_nested_pars(self.config, p)] if len(pars_missing) > 0: - logging.error(f'Missing item(s) in configuration: {", ".join(pars_missing)}') - return False + logger.error(f'Missing item(s) in configuration: {", ".join(pars_missing)}') + return(False) else: - return True + return(True)