1import functools
2import types
3import warnings
4
5import packaging.version as pkgv
6
7import pyxu.abc as pxa
8import pyxu.info.deps as pxd
9import pyxu.info.ptype as pxt
10import pyxu.info.warning as pxw
11import pyxu.operator.interop.source as px_src
12import pyxu.runtime as pxrt
13import pyxu.util as pxu
14
15jax = pxu.import_module("jax", fail_on_error=False)
16if jax is None:
17 import typing as typ
18
19 JaxArray = typ.TypeVar("JaxArray", bound="jax.Array")
20else:
21 version = pkgv.Version(jax.__version__)
22 supported = pxd.JAX_SUPPORT
23 assert supported["min"] <= version < supported["max"]
24
25 JaxArray = jax.Array
26 import jax.numpy as jnp
27
28
29__all__ = [
30 "from_jax",
31]
32
33
[docs]
34def _from_jax(
35 x: JaxArray,
36 xp: pxt.ArrayModule = None,
37) -> pxt.NDArray:
38 """
39 JAX -> NumPy/CuPy conversion.
40
41 The transform is always zero-copy, but it is not easy to check this condition for all array types (contiguous,
42 views, etc.) and backends (NUMPY, CUPY).
43
44 [More info] https://github.com/google/jax/issues/1961#issuecomment-875773326
45 """
46 N = pxd.NDArrayInfo # shorthand
47
48 if xp is None:
49 xp = N.default().module()
50
51 if xp not in (N.NUMPY.module(), N.CUPY.module()):
52 raise pxw.BackendWarning("Only NumPy/CuPy inputs are supported.")
53
54 y = xp.asarray(x)
55 return y
56
57
[docs]
58def _to_jax(x: pxt.NDArray, enable_warnings: bool = True) -> JaxArray:
59 """
60 NumPy/CuPy -> JAX conversion.
61
62 Conversion is zero-copy when possible, i.e. 16-byte alignment, on the right device, etc.
63
64 [More info] https://github.com/google/jax/issues/4486#issuecomment-735842976
65 """
66 N = pxd.NDArrayInfo # shorthand
67 W, cW = pxrt.Width, pxrt.CWidth # shorthand
68
69 ndi = N.from_obj(x)
70 if ndi == N.DASK:
71 raise pxw.BackendWarning("DASK inputs are unsupported.")
72
73 supported_dtype = set(w.value for w in W) | set(w.value for w in cW)
74 if x.dtype not in supported_dtype:
75 msg = "For safety reasons, _to_jax() only accepts pyxu.runtime.[C]Width-supported dtypes."
76 raise pxw.PrecisionWarning(msg)
77
78 xp = ndi.module()
79 if ndi == N.NUMPY:
80 dev_type = "cpu"
81 f_wrap = jnp.asarray
82 elif ndi == N.CUPY:
83 dev_type = "gpu"
84 x = xp.require(x, requirements="C") # JAX-DLPACK only supports contiguous arrays [2023.04.05]
85 f_wrap = jnp.from_dlpack
86 else:
87 raise ValueError("Unknown NDArray category.")
88 dev = jax.devices(dev_type)[0]
89 with jax.default_device(dev):
90 y = f_wrap(x)
91
92 same_dtype = x.dtype == y.dtype
93 same_mem = xp.byte_bounds(x)[0] == y.addressable_data(0).unsafe_buffer_pointer()
94 if not (same_dtype and same_mem) and enable_warnings:
95 msg = "\n".join(
96 [
97 "_to_jax(): a zero-copy conversion did not take place.",
98 "[More info] https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision",
99 "[More info] https://github.com/google/jax/issues/4486#issuecomment-735842976",
100 ]
101 )
102 warnings.warn(msg, pxw.PrecisionWarning)
103 return y
104
105
[docs]
106def from_jax(
107 cls: pxt.OpC,
108 dim_shape: pxt.NDArrayShape,
109 codim_shape: pxt.NDArrayShape,
110 vectorize: pxt.VarName = frozenset(),
111 jit: bool = False,
112 enable_warnings: bool = True,
113 **kwargs,
114) -> pxt.OpT:
115 r"""
116 Define an :py:class:`~pyxu.abc.Operator` from JAX functions.
117
118 Parameters
119 ----------
120 cls: OpC
121 Operator sub-class to instantiate.
122 dim_shape: NDArrayShape
123 Operator domain shape (M1,...,MD).
124 codim_shape: NDArrayShape
125 Operator co-domain shape (N1,...,NK).
126 kwargs: dict
127 (k[str], v[callable]) pairs to use as arithmetic methods.
128
129 Keys are restricted to the following arithmetic methods:
130
131 .. code-block:: python3
132
133 apply(), grad(), prox(), pinv(), adjoint()
134
135 Omitted arithmetic methods default to those provided by `cls`, or are auto-inferred via auto-diff rules.
136 vectorize: VarName
137 Arithmetic methods to vectorize.
138
139 `vectorize` is useful if an arithmetic method provided to `kwargs` does not support stacking dimensions.
140 jit: bool
141 If ``True``, JIT-compile JAX-backed arithmetic methods for better performance.
142 enable_warnings: bool
143 If ``True``, emit warnings in case of precision/zero-copy issues.
144
145 Returns
146 -------
147 op: OpT
148 Pyxu-compliant operator :math:`A: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{N_{1}
149 \times\cdots\times N_{K}}`.
150
151
152 Notes
153 -----
154 * If provided, arithmetic methods must abide exactly to the interface below:
155
156 .. code-block:: python3
157
158 def apply(arr: jax.Array) -> jax.Array # (..., M1,...,MD) -> (..., N1,...,NK)
159 def grad(arr: jax.Array) -> jax.Array # (..., M1,...,MD) -> (..., M1,...,MD)
160 def adjoint(arr: jax.Array) -> jax.Array # (..., N1,...,NK) -> (..., M1,...,MD)
161 def prox(arr: jax.Array, tau: pxt.Real) -> jax.Array # (..., M1,...,MD) -> (..., M1,...,MD)
162 def pinv(arr: jax.Array, damp: pxt.Real) -> jax.Array # (..., N1,...,NK) -> (..., M1,...,MD)
163
164 Moreover, the methods above **must** accept stacking dimensions in ``arr``. If this does not hold, consider
165 populating `vectorize`.
166
167 * Auto-vectorization consists in decorating `kwargs`-specified arithmetic methods with
168 :py:func:`jax.numpy.vectorize`.
169
170 * Note that JAX enforces `32-bit arithmetic
171 <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision>`_ by default,
172 and this constraint cannot be changed at runtime. As such, to allow zero-copy transfers between JAX and
173 NumPy/CuPy arrays, it is advised to perform computations in single-precision mode.
174
175 * Inferred arithmetic methods are not JIT-ed by default since the operation is error-prone depending on how
176 :py:meth:`~pyxu.abc.Map.apply` is defined. If :py:meth:`~pyxu.abc.Map.apply` supplied to
177 :py:func:`~pyxu.operator.interop.from_jax` is JIT-friendly, then consider enabling `jit`.
178
179 Examples
180 --------
181 Create the custom differential map :math:`f: \mathbb{R}^{2} \to \mathbb{R}^{3}`:
182
183 .. math::
184
185 f(x, y) =
186 \left[
187 \sin(x) + \cos(y),
188 \cos(x) - \sin(y),
189 \sin(x) + \cos(x)
190 \right]
191
192 .. code-block:: python3
193
194 import pyxu.abc as pxa
195 import pyxu.runtime as pxrt
196 import pyxu.operator.interop as pxi
197 import jax, jax.numpy as jnp
198 import numpy as np
199
200 @jax.jit
201 def j_apply(arr: jax.Array) -> jax.Array:
202 x, y = arr[0], arr[1]
203 o1 = jnp.sin(x) + jnp.cos(y)
204 o2 = jnp.cos(x) - jnp.sin(y)
205 o3 = jnp.sin(x) + jnp.cos(x)
206 out = jnp.r_[o1, o2, o3]
207 return out
208
209 op = pxi.from_jax(
210 cls=pxa.DiffMap,
211 dim_shape=2,
212 codim_shape=3,
213 vectorize="apply", # j_apply() does not work on stacked inputs
214 # --> let JAX figure it out automatically.
215 apply=j_apply,
216 )
217
218 rng = np.random.default_rng(0)
219 x = rng.normal(size=(5,3,4,2))
220 y1 = op.apply(x) # (5,3,4,3)
221
222 x = rng.normal(size=(2,))
223 opJ = op.jacobian(x) # JAX auto-infers the Jacobian for you.
224
225 v = rng.normal(size=(5,2))
226 w = rng.normal(size=(4,3))
227 y2 = opJ.apply(v) # (5,3)
228 y3 = opJ.adjoint(w) # (4,2)
229 """
230 if isinstance(vectorize, str):
231 vectorize = (vectorize,)
232 vectorize = frozenset(vectorize)
233
234 src = _FromJax(
235 cls=cls,
236 dim_shape=dim_shape,
237 codim_shape=codim_shape,
238 vectorize=vectorize,
239 jit=bool(jit),
240 enable_warnings=bool(enable_warnings),
241 **kwargs,
242 )
243 op = src.op()
244 return op
245
246
247class _FromJax(px_src._FromSource):
248 # supported methods in __init__(**kwargs)
249 _meth = frozenset({"apply", "grad", "prox", "pinv", "adjoint"})
250
251 def __init__( # See from_jax() for a detailed description.
252 self,
253 cls: pxt.OpC,
254 dim_shape: pxt.NDArrayShape,
255 codim_shape: pxt.NDArrayShape,
256 vectorize: frozenset[str],
257 jit: bool,
258 enable_warnings: bool,
259 **kwargs,
260 ):
261 super().__init__(
262 cls=cls,
263 dim_shape=dim_shape,
264 codim_shape=codim_shape,
265 embed=dict(), # jax-funcs are state-free.
266 vectorize=vectorize,
267 **kwargs,
268 )
269
270 self._jit = jit
271 self._enable_warnings = enable_warnings
272
273 # Only a subset of arithmetic methods allowed from from_jax().
274 if not (set(self._kwargs) <= self._meth):
275 msg_head = "Unsupported arithmetic methods:"
276 unsupported = set(self._kwargs) - self._meth
277 msg_tail = ", ".join([f"{name}()" for name in unsupported])
278 raise ValueError(f"{msg_head} {msg_tail}")
279
280 def op(self) -> pxt.OpT:
281 # Idea: modify `**kwargs` from constructor to [when required]:
282 # 1. auto-define omitted methods. [_infer_missing()]
283 # 2. auto-vectorize via vmap(). [_auto_vectorize()]
284 # 3. JIT & JAX<>NumPy/CuPy conversions. [_interface()]
285 self._infer_missing()
286 self._auto_vectorize()
287 j_state, kwargs = self._interface()
288
289 _op = px_src.from_source(
290 cls=self._op.__class__,
291 dim_shape=self._op.dim_shape,
292 codim_shape=self._op.codim_shape,
293 embed=dict(
294 _jax=j_state,
295 _enable_warnings=self._enable_warnings,
296 _jit=self._jit,
297 ),
298 # vectorize=None, # see top-level comment.
299 **kwargs,
300 )
301 return _op
302
303 def _infer_missing(self):
304 # The following methods must be auto-inferred if missing from `kwargs`:
305 #
306 # grad(), adjoint()
307 #
308 # Missing methods are auto-inferred via auto-diff rules and added to `_kwargs`.
309 # At the end of _infer_missing(), all jax-funcs required for _interface() have been added to `_kwargs`.
310 #
311 # Notes
312 # -----
313 # This method does NOT produce vectorized implementations: _auto_vectorize() is responsible for this.
314 self._vectorize = set(self._vectorize) # to allow updates below
315
316 nl_difffunc = all( # non-linear diff-func
317 [
318 self._op.has(pxa.Property.DIFFERENTIABLE_FUNCTION),
319 not self._op.has(pxa.Property.LINEAR),
320 ]
321 )
322 if nl_difffunc and ("grad" not in self._kwargs):
323
324 def f_grad(arr: JaxArray) -> JaxArray:
325 f = self._kwargs["apply"]
326 y, f_vjp = jax.vjp(f, arr)
327 v = jnp.ones_like(y)
328 out = f_vjp(v)[0] # f_vjp() returns a tuple
329 return out
330
331 self._vectorize.add("grad")
332 self._kwargs["grad"] = f_grad
333
334 non_selfadj = all( # linear, but not self-adjoint
335 [
336 self._op.has(pxa.Property.LINEAR),
337 not self._op.has(pxa.Property.LINEAR_SELF_ADJOINT),
338 ]
339 )
340 if non_selfadj and ("adjoint" not in self._kwargs):
341
342 def f_adjoint(arr: JaxArray) -> JaxArray:
343 f = self._kwargs["apply"]
344 x = jnp.zeros_like(arr, shape=self._op.dim_shape)
345 _, f_vjp = jax.vjp(f, x)
346 out = f_vjp(arr)[0] # f_vjp() returns a tuple
347 return out
348
349 self._vectorize.add("adjoint")
350 self._kwargs["adjoint"] = f_adjoint
351
352 self._vectorize = frozenset(self._vectorize)
353
354 def _auto_vectorize(self):
355 # Vectorize user-specified [or _infer_missing()-added] arithmetic methods via jax.vmap().
356 #
357 # Modifies `_kwargs` to hold vectorized jax-funcs.
358 d_sh = ",".join([f"m{i}" for i in range(self._op.dim_rank)]) # dim_shape
359 cd_sh = ",".join([f"n{i}" for i in range(self._op.codim_rank)]) # codim_shape
360 vkwargs = dict( # kwargs to jax.numpy.vectorize()
361 apply=dict(signature=f"({d_sh})->({cd_sh})"),
362 adjoint=dict(signature=f"({cd_sh})->({d_sh})"),
363 grad=dict(signature=f"({d_sh})->({d_sh})"),
364 prox=dict(signature=f"({d_sh})->({d_sh})", excluded={1}),
365 pinv=dict(signature=f"({cd_sh})->({d_sh})", excluded={1}),
366 )
367
368 for name in self._kwargs:
369 if name in self._vectorize:
370 func = self._kwargs[name] # necessarily jax_func
371 vectorize = functools.partial(
372 jax.numpy.vectorize,
373 **vkwargs[name],
374 )
375 self._kwargs[name] = vectorize(func)
376
377 def _interface(self):
378 # Arithmetic methods supplied in `kwargs`:
379 #
380 # * take `jax.Array` inputs
381 # * do not have the `self` parameter. (Reason: to be JIT-compatible.)
382 #
383 # This method creates modified arithmetic functions to match Pyxu's API, and `state` required for them to work.
384 #
385 # Returns
386 # -------
387 # j_state: dict
388 # Jax functions referenced by wrapper arithmetic methods. (See below.)
389 # These functions are JIT-compiled if specified.
390 # kwargs: dict
391 # Pyxu-compatible functions which can be submitted to interop.from_source().
392 j_state = dict()
393 for name, obj in self._kwargs.items():
394 if name in self._meth:
395 if self._jit:
396 obj = jax.jit(obj)
397 j_state[name] = obj # necessarily jax_func
398
399 kwargs = dict()
400 for name, obj in self._kwargs.items():
401 func = getattr(self.__class__, name)
402 kwargs[name] = func # necessarily a pyxu_func
403
404 # Special cases.
405 # (Reason: not in `_kwargs` [c.f. from_jax() docstring], but need to be inferred.)
406 for name in ("jacobian", "_quad_spec", "asarray", "_expr"):
407 kwargs[name] = getattr(self.__class__, name)
408
409 return j_state, kwargs
410
411 # Wrapper arithmetic methods ----------------------------------------------
412 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
413 j_arr = _to_jax(arr, enable_warnings=self._enable_warnings)
414 func = self._jax["apply"]
415 dev = j_arr.devices().pop()
416 with jax.default_device(dev):
417 j_out = func(j_arr)
418 out = _from_jax(j_out, xp=pxu.get_array_module(arr))
419 return out
420
421 def grad(self, arr: pxt.NDArray) -> pxt.NDArray:
422 j_arr = _to_jax(arr, enable_warnings=self._enable_warnings)
423 func = self._jax["grad"]
424 dev = j_arr.devices().pop()
425 with jax.default_device(dev):
426 j_out = func(j_arr)
427 out = _from_jax(j_out, xp=pxu.get_array_module(arr))
428 return out
429
430 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
431 j_arr = _to_jax(arr, enable_warnings=self._enable_warnings)
432 func = self._jax["adjoint"]
433 dev = j_arr.devices().pop()
434 with jax.default_device(dev):
435 j_out = func(j_arr)
436 out = _from_jax(j_out, xp=pxu.get_array_module(arr))
437 return out
438
439 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
440 j_arr = _to_jax(arr, enable_warnings=self._enable_warnings)
441 func = self._jax["prox"]
442 dev = j_arr.devices().pop()
443 with jax.default_device(dev):
444 j_out = func(j_arr, tau) # positional args only if auto-vectorized.
445 out = _from_jax(j_out, xp=pxu.get_array_module(arr))
446 return out
447
448 def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray:
449 j_arr = _to_jax(arr, enable_warnings=self._enable_warnings)
450 func = self._jax["pinv"]
451 dev = j_arr.devices().pop()
452 with jax.default_device(dev):
453 j_out = func(j_arr, damp) # positional args only if auto-vectorized.
454 out = _from_jax(j_out, xp=pxu.get_array_module(arr))
455 return out
456
457 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT:
458 try:
459 # Use the class' method if available ...
460 klass = self.__class__
461 op = klass.jacobian(self, arr)
462 except NotImplementedError:
463 # ... and fallback to auto-inference if undefined.
464 f = self._jax["apply"]
465 j_arr = _to_jax(arr, enable_warnings=self._enable_warnings)
466
467 # define forward: [1] explains why jvp() is a better fit than linearize().
468 # [1] https://jax.readthedocs.io/en/latest/_autosummary/jax.linearize.html
469 _fwd = functools.partial(jax.jvp, f, (j_arr,))
470 f_fwd = lambda arr: _fwd((arr,))[1] # jax returns a tuple
471
472 # define adjoint: [2] explains benefits of linear_transpose() over vjp().
473 # [2] https://jax.readthedocs.io/en/latest/_autosummary/jax.linear_transpose.html
474 hint = types.SimpleNamespace(shape=self.dim_shape, dtype=arr.dtype)
475 _adj = jax.linear_transpose(f_fwd, hint)
476 f_adj = lambda arr: _adj(arr)[0] # jax returns a tuple
477
478 klass = pxa.LinFunc if (self.codim_shape == (1,)) else pxa.LinOp
479 op = from_jax(
480 cls=klass,
481 dim_shape=self.dim_shape,
482 codim_shape=self.codim_shape,
483 vectorize=("apply", "adjoint"),
484 jit=self._jit,
485 enable_warnings=self._enable_warnings,
486 apply=f_fwd,
487 adjoint=f_adj,
488 )
489 return op
490
491 def _quad_spec(self):
492 if self.has(pxa.Property.QUADRATIC):
493 # auto-infer (Q, c, t)
494 def _grad(arr: JaxArray) -> JaxArray:
495 # Just like jax.grad(f)(arr), but works with (1,)-valued functions.
496 # [jax.grad(f) expects scalar outputs.]
497 f = self._jax["apply"]
498 y, f_vjp = jax.vjp(f, arr)
499 v = jnp.ones_like(y)
500 out = f_vjp(v)[0] # f_vjp() returns a tuple
501 return out
502
503 # vectorize & JIT internal function
504 d_sh = ",".join([f"m{i}" for i in range(self._op.dim_rank)]) # dim_shape
505 _grad = jnp.vectorize(_grad, signature=f"({d_sh})->({d_sh})")
506 if self._jit:
507 _grad = jax.jit(_grad)
508
509 def Q_apply(arr: JaxArray) -> JaxArray:
510 # \grad_{f}(x) = Q x + c = Q x + \grad_{f}(0)
511 # ==> Q x = \grad_{f}(x) - \grad_{f}(0)
512 z = jnp.zeros_like(arr)
513 out = _grad(arr) - _grad(z)
514 return out
515
516 def c_apply(arr: JaxArray) -> JaxArray:
517 z = jnp.zeros_like(arr)
518 c = _grad(z)
519 out = jnp.sum(c * arr)[jnp.newaxis]
520 return out
521
522 Q = from_jax(
523 cls=pxa.PosDefOp,
524 dim_shape=self.dim_shape,
525 codim_shape=self.dim_shape,
526 vectorize="apply",
527 jit=self._jit,
528 enable_warnings=self._enable_warnings,
529 apply=Q_apply,
530 )
531 c = from_jax(
532 cls=pxa.LinFunc,
533 dim_shape=self.dim_shape,
534 codim_shape=1,
535 vectorize="apply",
536 jit=self._jit,
537 enable_warnings=self._enable_warnings,
538 apply=c_apply,
539 )
540
541 # `t` can be computed using any backend, so we choose NUMPY.
542 f = self._jax["apply"]
543 with jax.default_device(jax.devices("cpu")[0]):
544 t = float(f(jnp.zeros(self.dim_shape)))
545
546 return (Q, c, t)
547 else:
548 raise NotImplementedError
549
550 def asarray(self, **kwargs) -> pxt.NDArray:
551 if self.has(pxa.Property.LINEAR):
552 # JAX operators don't accept DASK inputs: cannot call Lin[Op,Func].asarray() with user-specified `xp` value.
553 # -> We arbitrarily perform computations using the NUMPY backend, then cast as needed.
554 N = pxd.NDArrayInfo # shorthand
555 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value)
556 xp = kwargs.get("xp", N.default().module())
557
558 klass = self.__class__
559 A = klass.asarray(self, dtype=dtype, xp=N.NUMPY.module())
560
561 # Not the most efficient method, but fail-proof
562 B = xp.array(A, dtype=dtype)
563 return B
564 else:
565 raise NotImplementedError
566
567 def _expr(self) -> tuple:
568 # show which arithmetic methods are backed by jax-funcs
569 jax_funcs = tuple(self._jax.keys())
570 return ("from_jax", *jax_funcs)