pyxu.operator.interop#
Table of Contents
General#
- from_source(cls, dim_shape, codim_shape, embed=None, vectorize=frozenset({}), **kwargs)[source]#
Define an
Operator
from low-level constructs.- Parameters:
cls (
OpC
) – Operator sub-class to instantiate.dim_shape (
NDArrayShape
) – Operator domain shape (M1,…,MD).codim_shape (
NDArrayShape
) – Operator co-domain shape (N1,…,NK).embed (
dict
) –(k[str], v[value]) pairs to embed into the created operator.
embed
is useful to attach extra information to synthesizedOperator
used by arithmetic methods.kwargs (
dict
) –(k[str], v[callable]) pairs to use as arithmetic methods.
Keys must be entries from
arithmetic_methods()
.Omitted arithmetic attributes/methods default to those provided by
cls
.vectorize (
VarName
) –Arithmetic methods to vectorize.
vectorize
is useful if an arithmetic method provided tokwargs
(ex:apply()
) does not support stacking dimensions.
- Returns:
op – Pyxu-compliant operator \(A: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{N_{1} \times\cdots\times N_{K}}\).
- Return type:
Notes
If provided, arithmetic methods must abide exactly to the Pyxu interface. In particular, the following arithmetic methods, if supplied, must have the following interface:
def apply(self, arr: pxt.NDArray) -> pxt.NDArray # (..., M1,...,MD) -> (..., N1,...,NK) def grad(self, arr: pxt.NDArray) -> pxt.NDArray # (..., M1,...,MD) -> (..., M1,...,MD) def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray # (..., N1,...,NK) -> (..., M1,...,MD) def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray # (..., M1,...,MD) -> (..., M1,...,MD) def pinv(self, arr: pxt.NDArray, damp: pxt.Real) -> pxt.NDArray # (..., N1,...,NK) -> (..., M1,...,MD)
Moreover, the methods above must accept stacking dimensions in
arr
. If this does not hold, consider populatingvectorize
.Auto-vectorization consists in decorating
kwargs
-specified arithmetic methods withvectorize()
. Auto-vectorization may be less efficient than explicitly providing a vectorized implementation.
Examples
Creation of the custom element-wise differential operator \(f(\mathbf{x}) = \mathbf{x}^{2}\).
N = 5 f = from_source( cls=pyxu.abc.DiffMap, dim_shape=N, codim_shape=N, apply=lambda _, arr: arr**2, ) x = np.arange(N) y = f(x) # [0, 1, 4, 9, 16] dL = f.diff_lipschitz # inf (default value provided by DiffMap class.)
In practice we know that \(f\) has a finite-valued diff-Lipschitz constant. It is thus recommended to set it too when instantiating via
from_source
:N = 5 f = from_source( cls=pyxu.abc.DiffMap, dim_shape=N, codim_shape=N, embed=dict( # special form to set (diff-)Lipschitz attributes via from_source() _diff_lipschitz=2, ), apply=lambda _, arr: arr**2, ) x = np.arange(N) y = f(x) # [0, 1, 4, 9, 16] dL = f.diff_lipschitz # 2 <- instead of inf
SciPy#
- from_sciop(cls, sp_op)[source]#
Wrap a
LinearOperator
as a 2DLinOp
(or sub-class thereof).- Parameters:
sp_op (
LinearOperator
) – (N, M) Linear CPU/GPU operator compliant with SciPy’s interface.
- Returns:
op – Pyxu-compliant linear operator with:
dim_shape: (M,)
codim_shape: (N,)
- Return type:
JAX#
- from_jax(cls, dim_shape, codim_shape, vectorize=frozenset({}), jit=False, enable_warnings=True, **kwargs)[source]#
Define an
Operator
from JAX functions.- Parameters:
cls (
OpC
) – Operator sub-class to instantiate.dim_shape (
NDArrayShape
) – Operator domain shape (M1,…,MD).codim_shape (
NDArrayShape
) – Operator co-domain shape (N1,…,NK).kwargs (
dict
) –(k[str], v[callable]) pairs to use as arithmetic methods.
Keys are restricted to the following arithmetic methods:
apply(), grad(), prox(), pinv(), adjoint()
Omitted arithmetic methods default to those provided by
cls
, or are auto-inferred via auto-diff rules.vectorize (
VarName
) –Arithmetic methods to vectorize.
vectorize
is useful if an arithmetic method provided tokwargs
does not support stacking dimensions.jit (
bool
) – IfTrue
, JIT-compile JAX-backed arithmetic methods for better performance.enable_warnings (
bool
) – IfTrue
, emit warnings in case of precision/zero-copy issues.
- Returns:
op – Pyxu-compliant operator \(A: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{N_{1} \times\cdots\times N_{K}}\).
- Return type:
Notes
If provided, arithmetic methods must abide exactly to the interface below:
def apply(arr: jax.Array) -> jax.Array # (..., M1,...,MD) -> (..., N1,...,NK) def grad(arr: jax.Array) -> jax.Array # (..., M1,...,MD) -> (..., M1,...,MD) def adjoint(arr: jax.Array) -> jax.Array # (..., N1,...,NK) -> (..., M1,...,MD) def prox(arr: jax.Array, tau: pxt.Real) -> jax.Array # (..., M1,...,MD) -> (..., M1,...,MD) def pinv(arr: jax.Array, damp: pxt.Real) -> jax.Array # (..., N1,...,NK) -> (..., M1,...,MD)
Moreover, the methods above must accept stacking dimensions in
arr
. If this does not hold, consider populatingvectorize
.Auto-vectorization consists in decorating
kwargs
-specified arithmetic methods withjax.numpy.vectorize()
.Note that JAX enforces 32-bit arithmetic by default, and this constraint cannot be changed at runtime. As such, to allow zero-copy transfers between JAX and NumPy/CuPy arrays, it is advised to perform computations in single-precision mode.
Inferred arithmetic methods are not JIT-ed by default since the operation is error-prone depending on how
apply()
is defined. Ifapply()
supplied tofrom_jax()
is JIT-friendly, then consider enablingjit
.
Examples
Create the custom differential map \(f: \mathbb{R}^{2} \to \mathbb{R}^{3}\):
\[f(x, y) = \left[ \sin(x) + \cos(y), \cos(x) - \sin(y), \sin(x) + \cos(x) \right]\]import pyxu.abc as pxa import pyxu.runtime as pxrt import pyxu.operator.interop as pxi import jax, jax.numpy as jnp import numpy as np @jax.jit def j_apply(arr: jax.Array) -> jax.Array: x, y = arr[0], arr[1] o1 = jnp.sin(x) + jnp.cos(y) o2 = jnp.cos(x) - jnp.sin(y) o3 = jnp.sin(x) + jnp.cos(x) out = jnp.r_[o1, o2, o3] return out op = pxi.from_jax( cls=pxa.DiffMap, dim_shape=2, codim_shape=3, vectorize="apply", # j_apply() does not work on stacked inputs # --> let JAX figure it out automatically. apply=j_apply, ) rng = np.random.default_rng(0) x = rng.normal(size=(5,3,4,2)) y1 = op.apply(x) # (5,3,4,3) x = rng.normal(size=(2,)) opJ = op.jacobian(x) # JAX auto-infers the Jacobian for you. v = rng.normal(size=(5,2)) w = rng.normal(size=(4,3)) y2 = opJ.apply(v) # (5,3) y3 = opJ.adjoint(w) # (4,2)
- _from_jax(x, xp=None)[source]#
JAX -> NumPy/CuPy conversion.
The transform is always zero-copy, but it is not easy to check this condition for all array types (contiguous, views, etc.) and backends (NUMPY, CUPY).
[More info] google/jax#1961
- Parameters:
x (JaxArray)
xp (ArrayModule)
- Return type:
- _to_jax(x, enable_warnings=True)[source]#
NumPy/CuPy -> JAX conversion.
Conversion is zero-copy when possible, i.e. 16-byte alignment, on the right device, etc.
[More info] google/jax#4486
PyTorch#
- from_torch(cls, dim_shape, codim_shape, vectorize=frozenset({}), jit=False, enable_warnings=True, **kwargs)[source]#
Define an
Operator
from PyTorch functions.- Parameters:
cls (
OpC
) – Operator sub-class to instantiate.dim_shape (
NDArrayShape
) – Operator domain shape (M1,…,MD).codim_shape (
NDArrayShape
) – Operator co-domain shape (N1,…,NK).kwargs (
dict
) –(k[str], v[callable]) pairs to use as arithmetic methods.
Keys are restricted to the following arithmetic methods:
apply(), grad(), prox(), pinv(), adjoint()
Omitted arithmetic methods default to those provided by
cls
, or are auto-inferred via auto-diff rules.vectorize (
VarName
) –Arithmetic methods to vectorize.
vectorize
is useful if an arithmetic method provided tokwargs
does not support stacking dimensions.jit (
bool
) – Currently has no effect (for future-compatibility only). In the future, ifTrue
, then Torch-backed arithmetic methods will be JIT-compiled for better performance.enable_warnings (
bool
) – IfTrue
, emit warnings in case of precision/zero-copy issues.
- Returns:
- op (
OpT
) – Pyxu-compliant operator :math:`A: mathbb{R}^{M_{1} timescdotstimes M_{D}} to mathbb{R}^{N_{1} timescdotstimes N_{K}}`.
Notes
- op (
-----
* If provided
,arithmetic methods must abide exactly
tothe interface below
–def apply(arr: torch.Tensor) -> torch.Tensor # (..., M1,...,MD) -> (..., N1,...,NK) def grad(arr: torch.Tensor) -> torch.Tensor # (..., M1,...,MD) -> (..., M1,...,MD) def adjoint(arr: torch.Tensor) -> torch.Tensor # (..., N1,...,NK) -> (..., M1,...,MD) def prox(arr: torch.Tensor, tau: pxt.Real) -> torch.Tensor # (..., M1,...,MD) -> (..., M1,...,MD) def pinv(arr: torch.Tensor, damp: pxt.Real) -> torch.Tensor # (..., N1,...,NK) -> (..., M1,...,MD)
Moreover, the methods above must accept stacking dimensions in
arr
. If this does not hold, consider populatingvectorize
.* Auto-vectorization consists in decorating `kwargs
-specified arithmetic methods withtorch.vmap()
. See` – the PyTorch documentation for known limitations.* Arithmetic methods are **not currently JIT-ed** even if `jit
is set` toTrue
. This is because ofthe
– undocumented and currently poor interaction betweentorch.func
transforms andtorch.compile()
. See this issue for additional details.* For :py:class:`~pyxu.abc.DiffMap
(or subclasses thereof)`,the methods :py:meth:`~pyxu.abc.DiffMap.jacobian
,` –grad()
andadjoint()
are defined implicitly if not provided using the auto-differentiation transforms fromtorch.func
. As detailed on this page, such transforms work well on pure functions (that is, functions where the output is completely determined by the input and that do not involve side effects like mutation), but may fail on more complex functions. Moreover,torch.func
does not yet have full coverage over PyTorch operations. For functions that call atorch.nn.Module
, see here for some utilities.Warning::
– Operators created with this wrapper do not support Dask inputs for now.
- Return type:
- _from_torch(tensor)[source]#
PyTorch -> NumPy/CuPy conversion.
Convert a PyTorch tensor into a NumPy-like array, sharing data, dtype and device.
- Parameters:
tensor (
torch.Tensor
) – Input tensor.- Returns:
arr – Output array.
- Return type:
Notes
The returned array and input tensor share the same memory. Modifications to the tensor will be reflected in the ndarray and vice versa.
- _to_torch(arr, requires_grad=False)[source]#
NumPy/CuPy -> PyTorch conversion.
Convert a NumPy-like array into a PyTorch tensor, sharing data, dtype and device.
- Parameters:
- Returns:
tensor – Output tensor.
- Return type:
Notes
The returned tensor and input array share the same memory. Modifications to the tensor will be reflected in the ndarray and vice versa. The returned tensor is not resizable.