Source code for pyxu.util.array_module
1import collections.abc as cabc
2import functools
3
4import dask
5
6import pyxu.info.deps as pxd
7import pyxu.info.ptype as pxt
8import pyxu.util.misc as pxm
9
10__all__ = [
11 "compute",
12 "get_array_module",
13 "redirect",
14 "to_NUMPY",
15]
16
17
[docs]
18def get_array_module(x, fallback: pxt.ArrayModule = None) -> pxt.ArrayModule:
19 """
20 Get the array namespace corresponding to a given object.
21
22 Parameters
23 ----------
24 x: object
25 Any object compatible with the interface of NumPy arrays.
26 fallback: ArrayModule
27 Fallback module if `x` is not a NumPy-like array. Default behaviour: raise error if fallback used.
28
29 Returns
30 -------
31 namespace: ArrayModule
32 The namespace to use to manipulate `x`, or `fallback` if provided.
33 """
34
35 def infer_api(y):
36 try:
37 return pxd.NDArrayInfo.from_obj(y).module()
38 except ValueError:
39 return None
40
41 if (xp := infer_api(x)) is not None:
42 return xp
43 elif fallback is not None:
44 return fallback
45 else:
46 raise ValueError(f"Could not infer array module for {type(x)}.")
47
48
[docs]
49def redirect(
50 i: pxt.VarName,
51 **kwargs: cabc.Mapping[str, cabc.Callable],
52) -> cabc.Callable:
53 """
54 Change codepath for supplied array backends.
55
56 Some functions/methods cannot be written in module-agnostic fashion. The action of this decorator is summarized
57 below:
58
59 * Analyze an array-valued parameter (`x`) of the wrapped function/method (`f`).
60 * If `x` lies in one of the supplied array namespaces: re-route execution to the specified function.
61 * If `x` lies in none of the supplied array namespaces: execute `f`.
62
63 Parameters
64 ----------
65 i: VarName
66 name of the array-like variable in `f` to base dispatch on.
67 kwargs: ~collections.abc.Mapping
68
69 * key[:py:class:`str`]: array backend short-name as defined in :py:class:`~pyxu.info.deps.NDArrayInfo`.
70 * value[:py:class:`collections.abc.Callable`]: function/method to dispatch to.
71
72 Notes
73 -----
74 Auto-dispatch via :py:func:`redirect` assumes the dispatcher/dispatchee have the same parameterization, i.e.:
75
76 * if `f` is a function -> dispatch possible to another callable with identical signature (i.e., function or
77 staticmethod)
78 * if `f` is a staticmethod -> dispatch possible to another callable with identical signature (i.e., function or
79 staticmethod)
80 * if `f` is an instance-method -> dispatch to another instance-method of the class with identical signature.
81
82 Example
83 -------
84 .. code-block:: python3
85
86 def f(x, y): return "f"
87
88 @redirect('x', NUMPY=f) # if 'x' is of type NDArrayInfo.NUMPY, i.e. has
89 def g(x, y): return "g" # short-name 'NUMPY' -> reroute execution to `f`
90
91 x1 = np.arange(5)
92 x2 = da.array(x1)
93 y = 1
94 g(x1, y), g(x2, y) # 'f', 'g'
95 """
96
97 def decorator(func: cabc.Callable) -> cabc.Callable:
98 @functools.wraps(func)
99 def wrapper(*ARGS, **KWARGS):
100 try:
101 func_args = pxm.parse_params(func, *ARGS, **KWARGS)
102 except Exception as e:
103 error_msg = f"Could not parameterize {func}()."
104 raise ValueError(error_msg) from e
105
106 if i not in func_args:
107 error_msg = f"Parameter[{i}] not part of {func.__qualname__}() parameter list."
108 raise ValueError(error_msg)
109
110 ndi = pxd.NDArrayInfo.from_obj(func_args[i])
111 if (alt_func := kwargs.get(ndi.name)) is not None:
112 out = alt_func(**func_args)
113 else:
114 out = func(**func_args)
115
116 return out
117
118 return wrapper
119
120 return decorator
121
122
[docs]
123def compute(*args, mode: str = "compute", **kwargs):
124 r"""
125 Force computation of Dask collections.
126
127 Parameters
128 ----------
129 \*args: object, list
130 Any number of objects. If it is a dask object, it is evaluated and the result is returned. Non-dask arguments
131 are passed through unchanged. Python collections are traversed to find/evaluate dask objects within. (Use
132 `traverse` =False to disable this behavior.)
133 mode: str
134 Dask evaluation strategy: compute or persist.
135 \*\*kwargs: dict
136 Extra keyword parameters forwarded to :py:func:`dask.compute` or :py:func:`dask.persist`.
137
138 Returns
139 -------
140 \*cargs: object, list
141 Evaluated objects. Non-dask arguments are passed through unchanged.
142 """
143 try:
144 mode = mode.strip().lower()
145 func = dict(compute=dask.compute, persist=dask.persist)[mode]
146 except Exception:
147 raise ValueError(f"mode: expected compute/persist, got {mode}.")
148
149 cargs = func(*args, **kwargs)
150 if len(args) == 1:
151 cargs = cargs[0]
152 return cargs
153
154
[docs]
155def to_NUMPY(x: pxt.NDArray) -> pxt.NDArray:
156 """
157 Convert an array from a specific backend to NUMPY.
158
159 Parameters
160 ----------
161 x: NDArray
162 Array to be converted.
163
164 Returns
165 -------
166 y: NDArray
167 Array with NumPy backend.
168
169 Notes
170 -----
171 This function is a no-op if the array is already a NumPy array.
172 """
173 N = pxd.NDArrayInfo
174 ndi = N.from_obj(x)
175 if ndi == N.NUMPY:
176 y = x
177 elif ndi == N.DASK:
178 y = compute(x)
179 elif ndi == N.CUPY:
180 y = x.get()
181 else:
182 msg = f"Dev-action required: define behaviour for {ndi}."
183 raise ValueError(msg)
184 return y