Source code for pyxu.operator.interop.jax

  1import functools
  2import types
  3import warnings
  4
  5import packaging.version as pkgv
  6
  7import pyxu.abc as pxa
  8import pyxu.info.deps as pxd
  9import pyxu.info.ptype as pxt
 10import pyxu.info.warning as pxw
 11import pyxu.operator.interop.source as px_src
 12import pyxu.runtime as pxrt
 13import pyxu.util as pxu
 14
 15jax = pxu.import_module("jax", fail_on_error=False)
 16if jax is None:
 17    import typing as typ
 18
 19    JaxArray = typ.TypeVar("JaxArray", bound="jax.Array")
 20else:
 21    version = pkgv.Version(jax.__version__)
 22    supported = pxd.JAX_SUPPORT
 23    assert supported["min"] <= version < supported["max"]
 24
 25    JaxArray = jax.Array
 26    import jax.numpy as jnp
 27
 28
 29__all__ = [
 30    "from_jax",
 31]
 32
 33
[docs] 34def _from_jax( 35 x: JaxArray, 36 xp: pxt.ArrayModule = None, 37) -> pxt.NDArray: 38 """ 39 JAX -> NumPy/CuPy conversion. 40 41 The transform is always zero-copy, but it is not easy to check this condition for all array types (contiguous, 42 views, etc.) and backends (NUMPY, CUPY). 43 44 [More info] https://github.com/google/jax/issues/1961#issuecomment-875773326 45 """ 46 N = pxd.NDArrayInfo # shorthand 47 48 if xp is None: 49 xp = N.default().module() 50 51 if xp not in (N.NUMPY.module(), N.CUPY.module()): 52 raise pxw.BackendWarning("Only NumPy/CuPy inputs are supported.") 53 54 y = xp.asarray(x) 55 return y
56 57
[docs] 58def _to_jax(x: pxt.NDArray, enable_warnings: bool = True) -> JaxArray: 59 """ 60 NumPy/CuPy -> JAX conversion. 61 62 Conversion is zero-copy when possible, i.e. 16-byte alignment, on the right device, etc. 63 64 [More info] https://github.com/google/jax/issues/4486#issuecomment-735842976 65 """ 66 N = pxd.NDArrayInfo # shorthand 67 W, cW = pxrt.Width, pxrt.CWidth # shorthand 68 69 ndi = N.from_obj(x) 70 if ndi == N.DASK: 71 raise pxw.BackendWarning("DASK inputs are unsupported.") 72 73 supported_dtype = set(w.value for w in W) | set(w.value for w in cW) 74 if x.dtype not in supported_dtype: 75 msg = "For safety reasons, _to_jax() only accepts pyxu.runtime.[C]Width-supported dtypes." 76 raise pxw.PrecisionWarning(msg) 77 78 xp = ndi.module() 79 if ndi == N.NUMPY: 80 dev_type = "cpu" 81 f_wrap = jnp.asarray 82 elif ndi == N.CUPY: 83 dev_type = "gpu" 84 x = xp.require(x, requirements="C") # JAX-DLPACK only supports contiguous arrays [2023.04.05] 85 f_wrap = jnp.from_dlpack 86 else: 87 raise ValueError("Unknown NDArray category.") 88 dev = jax.devices(dev_type)[0] 89 with jax.default_device(dev): 90 y = f_wrap(x) 91 92 same_dtype = x.dtype == y.dtype 93 same_mem = xp.byte_bounds(x)[0] == y.addressable_data(0).unsafe_buffer_pointer() 94 if not (same_dtype and same_mem) and enable_warnings: 95 msg = "\n".join( 96 [ 97 "_to_jax(): a zero-copy conversion did not take place.", 98 "[More info] https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision", 99 "[More info] https://github.com/google/jax/issues/4486#issuecomment-735842976", 100 ] 101 ) 102 warnings.warn(msg, pxw.PrecisionWarning) 103 return y
104 105
[docs] 106def from_jax( 107 cls: pxt.OpC, 108 dim_shape: pxt.NDArrayShape, 109 codim_shape: pxt.NDArrayShape, 110 vectorize: pxt.VarName = frozenset(), 111 jit: bool = False, 112 enable_warnings: bool = True, 113 **kwargs, 114) -> pxt.OpT: 115 r""" 116 Define an :py:class:`~pyxu.abc.Operator` from JAX functions. 117 118 Parameters 119 ---------- 120 cls: OpC 121 Operator sub-class to instantiate. 122 dim_shape: NDArrayShape 123 Operator domain shape (M1,...,MD). 124 codim_shape: NDArrayShape 125 Operator co-domain shape (N1,...,NK). 126 kwargs: dict 127 (k[str], v[callable]) pairs to use as arithmetic methods. 128 129 Keys are restricted to the following arithmetic methods: 130 131 .. code-block:: python3 132 133 apply(), grad(), prox(), pinv(), adjoint() 134 135 Omitted arithmetic methods default to those provided by `cls`, or are auto-inferred via auto-diff rules. 136 vectorize: VarName 137 Arithmetic methods to vectorize. 138 139 `vectorize` is useful if an arithmetic method provided to `kwargs` does not support stacking dimensions. 140 jit: bool 141 If ``True``, JIT-compile JAX-backed arithmetic methods for better performance. 142 enable_warnings: bool 143 If ``True``, emit warnings in case of precision/zero-copy issues. 144 145 Returns 146 ------- 147 op: OpT 148 Pyxu-compliant operator :math:`A: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{N_{1} 149 \times\cdots\times N_{K}}`. 150 151 152 Notes 153 ----- 154 * If provided, arithmetic methods must abide exactly to the interface below: 155 156 .. code-block:: python3 157 158 def apply(arr: jax.Array) -> jax.Array # (..., M1,...,MD) -> (..., N1,...,NK) 159 def grad(arr: jax.Array) -> jax.Array # (..., M1,...,MD) -> (..., M1,...,MD) 160 def adjoint(arr: jax.Array) -> jax.Array # (..., N1,...,NK) -> (..., M1,...,MD) 161 def prox(arr: jax.Array, tau: pxt.Real) -> jax.Array # (..., M1,...,MD) -> (..., M1,...,MD) 162 def pinv(arr: jax.Array, damp: pxt.Real) -> jax.Array # (..., N1,...,NK) -> (..., M1,...,MD) 163 164 Moreover, the methods above **must** accept stacking dimensions in ``arr``. If this does not hold, consider 165 populating `vectorize`. 166 167 * Auto-vectorization consists in decorating `kwargs`-specified arithmetic methods with 168 :py:func:`jax.numpy.vectorize`. 169 170 * Note that JAX enforces `32-bit arithmetic 171 <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision>`_ by default, 172 and this constraint cannot be changed at runtime. As such, to allow zero-copy transfers between JAX and 173 NumPy/CuPy arrays, it is advised to perform computations in single-precision mode. 174 175 * Inferred arithmetic methods are not JIT-ed by default since the operation is error-prone depending on how 176 :py:meth:`~pyxu.abc.Map.apply` is defined. If :py:meth:`~pyxu.abc.Map.apply` supplied to 177 :py:func:`~pyxu.operator.interop.from_jax` is JIT-friendly, then consider enabling `jit`. 178 179 Examples 180 -------- 181 Create the custom differential map :math:`f: \mathbb{R}^{2} \to \mathbb{R}^{3}`: 182 183 .. math:: 184 185 f(x, y) = 186 \left[ 187 \sin(x) + \cos(y), 188 \cos(x) - \sin(y), 189 \sin(x) + \cos(x) 190 \right] 191 192 .. code-block:: python3 193 194 import pyxu.abc as pxa 195 import pyxu.runtime as pxrt 196 import pyxu.operator.interop as pxi 197 import jax, jax.numpy as jnp 198 import numpy as np 199 200 @jax.jit 201 def j_apply(arr: jax.Array) -> jax.Array: 202 x, y = arr[0], arr[1] 203 o1 = jnp.sin(x) + jnp.cos(y) 204 o2 = jnp.cos(x) - jnp.sin(y) 205 o3 = jnp.sin(x) + jnp.cos(x) 206 out = jnp.r_[o1, o2, o3] 207 return out 208 209 op = pxi.from_jax( 210 cls=pxa.DiffMap, 211 dim_shape=2, 212 codim_shape=3, 213 vectorize="apply", # j_apply() does not work on stacked inputs 214 # --> let JAX figure it out automatically. 215 apply=j_apply, 216 ) 217 218 rng = np.random.default_rng(0) 219 x = rng.normal(size=(5,3,4,2)) 220 y1 = op.apply(x) # (5,3,4,3) 221 222 x = rng.normal(size=(2,)) 223 opJ = op.jacobian(x) # JAX auto-infers the Jacobian for you. 224 225 v = rng.normal(size=(5,2)) 226 w = rng.normal(size=(4,3)) 227 y2 = opJ.apply(v) # (5,3) 228 y3 = opJ.adjoint(w) # (4,2) 229 """ 230 if isinstance(vectorize, str): 231 vectorize = (vectorize,) 232 vectorize = frozenset(vectorize) 233 234 src = _FromJax( 235 cls=cls, 236 dim_shape=dim_shape, 237 codim_shape=codim_shape, 238 vectorize=vectorize, 239 jit=bool(jit), 240 enable_warnings=bool(enable_warnings), 241 **kwargs, 242 ) 243 op = src.op() 244 return op
245 246 247class _FromJax(px_src._FromSource): 248 # supported methods in __init__(**kwargs) 249 _meth = frozenset({"apply", "grad", "prox", "pinv", "adjoint"}) 250 251 def __init__( # See from_jax() for a detailed description. 252 self, 253 cls: pxt.OpC, 254 dim_shape: pxt.NDArrayShape, 255 codim_shape: pxt.NDArrayShape, 256 vectorize: frozenset[str], 257 jit: bool, 258 enable_warnings: bool, 259 **kwargs, 260 ): 261 super().__init__( 262 cls=cls, 263 dim_shape=dim_shape, 264 codim_shape=codim_shape, 265 embed=dict(), # jax-funcs are state-free. 266 vectorize=vectorize, 267 **kwargs, 268 ) 269 270 self._jit = jit 271 self._enable_warnings = enable_warnings 272 273 # Only a subset of arithmetic methods allowed from from_jax(). 274 if not (set(self._kwargs) <= self._meth): 275 msg_head = "Unsupported arithmetic methods:" 276 unsupported = set(self._kwargs) - self._meth 277 msg_tail = ", ".join([f"{name}()" for name in unsupported]) 278 raise ValueError(f"{msg_head} {msg_tail}") 279 280 def op(self) -> pxt.OpT: 281 # Idea: modify `**kwargs` from constructor to [when required]: 282 # 1. auto-define omitted methods. [_infer_missing()] 283 # 2. auto-vectorize via vmap(). [_auto_vectorize()] 284 # 3. JIT & JAX<>NumPy/CuPy conversions. [_interface()] 285 self._infer_missing() 286 self._auto_vectorize() 287 j_state, kwargs = self._interface() 288 289 _op = px_src.from_source( 290 cls=self._op.__class__, 291 dim_shape=self._op.dim_shape, 292 codim_shape=self._op.codim_shape, 293 embed=dict( 294 _jax=j_state, 295 _enable_warnings=self._enable_warnings, 296 _jit=self._jit, 297 ), 298 # vectorize=None, # see top-level comment. 299 **kwargs, 300 ) 301 return _op 302 303 def _infer_missing(self): 304 # The following methods must be auto-inferred if missing from `kwargs`: 305 # 306 # grad(), adjoint() 307 # 308 # Missing methods are auto-inferred via auto-diff rules and added to `_kwargs`. 309 # At the end of _infer_missing(), all jax-funcs required for _interface() have been added to `_kwargs`. 310 # 311 # Notes 312 # ----- 313 # This method does NOT produce vectorized implementations: _auto_vectorize() is responsible for this. 314 self._vectorize = set(self._vectorize) # to allow updates below 315 316 nl_difffunc = all( # non-linear diff-func 317 [ 318 self._op.has(pxa.Property.DIFFERENTIABLE_FUNCTION), 319 not self._op.has(pxa.Property.LINEAR), 320 ] 321 ) 322 if nl_difffunc and ("grad" not in self._kwargs): 323 324 def f_grad(arr: JaxArray) -> JaxArray: 325 f = self._kwargs["apply"] 326 y, f_vjp = jax.vjp(f, arr) 327 v = jnp.ones_like(y) 328 out = f_vjp(v)[0] # f_vjp() returns a tuple 329 return out 330 331 self._vectorize.add("grad") 332 self._kwargs["grad"] = f_grad 333 334 non_selfadj = all( # linear, but not self-adjoint 335 [ 336 self._op.has(pxa.Property.LINEAR), 337 not self._op.has(pxa.Property.LINEAR_SELF_ADJOINT), 338 ] 339 ) 340 if non_selfadj and ("adjoint" not in self._kwargs): 341 342 def f_adjoint(arr: JaxArray) -> JaxArray: 343 f = self._kwargs["apply"] 344 x = jnp.zeros_like(arr, shape=self._op.dim_shape) 345 _, f_vjp = jax.vjp(f, x) 346 out = f_vjp(arr)[0] # f_vjp() returns a tuple 347 return out 348 349 self._vectorize.add("adjoint") 350 self._kwargs["adjoint"] = f_adjoint 351 352 self._vectorize = frozenset(self._vectorize) 353 354 def _auto_vectorize(self): 355 # Vectorize user-specified [or _infer_missing()-added] arithmetic methods via jax.vmap(). 356 # 357 # Modifies `_kwargs` to hold vectorized jax-funcs. 358 d_sh = ",".join([f"m{i}" for i in range(self._op.dim_rank)]) # dim_shape 359 cd_sh = ",".join([f"n{i}" for i in range(self._op.codim_rank)]) # codim_shape 360 vkwargs = dict( # kwargs to jax.numpy.vectorize() 361 apply=dict(signature=f"({d_sh})->({cd_sh})"), 362 adjoint=dict(signature=f"({cd_sh})->({d_sh})"), 363 grad=dict(signature=f"({d_sh})->({d_sh})"), 364 prox=dict(signature=f"({d_sh})->({d_sh})", excluded={1}), 365 pinv=dict(signature=f"({cd_sh})->({d_sh})", excluded={1}), 366 ) 367 368 for name in self._kwargs: 369 if name in self._vectorize: 370 func = self._kwargs[name] # necessarily jax_func 371 vectorize = functools.partial( 372 jax.numpy.vectorize, 373 **vkwargs[name], 374 ) 375 self._kwargs[name] = vectorize(func) 376 377 def _interface(self): 378 # Arithmetic methods supplied in `kwargs`: 379 # 380 # * take `jax.Array` inputs 381 # * do not have the `self` parameter. (Reason: to be JIT-compatible.) 382 # 383 # This method creates modified arithmetic functions to match Pyxu's API, and `state` required for them to work. 384 # 385 # Returns 386 # ------- 387 # j_state: dict 388 # Jax functions referenced by wrapper arithmetic methods. (See below.) 389 # These functions are JIT-compiled if specified. 390 # kwargs: dict 391 # Pyxu-compatible functions which can be submitted to interop.from_source(). 392 j_state = dict() 393 for name, obj in self._kwargs.items(): 394 if name in self._meth: 395 if self._jit: 396 obj = jax.jit(obj) 397 j_state[name] = obj # necessarily jax_func 398 399 kwargs = dict() 400 for name, obj in self._kwargs.items(): 401 func = getattr(self.__class__, name) 402 kwargs[name] = func # necessarily a pyxu_func 403 404 # Special cases. 405 # (Reason: not in `_kwargs` [c.f. from_jax() docstring], but need to be inferred.) 406 for name in ("jacobian", "_quad_spec", "asarray", "_expr"): 407 kwargs[name] = getattr(self.__class__, name) 408 409 return j_state, kwargs 410 411 # Wrapper arithmetic methods ---------------------------------------------- 412 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 413 j_arr = _to_jax(arr, enable_warnings=self._enable_warnings) 414 func = self._jax["apply"] 415 dev = j_arr.devices().pop() 416 with jax.default_device(dev): 417 j_out = func(j_arr) 418 out = _from_jax(j_out, xp=pxu.get_array_module(arr)) 419 return out 420 421 def grad(self, arr: pxt.NDArray) -> pxt.NDArray: 422 j_arr = _to_jax(arr, enable_warnings=self._enable_warnings) 423 func = self._jax["grad"] 424 dev = j_arr.devices().pop() 425 with jax.default_device(dev): 426 j_out = func(j_arr) 427 out = _from_jax(j_out, xp=pxu.get_array_module(arr)) 428 return out 429 430 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 431 j_arr = _to_jax(arr, enable_warnings=self._enable_warnings) 432 func = self._jax["adjoint"] 433 dev = j_arr.devices().pop() 434 with jax.default_device(dev): 435 j_out = func(j_arr) 436 out = _from_jax(j_out, xp=pxu.get_array_module(arr)) 437 return out 438 439 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray: 440 j_arr = _to_jax(arr, enable_warnings=self._enable_warnings) 441 func = self._jax["prox"] 442 dev = j_arr.devices().pop() 443 with jax.default_device(dev): 444 j_out = func(j_arr, tau) # positional args only if auto-vectorized. 445 out = _from_jax(j_out, xp=pxu.get_array_module(arr)) 446 return out 447 448 def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray: 449 j_arr = _to_jax(arr, enable_warnings=self._enable_warnings) 450 func = self._jax["pinv"] 451 dev = j_arr.devices().pop() 452 with jax.default_device(dev): 453 j_out = func(j_arr, damp) # positional args only if auto-vectorized. 454 out = _from_jax(j_out, xp=pxu.get_array_module(arr)) 455 return out 456 457 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT: 458 try: 459 # Use the class' method if available ... 460 klass = self.__class__ 461 op = klass.jacobian(self, arr) 462 except NotImplementedError: 463 # ... and fallback to auto-inference if undefined. 464 f = self._jax["apply"] 465 j_arr = _to_jax(arr, enable_warnings=self._enable_warnings) 466 467 # define forward: [1] explains why jvp() is a better fit than linearize(). 468 # [1] https://jax.readthedocs.io/en/latest/_autosummary/jax.linearize.html 469 _fwd = functools.partial(jax.jvp, f, (j_arr,)) 470 f_fwd = lambda arr: _fwd((arr,))[1] # jax returns a tuple 471 472 # define adjoint: [2] explains benefits of linear_transpose() over vjp(). 473 # [2] https://jax.readthedocs.io/en/latest/_autosummary/jax.linear_transpose.html 474 hint = types.SimpleNamespace(shape=self.dim_shape, dtype=arr.dtype) 475 _adj = jax.linear_transpose(f_fwd, hint) 476 f_adj = lambda arr: _adj(arr)[0] # jax returns a tuple 477 478 klass = pxa.LinFunc if (self.codim_shape == (1,)) else pxa.LinOp 479 op = from_jax( 480 cls=klass, 481 dim_shape=self.dim_shape, 482 codim_shape=self.codim_shape, 483 vectorize=("apply", "adjoint"), 484 jit=self._jit, 485 enable_warnings=self._enable_warnings, 486 apply=f_fwd, 487 adjoint=f_adj, 488 ) 489 return op 490 491 def _quad_spec(self): 492 if self.has(pxa.Property.QUADRATIC): 493 # auto-infer (Q, c, t) 494 def _grad(arr: JaxArray) -> JaxArray: 495 # Just like jax.grad(f)(arr), but works with (1,)-valued functions. 496 # [jax.grad(f) expects scalar outputs.] 497 f = self._jax["apply"] 498 y, f_vjp = jax.vjp(f, arr) 499 v = jnp.ones_like(y) 500 out = f_vjp(v)[0] # f_vjp() returns a tuple 501 return out 502 503 # vectorize & JIT internal function 504 d_sh = ",".join([f"m{i}" for i in range(self._op.dim_rank)]) # dim_shape 505 _grad = jnp.vectorize(_grad, signature=f"({d_sh})->({d_sh})") 506 if self._jit: 507 _grad = jax.jit(_grad) 508 509 def Q_apply(arr: JaxArray) -> JaxArray: 510 # \grad_{f}(x) = Q x + c = Q x + \grad_{f}(0) 511 # ==> Q x = \grad_{f}(x) - \grad_{f}(0) 512 z = jnp.zeros_like(arr) 513 out = _grad(arr) - _grad(z) 514 return out 515 516 def c_apply(arr: JaxArray) -> JaxArray: 517 z = jnp.zeros_like(arr) 518 c = _grad(z) 519 out = jnp.sum(c * arr)[jnp.newaxis] 520 return out 521 522 Q = from_jax( 523 cls=pxa.PosDefOp, 524 dim_shape=self.dim_shape, 525 codim_shape=self.dim_shape, 526 vectorize="apply", 527 jit=self._jit, 528 enable_warnings=self._enable_warnings, 529 apply=Q_apply, 530 ) 531 c = from_jax( 532 cls=pxa.LinFunc, 533 dim_shape=self.dim_shape, 534 codim_shape=1, 535 vectorize="apply", 536 jit=self._jit, 537 enable_warnings=self._enable_warnings, 538 apply=c_apply, 539 ) 540 541 # `t` can be computed using any backend, so we choose NUMPY. 542 f = self._jax["apply"] 543 with jax.default_device(jax.devices("cpu")[0]): 544 t = float(f(jnp.zeros(self.dim_shape))) 545 546 return (Q, c, t) 547 else: 548 raise NotImplementedError 549 550 def asarray(self, **kwargs) -> pxt.NDArray: 551 if self.has(pxa.Property.LINEAR): 552 # JAX operators don't accept DASK inputs: cannot call Lin[Op,Func].asarray() with user-specified `xp` value. 553 # -> We arbitrarily perform computations using the NUMPY backend, then cast as needed. 554 N = pxd.NDArrayInfo # shorthand 555 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value) 556 xp = kwargs.get("xp", N.default().module()) 557 558 klass = self.__class__ 559 A = klass.asarray(self, dtype=dtype, xp=N.NUMPY.module()) 560 561 # Not the most efficient method, but fail-proof 562 B = xp.array(A, dtype=dtype) 563 return B 564 else: 565 raise NotImplementedError 566 567 def _expr(self) -> tuple: 568 # show which arithmetic methods are backed by jax-funcs 569 jax_funcs = tuple(self._jax.keys()) 570 return ("from_jax", *jax_funcs)