from functools import wraps
import numpy as np
def masked_to_nan(arg):
"""
Convert a masked array to a float ndarray with nans; ensure
other arguments are float arrays or scalars.
"""
if np.ma.isMaskedArray(arg):
if arg.dtype.kind == 'f':
return arg.filled(np.nan)
else:
return arg.astype(float).filled(np.nan)
else:
return np.asarray(arg, dtype=float)
[docs]def match_args_return(f):
"""
Decorator for most functions that operate on profile data.
"""
@wraps(f)
def wrapper(*args, **kw):
p = kw.get('p', None)
if p is not None:
args = list(args)
args.append(p)
isarray = [hasattr(a, '__iter__') for a in args]
ismasked = [np.ma.isMaskedArray(a) for a in args]
isduck = [hasattr(a, '__array_ufunc__')
and not isinstance(a, np.ndarray) for a in args]
hasarray = np.any(isarray)
hasmasked = np.any(ismasked)
hasduck = np.any(isduck)
def fixup(ret):
if hasduck:
return ret
if hasmasked:
ret = np.ma.masked_invalid(ret)
if not hasarray and isinstance(ret, np.ndarray) and ret.size == 1:
try:
ret = ret[0]
except IndexError:
pass
return ret
newargs = []
for i, arg in enumerate(args):
if ismasked[i]:
newargs.append(masked_to_nan(arg))
elif isduck[i]:
newargs.append(arg)
else:
newargs.append(np.asarray(arg, dtype=float))
if p is not None:
kw['p'] = newargs.pop()
ret = f(*newargs, **kw)
if isinstance(ret, tuple):
retlist = [fixup(arg) for arg in ret]
ret = tuple(retlist)
else:
ret = fixup(ret)
return ret
wrapper.__wrapped__ = f
return wrapper
def axis_slicer(n, sl, axis):
"""
Return an indexing tuple for an array with `n` dimensions,
with slice `sl` taken on `axis`.
"""
itup = [slice(None)] * n
itup[axis] = sl
return tuple(itup)
[docs]def indexer(shape, axis, order='C'):
"""
Generator of indexing tuples for "apply_along_axis" usage.
The generator cycles through all axes other than `axis`.
The numpy np.apply_along_axis function only works with functions
of a single array; this generator allows us work with a function
of more than one array.
"""
ndim = len(shape)
ind_shape = list(shape)
ind_shape[axis] = 1 # "axis" and any dim of 1 will not be incremented
# list of indices, with a slice at "axis"
inds = [0] * ndim
inds[axis] = slice(None)
kmax = np.prod(ind_shape)
if order == 'C':
index_position = list(reversed(range(ndim)))
else:
index_position = list(range(ndim))
for k in range(kmax):
yield tuple(inds)
for i in index_position:
if ind_shape[i] == 1:
continue
inds[i] += 1
if inds[i] == ind_shape[i]:
inds[i] = 0
else:
break
# This is straight from pycurrents.system. We can trim out
# the parts we don't need, but there is no rush to do so.
class Bunch(dict):
"""
A dictionary that also provides access via attributes.
Additional methods update_values and update_None provide
control over whether new keys are added to the dictionary
when updating, and whether an attempt to add a new key is
ignored or raises a KeyError.
The Bunch also prints differently than a normal
dictionary, using str() instead of repr() for its
keys and values, and in key-sorted order. The printing
format can be customized by subclassing with a different
str_ftm class attribute. Do not assign directly to this
class attribute, because that would substitute an instance
attribute which would then become part of the Bunch, and
would be reported as such by the keys() method.
To output a string representation with
a particular format, without subclassing, use the
formatted() method.
"""
str_fmt = "{0!s:<{klen}} : {1!s:>{vlen}}\n"
def __init__(self, *args, **kwargs):
"""
*args* can be dictionaries, bunches, or sequences of
key,value tuples. *kwargs* can be used to initialize
or add key, value pairs.
"""
dict.__init__(self)
for arg in args:
self.update(arg)
self.update(kwargs)
def __getattr__(self, name):
try:
return self[name]
except KeyError:
raise AttributeError("'Bunch' object has no attribute '%s'" % name)
def __setattr__(self, name, value):
self[name] = value
def __str__(self):
return self.formatted()
def formatted(self, fmt=None, types=False):
"""
Return a string with keys and/or values or types.
*fmt* is a format string as used in the str.format() method.
The str.format() method is called with key, value as positional
arguments, and klen, vlen as kwargs. The latter are the maxima
of the string lengths for the keys and values, respectively,
up to respective maxima of 20 and 40.
"""
if fmt is None:
fmt = self.str_fmt
items = list(self.items())
items.sort()
klens = []
vlens = []
for i, (k, v) in enumerate(items):
lenk = len(str(k))
if types:
v = type(v).__name__
lenv = len(str(v))
items[i] = (k, v)
klens.append(lenk)
vlens.append(lenv)
klen = min(20, max(klens))
vlen = min(40, max(vlens))
slist = [fmt.format(k, v, klen=klen, vlen=vlen) for k, v in items]
return ''.join(slist)
def from_pyfile(self, filename):
"""
Read in variables from a python code file.
"""
# We can't simply exec the code directly, because in
# Python 3 the scoping for list comprehensions would
# lead to a NameError. Wrapping the code in a function
# fixes this.
d = dict()
lines = ["def _temp_func():\n"]
with open(filename) as f:
lines.extend([" " + line for line in f])
lines.extend(["\n return(locals())\n",
"_temp_out = _temp_func()\n",
"del(_temp_func)\n"])
codetext = "".join(lines)
code = compile(codetext, filename, 'exec')
exec(code, globals(), d)
self.update(d["_temp_out"])
return self
def update_values(self, *args, **kw):
"""
arguments are dictionary-like; if present, they act as
additional sources of kwargs, with the actual kwargs
taking precedence.
One reserved optional kwarg is "strict". If present and
True, then any attempt to update with keys that are not
already in the Bunch instance will raise a KeyError.
"""
strict = kw.pop("strict", False)
newkw = dict()
for d in args:
newkw.update(d)
newkw.update(kw)
self._check_strict(strict, newkw)
dsub = dict([(k, v) for (k, v) in newkw.items() if k in self])
self.update(dsub)
def update_None(self, *args, **kw):
"""
Similar to update_values, except that an existing value
will be updated only if it is None.
"""
strict = kw.pop("strict", False)
newkw = dict()
for d in args:
newkw.update(d)
newkw.update(kw)
self._check_strict(strict, newkw)
dsub = dict([(k, v) for (k, v) in newkw.items()
if k in self and self[k] is None])
self.update(dsub)
def _check_strict(self, strict, kw):
if strict:
bad = set(kw.keys()) - set(self.keys())
if bad:
bk = list(bad)
bk.sort()
ek = list(self.keys())
ek.sort()
raise KeyError(
"Update keys %s don't match existing keys %s" % (bk, ek))