1import collections.abc as cabc
2import types
3
4import numpy as np
5
6import pyxu.abc as pxa
7import pyxu.info.ptype as pxt
8import pyxu.operator.interop.source as px_src
9import pyxu.util as pxu
10
11__all__ = [
12 "stack",
13 "block_diag",
14]
15
16
[docs]
17def stack(ops: cabc.Sequence[pxt.OpT]) -> pxt.OpT:
18 r"""
19 Map operators over the same input.
20
21 A stacked operator :math:`S: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{Q \times N_{1}
22 \times\cdots\times N_{K}} is an operator containing (vertically) :math:`Q` blocks of smaller operators :math:`\{
23 O_{q}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{N_{1} \times\cdots\times N_{K}} \}_{q=1}^{Q}`:
24
25 .. math::
26
27 S
28 =
29 \left[
30 \begin{array}{c}
31 O_{1} \\
32 \vdots \\
33 O_{Q} \\
34 \end{array}
35 \right]
36
37 Each sub-operator :math:`O_{q}` acts on the same input and returns parallel outputs which get stacked along the
38 zero-th axis.
39
40 Parameters
41 ----------
42 ops: :py:class:`~collections.abc.Sequence` ( :py:attr:`~pyxu.info.ptype.OpT` )
43 (Q,) identically-shaped operators to map over inputs.
44
45 Returns
46 -------
47 op: OpT
48 Stacked (M1,...,MD) -> (Q, N1,...,NK) operator.
49
50 Examples
51 --------
52
53 .. code-block:: python3
54
55 import pyxu.operator as pxo
56 import numpy as np
57
58 op = pxo.Sum((3, 4), axis=-1) # (3,4) -> (3,1)
59 A = pxo.stack([op, 2*op]) # (3,4) -> (2,3,1)
60
61 x = np.arange(A.dim_size).reshape(A.dim_shape) # [[ 0 1 2 3]
62 # [ 4 5 6 7]
63 # [ 8 9 10 11]]
64 y = A.apply(x) # [[[ 6.]
65 # [22.]
66 # [38.]]
67 #
68 # [[12.]
69 # [44.]
70 # [76.]]]
71
72
73 See Also
74 --------
75 :py:func:`~pyxu.operator.block_diag`
76 """
77 op = _Stack(ops).op()
78 return op
79
80
[docs]
81def block_diag(ops: cabc.Sequence[pxt.OpT]) -> pxt.OpT:
82 r"""
83 Zip operators over parallel inputs.
84
85 A block-diagonal operator :math:`B: \mathbb{R}^{Q \times M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{Q \times
86 N_{1} \times\cdots\times N_{K}}` is an operator containing (diagonally) :math:`Q` blocks of smaller operators
87 :math:`\{ O_{q}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{N_{1} \times\cdots\times N_{K}}
88 \}_{q=1}^{Q}`:
89
90 .. math::
91
92 B
93 =
94 \left[
95 \begin{array}{ccc}
96 O_{1} & & \\
97 & \ddots & \\
98 & & O_{Q} \\
99 \end{array}
100 \right]
101
102 Each sub-operator :math:`O_{q}` acts on the :math:`q`-th slice of the inputs along the zero-th axis.
103
104 Parameters
105 ----------
106 ops: :py:class:`~collections.abc.Sequence` ( :py:attr:`~pyxu.info.ptype.OpT` )
107 (Q,) identically-shaped operators to zip over inputs.
108
109 Returns
110 -------
111 op: OpT
112 Block-diagonal (Q, M1,...,MD) -> (Q, N1,...,NK) operator.
113
114 Examples
115 --------
116
117 .. code-block:: python3
118
119 import pyxu.operator as pxo
120 import numpy as np
121
122 op = pxo.Sum((3, 4), axis=-1) # (3,4) -> (3,1)
123 A = pxo.block_diag([op, 2*op]) # (2,3,4) -> (2,3,1)
124
125 x = np.arange(A.dim_size).reshape(A.dim_shape) # [[[ 0 1 2 3]
126 # [ 4 5 6 7]
127 # [ 8 9 10 11]]
128 #
129 # [[12 13 14 15]
130 # [16 17 18 19]
131 # [20 21 22 23]]]
132 y = A.apply(x) # [[[ 6.]
133 # [ 22.]
134 # [ 38.]]
135 #
136 # [[108.]
137 # [140.]
138 # [172.]]]
139
140
141 See Also
142 --------
143 :py:func:`~pyxu.operator.stack`
144 """
145 op = _BlockDiag(ops).op()
146 return op
147
148
149class _BlockDiag:
150 # See block_diag() docstrings.
151 def __init__(self, ops: cabc.Sequence[pxt.OpT]):
152 dim_shape = ops[0].dim_shape
153 codim_shape = ops[0].codim_shape
154
155 shape_msg = "All operators must have same dim/codim."
156 assert all(_op.dim_shape == dim_shape for _op in ops), shape_msg
157 assert all(_op.codim_shape == codim_shape for _op in ops), shape_msg
158
159 self._ops = list(ops)
160
161 def op(self) -> pxt.OpT:
162 klass = self._infer_op_klass()
163 N_op = len(self._ops)
164 dim_shape = self._ops[0].dim_shape
165 codim_shape = self._ops[0].codim_shape
166 op = klass(
167 dim_shape=(N_op, *dim_shape),
168 codim_shape=(N_op, *codim_shape),
169 )
170 op._ops = self._ops # embed for introspection
171 for p in op.properties():
172 for name in p.arithmetic_methods():
173 func = getattr(self.__class__, name)
174 setattr(op, name, types.MethodType(func, op))
175 self._propagate_constants(op)
176 return op
177
178 def _infer_op_klass(self) -> pxt.OpC:
179 base = {
180 pxa.Property.CAN_EVAL,
181 pxa.Property.DIFFERENTIABLE,
182 pxa.Property.LINEAR,
183 pxa.Property.LINEAR_SQUARE,
184 pxa.Property.LINEAR_NORMAL,
185 pxa.Property.LINEAR_IDEMPOTENT,
186 pxa.Property.LINEAR_SELF_ADJOINT,
187 pxa.Property.LINEAR_POSITIVE_DEFINITE,
188 pxa.Property.LINEAR_UNITARY,
189 }
190 properties = set.intersection(
191 base,
192 *[_op.properties() for _op in self._ops],
193 )
194 klass = pxa.Operator._infer_operator_type(properties)
195 return klass
196
197 @staticmethod
198 def _propagate_constants(op: pxt.OpT):
199 # Propagate (diff-)Lipschitz constants forward via special call to
200 # Rule()-overridden `estimate_[diff_]lipschitz()` methods.
201
202 # Important: we write to _[diff_]lipschitz to not overwrite estimate_[diff_]lipschitz() methods.
203 if op.has(pxa.Property.CAN_EVAL):
204 op._lipschitz = op.estimate_lipschitz(__rule=True)
205 if op.has(pxa.Property.DIFFERENTIABLE):
206 op._diff_lipschitz = op.estimate_diff_lipschitz(__rule=True)
207
208 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
209 N_stack = len(arr.shape[: -self.dim_rank])
210 select = lambda i: (slice(None),) * N_stack + (i,)
211 parts = [_op.apply(arr[select(i)]) for (i, _op) in enumerate(self._ops)]
212
213 xp = pxu.get_array_module(arr)
214 out = xp.stack(parts, axis=-self.codim_rank)
215 return out
216
217 def __call__(self, arr: pxt.NDArray) -> pxt.NDArray:
218 return self.apply(arr)
219
220 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
221 N_stack = len(arr.shape[: -self.codim_rank])
222 select = lambda i: (slice(None),) * N_stack + (i,)
223 parts = [_op.adjoint(arr[select(i)]) for (i, _op) in enumerate(self._ops)]
224
225 xp = pxu.get_array_module(arr)
226 out = xp.stack(parts, axis=-self.dim_rank)
227 return out
228
229 def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray:
230 # op.pinv(y, damp) = stack([op1.pinv(y1, damp), ..., opN.pinv(yN, damp)], axis=0)
231 N_stack = len(arr.shape[: -self.codim_rank])
232 select = lambda i: (slice(None),) * N_stack + (i,)
233 parts = [_op.pinv(arr[select(i)], damp) for (i, _op) in enumerate(self._ops)]
234
235 xp = pxu.get_array_module(arr)
236 out = xp.stack(parts, axis=-self.dim_rank)
237 return out
238
239 def svdvals(self, **kwargs) -> pxt.NDArray:
240 # op.svdvals(**kwargs) = top_k([op1.svdvals(**kwargs), ..., opN.svdvals(**kwargs)])
241 parts = [_op.svdvals(**kwargs) for _op in self._ops]
242
243 k = kwargs.get("k")
244 xp = pxu.get_array_module(parts[0])
245 D = xp.sort(xp.concatenate(parts))[-k:]
246 return D
247
248 def trace(self, **kwargs) -> pxt.Real:
249 # op.trace(**kwargs) = sum([op1.trace(**kwargs), ..., opN.trace(**kwargs)])
250 parts = [_op.trace(**kwargs) for _op in self._ops]
251 tr = sum(parts)
252 return tr
253
254 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT:
255 if self.has(pxa.Property.LINEAR):
256 J = self
257 else:
258 parts = [_op.jacobian(_arr) for (_op, _arr) in zip(self._ops, arr)]
259 J = _BlockDiag(ops=parts).op()
260 return J
261
262 def asarray(self, **kwargs) -> pxt.NDArray:
263 parts = [_op.asarray(**kwargs) for _op in self._ops]
264
265 xp = pxu.get_array_module(parts[0])
266 dtype = parts[0].dtype
267 A = xp.zeros((*self.codim_shape, *self.dim_shape), dtype=dtype)
268
269 select = (slice(None),) * (self.codim_rank - 1)
270 for i, _A in enumerate(parts):
271 A[(i,) + select + (i,)] = _A
272 return A
273
274 def gram(self) -> pxt.OpT:
275 parts = [_op.gram() for _op in self._ops]
276 G = _BlockDiag(ops=parts).op()
277 return G
278
279 def cogram(self) -> pxt.OpT:
280 parts = [_op.cogram() for _op in self._ops]
281 CG = _BlockDiag(ops=parts).op()
282 return CG
283
284 def estimate_lipschitz(self, **kwargs) -> pxt.Real:
285 no_eval = "__rule" in kwargs
286 if no_eval:
287 L_parts = [_op.lipschitz for _op in self._ops]
288 elif self.has(pxa.Property.LINEAR):
289 L = self.__class__.estimate_lipschitz(self, **kwargs)
290 return L
291 else:
292 L_parts = [_op.estimate_lipschitz(**kwargs) for _op in self._ops]
293
294 # [non-linear case] Upper bound: L <= max(L_k)
295 L = max(L_parts)
296 return L
297
298 def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real:
299 no_eval = "__rule" in kwargs
300 if no_eval:
301 dL_parts = [_op.diff_lipschitz for _op in self._ops]
302 elif self.has(pxa.Property.LINEAR):
303 dL = 0
304 return dL
305 else:
306 dL_parts = [_op.estimate_diff_lipschitz(**kwargs) for _op in self._ops]
307
308 # [non-linear case] Upper bound: dL <= max(dL_k)
309 dL = max(dL_parts)
310 return dL
311
312 def _expr(self) -> tuple:
313 return ("block_diag", *self._ops)
314
315
316class _Stack:
317 # See stack() docstrings.
318 def __init__(self, ops: cabc.Sequence[pxt.OpT]):
319 dim_shape = ops[0].dim_shape
320 codim_shape = ops[0].codim_shape
321
322 shape_msg = "All operators must have same dim/codim."
323 assert all(_op.dim_shape == dim_shape for _op in ops), shape_msg
324 assert all(_op.codim_shape == codim_shape for _op in ops), shape_msg
325
326 self._ops = list(ops)
327
328 def op(self) -> pxt.OpT:
329 klass = self._infer_op_klass()
330 N_op = len(self._ops)
331 dim_shape = self._ops[0].dim_shape
332 codim_shape = self._ops[0].codim_shape
333 op = klass(
334 dim_shape=dim_shape,
335 codim_shape=(N_op, *codim_shape),
336 )
337 op._ops = self._ops # embed for introspection
338 for p in op.properties():
339 for name in p.arithmetic_methods():
340 func = getattr(self.__class__, name, None)
341 if func is not None:
342 setattr(op, name, types.MethodType(func, op))
343 self._propagate_constants(op)
344 return op
345
346 def _infer_op_klass(self) -> pxt.OpC:
347 base = {
348 pxa.Property.CAN_EVAL,
349 pxa.Property.DIFFERENTIABLE,
350 pxa.Property.LINEAR,
351 }
352 properties = set.intersection(
353 base,
354 *[_op.properties() for _op in self._ops],
355 )
356 klass = pxa.Operator._infer_operator_type(properties)
357 return klass
358
359 @staticmethod
360 def _propagate_constants(op: pxt.OpT):
361 # Propagate (diff-)Lipschitz constants forward via special call to
362 # Rule()-overridden `estimate_[diff_]lipschitz()` methods.
363
364 # Important: we write to _[diff_]lipschitz to not overwrite estimate_[diff_]lipschitz() methods.
365 if op.has(pxa.Property.CAN_EVAL):
366 op._lipschitz = op.estimate_lipschitz(__rule=True)
367 if op.has(pxa.Property.DIFFERENTIABLE):
368 op._diff_lipschitz = op.estimate_diff_lipschitz(__rule=True)
369
370 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
371 parts = [_op.apply(arr) for _op in self._ops]
372
373 xp = pxu.get_array_module(arr)
374 out = xp.stack(parts, axis=-self.codim_rank)
375 return out
376
377 def __call__(self, arr: pxt.NDArray) -> pxt.NDArray:
378 return self.apply(arr)
379
380 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
381 N_stack = len(arr.shape[: -self.codim_rank])
382 select = lambda i: (slice(None),) * N_stack + (i,)
383 parts = [_op.adjoint(arr[select(i)]) for (i, _op) in enumerate(self._ops)]
384
385 out = sum(parts)
386 return out
387
388 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT:
389 if self.has(pxa.Property.LINEAR):
390 J = self
391 else:
392 parts = [_op.jacobian(arr) for _op in self._ops]
393 J = _Stack(ops=parts).op()
394 return J
395
396 def asarray(self, **kwargs) -> pxt.NDArray:
397 parts = [_op.asarray(**kwargs) for _op in self._ops]
398 xp = pxu.get_array_module(parts[0])
399 A = xp.stack(parts, axis=0)
400 return A
401
402 def gram(self) -> pxt.OpT:
403 # [_ops.gram()] should be reduced (via +) to form a single operator.
404 # It is inefficient however to chain so many operators together via AddRule().
405 # apply() is thus redefined to improve performance.
406
407 def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray:
408 G = [_op.gram() for _op in _._ops]
409 parts = [_G.apply(arr) for _G in G]
410
411 out = sum(parts)
412 return out
413
414 def op_expr(_) -> tuple:
415 return ("gram", self)
416
417 G = px_src.from_source(
418 cls=pxa.SelfAdjointOp,
419 dim_shape=self.dim_shape,
420 codim_shape=self.dim_shape,
421 embed=dict(_ops=self._ops),
422 apply=op_apply,
423 _expr=op_expr,
424 )
425 return G
426
427 def estimate_lipschitz(self, **kwargs) -> pxt.Real:
428 no_eval = "__rule" in kwargs
429 if no_eval:
430 L_parts = [_op.lipschitz for _op in self._ops]
431 elif self.has(pxa.Property.LINEAR):
432 L = self.__class__.estimate_lipschitz(self, **kwargs)
433 return L
434 else:
435 L_parts = [_op.estimate_lipschitz(**kwargs) for _op in self._ops]
436
437 # [non-linear case] Upper bound: L**2 <= sum(L_k**2)
438 L2 = np.r_[L_parts] ** 2
439 L = np.sqrt(L2.sum())
440 return L
441
442 def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real:
443 no_eval = "__rule" in kwargs
444 if no_eval:
445 dL_parts = [_op.diff_lipschitz for _op in self._ops]
446 elif self.has(pxa.Property.LINEAR):
447 dL = 0
448 return dL
449 else:
450 dL_parts = [_op.estimate_diff_lipschitz(**kwargs) for _op in self._ops]
451
452 # [non-linear case] Upper bound: dL**2 <= sum(dL_k**2)
453 dL2 = np.r_[dL_parts] ** 2
454 dL = np.sqrt(dL2.sum())
455 return dL
456
457 def _expr(self) -> tuple:
458 return ("stack", *self._ops)