Source code for pyxu.abc.operator

   1import collections
   2import collections.abc as cabc
   3import copy
   4import enum
   5import inspect
   6import types
   7import typing as typ
   8import warnings
   9
  10import numpy as np
  11import scipy.sparse.linalg as spsl
  12
  13import pyxu.info.deps as pxd
  14import pyxu.info.ptype as pxt
  15import pyxu.info.warning as pxw
  16import pyxu.runtime as pxrt
  17import pyxu.util as pxu
  18
  19
[docs] 20class Property(enum.Enum): 21 """ 22 Mathematical property. 23 24 See Also 25 -------- 26 :py:class:`~pyxu.abc.Operator` 27 """ 28 29 CAN_EVAL = enum.auto() 30 FUNCTIONAL = enum.auto() 31 PROXIMABLE = enum.auto() 32 DIFFERENTIABLE = enum.auto() 33 DIFFERENTIABLE_FUNCTION = enum.auto() 34 LINEAR = enum.auto() 35 LINEAR_SQUARE = enum.auto() 36 LINEAR_NORMAL = enum.auto() 37 LINEAR_IDEMPOTENT = enum.auto() 38 LINEAR_SELF_ADJOINT = enum.auto() 39 LINEAR_POSITIVE_DEFINITE = enum.auto() 40 LINEAR_UNITARY = enum.auto() 41 QUADRATIC = enum.auto() 42
[docs] 43 def arithmetic_methods(self) -> cabc.Set[str]: 44 "Instance methods affected by arithmetic operations." 45 data = collections.defaultdict(list) 46 data[self.CAN_EVAL].extend( 47 [ 48 "apply", 49 "__call__", 50 "estimate_lipschitz", 51 "_expr", 52 ] 53 ) 54 data[self.PROXIMABLE].append("prox") 55 data[self.DIFFERENTIABLE].extend( 56 [ 57 "jacobian", 58 "estimate_diff_lipschitz", 59 ] 60 ) 61 data[self.DIFFERENTIABLE_FUNCTION].append("grad") 62 data[self.LINEAR].extend( 63 [ 64 "adjoint", 65 "asarray", 66 "svdvals", 67 "pinv", 68 "gram", 69 "cogram", 70 ] 71 ) 72 data[self.LINEAR_SQUARE].append("trace") 73 data[self.QUADRATIC].append("_quad_spec") 74 75 meth = frozenset(data[self]) 76 return meth
77 78
[docs] 79class Operator: 80 """ 81 Abstract Base Class for Pyxu operators. 82 83 Goals: 84 85 * enable operator arithmetic. 86 * cast operators to specialized forms. 87 * attach :py:class:`~pyxu.abc.Property` tags encoding certain mathematical properties. Each core sub-class **must** 88 have a unique set of properties to be distinguishable from its peers. 89 """ 90 91 # For `(size-1 ndarray) * OpT` to work, we need to force NumPy's hand and call OpT.__rmul__() in 92 # place of ndarray.__mul__() to determine how scaling should be performed. 93 # This is achieved by increasing __array_priority__ for all operators. 94 __array_priority__ = np.inf 95
[docs] 96 def __init__( 97 self, 98 dim_shape: pxt.NDArrayShape, 99 codim_shape: pxt.NDArrayShape, 100 ): 101 r""" 102 Parameters 103 ---------- 104 dim_shape: NDArrayShape 105 (M1,...,MD) operator input shape. 106 codim_shape: NDArrayShape 107 (N1,...,NK) operator output shape. 108 """ 109 dim_shape = pxu.as_canonical_shape(dim_shape) 110 assert all(ax >= 1 for ax in dim_shape) 111 112 codim_shape = pxu.as_canonical_shape(codim_shape) 113 assert all(ax >= 1 for ax in codim_shape) 114 115 self._dim_shape = dim_shape 116 self._codim_shape = codim_shape 117 self._name = self.__class__.__name__
118 119 # Public Interface -------------------------------------------------------- 120 @property 121 def dim_shape(self) -> pxt.NDArrayShape: 122 r""" 123 Return shape of operator's domain. (M1,...,MD) 124 """ 125 return self._dim_shape 126 127 @property 128 def dim_size(self) -> pxt.Integer: 129 r""" 130 Return size of operator's domain. (M1*...*MD) 131 """ 132 return np.prod(self.dim_shape) 133 134 @property 135 def dim_rank(self) -> pxt.Integer: 136 r""" 137 Return rank of operator's domain. (D) 138 """ 139 return len(self.dim_shape) 140 141 @property 142 def codim_shape(self) -> pxt.NDArrayShape: 143 r""" 144 Return shape of operator's co-domain. (N1,...,NK) 145 """ 146 return self._codim_shape 147 148 @property 149 def codim_size(self) -> pxt.Integer: 150 r""" 151 Return size of operator's co-domain. (N1*...*NK) 152 """ 153 return np.prod(self.codim_shape) 154 155 @property 156 def codim_rank(self) -> pxt.Integer: 157 r""" 158 Return rank of operator's co-domain. (K) 159 """ 160 return len(self.codim_shape) 161
[docs] 162 @classmethod 163 def properties(cls) -> cabc.Set[Property]: 164 "Mathematical properties of the operator." 165 return frozenset()
166
[docs] 167 @classmethod 168 def has(cls, prop: typ.Union[Property, cabc.Collection[Property]]) -> bool: 169 """ 170 Verify if operator possesses supplied properties. 171 """ 172 if isinstance(prop, Property): 173 prop = (prop,) 174 return frozenset(prop) <= cls.properties()
175
[docs] 176 def asop(self, cast_to: pxt.OpC) -> pxt.OpT: 177 r""" 178 Recast an :py:class:`~pyxu.abc.Operator` (or subclass thereof) to another :py:class:`~pyxu.abc.Operator`. 179 180 Users may call this method if the arithmetic API yields sub-optimal return types. 181 182 This method is a no-op if `cast_to` is a parent class of ``self``. 183 184 Parameters 185 ---------- 186 cast_to: OpC 187 Target type for the recast. 188 189 Returns 190 ------- 191 op: OpT 192 Operator with the new interface. 193 194 Fails when cast is forbidden. 195 (Ex: :py:class:`~pyxu.abc.Map` -> :py:class:`~pyxu.abc.Func` if codim.size > 1) 196 197 Notes 198 ----- 199 * The interface of `cast_to` is provided via encapsulation + forwarding. 200 * If ``self`` does not implement all methods from `cast_to`, then unimplemented methods will raise 201 :py:class:`NotImplementedError` when called. 202 """ 203 if cast_to not in _core_operators(): 204 raise ValueError(f"cast_to: expected a core base-class, got {cast_to}.") 205 206 p_core = frozenset(self.properties()) 207 p_shell = frozenset(cast_to.properties()) 208 if p_shell <= p_core: 209 # Trying to cast `self` to it's own class or a parent class. 210 # Inheritance rules mean the target object already satisfies the intended interface. 211 return self 212 else: 213 # (p_shell > p_core) -> specializing to a sub-class of ``self`` 214 # OR 215 # len(p_shell ^ p_core) > 0 -> specializing to another branch of the class hierarchy. 216 op = cast_to( 217 dim_shape=self.dim_shape, 218 codim_shape=self.codim_shape, 219 ) 220 op._core = self # for debugging 221 222 # Forward shared arithmetic fields from core to shell. 223 for p in p_shell & p_core: 224 for m in p.arithmetic_methods(): 225 m_core = getattr(self, m) 226 setattr(op, m, m_core) 227 228 # [diff_]lipschitz are not arithmetic methods, hence are not propagated. 229 # We propagate them manually to avoid un-warranted re-evaluations. 230 # Important: we write to _[diff_]lipschitz to not overwrite estimate_[diff_]lipschitz() methods. 231 if cast_to.has(Property.CAN_EVAL) and self.has(Property.CAN_EVAL): 232 op._lipschitz = self.lipschitz 233 if cast_to.has(Property.DIFFERENTIABLE) and self.has(Property.DIFFERENTIABLE): 234 op._diff_lipschitz = self.diff_lipschitz 235 236 return op
237 238 # Operator Arithmetic -----------------------------------------------------
[docs] 239 def __add__(self, other: pxt.OpT) -> pxt.OpT: 240 """ 241 Add two operators. 242 243 Parameters 244 ---------- 245 self: OpT 246 Left operand. 247 other: OpT 248 Right operand. 249 250 Returns 251 ------- 252 op: OpT 253 Composite operator ``self + other`` 254 255 Notes 256 ----- 257 Operand shapes must be consistent, i.e.: 258 259 * have `same dimensions`, and 260 * have `compatible co-dimensions` (after broadcasting). 261 262 See Also 263 -------- 264 :py:class:`~pyxu.abc.arithmetic.AddRule` 265 """ 266 import pyxu.abc.arithmetic as arithmetic 267 268 if isinstance(other, Operator): 269 return arithmetic.AddRule(lhs=self, rhs=other).op() 270 else: 271 return NotImplemented
272
[docs] 273 def __sub__(self, other: pxt.OpT) -> pxt.OpT: 274 """ 275 Subtract two operators. 276 277 Parameters 278 ---------- 279 self: OpT 280 Left operand. 281 other: OpT 282 Right operand. 283 284 Returns 285 ------- 286 op: OpT 287 Composite operator ``self - other`` 288 """ 289 import pyxu.abc.arithmetic as arithmetic 290 291 if isinstance(other, Operator): 292 return arithmetic.AddRule(lhs=self, rhs=-other).op() 293 else: 294 return NotImplemented
295
[docs] 296 def __neg__(self) -> pxt.OpT: 297 """ 298 Negate an operator. 299 300 Returns 301 ------- 302 op: OpT 303 Composite operator ``-1 * self``. 304 """ 305 import pyxu.abc.arithmetic as arithmetic 306 307 return arithmetic.ScaleRule(op=self, cst=-1).op()
308
[docs] 309 def __mul__(self, other: typ.Union[pxt.Real, pxt.OpT]) -> pxt.OpT: 310 """ 311 Compose two operators, or scale an operator by a constant. 312 313 Parameters 314 ---------- 315 self: OpT 316 Left operand. 317 other: Real, OpT 318 Scalar or right operand. 319 320 Returns 321 ------- 322 op: OpT 323 Scaled operator or composed operator ``self * other``. 324 325 Notes 326 ----- 327 If called with two operators, their shapes must be `consistent`, i.e. ``self.dim_shape == other.codim_shape``. 328 329 See Also 330 -------- 331 :py:class:`~pyxu.abc.arithmetic.ScaleRule`, 332 :py:class:`~pyxu.abc.arithmetic.ChainRule` 333 """ 334 import pyxu.abc.arithmetic as arithmetic 335 336 if isinstance(other, Operator): 337 return arithmetic.ChainRule(lhs=self, rhs=other).op() 338 elif _is_real(other): 339 return arithmetic.ScaleRule(op=self, cst=float(other)).op() 340 else: 341 return NotImplemented
342 343 def __rmul__(self, other: pxt.Real) -> pxt.OpT: 344 import pyxu.abc.arithmetic as arithmetic 345 346 if _is_real(other): 347 return arithmetic.ScaleRule(op=self, cst=float(other)).op() 348 else: 349 return NotImplemented 350 351 def __truediv__(self, other: pxt.Real) -> pxt.OpT: 352 import pyxu.abc.arithmetic as arithmetic 353 354 if _is_real(other): 355 return arithmetic.ScaleRule(op=self, cst=float(1 / other)).op() 356 else: 357 return NotImplemented 358
[docs] 359 def __pow__(self, k: pxt.Integer) -> pxt.OpT: 360 # (op ** k) unsupported 361 return NotImplemented
362 363 def __matmul__(self, other) -> pxt.OpT: 364 # (op @ NDArray) unsupported 365 return NotImplemented 366 367 def __rmatmul__(self, other) -> pxt.OpT: 368 # (NDArray @ op) unsupported 369 return NotImplemented 370
[docs] 371 def argscale(self, scalar: pxt.Real) -> pxt.OpT: 372 """ 373 Scale operator's domain. 374 375 Parameters 376 ---------- 377 scalar: Real 378 379 Returns 380 ------- 381 op: OpT 382 Domain-scaled operator. 383 384 See Also 385 -------- 386 :py:class:`~pyxu.abc.arithmetic.ArgScaleRule` 387 """ 388 import pyxu.abc.arithmetic as arithmetic 389 390 assert _is_real(scalar) 391 return arithmetic.ArgScaleRule(op=self, cst=float(scalar)).op()
392
[docs] 393 def argshift(self, shift: pxt.NDArray) -> pxt.OpT: 394 r""" 395 Shift operator's domain. 396 397 Parameters 398 ---------- 399 shift: NDArray 400 Shift value :math:`c \in \mathbb{R}^{M_{1} \times\cdots\times M_{D}}`. 401 402 `shift` must be broadcastable with operator's dimension. 403 404 Returns 405 ------- 406 op: OpT 407 Domain-shifted operator :math:`g(x) = f(x + c)`. 408 409 See Also 410 -------- 411 :py:class:`~pyxu.abc.arithmetic.ArgShiftRule` 412 """ 413 import pyxu.abc.arithmetic as arithmetic 414 415 return arithmetic.ArgShiftRule(op=self, cst=shift).op()
416 417 # Internal Helpers -------------------------------------------------------- 418 @staticmethod 419 def _infer_operator_type(prop: cabc.Collection[Property]) -> pxt.OpC: 420 prop = frozenset(prop) 421 for op in _core_operators(): 422 if op.properties() == prop: 423 return op 424 else: 425 raise ValueError(f"No operator found with properties {prop}.") 426 427 def __repr__(self) -> str: 428 klass = self._name 429 return f"{klass}(dim={self.dim_shape}, codim={self.codim_shape})" 430 431 def _expr(self) -> tuple: 432 r""" 433 Show the expression-representation of the operator. 434 435 If overridden, must return a tuple of the form 436 437 (head, \*tail), 438 439 where `head` is the operator (ex: +/\*), and `tail` denotes all the expression's terms. If an operator cannot 440 be expanded further, then this method should return (self,). 441 """ 442 return (self,) 443 444 def _meta(self): 445 # When using DASK inputs, it is sometimes necessary to pass extra information to Dask functions. This function 446 # serves this purpose: it lets class writers encode any such information and re-use it when processing DASK 447 # inputs. The action and return types of _meta() are at the sole discretion of the implementer. 448 raise NotImplementedError 449
[docs] 450 def expr(self, level: int = 0, strip: bool = True) -> str: 451 """ 452 Pretty-Print the expression representation of the operator. 453 454 Useful for debugging arithmetic-induced expressions. 455 456 Example 457 ------- 458 459 .. code-block:: python3 460 461 import numpy as np 462 import pyxu.abc as pxa 463 464 kwargs = dict(dim_shape=5, codim_shape=5) 465 op1 = pxa.LinOp(**kwargs) 466 op2 = pxa.DiffMap(**kwargs) 467 op = ((2 * op1) + (op1 * op2)).argshift(np.r_[1]) 468 469 print(op.expr()) 470 # [argshift, ==> DiffMap(dim=(5,), codim=(5,)) 471 # .[add, ==> DiffMap(dim=(5,), codim=(5,)) 472 # ..[scale, ==> LinOp(dim=(5,), codim=(5,)) 473 # ...LinOp(dim=(5,), codim=(5,)), 474 # ...2.0], 475 # ..[compose, ==> DiffMap(dim=(5,), codim=(5,)) 476 # ...LinOp(dim=(5,), codim=(5,)), 477 # ...DiffMap(dim=(5,), codim=(5,))]], 478 # .(1,)] 479 """ 480 fmt = lambda obj, lvl: ("." * lvl) + str(obj) 481 lines = [] 482 483 head, *tail = self._expr() 484 if len(tail) == 0: 485 head = f"{repr(head)}," 486 else: 487 head = f"[{head}, ==> {repr(self)}" 488 lines.append(fmt(head, level)) 489 490 for t in tail: 491 if isinstance(t, Operator): 492 lines += t.expr(level=level + 1, strip=False).split("\n") 493 else: 494 t = f"{t}," 495 lines.append(fmt(t, level + 1)) 496 if len(tail) > 0: 497 # Drop comma for last tail item, then close the sub-expression. 498 lines[-1] = lines[-1][:-1] 499 lines[-1] += "]," 500 501 out = "\n".join(lines) 502 if strip: 503 out = out.strip(",") # drop comma at top-level tail. 504 return out
505 506 # Short-hands for commonly-used operators ---------------------------------
[docs] 507 def squeeze(self, axes: pxt.NDArrayAxis = None) -> pxt.OpT: 508 """ 509 Drop size-1 axes from co-dimension. 510 511 See Also 512 -------- 513 :py:class:`~pyxu.operator.SqueezeAxes` 514 """ 515 516 from pyxu.operator import SqueezeAxes 517 518 sq = SqueezeAxes( 519 dim_shape=self.codim_shape, 520 axes=axes, 521 ) 522 op = sq * self 523 return op
524
[docs] 525 def transpose(self, axes: pxt.NDArrayAxis = None) -> pxt.OpT: 526 """ 527 Permute co-dimension axes. 528 529 See Also 530 -------- 531 :py:class:`~pyxu.operator.TransposeAxes` 532 """ 533 534 from pyxu.operator import TransposeAxes 535 536 tr = TransposeAxes( 537 dim_shape=self.codim_shape, 538 axes=axes, 539 ) 540 op = tr * self 541 return op
542
[docs] 543 def reshape(self, codim_shape: pxt.NDArrayShape) -> pxt.OpT: 544 """ 545 Reshape co-dimension shape. 546 547 See Also 548 -------- 549 :py:class:`~pyxu.operator.ReshapeAxes` 550 """ 551 552 from pyxu.operator import ReshapeAxes 553 554 rsh = ReshapeAxes( 555 dim_shape=self.codim_shape, 556 codim_shape=codim_shape, 557 ) 558 op = rsh * self 559 return op
560
[docs] 561 def broadcast_to(self, codim_shape: pxt.NDArrayShape) -> pxt.OpT: 562 """ 563 Broadcast co-dimension shape. 564 565 See Also 566 -------- 567 :py:class:`~pyxu.operator.BroadcastAxes` 568 """ 569 570 from pyxu.operator import BroadcastAxes 571 572 bcast = BroadcastAxes( 573 dim_shape=self.codim_shape, 574 codim_shape=codim_shape, 575 ) 576 op = bcast * self 577 return op
578
[docs] 579 def subsample(self, *indices) -> pxt.OpT: 580 """ 581 Sub-sample co-dimension. 582 583 See Also 584 -------- 585 :py:class:`~pyxu.operator.SubSample` 586 """ 587 588 from pyxu.operator import SubSample 589 590 sub = SubSample(self.codim_shape, *indices) 591 op = sub * self 592 return op
593
[docs] 594 def rechunk(self, chunks: dict) -> pxt.OpT: 595 """ 596 Re-chunk core dimensions to new chunk size. 597 598 See Also 599 -------- 600 :py:func:`~pyxu.operator.RechunkAxes` 601 """ 602 from pyxu.operator import RechunkAxes 603 604 chk = RechunkAxes( 605 dim_shape=self.codim_shape, 606 chunks=chunks, 607 ) 608 op = chk * self 609 return op
610 611
[docs] 612class Map(Operator): 613 r""" 614 Base class for real-valued maps :math:`\mathbf{f}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{N_{1} 615 \times\cdots\times N_{K}}`. 616 617 Instances of this class must implement :py:meth:`~pyxu.abc.Map.apply`. 618 619 If :math:`\mathbf{f}` is Lipschitz-continuous with known Lipschitz constant :math:`L`, the latter should be stored 620 in the :py:attr:`~pyxu.abc.Map.lipschitz` property. 621 """ 622
[docs] 623 @classmethod 624 def properties(cls) -> cabc.Set[Property]: 625 p = set(super().properties()) 626 p.add(Property.CAN_EVAL) 627 return frozenset(p)
628 629 def __init__( 630 self, 631 dim_shape: pxt.NDArrayShape, 632 codim_shape: pxt.NDArrayShape, 633 ): 634 super().__init__( 635 dim_shape=dim_shape, 636 codim_shape=codim_shape, 637 ) 638 self.lipschitz = np.inf 639
[docs] 640 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 641 """ 642 Evaluate operator at specified point(s). 643 644 Parameters 645 ---------- 646 arr: NDArray 647 (..., M1,...,MD) input points. 648 649 Returns 650 ------- 651 out: NDArray 652 (..., N1,...,NK) output points. 653 """ 654 raise NotImplementedError
655
[docs] 656 def __call__(self, arr: pxt.NDArray) -> pxt.NDArray: 657 """ 658 Alias for :py:meth:`~pyxu.abc.Map.apply`. 659 """ 660 return self.apply(arr)
661 662 @property 663 def lipschitz(self) -> pxt.Real: 664 r""" 665 Return the last computed Lipschitz constant of :math:`\mathbf{f}`. 666 667 Notes 668 ----- 669 * If a Lipschitz constant is known apriori, it can be stored in the instance as follows: 670 671 .. code-block:: python3 672 673 class TestOp(Map): 674 def __init__(self, dim_shape, codim_shape): 675 super().__init__(dim_shape, codim_shape) 676 self.lipschitz = 2 677 678 op = TestOp(2, 3) 679 op.lipschitz # => 2 680 681 682 Alternatively the Lipschitz constant can be set manually after initialization: 683 684 .. code-block:: python3 685 686 class TestOp(Map): 687 def __init__(self, dim_shape, codim_shape): 688 super().__init__(dim_shape, codim_shape) 689 690 op = TestOp(2, 3) 691 op.lipschitz # => inf, since unknown apriori 692 693 op.lipschitz = 2 # post-init specification 694 op.lipschitz # => 2 695 696 * :py:meth:`~pyxu.abc.Map.lipschitz` **never** computes anything: 697 call :py:meth:`~pyxu.abc.Map.estimate_lipschitz` manually to *compute* a new Lipschitz estimate: 698 699 .. code-block:: python3 700 701 op.lipschitz = op.estimate_lipschitz() 702 """ 703 if not hasattr(self, "_lipschitz"): 704 self._lipschitz = self.estimate_lipschitz() 705 706 return self._lipschitz 707 708 @lipschitz.setter 709 def lipschitz(self, L: pxt.Real): 710 assert L >= 0 711 self._lipschitz = float(L) 712 713 # If no algorithm available to auto-determine estimate_lipschitz(), then enforce user's choice. 714 if not self.has(Property.LINEAR): 715 716 def op_estimate_lipschitz(_, **kwargs) -> pxt.Real: 717 return _._lipschitz 718 719 self.estimate_lipschitz = types.MethodType(op_estimate_lipschitz, self) 720
[docs] 721 def estimate_lipschitz(self, **kwargs) -> pxt.Real: 722 r""" 723 Compute a Lipschitz constant of the operator. 724 725 Parameters 726 ---------- 727 kwargs: ~collections.abc.Mapping 728 Class-specific kwargs to configure Lipschitz estimation. 729 730 Notes 731 ----- 732 * This method should always be callable without specifying any kwargs. 733 734 * A constant :math:`L_{\mathbf{f}} > 0` is said to be a *Lipschitz constant* for a map :math:`\mathbf{f}: 735 \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{N_{1} \times\cdots\times N_{K}}` if: 736 737 .. math:: 738 739 \|\mathbf{f}(\mathbf{x}) - \mathbf{f}(\mathbf{y})\|_{\mathbb{R}^{N_{1} \times\cdots\times N_{K}}} 740 \leq 741 L_{\mathbf{f}} \|\mathbf{x} - \mathbf{y}\|_{\mathbb{R}^{M_{1} \times\cdots\times M_{D}}}, 742 \qquad 743 \forall \mathbf{x}, \mathbf{y}\in \mathbb{R}^{M_{1} \times\cdots\times M_{D}}, 744 745 where :math:`\|\cdot\|_{\mathbb{R}^{M_{1} \times\cdots\times M_{D}}}` and :math:`\|\cdot\|_{\mathbb{R}^{N_{1} 746 \times\cdots\times N_{K}}}` are the canonical norms on their respective spaces. 747 748 The smallest Lipschitz constant of a map is called the *optimal Lipschitz constant*. 749 """ 750 raise NotImplementedError
751 752
[docs] 753class Func(Map): 754 r""" 755 Base class for real-valued functionals :math:`f: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to 756 \mathbb{R}\cup\{+\infty\}`. 757 758 Instances of this class must implement :py:meth:`~pyxu.abc.Map.apply`. 759 760 If :math:`f` is Lipschitz-continuous with known Lipschitz constant :math:`L`, the latter should be stored in the 761 :py:attr:`~pyxu.abc.Map.lipschitz` property. 762 """ 763 764 @classmethod 765 def properties(cls) -> cabc.Set[Property]: 766 p = set(super().properties()) 767 p.add(Property.FUNCTIONAL) 768 return frozenset(p) 769 770 def __init__( 771 self, 772 dim_shape: pxt.NDArrayShape, 773 codim_shape: pxt.NDArrayShape, 774 ): 775 super().__init__( 776 dim_shape=dim_shape, 777 codim_shape=codim_shape, 778 ) 779 assert (self.codim_size == 1) and (self.codim_rank == 1)
780 781
[docs] 782class DiffMap(Map): 783 r""" 784 Base class for real-valued differentiable maps :math:`\mathbf{f}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to 785 \mathbb{R}^{N_{1} \times\cdots\times N_{K}}`. 786 787 Instances of this class must implement :py:meth:`~pyxu.abc.Map.apply` and :py:meth:`~pyxu.abc.DiffMap.jacobian`. 788 789 If :math:`\mathbf{f}` is Lipschitz-continuous with known Lipschitz constant :math:`L`, the latter should be stored 790 in the :py:attr:`~pyxu.abc.Map.lipschitz` property. 791 792 If :math:`\mathbf{J}_{\mathbf{f}}` is Lipschitz-continuous with known Lipschitz constant :math:`\partial L`, the 793 latter should be stored in the :py:attr:`~pyxu.abc.DiffMap.diff_lipschitz` property. 794 """ 795 796 @classmethod 797 def properties(cls) -> cabc.Set[Property]: 798 p = set(super().properties()) 799 p.add(Property.DIFFERENTIABLE) 800 return frozenset(p) 801 802 def __init__( 803 self, 804 dim_shape: pxt.NDArrayShape, 805 codim_shape: pxt.NDArrayShape, 806 ): 807 super().__init__( 808 dim_shape=dim_shape, 809 codim_shape=codim_shape, 810 ) 811 self.diff_lipschitz = np.inf 812
[docs] 813 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT: 814 r""" 815 Evaluate the Jacobian of :math:`\mathbf{f}` at the specified point. 816 817 Parameters 818 ---------- 819 arr: NDArray 820 (M1,...,MD) evaluation point. 821 822 Returns 823 ------- 824 op: OpT 825 Jacobian operator at point `arr`. 826 827 Notes 828 ----- 829 Let :math:`\mathbf{f}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{N_{1} \times\cdots\times 830 N_{K}}` be a differentiable multi-dimensional map. The *Jacobian* (or *differential*) of :math:`\mathbf{f}` at 831 :math:`\mathbf{z} \in \mathbb{R}^{M_{1} \times\cdots\times M_{D}}` is defined as the best linear approximator of 832 :math:`\mathbf{f}` near :math:`\mathbf{z}`, in the following sense: 833 834 .. math:: 835 836 \mathbf{f}(\mathbf{x}) - \mathbf{f}(\mathbf{z}) = \mathbf{J}_{\mathbf{f}}(\mathbf{z}) (\mathbf{x} - 837 \mathbf{z}) + o(\| \mathbf{x} - \mathbf{z} \|) \quad \text{as} \quad \mathbf{x} \to \mathbf{z}. 838 839 The Jacobian admits the following matrix representation: 840 841 .. math:: 842 843 [\mathbf{J}_{\mathbf{f}}(\mathbf{x})]_{ij} := \frac{\partial f_{i}}{\partial x_{j}}(\mathbf{x}), \qquad 844 \forall (i,j) \in \{1,\ldots,N_{1}\cdots N_{K}\} \times \{1,\ldots,M_{1}\cdots M_{D}\}. 845 """ 846 raise NotImplementedError
847 848 @property 849 def diff_lipschitz(self) -> pxt.Real: 850 r""" 851 Return the last computed Lipschitz constant of :math:`\mathbf{J}_{\mathbf{f}}`. 852 853 Notes 854 ----- 855 * If a diff-Lipschitz constant is known apriori, it can be stored in the instance as follows: 856 857 .. code-block:: python3 858 859 class TestOp(DiffMap): 860 def __init__(self, dim_shape, codim_shape): 861 super().__init__(dim_shape, codim_shape) 862 self.diff_lipschitz = 2 863 864 op = TestOp(2, 3) 865 op.diff_lipschitz # => 2 866 867 868 Alternatively the diff-Lipschitz constant can be set manually after initialization: 869 870 .. code-block:: python3 871 872 class TestOp(DiffMap): 873 def __init__(self, dim_shape, codim_shape): 874 super().__init__(dim_shape, codim_shape) 875 876 op = TestOp(2, 3) 877 op.diff_lipschitz # => inf, since unknown apriori 878 879 op.diff_lipschitz = 2 # post-init specification 880 op.diff_lipschitz # => 2 881 882 * :py:meth:`~pyxu.abc.DiffMap.diff_lipschitz` **never** computes anything: 883 call :py:meth:`~pyxu.abc.DiffMap.estimate_diff_lipschitz` manually to *compute* a new diff-Lipschitz estimate: 884 885 .. code-block:: python3 886 887 op.diff_lipschitz = op.estimate_diff_lipschitz() 888 """ 889 if not hasattr(self, "_diff_lipschitz"): 890 self._diff_lipschitz = self.estimate_diff_lipschitz() 891 892 return self._diff_lipschitz 893 894 @diff_lipschitz.setter 895 def diff_lipschitz(self, dL: pxt.Real): 896 assert dL >= 0 897 self._diff_lipschitz = float(dL) 898 899 # If no algorithm available to auto-determine estimate_diff_lipschitz(), then enforce user's choice. 900 if not self.has(Property.QUADRATIC): 901 902 def op_estimate_diff_lipschitz(_, **kwargs) -> pxt.Real: 903 return _._diff_lipschitz 904 905 self.estimate_diff_lipschitz = types.MethodType(op_estimate_diff_lipschitz, self) 906
[docs] 907 def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real: 908 r""" 909 Compute a Lipschitz constant of :py:meth:`~pyxu.abc.DiffMap.jacobian`. 910 911 Parameters 912 ---------- 913 kwargs: ~collections.abc.Mapping 914 Class-specific kwargs to configure diff-Lipschitz estimation. 915 916 Notes 917 ----- 918 * This method should always be callable without specifying any kwargs. 919 920 * A Lipschitz constant :math:`L_{\mathbf{J}_{\mathbf{f}}} > 0` of the Jacobian map 921 :math:`\mathbf{J}_{\mathbf{f}}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{(N_{1} 922 \times\cdots\times N_{K}) \times (M_{1} \times\cdots\times M_{D})}` is such that: 923 924 .. math:: 925 926 \|\mathbf{J}_{\mathbf{f}}(\mathbf{x}) - \mathbf{J}_{\mathbf{f}}(\mathbf{y})\|_{\mathbb{R}^{(N_{1} 927 \times\cdots\times N_{K}) \times (M_{1} \times\cdots\times M_{D})}} 928 \leq 929 L_{\mathbf{J}_{\mathbf{f}}} \|\mathbf{x} - \mathbf{y}\|_{\mathbb{R}^{M_{1} \times\cdots\times M_{D}}}, 930 \qquad 931 \forall \mathbf{x}, \mathbf{y} \in \mathbb{R}^{M_{1} \times\cdots\times M_{D}}, 932 933 where :math:`\|\cdot\|_{\mathbb{R}^{(N_{1} \times\cdots\times N_{K}) \times (M_{1} \times\cdots\times 934 M_{D})}}` and :math:`\|\cdot\|_{\mathbb{R}^{M_{1} \times\cdots\times M_{D}}}` are the canonical norms on their 935 respective spaces. 936 937 The smallest Lipschitz constant of the Jacobian is called the *optimal diff-Lipschitz constant*. 938 """ 939 raise NotImplementedError
940 941
[docs] 942class ProxFunc(Func): 943 r""" 944 Base class for real-valued proximable functionals :math:`f: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to 945 \mathbb{R} \cup \{+\infty\}`. 946 947 A functional :math:`f: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R} \cup \{+\infty\}` is said 948 *proximable* if its **proximity operator** (see :py:meth:`~pyxu.abc.ProxFunc.prox` for a definition) admits a 949 *simple closed-form expression* **or** can be evaluated *efficiently* and with *high accuracy*. 950 951 Instances of this class must implement :py:meth:`~pyxu.abc.Map.apply` and :py:meth:`~pyxu.abc.ProxFunc.prox`. 952 953 If :math:`f` is Lipschitz-continuous with known Lipschitz constant :math:`L`, the latter should be stored in the 954 :py:attr:`~pyxu.abc.Map.lipschitz` property. 955 """ 956 957 @classmethod 958 def properties(cls) -> cabc.Set[Property]: 959 p = set(super().properties()) 960 p.add(Property.PROXIMABLE) 961 return frozenset(p) 962 963 def __init__( 964 self, 965 dim_shape: pxt.NDArrayShape, 966 codim_shape: pxt.NDArrayShape, 967 ): 968 super().__init__( 969 dim_shape=dim_shape, 970 codim_shape=codim_shape, 971 ) 972
[docs] 973 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray: 974 r""" 975 Evaluate proximity operator of :math:`\tau f` at specified point(s). 976 977 Parameters 978 ---------- 979 arr: NDArray 980 (..., M1,...,MD) input points. 981 tau: Real 982 Positive scale factor. 983 984 Returns 985 ------- 986 out: NDArray 987 (..., M1,...,MD) proximal evaluations. 988 989 Notes 990 ----- 991 For :math:`\tau >0`, the *proximity operator* of a scaled functional :math:`f: \mathbb{R}^{M_{1} 992 \times\cdots\times M_{D}} \to \mathbb{R}` is defined as: 993 994 .. math:: 995 996 \mathbf{\text{prox}}_{\tau f}(\mathbf{z}) 997 := 998 \arg\min_{\mathbf{x}\in\mathbb{R}^{M_{1} \times\cdots\times M_{D}}} 999 f(x)+\frac{1}{2\tau} \|\mathbf{x}-\mathbf{z}\|_{2}^{2}, 1000 \quad 1001 \forall \mathbf{z} \in \mathbb{R}^{M_{1} \times\cdots\times M_{D}}. 1002 """ 1003 raise NotImplementedError
1004
[docs] 1005 def fenchel_prox(self, arr: pxt.NDArray, sigma: pxt.Real) -> pxt.NDArray: 1006 r""" 1007 Evaluate proximity operator of :math:`\sigma f^{\ast}`, the scaled Fenchel conjugate of :math:`f`, at specified 1008 point(s). 1009 1010 Parameters 1011 ---------- 1012 arr: NDArray 1013 (..., M1,...,MD) input points. 1014 sigma: Real 1015 Positive scale factor. 1016 1017 Returns 1018 ------- 1019 out: NDArray 1020 (..., M1,...,MD) proximal evaluations. 1021 1022 Notes 1023 ----- 1024 For :math:`\sigma > 0`, the *Fenchel conjugate* is defined as: 1025 1026 .. math:: 1027 1028 f^{\ast}(\mathbf{z}) 1029 := 1030 \max_{\mathbf{x}\in\mathbb{R}^{M_{1} \times\cdots\times M_{D}}} 1031 \langle \mathbf{x},\mathbf{z} \rangle - f(\mathbf{x}). 1032 1033 From *Moreau's identity*, its proximal operator is given by: 1034 1035 .. math:: 1036 1037 \mathbf{\text{prox}}_{\sigma f^{\ast}}(\mathbf{z}) 1038 = 1039 \mathbf{z} - \sigma \mathbf{\text{prox}}_{f/\sigma}(\mathbf{z}/\sigma). 1040 """ 1041 out = arr - sigma * self.prox(arr=arr / sigma, tau=1 / sigma) 1042 return out
1043
[docs] 1044 def moreau_envelope(self, mu: pxt.Real) -> pxt.OpT: 1045 r""" 1046 Approximate proximable functional :math:`f` by its *Moreau envelope* :math:`f^{\mu}`. 1047 1048 Parameters 1049 ---------- 1050 mu: Real 1051 Positive regularization parameter. 1052 1053 Returns 1054 ------- 1055 op: OpT 1056 Differential Moreau envelope. 1057 1058 Notes 1059 ----- 1060 Consider a convex non-smooth proximable functional :math:`f: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to 1061 \mathbb{R} \cup \{+\infty\}` and a regularization parameter :math:`\mu > 0`. The :math:`\mu`-*Moreau envelope* 1062 (or *Moreau-Yoshida envelope*) of :math:`f` is given by 1063 1064 .. math:: 1065 1066 f^{\mu}(\mathbf{x}) 1067 = 1068 \min_{\mathbf{z} \in \mathbb{R}^{M_{1} \times\cdots\times M_{D}}} 1069 f(\mathbf{z}) 1070 \quad + \quad 1071 \frac{1}{2\mu} \|\mathbf{x} - \mathbf{z}\|^{2}. 1072 1073 The parameter :math:`\mu` controls the trade-off between the regularity properties of :math:`f^{\mu}` and the 1074 approximation error incurred by the Moreau-Yoshida regularization. 1075 1076 The Moreau envelope inherits the convexity of :math:`f` and is gradient-Lipschitz (with Lipschitz constant 1077 :math:`\mu^{-1}`), even if :math:`f` is non-smooth. Its gradient is moreover given by: 1078 1079 .. math:: 1080 1081 \nabla f^{\mu}(\mathbf{x}) 1082 = 1083 \mu^{-1} \left(\mathbf{x} - \text{prox}_{\mu f}(\mathbf{x})\right). 1084 1085 In addition, :math:`f^{\mu}` envelopes :math:`f` from below: :math:`f^{\mu}(\mathbf{x}) \leq f(\mathbf{x})`. 1086 This envelope becomes tighter as :math:`\mu \to 0`: 1087 1088 .. math:: 1089 1090 \lim_{\mu \to 0} f^{\mu}(\mathbf{x}) = f(\mathbf{x}). 1091 1092 Finally, it can be shown that the minimizers of :math:`f` and :math:`f^{\mu}` coincide, and that the Fenchel 1093 conjugate of :math:`f^{\mu}` is strongly-convex. 1094 1095 Example 1096 ------- 1097 Construct and plot the Moreau envelope of the :math:`\ell_{1}`-norm: 1098 1099 .. plot:: 1100 1101 import numpy as np 1102 import matplotlib.pyplot as plt 1103 from pyxu.abc import ProxFunc 1104 1105 class L1Norm(ProxFunc): 1106 def __init__(self, dim: int): 1107 super().__init__(dim_shape=dim, codim_shape=1) 1108 def apply(self, arr): 1109 return np.linalg.norm(arr, axis=-1, keepdims=True, ord=1) 1110 def prox(self, arr, tau): 1111 return np.clip(np.abs(arr)-tau, a_min=0, a_max=None) * np.sign(arr) 1112 1113 mu = [0.1, 0.5, 1] 1114 f = [L1Norm(dim=1).moreau_envelope(_mu) for _mu in mu] 1115 x = np.linspace(-1, 1, 512).reshape(-1, 1) # evaluation points 1116 1117 fig, ax = plt.subplots(ncols=2) 1118 for _mu, _f in zip(mu, f): 1119 ax[0].plot(x, _f(x), label=f"mu={_mu}") 1120 ax[1].plot(x, _f.grad(x), label=f"mu={_mu}") 1121 ax[0].set_title('Moreau Envelope') 1122 ax[1].set_title("Derivative of Moreau Envelope") 1123 for _ax in ax: 1124 _ax.legend() 1125 _ax.set_aspect("equal") 1126 fig.tight_layout() 1127 """ 1128 from pyxu.operator.interop import from_source 1129 1130 assert mu > 0, f"mu: expected positive, got {mu}" 1131 1132 def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray: 1133 xp = pxu.get_array_module(arr) 1134 1135 x = self.prox(arr, tau=_._mu) 1136 y = xp.sum( 1137 (arr - x) ** 2, 1138 axis=tuple(range(-self.dim_rank, 0)), 1139 )[..., np.newaxis] 1140 y *= 0.5 / _._mu 1141 1142 out = self.apply(x) + y 1143 return out 1144 1145 def op_grad(_, arr): 1146 out = arr - self.prox(arr, tau=_._mu) 1147 out /= _._mu 1148 return out 1149 1150 op = from_source( 1151 cls=DiffFunc, 1152 dim_shape=self.dim_shape, 1153 codim_shape=self.codim_shape, 1154 embed=dict( 1155 _name="moreau_envelope", 1156 _mu=mu, 1157 _diff_lipschitz=float(1 / mu), 1158 ), 1159 apply=op_apply, 1160 grad=op_grad, 1161 _expr=lambda _: ("moreau_envelope", _, _._mu), 1162 ) 1163 return op
1164 1165
[docs] 1166class DiffFunc(DiffMap, Func): 1167 r""" 1168 Base class for real-valued differentiable functionals :math:`f: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to 1169 \mathbb{R}`. 1170 1171 Instances of this class must implement :py:meth:`~pyxu.abc.Map.apply` and :py:meth:`~pyxu.abc.DiffFunc.grad`. 1172 1173 If :math:`f` and/or its derivative :math:`f'` are Lipschitz-continuous with known Lipschitz constants :math:`L` and 1174 :math:`\partial L`, the latter should be stored in the :py:attr:`~pyxu.abc.Map.lipschitz` and 1175 :py:attr:`~pyxu.abc.DiffMap.diff_lipschitz` properties. 1176 """ 1177 1178 @classmethod 1179 def properties(cls) -> cabc.Set[Property]: 1180 p = set() 1181 for klass in cls.__bases__: 1182 p |= klass.properties() 1183 p.add(Property.DIFFERENTIABLE_FUNCTION) 1184 return frozenset(p) 1185 1186 def __init__( 1187 self, 1188 dim_shape: pxt.NDArrayShape, 1189 codim_shape: pxt.NDArrayShape, 1190 ): 1191 for klass in [DiffMap, Func]: 1192 klass.__init__( 1193 self, 1194 dim_shape=dim_shape, 1195 codim_shape=codim_shape, 1196 ) 1197 1198 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT: 1199 op = LinFunc.from_array( 1200 A=self.grad(arr)[np.newaxis, ...], 1201 dim_rank=self.dim_rank, 1202 ) 1203 return op 1204
[docs] 1205 def grad(self, arr: pxt.NDArray) -> pxt.NDArray: 1206 r""" 1207 Evaluate operator gradient at specified point(s). 1208 1209 Parameters 1210 ---------- 1211 arr: NDArray 1212 (..., M1,...,MD) input points. 1213 1214 Returns 1215 ------- 1216 out: NDArray 1217 (..., M1,...,MD) gradients. 1218 1219 Notes 1220 ----- 1221 The gradient of a functional :math:`f: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}` is given, for 1222 every :math:`\mathbf{x} \in \mathbb{R}^{M_{1} \times\cdots\times M_{D}}`, by 1223 1224 .. math:: 1225 1226 \nabla f(\mathbf{x}) 1227 := 1228 \left[\begin{array}{c} 1229 \frac{\partial f}{\partial x_{1}}(\mathbf{x}) \\ 1230 \vdots \\ 1231 \frac{\partial f}{\partial x_{M}}(\mathbf{x}) 1232 \end{array}\right]. 1233 """ 1234 raise NotImplementedError
1235 1236
[docs] 1237class ProxDiffFunc(ProxFunc, DiffFunc): 1238 r""" 1239 Base class for real-valued differentiable *and* proximable functionals :math:`f:\mathbb{R}^{M_{1} \times\cdots\times 1240 M_{D}} \to \mathbb{R}`. 1241 1242 Instances of this class must implement :py:meth:`~pyxu.abc.Map.apply`, :py:meth:`~pyxu.abc.DiffFunc.grad`, and 1243 :py:meth:`~pyxu.abc.ProxFunc.prox`. 1244 1245 If :math:`f` and/or its derivative :math:`f'` are Lipschitz-continuous with known Lipschitz constants :math:`L` and 1246 :math:`\partial L`, the latter should be stored in the :py:attr:`~pyxu.abc.Map.lipschitz` and 1247 :py:attr:`~pyxu.abc.DiffMap.diff_lipschitz` properties. 1248 """ 1249 1250 @classmethod 1251 def properties(cls) -> cabc.Set[Property]: 1252 p = set() 1253 for klass in cls.__bases__: 1254 p |= klass.properties() 1255 return frozenset(p) 1256 1257 def __init__( 1258 self, 1259 dim_shape: pxt.NDArrayShape, 1260 codim_shape: pxt.NDArrayShape, 1261 ): 1262 for klass in [ProxFunc, DiffFunc]: 1263 klass.__init__( 1264 self, 1265 dim_shape=dim_shape, 1266 codim_shape=codim_shape, 1267 )
1268 1269
[docs] 1270class QuadraticFunc(ProxDiffFunc): 1271 # This is a special abstract base class with more __init__ parameters than `dim/codim_shape`. 1272 r""" 1273 Base class for quadratic functionals :math:`f: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R} \cup 1274 \{+\infty\}`. 1275 1276 The quadratic functional is defined as: 1277 1278 .. math:: 1279 1280 f(\mathbf{x}) 1281 = 1282 \frac{1}{2} \langle\mathbf{x}, \mathbf{Q}\mathbf{x}\rangle 1283 + 1284 \langle\mathbf{c},\mathbf{x}\rangle 1285 + 1286 t, 1287 \qquad \forall \mathbf{x} \in \mathbb{R}^{M_{1} \times\cdots\times M_{D}}, 1288 1289 where :math:`Q` is a positive-definite operator :math:`\mathbf{Q}:\mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to 1290 \mathbb{R}^{M_{1} \times\cdots\times M_{D}}`, :math:`\mathbf{c} \in \mathbb{R}^{M_{1} \times\cdots\times M_{D}}`, 1291 and :math:`t > 0`. 1292 1293 Its gradient is given by: 1294 1295 .. math:: 1296 1297 \nabla f(\mathbf{x}) = \mathbf{Q}\mathbf{x} + \mathbf{c}. 1298 1299 Its proximity operator by: 1300 1301 .. math:: 1302 1303 \text{prox}_{\tau f}(x) 1304 = 1305 \left( 1306 \mathbf{Q} + \tau^{-1} \mathbf{Id} 1307 \right)^{-1} 1308 \left( 1309 \tau^{-1}\mathbf{x} - \mathbf{c} 1310 \right). 1311 1312 In practice the proximity operator is evaluated via :py:class:`~pyxu.opt.solver.CG`. 1313 1314 The Lipschitz constant :math:`L` of a quadratic on an unbounded domain is unbounded. The Lipschitz constant 1315 :math:`\partial L` of :math:`\nabla f` is given by the spectral norm of :math:`\mathbf{Q}`. 1316 """ 1317 1318 @classmethod 1319 def properties(cls) -> cabc.Set[Property]: 1320 p = set(super().properties()) 1321 p.add(Property.QUADRATIC) 1322 return frozenset(p) 1323
[docs] 1324 def __init__( 1325 self, 1326 dim_shape: pxt.NDArrayShape, 1327 codim_shape: pxt.NDArrayShape, 1328 # required in place of `dim` to have uniform interface with Operator hierarchy. 1329 Q: "PosDefOp" = None, 1330 c: "LinFunc" = None, 1331 t: pxt.Real = 0, 1332 ): 1333 r""" 1334 Parameters 1335 ---------- 1336 Q: ~pyxu.abc.PosDefOp 1337 Positive-definite operator. (Default: Identity) 1338 c: ~pyxu.abc.LinFunc 1339 Linear functional. (Default: NullFunc) 1340 t: Real 1341 Offset. (Default: 0) 1342 """ 1343 from pyxu.operator import IdentityOp, NullFunc 1344 1345 super().__init__( 1346 dim_shape=dim_shape, 1347 codim_shape=codim_shape, 1348 ) 1349 1350 # Do NOT access (_Q, _c, _t) directly through `self`: 1351 # their values may not reflect the true (Q, c, t) parameterization. 1352 # (Reason: arithmetic propagation.) 1353 # Always access (Q, c, t) by querying the arithmetic method `_quad_spec()`. 1354 self._Q = IdentityOp(dim_shape=self.dim_shape) if (Q is None) else Q 1355 self._c = NullFunc(dim_shape=self.dim_shape) if (c is None) else c 1356 self._t = t 1357 1358 # ensure dimensions are consistent if None-initialized 1359 assert self._Q.dim_shape == self.dim_shape 1360 assert self._Q.codim_shape == self.dim_shape 1361 assert self._c.dim_shape == self.dim_shape 1362 assert self._c.codim_shape == self.codim_shape 1363 1364 self.diff_lipschitz = self._Q.lipschitz
1365 1366 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 1367 Q, c, t = self._quad_spec() 1368 out = (arr * Q.apply(arr)).sum(axis=tuple(range(-self.dim_rank, 0)))[..., np.newaxis] 1369 out /= 2 1370 out += c.apply(arr) 1371 out += t 1372 return out 1373 1374 def grad(self, arr: pxt.NDArray) -> pxt.NDArray: 1375 Q, c, _ = self._quad_spec() 1376 out = Q.apply(arr) + c.grad(arr) 1377 return out 1378 1379 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray: 1380 from pyxu.operator import HomothetyOp 1381 from pyxu.opt.solver import CG 1382 from pyxu.opt.stop import MaxIter 1383 1384 Q, c, _ = self._quad_spec() 1385 A = Q + HomothetyOp(cst=1 / tau, dim_shape=Q.dim_shape) 1386 b = arr.copy() 1387 b /= tau 1388 b -= c.grad(arr) 1389 1390 slvr = CG(A=A, show_progress=False) 1391 1392 sentinel = MaxIter(n=2 * A.dim_size) 1393 stop_crit = slvr.default_stop_crit() | sentinel 1394 1395 slvr.fit(b=b, stop_crit=stop_crit) 1396 return slvr.solution() 1397 1398 def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real: 1399 Q, *_ = self._quad_spec() 1400 dL = Q.estimate_lipschitz(**kwargs) 1401 return dL 1402
[docs] 1403 def _quad_spec(self): 1404 """ 1405 Canonical quadratic parameterization. 1406 1407 Useful for some internal methods, and overloaded during operator arithmetic. 1408 """ 1409 return (self._Q, self._c, self._t)
1410 1411
[docs] 1412class LinOp(DiffMap): 1413 r""" 1414 Base class for real-valued linear operators :math:`\mathbf{A}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to 1415 \mathbb{R}^{N_{1} \times\cdots\times N_{K}}`. 1416 1417 Instances of this class must implement :py:meth:`~pyxu.abc.Map.apply` and :py:meth:`~pyxu.abc.LinOp.adjoint`. 1418 1419 If known, the Lipschitz constant :math:`L` should be stored in the :py:attr:`~pyxu.abc.Map.lipschitz` property. 1420 1421 The Jacobian of a linear map :math:`\mathbf{A}` is constant. 1422 """ 1423 1424 # Internal Helpers ------------------------------------ 1425 @staticmethod 1426 def _warn_vals_sparse_gpu(): 1427 msg = "\n".join( 1428 [ 1429 "Potential Error:", 1430 "Sparse GPU-evaluation of svdvals() is known to produce incorrect results. (CuPy-specific + Matrix-Dependant.)", 1431 "It is advised to cross-check results with CPU-computed results.", 1432 ] 1433 ) 1434 warnings.warn(msg, pxw.BackendWarning) 1435 1436 # ----------------------------------------------------- 1437 1438 @classmethod 1439 def properties(cls) -> cabc.Set[Property]: 1440 p = set(super().properties()) 1441 p.add(Property.LINEAR) 1442 return frozenset(p) 1443 1444 def __init__( 1445 self, 1446 dim_shape: pxt.NDArrayShape, 1447 codim_shape: pxt.NDArrayShape, 1448 ): 1449 super().__init__( 1450 dim_shape=dim_shape, 1451 codim_shape=codim_shape, 1452 ) 1453 self.diff_lipschitz = 0 1454
[docs] 1455 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 1456 r""" 1457 Evaluate operator adjoint at specified point(s). 1458 1459 Parameters 1460 ---------- 1461 arr: NDArray 1462 (..., N1,...,NK) input points. 1463 1464 Returns 1465 ------- 1466 out: NDArray 1467 (..., M1,...,MD) adjoint evaluations. 1468 1469 Notes 1470 ----- 1471 The *adjoint* :math:`\mathbf{A}^{\ast}: \mathbb{R}^{N_{1} \times\cdots\times N_{K}} \to \mathbb{R}^{M_{1} 1472 \times\cdots\times M_{D}}` of :math:`\mathbf{A}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to 1473 \mathbb{R}^{N_{1} \times\cdots\times N_{K}}` is defined as: 1474 1475 .. math:: 1476 1477 \langle \mathbf{x}, \mathbf{A}^{\ast}\mathbf{y}\rangle_{\mathbb{R}^{M_{1} \times\cdots\times M_{D}}} 1478 := 1479 \langle \mathbf{A}\mathbf{x}, \mathbf{y}\rangle_{\mathbb{R}^{N_{1} \times\cdots\times N_{K}}}, 1480 \qquad 1481 \forall (\mathbf{x},\mathbf{y})\in \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \times \mathbb{R}^{N_{1} 1482 \times\cdots\times N_{K}}. 1483 """ 1484 raise NotImplementedError
1485 1486 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT: 1487 return self 1488 1489 @property 1490 def T(self) -> pxt.OpT: 1491 r""" 1492 Return the adjoint of :math:`\mathbf{A}`. 1493 """ 1494 import pyxu.abc.arithmetic as arithmetic 1495 1496 return arithmetic.TransposeRule(op=self).op() 1497
[docs] 1498 def estimate_lipschitz(self, **kwargs) -> pxt.Real: 1499 r""" 1500 Compute a Lipschitz constant of the operator. 1501 1502 Parameters 1503 ---------- 1504 method: "svd" | "trace" 1505 1506 * If `svd`, compute the optimal Lipschitz constant. 1507 * If `trace`, compute an upper bound. (Default) 1508 1509 kwargs: 1510 Optional kwargs passed on to: 1511 1512 * `svd`: :py:func:`~pyxu.abc.LinOp.svdvals` 1513 * `trace`: :py:func:`~pyxu.math.hutchpp` 1514 1515 Notes 1516 ----- 1517 * The tightest Lipschitz constant is given by the spectral norm of the operator :math:`\mathbf{A}`: 1518 :math:`\|\mathbf{A}\|_{2}`. It can be computed via the SVD, which is compute-intensive task for large 1519 operators. In this setting, it may be advantageous to overestimate the Lipschitz constant with the Frobenius 1520 norm of :math:`\mathbf{A}` since :math:`\|\mathbf{A}\|_{F} \geq \|\mathbf{A}\|_{2}`. 1521 1522 :math:`\|\mathbf{A}\|_{F}` can be efficiently approximated by computing the trace of :math:`\mathbf{A}^{\ast} 1523 \mathbf{A}` (or :math:`\mathbf{A}\mathbf{A}^{\ast}`) via the `Hutch++ stochastic algorithm 1524 <https://arxiv.org/abs/2010.09649>`_. 1525 1526 * :math:`\|\mathbf{A}\|_{F}` is upper-bounded by :math:`\|\mathbf{A}\|_{F} \leq \sqrt{n} \|\mathbf{A}\|_{2}`, 1527 where the equality is reached (worst-case scenario) when the eigenspectrum of the linear operator is flat. 1528 """ 1529 method = kwargs.get("method", "trace").lower().strip() 1530 1531 if method == "svd": 1532 # svdvals() may have alternative signature in specialized classes, but we must always use 1533 # the LinOp.svdvals() interface below for kwargs-filtering. 1534 func, sig_func = self.__class__.svdvals, LinOp.svdvals 1535 kwargs.update(k=1) 1536 estimate = lambda: func(self, **kwargs).item() 1537 elif method == "trace": 1538 from pyxu.math import hutchpp as func 1539 1540 sig_func = func 1541 kwargs.update( 1542 op=self.gram() if (self.codim_size >= self.dim_size) else self.cogram(), 1543 m=kwargs.get("m", 126), 1544 ) 1545 estimate = lambda: np.sqrt(func(**kwargs)).item() 1546 else: 1547 raise NotImplementedError 1548 1549 # Filter unsupported kwargs 1550 sig = inspect.Signature.from_callable(sig_func) 1551 kwargs = {k: v for (k, v) in kwargs.items() if (k in sig.parameters)} 1552 1553 L = estimate() 1554 return L
1555
[docs] 1556 def svdvals( 1557 self, 1558 k: pxt.Integer, 1559 gpu: bool = False, 1560 dtype: pxt.DType = None, 1561 **kwargs, 1562 ) -> pxt.NDArray: 1563 r""" 1564 Compute leading singular values of the linear operator. 1565 1566 Parameters 1567 ---------- 1568 k: Integer 1569 Number of singular values to compute. 1570 gpu: bool 1571 If ``True`` the singular value decomposition is performed on the GPU. 1572 dtype: DType 1573 Working precision of the linear operator. 1574 kwargs: ~collections.abc.Mapping 1575 Additional kwargs accepted by :py:func:`~scipy.sparse.linalg.svds`. 1576 1577 Returns 1578 ------- 1579 D: NDArray 1580 (k,) singular values in ascending order. 1581 """ 1582 if dtype is None: 1583 dtype = pxrt.Width.DOUBLE.value 1584 1585 def _dense_eval(): 1586 if gpu: 1587 assert pxd.CUPY_ENABLED 1588 import cupy as xp 1589 import cupy.linalg as spx 1590 else: 1591 import numpy as xp 1592 import scipy.linalg as spx 1593 A = self.asarray(xp=xp, dtype=dtype) 1594 B = A.reshape(self.codim_size, self.dim_size) 1595 return spx.svd(B, compute_uv=False) 1596 1597 def _sparse_eval(): 1598 if gpu: 1599 assert pxd.CUPY_ENABLED 1600 import cupyx.scipy.sparse.linalg as spx 1601 1602 self._warn_vals_sparse_gpu() 1603 else: 1604 spx = spsl 1605 from pyxu.operator import ReshapeAxes 1606 from pyxu.operator.interop import to_sciop 1607 1608 # SciPy's LinearOperator only understands 2D linear operators. 1609 # -> wrap `self` into 2D form for SVD computations. 1610 lhs = ReshapeAxes(dim_shape=self.codim_shape, codim_shape=self.codim_size) 1611 rhs = ReshapeAxes(dim_shape=self.dim_size, codim_shape=self.dim_shape) 1612 op = to_sciop( 1613 op=lhs * self * rhs, 1614 gpu=gpu, 1615 dtype=dtype, 1616 ) 1617 1618 which = kwargs.get("which", "LM") 1619 assert which.upper() == "LM", "Only computing leading singular values is supported." 1620 kwargs.update( 1621 k=k, 1622 which=which, 1623 return_singular_vectors=False, 1624 # random_state=0, # unsupported by CuPy 1625 ) 1626 return spx.svds(op, **kwargs) 1627 1628 if k >= min(self.dim_size, self.codim_size) // 2: 1629 msg = "Too many svdvals wanted: using matrix-based ops." 1630 warnings.warn(msg, pxw.DenseWarning) 1631 D = _dense_eval() 1632 else: 1633 D = _sparse_eval() 1634 1635 # Filter to k largest magnitude + sorted 1636 xp = pxu.get_array_module(D) 1637 return D[xp.argsort(D)][-k:]
1638
[docs] 1639 def asarray( 1640 self, 1641 xp: pxt.ArrayModule = None, 1642 dtype: pxt.DType = None, 1643 ) -> pxt.NDArray: 1644 r""" 1645 Matrix representation of the linear operator. 1646 1647 Parameters 1648 ---------- 1649 xp: ArrayModule 1650 Which array module to use to represent the output. (Default: NumPy.) 1651 dtype: DType 1652 Precision of the array. (Default: current runtime precision.) 1653 1654 Returns 1655 ------- 1656 A: NDArray 1657 (*codim_shape, *dim_shape) array-representation of the operator. 1658 1659 Note 1660 ---- 1661 This generic implementation assumes the operator is backend-agnostic. Thus, when defining a new 1662 backend-specific operator, :py:meth:`~pyxu.abc.LinOp.asarray` may need to be overriden. 1663 """ 1664 if xp is None: 1665 xp = pxd.NDArrayInfo.default().module() 1666 if dtype is None: 1667 dtype = pxrt.Width.DOUBLE.value 1668 1669 E = xp.eye(self.dim_size, dtype=dtype).reshape(*self.dim_shape, *self.dim_shape) 1670 A = self.apply(E) # (*dim_shape, *codim_shape) 1671 1672 axes = tuple(range(-self.codim_rank, 0)) + tuple(range(self.dim_rank)) 1673 B = A.transpose(axes) # (*codim_shape, *dim_shape) 1674 return B
1675
[docs] 1676 def gram(self) -> pxt.OpT: 1677 r""" 1678 Gram operator :math:`\mathbf{A}^{\ast} \mathbf{A}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to 1679 \mathbb{R}^{M_{1} \times\cdots\times M_{D}}`. 1680 1681 Returns 1682 ------- 1683 op: OpT 1684 Gram operator. 1685 1686 Note 1687 ---- 1688 By default the Gram is computed by the composition ``self.T * self``. This may not be the fastest way to 1689 compute the Gram operator. If the Gram can be computed more efficiently (e.g. with a convolution), the user 1690 should re-define this method. 1691 """ 1692 1693 def op_expr(_) -> tuple: 1694 return ("gram", self) 1695 1696 op = self.T * self 1697 op._expr = types.MethodType(op_expr, op) 1698 return op.asop(SelfAdjointOp)
1699
[docs] 1700 def cogram(self) -> pxt.OpT: 1701 r""" 1702 Co-Gram operator :math:`\mathbf{A}\mathbf{A}^{\ast}:\mathbb{R}^{N_{1} \times\cdots\times N_{K}} \to 1703 \mathbb{R}^{N_{1} \times\cdots\times N_{K}}`. 1704 1705 Returns 1706 ------- 1707 op: OpT 1708 Co-Gram operator. 1709 1710 Note 1711 ---- 1712 By default the co-Gram is computed by the composition ``self * self.T``. This may not be the fastest way to 1713 compute the co-Gram operator. If the co-Gram can be computed more efficiently (e.g. with a convolution), the 1714 user should re-define this method. 1715 """ 1716 1717 def op_expr(_) -> tuple: 1718 return ("cogram", self) 1719 1720 op = self * self.T 1721 op._expr = types.MethodType(op_expr, op) 1722 return op.asop(SelfAdjointOp)
1723
[docs] 1724 def pinv( 1725 self, 1726 arr: pxt.NDArray, 1727 damp: pxt.Real, 1728 kwargs_init=None, 1729 kwargs_fit=None, 1730 ) -> pxt.NDArray: 1731 r""" 1732 Evaluate the Moore-Penrose pseudo-inverse :math:`\mathbf{A}^{\dagger}` at specified point(s). 1733 1734 Parameters 1735 ---------- 1736 arr: NDArray 1737 (..., N1,...,NK) input points. 1738 damp: Real 1739 Positive dampening factor regularizing the pseudo-inverse. 1740 kwargs_init: ~collections.abc.Mapping 1741 Optional kwargs to be passed to :py:meth:`~pyxu.opt.solver.CG`'s ``__init__()`` method. 1742 kwargs_fit: ~collections.abc.Mapping 1743 Optional kwargs to be passed to :py:meth:`~pyxu.opt.solver.CG`'s ``fit()`` method. 1744 1745 Returns 1746 ------- 1747 out: NDArray 1748 (..., M1,...,MD) pseudo-inverse(s). 1749 1750 Notes 1751 ----- 1752 The Moore-Penrose pseudo-inverse of an operator :math:`\mathbf{A}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} 1753 \to \mathbb{R}^{N_{1} \times\cdots\times N_{K}}` is defined as the operator :math:`\mathbf{A}^{\dagger}: 1754 \mathbb{R}^{N_{1} \times\cdots\times N_{K}} \to \mathbb{R}^{M_{1} \times\cdots\times M_{D}}` verifying the 1755 Moore-Penrose conditions: 1756 1757 1. :math:`\mathbf{A} \mathbf{A}^{\dagger} \mathbf{A} = \mathbf{A}`, 1758 2. :math:`\mathbf{A}^{\dagger} \mathbf{A} \mathbf{A}^{\dagger} = \mathbf{A}^{\dagger}`, 1759 3. :math:`(\mathbf{A}^{\dagger} \mathbf{A})^{\ast} = \mathbf{A}^{\dagger} \mathbf{A}`, 1760 4. :math:`(\mathbf{A} \mathbf{A}^{\dagger})^{\ast} = \mathbf{A} \mathbf{A}^{\dagger}`. 1761 1762 This operator exists and is unique for any finite-dimensional linear operator. The action of the pseudo-inverse 1763 :math:`\mathbf{A}^{\dagger} \mathbf{y}` for every :math:`\mathbf{y} \in \mathbb{R}^{N_{1} \times\cdots\times 1764 N_{K}}` can be computed in matrix-free fashion by solving the *normal equations*: 1765 1766 .. math:: 1767 1768 \mathbf{A}^{\ast} \mathbf{A} \mathbf{x} = \mathbf{A}^{\ast} \mathbf{y} 1769 \quad\Leftrightarrow\quad 1770 \mathbf{x} = \mathbf{A}^{\dagger} \mathbf{y}, 1771 \quad 1772 \forall (\mathbf{x}, \mathbf{y}) \in \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \times \mathbb{R}^{N_{1} 1773 \times\cdots\times N_{K}}. 1774 1775 In the case of severe ill-conditioning, it is possible to consider the dampened normal equations for a 1776 numerically-stabler approximation of :math:`\mathbf{A}^{\dagger} \mathbf{y}`: 1777 1778 .. math:: 1779 1780 (\mathbf{A}^{\ast} \mathbf{A} + \tau I) \mathbf{x} = \mathbf{A}^{\ast} \mathbf{y}, 1781 1782 where :math:`\tau > 0` corresponds to the `damp` parameter. 1783 """ 1784 from pyxu.operator import HomothetyOp 1785 from pyxu.opt.solver import CG 1786 from pyxu.opt.stop import MaxIter 1787 1788 kwargs_fit = dict() if kwargs_fit is None else kwargs_fit 1789 kwargs_init = dict() if kwargs_init is None else kwargs_init 1790 kwargs_init.update(show_progress=kwargs_init.get("show_progress", False)) 1791 1792 if np.isclose(damp, 0): 1793 A = self.gram() 1794 else: 1795 A = self.gram() + HomothetyOp(cst=damp, dim_shape=self.dim_shape) 1796 1797 cg = CG(A, **kwargs_init) 1798 if "stop_crit" not in kwargs_fit: 1799 # .pinv() may not have sufficiently converged given the default CG stopping criteria. 1800 # To avoid infinite loops, CG iterations are thresholded. 1801 sentinel = MaxIter(n=20 * A.dim_size) 1802 kwargs_fit["stop_crit"] = cg.default_stop_crit() | sentinel 1803 1804 b = self.adjoint(arr) 1805 cg.fit(b=b, **kwargs_fit) 1806 return cg.solution()
1807
[docs] 1808 def dagger( 1809 self, 1810 damp: pxt.Real, 1811 kwargs_init=None, 1812 kwargs_fit=None, 1813 ) -> pxt.OpT: 1814 r""" 1815 Return the Moore-Penrose pseudo-inverse operator :math:`\mathbf{A}^\dagger`. 1816 1817 Parameters 1818 ---------- 1819 damp: Real 1820 Positive dampening factor regularizing the pseudo-inverse. 1821 kwargs_init: ~collections.abc.Mapping 1822 Optional kwargs to be passed to :py:meth:`~pyxu.opt.solver.CG`'s ``__init__()`` method. 1823 kwargs_fit: ~collections.abc.Mapping 1824 Optional kwargs to be passed to :py:meth:`~pyxu.opt.solver.CG`'s ``fit()`` method. 1825 1826 Returns 1827 ------- 1828 op: OpT 1829 Moore-Penrose pseudo-inverse operator. 1830 """ 1831 from pyxu.operator.interop import from_source 1832 1833 def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray: 1834 return self.pinv( 1835 arr, 1836 damp=_._damp, 1837 kwargs_init=_._kwargs_init, 1838 kwargs_fit=_._kwargs_fit, 1839 ) 1840 1841 def op_adjoint(_, arr: pxt.NDArray) -> pxt.NDArray: 1842 return self.T.pinv( 1843 arr, 1844 damp=_._damp, 1845 kwargs_init=_._kwargs_init, 1846 kwargs_fit=_._kwargs_fit, 1847 ) 1848 1849 kwargs_fit = dict() if kwargs_fit is None else kwargs_fit 1850 kwargs_init = dict() if kwargs_init is None else kwargs_init 1851 1852 dagger = from_source( 1853 cls=SquareOp if (self.dim_size == self.codim_size) else LinOp, 1854 dim_shape=self.codim_shape, 1855 codim_shape=self.dim_shape, 1856 embed=dict( 1857 _name="dagger", 1858 _damp=damp, 1859 _kwargs_init=copy.copy(kwargs_init), 1860 _kwargs_fit=copy.copy(kwargs_fit), 1861 ), 1862 apply=op_apply, 1863 adjoint=op_adjoint, 1864 _expr=lambda _: (_._name, _, _._damp), 1865 ) 1866 return dagger
1867
[docs] 1868 @classmethod 1869 def from_array( 1870 cls, 1871 A: typ.Union[pxt.NDArray, pxt.SparseArray], 1872 dim_rank=None, 1873 enable_warnings: bool = True, 1874 ) -> pxt.OpT: 1875 r""" 1876 Instantiate a :py:class:`~pyxu.abc.LinOp` from its array representation. 1877 1878 Parameters 1879 ---------- 1880 A: NDArray 1881 (*codim_shape, *dim_shape) array. 1882 dim_rank: Integer 1883 Dimension rank :math:`D`. (Can be omitted if `A` is 2D.) 1884 enable_warnings: bool 1885 If ``True``, emit a warning in case of precision mis-match issues. 1886 1887 Returns 1888 ------- 1889 op: OpT 1890 Linear operator 1891 """ 1892 from pyxu.operator.linop.base import _ExplicitLinOp 1893 1894 op = _ExplicitLinOp( 1895 cls, 1896 mat=A, 1897 dim_rank=dim_rank, 1898 enable_warnings=enable_warnings, 1899 ) 1900 return op
1901 1902
[docs] 1903class SquareOp(LinOp): 1904 r""" 1905 Base class for *square* linear operators, i.e. :math:`\mathbf{A}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to 1906 \mathbb{R}^{M_{1} \times\cdots\times M_{D}}` (endomorphsisms). 1907 """ 1908 1909 @classmethod 1910 def properties(cls) -> cabc.Set[Property]: 1911 p = set(super().properties()) 1912 p.add(Property.LINEAR_SQUARE) 1913 return frozenset(p) 1914 1915 def __init__( 1916 self, 1917 dim_shape: pxt.NDArrayShape, 1918 codim_shape: pxt.NDArrayShape, 1919 ): 1920 super().__init__( 1921 dim_shape=dim_shape, 1922 codim_shape=codim_shape, 1923 ) 1924 assert self.dim_size == self.codim_size 1925
[docs] 1926 def trace(self, **kwargs) -> pxt.Real: 1927 """ 1928 Compute trace of the operator. 1929 1930 Parameters 1931 ---------- 1932 method: "explicit" | "hutchpp" 1933 1934 * If `explicit`, compute the exact trace. 1935 * If `hutchpp`, compute an approximation. (Default) 1936 1937 kwargs: ~collections.abc.Mapping 1938 Optional kwargs passed to: 1939 1940 * `explicit`: :py:func:`~pyxu.math.trace` 1941 * `hutchpp`: :py:func:`~pyxu.math.hutchpp` 1942 1943 Returns 1944 ------- 1945 tr: Real 1946 Trace estimate. 1947 """ 1948 from pyxu.math import hutchpp, trace 1949 1950 method = kwargs.get("method", "hutchpp").lower().strip() 1951 1952 if method == "explicit": 1953 func = sig_func = trace 1954 estimate = lambda: func(op=self, **kwargs) 1955 elif method == "hutchpp": 1956 func = sig_func = hutchpp 1957 estimate = lambda: func(op=self, **kwargs) 1958 else: 1959 raise NotImplementedError 1960 1961 # Filter unsupported kwargs 1962 sig = inspect.Signature.from_callable(sig_func) 1963 kwargs = {k: v for (k, v) in kwargs.items() if (k in sig.parameters)} 1964 1965 tr = estimate() 1966 return tr
1967 1968
[docs] 1969class NormalOp(SquareOp): 1970 r""" 1971 Base class for *normal* operators. 1972 1973 Normal operators satisfy the relation :math:`\mathbf{A} \mathbf{A}^{\ast} = \mathbf{A}^{\ast} \mathbf{A}`. It can 1974 be `shown <https://www.wikiwand.com/en/Spectral_theorem#/Normal_matrices>`_ that an operator is normal iff it is 1975 *unitarily diagonalizable*, i.e. :math:`\mathbf{A} = \mathbf{U} \mathbf{D} \mathbf{U}^{\ast}`. 1976 """ 1977 1978 @classmethod 1979 def properties(cls) -> cabc.Set[Property]: 1980 p = set(super().properties()) 1981 p.add(Property.LINEAR_NORMAL) 1982 return frozenset(p) 1983 1984 def cogram(self) -> pxt.OpT: 1985 return self.gram()
1986 1987
[docs] 1988class SelfAdjointOp(NormalOp): 1989 r""" 1990 Base class for *self-adjoint* operators. 1991 1992 Self-adjoint operators satisfy the relation :math:`\mathbf{A}^{\ast} = \mathbf{A}`. 1993 """ 1994 1995 @classmethod 1996 def properties(cls) -> cabc.Set[Property]: 1997 p = set(super().properties()) 1998 p.add(Property.LINEAR_SELF_ADJOINT) 1999 return frozenset(p) 2000 2001 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 2002 return self.apply(arr)
2003 2004
[docs] 2005class UnitOp(NormalOp): 2006 r""" 2007 Base class for *unitary* operators. 2008 2009 Unitary operators satisfy the relation :math:`\mathbf{A} \mathbf{A}^{\ast} = \mathbf{A}^{\ast} \mathbf{A} = I`. 2010 """ 2011 2012 @classmethod 2013 def properties(cls) -> cabc.Set[Property]: 2014 p = set(super().properties()) 2015 p.add(Property.LINEAR_UNITARY) 2016 return frozenset(p) 2017 2018 def __init__( 2019 self, 2020 dim_shape: pxt.NDArrayShape, 2021 codim_shape: pxt.NDArrayShape, 2022 ): 2023 super().__init__( 2024 dim_shape=dim_shape, 2025 codim_shape=codim_shape, 2026 ) 2027 self.lipschitz = UnitOp.estimate_lipschitz(self) 2028 2029 def estimate_lipschitz(self, **kwargs) -> pxt.Real: 2030 return 1 2031 2032 def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray: 2033 out = self.adjoint(arr) 2034 if not np.isclose(damp, 0): 2035 out = pxu.copy_if_unsafe(out) 2036 out /= 1 + damp 2037 return out 2038 2039 def dagger(self, damp: pxt.Real, **kwargs) -> pxt.OpT: 2040 op = self.T / (1 + damp) 2041 return op 2042 2043 def gram(self) -> pxt.OpT: 2044 from pyxu.operator import IdentityOp 2045 2046 return IdentityOp(dim_shape=self.dim_shape) 2047 2048 def svdvals(self, **kwargs) -> pxt.NDArray: 2049 gpu = kwargs.get("gpu", False) 2050 xp = pxd.NDArrayInfo.from_flag(gpu).module() 2051 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value) 2052 D = xp.ones(kwargs["k"], dtype=dtype) 2053 return D
2054 2055
[docs] 2056class ProjOp(SquareOp): 2057 r""" 2058 Base class for *projection* operators. 2059 2060 Projection operators are *idempotent*, i.e. :math:`\mathbf{A}^{2} = \mathbf{A}`. 2061 """ 2062 2063 @classmethod 2064 def properties(cls) -> cabc.Set[Property]: 2065 p = set(super().properties()) 2066 p.add(Property.LINEAR_IDEMPOTENT) 2067 return frozenset(p)
2068 2069
[docs] 2070class OrthProjOp(ProjOp, SelfAdjointOp): 2071 r""" 2072 Base class for *orthogonal projection* operators. 2073 2074 Orthogonal projection operators are *idempotent* and *self-adjoint*, i.e. :math:`\mathbf{A}^{2} = \mathbf{A}` and 2075 :math:`\mathbf{A}^{\ast} = \mathbf{A}`. 2076 """ 2077 2078 @classmethod 2079 def properties(cls) -> cabc.Set[Property]: 2080 p = set() 2081 for klass in cls.__bases__: 2082 p |= klass.properties() 2083 return frozenset(p) 2084 2085 def __init__( 2086 self, 2087 dim_shape: pxt.NDArrayShape, 2088 codim_shape: pxt.NDArrayShape, 2089 ): 2090 super().__init__( 2091 dim_shape=dim_shape, 2092 codim_shape=codim_shape, 2093 ) 2094 self.lipschitz = OrthProjOp.estimate_lipschitz(self) 2095 2096 def estimate_lipschitz(self, **kwargs) -> pxt.Real: 2097 return 1 2098 2099 def gram(self) -> pxt.OpT: 2100 return self 2101 2102 def cogram(self) -> pxt.OpT: 2103 return self 2104 2105 def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray: 2106 out = self.apply(arr) 2107 if not np.isclose(damp, 0): 2108 out = pxu.copy_if_unsafe(out) 2109 out /= 1 + damp 2110 return out 2111 2112 def dagger(self, damp: pxt.Real, **kwargs) -> pxt.OpT: 2113 op = self / (1 + damp) 2114 return op
2115 2116
[docs] 2117class PosDefOp(SelfAdjointOp): 2118 r""" 2119 Base class for *positive-definite* operators. 2120 """ 2121 2122 @classmethod 2123 def properties(cls) -> cabc.Set[Property]: 2124 p = set(super().properties()) 2125 p.add(Property.LINEAR_POSITIVE_DEFINITE) 2126 return frozenset(p)
2127 2128
[docs] 2129class LinFunc(ProxDiffFunc, LinOp): 2130 r""" 2131 Base class for real-valued linear functionals :math:`f: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}`. 2132 2133 Instances of this class must implement :py:meth:`~pyxu.abc.Map.apply`, and :py:meth:`~pyxu.abc.LinOp.adjoint`. 2134 2135 If known, the Lipschitz constant :math:`L` should be stored in the :py:attr:`~pyxu.abc.Map.lipschitz` property. 2136 """ 2137 2138 @classmethod 2139 def properties(cls) -> cabc.Set[Property]: 2140 p = set() 2141 for klass in cls.__bases__: 2142 p |= klass.properties() 2143 return frozenset(p) 2144 2145 def __init__( 2146 self, 2147 dim_shape: pxt.NDArrayShape, 2148 codim_shape: pxt.NDArrayShape, 2149 ): 2150 for klass in [ProxDiffFunc, LinOp]: 2151 klass.__init__( 2152 self, 2153 dim_shape=dim_shape, 2154 codim_shape=codim_shape, 2155 ) 2156 2157 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT: 2158 return LinOp.jacobian(self, arr) 2159 2160 def estimate_lipschitz(self, **kwargs) -> pxt.Real: 2161 # Try all backends until one works. 2162 for ndi in pxd.NDArrayInfo: 2163 try: 2164 xp = ndi.module() 2165 g = self.grad(xp.ones(self.dim_shape)) 2166 L = float(xp.sqrt(xp.sum(g**2))) 2167 return L 2168 except Exception: 2169 pass 2170 2171 def grad(self, arr: pxt.NDArray) -> pxt.NDArray: 2172 ndi = pxd.NDArrayInfo.from_obj(arr) 2173 xp = ndi.module() 2174 2175 sh = arr.shape[: -self.dim_rank] 2176 x = xp.ones((*sh, 1), dtype=arr.dtype) 2177 g = self.adjoint(x) 2178 2179 if ndi == pxd.NDArrayInfo.DASK: 2180 # LinFuncs auto-determine [grad,prox,fenchel_prox]() via the user-specified adjoint(). 2181 # Problem: cannot forward any core-chunk info to adjoint(), hence grad's core-chunks 2182 # may differ from `arr`. This is problematic since [grad,prox,fenchel_prox]() should 2183 # preserve core-chunks by default. 2184 if g.chunks != arr.chunks: 2185 g = g.rechunk(arr.chunks) 2186 return g 2187 2188 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray: 2189 out = arr - tau * self.grad(arr) 2190 return out 2191 2192 def fenchel_prox(self, arr: pxt.NDArray, sigma: pxt.Real) -> pxt.NDArray: 2193 return self.grad(arr) 2194 2195 def cogram(self) -> pxt.OpT: 2196 from pyxu.operator import HomothetyOp 2197 2198 L = self.estimate_lipschitz() 2199 return HomothetyOp(cst=L**2, dim_shape=1) 2200 2201 def svdvals(self, **kwargs) -> pxt.NDArray: 2202 gpu = kwargs.get("gpu", False) 2203 xp = pxd.NDArrayInfo.from_flag(gpu).module() 2204 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value) 2205 2206 L = self.estimate_lipschitz() 2207 D = xp.array([L], dtype=dtype) 2208 return D 2209 2210 def asarray(self, **kwargs) -> pxt.NDArray: 2211 xp = kwargs.get("xp", pxd.NDArrayInfo.default().module()) 2212 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value) 2213 2214 E = xp.ones((1, 1), dtype=dtype) 2215 A = self.adjoint(E) # (1, *dim_shape) 2216 return A
2217 2218 2219def _core_operators() -> cabc.Set[pxt.OpC]: 2220 # Operators which can be sub-classed by end-users and participate in arithmetic rules. 2221 ops = set() 2222 for _ in globals().values(): 2223 if inspect.isclass(_) and issubclass(_, Operator): 2224 ops.add(_) 2225 ops.remove(Operator) 2226 return ops 2227 2228 2229def _is_real(x) -> bool: 2230 if isinstance(x, pxt.Real): 2231 return True 2232 elif isinstance(x, pxd.supported_array_types()) and (x.size == 1): 2233 return True 2234 else: 2235 return False 2236 2237 2238__all__ = [ 2239 "Operator", 2240 "Property", 2241 *map(lambda _: _.__name__, _core_operators()), 2242]