Source code for pyxu.util.array_module

  1import collections.abc as cabc
  2import functools
  3
  4import dask
  5
  6import pyxu.info.deps as pxd
  7import pyxu.info.ptype as pxt
  8import pyxu.util.misc as pxm
  9
 10__all__ = [
 11    "compute",
 12    "get_array_module",
 13    "redirect",
 14    "to_NUMPY",
 15]
 16
 17
[docs] 18def get_array_module(x, fallback: pxt.ArrayModule = None) -> pxt.ArrayModule: 19 """ 20 Get the array namespace corresponding to a given object. 21 22 Parameters 23 ---------- 24 x: object 25 Any object compatible with the interface of NumPy arrays. 26 fallback: ArrayModule 27 Fallback module if `x` is not a NumPy-like array. Default behaviour: raise error if fallback used. 28 29 Returns 30 ------- 31 namespace: ArrayModule 32 The namespace to use to manipulate `x`, or `fallback` if provided. 33 """ 34 35 def infer_api(y): 36 try: 37 return pxd.NDArrayInfo.from_obj(y).module() 38 except ValueError: 39 return None 40 41 if (xp := infer_api(x)) is not None: 42 return xp 43 elif fallback is not None: 44 return fallback 45 else: 46 raise ValueError(f"Could not infer array module for {type(x)}.")
47 48
[docs] 49def redirect( 50 i: pxt.VarName, 51 **kwargs: cabc.Mapping[str, cabc.Callable], 52) -> cabc.Callable: 53 """ 54 Change codepath for supplied array backends. 55 56 Some functions/methods cannot be written in module-agnostic fashion. The action of this decorator is summarized 57 below: 58 59 * Analyze an array-valued parameter (`x`) of the wrapped function/method (`f`). 60 * If `x` lies in one of the supplied array namespaces: re-route execution to the specified function. 61 * If `x` lies in none of the supplied array namespaces: execute `f`. 62 63 Parameters 64 ---------- 65 i: VarName 66 name of the array-like variable in `f` to base dispatch on. 67 kwargs: ~collections.abc.Mapping 68 69 * key[:py:class:`str`]: array backend short-name as defined in :py:class:`~pyxu.info.deps.NDArrayInfo`. 70 * value[:py:class:`collections.abc.Callable`]: function/method to dispatch to. 71 72 Notes 73 ----- 74 Auto-dispatch via :py:func:`redirect` assumes the dispatcher/dispatchee have the same parameterization, i.e.: 75 76 * if `f` is a function -> dispatch possible to another callable with identical signature (i.e., function or 77 staticmethod) 78 * if `f` is a staticmethod -> dispatch possible to another callable with identical signature (i.e., function or 79 staticmethod) 80 * if `f` is an instance-method -> dispatch to another instance-method of the class with identical signature. 81 82 Example 83 ------- 84 .. code-block:: python3 85 86 def f(x, y): return "f" 87 88 @redirect('x', NUMPY=f) # if 'x' is of type NDArrayInfo.NUMPY, i.e. has 89 def g(x, y): return "g" # short-name 'NUMPY' -> reroute execution to `f` 90 91 x1 = np.arange(5) 92 x2 = da.array(x1) 93 y = 1 94 g(x1, y), g(x2, y) # 'f', 'g' 95 """ 96 97 def decorator(func: cabc.Callable) -> cabc.Callable: 98 @functools.wraps(func) 99 def wrapper(*ARGS, **KWARGS): 100 try: 101 func_args = pxm.parse_params(func, *ARGS, **KWARGS) 102 except Exception as e: 103 error_msg = f"Could not parameterize {func}()." 104 raise ValueError(error_msg) from e 105 106 if i not in func_args: 107 error_msg = f"Parameter[{i}] not part of {func.__qualname__}() parameter list." 108 raise ValueError(error_msg) 109 110 ndi = pxd.NDArrayInfo.from_obj(func_args[i]) 111 if (alt_func := kwargs.get(ndi.name)) is not None: 112 out = alt_func(**func_args) 113 else: 114 out = func(**func_args) 115 116 return out 117 118 return wrapper 119 120 return decorator
121 122
[docs] 123def compute(*args, mode: str = "compute", **kwargs): 124 r""" 125 Force computation of Dask collections. 126 127 Parameters 128 ---------- 129 \*args: object, list 130 Any number of objects. If it is a dask object, it is evaluated and the result is returned. Non-dask arguments 131 are passed through unchanged. Python collections are traversed to find/evaluate dask objects within. (Use 132 `traverse` =False to disable this behavior.) 133 mode: str 134 Dask evaluation strategy: compute or persist. 135 \*\*kwargs: dict 136 Extra keyword parameters forwarded to :py:func:`dask.compute` or :py:func:`dask.persist`. 137 138 Returns 139 ------- 140 \*cargs: object, list 141 Evaluated objects. Non-dask arguments are passed through unchanged. 142 """ 143 try: 144 mode = mode.strip().lower() 145 func = dict(compute=dask.compute, persist=dask.persist)[mode] 146 except Exception: 147 raise ValueError(f"mode: expected compute/persist, got {mode}.") 148 149 cargs = func(*args, **kwargs) 150 if len(args) == 1: 151 cargs = cargs[0] 152 return cargs
153 154
[docs] 155def to_NUMPY(x: pxt.NDArray) -> pxt.NDArray: 156 """ 157 Convert an array from a specific backend to NUMPY. 158 159 Parameters 160 ---------- 161 x: NDArray 162 Array to be converted. 163 164 Returns 165 ------- 166 y: NDArray 167 Array with NumPy backend. 168 169 Notes 170 ----- 171 This function is a no-op if the array is already a NumPy array. 172 """ 173 N = pxd.NDArrayInfo 174 ndi = N.from_obj(x) 175 if ndi == N.NUMPY: 176 y = x 177 elif ndi == N.DASK: 178 y = compute(x) 179 elif ndi == N.CUPY: 180 y = x.get() 181 else: 182 msg = f"Dev-action required: define behaviour for {ndi}." 183 raise ValueError(msg) 184 return y