1import collections
2import collections.abc as cabc
3import copy
4import enum
5import inspect
6import types
7import typing as typ
8import warnings
9
10import numpy as np
11import scipy.sparse.linalg as spsl
12
13import pyxu.info.deps as pxd
14import pyxu.info.ptype as pxt
15import pyxu.info.warning as pxw
16import pyxu.runtime as pxrt
17import pyxu.util as pxu
18
19
[docs]
20class Property(enum.Enum):
21 """
22 Mathematical property.
23
24 See Also
25 --------
26 :py:class:`~pyxu.abc.Operator`
27 """
28
29 CAN_EVAL = enum.auto()
30 FUNCTIONAL = enum.auto()
31 PROXIMABLE = enum.auto()
32 DIFFERENTIABLE = enum.auto()
33 DIFFERENTIABLE_FUNCTION = enum.auto()
34 LINEAR = enum.auto()
35 LINEAR_SQUARE = enum.auto()
36 LINEAR_NORMAL = enum.auto()
37 LINEAR_IDEMPOTENT = enum.auto()
38 LINEAR_SELF_ADJOINT = enum.auto()
39 LINEAR_POSITIVE_DEFINITE = enum.auto()
40 LINEAR_UNITARY = enum.auto()
41 QUADRATIC = enum.auto()
42
[docs]
43 def arithmetic_methods(self) -> cabc.Set[str]:
44 "Instance methods affected by arithmetic operations."
45 data = collections.defaultdict(list)
46 data[self.CAN_EVAL].extend(
47 [
48 "apply",
49 "__call__",
50 "estimate_lipschitz",
51 "_expr",
52 ]
53 )
54 data[self.PROXIMABLE].append("prox")
55 data[self.DIFFERENTIABLE].extend(
56 [
57 "jacobian",
58 "estimate_diff_lipschitz",
59 ]
60 )
61 data[self.DIFFERENTIABLE_FUNCTION].append("grad")
62 data[self.LINEAR].extend(
63 [
64 "adjoint",
65 "asarray",
66 "svdvals",
67 "pinv",
68 "gram",
69 "cogram",
70 ]
71 )
72 data[self.LINEAR_SQUARE].append("trace")
73 data[self.QUADRATIC].append("_quad_spec")
74
75 meth = frozenset(data[self])
76 return meth
77
78
[docs]
79class Operator:
80 """
81 Abstract Base Class for Pyxu operators.
82
83 Goals:
84
85 * enable operator arithmetic.
86 * cast operators to specialized forms.
87 * attach :py:class:`~pyxu.abc.Property` tags encoding certain mathematical properties. Each core sub-class **must**
88 have a unique set of properties to be distinguishable from its peers.
89 """
90
91 # For `(size-1 ndarray) * OpT` to work, we need to force NumPy's hand and call OpT.__rmul__() in
92 # place of ndarray.__mul__() to determine how scaling should be performed.
93 # This is achieved by increasing __array_priority__ for all operators.
94 __array_priority__ = np.inf
95
[docs]
96 def __init__(
97 self,
98 dim_shape: pxt.NDArrayShape,
99 codim_shape: pxt.NDArrayShape,
100 ):
101 r"""
102 Parameters
103 ----------
104 dim_shape: NDArrayShape
105 (M1,...,MD) operator input shape.
106 codim_shape: NDArrayShape
107 (N1,...,NK) operator output shape.
108 """
109 dim_shape = pxu.as_canonical_shape(dim_shape)
110 assert all(ax >= 1 for ax in dim_shape)
111
112 codim_shape = pxu.as_canonical_shape(codim_shape)
113 assert all(ax >= 1 for ax in codim_shape)
114
115 self._dim_shape = dim_shape
116 self._codim_shape = codim_shape
117 self._name = self.__class__.__name__
118
119 # Public Interface --------------------------------------------------------
120 @property
121 def dim_shape(self) -> pxt.NDArrayShape:
122 r"""
123 Return shape of operator's domain. (M1,...,MD)
124 """
125 return self._dim_shape
126
127 @property
128 def dim_size(self) -> pxt.Integer:
129 r"""
130 Return size of operator's domain. (M1*...*MD)
131 """
132 return np.prod(self.dim_shape)
133
134 @property
135 def dim_rank(self) -> pxt.Integer:
136 r"""
137 Return rank of operator's domain. (D)
138 """
139 return len(self.dim_shape)
140
141 @property
142 def codim_shape(self) -> pxt.NDArrayShape:
143 r"""
144 Return shape of operator's co-domain. (N1,...,NK)
145 """
146 return self._codim_shape
147
148 @property
149 def codim_size(self) -> pxt.Integer:
150 r"""
151 Return size of operator's co-domain. (N1*...*NK)
152 """
153 return np.prod(self.codim_shape)
154
155 @property
156 def codim_rank(self) -> pxt.Integer:
157 r"""
158 Return rank of operator's co-domain. (K)
159 """
160 return len(self.codim_shape)
161
[docs]
162 @classmethod
163 def properties(cls) -> cabc.Set[Property]:
164 "Mathematical properties of the operator."
165 return frozenset()
166
[docs]
167 @classmethod
168 def has(cls, prop: typ.Union[Property, cabc.Collection[Property]]) -> bool:
169 """
170 Verify if operator possesses supplied properties.
171 """
172 if isinstance(prop, Property):
173 prop = (prop,)
174 return frozenset(prop) <= cls.properties()
175
[docs]
176 def asop(self, cast_to: pxt.OpC) -> pxt.OpT:
177 r"""
178 Recast an :py:class:`~pyxu.abc.Operator` (or subclass thereof) to another :py:class:`~pyxu.abc.Operator`.
179
180 Users may call this method if the arithmetic API yields sub-optimal return types.
181
182 This method is a no-op if `cast_to` is a parent class of ``self``.
183
184 Parameters
185 ----------
186 cast_to: OpC
187 Target type for the recast.
188
189 Returns
190 -------
191 op: OpT
192 Operator with the new interface.
193
194 Fails when cast is forbidden.
195 (Ex: :py:class:`~pyxu.abc.Map` -> :py:class:`~pyxu.abc.Func` if codim.size > 1)
196
197 Notes
198 -----
199 * The interface of `cast_to` is provided via encapsulation + forwarding.
200 * If ``self`` does not implement all methods from `cast_to`, then unimplemented methods will raise
201 :py:class:`NotImplementedError` when called.
202 """
203 if cast_to not in _core_operators():
204 raise ValueError(f"cast_to: expected a core base-class, got {cast_to}.")
205
206 p_core = frozenset(self.properties())
207 p_shell = frozenset(cast_to.properties())
208 if p_shell <= p_core:
209 # Trying to cast `self` to it's own class or a parent class.
210 # Inheritance rules mean the target object already satisfies the intended interface.
211 return self
212 else:
213 # (p_shell > p_core) -> specializing to a sub-class of ``self``
214 # OR
215 # len(p_shell ^ p_core) > 0 -> specializing to another branch of the class hierarchy.
216 op = cast_to(
217 dim_shape=self.dim_shape,
218 codim_shape=self.codim_shape,
219 )
220 op._core = self # for debugging
221
222 # Forward shared arithmetic fields from core to shell.
223 for p in p_shell & p_core:
224 for m in p.arithmetic_methods():
225 m_core = getattr(self, m)
226 setattr(op, m, m_core)
227
228 # [diff_]lipschitz are not arithmetic methods, hence are not propagated.
229 # We propagate them manually to avoid un-warranted re-evaluations.
230 # Important: we write to _[diff_]lipschitz to not overwrite estimate_[diff_]lipschitz() methods.
231 if cast_to.has(Property.CAN_EVAL) and self.has(Property.CAN_EVAL):
232 op._lipschitz = self.lipschitz
233 if cast_to.has(Property.DIFFERENTIABLE) and self.has(Property.DIFFERENTIABLE):
234 op._diff_lipschitz = self.diff_lipschitz
235
236 return op
237
238 # Operator Arithmetic -----------------------------------------------------
[docs]
239 def __add__(self, other: pxt.OpT) -> pxt.OpT:
240 """
241 Add two operators.
242
243 Parameters
244 ----------
245 self: OpT
246 Left operand.
247 other: OpT
248 Right operand.
249
250 Returns
251 -------
252 op: OpT
253 Composite operator ``self + other``
254
255 Notes
256 -----
257 Operand shapes must be consistent, i.e.:
258
259 * have `same dimensions`, and
260 * have `compatible co-dimensions` (after broadcasting).
261
262 See Also
263 --------
264 :py:class:`~pyxu.abc.arithmetic.AddRule`
265 """
266 import pyxu.abc.arithmetic as arithmetic
267
268 if isinstance(other, Operator):
269 return arithmetic.AddRule(lhs=self, rhs=other).op()
270 else:
271 return NotImplemented
272
[docs]
273 def __sub__(self, other: pxt.OpT) -> pxt.OpT:
274 """
275 Subtract two operators.
276
277 Parameters
278 ----------
279 self: OpT
280 Left operand.
281 other: OpT
282 Right operand.
283
284 Returns
285 -------
286 op: OpT
287 Composite operator ``self - other``
288 """
289 import pyxu.abc.arithmetic as arithmetic
290
291 if isinstance(other, Operator):
292 return arithmetic.AddRule(lhs=self, rhs=-other).op()
293 else:
294 return NotImplemented
295
[docs]
296 def __neg__(self) -> pxt.OpT:
297 """
298 Negate an operator.
299
300 Returns
301 -------
302 op: OpT
303 Composite operator ``-1 * self``.
304 """
305 import pyxu.abc.arithmetic as arithmetic
306
307 return arithmetic.ScaleRule(op=self, cst=-1).op()
308
[docs]
309 def __mul__(self, other: typ.Union[pxt.Real, pxt.OpT]) -> pxt.OpT:
310 """
311 Compose two operators, or scale an operator by a constant.
312
313 Parameters
314 ----------
315 self: OpT
316 Left operand.
317 other: Real, OpT
318 Scalar or right operand.
319
320 Returns
321 -------
322 op: OpT
323 Scaled operator or composed operator ``self * other``.
324
325 Notes
326 -----
327 If called with two operators, their shapes must be `consistent`, i.e. ``self.dim_shape == other.codim_shape``.
328
329 See Also
330 --------
331 :py:class:`~pyxu.abc.arithmetic.ScaleRule`,
332 :py:class:`~pyxu.abc.arithmetic.ChainRule`
333 """
334 import pyxu.abc.arithmetic as arithmetic
335
336 if isinstance(other, Operator):
337 return arithmetic.ChainRule(lhs=self, rhs=other).op()
338 elif _is_real(other):
339 return arithmetic.ScaleRule(op=self, cst=float(other)).op()
340 else:
341 return NotImplemented
342
343 def __rmul__(self, other: pxt.Real) -> pxt.OpT:
344 import pyxu.abc.arithmetic as arithmetic
345
346 if _is_real(other):
347 return arithmetic.ScaleRule(op=self, cst=float(other)).op()
348 else:
349 return NotImplemented
350
351 def __truediv__(self, other: pxt.Real) -> pxt.OpT:
352 import pyxu.abc.arithmetic as arithmetic
353
354 if _is_real(other):
355 return arithmetic.ScaleRule(op=self, cst=float(1 / other)).op()
356 else:
357 return NotImplemented
358
[docs]
359 def __pow__(self, k: pxt.Integer) -> pxt.OpT:
360 # (op ** k) unsupported
361 return NotImplemented
362
363 def __matmul__(self, other) -> pxt.OpT:
364 # (op @ NDArray) unsupported
365 return NotImplemented
366
367 def __rmatmul__(self, other) -> pxt.OpT:
368 # (NDArray @ op) unsupported
369 return NotImplemented
370
[docs]
371 def argscale(self, scalar: pxt.Real) -> pxt.OpT:
372 """
373 Scale operator's domain.
374
375 Parameters
376 ----------
377 scalar: Real
378
379 Returns
380 -------
381 op: OpT
382 Domain-scaled operator.
383
384 See Also
385 --------
386 :py:class:`~pyxu.abc.arithmetic.ArgScaleRule`
387 """
388 import pyxu.abc.arithmetic as arithmetic
389
390 assert _is_real(scalar)
391 return arithmetic.ArgScaleRule(op=self, cst=float(scalar)).op()
392
[docs]
393 def argshift(self, shift: pxt.NDArray) -> pxt.OpT:
394 r"""
395 Shift operator's domain.
396
397 Parameters
398 ----------
399 shift: NDArray
400 Shift value :math:`c \in \mathbb{R}^{M_{1} \times\cdots\times M_{D}}`.
401
402 `shift` must be broadcastable with operator's dimension.
403
404 Returns
405 -------
406 op: OpT
407 Domain-shifted operator :math:`g(x) = f(x + c)`.
408
409 See Also
410 --------
411 :py:class:`~pyxu.abc.arithmetic.ArgShiftRule`
412 """
413 import pyxu.abc.arithmetic as arithmetic
414
415 return arithmetic.ArgShiftRule(op=self, cst=shift).op()
416
417 # Internal Helpers --------------------------------------------------------
418 @staticmethod
419 def _infer_operator_type(prop: cabc.Collection[Property]) -> pxt.OpC:
420 prop = frozenset(prop)
421 for op in _core_operators():
422 if op.properties() == prop:
423 return op
424 else:
425 raise ValueError(f"No operator found with properties {prop}.")
426
427 def __repr__(self) -> str:
428 klass = self._name
429 return f"{klass}(dim={self.dim_shape}, codim={self.codim_shape})"
430
431 def _expr(self) -> tuple:
432 r"""
433 Show the expression-representation of the operator.
434
435 If overridden, must return a tuple of the form
436
437 (head, \*tail),
438
439 where `head` is the operator (ex: +/\*), and `tail` denotes all the expression's terms. If an operator cannot
440 be expanded further, then this method should return (self,).
441 """
442 return (self,)
443
444 def _meta(self):
445 # When using DASK inputs, it is sometimes necessary to pass extra information to Dask functions. This function
446 # serves this purpose: it lets class writers encode any such information and re-use it when processing DASK
447 # inputs. The action and return types of _meta() are at the sole discretion of the implementer.
448 raise NotImplementedError
449
[docs]
450 def expr(self, level: int = 0, strip: bool = True) -> str:
451 """
452 Pretty-Print the expression representation of the operator.
453
454 Useful for debugging arithmetic-induced expressions.
455
456 Example
457 -------
458
459 .. code-block:: python3
460
461 import numpy as np
462 import pyxu.abc as pxa
463
464 kwargs = dict(dim_shape=5, codim_shape=5)
465 op1 = pxa.LinOp(**kwargs)
466 op2 = pxa.DiffMap(**kwargs)
467 op = ((2 * op1) + (op1 * op2)).argshift(np.r_[1])
468
469 print(op.expr())
470 # [argshift, ==> DiffMap(dim=(5,), codim=(5,))
471 # .[add, ==> DiffMap(dim=(5,), codim=(5,))
472 # ..[scale, ==> LinOp(dim=(5,), codim=(5,))
473 # ...LinOp(dim=(5,), codim=(5,)),
474 # ...2.0],
475 # ..[compose, ==> DiffMap(dim=(5,), codim=(5,))
476 # ...LinOp(dim=(5,), codim=(5,)),
477 # ...DiffMap(dim=(5,), codim=(5,))]],
478 # .(1,)]
479 """
480 fmt = lambda obj, lvl: ("." * lvl) + str(obj)
481 lines = []
482
483 head, *tail = self._expr()
484 if len(tail) == 0:
485 head = f"{repr(head)},"
486 else:
487 head = f"[{head}, ==> {repr(self)}"
488 lines.append(fmt(head, level))
489
490 for t in tail:
491 if isinstance(t, Operator):
492 lines += t.expr(level=level + 1, strip=False).split("\n")
493 else:
494 t = f"{t},"
495 lines.append(fmt(t, level + 1))
496 if len(tail) > 0:
497 # Drop comma for last tail item, then close the sub-expression.
498 lines[-1] = lines[-1][:-1]
499 lines[-1] += "],"
500
501 out = "\n".join(lines)
502 if strip:
503 out = out.strip(",") # drop comma at top-level tail.
504 return out
505
506 # Short-hands for commonly-used operators ---------------------------------
[docs]
507 def squeeze(self, axes: pxt.NDArrayAxis = None) -> pxt.OpT:
508 """
509 Drop size-1 axes from co-dimension.
510
511 See Also
512 --------
513 :py:class:`~pyxu.operator.SqueezeAxes`
514 """
515
516 from pyxu.operator import SqueezeAxes
517
518 sq = SqueezeAxes(
519 dim_shape=self.codim_shape,
520 axes=axes,
521 )
522 op = sq * self
523 return op
524
[docs]
525 def transpose(self, axes: pxt.NDArrayAxis = None) -> pxt.OpT:
526 """
527 Permute co-dimension axes.
528
529 See Also
530 --------
531 :py:class:`~pyxu.operator.TransposeAxes`
532 """
533
534 from pyxu.operator import TransposeAxes
535
536 tr = TransposeAxes(
537 dim_shape=self.codim_shape,
538 axes=axes,
539 )
540 op = tr * self
541 return op
542
[docs]
543 def reshape(self, codim_shape: pxt.NDArrayShape) -> pxt.OpT:
544 """
545 Reshape co-dimension shape.
546
547 See Also
548 --------
549 :py:class:`~pyxu.operator.ReshapeAxes`
550 """
551
552 from pyxu.operator import ReshapeAxes
553
554 rsh = ReshapeAxes(
555 dim_shape=self.codim_shape,
556 codim_shape=codim_shape,
557 )
558 op = rsh * self
559 return op
560
[docs]
561 def broadcast_to(self, codim_shape: pxt.NDArrayShape) -> pxt.OpT:
562 """
563 Broadcast co-dimension shape.
564
565 See Also
566 --------
567 :py:class:`~pyxu.operator.BroadcastAxes`
568 """
569
570 from pyxu.operator import BroadcastAxes
571
572 bcast = BroadcastAxes(
573 dim_shape=self.codim_shape,
574 codim_shape=codim_shape,
575 )
576 op = bcast * self
577 return op
578
[docs]
579 def subsample(self, *indices) -> pxt.OpT:
580 """
581 Sub-sample co-dimension.
582
583 See Also
584 --------
585 :py:class:`~pyxu.operator.SubSample`
586 """
587
588 from pyxu.operator import SubSample
589
590 sub = SubSample(self.codim_shape, *indices)
591 op = sub * self
592 return op
593
[docs]
594 def rechunk(self, chunks: dict) -> pxt.OpT:
595 """
596 Re-chunk core dimensions to new chunk size.
597
598 See Also
599 --------
600 :py:func:`~pyxu.operator.RechunkAxes`
601 """
602 from pyxu.operator import RechunkAxes
603
604 chk = RechunkAxes(
605 dim_shape=self.codim_shape,
606 chunks=chunks,
607 )
608 op = chk * self
609 return op
610
611
[docs]
612class Map(Operator):
613 r"""
614 Base class for real-valued maps :math:`\mathbf{f}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{N_{1}
615 \times\cdots\times N_{K}}`.
616
617 Instances of this class must implement :py:meth:`~pyxu.abc.Map.apply`.
618
619 If :math:`\mathbf{f}` is Lipschitz-continuous with known Lipschitz constant :math:`L`, the latter should be stored
620 in the :py:attr:`~pyxu.abc.Map.lipschitz` property.
621 """
622
[docs]
623 @classmethod
624 def properties(cls) -> cabc.Set[Property]:
625 p = set(super().properties())
626 p.add(Property.CAN_EVAL)
627 return frozenset(p)
628
629 def __init__(
630 self,
631 dim_shape: pxt.NDArrayShape,
632 codim_shape: pxt.NDArrayShape,
633 ):
634 super().__init__(
635 dim_shape=dim_shape,
636 codim_shape=codim_shape,
637 )
638 self.lipschitz = np.inf
639
[docs]
640 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
641 """
642 Evaluate operator at specified point(s).
643
644 Parameters
645 ----------
646 arr: NDArray
647 (..., M1,...,MD) input points.
648
649 Returns
650 -------
651 out: NDArray
652 (..., N1,...,NK) output points.
653 """
654 raise NotImplementedError
655
[docs]
656 def __call__(self, arr: pxt.NDArray) -> pxt.NDArray:
657 """
658 Alias for :py:meth:`~pyxu.abc.Map.apply`.
659 """
660 return self.apply(arr)
661
662 @property
663 def lipschitz(self) -> pxt.Real:
664 r"""
665 Return the last computed Lipschitz constant of :math:`\mathbf{f}`.
666
667 Notes
668 -----
669 * If a Lipschitz constant is known apriori, it can be stored in the instance as follows:
670
671 .. code-block:: python3
672
673 class TestOp(Map):
674 def __init__(self, dim_shape, codim_shape):
675 super().__init__(dim_shape, codim_shape)
676 self.lipschitz = 2
677
678 op = TestOp(2, 3)
679 op.lipschitz # => 2
680
681
682 Alternatively the Lipschitz constant can be set manually after initialization:
683
684 .. code-block:: python3
685
686 class TestOp(Map):
687 def __init__(self, dim_shape, codim_shape):
688 super().__init__(dim_shape, codim_shape)
689
690 op = TestOp(2, 3)
691 op.lipschitz # => inf, since unknown apriori
692
693 op.lipschitz = 2 # post-init specification
694 op.lipschitz # => 2
695
696 * :py:meth:`~pyxu.abc.Map.lipschitz` **never** computes anything:
697 call :py:meth:`~pyxu.abc.Map.estimate_lipschitz` manually to *compute* a new Lipschitz estimate:
698
699 .. code-block:: python3
700
701 op.lipschitz = op.estimate_lipschitz()
702 """
703 if not hasattr(self, "_lipschitz"):
704 self._lipschitz = self.estimate_lipschitz()
705
706 return self._lipschitz
707
708 @lipschitz.setter
709 def lipschitz(self, L: pxt.Real):
710 assert L >= 0
711 self._lipschitz = float(L)
712
713 # If no algorithm available to auto-determine estimate_lipschitz(), then enforce user's choice.
714 if not self.has(Property.LINEAR):
715
716 def op_estimate_lipschitz(_, **kwargs) -> pxt.Real:
717 return _._lipschitz
718
719 self.estimate_lipschitz = types.MethodType(op_estimate_lipschitz, self)
720
[docs]
721 def estimate_lipschitz(self, **kwargs) -> pxt.Real:
722 r"""
723 Compute a Lipschitz constant of the operator.
724
725 Parameters
726 ----------
727 kwargs: ~collections.abc.Mapping
728 Class-specific kwargs to configure Lipschitz estimation.
729
730 Notes
731 -----
732 * This method should always be callable without specifying any kwargs.
733
734 * A constant :math:`L_{\mathbf{f}} > 0` is said to be a *Lipschitz constant* for a map :math:`\mathbf{f}:
735 \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{N_{1} \times\cdots\times N_{K}}` if:
736
737 .. math::
738
739 \|\mathbf{f}(\mathbf{x}) - \mathbf{f}(\mathbf{y})\|_{\mathbb{R}^{N_{1} \times\cdots\times N_{K}}}
740 \leq
741 L_{\mathbf{f}} \|\mathbf{x} - \mathbf{y}\|_{\mathbb{R}^{M_{1} \times\cdots\times M_{D}}},
742 \qquad
743 \forall \mathbf{x}, \mathbf{y}\in \mathbb{R}^{M_{1} \times\cdots\times M_{D}},
744
745 where :math:`\|\cdot\|_{\mathbb{R}^{M_{1} \times\cdots\times M_{D}}}` and :math:`\|\cdot\|_{\mathbb{R}^{N_{1}
746 \times\cdots\times N_{K}}}` are the canonical norms on their respective spaces.
747
748 The smallest Lipschitz constant of a map is called the *optimal Lipschitz constant*.
749 """
750 raise NotImplementedError
751
752
[docs]
753class Func(Map):
754 r"""
755 Base class for real-valued functionals :math:`f: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to
756 \mathbb{R}\cup\{+\infty\}`.
757
758 Instances of this class must implement :py:meth:`~pyxu.abc.Map.apply`.
759
760 If :math:`f` is Lipschitz-continuous with known Lipschitz constant :math:`L`, the latter should be stored in the
761 :py:attr:`~pyxu.abc.Map.lipschitz` property.
762 """
763
764 @classmethod
765 def properties(cls) -> cabc.Set[Property]:
766 p = set(super().properties())
767 p.add(Property.FUNCTIONAL)
768 return frozenset(p)
769
770 def __init__(
771 self,
772 dim_shape: pxt.NDArrayShape,
773 codim_shape: pxt.NDArrayShape,
774 ):
775 super().__init__(
776 dim_shape=dim_shape,
777 codim_shape=codim_shape,
778 )
779 assert (self.codim_size == 1) and (self.codim_rank == 1)
780
781
[docs]
782class DiffMap(Map):
783 r"""
784 Base class for real-valued differentiable maps :math:`\mathbf{f}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to
785 \mathbb{R}^{N_{1} \times\cdots\times N_{K}}`.
786
787 Instances of this class must implement :py:meth:`~pyxu.abc.Map.apply` and :py:meth:`~pyxu.abc.DiffMap.jacobian`.
788
789 If :math:`\mathbf{f}` is Lipschitz-continuous with known Lipschitz constant :math:`L`, the latter should be stored
790 in the :py:attr:`~pyxu.abc.Map.lipschitz` property.
791
792 If :math:`\mathbf{J}_{\mathbf{f}}` is Lipschitz-continuous with known Lipschitz constant :math:`\partial L`, the
793 latter should be stored in the :py:attr:`~pyxu.abc.DiffMap.diff_lipschitz` property.
794 """
795
796 @classmethod
797 def properties(cls) -> cabc.Set[Property]:
798 p = set(super().properties())
799 p.add(Property.DIFFERENTIABLE)
800 return frozenset(p)
801
802 def __init__(
803 self,
804 dim_shape: pxt.NDArrayShape,
805 codim_shape: pxt.NDArrayShape,
806 ):
807 super().__init__(
808 dim_shape=dim_shape,
809 codim_shape=codim_shape,
810 )
811 self.diff_lipschitz = np.inf
812
[docs]
813 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT:
814 r"""
815 Evaluate the Jacobian of :math:`\mathbf{f}` at the specified point.
816
817 Parameters
818 ----------
819 arr: NDArray
820 (M1,...,MD) evaluation point.
821
822 Returns
823 -------
824 op: OpT
825 Jacobian operator at point `arr`.
826
827 Notes
828 -----
829 Let :math:`\mathbf{f}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{N_{1} \times\cdots\times
830 N_{K}}` be a differentiable multi-dimensional map. The *Jacobian* (or *differential*) of :math:`\mathbf{f}` at
831 :math:`\mathbf{z} \in \mathbb{R}^{M_{1} \times\cdots\times M_{D}}` is defined as the best linear approximator of
832 :math:`\mathbf{f}` near :math:`\mathbf{z}`, in the following sense:
833
834 .. math::
835
836 \mathbf{f}(\mathbf{x}) - \mathbf{f}(\mathbf{z}) = \mathbf{J}_{\mathbf{f}}(\mathbf{z}) (\mathbf{x} -
837 \mathbf{z}) + o(\| \mathbf{x} - \mathbf{z} \|) \quad \text{as} \quad \mathbf{x} \to \mathbf{z}.
838
839 The Jacobian admits the following matrix representation:
840
841 .. math::
842
843 [\mathbf{J}_{\mathbf{f}}(\mathbf{x})]_{ij} := \frac{\partial f_{i}}{\partial x_{j}}(\mathbf{x}), \qquad
844 \forall (i,j) \in \{1,\ldots,N_{1}\cdots N_{K}\} \times \{1,\ldots,M_{1}\cdots M_{D}\}.
845 """
846 raise NotImplementedError
847
848 @property
849 def diff_lipschitz(self) -> pxt.Real:
850 r"""
851 Return the last computed Lipschitz constant of :math:`\mathbf{J}_{\mathbf{f}}`.
852
853 Notes
854 -----
855 * If a diff-Lipschitz constant is known apriori, it can be stored in the instance as follows:
856
857 .. code-block:: python3
858
859 class TestOp(DiffMap):
860 def __init__(self, dim_shape, codim_shape):
861 super().__init__(dim_shape, codim_shape)
862 self.diff_lipschitz = 2
863
864 op = TestOp(2, 3)
865 op.diff_lipschitz # => 2
866
867
868 Alternatively the diff-Lipschitz constant can be set manually after initialization:
869
870 .. code-block:: python3
871
872 class TestOp(DiffMap):
873 def __init__(self, dim_shape, codim_shape):
874 super().__init__(dim_shape, codim_shape)
875
876 op = TestOp(2, 3)
877 op.diff_lipschitz # => inf, since unknown apriori
878
879 op.diff_lipschitz = 2 # post-init specification
880 op.diff_lipschitz # => 2
881
882 * :py:meth:`~pyxu.abc.DiffMap.diff_lipschitz` **never** computes anything:
883 call :py:meth:`~pyxu.abc.DiffMap.estimate_diff_lipschitz` manually to *compute* a new diff-Lipschitz estimate:
884
885 .. code-block:: python3
886
887 op.diff_lipschitz = op.estimate_diff_lipschitz()
888 """
889 if not hasattr(self, "_diff_lipschitz"):
890 self._diff_lipschitz = self.estimate_diff_lipschitz()
891
892 return self._diff_lipschitz
893
894 @diff_lipschitz.setter
895 def diff_lipschitz(self, dL: pxt.Real):
896 assert dL >= 0
897 self._diff_lipschitz = float(dL)
898
899 # If no algorithm available to auto-determine estimate_diff_lipschitz(), then enforce user's choice.
900 if not self.has(Property.QUADRATIC):
901
902 def op_estimate_diff_lipschitz(_, **kwargs) -> pxt.Real:
903 return _._diff_lipschitz
904
905 self.estimate_diff_lipschitz = types.MethodType(op_estimate_diff_lipschitz, self)
906
[docs]
907 def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real:
908 r"""
909 Compute a Lipschitz constant of :py:meth:`~pyxu.abc.DiffMap.jacobian`.
910
911 Parameters
912 ----------
913 kwargs: ~collections.abc.Mapping
914 Class-specific kwargs to configure diff-Lipschitz estimation.
915
916 Notes
917 -----
918 * This method should always be callable without specifying any kwargs.
919
920 * A Lipschitz constant :math:`L_{\mathbf{J}_{\mathbf{f}}} > 0` of the Jacobian map
921 :math:`\mathbf{J}_{\mathbf{f}}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{(N_{1}
922 \times\cdots\times N_{K}) \times (M_{1} \times\cdots\times M_{D})}` is such that:
923
924 .. math::
925
926 \|\mathbf{J}_{\mathbf{f}}(\mathbf{x}) - \mathbf{J}_{\mathbf{f}}(\mathbf{y})\|_{\mathbb{R}^{(N_{1}
927 \times\cdots\times N_{K}) \times (M_{1} \times\cdots\times M_{D})}}
928 \leq
929 L_{\mathbf{J}_{\mathbf{f}}} \|\mathbf{x} - \mathbf{y}\|_{\mathbb{R}^{M_{1} \times\cdots\times M_{D}}},
930 \qquad
931 \forall \mathbf{x}, \mathbf{y} \in \mathbb{R}^{M_{1} \times\cdots\times M_{D}},
932
933 where :math:`\|\cdot\|_{\mathbb{R}^{(N_{1} \times\cdots\times N_{K}) \times (M_{1} \times\cdots\times
934 M_{D})}}` and :math:`\|\cdot\|_{\mathbb{R}^{M_{1} \times\cdots\times M_{D}}}` are the canonical norms on their
935 respective spaces.
936
937 The smallest Lipschitz constant of the Jacobian is called the *optimal diff-Lipschitz constant*.
938 """
939 raise NotImplementedError
940
941
[docs]
942class ProxFunc(Func):
943 r"""
944 Base class for real-valued proximable functionals :math:`f: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to
945 \mathbb{R} \cup \{+\infty\}`.
946
947 A functional :math:`f: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R} \cup \{+\infty\}` is said
948 *proximable* if its **proximity operator** (see :py:meth:`~pyxu.abc.ProxFunc.prox` for a definition) admits a
949 *simple closed-form expression* **or** can be evaluated *efficiently* and with *high accuracy*.
950
951 Instances of this class must implement :py:meth:`~pyxu.abc.Map.apply` and :py:meth:`~pyxu.abc.ProxFunc.prox`.
952
953 If :math:`f` is Lipschitz-continuous with known Lipschitz constant :math:`L`, the latter should be stored in the
954 :py:attr:`~pyxu.abc.Map.lipschitz` property.
955 """
956
957 @classmethod
958 def properties(cls) -> cabc.Set[Property]:
959 p = set(super().properties())
960 p.add(Property.PROXIMABLE)
961 return frozenset(p)
962
963 def __init__(
964 self,
965 dim_shape: pxt.NDArrayShape,
966 codim_shape: pxt.NDArrayShape,
967 ):
968 super().__init__(
969 dim_shape=dim_shape,
970 codim_shape=codim_shape,
971 )
972
[docs]
973 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
974 r"""
975 Evaluate proximity operator of :math:`\tau f` at specified point(s).
976
977 Parameters
978 ----------
979 arr: NDArray
980 (..., M1,...,MD) input points.
981 tau: Real
982 Positive scale factor.
983
984 Returns
985 -------
986 out: NDArray
987 (..., M1,...,MD) proximal evaluations.
988
989 Notes
990 -----
991 For :math:`\tau >0`, the *proximity operator* of a scaled functional :math:`f: \mathbb{R}^{M_{1}
992 \times\cdots\times M_{D}} \to \mathbb{R}` is defined as:
993
994 .. math::
995
996 \mathbf{\text{prox}}_{\tau f}(\mathbf{z})
997 :=
998 \arg\min_{\mathbf{x}\in\mathbb{R}^{M_{1} \times\cdots\times M_{D}}}
999 f(x)+\frac{1}{2\tau} \|\mathbf{x}-\mathbf{z}\|_{2}^{2},
1000 \quad
1001 \forall \mathbf{z} \in \mathbb{R}^{M_{1} \times\cdots\times M_{D}}.
1002 """
1003 raise NotImplementedError
1004
[docs]
1005 def fenchel_prox(self, arr: pxt.NDArray, sigma: pxt.Real) -> pxt.NDArray:
1006 r"""
1007 Evaluate proximity operator of :math:`\sigma f^{\ast}`, the scaled Fenchel conjugate of :math:`f`, at specified
1008 point(s).
1009
1010 Parameters
1011 ----------
1012 arr: NDArray
1013 (..., M1,...,MD) input points.
1014 sigma: Real
1015 Positive scale factor.
1016
1017 Returns
1018 -------
1019 out: NDArray
1020 (..., M1,...,MD) proximal evaluations.
1021
1022 Notes
1023 -----
1024 For :math:`\sigma > 0`, the *Fenchel conjugate* is defined as:
1025
1026 .. math::
1027
1028 f^{\ast}(\mathbf{z})
1029 :=
1030 \max_{\mathbf{x}\in\mathbb{R}^{M_{1} \times\cdots\times M_{D}}}
1031 \langle \mathbf{x},\mathbf{z} \rangle - f(\mathbf{x}).
1032
1033 From *Moreau's identity*, its proximal operator is given by:
1034
1035 .. math::
1036
1037 \mathbf{\text{prox}}_{\sigma f^{\ast}}(\mathbf{z})
1038 =
1039 \mathbf{z} - \sigma \mathbf{\text{prox}}_{f/\sigma}(\mathbf{z}/\sigma).
1040 """
1041 out = arr - sigma * self.prox(arr=arr / sigma, tau=1 / sigma)
1042 return out
1043
[docs]
1044 def moreau_envelope(self, mu: pxt.Real) -> pxt.OpT:
1045 r"""
1046 Approximate proximable functional :math:`f` by its *Moreau envelope* :math:`f^{\mu}`.
1047
1048 Parameters
1049 ----------
1050 mu: Real
1051 Positive regularization parameter.
1052
1053 Returns
1054 -------
1055 op: OpT
1056 Differential Moreau envelope.
1057
1058 Notes
1059 -----
1060 Consider a convex non-smooth proximable functional :math:`f: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to
1061 \mathbb{R} \cup \{+\infty\}` and a regularization parameter :math:`\mu > 0`. The :math:`\mu`-*Moreau envelope*
1062 (or *Moreau-Yoshida envelope*) of :math:`f` is given by
1063
1064 .. math::
1065
1066 f^{\mu}(\mathbf{x})
1067 =
1068 \min_{\mathbf{z} \in \mathbb{R}^{M_{1} \times\cdots\times M_{D}}}
1069 f(\mathbf{z})
1070 \quad + \quad
1071 \frac{1}{2\mu} \|\mathbf{x} - \mathbf{z}\|^{2}.
1072
1073 The parameter :math:`\mu` controls the trade-off between the regularity properties of :math:`f^{\mu}` and the
1074 approximation error incurred by the Moreau-Yoshida regularization.
1075
1076 The Moreau envelope inherits the convexity of :math:`f` and is gradient-Lipschitz (with Lipschitz constant
1077 :math:`\mu^{-1}`), even if :math:`f` is non-smooth. Its gradient is moreover given by:
1078
1079 .. math::
1080
1081 \nabla f^{\mu}(\mathbf{x})
1082 =
1083 \mu^{-1} \left(\mathbf{x} - \text{prox}_{\mu f}(\mathbf{x})\right).
1084
1085 In addition, :math:`f^{\mu}` envelopes :math:`f` from below: :math:`f^{\mu}(\mathbf{x}) \leq f(\mathbf{x})`.
1086 This envelope becomes tighter as :math:`\mu \to 0`:
1087
1088 .. math::
1089
1090 \lim_{\mu \to 0} f^{\mu}(\mathbf{x}) = f(\mathbf{x}).
1091
1092 Finally, it can be shown that the minimizers of :math:`f` and :math:`f^{\mu}` coincide, and that the Fenchel
1093 conjugate of :math:`f^{\mu}` is strongly-convex.
1094
1095 Example
1096 -------
1097 Construct and plot the Moreau envelope of the :math:`\ell_{1}`-norm:
1098
1099 .. plot::
1100
1101 import numpy as np
1102 import matplotlib.pyplot as plt
1103 from pyxu.abc import ProxFunc
1104
1105 class L1Norm(ProxFunc):
1106 def __init__(self, dim: int):
1107 super().__init__(dim_shape=dim, codim_shape=1)
1108 def apply(self, arr):
1109 return np.linalg.norm(arr, axis=-1, keepdims=True, ord=1)
1110 def prox(self, arr, tau):
1111 return np.clip(np.abs(arr)-tau, a_min=0, a_max=None) * np.sign(arr)
1112
1113 mu = [0.1, 0.5, 1]
1114 f = [L1Norm(dim=1).moreau_envelope(_mu) for _mu in mu]
1115 x = np.linspace(-1, 1, 512).reshape(-1, 1) # evaluation points
1116
1117 fig, ax = plt.subplots(ncols=2)
1118 for _mu, _f in zip(mu, f):
1119 ax[0].plot(x, _f(x), label=f"mu={_mu}")
1120 ax[1].plot(x, _f.grad(x), label=f"mu={_mu}")
1121 ax[0].set_title('Moreau Envelope')
1122 ax[1].set_title("Derivative of Moreau Envelope")
1123 for _ax in ax:
1124 _ax.legend()
1125 _ax.set_aspect("equal")
1126 fig.tight_layout()
1127 """
1128 from pyxu.operator.interop import from_source
1129
1130 assert mu > 0, f"mu: expected positive, got {mu}"
1131
1132 def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray:
1133 xp = pxu.get_array_module(arr)
1134
1135 x = self.prox(arr, tau=_._mu)
1136 y = xp.sum(
1137 (arr - x) ** 2,
1138 axis=tuple(range(-self.dim_rank, 0)),
1139 )[..., np.newaxis]
1140 y *= 0.5 / _._mu
1141
1142 out = self.apply(x) + y
1143 return out
1144
1145 def op_grad(_, arr):
1146 out = arr - self.prox(arr, tau=_._mu)
1147 out /= _._mu
1148 return out
1149
1150 op = from_source(
1151 cls=DiffFunc,
1152 dim_shape=self.dim_shape,
1153 codim_shape=self.codim_shape,
1154 embed=dict(
1155 _name="moreau_envelope",
1156 _mu=mu,
1157 _diff_lipschitz=float(1 / mu),
1158 ),
1159 apply=op_apply,
1160 grad=op_grad,
1161 _expr=lambda _: ("moreau_envelope", _, _._mu),
1162 )
1163 return op
1164
1165
[docs]
1166class DiffFunc(DiffMap, Func):
1167 r"""
1168 Base class for real-valued differentiable functionals :math:`f: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to
1169 \mathbb{R}`.
1170
1171 Instances of this class must implement :py:meth:`~pyxu.abc.Map.apply` and :py:meth:`~pyxu.abc.DiffFunc.grad`.
1172
1173 If :math:`f` and/or its derivative :math:`f'` are Lipschitz-continuous with known Lipschitz constants :math:`L` and
1174 :math:`\partial L`, the latter should be stored in the :py:attr:`~pyxu.abc.Map.lipschitz` and
1175 :py:attr:`~pyxu.abc.DiffMap.diff_lipschitz` properties.
1176 """
1177
1178 @classmethod
1179 def properties(cls) -> cabc.Set[Property]:
1180 p = set()
1181 for klass in cls.__bases__:
1182 p |= klass.properties()
1183 p.add(Property.DIFFERENTIABLE_FUNCTION)
1184 return frozenset(p)
1185
1186 def __init__(
1187 self,
1188 dim_shape: pxt.NDArrayShape,
1189 codim_shape: pxt.NDArrayShape,
1190 ):
1191 for klass in [DiffMap, Func]:
1192 klass.__init__(
1193 self,
1194 dim_shape=dim_shape,
1195 codim_shape=codim_shape,
1196 )
1197
1198 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT:
1199 op = LinFunc.from_array(
1200 A=self.grad(arr)[np.newaxis, ...],
1201 dim_rank=self.dim_rank,
1202 )
1203 return op
1204
[docs]
1205 def grad(self, arr: pxt.NDArray) -> pxt.NDArray:
1206 r"""
1207 Evaluate operator gradient at specified point(s).
1208
1209 Parameters
1210 ----------
1211 arr: NDArray
1212 (..., M1,...,MD) input points.
1213
1214 Returns
1215 -------
1216 out: NDArray
1217 (..., M1,...,MD) gradients.
1218
1219 Notes
1220 -----
1221 The gradient of a functional :math:`f: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}` is given, for
1222 every :math:`\mathbf{x} \in \mathbb{R}^{M_{1} \times\cdots\times M_{D}}`, by
1223
1224 .. math::
1225
1226 \nabla f(\mathbf{x})
1227 :=
1228 \left[\begin{array}{c}
1229 \frac{\partial f}{\partial x_{1}}(\mathbf{x}) \\
1230 \vdots \\
1231 \frac{\partial f}{\partial x_{M}}(\mathbf{x})
1232 \end{array}\right].
1233 """
1234 raise NotImplementedError
1235
1236
[docs]
1237class ProxDiffFunc(ProxFunc, DiffFunc):
1238 r"""
1239 Base class for real-valued differentiable *and* proximable functionals :math:`f:\mathbb{R}^{M_{1} \times\cdots\times
1240 M_{D}} \to \mathbb{R}`.
1241
1242 Instances of this class must implement :py:meth:`~pyxu.abc.Map.apply`, :py:meth:`~pyxu.abc.DiffFunc.grad`, and
1243 :py:meth:`~pyxu.abc.ProxFunc.prox`.
1244
1245 If :math:`f` and/or its derivative :math:`f'` are Lipschitz-continuous with known Lipschitz constants :math:`L` and
1246 :math:`\partial L`, the latter should be stored in the :py:attr:`~pyxu.abc.Map.lipschitz` and
1247 :py:attr:`~pyxu.abc.DiffMap.diff_lipschitz` properties.
1248 """
1249
1250 @classmethod
1251 def properties(cls) -> cabc.Set[Property]:
1252 p = set()
1253 for klass in cls.__bases__:
1254 p |= klass.properties()
1255 return frozenset(p)
1256
1257 def __init__(
1258 self,
1259 dim_shape: pxt.NDArrayShape,
1260 codim_shape: pxt.NDArrayShape,
1261 ):
1262 for klass in [ProxFunc, DiffFunc]:
1263 klass.__init__(
1264 self,
1265 dim_shape=dim_shape,
1266 codim_shape=codim_shape,
1267 )
1268
1269
[docs]
1270class QuadraticFunc(ProxDiffFunc):
1271 # This is a special abstract base class with more __init__ parameters than `dim/codim_shape`.
1272 r"""
1273 Base class for quadratic functionals :math:`f: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R} \cup
1274 \{+\infty\}`.
1275
1276 The quadratic functional is defined as:
1277
1278 .. math::
1279
1280 f(\mathbf{x})
1281 =
1282 \frac{1}{2} \langle\mathbf{x}, \mathbf{Q}\mathbf{x}\rangle
1283 +
1284 \langle\mathbf{c},\mathbf{x}\rangle
1285 +
1286 t,
1287 \qquad \forall \mathbf{x} \in \mathbb{R}^{M_{1} \times\cdots\times M_{D}},
1288
1289 where :math:`Q` is a positive-definite operator :math:`\mathbf{Q}:\mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to
1290 \mathbb{R}^{M_{1} \times\cdots\times M_{D}}`, :math:`\mathbf{c} \in \mathbb{R}^{M_{1} \times\cdots\times M_{D}}`,
1291 and :math:`t > 0`.
1292
1293 Its gradient is given by:
1294
1295 .. math::
1296
1297 \nabla f(\mathbf{x}) = \mathbf{Q}\mathbf{x} + \mathbf{c}.
1298
1299 Its proximity operator by:
1300
1301 .. math::
1302
1303 \text{prox}_{\tau f}(x)
1304 =
1305 \left(
1306 \mathbf{Q} + \tau^{-1} \mathbf{Id}
1307 \right)^{-1}
1308 \left(
1309 \tau^{-1}\mathbf{x} - \mathbf{c}
1310 \right).
1311
1312 In practice the proximity operator is evaluated via :py:class:`~pyxu.opt.solver.CG`.
1313
1314 The Lipschitz constant :math:`L` of a quadratic on an unbounded domain is unbounded. The Lipschitz constant
1315 :math:`\partial L` of :math:`\nabla f` is given by the spectral norm of :math:`\mathbf{Q}`.
1316 """
1317
1318 @classmethod
1319 def properties(cls) -> cabc.Set[Property]:
1320 p = set(super().properties())
1321 p.add(Property.QUADRATIC)
1322 return frozenset(p)
1323
[docs]
1324 def __init__(
1325 self,
1326 dim_shape: pxt.NDArrayShape,
1327 codim_shape: pxt.NDArrayShape,
1328 # required in place of `dim` to have uniform interface with Operator hierarchy.
1329 Q: "PosDefOp" = None,
1330 c: "LinFunc" = None,
1331 t: pxt.Real = 0,
1332 ):
1333 r"""
1334 Parameters
1335 ----------
1336 Q: ~pyxu.abc.PosDefOp
1337 Positive-definite operator. (Default: Identity)
1338 c: ~pyxu.abc.LinFunc
1339 Linear functional. (Default: NullFunc)
1340 t: Real
1341 Offset. (Default: 0)
1342 """
1343 from pyxu.operator import IdentityOp, NullFunc
1344
1345 super().__init__(
1346 dim_shape=dim_shape,
1347 codim_shape=codim_shape,
1348 )
1349
1350 # Do NOT access (_Q, _c, _t) directly through `self`:
1351 # their values may not reflect the true (Q, c, t) parameterization.
1352 # (Reason: arithmetic propagation.)
1353 # Always access (Q, c, t) by querying the arithmetic method `_quad_spec()`.
1354 self._Q = IdentityOp(dim_shape=self.dim_shape) if (Q is None) else Q
1355 self._c = NullFunc(dim_shape=self.dim_shape) if (c is None) else c
1356 self._t = t
1357
1358 # ensure dimensions are consistent if None-initialized
1359 assert self._Q.dim_shape == self.dim_shape
1360 assert self._Q.codim_shape == self.dim_shape
1361 assert self._c.dim_shape == self.dim_shape
1362 assert self._c.codim_shape == self.codim_shape
1363
1364 self.diff_lipschitz = self._Q.lipschitz
1365
1366 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
1367 Q, c, t = self._quad_spec()
1368 out = (arr * Q.apply(arr)).sum(axis=tuple(range(-self.dim_rank, 0)))[..., np.newaxis]
1369 out /= 2
1370 out += c.apply(arr)
1371 out += t
1372 return out
1373
1374 def grad(self, arr: pxt.NDArray) -> pxt.NDArray:
1375 Q, c, _ = self._quad_spec()
1376 out = Q.apply(arr) + c.grad(arr)
1377 return out
1378
1379 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
1380 from pyxu.operator import HomothetyOp
1381 from pyxu.opt.solver import CG
1382 from pyxu.opt.stop import MaxIter
1383
1384 Q, c, _ = self._quad_spec()
1385 A = Q + HomothetyOp(cst=1 / tau, dim_shape=Q.dim_shape)
1386 b = arr.copy()
1387 b /= tau
1388 b -= c.grad(arr)
1389
1390 slvr = CG(A=A, show_progress=False)
1391
1392 sentinel = MaxIter(n=2 * A.dim_size)
1393 stop_crit = slvr.default_stop_crit() | sentinel
1394
1395 slvr.fit(b=b, stop_crit=stop_crit)
1396 return slvr.solution()
1397
1398 def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real:
1399 Q, *_ = self._quad_spec()
1400 dL = Q.estimate_lipschitz(**kwargs)
1401 return dL
1402
[docs]
1403 def _quad_spec(self):
1404 """
1405 Canonical quadratic parameterization.
1406
1407 Useful for some internal methods, and overloaded during operator arithmetic.
1408 """
1409 return (self._Q, self._c, self._t)
1410
1411
[docs]
1412class LinOp(DiffMap):
1413 r"""
1414 Base class for real-valued linear operators :math:`\mathbf{A}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to
1415 \mathbb{R}^{N_{1} \times\cdots\times N_{K}}`.
1416
1417 Instances of this class must implement :py:meth:`~pyxu.abc.Map.apply` and :py:meth:`~pyxu.abc.LinOp.adjoint`.
1418
1419 If known, the Lipschitz constant :math:`L` should be stored in the :py:attr:`~pyxu.abc.Map.lipschitz` property.
1420
1421 The Jacobian of a linear map :math:`\mathbf{A}` is constant.
1422 """
1423
1424 # Internal Helpers ------------------------------------
1425 @staticmethod
1426 def _warn_vals_sparse_gpu():
1427 msg = "\n".join(
1428 [
1429 "Potential Error:",
1430 "Sparse GPU-evaluation of svdvals() is known to produce incorrect results. (CuPy-specific + Matrix-Dependant.)",
1431 "It is advised to cross-check results with CPU-computed results.",
1432 ]
1433 )
1434 warnings.warn(msg, pxw.BackendWarning)
1435
1436 # -----------------------------------------------------
1437
1438 @classmethod
1439 def properties(cls) -> cabc.Set[Property]:
1440 p = set(super().properties())
1441 p.add(Property.LINEAR)
1442 return frozenset(p)
1443
1444 def __init__(
1445 self,
1446 dim_shape: pxt.NDArrayShape,
1447 codim_shape: pxt.NDArrayShape,
1448 ):
1449 super().__init__(
1450 dim_shape=dim_shape,
1451 codim_shape=codim_shape,
1452 )
1453 self.diff_lipschitz = 0
1454
[docs]
1455 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
1456 r"""
1457 Evaluate operator adjoint at specified point(s).
1458
1459 Parameters
1460 ----------
1461 arr: NDArray
1462 (..., N1,...,NK) input points.
1463
1464 Returns
1465 -------
1466 out: NDArray
1467 (..., M1,...,MD) adjoint evaluations.
1468
1469 Notes
1470 -----
1471 The *adjoint* :math:`\mathbf{A}^{\ast}: \mathbb{R}^{N_{1} \times\cdots\times N_{K}} \to \mathbb{R}^{M_{1}
1472 \times\cdots\times M_{D}}` of :math:`\mathbf{A}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to
1473 \mathbb{R}^{N_{1} \times\cdots\times N_{K}}` is defined as:
1474
1475 .. math::
1476
1477 \langle \mathbf{x}, \mathbf{A}^{\ast}\mathbf{y}\rangle_{\mathbb{R}^{M_{1} \times\cdots\times M_{D}}}
1478 :=
1479 \langle \mathbf{A}\mathbf{x}, \mathbf{y}\rangle_{\mathbb{R}^{N_{1} \times\cdots\times N_{K}}},
1480 \qquad
1481 \forall (\mathbf{x},\mathbf{y})\in \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \times \mathbb{R}^{N_{1}
1482 \times\cdots\times N_{K}}.
1483 """
1484 raise NotImplementedError
1485
1486 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT:
1487 return self
1488
1489 @property
1490 def T(self) -> pxt.OpT:
1491 r"""
1492 Return the adjoint of :math:`\mathbf{A}`.
1493 """
1494 import pyxu.abc.arithmetic as arithmetic
1495
1496 return arithmetic.TransposeRule(op=self).op()
1497
[docs]
1498 def estimate_lipschitz(self, **kwargs) -> pxt.Real:
1499 r"""
1500 Compute a Lipschitz constant of the operator.
1501
1502 Parameters
1503 ----------
1504 method: "svd" | "trace"
1505
1506 * If `svd`, compute the optimal Lipschitz constant.
1507 * If `trace`, compute an upper bound. (Default)
1508
1509 kwargs:
1510 Optional kwargs passed on to:
1511
1512 * `svd`: :py:func:`~pyxu.abc.LinOp.svdvals`
1513 * `trace`: :py:func:`~pyxu.math.hutchpp`
1514
1515 Notes
1516 -----
1517 * The tightest Lipschitz constant is given by the spectral norm of the operator :math:`\mathbf{A}`:
1518 :math:`\|\mathbf{A}\|_{2}`. It can be computed via the SVD, which is compute-intensive task for large
1519 operators. In this setting, it may be advantageous to overestimate the Lipschitz constant with the Frobenius
1520 norm of :math:`\mathbf{A}` since :math:`\|\mathbf{A}\|_{F} \geq \|\mathbf{A}\|_{2}`.
1521
1522 :math:`\|\mathbf{A}\|_{F}` can be efficiently approximated by computing the trace of :math:`\mathbf{A}^{\ast}
1523 \mathbf{A}` (or :math:`\mathbf{A}\mathbf{A}^{\ast}`) via the `Hutch++ stochastic algorithm
1524 <https://arxiv.org/abs/2010.09649>`_.
1525
1526 * :math:`\|\mathbf{A}\|_{F}` is upper-bounded by :math:`\|\mathbf{A}\|_{F} \leq \sqrt{n} \|\mathbf{A}\|_{2}`,
1527 where the equality is reached (worst-case scenario) when the eigenspectrum of the linear operator is flat.
1528 """
1529 method = kwargs.get("method", "trace").lower().strip()
1530
1531 if method == "svd":
1532 # svdvals() may have alternative signature in specialized classes, but we must always use
1533 # the LinOp.svdvals() interface below for kwargs-filtering.
1534 func, sig_func = self.__class__.svdvals, LinOp.svdvals
1535 kwargs.update(k=1)
1536 estimate = lambda: func(self, **kwargs).item()
1537 elif method == "trace":
1538 from pyxu.math import hutchpp as func
1539
1540 sig_func = func
1541 kwargs.update(
1542 op=self.gram() if (self.codim_size >= self.dim_size) else self.cogram(),
1543 m=kwargs.get("m", 126),
1544 )
1545 estimate = lambda: np.sqrt(func(**kwargs)).item()
1546 else:
1547 raise NotImplementedError
1548
1549 # Filter unsupported kwargs
1550 sig = inspect.Signature.from_callable(sig_func)
1551 kwargs = {k: v for (k, v) in kwargs.items() if (k in sig.parameters)}
1552
1553 L = estimate()
1554 return L
1555
[docs]
1556 def svdvals(
1557 self,
1558 k: pxt.Integer,
1559 gpu: bool = False,
1560 dtype: pxt.DType = None,
1561 **kwargs,
1562 ) -> pxt.NDArray:
1563 r"""
1564 Compute leading singular values of the linear operator.
1565
1566 Parameters
1567 ----------
1568 k: Integer
1569 Number of singular values to compute.
1570 gpu: bool
1571 If ``True`` the singular value decomposition is performed on the GPU.
1572 dtype: DType
1573 Working precision of the linear operator.
1574 kwargs: ~collections.abc.Mapping
1575 Additional kwargs accepted by :py:func:`~scipy.sparse.linalg.svds`.
1576
1577 Returns
1578 -------
1579 D: NDArray
1580 (k,) singular values in ascending order.
1581 """
1582 if dtype is None:
1583 dtype = pxrt.Width.DOUBLE.value
1584
1585 def _dense_eval():
1586 if gpu:
1587 assert pxd.CUPY_ENABLED
1588 import cupy as xp
1589 import cupy.linalg as spx
1590 else:
1591 import numpy as xp
1592 import scipy.linalg as spx
1593 A = self.asarray(xp=xp, dtype=dtype)
1594 B = A.reshape(self.codim_size, self.dim_size)
1595 return spx.svd(B, compute_uv=False)
1596
1597 def _sparse_eval():
1598 if gpu:
1599 assert pxd.CUPY_ENABLED
1600 import cupyx.scipy.sparse.linalg as spx
1601
1602 self._warn_vals_sparse_gpu()
1603 else:
1604 spx = spsl
1605 from pyxu.operator import ReshapeAxes
1606 from pyxu.operator.interop import to_sciop
1607
1608 # SciPy's LinearOperator only understands 2D linear operators.
1609 # -> wrap `self` into 2D form for SVD computations.
1610 lhs = ReshapeAxes(dim_shape=self.codim_shape, codim_shape=self.codim_size)
1611 rhs = ReshapeAxes(dim_shape=self.dim_size, codim_shape=self.dim_shape)
1612 op = to_sciop(
1613 op=lhs * self * rhs,
1614 gpu=gpu,
1615 dtype=dtype,
1616 )
1617
1618 which = kwargs.get("which", "LM")
1619 assert which.upper() == "LM", "Only computing leading singular values is supported."
1620 kwargs.update(
1621 k=k,
1622 which=which,
1623 return_singular_vectors=False,
1624 # random_state=0, # unsupported by CuPy
1625 )
1626 return spx.svds(op, **kwargs)
1627
1628 if k >= min(self.dim_size, self.codim_size) // 2:
1629 msg = "Too many svdvals wanted: using matrix-based ops."
1630 warnings.warn(msg, pxw.DenseWarning)
1631 D = _dense_eval()
1632 else:
1633 D = _sparse_eval()
1634
1635 # Filter to k largest magnitude + sorted
1636 xp = pxu.get_array_module(D)
1637 return D[xp.argsort(D)][-k:]
1638
[docs]
1639 def asarray(
1640 self,
1641 xp: pxt.ArrayModule = None,
1642 dtype: pxt.DType = None,
1643 ) -> pxt.NDArray:
1644 r"""
1645 Matrix representation of the linear operator.
1646
1647 Parameters
1648 ----------
1649 xp: ArrayModule
1650 Which array module to use to represent the output. (Default: NumPy.)
1651 dtype: DType
1652 Precision of the array. (Default: current runtime precision.)
1653
1654 Returns
1655 -------
1656 A: NDArray
1657 (*codim_shape, *dim_shape) array-representation of the operator.
1658
1659 Note
1660 ----
1661 This generic implementation assumes the operator is backend-agnostic. Thus, when defining a new
1662 backend-specific operator, :py:meth:`~pyxu.abc.LinOp.asarray` may need to be overriden.
1663 """
1664 if xp is None:
1665 xp = pxd.NDArrayInfo.default().module()
1666 if dtype is None:
1667 dtype = pxrt.Width.DOUBLE.value
1668
1669 E = xp.eye(self.dim_size, dtype=dtype).reshape(*self.dim_shape, *self.dim_shape)
1670 A = self.apply(E) # (*dim_shape, *codim_shape)
1671
1672 axes = tuple(range(-self.codim_rank, 0)) + tuple(range(self.dim_rank))
1673 B = A.transpose(axes) # (*codim_shape, *dim_shape)
1674 return B
1675
[docs]
1676 def gram(self) -> pxt.OpT:
1677 r"""
1678 Gram operator :math:`\mathbf{A}^{\ast} \mathbf{A}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to
1679 \mathbb{R}^{M_{1} \times\cdots\times M_{D}}`.
1680
1681 Returns
1682 -------
1683 op: OpT
1684 Gram operator.
1685
1686 Note
1687 ----
1688 By default the Gram is computed by the composition ``self.T * self``. This may not be the fastest way to
1689 compute the Gram operator. If the Gram can be computed more efficiently (e.g. with a convolution), the user
1690 should re-define this method.
1691 """
1692
1693 def op_expr(_) -> tuple:
1694 return ("gram", self)
1695
1696 op = self.T * self
1697 op._expr = types.MethodType(op_expr, op)
1698 return op.asop(SelfAdjointOp)
1699
[docs]
1700 def cogram(self) -> pxt.OpT:
1701 r"""
1702 Co-Gram operator :math:`\mathbf{A}\mathbf{A}^{\ast}:\mathbb{R}^{N_{1} \times\cdots\times N_{K}} \to
1703 \mathbb{R}^{N_{1} \times\cdots\times N_{K}}`.
1704
1705 Returns
1706 -------
1707 op: OpT
1708 Co-Gram operator.
1709
1710 Note
1711 ----
1712 By default the co-Gram is computed by the composition ``self * self.T``. This may not be the fastest way to
1713 compute the co-Gram operator. If the co-Gram can be computed more efficiently (e.g. with a convolution), the
1714 user should re-define this method.
1715 """
1716
1717 def op_expr(_) -> tuple:
1718 return ("cogram", self)
1719
1720 op = self * self.T
1721 op._expr = types.MethodType(op_expr, op)
1722 return op.asop(SelfAdjointOp)
1723
[docs]
1724 def pinv(
1725 self,
1726 arr: pxt.NDArray,
1727 damp: pxt.Real,
1728 kwargs_init=None,
1729 kwargs_fit=None,
1730 ) -> pxt.NDArray:
1731 r"""
1732 Evaluate the Moore-Penrose pseudo-inverse :math:`\mathbf{A}^{\dagger}` at specified point(s).
1733
1734 Parameters
1735 ----------
1736 arr: NDArray
1737 (..., N1,...,NK) input points.
1738 damp: Real
1739 Positive dampening factor regularizing the pseudo-inverse.
1740 kwargs_init: ~collections.abc.Mapping
1741 Optional kwargs to be passed to :py:meth:`~pyxu.opt.solver.CG`'s ``__init__()`` method.
1742 kwargs_fit: ~collections.abc.Mapping
1743 Optional kwargs to be passed to :py:meth:`~pyxu.opt.solver.CG`'s ``fit()`` method.
1744
1745 Returns
1746 -------
1747 out: NDArray
1748 (..., M1,...,MD) pseudo-inverse(s).
1749
1750 Notes
1751 -----
1752 The Moore-Penrose pseudo-inverse of an operator :math:`\mathbf{A}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}}
1753 \to \mathbb{R}^{N_{1} \times\cdots\times N_{K}}` is defined as the operator :math:`\mathbf{A}^{\dagger}:
1754 \mathbb{R}^{N_{1} \times\cdots\times N_{K}} \to \mathbb{R}^{M_{1} \times\cdots\times M_{D}}` verifying the
1755 Moore-Penrose conditions:
1756
1757 1. :math:`\mathbf{A} \mathbf{A}^{\dagger} \mathbf{A} = \mathbf{A}`,
1758 2. :math:`\mathbf{A}^{\dagger} \mathbf{A} \mathbf{A}^{\dagger} = \mathbf{A}^{\dagger}`,
1759 3. :math:`(\mathbf{A}^{\dagger} \mathbf{A})^{\ast} = \mathbf{A}^{\dagger} \mathbf{A}`,
1760 4. :math:`(\mathbf{A} \mathbf{A}^{\dagger})^{\ast} = \mathbf{A} \mathbf{A}^{\dagger}`.
1761
1762 This operator exists and is unique for any finite-dimensional linear operator. The action of the pseudo-inverse
1763 :math:`\mathbf{A}^{\dagger} \mathbf{y}` for every :math:`\mathbf{y} \in \mathbb{R}^{N_{1} \times\cdots\times
1764 N_{K}}` can be computed in matrix-free fashion by solving the *normal equations*:
1765
1766 .. math::
1767
1768 \mathbf{A}^{\ast} \mathbf{A} \mathbf{x} = \mathbf{A}^{\ast} \mathbf{y}
1769 \quad\Leftrightarrow\quad
1770 \mathbf{x} = \mathbf{A}^{\dagger} \mathbf{y},
1771 \quad
1772 \forall (\mathbf{x}, \mathbf{y}) \in \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \times \mathbb{R}^{N_{1}
1773 \times\cdots\times N_{K}}.
1774
1775 In the case of severe ill-conditioning, it is possible to consider the dampened normal equations for a
1776 numerically-stabler approximation of :math:`\mathbf{A}^{\dagger} \mathbf{y}`:
1777
1778 .. math::
1779
1780 (\mathbf{A}^{\ast} \mathbf{A} + \tau I) \mathbf{x} = \mathbf{A}^{\ast} \mathbf{y},
1781
1782 where :math:`\tau > 0` corresponds to the `damp` parameter.
1783 """
1784 from pyxu.operator import HomothetyOp
1785 from pyxu.opt.solver import CG
1786 from pyxu.opt.stop import MaxIter
1787
1788 kwargs_fit = dict() if kwargs_fit is None else kwargs_fit
1789 kwargs_init = dict() if kwargs_init is None else kwargs_init
1790 kwargs_init.update(show_progress=kwargs_init.get("show_progress", False))
1791
1792 if np.isclose(damp, 0):
1793 A = self.gram()
1794 else:
1795 A = self.gram() + HomothetyOp(cst=damp, dim_shape=self.dim_shape)
1796
1797 cg = CG(A, **kwargs_init)
1798 if "stop_crit" not in kwargs_fit:
1799 # .pinv() may not have sufficiently converged given the default CG stopping criteria.
1800 # To avoid infinite loops, CG iterations are thresholded.
1801 sentinel = MaxIter(n=20 * A.dim_size)
1802 kwargs_fit["stop_crit"] = cg.default_stop_crit() | sentinel
1803
1804 b = self.adjoint(arr)
1805 cg.fit(b=b, **kwargs_fit)
1806 return cg.solution()
1807
[docs]
1808 def dagger(
1809 self,
1810 damp: pxt.Real,
1811 kwargs_init=None,
1812 kwargs_fit=None,
1813 ) -> pxt.OpT:
1814 r"""
1815 Return the Moore-Penrose pseudo-inverse operator :math:`\mathbf{A}^\dagger`.
1816
1817 Parameters
1818 ----------
1819 damp: Real
1820 Positive dampening factor regularizing the pseudo-inverse.
1821 kwargs_init: ~collections.abc.Mapping
1822 Optional kwargs to be passed to :py:meth:`~pyxu.opt.solver.CG`'s ``__init__()`` method.
1823 kwargs_fit: ~collections.abc.Mapping
1824 Optional kwargs to be passed to :py:meth:`~pyxu.opt.solver.CG`'s ``fit()`` method.
1825
1826 Returns
1827 -------
1828 op: OpT
1829 Moore-Penrose pseudo-inverse operator.
1830 """
1831 from pyxu.operator.interop import from_source
1832
1833 def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray:
1834 return self.pinv(
1835 arr,
1836 damp=_._damp,
1837 kwargs_init=_._kwargs_init,
1838 kwargs_fit=_._kwargs_fit,
1839 )
1840
1841 def op_adjoint(_, arr: pxt.NDArray) -> pxt.NDArray:
1842 return self.T.pinv(
1843 arr,
1844 damp=_._damp,
1845 kwargs_init=_._kwargs_init,
1846 kwargs_fit=_._kwargs_fit,
1847 )
1848
1849 kwargs_fit = dict() if kwargs_fit is None else kwargs_fit
1850 kwargs_init = dict() if kwargs_init is None else kwargs_init
1851
1852 dagger = from_source(
1853 cls=SquareOp if (self.dim_size == self.codim_size) else LinOp,
1854 dim_shape=self.codim_shape,
1855 codim_shape=self.dim_shape,
1856 embed=dict(
1857 _name="dagger",
1858 _damp=damp,
1859 _kwargs_init=copy.copy(kwargs_init),
1860 _kwargs_fit=copy.copy(kwargs_fit),
1861 ),
1862 apply=op_apply,
1863 adjoint=op_adjoint,
1864 _expr=lambda _: (_._name, _, _._damp),
1865 )
1866 return dagger
1867
[docs]
1868 @classmethod
1869 def from_array(
1870 cls,
1871 A: typ.Union[pxt.NDArray, pxt.SparseArray],
1872 dim_rank=None,
1873 enable_warnings: bool = True,
1874 ) -> pxt.OpT:
1875 r"""
1876 Instantiate a :py:class:`~pyxu.abc.LinOp` from its array representation.
1877
1878 Parameters
1879 ----------
1880 A: NDArray
1881 (*codim_shape, *dim_shape) array.
1882 dim_rank: Integer
1883 Dimension rank :math:`D`. (Can be omitted if `A` is 2D.)
1884 enable_warnings: bool
1885 If ``True``, emit a warning in case of precision mis-match issues.
1886
1887 Returns
1888 -------
1889 op: OpT
1890 Linear operator
1891 """
1892 from pyxu.operator.linop.base import _ExplicitLinOp
1893
1894 op = _ExplicitLinOp(
1895 cls,
1896 mat=A,
1897 dim_rank=dim_rank,
1898 enable_warnings=enable_warnings,
1899 )
1900 return op
1901
1902
[docs]
1903class SquareOp(LinOp):
1904 r"""
1905 Base class for *square* linear operators, i.e. :math:`\mathbf{A}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to
1906 \mathbb{R}^{M_{1} \times\cdots\times M_{D}}` (endomorphsisms).
1907 """
1908
1909 @classmethod
1910 def properties(cls) -> cabc.Set[Property]:
1911 p = set(super().properties())
1912 p.add(Property.LINEAR_SQUARE)
1913 return frozenset(p)
1914
1915 def __init__(
1916 self,
1917 dim_shape: pxt.NDArrayShape,
1918 codim_shape: pxt.NDArrayShape,
1919 ):
1920 super().__init__(
1921 dim_shape=dim_shape,
1922 codim_shape=codim_shape,
1923 )
1924 assert self.dim_size == self.codim_size
1925
[docs]
1926 def trace(self, **kwargs) -> pxt.Real:
1927 """
1928 Compute trace of the operator.
1929
1930 Parameters
1931 ----------
1932 method: "explicit" | "hutchpp"
1933
1934 * If `explicit`, compute the exact trace.
1935 * If `hutchpp`, compute an approximation. (Default)
1936
1937 kwargs: ~collections.abc.Mapping
1938 Optional kwargs passed to:
1939
1940 * `explicit`: :py:func:`~pyxu.math.trace`
1941 * `hutchpp`: :py:func:`~pyxu.math.hutchpp`
1942
1943 Returns
1944 -------
1945 tr: Real
1946 Trace estimate.
1947 """
1948 from pyxu.math import hutchpp, trace
1949
1950 method = kwargs.get("method", "hutchpp").lower().strip()
1951
1952 if method == "explicit":
1953 func = sig_func = trace
1954 estimate = lambda: func(op=self, **kwargs)
1955 elif method == "hutchpp":
1956 func = sig_func = hutchpp
1957 estimate = lambda: func(op=self, **kwargs)
1958 else:
1959 raise NotImplementedError
1960
1961 # Filter unsupported kwargs
1962 sig = inspect.Signature.from_callable(sig_func)
1963 kwargs = {k: v for (k, v) in kwargs.items() if (k in sig.parameters)}
1964
1965 tr = estimate()
1966 return tr
1967
1968
[docs]
1969class NormalOp(SquareOp):
1970 r"""
1971 Base class for *normal* operators.
1972
1973 Normal operators satisfy the relation :math:`\mathbf{A} \mathbf{A}^{\ast} = \mathbf{A}^{\ast} \mathbf{A}`. It can
1974 be `shown <https://www.wikiwand.com/en/Spectral_theorem#/Normal_matrices>`_ that an operator is normal iff it is
1975 *unitarily diagonalizable*, i.e. :math:`\mathbf{A} = \mathbf{U} \mathbf{D} \mathbf{U}^{\ast}`.
1976 """
1977
1978 @classmethod
1979 def properties(cls) -> cabc.Set[Property]:
1980 p = set(super().properties())
1981 p.add(Property.LINEAR_NORMAL)
1982 return frozenset(p)
1983
1984 def cogram(self) -> pxt.OpT:
1985 return self.gram()
1986
1987
[docs]
1988class SelfAdjointOp(NormalOp):
1989 r"""
1990 Base class for *self-adjoint* operators.
1991
1992 Self-adjoint operators satisfy the relation :math:`\mathbf{A}^{\ast} = \mathbf{A}`.
1993 """
1994
1995 @classmethod
1996 def properties(cls) -> cabc.Set[Property]:
1997 p = set(super().properties())
1998 p.add(Property.LINEAR_SELF_ADJOINT)
1999 return frozenset(p)
2000
2001 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
2002 return self.apply(arr)
2003
2004
[docs]
2005class UnitOp(NormalOp):
2006 r"""
2007 Base class for *unitary* operators.
2008
2009 Unitary operators satisfy the relation :math:`\mathbf{A} \mathbf{A}^{\ast} = \mathbf{A}^{\ast} \mathbf{A} = I`.
2010 """
2011
2012 @classmethod
2013 def properties(cls) -> cabc.Set[Property]:
2014 p = set(super().properties())
2015 p.add(Property.LINEAR_UNITARY)
2016 return frozenset(p)
2017
2018 def __init__(
2019 self,
2020 dim_shape: pxt.NDArrayShape,
2021 codim_shape: pxt.NDArrayShape,
2022 ):
2023 super().__init__(
2024 dim_shape=dim_shape,
2025 codim_shape=codim_shape,
2026 )
2027 self.lipschitz = UnitOp.estimate_lipschitz(self)
2028
2029 def estimate_lipschitz(self, **kwargs) -> pxt.Real:
2030 return 1
2031
2032 def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray:
2033 out = self.adjoint(arr)
2034 if not np.isclose(damp, 0):
2035 out = pxu.copy_if_unsafe(out)
2036 out /= 1 + damp
2037 return out
2038
2039 def dagger(self, damp: pxt.Real, **kwargs) -> pxt.OpT:
2040 op = self.T / (1 + damp)
2041 return op
2042
2043 def gram(self) -> pxt.OpT:
2044 from pyxu.operator import IdentityOp
2045
2046 return IdentityOp(dim_shape=self.dim_shape)
2047
2048 def svdvals(self, **kwargs) -> pxt.NDArray:
2049 gpu = kwargs.get("gpu", False)
2050 xp = pxd.NDArrayInfo.from_flag(gpu).module()
2051 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value)
2052 D = xp.ones(kwargs["k"], dtype=dtype)
2053 return D
2054
2055
[docs]
2056class ProjOp(SquareOp):
2057 r"""
2058 Base class for *projection* operators.
2059
2060 Projection operators are *idempotent*, i.e. :math:`\mathbf{A}^{2} = \mathbf{A}`.
2061 """
2062
2063 @classmethod
2064 def properties(cls) -> cabc.Set[Property]:
2065 p = set(super().properties())
2066 p.add(Property.LINEAR_IDEMPOTENT)
2067 return frozenset(p)
2068
2069
[docs]
2070class OrthProjOp(ProjOp, SelfAdjointOp):
2071 r"""
2072 Base class for *orthogonal projection* operators.
2073
2074 Orthogonal projection operators are *idempotent* and *self-adjoint*, i.e. :math:`\mathbf{A}^{2} = \mathbf{A}` and
2075 :math:`\mathbf{A}^{\ast} = \mathbf{A}`.
2076 """
2077
2078 @classmethod
2079 def properties(cls) -> cabc.Set[Property]:
2080 p = set()
2081 for klass in cls.__bases__:
2082 p |= klass.properties()
2083 return frozenset(p)
2084
2085 def __init__(
2086 self,
2087 dim_shape: pxt.NDArrayShape,
2088 codim_shape: pxt.NDArrayShape,
2089 ):
2090 super().__init__(
2091 dim_shape=dim_shape,
2092 codim_shape=codim_shape,
2093 )
2094 self.lipschitz = OrthProjOp.estimate_lipschitz(self)
2095
2096 def estimate_lipschitz(self, **kwargs) -> pxt.Real:
2097 return 1
2098
2099 def gram(self) -> pxt.OpT:
2100 return self
2101
2102 def cogram(self) -> pxt.OpT:
2103 return self
2104
2105 def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray:
2106 out = self.apply(arr)
2107 if not np.isclose(damp, 0):
2108 out = pxu.copy_if_unsafe(out)
2109 out /= 1 + damp
2110 return out
2111
2112 def dagger(self, damp: pxt.Real, **kwargs) -> pxt.OpT:
2113 op = self / (1 + damp)
2114 return op
2115
2116
[docs]
2117class PosDefOp(SelfAdjointOp):
2118 r"""
2119 Base class for *positive-definite* operators.
2120 """
2121
2122 @classmethod
2123 def properties(cls) -> cabc.Set[Property]:
2124 p = set(super().properties())
2125 p.add(Property.LINEAR_POSITIVE_DEFINITE)
2126 return frozenset(p)
2127
2128
[docs]
2129class LinFunc(ProxDiffFunc, LinOp):
2130 r"""
2131 Base class for real-valued linear functionals :math:`f: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}`.
2132
2133 Instances of this class must implement :py:meth:`~pyxu.abc.Map.apply`, and :py:meth:`~pyxu.abc.LinOp.adjoint`.
2134
2135 If known, the Lipschitz constant :math:`L` should be stored in the :py:attr:`~pyxu.abc.Map.lipschitz` property.
2136 """
2137
2138 @classmethod
2139 def properties(cls) -> cabc.Set[Property]:
2140 p = set()
2141 for klass in cls.__bases__:
2142 p |= klass.properties()
2143 return frozenset(p)
2144
2145 def __init__(
2146 self,
2147 dim_shape: pxt.NDArrayShape,
2148 codim_shape: pxt.NDArrayShape,
2149 ):
2150 for klass in [ProxDiffFunc, LinOp]:
2151 klass.__init__(
2152 self,
2153 dim_shape=dim_shape,
2154 codim_shape=codim_shape,
2155 )
2156
2157 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT:
2158 return LinOp.jacobian(self, arr)
2159
2160 def estimate_lipschitz(self, **kwargs) -> pxt.Real:
2161 # Try all backends until one works.
2162 for ndi in pxd.NDArrayInfo:
2163 try:
2164 xp = ndi.module()
2165 g = self.grad(xp.ones(self.dim_shape))
2166 L = float(xp.sqrt(xp.sum(g**2)))
2167 return L
2168 except Exception:
2169 pass
2170
2171 def grad(self, arr: pxt.NDArray) -> pxt.NDArray:
2172 ndi = pxd.NDArrayInfo.from_obj(arr)
2173 xp = ndi.module()
2174
2175 sh = arr.shape[: -self.dim_rank]
2176 x = xp.ones((*sh, 1), dtype=arr.dtype)
2177 g = self.adjoint(x)
2178
2179 if ndi == pxd.NDArrayInfo.DASK:
2180 # LinFuncs auto-determine [grad,prox,fenchel_prox]() via the user-specified adjoint().
2181 # Problem: cannot forward any core-chunk info to adjoint(), hence grad's core-chunks
2182 # may differ from `arr`. This is problematic since [grad,prox,fenchel_prox]() should
2183 # preserve core-chunks by default.
2184 if g.chunks != arr.chunks:
2185 g = g.rechunk(arr.chunks)
2186 return g
2187
2188 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
2189 out = arr - tau * self.grad(arr)
2190 return out
2191
2192 def fenchel_prox(self, arr: pxt.NDArray, sigma: pxt.Real) -> pxt.NDArray:
2193 return self.grad(arr)
2194
2195 def cogram(self) -> pxt.OpT:
2196 from pyxu.operator import HomothetyOp
2197
2198 L = self.estimate_lipschitz()
2199 return HomothetyOp(cst=L**2, dim_shape=1)
2200
2201 def svdvals(self, **kwargs) -> pxt.NDArray:
2202 gpu = kwargs.get("gpu", False)
2203 xp = pxd.NDArrayInfo.from_flag(gpu).module()
2204 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value)
2205
2206 L = self.estimate_lipschitz()
2207 D = xp.array([L], dtype=dtype)
2208 return D
2209
2210 def asarray(self, **kwargs) -> pxt.NDArray:
2211 xp = kwargs.get("xp", pxd.NDArrayInfo.default().module())
2212 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value)
2213
2214 E = xp.ones((1, 1), dtype=dtype)
2215 A = self.adjoint(E) # (1, *dim_shape)
2216 return A
2217
2218
2219def _core_operators() -> cabc.Set[pxt.OpC]:
2220 # Operators which can be sub-classed by end-users and participate in arithmetic rules.
2221 ops = set()
2222 for _ in globals().values():
2223 if inspect.isclass(_) and issubclass(_, Operator):
2224 ops.add(_)
2225 ops.remove(Operator)
2226 return ops
2227
2228
2229def _is_real(x) -> bool:
2230 if isinstance(x, pxt.Real):
2231 return True
2232 elif isinstance(x, pxd.supported_array_types()) and (x.size == 1):
2233 return True
2234 else:
2235 return False
2236
2237
2238__all__ = [
2239 "Operator",
2240 "Property",
2241 *map(lambda _: _.__name__, _core_operators()),
2242]