Source code for pyxu.operator.interop.sciop
  1import warnings
  2
  3import scipy.sparse.linalg as spsl
  4
  5import pyxu.abc as pxa
  6import pyxu.info.deps as pxd
  7import pyxu.info.ptype as pxt
  8import pyxu.info.warning as pxw
  9import pyxu.operator.interop.source as px_src
 10import pyxu.runtime as pxrt
 11import pyxu.util as pxu
 12
 13__all__ = [
 14    "from_sciop",
 15    "to_sciop",
 16]
 17
 18
[docs]
 19def from_sciop(cls: pxt.OpC, sp_op: spsl.LinearOperator) -> pxt.OpT:
 20    r"""
 21    Wrap a :py:class:`~scipy.sparse.linalg.LinearOperator` as a 2D :py:class:`~pyxu.abc.LinOp` (or sub-class thereof).
 22
 23    Parameters
 24    ----------
 25    sp_op: ~scipy.sparse.linalg.LinearOperator
 26        (N, M) Linear CPU/GPU operator compliant with SciPy's interface.
 27
 28    Returns
 29    -------
 30    op: OpT
 31        Pyxu-compliant linear operator with:
 32
 33        * dim_shape: (M,)
 34        * codim_shape: (N,)
 35    """
 36    assert cls.has(pxa.Property.LINEAR)
 37
 38    if sp_op.dtype not in [_.value for _ in pxrt.Width]:
 39        warnings.warn(
 40            "Computation may not be performed at the requested precision.",
 41            pxw.PrecisionWarning,
 42        )
 43
 44    # [r]matmat only accepts 2D inputs -> reshape apply|adjoint inputs as needed.
 45
 46    def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray:
 47        sh = arr.shape[:-1]
 48        arr = arr.reshape(-1, _.dim_size)
 49        out = _._sp_op.matmat(arr.T).T
 50        out = out.reshape(*sh, _.codim_size)
 51        return out
 52
 53    def op_adjoint(_, arr: pxt.NDArray) -> pxt.NDArray:
 54        sh = arr.shape[:-1]
 55        arr = arr.reshape(-1, _.codim_size)
 56        out = _._sp_op.rmatmat(arr.T).T
 57        out = out.reshape(*sh, _.dim_size)
 58        return out
 59
 60    def op_asarray(_, **kwargs) -> pxt.NDArray:
 61        # Determine XP-module accepted by sci_op, then compute array-representation.
 62        for ndi in [
 63            pxd.NDArrayInfo.NUMPY,
 64            pxd.NDArrayInfo.CUPY,
 65        ]:
 66            try:
 67                cls = _.__class__
 68                _A = cls.asarray(_, xp=ndi.module(), dtype=_._sp_op.dtype)
 69                break
 70            except Exception:
 71                pass
 72
 73        # Cast to user specs.
 74        xp = kwargs.get("xp", pxd.NDArrayInfo.NUMPY.module())
 75        dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value)
 76        A = xp.array(pxu.to_NUMPY(_A), dtype=dtype)
 77        return A
 78
 79    def op_expr(_) -> tuple:
 80        return ("from_sciop", _._sp_op)
 81
 82    op = px_src.from_source(
 83        cls=cls,
 84        dim_shape=sp_op.shape[1],
 85        codim_shape=sp_op.shape[0],
 86        apply=op_apply,
 87        adjoint=op_adjoint,
 88        asarray=op_asarray,
 89        _expr=op_expr,
 90    )
 91    op._sp_op = sp_op
 92
 93    return op 
 94
 95
 96def to_sciop(
 97    op: pxt.OpT,
 98    dtype: pxt.DType = None,
 99    gpu: bool = False,
100) -> spsl.LinearOperator:
101    r"""
102    Cast a :py:class:`~pyxu.abc.LinOp` to a CPU/GPU :py:class:`~scipy.sparse.linalg.LinearOperator`, compatible with
103    the matrix-free linear algebra routines of :py:mod:`scipy.sparse.linalg`.
104
105    Parameters
106    ----------
107    dtype: DType
108        Working precision of the linear operator.
109    gpu: bool
110        Operate on CuPy inputs (True) vs. NumPy inputs (False).
111
112    Returns
113    -------
114    op: ~scipy.sparse.linalg.LinearOperator
115        Linear operator object compliant with SciPy's interface.
116    """
117    if not (op.dim_rank == op.codim_rank == 1):
118        msg = "SciPy LinOps are limited to 1D -> 1D maps."
119        raise ValueError(msg)
120
121    def matmat(arr):
122        return op.apply(arr.T).T
123
124    def rmatmat(arr):
125        return op.adjoint(arr.T).T
126
127    if dtype is None:
128        dtype = pxrt.Width.DOUBLE.value
129
130    if gpu:
131        assert pxd.CUPY_ENABLED
132        spx = pxu.import_module("cupyx.scipy.sparse.linalg")
133    else:
134        spx = spsl
135    return spx.LinearOperator(
136        shape=(op.codim_size, op.dim_size),
137        matvec=matmat,
138        rmatvec=rmatmat,
139        matmat=matmat,
140        rmatmat=rmatmat,
141        dtype=dtype,
142    )