1import types
2import typing as typ
3import warnings
4
5import numpy as np
6
7import pyxu.abc as pxa
8import pyxu.info.deps as pxd
9import pyxu.info.ptype as pxt
10import pyxu.info.warning as pxw
11import pyxu.operator.interop.source as px_src
12import pyxu.runtime as pxrt
13import pyxu.util as pxu
14
15__all__ = [
16 "IdentityOp",
17 "NullOp",
18 "NullFunc",
19 "HomothetyOp",
20 "DiagonalOp",
21]
22
23
[docs]
24class IdentityOp(pxa.OrthProjOp):
25 """
26 Identity operator.
27 """
28
29 def __init__(self, dim_shape: pxt.NDArrayShape):
30 super().__init__(
31 dim_shape=dim_shape,
32 codim_shape=dim_shape,
33 )
34 self.lipschitz = 1
35
36 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
37 return pxu.read_only(arr)
38
39 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
40 return pxu.read_only(arr)
41
42 def svdvals(self, **kwargs) -> pxt.NDArray:
43 return pxa.UnitOp.svdvals(self, **kwargs)
44
45 def asarray(self, **kwargs) -> pxt.NDArray:
46 xp = kwargs.get("xp", pxd.NDArrayInfo.default().module())
47 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value)
48 A = xp.eye(N=self.dim_size, dtype=dtype)
49 B = A.reshape(*self.codim_shape, *self.dim_shape)
50 return B
51
52 def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray:
53 out = arr.copy()
54 out /= 1 + damp
55 return out
56
57 def dagger(self, damp: pxt.Real, **kwargs) -> pxt.OpT:
58 op = HomothetyOp(
59 cst=1 / (1 + damp),
60 dim_shape=self.dim_shape,
61 )
62 return op
63
64 def trace(self, **kwargs) -> pxt.Real:
65 return self.dim_size
66
67
[docs]
68class NullOp(pxa.LinOp):
69 """
70 Null operator.
71
72 This operator maps any input vector on the null vector.
73 """
74
75 def __init__(
76 self,
77 dim_shape: pxt.NDArrayShape,
78 codim_shape: pxt.NDArrayShape,
79 ):
80 super().__init__(
81 dim_shape=dim_shape,
82 codim_shape=codim_shape,
83 )
84 self.lipschitz = 0
85
86 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
87 ndi = pxd.NDArrayInfo.from_obj(arr)
88 kwargs = dict()
89 if ndi == pxd.NDArrayInfo.DASK:
90 stack_chunks = arr.chunks[: -self.dim_rank]
91 core_chunks = ("auto",) * self.codim_rank
92 kwargs.update(chunks=stack_chunks + core_chunks)
93
94 xp = ndi.module()
95 return xp.broadcast_to(
96 xp.array(0, arr.dtype),
97 (*arr.shape[: -self.dim_rank], *self.codim_shape),
98 **kwargs,
99 )
100
101 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
102 ndi = pxd.NDArrayInfo.from_obj(arr)
103 kwargs = dict()
104 if ndi == pxd.NDArrayInfo.DASK:
105 stack_chunks = arr.chunks[: -self.codim_rank]
106 core_chunks = ("auto",) * self.dim_rank
107 kwargs.update(chunks=stack_chunks + core_chunks)
108
109 xp = ndi.module()
110 return xp.broadcast_to(
111 xp.array(0, arr.dtype),
112 (*arr.shape[: -self.codim_rank], *self.dim_shape),
113 **kwargs,
114 )
115
116 def svdvals(self, **kwargs) -> pxt.NDArray:
117 gpu = kwargs.get("gpu", False)
118 xp = pxd.NDArrayInfo.from_flag(gpu).module()
119 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value)
120
121 D = xp.zeros(kwargs["k"], dtype=dtype)
122 return D
123
124 def gram(self) -> pxt.OpT:
125 op = NullOp(
126 dim_shape=self.dim_shape,
127 codim_shape=self.dim_shape,
128 )
129 return op.asop(pxa.SelfAdjointOp)
130
131 def cogram(self) -> pxt.OpT:
132 op = NullOp(
133 dim_shape=self.codim_shape,
134 codim_shape=self.codim_shape,
135 )
136 return op.asop(pxa.SelfAdjointOp)
137
138 def asarray(self, **kwargs) -> pxt.NDArray:
139 xp = kwargs.get("xp", pxd.NDArrayInfo.default().module())
140 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value)
141 A = xp.broadcast_to(
142 xp.array(0, dtype=dtype),
143 (*self.codim_shape, *self.dim_shape),
144 )
145 return A
146
147 def trace(self, **kwargs) -> pxt.Real:
148 return 0
149
150
[docs]
151def NullFunc(dim_shape: pxt.NDArrayShape) -> pxt.OpT:
152 """
153 Null functional.
154
155 This functional maps any input vector on the null scalar.
156 """
157 op = NullOp(
158 dim_shape=dim_shape,
159 codim_shape=1,
160 ).asop(pxa.LinFunc)
161 op._name = "NullFunc"
162 return op
163
164
[docs]
165def HomothetyOp(dim_shape: pxt.NDArrayShape, cst: pxt.Real) -> pxt.OpT:
166 """
167 Constant scaling operator.
168
169 Parameters
170 ----------
171 cst: Real
172 Scaling factor.
173
174 Returns
175 -------
176 op: OpT
177 Scaling operator.
178
179 Note
180 ----
181 This operator is not defined in terms of :py:func:`~pyxu.operator.DiagonalOp` since it is array-backend-agnostic.
182 """
183 assert isinstance(cst, pxt.Real), f"cst: expected real, got {cst}."
184
185 if np.isclose(cst, 0):
186 op = NullOp(
187 dim_shape=dim_shape,
188 codim_shape=dim_shape,
189 )
190 elif np.isclose(cst, 1):
191 op = IdentityOp(dim_shape=dim_shape)
192 else: # build PosDef or SelfAdjointOp
193
194 def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray:
195 out = arr.copy()
196 out *= _._cst
197 return out
198
199 def op_svdvals(_, **kwargs) -> pxt.NDArray:
200 gpu = kwargs.get("gpu", False)
201 xp = pxd.NDArrayInfo.from_flag(gpu).module()
202 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value)
203
204 D = xp.full(
205 shape=kwargs["k"],
206 fill_value=abs(_._cst),
207 dtype=dtype,
208 )
209 return D
210
211 def op_pinv(_, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray:
212 out = arr.copy()
213 out *= _._cst / (_._cst**2 + damp)
214 return out
215
216 def op_dagger(_, damp: pxt.Real, **kwargs) -> pxt.OpT:
217 op = HomothetyOp(
218 cst=_._cst / (_._cst**2 + damp),
219 dim_shape=_.dim_shape,
220 )
221 return op
222
223 def op_gram(_):
224 op = HomothetyOp(
225 cst=_._cst**2,
226 dim_shape=_.dim_shape,
227 )
228 return op
229
230 def op_estimate_lipschitz(_, **kwargs) -> pxt.Real:
231 L = abs(_._cst)
232 return L
233
234 def op_trace(_, **kwargs):
235 out = _._cst * _.dim_size
236 return out
237
238 op = px_src.from_source(
239 cls=pxa.PosDefOp if (cst > 0) else pxa.SelfAdjointOp,
240 dim_shape=dim_shape,
241 codim_shape=dim_shape,
242 embed=dict(
243 _name="HomothetyOp",
244 _cst=cst,
245 _lipschitz=float(abs(cst)),
246 ),
247 apply=op_apply,
248 svdvals=op_svdvals,
249 pinv=op_pinv,
250 gram=op_gram,
251 cogram=op_gram,
252 trace=op_trace,
253 estimate_lipschitz=op_estimate_lipschitz,
254 )
255 op.dagger = types.MethodType(op_dagger, op)
256 return op
257
258
[docs]
259def DiagonalOp(
260 vec: pxt.NDArray,
261 dim_shape: pxt.NDArrayShape = None,
262 enable_warnings: bool = True,
263) -> pxt.OpT:
264 r"""
265 Element-wise scaling operator.
266
267 Note
268 ----
269 * :py:func:`~pyxu.operator.DiagonalOp` instances are **not arraymodule-agnostic**: they will only work with NDArrays
270 belonging to the same array module as `vec`. Moreover, inner computations may cast input arrays when the
271 precision of `vec` does not match the user-requested precision. If such a situation occurs, a warning is raised.
272 * If `vec` is a DASK array, the operator will be a :py:class:`~pyxu.abc.SelfAdjointOp`. If `vec` is a NUMPY/CUPY
273 array, the created operator specializes to :py:class:`~pyxu.abc.PosDefOp` when possible. Specialization is not
274 automatic for DASK inputs because operators should be quick to build under all circumstances, and this is not
275 guaranteed if we have to check that all entries are positive for out-of-core arrays. Users who know that all
276 `vec` entries are positive can manually cast to :py:class:`~pyxu.abc.PosDefOp` afterwards if required.
277
278 Parameters
279 ----------
280 dim_shape: NDArrayShape
281 (M1,...,MD) shape of operator's domain.
282 Defaults to the shape of `vec` when omitted.
283 vec: NDArray
284 Scale factors. If `dim_shape` is provided, then `vec` must be broadcastable with arrays of size `dim_shape`.
285 enable_warnings: bool
286 If ``True``, emit a warning in case of precision mis-match issues.
287 """
288 if dim_shape is None:
289 dim_shape = vec.shape
290 else:
291 dim_shape = pxu.as_canonical_shape(dim_shape)
292 sh = np.broadcast_shapes(vec.shape, dim_shape)
293
294 # Getting here means `vec` and `dim_shape` are broadcastable, but we don't know yet
295 # which one defines the upper bound.
296 assert all(s <= d for (s, d) in zip(sh, dim_shape)), "vec and dim_shape are incompatible."
297
298 def op_apply(_, arr):
299 if (_._vec.dtype != arr.dtype) and _._enable_warnings:
300 msg = "Computation may not be performed at the requested precision."
301 warnings.warn(msg, pxw.PrecisionWarning)
302 out = arr.copy()
303 out *= _._vec
304 return out
305
306 def op_asarray(_, **kwargs) -> pxt.NDArray:
307 xp = pxu.get_array_module(_._vec)
308 vec = xp.broadcast_to(_._vec, _.dim_shape)
309 A = xp.diag(vec.reshape(-1)).reshape((*_.codim_shape, *_.dim_shape))
310
311 xp = kwargs.get("xp", pxd.NDArrayInfo.default().module())
312 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value)
313 B = xp.array(pxu.to_NUMPY(A), dtype=dtype)
314 return B
315
316 def op_gram(_):
317 return DiagonalOp(
318 vec=_._vec**2,
319 dim_shape=_.dim_shape,
320 enable_warnings=_._enable_warnings,
321 )
322
323 def op_svdvals(_, **kwargs):
324 gpu = kwargs.get("gpu", False)
325 xp = pxd.NDArrayInfo.from_flag(gpu).module()
326 k = kwargs["k"]
327 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value)
328
329 vec = xp.broadcast_to(
330 xp.abs(_._vec),
331 _.dim_shape,
332 ).reshape(-1)
333 if ndi == pxd.NDArrayInfo.DASK:
334 D = xp.topk(vec, k)
335 else:
336 vec = vec[vec.argsort()]
337 D = vec[-k:]
338 return D.astype(dtype)
339
340 def op_pinv(_, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray:
341 xp = pxu.get_array_module(arr)
342 with warnings.catch_warnings():
343 warnings.simplefilter("ignore")
344 scale = _._vec / (_._vec**2 + damp)
345 scale[xp.isnan(scale)] = 0
346 out = arr.copy()
347 out *= scale
348 return out
349
350 def op_dagger(_, damp: pxt.Real, **kwargs) -> pxt.OpT:
351 xp = pxu.get_array_module(_._vec)
352 with warnings.catch_warnings():
353 warnings.simplefilter("ignore")
354 scale = _._vec / (_._vec**2 + damp)
355 scale[xp.isnan(scale)] = 0
356 return DiagonalOp(
357 vec=scale,
358 dim_shape=_.dim_shape,
359 enable_warnings=_._enable_warnings,
360 )
361
362 def op_trace(_, **kwargs):
363 xp = pxu.get_array_module(_._vec)
364 vec = xp.broadcast_to(_._vec, _.dim_shape)
365 return float(vec.sum())
366
367 def op_estimate_lipschitz(_, **kwargs):
368 xp = pxu.get_array_module(_._vec)
369 _.lipschitz = float(xp.fabs(vec).max())
370 return _.lipschitz
371
372 ndi = pxd.NDArrayInfo.from_obj(vec)
373 if ndi == pxd.NDArrayInfo.DASK:
374 klass = pxa.SelfAdjointOp
375 else:
376 positive = (vec > 0).all()
377 klass = pxa.PosDefOp if positive else pxa.SelfAdjointOp
378 op = px_src.from_source(
379 cls=klass,
380 dim_shape=dim_shape,
381 codim_shape=dim_shape,
382 embed=dict(
383 _name="DiagonalOp",
384 _vec=vec,
385 _enable_warnings=bool(enable_warnings),
386 ),
387 apply=op_apply,
388 estimate_lipschitz=op_estimate_lipschitz,
389 asarray=op_asarray,
390 gram=op_gram,
391 cogram=op_gram,
392 svdvals=op_svdvals,
393 pinv=op_pinv,
394 trace=op_trace,
395 )
396 op.dagger = types.MethodType(op_dagger, op)
397 return op
398
399
400def _ExplicitLinOp(
401 cls: pxt.OpC,
402 mat: typ.Union[pxt.NDArray, pxt.SparseArray],
403 dim_rank: pxt.Integer = None,
404 enable_warnings: bool = True,
405) -> pxt.OpT:
406 r"""
407 Build a linear operator from its matrix representation.
408
409 Given an array :math:`\mathbf{A} \in \mathbb{R}^{N_{1} \times\cdots\times N_{K} \times M_{1} \times\cdots\times
410 M_{D}}`, the *explicit linear operator* associated to :math:`\mathbf{A}` is defined as
411
412 .. math::
413
414 [\mathbf{A}\mathbf{x}]_{n_{1},\ldots,n_{K}}
415 =
416 \langle \mathbf{A}[n_{1},\ldots,n_{K},\ldots], \mathbf{x} \rangle
417 \qquad
418 \forall \mathbf{x} \in \mathbb{R}^{M_{1} \times\cdots\times M_{D}}.
419
420 Parameters
421 ----------
422 cls: OpC
423 LinOp sub-class to instantiate.
424 mat: NDArray, SparseArray
425 (N1,...,NK, M1,...,MD) matrix generator.
426 The input array can be *dense* or *sparse*.
427 Accepted 2D sparse arrays are:
428
429 * CPU: COO/CSC/CSR/BSR
430 * GPU: COO/CSC/CSR
431 dim_rank: Integer
432 Rank of operator's domain. (D)
433 It can be omitted if `mat` is 2D since auto-inferred to 1.
434 enable_warnings: bool
435 If ``True``, emit a warning in case of precision mis-match issues.
436
437 Notes
438 -----
439 * :py:class:`~pyxu.operator.linop.base._ExplicitLinOp` instances are **not arraymodule-agnostic**:
440 they will only work with NDArrays belonging to the same (dense) array module as `mat`. Moreover, inner
441 computations may cast input arrays when the precision of `mat` does not match the user-requested precision. If
442 such a situation occurs, a warning is raised.
443
444 * The matrix provided to :py:func:`~pyxu.operator.linop.base._ExplicitLinOp` is used as-is and can be accessed via
445 ``.mat``.
446 """
447
448 def is_dense(A) -> bool:
449 # Ensure `A` is a supported array format, then return
450 # `True` if `A` is dense
451 # `False` if `A` is sparse
452 fail_dense = False
453 try:
454 pxd.NDArrayInfo.from_obj(A)
455 return True
456 except Exception:
457 fail_dense = True
458
459 fail_sparse = False
460 try:
461 pxd.SparseArrayInfo.from_obj(A)
462 return False
463 except Exception:
464 fail_sparse = True
465
466 if fail_dense and fail_sparse:
467 raise ValueError("mat: format could not be inferred.")
468
469 def tensordot(A, b, dim_rank, warn: bool):
470 # Parameters
471 # ----------
472 # A: (N1,...,NK, M1,...,MD) dense or sparse (2D)
473 # b: (S1,...,SL, M1,...,MD) dense
474 # dim_rank: D
475 # warn: bool
476 #
477 # Returns
478 # -------
479 # out: (S1,...,SL, N1,...,NK) dense
480 if (A.dtype != b.dtype) and warn:
481 msg = "Computation may not be performed at the requested precision."
482 warnings.warn(msg, pxw.PrecisionWarning)
483
484 dim_shape = A.shape[-dim_rank:]
485 dim_size = np.prod(dim_shape)
486 codim_shape = A.shape[:-dim_rank]
487 codim_size = np.prod(codim_shape)
488
489 sh = b.shape[:-dim_rank]
490 if not is_dense(A): # sparse matrix -> necessarily 2D
491 b = b.reshape(-1, dim_size)
492 out = A.dot(b.T).T # (prod(sh), codim_size)
493 out = out.reshape(*sh, codim_size)
494 else: # ND dense array
495 N = pxd.NDArrayInfo # short-hand
496 ndi = N.from_obj(A)
497 xp = ndi.module()
498
499 if ndi != N.DASK:
500 # NUMPY/CUPY.tensordot() works -> use it.
501 out = xp.tensordot( # (S1,...,SL, N1,...,NK)
502 b, # (S1,...,SL, M1,...,MD)
503 A, # (N1,...,NK, M1,...,MD)
504 axes=[
505 list(range(-dim_rank, 0)),
506 list(range(-dim_rank, 0)),
507 ],
508 )
509 else: # DASK-backed `mat`
510 # DASK.tensordot() broken -> use 2D-ops instead
511 msg = "[2023.12] DASK's tensordot() is broken. -> fallback onto 2D-shaped ops."
512 pxw.warn_dask_perf(msg)
513
514 A_2D = A.reshape(codim_size, dim_size)
515 b = b.reshape(-1, dim_size)
516 out = A_2D.dot(b.T).T # (prod(sh), codim_size)
517 out = out.reshape(*sh, *codim_shape)
518 return out
519
520 def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray:
521 out = tensordot(
522 A=_.mat,
523 b=arr,
524 dim_rank=_.dim_rank,
525 warn=_._enable_warnings,
526 )
527 return out
528
529 def op_adjoint(_, arr: pxt.NDArray) -> pxt.NDArray:
530 if is_dense(_.mat):
531 axes = (
532 *tuple(range(-_.dim_rank, 0)),
533 *tuple(range(_.codim_rank)),
534 )
535 else:
536 axes = None # transposes all axes for 2D sparse arrays
537 out = tensordot(
538 A=_.mat.transpose(axes),
539 b=arr,
540 dim_rank=_.codim_rank,
541 warn=_._enable_warnings,
542 )
543 return out
544
545 def op_estimate_lipscthitz(_, **kwargs) -> pxt.Real:
546 N = pxd.NDArrayInfo
547 S = pxd.SparseArrayInfo
548
549 if is_dense(_.mat):
550 ndi = N.from_obj(_.mat)
551 else:
552 sdi = S.from_obj(_.mat)
553 if sdi == S.SCIPY_SPARSE:
554 ndi = N.NUMPY
555 elif sdi == S.CUPY_SPARSE:
556 ndi = N.CUPY
557 else:
558 raise NotImplementedError
559
560 kwargs.update(
561 xp=ndi.module(),
562 gpu=ndi == N.CUPY,
563 dtype=_.mat.dtype,
564 )
565 klass = _.__class__
566 return klass.estimate_lipschitz(_, **kwargs)
567
568 def op_asarray(_, **kwargs) -> pxt.NDArray:
569 N = pxd.NDArrayInfo
570 xp = kwargs.get("xp", N.default().module())
571 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value)
572
573 if is_dense(_.mat):
574 A = _.mat.astype(dtype, copy=False)
575 else:
576 A = _.mat.astype(dtype).toarray() # `copy field not ubiquitous`
577 B = xp.array(pxu.to_NUMPY(A), dtype=dtype)
578 return B
579
580 def op_trace(_, **kwargs) -> pxt.Real:
581 if _.dim_size != _.codim_size:
582 raise NotImplementedError
583 elif len(_.mat.shape) == 2: # dense or sparse
584 try:
585 tr = _.mat.trace()
586 except AttributeError:
587 # Not all sparse types have a .trace() method ...
588 tr = _.mat.diagonal().sum()
589 else: # ND dense arrays only
590 # We don't want to reshape `mat` if DASK-backed for performance reasons, so the trace is built by indexing
591 # the "diagonal" manually.
592 tr = 0
593 for idx in range(_.dim_size):
594 dim_idx = np.unravel_index(idx, _.dim_shape)
595 codim_idx = np.unravel_index(idx, _.codim_shape)
596 tr += _.mat[*codim_idx, *dim_idx]
597 return float(tr)
598
599 is_dense(mat) # We were given a dense/sparse array ...
600 # ... but is dim_rank correctly specified?
601 assert len(mat.shape) >= 2, "Only 2D+ arrays are supported."
602 if len(mat.shape) == 2:
603 dim_rank = 1 # doesn't matter what the user specified.
604 else: # rank > 2
605 # if ND -> mandatory supplied & (1 <= dim_rank < mat.ndim)
606 assert dim_rank is not None, "Dimension rank must be specified for ND operators."
607 assert 1 <= dim_rank < len(mat.shape)
608
609 op = px_src.from_source(
610 cls=cls,
611 dim_shape=mat.shape[-dim_rank:],
612 codim_shape=mat.shape[:-dim_rank],
613 embed=dict(
614 _name="_ExplicitLinOp",
615 mat=mat,
616 _enable_warnings=bool(enable_warnings),
617 ),
618 apply=op_apply,
619 adjoint=op_adjoint,
620 asarray=op_asarray,
621 trace=op_trace,
622 estimate_lipschitz=op_estimate_lipscthitz,
623 )
624 return op