Source code for pyxu.util.operator

  1import collections.abc as cabc
  2import concurrent.futures as cf
  3import copy
  4import functools
  5import inspect
  6import itertools
  7
  8import pyxu.info.deps as pxd
  9import pyxu.info.ptype as pxt
 10import pyxu.util.misc as pxm
 11
 12__all__ = [
 13    "as_canonical_axes",
 14    "as_canonical_shape",
 15    "vectorize",
 16]
 17
 18
[docs] 19def as_canonical_shape(x: pxt.NDArrayShape) -> pxt.NDArrayShape: 20 """ 21 Transform a lone integer into a valid tuple-based shape specifier. 22 """ 23 if isinstance(x, cabc.Iterable): 24 x = tuple(x) 25 else: 26 x = (x,) 27 assert all(isinstance(_x, pxt.Integer) for _x in x) 28 29 shape = tuple(map(int, x)) 30 return shape
31 32
[docs] 33def as_canonical_axes( 34 axes: pxt.NDArrayAxis, 35 rank: pxt.Integer, 36) -> pxt.NDArrayAxis: 37 """ 38 Transform NDarray axes into tuple-form with positive indices. 39 40 Parameters 41 ---------- 42 rank: Integer 43 Rank of the NDArray. (Required to make all entries positive.) 44 """ 45 assert rank >= 1 46 47 axes = as_canonical_shape(axes) 48 assert all(-rank <= ax < rank for ax in axes) # all axes in valid range 49 axes = tuple((ax + rank) % rank for ax in axes) # get rid of negative axes 50 return axes
51 52
[docs] 53def vectorize( 54 i: pxt.VarName, 55 dim_shape: pxt.NDArrayShape, 56 codim_shape: pxt.NDArrayShape, 57) -> cabc.Callable: 58 r""" 59 Decorator to auto-vectorize a function :math:`\mathbf{f}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to 60 \mathbb{R}^{N_{1} \times\cdots\times N_{K}}` to accept stacking dimensions. 61 62 Parameters 63 ---------- 64 i: VarName 65 Function/method parameter to vectorize. This variable must hold an object with a NumPy API. 66 dim_shape: NDArrayShape 67 (M1,...,MD) shape of operator's domain. 68 codim_shape: NDArrayShape 69 (N1,...,NK) shape of operator's co-domain. 70 71 Returns 72 ------- 73 g: ~collections.abc.Callable 74 Function/Method with signature ``(..., M1,...,MD) -> (..., N1,...,NK)`` in parameter `i`. 75 76 Example 77 ------- 78 .. code-block:: python3 79 80 import pyxu.util as pxu 81 82 def f(x): 83 return x.sum(keepdims=True) 84 85 N = 5 86 g = pxu.vectorize("x", N, 1)(f) 87 88 x = np.arange(2*N).reshape((2, N)) 89 g(x[0]), g(x[1]) # [10], [35] 90 g(x) # [[10], 91 # [35]] 92 93 Notes 94 ----- 95 * :py:func:`~pyxu.util.vectorize` assumes the function being vectorized is **thread-safe** and can be evaluated in 96 parallel. Using it on thread-unsafe code may lead to incorrect outputs. 97 * As predicted by Pyxu's :py:class:`~pyxu.abc.Operator` API: 98 99 - The dtype of the vectorized function is assumed to match `x.dtype`. 100 - The array backend of the vectorized function is assumed to match that of `x`. 101 """ 102 N = pxd.NDArrayInfo # short-hand 103 dim_shape = as_canonical_shape(dim_shape) 104 dim_rank = len(dim_shape) 105 codim_shape = as_canonical_shape(codim_shape) 106 107 def decorator(func: cabc.Callable) -> cabc.Callable: 108 sig = inspect.Signature.from_callable(func) 109 if i not in sig.parameters: 110 error_msg = f"Parameter[{i}] not part of {func.__qualname__}() parameter list." 111 raise ValueError(error_msg) 112 113 @functools.wraps(func) 114 def wrapper(*ARGS, **KWARGS): 115 func_args = pxm.parse_params(func, *ARGS, **KWARGS) 116 117 x = func_args.pop(i) 118 ndi = N.from_obj(x) 119 xp = ndi.module() 120 121 sh_stack = x.shape[:-dim_rank] 122 if ndi in [N.NUMPY, N.CUPY]: 123 task_kwargs = [] 124 for idx in itertools.product(*map(range, sh_stack)): 125 kwargs = copy.deepcopy(func_args) 126 kwargs[i] = x[idx] 127 task_kwargs.append(kwargs) 128 129 with cf.ThreadPoolExecutor() as executor: 130 res = executor.map(lambda _: func(**_), task_kwargs) 131 y = xp.stack(list(res), axis=0).reshape((*sh_stack, *codim_shape)) 132 elif ndi == N.DASK: 133 # Find out codim chunk structure ... 134 idx = (0,) * len(sh_stack) 135 func_args[i] = x[idx] 136 codim_chunks = func(**func_args).chunks # no compute; only extract chunk info 137 138 # ... then process all inputs. 139 y = xp.zeros( 140 (*sh_stack, *codim_shape), 141 dtype=x.dtype, 142 chunks=x.chunks[:-dim_rank] + codim_chunks, 143 ) 144 for idx in itertools.product(*map(range, sh_stack)): 145 func_args[i] = x[idx] 146 y[idx] = func(**func_args) 147 else: 148 # Define custom behavior 149 raise ValueError("Unknown NDArray category.") 150 151 return y 152 153 return wrapper 154 155 return decorator