Source code for pyxu.operator.blocks

  1import collections.abc as cabc
  2import types
  3
  4import numpy as np
  5
  6import pyxu.abc as pxa
  7import pyxu.info.ptype as pxt
  8import pyxu.operator.interop.source as px_src
  9import pyxu.util as pxu
 10
 11__all__ = [
 12    "stack",
 13    "block_diag",
 14]
 15
 16
[docs] 17def stack(ops: cabc.Sequence[pxt.OpT]) -> pxt.OpT: 18 r""" 19 Map operators over the same input. 20 21 A stacked operator :math:`S: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{Q \times N_{1} 22 \times\cdots\times N_{K}} is an operator containing (vertically) :math:`Q` blocks of smaller operators :math:`\{ 23 O_{q}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{N_{1} \times\cdots\times N_{K}} \}_{q=1}^{Q}`: 24 25 .. math:: 26 27 S 28 = 29 \left[ 30 \begin{array}{c} 31 O_{1} \\ 32 \vdots \\ 33 O_{Q} \\ 34 \end{array} 35 \right] 36 37 Each sub-operator :math:`O_{q}` acts on the same input and returns parallel outputs which get stacked along the 38 zero-th axis. 39 40 Parameters 41 ---------- 42 ops: :py:class:`~collections.abc.Sequence` ( :py:attr:`~pyxu.info.ptype.OpT` ) 43 (Q,) identically-shaped operators to map over inputs. 44 45 Returns 46 ------- 47 op: OpT 48 Stacked (M1,...,MD) -> (Q, N1,...,NK) operator. 49 50 Examples 51 -------- 52 53 .. code-block:: python3 54 55 import pyxu.operator as pxo 56 import numpy as np 57 58 op = pxo.Sum((3, 4), axis=-1) # (3,4) -> (3,1) 59 A = pxo.stack([op, 2*op]) # (3,4) -> (2,3,1) 60 61 x = np.arange(A.dim_size).reshape(A.dim_shape) # [[ 0 1 2 3] 62 # [ 4 5 6 7] 63 # [ 8 9 10 11]] 64 y = A.apply(x) # [[[ 6.] 65 # [22.] 66 # [38.]] 67 # 68 # [[12.] 69 # [44.] 70 # [76.]]] 71 72 73 See Also 74 -------- 75 :py:func:`~pyxu.operator.block_diag` 76 """ 77 op = _Stack(ops).op() 78 return op
79 80
[docs] 81def block_diag(ops: cabc.Sequence[pxt.OpT]) -> pxt.OpT: 82 r""" 83 Zip operators over parallel inputs. 84 85 A block-diagonal operator :math:`B: \mathbb{R}^{Q \times M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{Q \times 86 N_{1} \times\cdots\times N_{K}}` is an operator containing (diagonally) :math:`Q` blocks of smaller operators 87 :math:`\{ O_{q}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{N_{1} \times\cdots\times N_{K}} 88 \}_{q=1}^{Q}`: 89 90 .. math:: 91 92 B 93 = 94 \left[ 95 \begin{array}{ccc} 96 O_{1} & & \\ 97 & \ddots & \\ 98 & & O_{Q} \\ 99 \end{array} 100 \right] 101 102 Each sub-operator :math:`O_{q}` acts on the :math:`q`-th slice of the inputs along the zero-th axis. 103 104 Parameters 105 ---------- 106 ops: :py:class:`~collections.abc.Sequence` ( :py:attr:`~pyxu.info.ptype.OpT` ) 107 (Q,) identically-shaped operators to zip over inputs. 108 109 Returns 110 ------- 111 op: OpT 112 Block-diagonal (Q, M1,...,MD) -> (Q, N1,...,NK) operator. 113 114 Examples 115 -------- 116 117 .. code-block:: python3 118 119 import pyxu.operator as pxo 120 import numpy as np 121 122 op = pxo.Sum((3, 4), axis=-1) # (3,4) -> (3,1) 123 A = pxo.block_diag([op, 2*op]) # (2,3,4) -> (2,3,1) 124 125 x = np.arange(A.dim_size).reshape(A.dim_shape) # [[[ 0 1 2 3] 126 # [ 4 5 6 7] 127 # [ 8 9 10 11]] 128 # 129 # [[12 13 14 15] 130 # [16 17 18 19] 131 # [20 21 22 23]]] 132 y = A.apply(x) # [[[ 6.] 133 # [ 22.] 134 # [ 38.]] 135 # 136 # [[108.] 137 # [140.] 138 # [172.]]] 139 140 141 See Also 142 -------- 143 :py:func:`~pyxu.operator.stack` 144 """ 145 op = _BlockDiag(ops).op() 146 return op
147 148 149class _BlockDiag: 150 # See block_diag() docstrings. 151 def __init__(self, ops: cabc.Sequence[pxt.OpT]): 152 dim_shape = ops[0].dim_shape 153 codim_shape = ops[0].codim_shape 154 155 shape_msg = "All operators must have same dim/codim." 156 assert all(_op.dim_shape == dim_shape for _op in ops), shape_msg 157 assert all(_op.codim_shape == codim_shape for _op in ops), shape_msg 158 159 self._ops = list(ops) 160 161 def op(self) -> pxt.OpT: 162 klass = self._infer_op_klass() 163 N_op = len(self._ops) 164 dim_shape = self._ops[0].dim_shape 165 codim_shape = self._ops[0].codim_shape 166 op = klass( 167 dim_shape=(N_op, *dim_shape), 168 codim_shape=(N_op, *codim_shape), 169 ) 170 op._ops = self._ops # embed for introspection 171 for p in op.properties(): 172 for name in p.arithmetic_methods(): 173 func = getattr(self.__class__, name) 174 setattr(op, name, types.MethodType(func, op)) 175 self._propagate_constants(op) 176 return op 177 178 def _infer_op_klass(self) -> pxt.OpC: 179 base = { 180 pxa.Property.CAN_EVAL, 181 pxa.Property.DIFFERENTIABLE, 182 pxa.Property.LINEAR, 183 pxa.Property.LINEAR_SQUARE, 184 pxa.Property.LINEAR_NORMAL, 185 pxa.Property.LINEAR_IDEMPOTENT, 186 pxa.Property.LINEAR_SELF_ADJOINT, 187 pxa.Property.LINEAR_POSITIVE_DEFINITE, 188 pxa.Property.LINEAR_UNITARY, 189 } 190 properties = set.intersection( 191 base, 192 *[_op.properties() for _op in self._ops], 193 ) 194 klass = pxa.Operator._infer_operator_type(properties) 195 return klass 196 197 @staticmethod 198 def _propagate_constants(op: pxt.OpT): 199 # Propagate (diff-)Lipschitz constants forward via special call to 200 # Rule()-overridden `estimate_[diff_]lipschitz()` methods. 201 202 # Important: we write to _[diff_]lipschitz to not overwrite estimate_[diff_]lipschitz() methods. 203 if op.has(pxa.Property.CAN_EVAL): 204 op._lipschitz = op.estimate_lipschitz(__rule=True) 205 if op.has(pxa.Property.DIFFERENTIABLE): 206 op._diff_lipschitz = op.estimate_diff_lipschitz(__rule=True) 207 208 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 209 N_stack = len(arr.shape[: -self.dim_rank]) 210 select = lambda i: (slice(None),) * N_stack + (i,) 211 parts = [_op.apply(arr[select(i)]) for (i, _op) in enumerate(self._ops)] 212 213 xp = pxu.get_array_module(arr) 214 out = xp.stack(parts, axis=-self.codim_rank) 215 return out 216 217 def __call__(self, arr: pxt.NDArray) -> pxt.NDArray: 218 return self.apply(arr) 219 220 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 221 N_stack = len(arr.shape[: -self.codim_rank]) 222 select = lambda i: (slice(None),) * N_stack + (i,) 223 parts = [_op.adjoint(arr[select(i)]) for (i, _op) in enumerate(self._ops)] 224 225 xp = pxu.get_array_module(arr) 226 out = xp.stack(parts, axis=-self.dim_rank) 227 return out 228 229 def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray: 230 # op.pinv(y, damp) = stack([op1.pinv(y1, damp), ..., opN.pinv(yN, damp)], axis=0) 231 N_stack = len(arr.shape[: -self.codim_rank]) 232 select = lambda i: (slice(None),) * N_stack + (i,) 233 parts = [_op.pinv(arr[select(i)], damp) for (i, _op) in enumerate(self._ops)] 234 235 xp = pxu.get_array_module(arr) 236 out = xp.stack(parts, axis=-self.dim_rank) 237 return out 238 239 def svdvals(self, **kwargs) -> pxt.NDArray: 240 # op.svdvals(**kwargs) = top_k([op1.svdvals(**kwargs), ..., opN.svdvals(**kwargs)]) 241 parts = [_op.svdvals(**kwargs) for _op in self._ops] 242 243 k = kwargs.get("k") 244 xp = pxu.get_array_module(parts[0]) 245 D = xp.sort(xp.concatenate(parts))[-k:] 246 return D 247 248 def trace(self, **kwargs) -> pxt.Real: 249 # op.trace(**kwargs) = sum([op1.trace(**kwargs), ..., opN.trace(**kwargs)]) 250 parts = [_op.trace(**kwargs) for _op in self._ops] 251 tr = sum(parts) 252 return tr 253 254 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT: 255 if self.has(pxa.Property.LINEAR): 256 J = self 257 else: 258 parts = [_op.jacobian(_arr) for (_op, _arr) in zip(self._ops, arr)] 259 J = _BlockDiag(ops=parts).op() 260 return J 261 262 def asarray(self, **kwargs) -> pxt.NDArray: 263 parts = [_op.asarray(**kwargs) for _op in self._ops] 264 265 xp = pxu.get_array_module(parts[0]) 266 dtype = parts[0].dtype 267 A = xp.zeros((*self.codim_shape, *self.dim_shape), dtype=dtype) 268 269 select = (slice(None),) * (self.codim_rank - 1) 270 for i, _A in enumerate(parts): 271 A[(i,) + select + (i,)] = _A 272 return A 273 274 def gram(self) -> pxt.OpT: 275 parts = [_op.gram() for _op in self._ops] 276 G = _BlockDiag(ops=parts).op() 277 return G 278 279 def cogram(self) -> pxt.OpT: 280 parts = [_op.cogram() for _op in self._ops] 281 CG = _BlockDiag(ops=parts).op() 282 return CG 283 284 def estimate_lipschitz(self, **kwargs) -> pxt.Real: 285 no_eval = "__rule" in kwargs 286 if no_eval: 287 L_parts = [_op.lipschitz for _op in self._ops] 288 elif self.has(pxa.Property.LINEAR): 289 L = self.__class__.estimate_lipschitz(self, **kwargs) 290 return L 291 else: 292 L_parts = [_op.estimate_lipschitz(**kwargs) for _op in self._ops] 293 294 # [non-linear case] Upper bound: L <= max(L_k) 295 L = max(L_parts) 296 return L 297 298 def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real: 299 no_eval = "__rule" in kwargs 300 if no_eval: 301 dL_parts = [_op.diff_lipschitz for _op in self._ops] 302 elif self.has(pxa.Property.LINEAR): 303 dL = 0 304 return dL 305 else: 306 dL_parts = [_op.estimate_diff_lipschitz(**kwargs) for _op in self._ops] 307 308 # [non-linear case] Upper bound: dL <= max(dL_k) 309 dL = max(dL_parts) 310 return dL 311 312 def _expr(self) -> tuple: 313 return ("block_diag", *self._ops) 314 315 316class _Stack: 317 # See stack() docstrings. 318 def __init__(self, ops: cabc.Sequence[pxt.OpT]): 319 dim_shape = ops[0].dim_shape 320 codim_shape = ops[0].codim_shape 321 322 shape_msg = "All operators must have same dim/codim." 323 assert all(_op.dim_shape == dim_shape for _op in ops), shape_msg 324 assert all(_op.codim_shape == codim_shape for _op in ops), shape_msg 325 326 self._ops = list(ops) 327 328 def op(self) -> pxt.OpT: 329 klass = self._infer_op_klass() 330 N_op = len(self._ops) 331 dim_shape = self._ops[0].dim_shape 332 codim_shape = self._ops[0].codim_shape 333 op = klass( 334 dim_shape=dim_shape, 335 codim_shape=(N_op, *codim_shape), 336 ) 337 op._ops = self._ops # embed for introspection 338 for p in op.properties(): 339 for name in p.arithmetic_methods(): 340 func = getattr(self.__class__, name, None) 341 if func is not None: 342 setattr(op, name, types.MethodType(func, op)) 343 self._propagate_constants(op) 344 return op 345 346 def _infer_op_klass(self) -> pxt.OpC: 347 base = { 348 pxa.Property.CAN_EVAL, 349 pxa.Property.DIFFERENTIABLE, 350 pxa.Property.LINEAR, 351 } 352 properties = set.intersection( 353 base, 354 *[_op.properties() for _op in self._ops], 355 ) 356 klass = pxa.Operator._infer_operator_type(properties) 357 return klass 358 359 @staticmethod 360 def _propagate_constants(op: pxt.OpT): 361 # Propagate (diff-)Lipschitz constants forward via special call to 362 # Rule()-overridden `estimate_[diff_]lipschitz()` methods. 363 364 # Important: we write to _[diff_]lipschitz to not overwrite estimate_[diff_]lipschitz() methods. 365 if op.has(pxa.Property.CAN_EVAL): 366 op._lipschitz = op.estimate_lipschitz(__rule=True) 367 if op.has(pxa.Property.DIFFERENTIABLE): 368 op._diff_lipschitz = op.estimate_diff_lipschitz(__rule=True) 369 370 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 371 parts = [_op.apply(arr) for _op in self._ops] 372 373 xp = pxu.get_array_module(arr) 374 out = xp.stack(parts, axis=-self.codim_rank) 375 return out 376 377 def __call__(self, arr: pxt.NDArray) -> pxt.NDArray: 378 return self.apply(arr) 379 380 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 381 N_stack = len(arr.shape[: -self.codim_rank]) 382 select = lambda i: (slice(None),) * N_stack + (i,) 383 parts = [_op.adjoint(arr[select(i)]) for (i, _op) in enumerate(self._ops)] 384 385 out = sum(parts) 386 return out 387 388 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT: 389 if self.has(pxa.Property.LINEAR): 390 J = self 391 else: 392 parts = [_op.jacobian(arr) for _op in self._ops] 393 J = _Stack(ops=parts).op() 394 return J 395 396 def asarray(self, **kwargs) -> pxt.NDArray: 397 parts = [_op.asarray(**kwargs) for _op in self._ops] 398 xp = pxu.get_array_module(parts[0]) 399 A = xp.stack(parts, axis=0) 400 return A 401 402 def gram(self) -> pxt.OpT: 403 # [_ops.gram()] should be reduced (via +) to form a single operator. 404 # It is inefficient however to chain so many operators together via AddRule(). 405 # apply() is thus redefined to improve performance. 406 407 def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray: 408 G = [_op.gram() for _op in _._ops] 409 parts = [_G.apply(arr) for _G in G] 410 411 out = sum(parts) 412 return out 413 414 def op_expr(_) -> tuple: 415 return ("gram", self) 416 417 G = px_src.from_source( 418 cls=pxa.SelfAdjointOp, 419 dim_shape=self.dim_shape, 420 codim_shape=self.dim_shape, 421 embed=dict(_ops=self._ops), 422 apply=op_apply, 423 _expr=op_expr, 424 ) 425 return G 426 427 def estimate_lipschitz(self, **kwargs) -> pxt.Real: 428 no_eval = "__rule" in kwargs 429 if no_eval: 430 L_parts = [_op.lipschitz for _op in self._ops] 431 elif self.has(pxa.Property.LINEAR): 432 L = self.__class__.estimate_lipschitz(self, **kwargs) 433 return L 434 else: 435 L_parts = [_op.estimate_lipschitz(**kwargs) for _op in self._ops] 436 437 # [non-linear case] Upper bound: L**2 <= sum(L_k**2) 438 L2 = np.r_[L_parts] ** 2 439 L = np.sqrt(L2.sum()) 440 return L 441 442 def estimate_diff_lipschitz(self, **kwargs) -> pxt.Real: 443 no_eval = "__rule" in kwargs 444 if no_eval: 445 dL_parts = [_op.diff_lipschitz for _op in self._ops] 446 elif self.has(pxa.Property.LINEAR): 447 dL = 0 448 return dL 449 else: 450 dL_parts = [_op.estimate_diff_lipschitz(**kwargs) for _op in self._ops] 451 452 # [non-linear case] Upper bound: dL**2 <= sum(dL_k**2) 453 dL2 = np.r_[dL_parts] ** 2 454 dL = np.sqrt(dL2.sum()) 455 return dL 456 457 def _expr(self) -> tuple: 458 return ("stack", *self._ops)