Source code for pyxu.operator.linop.reduce

 1import numpy as np
 2
 3import pyxu.abc as pxa
 4import pyxu.info.ptype as pxt
 5import pyxu.util as pxu
 6
 7__all__ = [
 8    "Sum",
 9]
10
11
[docs] 12class Sum(pxa.LinOp): 13 r""" 14 Multi-dimensional sum reduction :math:`\mathbf{A}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{N_{1} 15 \times\cdots\times N_{D}}`. 16 17 Notes 18 ----- 19 * The co-dimension rank **always** matches the dimension rank, i.e. summed-over dimensions are not dropped. 20 Single-element dimensions can be removed by composing :py:class:`~pyxu.operator.Sum` with 21 :py:class:`~pyxu.operator.SqueezeAxes`. 22 23 * The matrix operator of a 1D reduction applied to :math:`\mathbf{x} \in \mathbb{R}^{M}` is given by 24 25 .. math:: 26 27 \mathbf{A}(x) = \mathbf{1}^{T} \mathbf{x}, 28 29 where :math:`\sigma_{\max}(\mathbf{A}) = \sqrt{M}`. An ND reduction is a chain of 1D reductions in orthogonal 30 dimensions. Hence the Lipschitz constant of an ND reduction is the product of Lipschitz constants of all 1D 31 reductions involved, i.e.: 32 33 .. math:: 34 35 L = \sqrt{\prod_{i_{k}} M_{i_{k}}}, 36 37 where :math:`\{i_{k}\}_{k}` denotes the axes being summed over. 38 """ 39
[docs] 40 def __init__( 41 self, 42 dim_shape: pxt.NDArrayShape, 43 axis: pxt.NDArrayAxis = None, 44 ): 45 r""" 46 Multi-dimensional sum reduction. 47 48 Parameters 49 ---------- 50 dim_shape: NDArrayShape 51 (M1,...,MD) domain dimensions. 52 axis: NDArrayAxis 53 Axis or axes along which a sum is performed. The default, axis=None, will sum all the elements of the input 54 array. 55 """ 56 super().__init__( 57 dim_shape=dim_shape, 58 codim_shape=dim_shape, # temporary; to canonicalize dim_shape. 59 ) 60 61 if axis is None: 62 axis = np.arange(self.dim_rank) 63 axis = pxu.as_canonical_axes(axis, rank=self.dim_rank) 64 axis = set(axis) # drop duplicates 65 66 codim_shape = list(self.dim_shape) # array shape after reduction 67 for i in range(self.dim_rank): 68 if i in axis: 69 codim_shape[i] = 1 70 self._codim_shape = tuple(codim_shape) 71 72 self._axis = tuple(axis) 73 self.lipschitz = self.estimate_lipschitz()
74
[docs] 75 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 76 sh = arr.shape[: -self.dim_rank] 77 axis = tuple(ax + len(sh) for ax in self._axis) 78 out = arr.sum(axis=axis, keepdims=True) 79 return out
80
[docs] 81 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 82 sh = arr.shape[: -self.codim_rank] 83 xp = pxu.get_array_module(arr) 84 out = xp.broadcast_to(arr, sh + self.dim_shape) 85 return out
86
[docs] 87 def estimate_lipschitz(self, **kwargs) -> pxt.Real: 88 M = np.prod(self.dim_shape) / np.prod(self.codim_shape) 89 L = np.sqrt(M) 90 return L