Source code for pyxu.util.operator
1import collections.abc as cabc
2import concurrent.futures as cf
3import copy
4import functools
5import inspect
6import itertools
7
8import pyxu.info.deps as pxd
9import pyxu.info.ptype as pxt
10import pyxu.util.misc as pxm
11
12__all__ = [
13 "as_canonical_axes",
14 "as_canonical_shape",
15 "vectorize",
16]
17
18
[docs]
19def as_canonical_shape(x: pxt.NDArrayShape) -> pxt.NDArrayShape:
20 """
21 Transform a lone integer into a valid tuple-based shape specifier.
22 """
23 if isinstance(x, cabc.Iterable):
24 x = tuple(x)
25 else:
26 x = (x,)
27 assert all(isinstance(_x, pxt.Integer) for _x in x)
28
29 shape = tuple(map(int, x))
30 return shape
31
32
[docs]
33def as_canonical_axes(
34 axes: pxt.NDArrayAxis,
35 rank: pxt.Integer,
36) -> pxt.NDArrayAxis:
37 """
38 Transform NDarray axes into tuple-form with positive indices.
39
40 Parameters
41 ----------
42 rank: Integer
43 Rank of the NDArray. (Required to make all entries positive.)
44 """
45 assert rank >= 1
46
47 axes = as_canonical_shape(axes)
48 assert all(-rank <= ax < rank for ax in axes) # all axes in valid range
49 axes = tuple((ax + rank) % rank for ax in axes) # get rid of negative axes
50 return axes
51
52
[docs]
53def vectorize(
54 i: pxt.VarName,
55 dim_shape: pxt.NDArrayShape,
56 codim_shape: pxt.NDArrayShape,
57) -> cabc.Callable:
58 r"""
59 Decorator to auto-vectorize a function :math:`\mathbf{f}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to
60 \mathbb{R}^{N_{1} \times\cdots\times N_{K}}` to accept stacking dimensions.
61
62 Parameters
63 ----------
64 i: VarName
65 Function/method parameter to vectorize. This variable must hold an object with a NumPy API.
66 dim_shape: NDArrayShape
67 (M1,...,MD) shape of operator's domain.
68 codim_shape: NDArrayShape
69 (N1,...,NK) shape of operator's co-domain.
70
71 Returns
72 -------
73 g: ~collections.abc.Callable
74 Function/Method with signature ``(..., M1,...,MD) -> (..., N1,...,NK)`` in parameter `i`.
75
76 Example
77 -------
78 .. code-block:: python3
79
80 import pyxu.util as pxu
81
82 def f(x):
83 return x.sum(keepdims=True)
84
85 N = 5
86 g = pxu.vectorize("x", N, 1)(f)
87
88 x = np.arange(2*N).reshape((2, N))
89 g(x[0]), g(x[1]) # [10], [35]
90 g(x) # [[10],
91 # [35]]
92
93 Notes
94 -----
95 * :py:func:`~pyxu.util.vectorize` assumes the function being vectorized is **thread-safe** and can be evaluated in
96 parallel. Using it on thread-unsafe code may lead to incorrect outputs.
97 * As predicted by Pyxu's :py:class:`~pyxu.abc.Operator` API:
98
99 - The dtype of the vectorized function is assumed to match `x.dtype`.
100 - The array backend of the vectorized function is assumed to match that of `x`.
101 """
102 N = pxd.NDArrayInfo # short-hand
103 dim_shape = as_canonical_shape(dim_shape)
104 dim_rank = len(dim_shape)
105 codim_shape = as_canonical_shape(codim_shape)
106
107 def decorator(func: cabc.Callable) -> cabc.Callable:
108 sig = inspect.Signature.from_callable(func)
109 if i not in sig.parameters:
110 error_msg = f"Parameter[{i}] not part of {func.__qualname__}() parameter list."
111 raise ValueError(error_msg)
112
113 @functools.wraps(func)
114 def wrapper(*ARGS, **KWARGS):
115 func_args = pxm.parse_params(func, *ARGS, **KWARGS)
116
117 x = func_args.pop(i)
118 ndi = N.from_obj(x)
119 xp = ndi.module()
120
121 sh_stack = x.shape[:-dim_rank]
122 if ndi in [N.NUMPY, N.CUPY]:
123 task_kwargs = []
124 for idx in itertools.product(*map(range, sh_stack)):
125 kwargs = copy.deepcopy(func_args)
126 kwargs[i] = x[idx]
127 task_kwargs.append(kwargs)
128
129 with cf.ThreadPoolExecutor() as executor:
130 res = executor.map(lambda _: func(**_), task_kwargs)
131 y = xp.stack(list(res), axis=0).reshape((*sh_stack, *codim_shape))
132 elif ndi == N.DASK:
133 # Find out codim chunk structure ...
134 idx = (0,) * len(sh_stack)
135 func_args[i] = x[idx]
136 codim_chunks = func(**func_args).chunks # no compute; only extract chunk info
137
138 # ... then process all inputs.
139 y = xp.zeros(
140 (*sh_stack, *codim_shape),
141 dtype=x.dtype,
142 chunks=x.chunks[:-dim_rank] + codim_chunks,
143 )
144 for idx in itertools.product(*map(range, sh_stack)):
145 func_args[i] = x[idx]
146 y[idx] = func(**func_args)
147 else:
148 # Define custom behavior
149 raise ValueError("Unknown NDArray category.")
150
151 return y
152
153 return wrapper
154
155 return decorator