1# Arithmetic Rules
2
3import types
4
5import numpy as np
6
7import pyxu.abc.operator as pxo
8import pyxu.info.deps as pxd
9import pyxu.info.ptype as pxt
10import pyxu.util as pxu
11
12
[docs]
13class Rule:
14 """
15 General arithmetic rule.
16
17 This class defines default arithmetic rules applicable unless re-defined by sub-classes.
18 """
19
[docs]
20 def op(self) -> pxt.OpT:
21 """
22 Returns
23 -------
24 op: OpT
25 Synthesize operator.
26 """
27 raise NotImplementedError
28
29 # Helper Methods ----------------------------------------------------------
30
31 @staticmethod
32 def _propagate_constants(op: pxt.OpT):
33 # Propagate (diff-)Lipschitz constants forward via special call to
34 # Rule()-overridden `estimate_[diff_]lipschitz()` methods.
35
36 # Important: we write to _[diff_]lipschitz to not overwrite estimate_[diff_]lipschitz() methods.
37 if op.has(pxo.Property.CAN_EVAL):
38 op._lipschitz = op.estimate_lipschitz(__rule=True)
39 if op.has(pxo.Property.DIFFERENTIABLE):
40 op._diff_lipschitz = op.estimate_diff_lipschitz(__rule=True)
41
42 # Default Arithmetic Methods ----------------------------------------------
43 # Fallback on these when no simple form in terms of Rule.__init__() args is known.
44 # If a method from Property.arithmetic_methods() is not listed here, then all Rule subclasses
45 # provide an overloaded implementation.
46
47 def __call__(self, arr: pxt.NDArray) -> pxt.NDArray:
48 return self.apply(arr)
49
50 def svdvals(self, **kwargs) -> pxt.NDArray:
51 D = self.__class__.svdvals(self, **kwargs)
52 return D
53
54 def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray:
55 out = self.__class__.pinv(self, arr=arr, damp=damp, **kwargs)
56 return out
57
58 def trace(self, **kwargs) -> pxt.Real:
59 tr = self.__class__.trace(self, **kwargs)
60 return tr
61
62
[docs]
63class ScaleRule(Rule):
64 r"""
65 Arithmetic rules for element-wise scaling: :math:`B(x) = \alpha A(x)`.
66
67 Special Cases::
68
69 \alpha = 0 => NullOp/NullFunc
70 \alpha = 1 => self
71
72 Else::
73
74 |--------------------------|-------------|--------------------------------------------------------------------|
75 | Property | Preserved? | Arithmetic Update Rule(s) |
76 |--------------------------|-------------|--------------------------------------------------------------------|
77 | CAN_EVAL | yes | op_new.apply(arr) = op_old.apply(arr) * \alpha |
78 | | | op_new.lipschitz = op_old.lipschitz * abs(\alpha) |
79 |--------------------------|-------------|--------------------------------------------------------------------|
80 | FUNCTIONAL | yes | |
81 |--------------------------|-------------|--------------------------------------------------------------------|
82 | PROXIMABLE | \alpha > 0 | op_new.prox(arr, tau) = op_old.prox(arr, tau * \alpha) |
83 |--------------------------|-------------|--------------------------------------------------------------------|
84 | DIFFERENTIABLE | yes | op_new.jacobian(arr) = op_old.jacobian(arr) * \alpha |
85 | | | op_new.diff_lipschitz = op_old.diff_lipschitz * abs(\alpha) |
86 |--------------------------|-------------|--------------------------------------------------------------------|
87 | DIFFERENTIABLE_FUNCTION | yes | op_new.grad(arr) = op_old.grad(arr) * \alpha |
88 |--------------------------|-------------|--------------------------------------------------------------------|
89 | QUADRATIC | \alpha > 0 | Q, c, t = op_old._quad_spec() |
90 | | | op_new._quad_spec() = (\alpha * Q, \alpha * c, \alpha * t) |
91 |--------------------------|-------------|--------------------------------------------------------------------|
92 | LINEAR | yes | op_new.adjoint(arr) = op_old.adjoint(arr) * \alpha |
93 | | | op_new.asarray() = op_old.asarray() * \alpha |
94 | | | op_new.svdvals() = op_old.svdvals() * abs(\alpha) |
95 | | | op_new.pinv(x, damp) = op_old.pinv(x, damp / (\alpha**2)) / \alpha |
96 | | | op_new.gram() = op_old.gram() * (\alpha**2) |
97 | | | op_new.cogram() = op_old.cogram() * (\alpha**2) |
98 |--------------------------|-------------|--------------------------------------------------------------------|
99 | LINEAR_SQUARE | yes | op_new.trace() = op_old.trace() * \alpha |
100 |--------------------------|-------------|--------------------------------------------------------------------|
101 | LINEAR_NORMAL | yes | |
102 |--------------------------|-------------|--------------------------------------------------------------------|
103 | LINEAR_UNITARY | \alpha = -1 | |
104 |--------------------------|-------------|--------------------------------------------------------------------|
105 | LINEAR_SELF_ADJOINT | yes | |
106 |--------------------------|-------------|--------------------------------------------------------------------|
107 | LINEAR_POSITIVE_DEFINITE | \alpha > 0 | |
108 |--------------------------|-------------|--------------------------------------------------------------------|
109 | LINEAR_IDEMPOTENT | no | |
110 |--------------------------|-------------|--------------------------------------------------------------------|
111 """
112
113 def __init__(self, op: pxt.OpT, cst: pxt.Real):
114 super().__init__()
115 self._op = op
116 self._cst = float(cst)
117
118 def op(self) -> pxt.OpT:
119 if np.isclose(self._cst, 0):
120 from pyxu.operator import NullOp
121
122 op = NullOp(
123 dim_shape=self._op.dim_shape,
124 codim_shape=self._op.codim_shape,
125 )
126 elif np.isclose(self._cst, 1):
127 op = self._op
128 else:
129 klass = self._infer_op_klass()
130 op = klass(
131 dim_shape=self._op.dim_shape,
132 codim_shape=self._op.codim_shape,
133 )
134 op._op = self._op # embed for introspection
135 op._cst = self._cst # embed for introspection
136 for p in op.properties():
137 for name in p.arithmetic_methods():
138 func = getattr(self.__class__, name)
139 setattr(op, name, types.MethodType(func, op))
140 self._propagate_constants(op)
141 return op
142
143 def _expr(self) -> tuple:
144 return ("scale", self._op, self._cst)
145
146 def _infer_op_klass(self) -> pxt.OpC:
147 preserved = {
148 pxo.Property.CAN_EVAL,
149 pxo.Property.FUNCTIONAL,
150 pxo.Property.DIFFERENTIABLE,
151 pxo.Property.DIFFERENTIABLE_FUNCTION,
152 pxo.Property.LINEAR,
153 pxo.Property.LINEAR_SQUARE,
154 pxo.Property.LINEAR_NORMAL,
155 pxo.Property.LINEAR_SELF_ADJOINT,
156 }
157 if self._cst > 0:
158 preserved |= {
159 pxo.Property.LINEAR_POSITIVE_DEFINITE,
160 pxo.Property.QUADRATIC,
161 pxo.Property.PROXIMABLE,
162 }
163 if self._op.has(pxo.Property.LINEAR):
164 preserved.add(pxo.Property.PROXIMABLE)
165 if np.isclose(self._cst, -1):
166 preserved.add(pxo.Property.LINEAR_UNITARY)
167
168 properties = self._op.properties() & preserved
169 klass = pxo.Operator._infer_operator_type(properties)
170 return klass
171
172 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
173 out = pxu.copy_if_unsafe(self._op.apply(arr))
174 out *= self._cst
175 return out
176
177 def estimate_lipschitz(self, **kwargs) -> pxt.Real:
178 no_eval = "__rule" in kwargs
179 if no_eval:
180 L = float(self._op.lipschitz)
181 else:
182 L = self._op.estimate_lipschitz(**kwargs)
183 L *= abs(self._cst)
184 return L
185
186 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
187 return self._op.prox(arr, tau * self._cst)
188
189 def _quad_spec(self):
190 Q1, c1, t1 = self._op._quad_spec()
191 Q2 = ScaleRule(op=Q1, cst=self._cst).op()
192 c2 = ScaleRule(op=c1, cst=self._cst).op()
193 t2 = t1 * self._cst
194 return (Q2, c2, t2)
195
196 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT:
197 if self.has(pxo.Property.LINEAR):
198 op = self
199 else:
200 op = self._op.jacobian(arr) * self._cst
201 return op
202
203 def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real:
204 no_eval = "__rule" in kwargs
205 if no_eval:
206 dL = float(self._op.diff_lipschitz)
207 else:
208 dL = self._op.estimate_diff_lipschitz(**kwargs)
209 dL *= abs(self._cst)
210 return dL
211
212 def grad(self, arr: pxt.NDArray) -> pxt.NDArray:
213 out = pxu.copy_if_unsafe(self._op.grad(arr))
214 out *= self._cst
215 return out
216
217 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
218 out = pxu.copy_if_unsafe(self._op.adjoint(arr))
219 out *= self._cst
220 return out
221
222 def asarray(self, **kwargs) -> pxt.NDArray:
223 A = pxu.copy_if_unsafe(self._op.asarray(**kwargs))
224 A *= self._cst
225 return A
226
227 def svdvals(self, **kwargs) -> pxt.NDArray:
228 D = pxu.copy_if_unsafe(self._op.svdvals(**kwargs))
229 D *= abs(self._cst)
230 return D
231
232 def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray:
233 scale = damp / (self._cst**2)
234 out = pxu.copy_if_unsafe(self._op.pinv(arr, damp=scale, **kwargs))
235 out /= self._cst
236 return out
237
238 def gram(self) -> pxt.OpT:
239 op = self._op.gram() * (self._cst**2)
240 return op
241
242 def cogram(self) -> pxt.OpT:
243 op = self._op.cogram() * (self._cst**2)
244 return op
245
246 def trace(self, **kwargs) -> pxt.Real:
247 tr = self._op.trace(**kwargs) * self._cst
248 return tr
249
250
[docs]
251class ArgScaleRule(Rule):
252 r"""
253 Arithmetic rules for element-wise parameter scaling: :math:`B(x) = A(\alpha x)`.
254
255 Special Cases::
256
257 \alpha = 0 => ConstantValued (w/ potential vector-valued output)
258 \alpha = 1 => self
259
260 Else::
261
262 |--------------------------|-------------|-----------------------------------------------------------------------------|
263 | Property | Preserved? | Arithmetic Update Rule(s) |
264 |--------------------------|-------------|-----------------------------------------------------------------------------|
265 | CAN_EVAL | yes | op_new.apply(arr) = op_old.apply(arr * \alpha) |
266 | | | op_new.lipschitz = op_old.lipschitz * abs(\alpha) |
267 |--------------------------|-------------|-----------------------------------------------------------------------------|
268 | FUNCTIONAL | yes | |
269 |--------------------------|-------------|-----------------------------------------------------------------------------|
270 | PROXIMABLE | yes | op_new.prox(arr, tau) = op_old.prox(\alpha * arr, \alpha**2 * tau) / \alpha |
271 |--------------------------|-------------|-----------------------------------------------------------------------------|
272 | DIFFERENTIABLE | yes | op_new.diff_lipschitz = op_old.diff_lipschitz * (\alpha**2) |
273 | | | op_new.jacobian(arr) = op_old.jacobian(arr * \alpha) * \alpha |
274 |--------------------------|-------------|-----------------------------------------------------------------------------|
275 | DIFFERENTIABLE_FUNCTION | yes | op_new.grad(arr) = op_old.grad(\alpha * arr) * \alpha |
276 |--------------------------|-------------|-----------------------------------------------------------------------------|
277 | QUADRATIC | yes | Q, c, t = op_old._quad_spec() |
278 | | | op_new._quad_spec() = (\alpha**2 * Q, \alpha * c, t) |
279 |--------------------------|-------------|-----------------------------------------------------------------------------|
280 | LINEAR | yes | op_new.adjoint(arr) = op_old.adjoint(arr) * \alpha |
281 | | | op_new.asarray() = op_old.asarray() * \alpha |
282 | | | op_new.svdvals() = op_old.svdvals() * abs(\alpha) |
283 | | | op_new.pinv(x, damp) = op_old.pinv(x, damp / (\alpha**2)) / \alpha |
284 | | | op_new.gram() = op_old.gram() * (\alpha**2) |
285 | | | op_new.cogram() = op_old.cogram() * (\alpha**2) |
286 |--------------------------|-------------|-----------------------------------------------------------------------------|
287 | LINEAR_SQUARE | yes | op_new.trace() = op_old.trace() * \alpha |
288 |--------------------------|-------------|-----------------------------------------------------------------------------|
289 | LINEAR_NORMAL | yes | |
290 |--------------------------|-------------|-----------------------------------------------------------------------------|
291 | LINEAR_UNITARY | \alpha = -1 | |
292 |--------------------------|-------------|-----------------------------------------------------------------------------|
293 | LINEAR_SELF_ADJOINT | yes | |
294 |--------------------------|-------------|-----------------------------------------------------------------------------|
295 | LINEAR_POSITIVE_DEFINITE | \alpha > 0 | |
296 |--------------------------|-------------|-----------------------------------------------------------------------------|
297 | LINEAR_IDEMPOTENT | no | |
298 |--------------------------|-------------|-----------------------------------------------------------------------------|
299 """
300
301 def __init__(self, op: pxt.OpT, cst: pxt.Real):
302 super().__init__()
303 self._op = op
304 self._cst = float(cst)
305
306 def op(self) -> pxt.OpT:
307 if np.isclose(self._cst, 0):
308 # ConstantVECTOR output: modify ConstantValued to work.
309 from pyxu.operator import ConstantValued
310
311 def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray:
312 xp = pxu.get_array_module(arr)
313 arr = xp.zeros_like(arr)
314 out = self._op.apply(arr)
315 return out
316
317 op = ConstantValued(
318 dim_shape=self._op.dim_shape,
319 codim_shape=self._op.codim_shape,
320 cst=self._cst,
321 )
322 op.apply = types.MethodType(op_apply, op)
323 op._name = "ConstantVector"
324 elif np.isclose(self._cst, 1):
325 op = self._op
326 else:
327 klass = self._infer_op_klass()
328 op = klass(
329 dim_shape=self._op.dim_shape,
330 codim_shape=self._op.codim_shape,
331 )
332 op._op = self._op # embed for introspection
333 op._cst = self._cst # embed for introspection
334 for p in op.properties():
335 for name in p.arithmetic_methods():
336 func = getattr(self.__class__, name)
337 setattr(op, name, types.MethodType(func, op))
338 self._propagate_constants(op)
339 return op
340
341 def _expr(self) -> tuple:
342 return ("argscale", self._op, self._cst)
343
344 def _infer_op_klass(self) -> pxt.OpC:
345 preserved = {
346 pxo.Property.CAN_EVAL,
347 pxo.Property.FUNCTIONAL,
348 pxo.Property.PROXIMABLE,
349 pxo.Property.DIFFERENTIABLE,
350 pxo.Property.DIFFERENTIABLE_FUNCTION,
351 pxo.Property.LINEAR,
352 pxo.Property.LINEAR_SQUARE,
353 pxo.Property.LINEAR_NORMAL,
354 pxo.Property.LINEAR_SELF_ADJOINT,
355 pxo.Property.QUADRATIC,
356 }
357 if self._cst > 0:
358 preserved.add(pxo.Property.LINEAR_POSITIVE_DEFINITE)
359 if np.isclose(self._cst, -1):
360 preserved.add(pxo.Property.LINEAR_UNITARY)
361
362 properties = self._op.properties() & preserved
363 klass = pxo.Operator._infer_operator_type(properties)
364 return klass
365
366 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
367 x = arr.copy()
368 x *= self._cst
369 out = self._op.apply(x)
370 return out
371
372 def estimate_lipschitz(self, **kwargs) -> pxt.Real:
373 no_eval = "__rule" in kwargs
374 if no_eval:
375 L = float(self._op.lipschitz)
376 else:
377 L = self._op.estimate_lipschitz(**kwargs)
378 L *= abs(self._cst)
379 return L
380
381 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
382 x = arr.copy()
383 x *= self._cst
384 y = self._op.prox(x, (self._cst**2) * tau)
385 out = pxu.copy_if_unsafe(y)
386 out /= self._cst
387 return out
388
389 def _quad_spec(self):
390 Q1, c1, t1 = self._op._quad_spec()
391 Q2 = ScaleRule(op=Q1, cst=self._cst**2).op()
392 c2 = ScaleRule(op=c1, cst=self._cst).op()
393 t2 = t1
394 return (Q2, c2, t2)
395
396 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT:
397 if self.has(pxo.Property.LINEAR):
398 op = self
399 else:
400 x = arr.copy()
401 x *= self._cst
402 op = self._op.jacobian(x) * self._cst
403 return op
404
405 def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real:
406 no_eval = "__rule" in kwargs
407 if no_eval:
408 dL = self._op.diff_lipschitz
409 else:
410 dL = self._op.estimate_diff_lipschitz(**kwargs)
411 dL *= self._cst**2
412 return dL
413
414 def grad(self, arr: pxt.NDArray) -> pxt.NDArray:
415 x = arr.copy()
416 x *= self._cst
417 out = pxu.copy_if_unsafe(self._op.grad(x))
418 out *= self._cst
419 return out
420
421 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
422 out = pxu.copy_if_unsafe(self._op.adjoint(arr))
423 out *= self._cst
424 return out
425
426 def asarray(self, **kwargs) -> pxt.NDArray:
427 A = pxu.copy_if_unsafe(self._op.asarray(**kwargs))
428 A *= self._cst
429 return A
430
431 def svdvals(self, **kwargs) -> pxt.NDArray:
432 D = pxu.copy_if_unsafe(self._op.svdvals(**kwargs))
433 D *= abs(self._cst)
434 return D
435
436 def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray:
437 scale = damp / (self._cst**2)
438 out = pxu.copy_if_unsafe(self._op.pinv(arr, damp=scale, **kwargs))
439 out /= self._cst
440 return out
441
442 def gram(self) -> pxt.OpT:
443 op = self._op.gram() * (self._cst**2)
444 return op
445
446 def cogram(self) -> pxt.OpT:
447 op = self._op.cogram() * (self._cst**2)
448 return op
449
450 def trace(self, **kwargs) -> pxt.Real:
451 tr = self._op.trace(**kwargs) * self._cst
452 return tr
453
454
[docs]
455class ArgShiftRule(Rule):
456 r"""
457 Arithmetic rules for parameter shifting: :math:`B(x) = A(x + c)`.
458
459 Special Cases::
460
461 [NUMPY,CUPY] \shift = 0 => self
462 [DASK] \shift = 0 => rules below apply ...
463 ... because we don't force evaluation of \shift for performance reasons.
464
465 Else::
466
467 |--------------------------|------------|-----------------------------------------------------------------|
468 | Property | Preserved? | Arithmetic Update Rule(s) |
469 |--------------------------|------------|-----------------------------------------------------------------|
470 | CAN_EVAL | yes | op_new.apply(arr) = op_old.apply(arr + \shift) |
471 | | | op_new.lipschitz = op_old.lipschitz |
472 |--------------------------|------------|-----------------------------------------------------------------|
473 | FUNCTIONAL | yes | |
474 |--------------------------|------------|-----------------------------------------------------------------|
475 | PROXIMABLE | yes | op_new.prox(arr, tau) = op_old.prox(arr + \shift, tau) - \shift |
476 |--------------------------|------------|-----------------------------------------------------------------|
477 | DIFFERENTIABLE | yes | op_new.diff_lipschitz = op_old.diff_lipschitz |
478 | | | op_new.jacobian(arr) = op_old.jacobian(arr + \shift) |
479 |--------------------------|------------|-----------------------------------------------------------------|
480 | DIFFERENTIABLE_FUNCTION | yes | op_new.grad(arr) = op_old.grad(arr + \shift) |
481 |--------------------------|------------|-----------------------------------------------------------------|
482 | QUADRATIC | yes | Q, c, t = op_old._quad_spec() |
483 | | | op_new._quad_spec() = (Q, c + Q @ \shift, op_old.apply(\shift)) |
484 |--------------------------|------------|-----------------------------------------------------------------|
485 | LINEAR | no | |
486 |--------------------------|------------|-----------------------------------------------------------------|
487 | LINEAR_SQUARE | no | |
488 |--------------------------|------------|-----------------------------------------------------------------|
489 | LINEAR_NORMAL | no | |
490 |--------------------------|------------|-----------------------------------------------------------------|
491 | LINEAR_UNITARY | no | |
492 |--------------------------|------------|-----------------------------------------------------------------|
493 | LINEAR_SELF_ADJOINT | no | |
494 |--------------------------|------------|-----------------------------------------------------------------|
495 | LINEAR_POSITIVE_DEFINITE | no | |
496 |--------------------------|------------|-----------------------------------------------------------------|
497 | LINEAR_IDEMPOTENT | no | |
498 |--------------------------|------------|-----------------------------------------------------------------|
499 """
500
501 def __init__(self, op: pxt.OpT, cst: pxt.NDArray):
502 super().__init__()
503 self._op = op
504
505 xp = pxu.get_array_module(cst)
506 try:
507 xp.broadcast_to(cst, op.dim_shape)
508 except ValueError:
509 error_msg = "`cst` must be broadcastable with operator dimensions: "
510 error_msg += f"expected broadcastable-to {op.dim_shape}, got {cst.shape}."
511 raise ValueError(error_msg)
512 self._cst = cst
513
514 def op(self) -> pxt.OpT:
515 N = pxd.NDArrayInfo # short-hand
516 ndi = N.from_obj(self._cst)
517 if ndi == N.DASK:
518 no_op = False
519 else: # NUMPY/CUPY
520 xp = ndi.module()
521 norm = xp.sum(self._cst) ** 2
522 no_op = xp.allclose(norm, 0)
523
524 if no_op:
525 op = self._op
526 else:
527 klass = self._infer_op_klass()
528 op = klass(
529 dim_shape=self._op.dim_shape,
530 codim_shape=self._op.codim_shape,
531 )
532 op._op = self._op # embed for introspection
533 op._cst = self._cst # embed for introspection
534 for p in op.properties():
535 for name in p.arithmetic_methods():
536 func = getattr(self.__class__, name)
537 setattr(op, name, types.MethodType(func, op))
538 self._propagate_constants(op)
539 return op
540
541 def _expr(self) -> tuple:
542 return ("argshift", self._op, self._cst.shape)
543
544 def _infer_op_klass(self) -> pxt.OpC:
545 preserved = {
546 pxo.Property.CAN_EVAL,
547 pxo.Property.FUNCTIONAL,
548 pxo.Property.PROXIMABLE,
549 pxo.Property.DIFFERENTIABLE,
550 pxo.Property.DIFFERENTIABLE_FUNCTION,
551 pxo.Property.QUADRATIC,
552 }
553
554 properties = self._op.properties() & preserved
555 klass = pxo.Operator._infer_operator_type(properties)
556 return klass
557
558 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
559 x = arr.copy()
560 x += self._cst
561 out = self._op.apply(x)
562 return out
563
564 def estimate_lipschitz(self, **kwargs) -> pxt.Real:
565 no_eval = "__rule" in kwargs
566 if no_eval:
567 L = self._op.lipschitz
568 else:
569 L = self._op.estimate_lipschitz(**kwargs)
570 return L
571
572 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
573 x = arr.copy()
574 x += self._cst
575 out = self._op.prox(x, tau)
576 out -= self._cst
577 return out
578
579 def _quad_spec(self):
580 Q1, c1, t1 = self._op._quad_spec()
581
582 xp = pxu.get_array_module(self._cst)
583 cst = xp.broadcast_to(self._cst, self._op.dim_shape)
584
585 Q2 = Q1
586 c2 = c1 + pxo.LinFunc.from_array(
587 A=Q1.apply(cst)[np.newaxis, ...],
588 dim_rank=self._op.dim_rank,
589 enable_warnings=False,
590 # [enable_warnings] API users have no reason to call _quad_spec().
591 # If they choose to use `c2`, then we assume they know what they are doing.
592 )
593 t2 = float(self._op.apply(cst)[0])
594
595 return (Q2, c2, t2)
596
597 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT:
598 x = arr.copy()
599 x += self._cst
600 op = self._op.jacobian(x)
601 return op
602
603 def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real:
604 no_eval = "__rule" in kwargs
605 if no_eval:
606 dL = self._op.diff_lipschitz
607 else:
608 dL = self._op.estimate_diff_lipschitz(**kwargs)
609 return dL
610
611 def grad(self, arr: pxt.NDArray) -> pxt.NDArray:
612 x = arr.copy()
613 x += self._cst
614 out = self._op.grad(x)
615 return out
616
617
[docs]
618class AddRule(Rule):
619 r"""
620 Arithmetic rules for operator addition: :math:`C(x) = A(x) + B(x)`.
621
622 The output type of ``AddRule(A, B)`` is summarized in the table below (LHS/RHS commute)::
623
624 |---------------|-----|------|---------|----------|----------|--------------|-----------|---------|--------------|------------|------------|------------|---------------|---------------|------------|---------------|
625 | LHS / RHS | Map | Func | DiffMap | DiffFunc | ProxFunc | ProxDiffFunc | Quadratic | LinOp | LinFunc | SquareOp | NormalOp | UnitOp | SelfAdjointOp | PosDefOp | ProjOp | OrthProjOp |
626 |---------------|-----|------|---------|----------|----------|--------------|-----------|---------|--------------|------------|------------|------------|---------------|---------------|------------|---------------|
627 | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map |
628 | Func | | Func | Map | Func | Func | Func | Func | Map | Func | Map | Map | Map | Map | Map | Map | Map |
629 | DiffMap | | | DiffMap | DiffMap | Map | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap |
630 | DiffFunc | | | | DiffFunc | Func | DiffFunc | DiffFunc | DiffMap | DiffFunc | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap |
631 | ProxFunc | | | | | Func | Func | Func | Map | ProxFunc | Map | Map | Map | Map | Map | Map | Map |
632 | ProxDiffFunc | | | | | | DiffFunc | DiffFunc | DiffMap | ProxDiffFunc | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap |
633 | Quadratic | | | | | | | Quadratic | DiffMap | Quadratic | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap |
634 | LinOp | | | | | | | | LinOp | LinOp | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE |
635 | LinFunc | | | | | | | | | LinFunc | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp |
636 | SquareOp | | | | | | | | | | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp |
637 | NormalOp | | | | | | | | | | | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp |
638 | UnitOp | | | | | | | | | | | | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp |
639 | SelfAdjointOp | | | | | | | | | | | | | SelfAdjointOp | SelfAdjointOp | SquareOp | SelfAdjointOp |
640 | PosDefOp | | | | | | | | | | | | | | PosDefOp | SquareOp | PosDefOp |
641 | ProjOp | | | | | | | | | | | | | | | SquareOp | SquareOp |
642 | OrthProjOp | | | | | | | | | | | | | | | | SelfAdjointOp |
643 |---------------|-----|------|---------|----------|----------|--------------|-----------|---------|--------------|------------|------------|------------|---------------|---------------|------------|---------------|
644
645 Arithmetic Update Rule(s)::
646
647 * CAN_EVAL
648 op.apply(arr) = _lhs.apply(arr) + _rhs.apply(arr)
649 op.lipschitz = _lhs.lipschitz + _rhs.lipschitz
650 IMPORTANT: if range-broadcasting takes place (ex: LHS(1,) + RHS(M,)), then the broadcasted
651 operand's Lipschitz constant must be magnified by \sqrt{M}.
652
653 * PROXIMABLE
654 op.prox(arr, tau) = _lhs.prox(arr - tau * _rhs.grad(arr), tau)
655 OR = _rhs.prox(arr - tau * _lhs.grad(arr), tau)
656 IMPORTANT: the one calling .grad() should be either (lhs, rhs) which has LINEAR property
657
658 * DIFFERENTIABLE
659 op.jacobian(arr) = _lhs.jacobian(arr) + _rhs.jacobian(arr)
660 op.diff_lipschitz = _lhs.diff_lipschitz + _rhs.diff_lipschitz
661 IMPORTANT: if range-broadcasting takes place (ex: LHS(1,) + RHS(M,)), then the broadcasted
662 operand's diff-Lipschitz constant must be magnified by \sqrt{M}.
663
664 * DIFFERENTIABLE_FUNCTION
665 op.grad(arr) = _lhs.grad(arr) + _rhs.grad(arr)
666
667 * LINEAR
668 op.adjoint(arr) = _lhs.adjoint(arr) + _rhs.adjoint(arr)
669 IMPORTANT: if range-broadcasting takes place (ex: LHS(1,) + RHS(M,)), then the broadcasted
670 operand's adjoint-input must be averaged.
671 op.asarray() = _lhs.asarray() + _rhs.asarray()
672 op.gram() = _lhs.gram() + _rhs.gram() + (_lhs.T * _rhs) + (_rhs.T * _lhs)
673 op.cogram() = _lhs.cogram() + _rhs.cogram() + (_lhs * _rhs.T) + (_rhs * _lhs.T)
674
675 * LINEAR_SQUARE
676 op.trace() = _lhs.trace() + _rhs.trace()
677
678 * QUADRATIC
679 lhs = rhs = quadratic
680 Q_l, c_l, t_l = lhs._quad_spec()
681 Q_r, c_r, t_r = rhs._quad_spec()
682 op._quad_spec() = (Q_l + Q_r, c_l + c_r, t_l + t_r)
683 lhs, rhs = quadratic, linear
684 Q, c, t = lhs._quad_spec()
685 op._quad_spec() = (Q, c + rhs, t)
686 """
687
688 def __init__(self, lhs: pxt.OpT, rhs: pxt.OpT):
689 assert lhs.dim_shape == rhs.dim_shape, "Operator dimensions are not compatible."
690 try:
691 codim_bcast = np.broadcast_shapes(lhs.codim_shape, rhs.codim_shape)
692 except ValueError:
693 error_msg = "`lhs/rhs` codims must be broadcastable: "
694 error_msg += f"got {lhs.codim_shape}, {rhs.codim_shape}."
695 raise ValueError(error_msg)
696
697 if codim_bcast != lhs.codim_shape:
698 from pyxu.operator import BroadcastAxes
699
700 bcast = BroadcastAxes(
701 dim_shape=lhs.codim_shape,
702 codim_shape=codim_bcast,
703 )
704 lhs = bcast * lhs
705 if codim_bcast != rhs.codim_shape:
706 from pyxu.operator import BroadcastAxes
707
708 bcast = BroadcastAxes(
709 dim_shape=rhs.codim_shape,
710 codim_shape=codim_bcast,
711 )
712 rhs = bcast * rhs
713
714 super().__init__()
715 self._lhs = lhs
716 self._rhs = rhs
717
718 def op(self) -> pxt.OpT:
719 # LHS/RHS have same dim/codim following __init__()
720 dim_shape = self._rhs.dim_shape
721 codim_shape = self._rhs.codim_shape
722 klass = self._infer_op_klass(dim_shape, codim_shape)
723
724 if klass.has(pxo.Property.QUADRATIC):
725 # Quadratic additive arithmetic differs substantially from other arithmetic operations.
726 # To avoid tedious redefinitions of arithmetic methods to handle QuadraticFunc
727 # specifically, the code-path below delegates additive arithmetic directly to
728 # QuadraticFunc.
729 lin = lambda _: _.has(pxo.Property.LINEAR)
730 quad = lambda _: _.has(pxo.Property.QUADRATIC)
731
732 if quad(self._lhs) and quad(self._rhs):
733 lQ, lc, lt = self._lhs._quad_spec()
734 rQ, rc, rt = self._rhs._quad_spec()
735 op = klass(
736 dim_shape=dim_shape,
737 codim_shape=1,
738 Q=lQ + rQ,
739 c=lc + rc,
740 t=lt + rt,
741 )
742 elif quad(self._lhs) and lin(self._rhs):
743 lQ, lc, lt = self._lhs._quad_spec()
744 op = klass(
745 dim_shape=dim_shape,
746 codim_shape=1,
747 Q=lQ,
748 c=lc + self._rhs,
749 t=lt,
750 )
751 elif lin(self._lhs) and quad(self._rhs):
752 rQ, rc, rt = self._rhs._quad_spec()
753 op = klass(
754 dim_shape=dim_shape,
755 codim_shape=1,
756 Q=rQ,
757 c=self._lhs + rc,
758 t=rt,
759 )
760 else:
761 raise ValueError("Impossible scenario: something went wrong during klass inference.")
762 else:
763 op = klass(
764 dim_shape=dim_shape,
765 codim_shape=codim_shape,
766 )
767 op._lhs = self._lhs # embed for introspection
768 op._rhs = self._rhs # embed for introspection
769 for p in op.properties():
770 for name in p.arithmetic_methods():
771 func = getattr(self.__class__, name)
772 setattr(op, name, types.MethodType(func, op))
773 self._propagate_constants(op)
774 return op
775
776 def _expr(self) -> tuple:
777 return ("add", self._lhs, self._rhs)
778
779 def _infer_op_klass(
780 self,
781 dim_shape: pxt.NDArrayShape,
782 codim_shape: pxt.NDArrayShape,
783 ) -> pxt.OpC:
784 P = pxo.Property
785 lhs_p = self._lhs.properties()
786 rhs_p = self._rhs.properties()
787 base = set(lhs_p & rhs_p)
788 base.discard(pxo.Property.LINEAR_NORMAL)
789 base.discard(pxo.Property.LINEAR_UNITARY)
790 base.discard(pxo.Property.LINEAR_IDEMPOTENT)
791 base.discard(pxo.Property.PROXIMABLE)
792
793 # Exceptions ----------------------------------------------------------
794 # normality preserved for self-adjoint addition
795 if P.LINEAR_SELF_ADJOINT in base:
796 base.add(P.LINEAR_NORMAL)
797
798 # orth-proj + pos-def => pos-def
799 if (({P.LINEAR_IDEMPOTENT, P.LINEAR_SELF_ADJOINT} < lhs_p) and (P.LINEAR_POSITIVE_DEFINITE in rhs_p)) or (
800 ({P.LINEAR_IDEMPOTENT, P.LINEAR_SELF_ADJOINT} < rhs_p) and (P.LINEAR_POSITIVE_DEFINITE in lhs_p)
801 ):
802 base.add(P.LINEAR_SQUARE)
803 base.add(P.LINEAR_NORMAL)
804 base.add(P.LINEAR_SELF_ADJOINT)
805 base.add(P.LINEAR_POSITIVE_DEFINITE)
806
807 # linfunc + (square-shape) => square
808 if P.LINEAR in base:
809 dim_size = np.prod(dim_shape)
810 codim_size = np.prod(codim_shape)
811 if (dim_size == codim_size) and (codim_shape != (1,)):
812 base.add(P.LINEAR_SQUARE)
813
814 # quadratic + quadratic => quadratic
815 if P.QUADRATIC in base:
816 base.add(P.PROXIMABLE)
817
818 # quadratic + linfunc => quadratic
819 if (P.PROXIMABLE in (lhs_p & rhs_p)) and ({P.QUADRATIC, P.LINEAR} < (lhs_p | rhs_p)):
820 base.add(P.QUADRATIC)
821
822 # prox(-diff) + linfunc => prox(-diff)
823 if (P.PROXIMABLE in (lhs_p & rhs_p)) and (P.LINEAR in (lhs_p | rhs_p)):
824 base.add(P.PROXIMABLE)
825
826 klass = pxo.Operator._infer_operator_type(base)
827 return klass
828
829 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
830 out = pxu.copy_if_unsafe(self._lhs.apply(arr))
831 out += self._rhs.apply(arr)
832 return out
833
834 def estimate_lipschitz(self, **kwargs) -> pxt.Real:
835 no_eval = "__rule" in kwargs
836 if no_eval:
837 L_lhs = self._lhs.lipschitz
838 L_rhs = self._rhs.lipschitz
839 elif self.has(pxo.Property.LINEAR):
840 L = self.__class__.estimate_lipschitz(self, **kwargs)
841 return L
842 else:
843 L_lhs = self._lhs.estimate_lipschitz(**kwargs)
844 L_rhs = self._rhs.estimate_lipschitz(**kwargs)
845
846 L = L_lhs + L_rhs
847 return L
848
849 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
850 P_LHS = self._lhs.properties()
851 P_RHS = self._rhs.properties()
852 if pxo.Property.LINEAR in (P_LHS | P_RHS):
853 # linear + proximable
854 if pxo.Property.LINEAR in P_LHS:
855 P, G = self._rhs, self._lhs
856 elif pxo.Property.LINEAR in P_RHS:
857 P, G = self._lhs, self._rhs
858 x = pxu.copy_if_unsafe(G.grad(arr))
859 x *= -tau
860 x += arr
861 out = P.prox(x, tau)
862 else:
863 raise NotImplementedError
864 return out
865
866 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT:
867 if self.has(pxo.Property.LINEAR):
868 op = self
869 else:
870 op_lhs = self._lhs.jacobian(arr)
871 op_rhs = self._rhs.jacobian(arr)
872 op = op_lhs + op_rhs
873 return op
874
875 def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real:
876 no_eval = "__rule" in kwargs
877 if no_eval:
878 dL_lhs = self._lhs.diff_lipschitz
879 dL_rhs = self._rhs.diff_lipschitz
880 elif self.has(pxo.Property.LINEAR):
881 dL_lhs = 0
882 dL_rhs = 0
883 else:
884 dL_lhs = self._lhs.estimate_diff_lipschitz(**kwargs)
885 dL_rhs = self._rhs.estimate_diff_lipschitz(**kwargs)
886
887 dL = dL_lhs + dL_rhs
888 return dL
889
890 def grad(self, arr: pxt.NDArray) -> pxt.NDArray:
891 out = self._lhs.grad(arr)
892 out = pxu.copy_if_unsafe(out)
893 out += self._rhs.grad(arr)
894 return out
895
896 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
897 out = self._lhs.adjoint(arr)
898 out = pxu.copy_if_unsafe(out)
899 out += self._rhs.adjoint(arr)
900 return out
901
902 def asarray(self, **kwargs) -> pxt.NDArray:
903 A = self._lhs.asarray(**kwargs)
904 A = pxu.copy_if_unsafe(A)
905 A += self._rhs.asarray(**kwargs)
906 return A
907
908 def gram(self) -> pxt.OpT:
909 op1 = self._lhs.gram()
910 op2 = self._rhs.gram()
911 op3 = self._lhs.T * self._rhs
912 op4 = self._rhs.T * self._lhs
913 op = op1 + op2 + (op3 + op4).asop(pxo.SelfAdjointOp)
914 return op
915
916 def cogram(self) -> pxt.OpT:
917 op1 = self._lhs.cogram()
918 op2 = self._rhs.cogram()
919 op3 = self._lhs * self._rhs.T
920 op4 = self._rhs * self._lhs.T
921 op = op1 + op2 + (op3 + op4).asop(pxo.SelfAdjointOp)
922 return op
923
924 def trace(self, **kwargs) -> pxt.Real:
925 tr = 0
926 for side in (self._lhs, self._rhs):
927 tr += side.trace(**kwargs)
928 return float(tr)
929
930
[docs]
931class ChainRule(Rule):
932 r"""
933 Arithmetic rules for operator composition: :math:`C(x) = (A \circ B)(x)`.
934
935 The output type of ``ChainRule(A, B)`` is summarized in the table below::
936
937 |---------------|------|------------|----------|------------|------------|----------------|----------------------|------------------|------------|-----------|-----------|--------------|---------------|-----------|-----------|------------|
938 | LHS / RHS | Map | Func | DiffMap | DiffFunc | ProxFunc | ProxDiffFunc | Quadratic | LinOp | LinFunc | SquareOp | NormalOp | UnitOp | SelfAdjointOp | PosDefOp | ProjOp | OrthProjOp |
939 |---------------|------|------------|----------|------------|------------|----------------|----------------------|------------------|------------|-----------|-----------|--------------|---------------|-----------|-----------|------------|
940 | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map | Map |
941 | Func | Func | Func | Func | Func | Func | Func | Func | Func | Func | Func | Func | Func | Func | Func | Func | Func |
942 | DiffMap | Map | Map | DiffMap | DiffMap | Map | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap | DiffMap |
943 | DiffFunc | Func | Func | DiffFunc | DiffFunc | Func | DiffFunc | DiffFunc | DiffFunc | DiffFunc | DiffFunc | DiffFunc | DiffFunc | DiffFunc | DiffFunc | DiffFunc | DiffFunc |
944 | ProxFunc | Func | Func | Func | Func | Func | Func | Func | Func | Func | Func | Func | ProxFunc | Func | Func | Func | Func |
945 | ProxDiffFunc | Func | Func | DiffFunc | DiffFunc | Func | DiffFunc | DiffFunc | DiffFunc | DiffFunc | DiffFunc | DiffFunc | ProxDiffFunc | DiffFunc | DiffFunc | DiffFunc | DiffFunc |
946 | Quadratic | Func | Func | DiffFunc | DiffFunc | Func | DiffFunc | DiffFunc | Quadratic | Quadratic | Quadratic | Quadratic | Quadratic | Quadratic | Quadratic | Quadratic | Quadratic |
947 | LinOp | Map | Func | DiffMap | DiffMap | Map | DiffMap | DiffMap | LinOp / SquareOp | LinOp | LinOp | LinOp | LinOp | LinOp | LinOp | LinOp | LinOp |
948 | LinFunc | Func | Func | DiffFunc | DiffFunc | [Prox]Func | [Prox]DiffFunc | DiffFunc / Quadratic | LinFunc | LinFunc | LinFunc | LinFunc | LinFunc | LinFunc | LinFunc | LinFunc | LinFunc |
949 | SquareOp | Map | IMPOSSIBLE | DiffMap | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | LinOp | IMPOSSIBLE | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp |
950 | NormalOp | Map | IMPOSSIBLE | DiffMap | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | LinOp | IMPOSSIBLE | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp |
951 | UnitOp | Map | IMPOSSIBLE | DiffMap | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | LinOp | IMPOSSIBLE | SquareOp | SquareOp | UnitOp | SquareOp | SquareOp | SquareOp | SquareOp |
952 | SelfAdjointOp | Map | IMPOSSIBLE | DiffMap | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | LinOp | IMPOSSIBLE | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp |
953 | PosDefOp | Map | IMPOSSIBLE | DiffMap | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | LinOp | IMPOSSIBLE | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp |
954 | ProjOp | Map | IMPOSSIBLE | DiffMap | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | LinOp | IMPOSSIBLE | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp |
955 | OrthProjOp | Map | IMPOSSIBLE | DiffMap | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | IMPOSSIBLE | LinOp | IMPOSSIBLE | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp | SquareOp |
956 |---------------|------|------------|----------|------------|------------|----------------|----------------------|------------------|------------|-----------|-----------|--------------|---------------|-----------|-----------|------------|
957
958 Arithmetic Update Rule(s)::
959
960 * CAN_EVAL
961 op.apply(arr) = _lhs.apply(_rhs.apply(arr))
962 op.lipschitz = _lhs.lipschitz * _rhs.lipschitz
963
964 * PROXIMABLE (RHS Unitary only)
965 op.prox(arr, tau) = _rhs.adjoint(_lhs.prox(_rhs.apply(arr), tau))
966
967 * DIFFERENTIABLE
968 op.jacobian(arr) = _lhs.jacobian(_rhs.apply(arr)) * _rhs.jacobian(arr)
969 op.diff_lipschitz =
970 quadratic => _quad_spec().Q.lipschitz
971 linear \comp linear => 0
972 linear \comp diff => _lhs.lipschitz * _rhs.diff_lipschitz
973 diff \comp linear => _lhs.diff_lipschitz * (_rhs.lipschitz ** 2)
974 diff \comp diff => \infty
975
976 * DIFFERENTIABLE_FUNCTION (1D input)
977 op.grad(arr) = _lhs.grad(_rhs.apply(arr)) @ _rhs.jacobian(arr).asarray()
978
979 * LINEAR
980 op.adjoint(arr) = _rhs.adjoint(_lhs.adjoint(arr))
981 op.asarray() = _lhs.asarray() @ _rhs.asarray()
982 op.gram() = _rhs.T @ _lhs.gram() @ _rhs
983 op.cogram() = _lhs @ _rhs.cogram() @ _lhs.T
984
985 * QUADRATIC
986 Q, c, t = _lhs._quad_spec()
987 op._quad_spec() = (_rhs.T * Q * _rhs, _rhs.T * c, t)
988 """
989
990 def __init__(self, lhs: pxt.OpT, rhs: pxt.OpT):
991 assert lhs.dim_shape == rhs.codim_shape, "Operator dimensions are not compatible."
992
993 super().__init__()
994 self._lhs = lhs
995 self._rhs = rhs
996
997 def op(self) -> pxt.OpT:
998 klass = self._infer_op_klass()
999 op = klass(
1000 dim_shape=self._rhs.dim_shape,
1001 codim_shape=self._lhs.codim_shape,
1002 )
1003 op._lhs = self._lhs # embed for introspection
1004 op._rhs = self._rhs # embed for introspection
1005 for p in op.properties():
1006 for name in p.arithmetic_methods():
1007 func = getattr(self.__class__, name)
1008 setattr(op, name, types.MethodType(func, op))
1009 self._propagate_constants(op)
1010 return op
1011
1012 def _expr(self) -> tuple:
1013 return ("compose", self._lhs, self._rhs)
1014
1015 def _infer_op_klass(self) -> pxt.OpC:
1016 # |--------------------------|------------------------------------------------------|
1017 # | Property | Preserved? |
1018 # |--------------------------|------------------------------------------------------|
1019 # | CAN_EVAL | (LHS CAN_EVAL) & (RHS CAN_EVAL) |
1020 # |--------------------------|------------------------------------------------------|
1021 # | FUNCTIONAL | LHS FUNCTIONAL |
1022 # |--------------------------|------------------------------------------------------|
1023 # | PROXIMABLE | * (LHS PROXIMABLE) & (RHS LINEAR_UNITARY) |
1024 # | | * (LHS LINEAR) & (RHS LINEAR) |
1025 # | | * (LHS LINEAR FUNCTIONAL [> 0]) & (RHS PROXIMABLE) |
1026 # |--------------------------|------------------------------------------------------|
1027 # | DIFFERENTIABLE | (LHS DIFFERENTIABLE) & (RHS DIFFERENTIABLE) |
1028 # |--------------------------|------------------------------------------------------|
1029 # | DIFFERENTIABLE_FUNCTION | (LHS DIFFERENTIABLE_FUNCTION) & (RHS DIFFERENTIABLE) |
1030 # |--------------------------|------------------------------------------------------|
1031 # | QUADRATIC | * (LHS QUADRATIC) & (RHS LINEAR) |
1032 # | | * (LHS LINEAR FUNCTIONAL [> 0]) & (RHS QUADRATIC) |
1033 # |--------------------------|------------------------------------------------------|
1034 # | LINEAR | (LHS LINEAR) & (RHS LINEAR) |
1035 # |--------------------------|------------------------------------------------------|
1036 # | LINEAR_SQUARE | (Shape[LHS * RHS] square) & (LHS.codim > 1) |
1037 # |--------------------------|------------------------------------------------------|
1038 # | LINEAR_NORMAL | no |
1039 # |--------------------------|------------------------------------------------------|
1040 # | LINEAR_UNITARY | (LHS LINEAR_UNITARY) & (RHS LINEAR_UNITARY) |
1041 # |--------------------------|------------------------------------------------------|
1042 # | LINEAR_SELF_ADJOINT | no |
1043 # |--------------------------|------------------------------------------------------|
1044 # | LINEAR_POSITIVE_DEFINITE | no |
1045 # |--------------------------|------------------------------------------------------|
1046 # | LINEAR_IDEMPOTENT | no |
1047 # |--------------------------|------------------------------------------------------|
1048 lhs_p = self._lhs.properties()
1049 rhs_p = self._rhs.properties()
1050 P = pxo.Property
1051 properties = {P.CAN_EVAL}
1052 if P.FUNCTIONAL in lhs_p:
1053 properties.add(P.FUNCTIONAL)
1054 # Proximal ------------------------------
1055 if (P.PROXIMABLE in lhs_p) and (P.LINEAR_UNITARY in rhs_p):
1056 properties.add(P.PROXIMABLE)
1057 elif ({P.LINEAR, P.FUNCTIONAL} < lhs_p) and (P.PROXIMABLE in rhs_p):
1058 cst = self._lhs.asarray().item()
1059 if cst > 0:
1060 properties.add(P.PROXIMABLE)
1061 if P.QUADRATIC in rhs_p:
1062 properties.add(P.QUADRATIC)
1063 # ---------------------------------------
1064 if P.DIFFERENTIABLE in (lhs_p & rhs_p):
1065 properties.add(P.DIFFERENTIABLE)
1066 if (P.DIFFERENTIABLE_FUNCTION in lhs_p) and (P.DIFFERENTIABLE in rhs_p):
1067 properties.add(P.DIFFERENTIABLE_FUNCTION)
1068 if (P.QUADRATIC in lhs_p) and (P.LINEAR in rhs_p):
1069 properties.add(P.PROXIMABLE)
1070 properties.add(P.QUADRATIC)
1071 if P.LINEAR in (lhs_p & rhs_p):
1072 properties.add(P.LINEAR)
1073 if self._lhs.codim_shape == (1,):
1074 for p in pxo.LinFunc.properties():
1075 properties.add(p)
1076 if (self._lhs.codim_size == self._rhs.dim_size) and (self._rhs.dim_size > 1):
1077 properties.add(P.LINEAR_SQUARE)
1078 if P.LINEAR_UNITARY in (lhs_p & rhs_p):
1079 properties.add(P.LINEAR_NORMAL)
1080 properties.add(P.LINEAR_UNITARY)
1081
1082 klass = pxo.Operator._infer_operator_type(properties)
1083 return klass
1084
1085 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
1086 x = self._rhs.apply(arr)
1087 out = self._lhs.apply(x)
1088 return out
1089
1090 def estimate_lipschitz(self, **kwargs) -> pxt.Real:
1091 no_eval = "__rule" in kwargs
1092 if no_eval:
1093 L_lhs = self._lhs.lipschitz
1094 L_rhs = self._rhs.lipschitz
1095 elif self.has(pxo.Property.LINEAR):
1096 L = self.__class__.estimate_lipschitz(self, **kwargs)
1097 return L
1098 else:
1099 L_lhs = self._lhs.estimate_lipschitz(**kwargs)
1100 L_rhs = self._rhs.estimate_lipschitz(**kwargs)
1101
1102 zeroQ = lambda _: np.isclose(_, 0)
1103 if zeroQ(L_lhs) or zeroQ(L_rhs):
1104 L = 0
1105 else:
1106 L = L_lhs * L_rhs
1107 return L
1108
1109 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
1110 if self.has(pxo.Property.PROXIMABLE):
1111 out = None
1112 if self._lhs.has(pxo.Property.PROXIMABLE) and self._rhs.has(pxo.Property.LINEAR_UNITARY):
1113 # prox[diff]func() \comp unitop() => prox[diff]func()
1114 x = self._rhs.apply(arr)
1115 y = self._lhs.prox(x, tau)
1116 out = self._rhs.adjoint(y)
1117 elif self._lhs.has(pxo.Property.QUADRATIC) and self._rhs.has(pxo.Property.LINEAR):
1118 # quadratic \comp linop => quadratic
1119 Q, c, t = self._quad_spec()
1120 op = pxo.QuadraticFunc(
1121 dim_shape=self.dim_shape,
1122 codim_shape=self.codim_shape,
1123 Q=Q,
1124 c=c,
1125 t=t,
1126 )
1127 out = op.prox(arr, tau)
1128 elif self._lhs.has(pxo.Property.LINEAR) and self._rhs.has(pxo.Property.PROXIMABLE):
1129 # linfunc() \comp prox[diff]func() => prox[diff]func()
1130 # = (\alpha * prox[diff]func())
1131 op = ScaleRule(op=self._rhs, cst=self._lhs.asarray().item()).op()
1132 out = op.prox(arr, tau)
1133 elif pxo.Property.LINEAR in (self._lhs.properties() & self._rhs.properties()):
1134 # linfunc() \comp linop() => linfunc()
1135 out = pxo.LinFunc.prox(self, arr, tau)
1136
1137 if out is not None:
1138 return out
1139 raise NotImplementedError
1140
1141 def _quad_spec(self):
1142 if self.has(pxo.Property.QUADRATIC):
1143 if self._lhs.has(pxo.Property.LINEAR):
1144 # linfunc (scalar) \comp quadratic
1145 op = ScaleRule(op=self._rhs, cst=self._lhs.asarray().item()).op()
1146 Q2, c2, t2 = op._quad_spec()
1147 elif self._rhs.has(pxo.Property.LINEAR):
1148 # quadratic \comp linop
1149 Q1, c1, t1 = self._lhs._quad_spec()
1150 Q2 = (self._rhs.T * Q1 * self._rhs).asop(pxo.PosDefOp)
1151 c2 = c1 * self._rhs
1152 t2 = t1
1153 return (Q2, c2, t2)
1154 else:
1155 raise NotImplementedError
1156
1157 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT:
1158 if self.has(pxo.Property.LINEAR):
1159 op = self
1160 else:
1161 J_rhs = self._rhs.jacobian(arr)
1162 J_lhs = self._lhs.jacobian(self._rhs.apply(arr))
1163 op = J_lhs * J_rhs
1164 return op
1165
1166 def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real:
1167 no_eval = "__rule" in kwargs
1168 if self.has(pxo.Property.QUADRATIC):
1169 Q, c, t = self._quad_spec()
1170 op = pxo.QuadraticFunc(
1171 dim_shape=self.dim_shape,
1172 codim_shape=self.codim_shape,
1173 Q=Q,
1174 c=c,
1175 t=t,
1176 )
1177 if no_eval:
1178 dL = op.diff_lipschitz
1179 else:
1180 dL = op.estimate_diff_lipschitz(**kwargs)
1181 elif self._lhs.has(pxo.Property.LINEAR) and self._rhs.has(pxo.Property.LINEAR):
1182 dL = 0
1183 elif self._lhs.has(pxo.Property.LINEAR) and self._rhs.has(pxo.Property.DIFFERENTIABLE):
1184 if no_eval:
1185 L_lhs = self._lhs.lipschitz
1186 dL_rhs = self._rhs.diff_lipschitz
1187 else:
1188 L_lhs = self._lhs.estimate_lipschitz(**kwargs)
1189 dL_rhs = self._rhs.estimate_diff_lipschitz(**kwargs)
1190 dL = L_lhs * dL_rhs
1191 elif self._lhs.has(pxo.Property.DIFFERENTIABLE) and self._rhs.has(pxo.Property.LINEAR):
1192 if no_eval:
1193 dL_lhs = self._lhs.diff_lipschitz
1194 L_rhs = self._rhs.lipschitz
1195 else:
1196 dL_lhs = self._lhs.estimate_diff_lipschitz(**kwargs)
1197 L_rhs = self._rhs.estimate_lipschitz(**kwargs)
1198 dL = dL_lhs * (L_rhs**2)
1199 else:
1200 dL = np.inf
1201 return dL
1202
1203 def grad(self, arr: pxt.NDArray) -> pxt.NDArray:
1204 sh = arr.shape[: -self.dim_rank]
1205 if (len(sh) == 0) or self._rhs.has(pxo.Property.LINEAR):
1206 x = self._lhs.grad(self._rhs.apply(arr))
1207 out = self._rhs.jacobian(arr).adjoint(x)
1208
1209 # RHS.adjoint() may change core-chunks if (codim->dim) changes are involved.
1210 # This is problematic since grad() should preserve core-chunks by default.
1211 ndi = pxd.NDArrayInfo.from_obj(arr)
1212 if ndi == pxd.NDArrayInfo.DASK:
1213 if out.chunks != arr.chunks:
1214 out = out.rechunk(arr.chunks)
1215 else:
1216 # We need to evaluate the Jacobian seperately per stacked input.
1217
1218 @pxu.vectorize(
1219 i="arr",
1220 dim_shape=self.dim_shape,
1221 codim_shape=self.dim_shape,
1222 )
1223 def f(arr: pxt.NDArray) -> pxt.NDArray:
1224 x = self._lhs.grad(self._rhs.apply(arr))
1225 out = self._rhs.jacobian(arr).adjoint(x)
1226
1227 # RHS.adjoint() may change core-chunks if (codim->dim) changes are involved.
1228 # This is problematic since grad() should preserve core-chunks by default.
1229 ndi = pxd.NDArrayInfo.from_obj(arr)
1230 if ndi == pxd.NDArrayInfo.DASK:
1231 if out.chunks != arr.chunks:
1232 out = out.rechunk(arr.chunks)
1233 return out
1234
1235 out = f(arr)
1236 return out
1237
1238 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
1239 x = self._lhs.adjoint(arr)
1240 out = self._rhs.adjoint(x)
1241 return out
1242
1243 def asarray(self, **kwargs) -> pxt.NDArray:
1244 A_lhs = self._lhs.asarray(**kwargs)
1245 A_rhs = self._rhs.asarray(**kwargs)
1246
1247 xp = pxu.get_array_module(A_lhs)
1248 A = xp.tensordot(A_lhs, A_rhs, axes=self._lhs.dim_rank)
1249 return A
1250
1251 def gram(self) -> pxt.OpT:
1252 op = self._rhs.T * self._lhs.gram() * self._rhs
1253 return op.asop(pxo.SelfAdjointOp)
1254
1255 def cogram(self) -> pxt.OpT:
1256 op = self._lhs * self._rhs.cogram() * self._lhs.T
1257 return op.asop(pxo.SelfAdjointOp)
1258
1259
[docs]
1260class TransposeRule(Rule):
1261 # Not strictly-speaking an arithmetic method, but the logic behind constructing transposed
1262 # operators is identical to arithmetic methods.
1263 # LinOp.T() rules are hence summarized here.
1264 r"""
1265 Arithmetic rules for :py:class:`~pyxu.abc.LinOp` transposition: :math:`B(x) = A^{T}(x)`.
1266
1267 Arithmetic Update Rule(s)::
1268
1269 * CAN_EVAL
1270 opT.apply(arr) = op.adjoint(arr)
1271 opT.lipschitz = op.lipschitz
1272
1273 * PROXIMABLE
1274 opT.prox(arr, tau) = LinFunc.prox(arr, tau)
1275
1276 * DIFFERENTIABLE
1277 opT.jacobian(arr) = opT
1278 opT.diff_lipschitz = 0
1279
1280 * DIFFERENTIABLE_FUNCTION
1281 opT.grad(arr) = LinFunc.grad(arr)
1282
1283 * LINEAR
1284 opT.adjoint(arr) = op.apply(arr)
1285 opT.asarray() = op.asarray().T [block-reorder dim/codim]
1286 opT.gram() = op.cogram()
1287 opT.cogram() = op.gram()
1288 opT.svdvals() = op.svdvals()
1289
1290 * LINEAR_SQUARE
1291 opT.trace() = op.trace()
1292 """
1293
1294 def __init__(self, op: pxt.OpT):
1295 super().__init__()
1296 self._op = op
1297
1298 def op(self) -> pxt.OpT:
1299 klass = self._infer_op_klass()
1300 op = klass(
1301 dim_shape=self._op.codim_shape,
1302 codim_shape=self._op.dim_shape,
1303 )
1304 op._op = self._op # embed for introspection
1305 for p in op.properties():
1306 for name in p.arithmetic_methods():
1307 func = getattr(self.__class__, name)
1308 setattr(op, name, types.MethodType(func, op))
1309 self._propagate_constants(op)
1310 return op
1311
1312 def _expr(self) -> tuple:
1313 return ("transpose", self._op)
1314
1315 def _infer_op_klass(self) -> pxt.OpC:
1316 # |--------------------------------|--------------------------------|
1317 # | op_klass(codim; dim) | opT_klass(codim; dim) |
1318 # |--------------------------------|--------------------------------|
1319 # | LINEAR(1; 1) | LinFunc(1; 1) |
1320 # | LinFunc(1; M1,...,MD) | LinOp(M1,...,MD; 1) |
1321 # | LinOp(N1,...,ND; 1) | LinFunc(1; N1,...,ND) |
1322 # | op_klass(N1,...,ND; M1,...,MD) | op_klass(M1,...,MD; N1,...,ND) |
1323 # |--------------------------------|--------------------------------|
1324 single_dim = self._op.dim_shape == (1,)
1325 single_codim = self._op.codim_shape == (1,)
1326
1327 if single_dim and single_codim:
1328 klass = pxo.LinFunc
1329 elif single_codim:
1330 klass = pxo.LinOp
1331 elif single_dim:
1332 klass = pxo.LinFunc
1333 else:
1334 prop = self._op.properties()
1335 klass = pxo.Operator._infer_operator_type(prop)
1336 return klass
1337
1338 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
1339 out = self._op.adjoint(arr)
1340 return out
1341
1342 def estimate_lipschitz(self, **kwargs) -> pxt.Real:
1343 no_eval = "__rule" in kwargs
1344 if no_eval:
1345 L = self._op.lipschitz
1346 else:
1347 L = self._op.estimate_lipschitz(**kwargs)
1348 return L
1349
1350 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
1351 out = pxo.LinFunc.prox(self, arr, tau)
1352 return out
1353
1354 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT:
1355 return self
1356
1357 def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real:
1358 return 0
1359
1360 def grad(self, arr: pxt.NDArray) -> pxt.NDArray:
1361 out = pxo.LinFunc.grad(self, arr)
1362 return out
1363
1364 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
1365 out = self._op.apply(arr)
1366 return out
1367
1368 def asarray(self, **kwargs) -> pxt.NDArray:
1369 A = self._op.asarray(**kwargs)
1370 B = A.transpose(
1371 *range(-self._op.dim_rank, 0),
1372 *range(self._op.codim_rank),
1373 )
1374 return B
1375
1376 def gram(self) -> pxt.OpT:
1377 op = self._op.cogram()
1378 return op
1379
1380 def cogram(self) -> pxt.OpT:
1381 op = self._op.gram()
1382 return op
1383
1384 def svdvals(self, **kwargs) -> pxt.NDArray:
1385 D = self._op.svdvals(**kwargs)
1386 return D
1387
1388 def trace(self, **kwargs) -> pxt.Real:
1389 tr = self._op.trace(**kwargs)
1390 return tr