Source code for pyxu.operator.interop.sciop

  1import warnings
  2
  3import scipy.sparse.linalg as spsl
  4
  5import pyxu.abc as pxa
  6import pyxu.info.deps as pxd
  7import pyxu.info.ptype as pxt
  8import pyxu.info.warning as pxw
  9import pyxu.operator.interop.source as px_src
 10import pyxu.runtime as pxrt
 11import pyxu.util as pxu
 12
 13__all__ = [
 14    "from_sciop",
 15    "to_sciop",
 16]
 17
 18
[docs] 19def from_sciop(cls: pxt.OpC, sp_op: spsl.LinearOperator) -> pxt.OpT: 20 r""" 21 Wrap a :py:class:`~scipy.sparse.linalg.LinearOperator` as a 2D :py:class:`~pyxu.abc.LinOp` (or sub-class thereof). 22 23 Parameters 24 ---------- 25 sp_op: ~scipy.sparse.linalg.LinearOperator 26 (N, M) Linear CPU/GPU operator compliant with SciPy's interface. 27 28 Returns 29 ------- 30 op: OpT 31 Pyxu-compliant linear operator with: 32 33 * dim_shape: (M,) 34 * codim_shape: (N,) 35 """ 36 assert cls.has(pxa.Property.LINEAR) 37 38 if sp_op.dtype not in [_.value for _ in pxrt.Width]: 39 warnings.warn( 40 "Computation may not be performed at the requested precision.", 41 pxw.PrecisionWarning, 42 ) 43 44 # [r]matmat only accepts 2D inputs -> reshape apply|adjoint inputs as needed. 45 46 def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray: 47 sh = arr.shape[:-1] 48 arr = arr.reshape(-1, _.dim_size) 49 out = _._sp_op.matmat(arr.T).T 50 out = out.reshape(*sh, _.codim_size) 51 return out 52 53 def op_adjoint(_, arr: pxt.NDArray) -> pxt.NDArray: 54 sh = arr.shape[:-1] 55 arr = arr.reshape(-1, _.codim_size) 56 out = _._sp_op.rmatmat(arr.T).T 57 out = out.reshape(*sh, _.dim_size) 58 return out 59 60 def op_asarray(_, **kwargs) -> pxt.NDArray: 61 # Determine XP-module accepted by sci_op, then compute array-representation. 62 for ndi in [ 63 pxd.NDArrayInfo.NUMPY, 64 pxd.NDArrayInfo.CUPY, 65 ]: 66 try: 67 cls = _.__class__ 68 _A = cls.asarray(_, xp=ndi.module(), dtype=_._sp_op.dtype) 69 break 70 except Exception: 71 pass 72 73 # Cast to user specs. 74 xp = kwargs.get("xp", pxd.NDArrayInfo.NUMPY.module()) 75 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value) 76 A = xp.array(pxu.to_NUMPY(_A), dtype=dtype) 77 return A 78 79 def op_expr(_) -> tuple: 80 return ("from_sciop", _._sp_op) 81 82 op = px_src.from_source( 83 cls=cls, 84 dim_shape=sp_op.shape[1], 85 codim_shape=sp_op.shape[0], 86 apply=op_apply, 87 adjoint=op_adjoint, 88 asarray=op_asarray, 89 _expr=op_expr, 90 ) 91 op._sp_op = sp_op 92 93 return op
94 95 96def to_sciop( 97 op: pxt.OpT, 98 dtype: pxt.DType = None, 99 gpu: bool = False, 100) -> spsl.LinearOperator: 101 r""" 102 Cast a :py:class:`~pyxu.abc.LinOp` to a CPU/GPU :py:class:`~scipy.sparse.linalg.LinearOperator`, compatible with 103 the matrix-free linear algebra routines of :py:mod:`scipy.sparse.linalg`. 104 105 Parameters 106 ---------- 107 dtype: DType 108 Working precision of the linear operator. 109 gpu: bool 110 Operate on CuPy inputs (True) vs. NumPy inputs (False). 111 112 Returns 113 ------- 114 op: ~scipy.sparse.linalg.LinearOperator 115 Linear operator object compliant with SciPy's interface. 116 """ 117 if not (op.dim_rank == op.codim_rank == 1): 118 msg = "SciPy LinOps are limited to 1D -> 1D maps." 119 raise ValueError(msg) 120 121 def matmat(arr): 122 return op.apply(arr.T).T 123 124 def rmatmat(arr): 125 return op.adjoint(arr.T).T 126 127 if dtype is None: 128 dtype = pxrt.Width.DOUBLE.value 129 130 if gpu: 131 assert pxd.CUPY_ENABLED 132 spx = pxu.import_module("cupyx.scipy.sparse.linalg") 133 else: 134 spx = spsl 135 return spx.LinearOperator( 136 shape=(op.codim_size, op.dim_size), 137 matvec=matmat, 138 rmatvec=rmatmat, 139 matmat=matmat, 140 rmatmat=rmatmat, 141 dtype=dtype, 142 )