Source code for pyxu.operator.interop.sciop
1import warnings
2
3import scipy.sparse.linalg as spsl
4
5import pyxu.abc as pxa
6import pyxu.info.deps as pxd
7import pyxu.info.ptype as pxt
8import pyxu.info.warning as pxw
9import pyxu.operator.interop.source as px_src
10import pyxu.runtime as pxrt
11import pyxu.util as pxu
12
13__all__ = [
14 "from_sciop",
15 "to_sciop",
16]
17
18
[docs]
19def from_sciop(cls: pxt.OpC, sp_op: spsl.LinearOperator) -> pxt.OpT:
20 r"""
21 Wrap a :py:class:`~scipy.sparse.linalg.LinearOperator` as a 2D :py:class:`~pyxu.abc.LinOp` (or sub-class thereof).
22
23 Parameters
24 ----------
25 sp_op: ~scipy.sparse.linalg.LinearOperator
26 (N, M) Linear CPU/GPU operator compliant with SciPy's interface.
27
28 Returns
29 -------
30 op: OpT
31 Pyxu-compliant linear operator with:
32
33 * dim_shape: (M,)
34 * codim_shape: (N,)
35 """
36 assert cls.has(pxa.Property.LINEAR)
37
38 if sp_op.dtype not in [_.value for _ in pxrt.Width]:
39 warnings.warn(
40 "Computation may not be performed at the requested precision.",
41 pxw.PrecisionWarning,
42 )
43
44 # [r]matmat only accepts 2D inputs -> reshape apply|adjoint inputs as needed.
45
46 def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray:
47 sh = arr.shape[:-1]
48 arr = arr.reshape(-1, _.dim_size)
49 out = _._sp_op.matmat(arr.T).T
50 out = out.reshape(*sh, _.codim_size)
51 return out
52
53 def op_adjoint(_, arr: pxt.NDArray) -> pxt.NDArray:
54 sh = arr.shape[:-1]
55 arr = arr.reshape(-1, _.codim_size)
56 out = _._sp_op.rmatmat(arr.T).T
57 out = out.reshape(*sh, _.dim_size)
58 return out
59
60 def op_asarray(_, **kwargs) -> pxt.NDArray:
61 # Determine XP-module accepted by sci_op, then compute array-representation.
62 for ndi in [
63 pxd.NDArrayInfo.NUMPY,
64 pxd.NDArrayInfo.CUPY,
65 ]:
66 try:
67 cls = _.__class__
68 _A = cls.asarray(_, xp=ndi.module(), dtype=_._sp_op.dtype)
69 break
70 except Exception:
71 pass
72
73 # Cast to user specs.
74 xp = kwargs.get("xp", pxd.NDArrayInfo.NUMPY.module())
75 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value)
76 A = xp.array(pxu.to_NUMPY(_A), dtype=dtype)
77 return A
78
79 def op_expr(_) -> tuple:
80 return ("from_sciop", _._sp_op)
81
82 op = px_src.from_source(
83 cls=cls,
84 dim_shape=sp_op.shape[1],
85 codim_shape=sp_op.shape[0],
86 apply=op_apply,
87 adjoint=op_adjoint,
88 asarray=op_asarray,
89 _expr=op_expr,
90 )
91 op._sp_op = sp_op
92
93 return op
94
95
96def to_sciop(
97 op: pxt.OpT,
98 dtype: pxt.DType = None,
99 gpu: bool = False,
100) -> spsl.LinearOperator:
101 r"""
102 Cast a :py:class:`~pyxu.abc.LinOp` to a CPU/GPU :py:class:`~scipy.sparse.linalg.LinearOperator`, compatible with
103 the matrix-free linear algebra routines of :py:mod:`scipy.sparse.linalg`.
104
105 Parameters
106 ----------
107 dtype: DType
108 Working precision of the linear operator.
109 gpu: bool
110 Operate on CuPy inputs (True) vs. NumPy inputs (False).
111
112 Returns
113 -------
114 op: ~scipy.sparse.linalg.LinearOperator
115 Linear operator object compliant with SciPy's interface.
116 """
117 if not (op.dim_rank == op.codim_rank == 1):
118 msg = "SciPy LinOps are limited to 1D -> 1D maps."
119 raise ValueError(msg)
120
121 def matmat(arr):
122 return op.apply(arr.T).T
123
124 def rmatmat(arr):
125 return op.adjoint(arr.T).T
126
127 if dtype is None:
128 dtype = pxrt.Width.DOUBLE.value
129
130 if gpu:
131 assert pxd.CUPY_ENABLED
132 spx = pxu.import_module("cupyx.scipy.sparse.linalg")
133 else:
134 spx = spsl
135 return spx.LinearOperator(
136 shape=(op.codim_size, op.dim_size),
137 matvec=matmat,
138 rmatvec=rmatmat,
139 matmat=matmat,
140 rmatmat=rmatmat,
141 dtype=dtype,
142 )