Source code for pyxu.abc.arithmetic

   1# Arithmetic Rules
   2
   3import types
   4
   5import numpy as np
   6
   7import pyxu.abc.operator as pxo
   8import pyxu.info.deps as pxd
   9import pyxu.info.ptype as pxt
  10import pyxu.util as pxu
  11
  12
[docs] 13class Rule: 14 """ 15 General arithmetic rule. 16 17 This class defines default arithmetic rules applicable unless re-defined by sub-classes. 18 """ 19
[docs] 20 def op(self) -> pxt.OpT: 21 """ 22 Returns 23 ------- 24 op: OpT 25 Synthesize operator. 26 """ 27 raise NotImplementedError
28 29 # Helper Methods ---------------------------------------------------------- 30 31 @staticmethod 32 def _propagate_constants(op: pxt.OpT): 33 # Propagate (diff-)Lipschitz constants forward via special call to 34 # Rule()-overridden `estimate_[diff_]lipschitz()` methods. 35 36 # Important: we write to _[diff_]lipschitz to not overwrite estimate_[diff_]lipschitz() methods. 37 if op.has(pxo.Property.CAN_EVAL): 38 op._lipschitz = op.estimate_lipschitz(__rule=True) 39 if op.has(pxo.Property.DIFFERENTIABLE): 40 op._diff_lipschitz = op.estimate_diff_lipschitz(__rule=True) 41 42 # Default Arithmetic Methods ---------------------------------------------- 43 # Fallback on these when no simple form in terms of Rule.__init__() args is known. 44 # If a method from Property.arithmetic_methods() is not listed here, then all Rule subclasses 45 # provide an overloaded implementation. 46 47 def __call__(self, arr: pxt.NDArray) -> pxt.NDArray: 48 return self.apply(arr) 49 50 def svdvals(self, **kwargs) -> pxt.NDArray: 51 D = self.__class__.svdvals(self, **kwargs) 52 return D 53 54 def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray: 55 out = self.__class__.pinv(self, arr=arr, damp=damp, **kwargs) 56 return out 57 58 def trace(self, **kwargs) -> pxt.Real: 59 tr = self.__class__.trace(self, **kwargs) 60 return tr
61 62
[docs] 63class ScaleRule(Rule): 64 r""" 65 Arithmetic rules for element-wise scaling: :math:`B(x) = \alpha A(x)`. 66 67 Special Cases:: 68 69 \alpha = 0 => NullOp/NullFunc 70 \alpha = 1 => self 71 72 Else:: 73 74 |--------------------------|-------------|--------------------------------------------------------------------| 75 | Property | Preserved? | Arithmetic Update Rule(s) | 76 |--------------------------|-------------|--------------------------------------------------------------------| 77 | CAN_EVAL | yes | op_new.apply(arr) = op_old.apply(arr) * \alpha | 78 | | | op_new.lipschitz = op_old.lipschitz * abs(\alpha) | 79 |--------------------------|-------------|--------------------------------------------------------------------| 80 | FUNCTIONAL | yes | | 81 |--------------------------|-------------|--------------------------------------------------------------------| 82 | PROXIMABLE | \alpha > 0 | op_new.prox(arr, tau) = op_old.prox(arr, tau * \alpha) | 83 |--------------------------|-------------|--------------------------------------------------------------------| 84 | DIFFERENTIABLE | yes | op_new.jacobian(arr) = op_old.jacobian(arr) * \alpha | 85 | | | op_new.diff_lipschitz = op_old.diff_lipschitz * abs(\alpha) | 86 |--------------------------|-------------|--------------------------------------------------------------------| 87 | DIFFERENTIABLE_FUNCTION | yes | op_new.grad(arr) = op_old.grad(arr) * \alpha | 88 |--------------------------|-------------|--------------------------------------------------------------------| 89 | QUADRATIC | \alpha > 0 | Q, c, t = op_old._quad_spec() | 90 | | | op_new._quad_spec() = (\alpha * Q, \alpha * c, \alpha * t) | 91 |--------------------------|-------------|--------------------------------------------------------------------| 92 | LINEAR | yes | op_new.adjoint(arr) = op_old.adjoint(arr) * \alpha | 93 | | | op_new.asarray() = op_old.asarray() * \alpha | 94 | | | op_new.svdvals() = op_old.svdvals() * abs(\alpha) | 95 | | | op_new.pinv(x, damp) = op_old.pinv(x, damp / (\alpha**2)) / \alpha | 96 | | | op_new.gram() = op_old.gram() * (\alpha**2) | 97 | | | op_new.cogram() = op_old.cogram() * (\alpha**2) | 98 |--------------------------|-------------|--------------------------------------------------------------------| 99 | LINEAR_SQUARE | yes | op_new.trace() = op_old.trace() * \alpha | 100 |--------------------------|-------------|--------------------------------------------------------------------| 101 | LINEAR_NORMAL | yes | | 102 |--------------------------|-------------|--------------------------------------------------------------------| 103 | LINEAR_UNITARY | \alpha = -1 | | 104 |--------------------------|-------------|--------------------------------------------------------------------| 105 | LINEAR_SELF_ADJOINT | yes | | 106 |--------------------------|-------------|--------------------------------------------------------------------| 107 | LINEAR_POSITIVE_DEFINITE | \alpha > 0 | | 108 |--------------------------|-------------|--------------------------------------------------------------------| 109 | LINEAR_IDEMPOTENT | no | | 110 |--------------------------|-------------|--------------------------------------------------------------------| 111 """ 112 113 def __init__(self, op: pxt.OpT, cst: pxt.Real): 114 super().__init__() 115 self._op = op 116 self._cst = float(cst) 117 118 def op(self) -> pxt.OpT: 119 if np.isclose(self._cst, 0): 120 from pyxu.operator import NullOp 121 122 op = NullOp( 123 dim_shape=self._op.dim_shape, 124 codim_shape=self._op.codim_shape, 125 ) 126 elif np.isclose(self._cst, 1): 127 op = self._op 128 else: 129 klass = self._infer_op_klass() 130 op = klass( 131 dim_shape=self._op.dim_shape, 132 codim_shape=self._op.codim_shape, 133 ) 134 op._op = self._op # embed for introspection 135 op._cst = self._cst # embed for introspection 136 for p in op.properties(): 137 for name in p.arithmetic_methods(): 138 func = getattr(self.__class__, name) 139 setattr(op, name, types.MethodType(func, op)) 140 self._propagate_constants(op) 141 return op 142 143 def _expr(self) -> tuple: 144 return ("scale", self._op, self._cst) 145 146 def _infer_op_klass(self) -> pxt.OpC: 147 preserved = { 148 pxo.Property.CAN_EVAL, 149 pxo.Property.FUNCTIONAL, 150 pxo.Property.DIFFERENTIABLE, 151 pxo.Property.DIFFERENTIABLE_FUNCTION, 152 pxo.Property.LINEAR, 153 pxo.Property.LINEAR_SQUARE, 154 pxo.Property.LINEAR_NORMAL, 155 pxo.Property.LINEAR_SELF_ADJOINT, 156 } 157 if self._cst > 0: 158 preserved |= { 159 pxo.Property.LINEAR_POSITIVE_DEFINITE, 160 pxo.Property.QUADRATIC, 161 pxo.Property.PROXIMABLE, 162 } 163 if self._op.has(pxo.Property.LINEAR): 164 preserved.add(pxo.Property.PROXIMABLE) 165 if np.isclose(self._cst, -1): 166 preserved.add(pxo.Property.LINEAR_UNITARY) 167 168 properties = self._op.properties() & preserved 169 klass = pxo.Operator._infer_operator_type(properties) 170 return klass 171 172 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 173 out = pxu.copy_if_unsafe(self._op.apply(arr)) 174 out *= self._cst 175 return out 176 177 def estimate_lipschitz(self, **kwargs) -> pxt.Real: 178 no_eval = "__rule" in kwargs 179 if no_eval: 180 L = float(self._op.lipschitz) 181 else: 182 L = self._op.estimate_lipschitz(**kwargs) 183 L *= abs(self._cst) 184 return L 185 186 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray: 187 return self._op.prox(arr, tau * self._cst) 188 189 def _quad_spec(self): 190 Q1, c1, t1 = self._op._quad_spec() 191 Q2 = ScaleRule(op=Q1, cst=self._cst).op() 192 c2 = ScaleRule(op=c1, cst=self._cst).op() 193 t2 = t1 * self._cst 194 return (Q2, c2, t2) 195 196 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT: 197 if self.has(pxo.Property.LINEAR): 198 op = self 199 else: 200 op = self._op.jacobian(arr) * self._cst 201 return op 202 203 def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real: 204 no_eval = "__rule" in kwargs 205 if no_eval: 206 dL = float(self._op.diff_lipschitz) 207 else: 208 dL = self._op.estimate_diff_lipschitz(**kwargs) 209 dL *= abs(self._cst) 210 return dL 211 212 def grad(self, arr: pxt.NDArray) -> pxt.NDArray: 213 out = pxu.copy_if_unsafe(self._op.grad(arr)) 214 out *= self._cst 215 return out 216 217 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 218 out = pxu.copy_if_unsafe(self._op.adjoint(arr)) 219 out *= self._cst 220 return out 221 222 def asarray(self, **kwargs) -> pxt.NDArray: 223 A = pxu.copy_if_unsafe(self._op.asarray(**kwargs)) 224 A *= self._cst 225 return A 226 227 def svdvals(self, **kwargs) -> pxt.NDArray: 228 D = pxu.copy_if_unsafe(self._op.svdvals(**kwargs)) 229 D *= abs(self._cst) 230 return D 231 232 def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray: 233 scale = damp / (self._cst**2) 234 out = pxu.copy_if_unsafe(self._op.pinv(arr, damp=scale, **kwargs)) 235 out /= self._cst 236 return out 237 238 def gram(self) -> pxt.OpT: 239 op = self._op.gram() * (self._cst**2) 240 return op 241 242 def cogram(self) -> pxt.OpT: 243 op = self._op.cogram() * (self._cst**2) 244 return op 245 246 def trace(self, **kwargs) -> pxt.Real: 247 tr = self._op.trace(**kwargs) * self._cst 248 return tr
249 250
[docs] 251class ArgScaleRule(Rule): 252 r""" 253 Arithmetic rules for element-wise parameter scaling: :math:`B(x) = A(\alpha x)`. 254 255 Special Cases:: 256 257 \alpha = 0 => ConstantValued (w/ potential vector-valued output) 258 \alpha = 1 => self 259 260 Else:: 261 262 |--------------------------|-------------|-----------------------------------------------------------------------------| 263 | Property | Preserved? | Arithmetic Update Rule(s) | 264 |--------------------------|-------------|-----------------------------------------------------------------------------| 265 | CAN_EVAL | yes | op_new.apply(arr) = op_old.apply(arr * \alpha) | 266 | | | op_new.lipschitz = op_old.lipschitz * abs(\alpha) | 267 |--------------------------|-------------|-----------------------------------------------------------------------------| 268 | FUNCTIONAL | yes | | 269 |--------------------------|-------------|-----------------------------------------------------------------------------| 270 | PROXIMABLE | yes | op_new.prox(arr, tau) = op_old.prox(\alpha * arr, \alpha**2 * tau) / \alpha | 271 |--------------------------|-------------|-----------------------------------------------------------------------------| 272 | DIFFERENTIABLE | yes | op_new.diff_lipschitz = op_old.diff_lipschitz * (\alpha**2) | 273 | | | op_new.jacobian(arr) = op_old.jacobian(arr * \alpha) * \alpha | 274 |--------------------------|-------------|-----------------------------------------------------------------------------| 275 | DIFFERENTIABLE_FUNCTION | yes | op_new.grad(arr) = op_old.grad(\alpha * arr) * \alpha | 276 |--------------------------|-------------|-----------------------------------------------------------------------------| 277 | QUADRATIC | yes | Q, c, t = op_old._quad_spec() | 278 | | | op_new._quad_spec() = (\alpha**2 * Q, \alpha * c, t) | 279 |--------------------------|-------------|-----------------------------------------------------------------------------| 280 | LINEAR | yes | op_new.adjoint(arr) = op_old.adjoint(arr) * \alpha | 281 | | | op_new.asarray() = op_old.asarray() * \alpha | 282 | | | op_new.svdvals() = op_old.svdvals() * abs(\alpha) | 283 | | | op_new.pinv(x, damp) = op_old.pinv(x, damp / (\alpha**2)) / \alpha | 284 | | | op_new.gram() = op_old.gram() * (\alpha**2) | 285 | | | op_new.cogram() = op_old.cogram() * (\alpha**2) | 286 |--------------------------|-------------|-----------------------------------------------------------------------------| 287 | LINEAR_SQUARE | yes | op_new.trace() = op_old.trace() * \alpha | 288 |--------------------------|-------------|-----------------------------------------------------------------------------| 289 | LINEAR_NORMAL | yes | | 290 |--------------------------|-------------|-----------------------------------------------------------------------------| 291 | LINEAR_UNITARY | \alpha = -1 | | 292 |--------------------------|-------------|-----------------------------------------------------------------------------| 293 | LINEAR_SELF_ADJOINT | yes | | 294 |--------------------------|-------------|-----------------------------------------------------------------------------| 295 | LINEAR_POSITIVE_DEFINITE | \alpha > 0 | | 296 |--------------------------|-------------|-----------------------------------------------------------------------------| 297 | LINEAR_IDEMPOTENT | no | | 298 |--------------------------|-------------|-----------------------------------------------------------------------------| 299 """ 300 301 def __init__(self, op: pxt.OpT, cst: pxt.Real): 302 super().__init__() 303 self._op = op 304 self._cst = float(cst) 305 306 def op(self) -> pxt.OpT: 307 if np.isclose(self._cst, 0): 308 # ConstantVECTOR output: modify ConstantValued to work. 309 from pyxu.operator import ConstantValued 310 311 def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray: 312 xp = pxu.get_array_module(arr) 313 arr = xp.zeros_like(arr) 314 out = self._op.apply(arr) 315 return out 316 317 op = ConstantValued( 318 dim_shape=self._op.dim_shape, 319 codim_shape=self._op.codim_shape, 320 cst=self._cst, 321 ) 322 op.apply = types.MethodType(op_apply, op) 323 op._name = "ConstantVector" 324 elif np.isclose(self._cst, 1): 325 op = self._op 326 else: 327 klass = self._infer_op_klass() 328 op = klass( 329 dim_shape=self._op.dim_shape, 330 codim_shape=self._op.codim_shape, 331 ) 332 op._op = self._op # embed for introspection 333 op._cst = self._cst # embed for introspection 334 for p in op.properties(): 335 for name in p.arithmetic_methods(): 336 func = getattr(self.__class__, name) 337 setattr(op, name, types.MethodType(func, op)) 338 self._propagate_constants(op) 339 return op 340 341 def _expr(self) -> tuple: 342 return ("argscale", self._op, self._cst) 343 344 def _infer_op_klass(self) -> pxt.OpC: 345 preserved = { 346 pxo.Property.CAN_EVAL, 347 pxo.Property.FUNCTIONAL, 348 pxo.Property.PROXIMABLE, 349 pxo.Property.DIFFERENTIABLE, 350 pxo.Property.DIFFERENTIABLE_FUNCTION, 351 pxo.Property.LINEAR, 352 pxo.Property.LINEAR_SQUARE, 353 pxo.Property.LINEAR_NORMAL, 354 pxo.Property.LINEAR_SELF_ADJOINT, 355 pxo.Property.QUADRATIC, 356 } 357 if self._cst > 0: 358 preserved.add(pxo.Property.LINEAR_POSITIVE_DEFINITE) 359 if np.isclose(self._cst, -1): 360 preserved.add(pxo.Property.LINEAR_UNITARY) 361 362 properties = self._op.properties() & preserved 363 klass = pxo.Operator._infer_operator_type(properties) 364 return klass 365 366 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 367 x = arr.copy() 368 x *= self._cst 369 out = self._op.apply(x) 370 return out 371 372 def estimate_lipschitz(self, **kwargs) -> pxt.Real: 373 no_eval = "__rule" in kwargs 374 if no_eval: 375 L = float(self._op.lipschitz) 376 else: 377 L = self._op.estimate_lipschitz(**kwargs) 378 L *= abs(self._cst) 379 return L 380 381 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray: 382 x = arr.copy() 383 x *= self._cst 384 y = self._op.prox(x, (self._cst**2) * tau) 385 out = pxu.copy_if_unsafe(y) 386 out /= self._cst 387 return out 388 389 def _quad_spec(self): 390 Q1, c1, t1 = self._op._quad_spec() 391 Q2 = ScaleRule(op=Q1, cst=self._cst**2).op() 392 c2 = ScaleRule(op=c1, cst=self._cst).op() 393 t2 = t1 394 return (Q2, c2, t2) 395 396 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT: 397 if self.has(pxo.Property.LINEAR): 398 op = self 399 else: 400 x = arr.copy() 401 x *= self._cst 402 op = self._op.jacobian(x) * self._cst 403 return op 404 405 def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real: 406 no_eval = "__rule" in kwargs 407 if no_eval: 408 dL = self._op.diff_lipschitz 409 else: 410 dL = self._op.estimate_diff_lipschitz(**kwargs) 411 dL *= self._cst**2 412 return dL 413 414 def grad(self, arr: pxt.NDArray) -> pxt.NDArray: 415 x = arr.copy() 416 x *= self._cst 417 out = pxu.copy_if_unsafe(self._op.grad(x)) 418 out *= self._cst 419 return out 420 421 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 422 out = pxu.copy_if_unsafe(self._op.adjoint(arr)) 423 out *= self._cst 424 return out 425 426 def asarray(self, **kwargs) -> pxt.NDArray: 427 A = pxu.copy_if_unsafe(self._op.asarray(**kwargs)) 428 A *= self._cst 429 return A 430 431 def svdvals(self, **kwargs) -> pxt.NDArray: 432 D = pxu.copy_if_unsafe(self._op.svdvals(**kwargs)) 433 D *= abs(self._cst) 434 return D 435 436 def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray: 437 scale = damp / (self._cst**2) 438 out = pxu.copy_if_unsafe(self._op.pinv(arr, damp=scale, **kwargs)) 439 out /= self._cst 440 return out 441 442 def gram(self) -> pxt.OpT: 443 op = self._op.gram() * (self._cst**2) 444 return op 445 446 def cogram(self) -> pxt.OpT: 447 op = self._op.cogram() * (self._cst**2) 448 return op 449 450 def trace(self, **kwargs) -> pxt.Real: 451 tr = self._op.trace(**kwargs) * self._cst 452 return tr
453 454
[docs] 455class ArgShiftRule(Rule): 456 r""" 457 Arithmetic rules for parameter shifting: :math:`B(x) = A(x + c)`. 458 459 Special Cases:: 460 461 [NUMPY,CUPY] \shift = 0 => self 462 [DASK] \shift = 0 => rules below apply ... 463 ... because we don't force evaluation of \shift for performance reasons. 464 465 Else:: 466 467 |--------------------------|------------|-----------------------------------------------------------------| 468 | Property | Preserved? | Arithmetic Update Rule(s) | 469 |--------------------------|------------|-----------------------------------------------------------------| 470 | CAN_EVAL | yes | op_new.apply(arr) = op_old.apply(arr + \shift) | 471 | | | op_new.lipschitz = op_old.lipschitz | 472 |--------------------------|------------|-----------------------------------------------------------------| 473 | FUNCTIONAL | yes | | 474 |--------------------------|------------|-----------------------------------------------------------------| 475 | PROXIMABLE | yes | op_new.prox(arr, tau) = op_old.prox(arr + \shift, tau) - \shift | 476 |--------------------------|------------|-----------------------------------------------------------------| 477 | DIFFERENTIABLE | yes | op_new.diff_lipschitz = op_old.diff_lipschitz | 478 | | | op_new.jacobian(arr) = op_old.jacobian(arr + \shift) | 479 |--------------------------|------------|-----------------------------------------------------------------| 480 | DIFFERENTIABLE_FUNCTION | yes | op_new.grad(arr) = op_old.grad(arr + \shift) | 481 |--------------------------|------------|-----------------------------------------------------------------| 482 | QUADRATIC | yes | Q, c, t = op_old._quad_spec() | 483 | | | op_new._quad_spec() = (Q, c + Q @ \shift, op_old.apply(\shift)) | 484 |--------------------------|------------|-----------------------------------------------------------------| 485 | LINEAR | no | | 486 |--------------------------|------------|-----------------------------------------------------------------| 487 | LINEAR_SQUARE | no | | 488 |--------------------------|------------|-----------------------------------------------------------------| 489 | LINEAR_NORMAL | no | | 490 |--------------------------|------------|-----------------------------------------------------------------| 491 | LINEAR_UNITARY | no | | 492 |--------------------------|------------|-----------------------------------------------------------------| 493 | LINEAR_SELF_ADJOINT | no | | 494 |--------------------------|------------|-----------------------------------------------------------------| 495 | LINEAR_POSITIVE_DEFINITE | no | | 496 |--------------------------|------------|-----------------------------------------------------------------| 497 | LINEAR_IDEMPOTENT | no | | 498 |--------------------------|------------|-----------------------------------------------------------------| 499 """ 500 501 def __init__(self, op: pxt.OpT, cst: pxt.NDArray): 502 super().__init__() 503 self._op = op 504 505 xp = pxu.get_array_module(cst) 506 try: 507 xp.broadcast_to(cst, op.dim_shape) 508 except ValueError: 509 error_msg = "`cst` must be broadcastable with operator dimensions: " 510 error_msg += f"expected broadcastable-to {op.dim_shape}, got {cst.shape}." 511 raise ValueError(error_msg) 512 self._cst = cst 513 514 def op(self) -> pxt.OpT: 515 N = pxd.NDArrayInfo # short-hand 516 ndi = N.from_obj(self._cst) 517 if ndi == N.DASK: 518 no_op = False 519 else: # NUMPY/CUPY 520 xp = ndi.module() 521 norm = xp.sum(self._cst) ** 2 522 no_op = xp.allclose(norm, 0) 523 524 if no_op: 525 op = self._op 526 else: 527 klass = self._infer_op_klass() 528 op = klass( 529 dim_shape=self._op.dim_shape, 530 codim_shape=self._op.codim_shape, 531 ) 532 op._op = self._op # embed for introspection 533 op._cst = self._cst # embed for introspection 534 for p in op.properties(): 535 for name in p.arithmetic_methods(): 536 func = getattr(self.__class__, name) 537 setattr(op, name, types.MethodType(func, op)) 538 self._propagate_constants(op) 539 return op 540 541 def _expr(self) -> tuple: 542 return ("argshift", self._op, self._cst.shape) 543 544 def _infer_op_klass(self) -> pxt.OpC: 545 preserved = { 546 pxo.Property.CAN_EVAL, 547 pxo.Property.FUNCTIONAL, 548 pxo.Property.PROXIMABLE, 549 pxo.Property.DIFFERENTIABLE, 550 pxo.Property.DIFFERENTIABLE_FUNCTION, 551 pxo.Property.QUADRATIC, 552 } 553 554 properties = self._op.properties() & preserved 555 klass = pxo.Operator._infer_operator_type(properties) 556 return klass 557 558 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 559 x = arr.copy() 560 x += self._cst 561 out = self._op.apply(x) 562 return out 563 564 def estimate_lipschitz(self, **kwargs) -> pxt.Real: 565 no_eval = "__rule" in kwargs 566 if no_eval: 567 L = self._op.lipschitz 568 else: 569 L = self._op.estimate_lipschitz(**kwargs) 570 return L 571 572 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray: 573 x = arr.copy() 574 x += self._cst 575 out = self._op.prox(x, tau) 576 out -= self._cst 577 return out 578 579 def _quad_spec(self): 580 Q1, c1, t1 = self._op._quad_spec() 581 582 xp = pxu.get_array_module(self._cst) 583 cst = xp.broadcast_to(self._cst, self._op.dim_shape) 584 585 Q2 = Q1 586 c2 = c1 + pxo.LinFunc.from_array( 587 A=Q1.apply(cst)[np.newaxis, ...], 588 dim_rank=self._op.dim_rank, 589 enable_warnings=False, 590 # [enable_warnings] API users have no reason to call _quad_spec(). 591 # If they choose to use `c2`, then we assume they know what they are doing. 592 ) 593 t2 = float(self._op.apply(cst)[0]) 594 595 return (Q2, c2, t2) 596 597 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT: 598 x = arr.copy() 599 x += self._cst 600 op = self._op.jacobian(x) 601 return op 602 603 def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real: 604 no_eval = "__rule" in kwargs 605 if no_eval: 606 dL = self._op.diff_lipschitz 607 else: 608 dL = self._op.estimate_diff_lipschitz(**kwargs) 609 return dL 610 611 def grad(self, arr: pxt.NDArray) -> pxt.NDArray: 612 x = arr.copy() 613 x += self._cst 614 out = self._op.grad(x) 615 return out
616 617
[docs] 618class AddRule(Rule): 619 r""" 620 Arithmetic rules for operator addition: :math:`C(x) = A(x) + B(x)`. 621 622 The output type of ``AddRule(A, B)`` is summarized in the table below (LHS/RHS commute):: 623 624 |---------------|-----|------|---------|----------|----------|--------------|-----------|---------|--------------|------------|------------|------------|---------------|---------------|------------|---------------| 625 | LHS / RHS | Map | Func | DiffMap | DiffFunc | ProxFunc | ProxDiffFunc | Quadratic | LinOp | LinFunc | SquareOp | NormalOp | UnitOp | SelfAdjointOp | PosDefOp | ProjOp | OrthProjOp | 626 |---------------|-----|------|---------|----------|----------|--------------|-----------|---------|--------------|------------|------------|------------|---------------|---------------|------------|---------------| 627 | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | 628 | Func | | Func | Map | Func | Func | Func | Func | Map | Func | Map | Map | Map | Map | Map | Map | Map | 629 | DiffMap | | | DiffMap | DiffMap | Map | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | 630 | DiffFunc | | | | DiffFunc | Func | DiffFunc | DiffFunc | DiffMap | DiffFunc | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | 631 | ProxFunc | | | | | Func | Func | Func | Map | ProxFunc | Map | Map | Map | Map | Map | Map | Map | 632 | ProxDiffFunc | | | | | | DiffFunc | DiffFunc | DiffMap | ProxDiffFunc | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | 633 | Quadratic | | | | | | | Quadratic | DiffMap | Quadratic | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | 634 | LinOp | | | | | | | | LinOp | LinOp | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | 635 | LinFunc | | | | | | | | | LinFunc | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | 636 | SquareOp | | | | | | | | | | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | 637 | NormalOp | | | | | | | | | | | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | 638 | UnitOp | | | | | | | | | | | | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | 639 | SelfAdjointOp | | | | | | | | | | | | | SelfAdjointOp | SelfAdjointOp | SquareOp | SelfAdjointOp | 640 | PosDefOp | | | | | | | | | | | | | | PosDefOp | SquareOp | PosDefOp | 641 | ProjOp | | | | | | | | | | | | | | | SquareOp | SquareOp | 642 | OrthProjOp | | | | | | | | | | | | | | | | SelfAdjointOp | 643 |---------------|-----|------|---------|----------|----------|--------------|-----------|---------|--------------|------------|------------|------------|---------------|---------------|------------|---------------| 644 645 Arithmetic Update Rule(s):: 646 647 * CAN_EVAL 648 op.apply(arr) = _lhs.apply(arr) + _rhs.apply(arr) 649 op.lipschitz = _lhs.lipschitz + _rhs.lipschitz 650 IMPORTANT: if range-broadcasting takes place (ex: LHS(1,) + RHS(M,)), then the broadcasted 651 operand's Lipschitz constant must be magnified by \sqrt{M}. 652 653 * PROXIMABLE 654 op.prox(arr, tau) = _lhs.prox(arr - tau * _rhs.grad(arr), tau) 655 OR = _rhs.prox(arr - tau * _lhs.grad(arr), tau) 656 IMPORTANT: the one calling .grad() should be either (lhs, rhs) which has LINEAR property 657 658 * DIFFERENTIABLE 659 op.jacobian(arr) = _lhs.jacobian(arr) + _rhs.jacobian(arr) 660 op.diff_lipschitz = _lhs.diff_lipschitz + _rhs.diff_lipschitz 661 IMPORTANT: if range-broadcasting takes place (ex: LHS(1,) + RHS(M,)), then the broadcasted 662 operand's diff-Lipschitz constant must be magnified by \sqrt{M}. 663 664 * DIFFERENTIABLE_FUNCTION 665 op.grad(arr) = _lhs.grad(arr) + _rhs.grad(arr) 666 667 * LINEAR 668 op.adjoint(arr) = _lhs.adjoint(arr) + _rhs.adjoint(arr) 669 IMPORTANT: if range-broadcasting takes place (ex: LHS(1,) + RHS(M,)), then the broadcasted 670 operand's adjoint-input must be averaged. 671 op.asarray() = _lhs.asarray() + _rhs.asarray() 672 op.gram() = _lhs.gram() + _rhs.gram() + (_lhs.T * _rhs) + (_rhs.T * _lhs) 673 op.cogram() = _lhs.cogram() + _rhs.cogram() + (_lhs * _rhs.T) + (_rhs * _lhs.T) 674 675 * LINEAR_SQUARE 676 op.trace() = _lhs.trace() + _rhs.trace() 677 678 * QUADRATIC 679 lhs = rhs = quadratic 680 Q_l, c_l, t_l = lhs._quad_spec() 681 Q_r, c_r, t_r = rhs._quad_spec() 682 op._quad_spec() = (Q_l + Q_r, c_l + c_r, t_l + t_r) 683 lhs, rhs = quadratic, linear 684 Q, c, t = lhs._quad_spec() 685 op._quad_spec() = (Q, c + rhs, t) 686 """ 687 688 def __init__(self, lhs: pxt.OpT, rhs: pxt.OpT): 689 assert lhs.dim_shape == rhs.dim_shape, "Operator dimensions are not compatible." 690 try: 691 codim_bcast = np.broadcast_shapes(lhs.codim_shape, rhs.codim_shape) 692 except ValueError: 693 error_msg = "`lhs/rhs` codims must be broadcastable: " 694 error_msg += f"got {lhs.codim_shape}, {rhs.codim_shape}." 695 raise ValueError(error_msg) 696 697 if codim_bcast != lhs.codim_shape: 698 from pyxu.operator import BroadcastAxes 699 700 bcast = BroadcastAxes( 701 dim_shape=lhs.codim_shape, 702 codim_shape=codim_bcast, 703 ) 704 lhs = bcast * lhs 705 if codim_bcast != rhs.codim_shape: 706 from pyxu.operator import BroadcastAxes 707 708 bcast = BroadcastAxes( 709 dim_shape=rhs.codim_shape, 710 codim_shape=codim_bcast, 711 ) 712 rhs = bcast * rhs 713 714 super().__init__() 715 self._lhs = lhs 716 self._rhs = rhs 717 718 def op(self) -> pxt.OpT: 719 # LHS/RHS have same dim/codim following __init__() 720 dim_shape = self._rhs.dim_shape 721 codim_shape = self._rhs.codim_shape 722 klass = self._infer_op_klass(dim_shape, codim_shape) 723 724 if klass.has(pxo.Property.QUADRATIC): 725 # Quadratic additive arithmetic differs substantially from other arithmetic operations. 726 # To avoid tedious redefinitions of arithmetic methods to handle QuadraticFunc 727 # specifically, the code-path below delegates additive arithmetic directly to 728 # QuadraticFunc. 729 lin = lambda _: _.has(pxo.Property.LINEAR) 730 quad = lambda _: _.has(pxo.Property.QUADRATIC) 731 732 if quad(self._lhs) and quad(self._rhs): 733 lQ, lc, lt = self._lhs._quad_spec() 734 rQ, rc, rt = self._rhs._quad_spec() 735 op = klass( 736 dim_shape=dim_shape, 737 codim_shape=1, 738 Q=lQ + rQ, 739 c=lc + rc, 740 t=lt + rt, 741 ) 742 elif quad(self._lhs) and lin(self._rhs): 743 lQ, lc, lt = self._lhs._quad_spec() 744 op = klass( 745 dim_shape=dim_shape, 746 codim_shape=1, 747 Q=lQ, 748 c=lc + self._rhs, 749 t=lt, 750 ) 751 elif lin(self._lhs) and quad(self._rhs): 752 rQ, rc, rt = self._rhs._quad_spec() 753 op = klass( 754 dim_shape=dim_shape, 755 codim_shape=1, 756 Q=rQ, 757 c=self._lhs + rc, 758 t=rt, 759 ) 760 else: 761 raise ValueError("Impossible scenario: something went wrong during klass inference.") 762 else: 763 op = klass( 764 dim_shape=dim_shape, 765 codim_shape=codim_shape, 766 ) 767 op._lhs = self._lhs # embed for introspection 768 op._rhs = self._rhs # embed for introspection 769 for p in op.properties(): 770 for name in p.arithmetic_methods(): 771 func = getattr(self.__class__, name) 772 setattr(op, name, types.MethodType(func, op)) 773 self._propagate_constants(op) 774 return op 775 776 def _expr(self) -> tuple: 777 return ("add", self._lhs, self._rhs) 778 779 def _infer_op_klass( 780 self, 781 dim_shape: pxt.NDArrayShape, 782 codim_shape: pxt.NDArrayShape, 783 ) -> pxt.OpC: 784 P = pxo.Property 785 lhs_p = self._lhs.properties() 786 rhs_p = self._rhs.properties() 787 base = set(lhs_p & rhs_p) 788 base.discard(pxo.Property.LINEAR_NORMAL) 789 base.discard(pxo.Property.LINEAR_UNITARY) 790 base.discard(pxo.Property.LINEAR_IDEMPOTENT) 791 base.discard(pxo.Property.PROXIMABLE) 792 793 # Exceptions ---------------------------------------------------------- 794 # normality preserved for self-adjoint addition 795 if P.LINEAR_SELF_ADJOINT in base: 796 base.add(P.LINEAR_NORMAL) 797 798 # orth-proj + pos-def => pos-def 799 if (({P.LINEAR_IDEMPOTENT, P.LINEAR_SELF_ADJOINT} < lhs_p) and (P.LINEAR_POSITIVE_DEFINITE in rhs_p)) or ( 800 ({P.LINEAR_IDEMPOTENT, P.LINEAR_SELF_ADJOINT} < rhs_p) and (P.LINEAR_POSITIVE_DEFINITE in lhs_p) 801 ): 802 base.add(P.LINEAR_SQUARE) 803 base.add(P.LINEAR_NORMAL) 804 base.add(P.LINEAR_SELF_ADJOINT) 805 base.add(P.LINEAR_POSITIVE_DEFINITE) 806 807 # linfunc + (square-shape) => square 808 if P.LINEAR in base: 809 dim_size = np.prod(dim_shape) 810 codim_size = np.prod(codim_shape) 811 if (dim_size == codim_size) and (codim_shape != (1,)): 812 base.add(P.LINEAR_SQUARE) 813 814 # quadratic + quadratic => quadratic 815 if P.QUADRATIC in base: 816 base.add(P.PROXIMABLE) 817 818 # quadratic + linfunc => quadratic 819 if (P.PROXIMABLE in (lhs_p & rhs_p)) and ({P.QUADRATIC, P.LINEAR} < (lhs_p | rhs_p)): 820 base.add(P.QUADRATIC) 821 822 # prox(-diff) + linfunc => prox(-diff) 823 if (P.PROXIMABLE in (lhs_p & rhs_p)) and (P.LINEAR in (lhs_p | rhs_p)): 824 base.add(P.PROXIMABLE) 825 826 klass = pxo.Operator._infer_operator_type(base) 827 return klass 828 829 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 830 out = pxu.copy_if_unsafe(self._lhs.apply(arr)) 831 out += self._rhs.apply(arr) 832 return out 833 834 def estimate_lipschitz(self, **kwargs) -> pxt.Real: 835 no_eval = "__rule" in kwargs 836 if no_eval: 837 L_lhs = self._lhs.lipschitz 838 L_rhs = self._rhs.lipschitz 839 elif self.has(pxo.Property.LINEAR): 840 L = self.__class__.estimate_lipschitz(self, **kwargs) 841 return L 842 else: 843 L_lhs = self._lhs.estimate_lipschitz(**kwargs) 844 L_rhs = self._rhs.estimate_lipschitz(**kwargs) 845 846 L = L_lhs + L_rhs 847 return L 848 849 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray: 850 P_LHS = self._lhs.properties() 851 P_RHS = self._rhs.properties() 852 if pxo.Property.LINEAR in (P_LHS | P_RHS): 853 # linear + proximable 854 if pxo.Property.LINEAR in P_LHS: 855 P, G = self._rhs, self._lhs 856 elif pxo.Property.LINEAR in P_RHS: 857 P, G = self._lhs, self._rhs 858 x = pxu.copy_if_unsafe(G.grad(arr)) 859 x *= -tau 860 x += arr 861 out = P.prox(x, tau) 862 else: 863 raise NotImplementedError 864 return out 865 866 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT: 867 if self.has(pxo.Property.LINEAR): 868 op = self 869 else: 870 op_lhs = self._lhs.jacobian(arr) 871 op_rhs = self._rhs.jacobian(arr) 872 op = op_lhs + op_rhs 873 return op 874 875 def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real: 876 no_eval = "__rule" in kwargs 877 if no_eval: 878 dL_lhs = self._lhs.diff_lipschitz 879 dL_rhs = self._rhs.diff_lipschitz 880 elif self.has(pxo.Property.LINEAR): 881 dL_lhs = 0 882 dL_rhs = 0 883 else: 884 dL_lhs = self._lhs.estimate_diff_lipschitz(**kwargs) 885 dL_rhs = self._rhs.estimate_diff_lipschitz(**kwargs) 886 887 dL = dL_lhs + dL_rhs 888 return dL 889 890 def grad(self, arr: pxt.NDArray) -> pxt.NDArray: 891 out = self._lhs.grad(arr) 892 out = pxu.copy_if_unsafe(out) 893 out += self._rhs.grad(arr) 894 return out 895 896 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 897 out = self._lhs.adjoint(arr) 898 out = pxu.copy_if_unsafe(out) 899 out += self._rhs.adjoint(arr) 900 return out 901 902 def asarray(self, **kwargs) -> pxt.NDArray: 903 A = self._lhs.asarray(**kwargs) 904 A = pxu.copy_if_unsafe(A) 905 A += self._rhs.asarray(**kwargs) 906 return A 907 908 def gram(self) -> pxt.OpT: 909 op1 = self._lhs.gram() 910 op2 = self._rhs.gram() 911 op3 = self._lhs.T * self._rhs 912 op4 = self._rhs.T * self._lhs 913 op = op1 + op2 + (op3 + op4).asop(pxo.SelfAdjointOp) 914 return op 915 916 def cogram(self) -> pxt.OpT: 917 op1 = self._lhs.cogram() 918 op2 = self._rhs.cogram() 919 op3 = self._lhs * self._rhs.T 920 op4 = self._rhs * self._lhs.T 921 op = op1 + op2 + (op3 + op4).asop(pxo.SelfAdjointOp) 922 return op 923 924 def trace(self, **kwargs) -> pxt.Real: 925 tr = 0 926 for side in (self._lhs, self._rhs): 927 tr += side.trace(**kwargs) 928 return float(tr)
929 930
[docs] 931class ChainRule(Rule): 932 r""" 933 Arithmetic rules for operator composition: :math:`C(x) = (A \circ B)(x)`. 934 935 The output type of ``ChainRule(A, B)`` is summarized in the table below:: 936 937 |---------------|------|------------|----------|------------|------------|----------------|----------------------|------------------|------------|-----------|-----------|--------------|---------------|-----------|-----------|------------| 938 | LHS / RHS | Map | Func | DiffMap | DiffFunc | ProxFunc | ProxDiffFunc | Quadratic | LinOp | LinFunc | SquareOp | NormalOp | UnitOp | SelfAdjointOp | PosDefOp | ProjOp | OrthProjOp | 939 |---------------|------|------------|----------|------------|------------|----------------|----------------------|------------------|------------|-----------|-----------|--------------|---------------|-----------|-----------|------------| 940 | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | 941 | Func | Func | Func | Func | Func | Func | Func | Func | Func | Func | Func | Func | Func | Func | Func | Func | Func | 942 | DiffMap | Map | Map | DiffMap | DiffMap | Map | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | 943 | DiffFunc | Func | Func | DiffFunc | DiffFunc | Func | DiffFunc | DiffFunc | DiffFunc | DiffFunc | DiffFunc | DiffFunc | DiffFunc | DiffFunc | DiffFunc | DiffFunc | DiffFunc | 944 | ProxFunc | Func | Func | Func | Func | Func | Func | Func | Func | Func | Func | Func | ProxFunc | Func | Func | Func | Func | 945 | ProxDiffFunc | Func | Func | DiffFunc | DiffFunc | Func | DiffFunc | DiffFunc | DiffFunc | DiffFunc | DiffFunc | DiffFunc | ProxDiffFunc | DiffFunc | DiffFunc | DiffFunc | DiffFunc | 946 | Quadratic | Func | Func | DiffFunc | DiffFunc | Func | DiffFunc | DiffFunc | Quadratic | Quadratic | Quadratic | Quadratic | Quadratic | Quadratic | Quadratic | Quadratic | Quadratic | 947 | LinOp | Map | Func | DiffMap | DiffMap | Map | DiffMap | DiffMap | LinOp / SquareOp | LinOp | LinOp | LinOp | LinOp | LinOp | LinOp | LinOp | LinOp | 948 | LinFunc | Func | Func | DiffFunc | DiffFunc | [Prox]Func | [Prox]DiffFunc | DiffFunc / Quadratic | LinFunc | LinFunc | LinFunc | LinFunc | LinFunc | LinFunc | LinFunc | LinFunc | LinFunc | 949 | SquareOp | Map | IMPOSSIBLE | DiffMap | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | LinOp | IMPOSSIBLE | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | 950 | NormalOp | Map | IMPOSSIBLE | DiffMap | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | LinOp | IMPOSSIBLE | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | 951 | UnitOp | Map | IMPOSSIBLE | DiffMap | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | LinOp | IMPOSSIBLE | SquareOp | SquareOp | UnitOp | SquareOp | SquareOp | SquareOp | SquareOp | 952 | SelfAdjointOp | Map | IMPOSSIBLE | DiffMap | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | LinOp | IMPOSSIBLE | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | 953 | PosDefOp | Map | IMPOSSIBLE | DiffMap | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | LinOp | IMPOSSIBLE | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | 954 | ProjOp | Map | IMPOSSIBLE | DiffMap | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | LinOp | IMPOSSIBLE | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | 955 | OrthProjOp | Map | IMPOSSIBLE | DiffMap | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | LinOp | IMPOSSIBLE | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | 956 |---------------|------|------------|----------|------------|------------|----------------|----------------------|------------------|------------|-----------|-----------|--------------|---------------|-----------|-----------|------------| 957 958 Arithmetic Update Rule(s):: 959 960 * CAN_EVAL 961 op.apply(arr) = _lhs.apply(_rhs.apply(arr)) 962 op.lipschitz = _lhs.lipschitz * _rhs.lipschitz 963 964 * PROXIMABLE (RHS Unitary only) 965 op.prox(arr, tau) = _rhs.adjoint(_lhs.prox(_rhs.apply(arr), tau)) 966 967 * DIFFERENTIABLE 968 op.jacobian(arr) = _lhs.jacobian(_rhs.apply(arr)) * _rhs.jacobian(arr) 969 op.diff_lipschitz = 970 quadratic => _quad_spec().Q.lipschitz 971 linear \comp linear => 0 972 linear \comp diff => _lhs.lipschitz * _rhs.diff_lipschitz 973 diff \comp linear => _lhs.diff_lipschitz * (_rhs.lipschitz ** 2) 974 diff \comp diff => \infty 975 976 * DIFFERENTIABLE_FUNCTION (1D input) 977 op.grad(arr) = _lhs.grad(_rhs.apply(arr)) @ _rhs.jacobian(arr).asarray() 978 979 * LINEAR 980 op.adjoint(arr) = _rhs.adjoint(_lhs.adjoint(arr)) 981 op.asarray() = _lhs.asarray() @ _rhs.asarray() 982 op.gram() = _rhs.T @ _lhs.gram() @ _rhs 983 op.cogram() = _lhs @ _rhs.cogram() @ _lhs.T 984 985 * QUADRATIC 986 Q, c, t = _lhs._quad_spec() 987 op._quad_spec() = (_rhs.T * Q * _rhs, _rhs.T * c, t) 988 """ 989 990 def __init__(self, lhs: pxt.OpT, rhs: pxt.OpT): 991 assert lhs.dim_shape == rhs.codim_shape, "Operator dimensions are not compatible." 992 993 super().__init__() 994 self._lhs = lhs 995 self._rhs = rhs 996 997 def op(self) -> pxt.OpT: 998 klass = self._infer_op_klass() 999 op = klass( 1000 dim_shape=self._rhs.dim_shape, 1001 codim_shape=self._lhs.codim_shape, 1002 ) 1003 op._lhs = self._lhs # embed for introspection 1004 op._rhs = self._rhs # embed for introspection 1005 for p in op.properties(): 1006 for name in p.arithmetic_methods(): 1007 func = getattr(self.__class__, name) 1008 setattr(op, name, types.MethodType(func, op)) 1009 self._propagate_constants(op) 1010 return op 1011 1012 def _expr(self) -> tuple: 1013 return ("compose", self._lhs, self._rhs) 1014 1015 def _infer_op_klass(self) -> pxt.OpC: 1016 # |--------------------------|------------------------------------------------------| 1017 # | Property | Preserved? | 1018 # |--------------------------|------------------------------------------------------| 1019 # | CAN_EVAL | (LHS CAN_EVAL) & (RHS CAN_EVAL) | 1020 # |--------------------------|------------------------------------------------------| 1021 # | FUNCTIONAL | LHS FUNCTIONAL | 1022 # |--------------------------|------------------------------------------------------| 1023 # | PROXIMABLE | * (LHS PROXIMABLE) & (RHS LINEAR_UNITARY) | 1024 # | | * (LHS LINEAR) & (RHS LINEAR) | 1025 # | | * (LHS LINEAR FUNCTIONAL [> 0]) & (RHS PROXIMABLE) | 1026 # |--------------------------|------------------------------------------------------| 1027 # | DIFFERENTIABLE | (LHS DIFFERENTIABLE) & (RHS DIFFERENTIABLE) | 1028 # |--------------------------|------------------------------------------------------| 1029 # | DIFFERENTIABLE_FUNCTION | (LHS DIFFERENTIABLE_FUNCTION) & (RHS DIFFERENTIABLE) | 1030 # |--------------------------|------------------------------------------------------| 1031 # | QUADRATIC | * (LHS QUADRATIC) & (RHS LINEAR) | 1032 # | | * (LHS LINEAR FUNCTIONAL [> 0]) & (RHS QUADRATIC) | 1033 # |--------------------------|------------------------------------------------------| 1034 # | LINEAR | (LHS LINEAR) & (RHS LINEAR) | 1035 # |--------------------------|------------------------------------------------------| 1036 # | LINEAR_SQUARE | (Shape[LHS * RHS] square) & (LHS.codim > 1) | 1037 # |--------------------------|------------------------------------------------------| 1038 # | LINEAR_NORMAL | no | 1039 # |--------------------------|------------------------------------------------------| 1040 # | LINEAR_UNITARY | (LHS LINEAR_UNITARY) & (RHS LINEAR_UNITARY) | 1041 # |--------------------------|------------------------------------------------------| 1042 # | LINEAR_SELF_ADJOINT | no | 1043 # |--------------------------|------------------------------------------------------| 1044 # | LINEAR_POSITIVE_DEFINITE | no | 1045 # |--------------------------|------------------------------------------------------| 1046 # | LINEAR_IDEMPOTENT | no | 1047 # |--------------------------|------------------------------------------------------| 1048 lhs_p = self._lhs.properties() 1049 rhs_p = self._rhs.properties() 1050 P = pxo.Property 1051 properties = {P.CAN_EVAL} 1052 if P.FUNCTIONAL in lhs_p: 1053 properties.add(P.FUNCTIONAL) 1054 # Proximal ------------------------------ 1055 if (P.PROXIMABLE in lhs_p) and (P.LINEAR_UNITARY in rhs_p): 1056 properties.add(P.PROXIMABLE) 1057 elif ({P.LINEAR, P.FUNCTIONAL} < lhs_p) and (P.PROXIMABLE in rhs_p): 1058 cst = self._lhs.asarray().item() 1059 if cst > 0: 1060 properties.add(P.PROXIMABLE) 1061 if P.QUADRATIC in rhs_p: 1062 properties.add(P.QUADRATIC) 1063 # --------------------------------------- 1064 if P.DIFFERENTIABLE in (lhs_p & rhs_p): 1065 properties.add(P.DIFFERENTIABLE) 1066 if (P.DIFFERENTIABLE_FUNCTION in lhs_p) and (P.DIFFERENTIABLE in rhs_p): 1067 properties.add(P.DIFFERENTIABLE_FUNCTION) 1068 if (P.QUADRATIC in lhs_p) and (P.LINEAR in rhs_p): 1069 properties.add(P.PROXIMABLE) 1070 properties.add(P.QUADRATIC) 1071 if P.LINEAR in (lhs_p & rhs_p): 1072 properties.add(P.LINEAR) 1073 if self._lhs.codim_shape == (1,): 1074 for p in pxo.LinFunc.properties(): 1075 properties.add(p) 1076 if (self._lhs.codim_size == self._rhs.dim_size) and (self._rhs.dim_size > 1): 1077 properties.add(P.LINEAR_SQUARE) 1078 if P.LINEAR_UNITARY in (lhs_p & rhs_p): 1079 properties.add(P.LINEAR_NORMAL) 1080 properties.add(P.LINEAR_UNITARY) 1081 1082 klass = pxo.Operator._infer_operator_type(properties) 1083 return klass 1084 1085 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 1086 x = self._rhs.apply(arr) 1087 out = self._lhs.apply(x) 1088 return out 1089 1090 def estimate_lipschitz(self, **kwargs) -> pxt.Real: 1091 no_eval = "__rule" in kwargs 1092 if no_eval: 1093 L_lhs = self._lhs.lipschitz 1094 L_rhs = self._rhs.lipschitz 1095 elif self.has(pxo.Property.LINEAR): 1096 L = self.__class__.estimate_lipschitz(self, **kwargs) 1097 return L 1098 else: 1099 L_lhs = self._lhs.estimate_lipschitz(**kwargs) 1100 L_rhs = self._rhs.estimate_lipschitz(**kwargs) 1101 1102 zeroQ = lambda _: np.isclose(_, 0) 1103 if zeroQ(L_lhs) or zeroQ(L_rhs): 1104 L = 0 1105 else: 1106 L = L_lhs * L_rhs 1107 return L 1108 1109 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray: 1110 if self.has(pxo.Property.PROXIMABLE): 1111 out = None 1112 if self._lhs.has(pxo.Property.PROXIMABLE) and self._rhs.has(pxo.Property.LINEAR_UNITARY): 1113 # prox[diff]func() \comp unitop() => prox[diff]func() 1114 x = self._rhs.apply(arr) 1115 y = self._lhs.prox(x, tau) 1116 out = self._rhs.adjoint(y) 1117 elif self._lhs.has(pxo.Property.QUADRATIC) and self._rhs.has(pxo.Property.LINEAR): 1118 # quadratic \comp linop => quadratic 1119 Q, c, t = self._quad_spec() 1120 op = pxo.QuadraticFunc( 1121 dim_shape=self.dim_shape, 1122 codim_shape=self.codim_shape, 1123 Q=Q, 1124 c=c, 1125 t=t, 1126 ) 1127 out = op.prox(arr, tau) 1128 elif self._lhs.has(pxo.Property.LINEAR) and self._rhs.has(pxo.Property.PROXIMABLE): 1129 # linfunc() \comp prox[diff]func() => prox[diff]func() 1130 # = (\alpha * prox[diff]func()) 1131 op = ScaleRule(op=self._rhs, cst=self._lhs.asarray().item()).op() 1132 out = op.prox(arr, tau) 1133 elif pxo.Property.LINEAR in (self._lhs.properties() & self._rhs.properties()): 1134 # linfunc() \comp linop() => linfunc() 1135 out = pxo.LinFunc.prox(self, arr, tau) 1136 1137 if out is not None: 1138 return out 1139 raise NotImplementedError 1140 1141 def _quad_spec(self): 1142 if self.has(pxo.Property.QUADRATIC): 1143 if self._lhs.has(pxo.Property.LINEAR): 1144 # linfunc (scalar) \comp quadratic 1145 op = ScaleRule(op=self._rhs, cst=self._lhs.asarray().item()).op() 1146 Q2, c2, t2 = op._quad_spec() 1147 elif self._rhs.has(pxo.Property.LINEAR): 1148 # quadratic \comp linop 1149 Q1, c1, t1 = self._lhs._quad_spec() 1150 Q2 = (self._rhs.T * Q1 * self._rhs).asop(pxo.PosDefOp) 1151 c2 = c1 * self._rhs 1152 t2 = t1 1153 return (Q2, c2, t2) 1154 else: 1155 raise NotImplementedError 1156 1157 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT: 1158 if self.has(pxo.Property.LINEAR): 1159 op = self 1160 else: 1161 J_rhs = self._rhs.jacobian(arr) 1162 J_lhs = self._lhs.jacobian(self._rhs.apply(arr)) 1163 op = J_lhs * J_rhs 1164 return op 1165 1166 def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real: 1167 no_eval = "__rule" in kwargs 1168 if self.has(pxo.Property.QUADRATIC): 1169 Q, c, t = self._quad_spec() 1170 op = pxo.QuadraticFunc( 1171 dim_shape=self.dim_shape, 1172 codim_shape=self.codim_shape, 1173 Q=Q, 1174 c=c, 1175 t=t, 1176 ) 1177 if no_eval: 1178 dL = op.diff_lipschitz 1179 else: 1180 dL = op.estimate_diff_lipschitz(**kwargs) 1181 elif self._lhs.has(pxo.Property.LINEAR) and self._rhs.has(pxo.Property.LINEAR): 1182 dL = 0 1183 elif self._lhs.has(pxo.Property.LINEAR) and self._rhs.has(pxo.Property.DIFFERENTIABLE): 1184 if no_eval: 1185 L_lhs = self._lhs.lipschitz 1186 dL_rhs = self._rhs.diff_lipschitz 1187 else: 1188 L_lhs = self._lhs.estimate_lipschitz(**kwargs) 1189 dL_rhs = self._rhs.estimate_diff_lipschitz(**kwargs) 1190 dL = L_lhs * dL_rhs 1191 elif self._lhs.has(pxo.Property.DIFFERENTIABLE) and self._rhs.has(pxo.Property.LINEAR): 1192 if no_eval: 1193 dL_lhs = self._lhs.diff_lipschitz 1194 L_rhs = self._rhs.lipschitz 1195 else: 1196 dL_lhs = self._lhs.estimate_diff_lipschitz(**kwargs) 1197 L_rhs = self._rhs.estimate_lipschitz(**kwargs) 1198 dL = dL_lhs * (L_rhs**2) 1199 else: 1200 dL = np.inf 1201 return dL 1202 1203 def grad(self, arr: pxt.NDArray) -> pxt.NDArray: 1204 sh = arr.shape[: -self.dim_rank] 1205 if (len(sh) == 0) or self._rhs.has(pxo.Property.LINEAR): 1206 x = self._lhs.grad(self._rhs.apply(arr)) 1207 out = self._rhs.jacobian(arr).adjoint(x) 1208 1209 # RHS.adjoint() may change core-chunks if (codim->dim) changes are involved. 1210 # This is problematic since grad() should preserve core-chunks by default. 1211 ndi = pxd.NDArrayInfo.from_obj(arr) 1212 if ndi == pxd.NDArrayInfo.DASK: 1213 if out.chunks != arr.chunks: 1214 out = out.rechunk(arr.chunks) 1215 else: 1216 # We need to evaluate the Jacobian seperately per stacked input. 1217 1218 @pxu.vectorize( 1219 i="arr", 1220 dim_shape=self.dim_shape, 1221 codim_shape=self.dim_shape, 1222 ) 1223 def f(arr: pxt.NDArray) -> pxt.NDArray: 1224 x = self._lhs.grad(self._rhs.apply(arr)) 1225 out = self._rhs.jacobian(arr).adjoint(x) 1226 1227 # RHS.adjoint() may change core-chunks if (codim->dim) changes are involved. 1228 # This is problematic since grad() should preserve core-chunks by default. 1229 ndi = pxd.NDArrayInfo.from_obj(arr) 1230 if ndi == pxd.NDArrayInfo.DASK: 1231 if out.chunks != arr.chunks: 1232 out = out.rechunk(arr.chunks) 1233 return out 1234 1235 out = f(arr) 1236 return out 1237 1238 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 1239 x = self._lhs.adjoint(arr) 1240 out = self._rhs.adjoint(x) 1241 return out 1242 1243 def asarray(self, **kwargs) -> pxt.NDArray: 1244 A_lhs = self._lhs.asarray(**kwargs) 1245 A_rhs = self._rhs.asarray(**kwargs) 1246 1247 xp = pxu.get_array_module(A_lhs) 1248 A = xp.tensordot(A_lhs, A_rhs, axes=self._lhs.dim_rank) 1249 return A 1250 1251 def gram(self) -> pxt.OpT: 1252 op = self._rhs.T * self._lhs.gram() * self._rhs 1253 return op.asop(pxo.SelfAdjointOp) 1254 1255 def cogram(self) -> pxt.OpT: 1256 op = self._lhs * self._rhs.cogram() * self._lhs.T 1257 return op.asop(pxo.SelfAdjointOp)
1258 1259
[docs] 1260class TransposeRule(Rule): 1261 # Not strictly-speaking an arithmetic method, but the logic behind constructing transposed 1262 # operators is identical to arithmetic methods. 1263 # LinOp.T() rules are hence summarized here. 1264 r""" 1265 Arithmetic rules for :py:class:`~pyxu.abc.LinOp` transposition: :math:`B(x) = A^{T}(x)`. 1266 1267 Arithmetic Update Rule(s):: 1268 1269 * CAN_EVAL 1270 opT.apply(arr) = op.adjoint(arr) 1271 opT.lipschitz = op.lipschitz 1272 1273 * PROXIMABLE 1274 opT.prox(arr, tau) = LinFunc.prox(arr, tau) 1275 1276 * DIFFERENTIABLE 1277 opT.jacobian(arr) = opT 1278 opT.diff_lipschitz = 0 1279 1280 * DIFFERENTIABLE_FUNCTION 1281 opT.grad(arr) = LinFunc.grad(arr) 1282 1283 * LINEAR 1284 opT.adjoint(arr) = op.apply(arr) 1285 opT.asarray() = op.asarray().T [block-reorder dim/codim] 1286 opT.gram() = op.cogram() 1287 opT.cogram() = op.gram() 1288 opT.svdvals() = op.svdvals() 1289 1290 * LINEAR_SQUARE 1291 opT.trace() = op.trace() 1292 """ 1293 1294 def __init__(self, op: pxt.OpT): 1295 super().__init__() 1296 self._op = op 1297 1298 def op(self) -> pxt.OpT: 1299 klass = self._infer_op_klass() 1300 op = klass( 1301 dim_shape=self._op.codim_shape, 1302 codim_shape=self._op.dim_shape, 1303 ) 1304 op._op = self._op # embed for introspection 1305 for p in op.properties(): 1306 for name in p.arithmetic_methods(): 1307 func = getattr(self.__class__, name) 1308 setattr(op, name, types.MethodType(func, op)) 1309 self._propagate_constants(op) 1310 return op 1311 1312 def _expr(self) -> tuple: 1313 return ("transpose", self._op) 1314 1315 def _infer_op_klass(self) -> pxt.OpC: 1316 # |--------------------------------|--------------------------------| 1317 # | op_klass(codim; dim) | opT_klass(codim; dim) | 1318 # |--------------------------------|--------------------------------| 1319 # | LINEAR(1; 1) | LinFunc(1; 1) | 1320 # | LinFunc(1; M1,...,MD) | LinOp(M1,...,MD; 1) | 1321 # | LinOp(N1,...,ND; 1) | LinFunc(1; N1,...,ND) | 1322 # | op_klass(N1,...,ND; M1,...,MD) | op_klass(M1,...,MD; N1,...,ND) | 1323 # |--------------------------------|--------------------------------| 1324 single_dim = self._op.dim_shape == (1,) 1325 single_codim = self._op.codim_shape == (1,) 1326 1327 if single_dim and single_codim: 1328 klass = pxo.LinFunc 1329 elif single_codim: 1330 klass = pxo.LinOp 1331 elif single_dim: 1332 klass = pxo.LinFunc 1333 else: 1334 prop = self._op.properties() 1335 klass = pxo.Operator._infer_operator_type(prop) 1336 return klass 1337 1338 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 1339 out = self._op.adjoint(arr) 1340 return out 1341 1342 def estimate_lipschitz(self, **kwargs) -> pxt.Real: 1343 no_eval = "__rule" in kwargs 1344 if no_eval: 1345 L = self._op.lipschitz 1346 else: 1347 L = self._op.estimate_lipschitz(**kwargs) 1348 return L 1349 1350 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray: 1351 out = pxo.LinFunc.prox(self, arr, tau) 1352 return out 1353 1354 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT: 1355 return self 1356 1357 def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real: 1358 return 0 1359 1360 def grad(self, arr: pxt.NDArray) -> pxt.NDArray: 1361 out = pxo.LinFunc.grad(self, arr) 1362 return out 1363 1364 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 1365 out = self._op.apply(arr) 1366 return out 1367 1368 def asarray(self, **kwargs) -> pxt.NDArray: 1369 A = self._op.asarray(**kwargs) 1370 B = A.transpose( 1371 *range(-self._op.dim_rank, 0), 1372 *range(self._op.codim_rank), 1373 ) 1374 return B 1375 1376 def gram(self) -> pxt.OpT: 1377 op = self._op.cogram() 1378 return op 1379 1380 def cogram(self) -> pxt.OpT: 1381 op = self._op.gram() 1382 return op 1383 1384 def svdvals(self, **kwargs) -> pxt.NDArray: 1385 D = self._op.svdvals(**kwargs) 1386 return D 1387 1388 def trace(self, **kwargs) -> pxt.Real: 1389 tr = self._op.trace(**kwargs) 1390 return tr