pyxu.operator.interop#

General#

from_source(cls, dim_shape, codim_shape, embed=None, vectorize=frozenset({}), **kwargs)[source]#

Define an Operator from low-level constructs.

Parameters:
  • cls (OpC) – Operator sub-class to instantiate.

  • dim_shape (NDArrayShape) – Operator domain shape (M1,…,MD).

  • codim_shape (NDArrayShape) – Operator co-domain shape (N1,…,NK).

  • embed (dict) –

    (k[str], v[value]) pairs to embed into the created operator.

    embed is useful to attach extra information to synthesized Operator used by arithmetic methods.

  • kwargs (dict) –

    (k[str], v[callable]) pairs to use as arithmetic methods.

    Keys must be entries from arithmetic_methods().

    Omitted arithmetic attributes/methods default to those provided by cls.

  • vectorize (VarName) –

    Arithmetic methods to vectorize.

    vectorize is useful if an arithmetic method provided to kwargs (ex: apply()) does not support stacking dimensions.

Returns:

op – Pyxu-compliant operator \(A: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{N_{1} \times\cdots\times N_{K}}\).

Return type:

OpT

Notes

  • If provided, arithmetic methods must abide exactly to the Pyxu interface. In particular, the following arithmetic methods, if supplied, must have the following interface:

    def apply(self, arr: pxt.NDArray) -> pxt.NDArray                   # (..., M1,...,MD) -> (..., N1,...,NK)
    def grad(self, arr: pxt.NDArray) -> pxt.NDArray                    # (..., M1,...,MD) -> (..., M1,...,MD)
    def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray                 # (..., N1,...,NK) -> (..., M1,...,MD)
    def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray     # (..., M1,...,MD) -> (..., M1,...,MD)
    def pinv(self, arr: pxt.NDArray, damp: pxt.Real) -> pxt.NDArray    # (..., N1,...,NK) -> (..., M1,...,MD)
    

    Moreover, the methods above must accept stacking dimensions in arr. If this does not hold, consider populating vectorize.

  • Auto-vectorization consists in decorating kwargs-specified arithmetic methods with vectorize(). Auto-vectorization may be less efficient than explicitly providing a vectorized implementation.

Examples

Creation of the custom element-wise differential operator \(f(\mathbf{x}) = \mathbf{x}^{2}\).

N = 5
f = from_source(
    cls=pyxu.abc.DiffMap,
    dim_shape=N,
    codim_shape=N,
    apply=lambda _, arr: arr**2,
)
x = np.arange(N)
y = f(x)  # [0, 1, 4, 9, 16]
dL = f.diff_lipschitz  # inf (default value provided by DiffMap class.)

In practice we know that \(f\) has a finite-valued diff-Lipschitz constant. It is thus recommended to set it too when instantiating via from_source:

N = 5
f = from_source(
    cls=pyxu.abc.DiffMap,
    dim_shape=N,
    codim_shape=N,
    embed=dict(
        # special form to set (diff-)Lipschitz attributes via from_source()
        _diff_lipschitz=2,
    ),
    apply=lambda _, arr: arr**2,
)
x = np.arange(N)
y = f(x)  # [0, 1, 4, 9, 16]
dL = f.diff_lipschitz  # 2  <- instead of inf

SciPy#

from_sciop(cls, sp_op)[source]#

Wrap a LinearOperator as a 2D LinOp (or sub-class thereof).

Parameters:
  • sp_op (LinearOperator) – (N, M) Linear CPU/GPU operator compliant with SciPy’s interface.

  • cls (Type[OpT])

Returns:

op – Pyxu-compliant linear operator with:

  • dim_shape: (M,)

  • codim_shape: (N,)

Return type:

OpT

JAX#

from_jax(cls, dim_shape, codim_shape, vectorize=frozenset({}), jit=False, enable_warnings=True, **kwargs)[source]#

Define an Operator from JAX functions.

Parameters:
  • cls (OpC) – Operator sub-class to instantiate.

  • dim_shape (NDArrayShape) – Operator domain shape (M1,…,MD).

  • codim_shape (NDArrayShape) – Operator co-domain shape (N1,…,NK).

  • kwargs (dict) –

    (k[str], v[callable]) pairs to use as arithmetic methods.

    Keys are restricted to the following arithmetic methods:

    apply(), grad(), prox(), pinv(), adjoint()
    

    Omitted arithmetic methods default to those provided by cls, or are auto-inferred via auto-diff rules.

  • vectorize (VarName) –

    Arithmetic methods to vectorize.

    vectorize is useful if an arithmetic method provided to kwargs does not support stacking dimensions.

  • jit (bool) – If True, JIT-compile JAX-backed arithmetic methods for better performance.

  • enable_warnings (bool) – If True, emit warnings in case of precision/zero-copy issues.

Returns:

op – Pyxu-compliant operator \(A: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{N_{1} \times\cdots\times N_{K}}\).

Return type:

OpT

Notes

  • If provided, arithmetic methods must abide exactly to the interface below:

    def apply(arr: jax.Array) -> jax.Array                  # (..., M1,...,MD) -> (..., N1,...,NK)
    def grad(arr: jax.Array) -> jax.Array                   # (..., M1,...,MD) -> (..., M1,...,MD)
    def adjoint(arr: jax.Array) -> jax.Array                # (..., N1,...,NK) -> (..., M1,...,MD)
    def prox(arr: jax.Array, tau: pxt.Real) -> jax.Array    # (..., M1,...,MD) -> (..., M1,...,MD)
    def pinv(arr: jax.Array, damp: pxt.Real) -> jax.Array   # (..., N1,...,NK) -> (..., M1,...,MD)
    

    Moreover, the methods above must accept stacking dimensions in arr. If this does not hold, consider populating vectorize.

  • Auto-vectorization consists in decorating kwargs-specified arithmetic methods with jax.numpy.vectorize().

  • Note that JAX enforces 32-bit arithmetic by default, and this constraint cannot be changed at runtime. As such, to allow zero-copy transfers between JAX and NumPy/CuPy arrays, it is advised to perform computations in single-precision mode.

  • Inferred arithmetic methods are not JIT-ed by default since the operation is error-prone depending on how apply() is defined. If apply() supplied to from_jax() is JIT-friendly, then consider enabling jit.

Examples

Create the custom differential map \(f: \mathbb{R}^{2} \to \mathbb{R}^{3}\):

\[f(x, y) = \left[ \sin(x) + \cos(y), \cos(x) - \sin(y), \sin(x) + \cos(x) \right]\]
import pyxu.abc as pxa
import pyxu.runtime as pxrt
import pyxu.operator.interop as pxi
import jax, jax.numpy as jnp
import numpy as np

@jax.jit
def j_apply(arr: jax.Array) -> jax.Array:
    x, y = arr[0], arr[1]
    o1 = jnp.sin(x) + jnp.cos(y)
    o2 = jnp.cos(x) - jnp.sin(y)
    o3 = jnp.sin(x) + jnp.cos(x)
    out = jnp.r_[o1, o2, o3]
    return out

op = pxi.from_jax(
    cls=pxa.DiffMap,
    dim_shape=2,
    codim_shape=3,
    vectorize="apply",  # j_apply() does not work on stacked inputs
                        # --> let JAX figure it out automatically.
    apply=j_apply,
)

rng = np.random.default_rng(0)
x = rng.normal(size=(5,3,4,2))
y1 = op.apply(x)  # (5,3,4,3)

x = rng.normal(size=(2,))
opJ = op.jacobian(x)  # JAX auto-infers the Jacobian for you.

v = rng.normal(size=(5,2))
w = rng.normal(size=(4,3))
y2 = opJ.apply(v)  # (5,3)
y3 = opJ.adjoint(w)  # (4,2)
_from_jax(x, xp=None)[source]#

JAX -> NumPy/CuPy conversion.

The transform is always zero-copy, but it is not easy to check this condition for all array types (contiguous, views, etc.) and backends (NUMPY, CUPY).

[More info] google/jax#1961

Parameters:
Return type:

NDArray

_to_jax(x, enable_warnings=True)[source]#

NumPy/CuPy -> JAX conversion.

Conversion is zero-copy when possible, i.e. 16-byte alignment, on the right device, etc.

[More info] google/jax#4486

Parameters:
Return type:

JaxArray

PyTorch#

from_torch(cls, dim_shape, codim_shape, vectorize=frozenset({}), jit=False, enable_warnings=True, **kwargs)[source]#

Define an Operator from PyTorch functions.

Parameters:
  • cls (OpC) – Operator sub-class to instantiate.

  • dim_shape (NDArrayShape) – Operator domain shape (M1,…,MD).

  • codim_shape (NDArrayShape) – Operator co-domain shape (N1,…,NK).

  • kwargs (dict) –

    (k[str], v[callable]) pairs to use as arithmetic methods.

    Keys are restricted to the following arithmetic methods:

    apply(), grad(), prox(), pinv(), adjoint()
    

    Omitted arithmetic methods default to those provided by cls, or are auto-inferred via auto-diff rules.

  • vectorize (VarName) –

    Arithmetic methods to vectorize.

    vectorize is useful if an arithmetic method provided to kwargs does not support stacking dimensions.

  • jit (bool) – Currently has no effect (for future-compatibility only). In the future, if True, then Torch-backed arithmetic methods will be JIT-compiled for better performance.

  • enable_warnings (bool) – If True, emit warnings in case of precision/zero-copy issues.

Returns:

  • op (OpT) – Pyxu-compliant operator :math:`A: mathbb{R}^{M_{1} timescdotstimes M_{D}} to mathbb{R}^{N_{1}

    timescdotstimes N_{K}}`.

    Notes

  • -----

  • * If provided, arithmetic methods must abide exactly to the interface below

    def apply(arr: torch.Tensor) -> torch.Tensor                  # (..., M1,...,MD) -> (..., N1,...,NK)
    def grad(arr: torch.Tensor) -> torch.Tensor                   # (..., M1,...,MD) -> (..., M1,...,MD)
    def adjoint(arr: torch.Tensor) -> torch.Tensor                # (..., N1,...,NK) -> (..., M1,...,MD)
    def prox(arr: torch.Tensor, tau: pxt.Real) -> torch.Tensor    # (..., M1,...,MD) -> (..., M1,...,MD)
    def pinv(arr: torch.Tensor, damp: pxt.Real) -> torch.Tensor   # (..., N1,...,NK) -> (..., M1,...,MD)
    

    Moreover, the methods above must accept stacking dimensions in arr. If this does not hold, consider populating vectorize.

  • * Auto-vectorization consists in decorating `kwargs-specified arithmetic methods with torch.vmap(). See` – the PyTorch documentation for known limitations.

  • * Arithmetic methods are **not currently JIT-ed** even if `jit is set` to True. This is because of the – undocumented and currently poor interaction between torch.func transforms and torch.compile(). See this issue for additional details.

  • * For :py:class:`~pyxu.abc.DiffMap (or subclasses thereof)`, the methods :py:meth:`~pyxu.abc.DiffMap.jacobian,` – grad() and adjoint() are defined implicitly if not provided using the auto-differentiation transforms from torch.func. As detailed on this page, such transforms work well on pure functions (that is, functions where the output is completely determined by the input and that do not involve side effects like mutation), but may fail on more complex functions. Moreover, torch.func does not yet have full coverage over PyTorch operations. For functions that call a torch.nn.Module, see here for some utilities.

  • Warning:: – Operators created with this wrapper do not support Dask inputs for now.

Return type:

OpT

_from_torch(tensor)[source]#

PyTorch -> NumPy/CuPy conversion.

Convert a PyTorch tensor into a NumPy-like array, sharing data, dtype and device.

Parameters:

tensor (torch.Tensor) – Input tensor.

Returns:

arr – Output array.

Return type:

NDArray

Notes

The returned array and input tensor share the same memory. Modifications to the tensor will be reflected in the ndarray and vice versa.

_to_torch(arr, requires_grad=False)[source]#

NumPy/CuPy -> PyTorch conversion.

Convert a NumPy-like array into a PyTorch tensor, sharing data, dtype and device.

Parameters:
  • arr (NDArray) – Input array.

  • requires_grad (bool) – If autograd should record operations on the returned tensor.

Returns:

tensor – Output tensor.

Return type:

torch.Tensor

Notes

The returned tensor and input array share the same memory. Modifications to the tensor will be reflected in the ndarray and vice versa. The returned tensor is not resizable.