Source code for pyxu.operator.linop.reduce
1import numpy as np
2
3import pyxu.abc as pxa
4import pyxu.info.ptype as pxt
5import pyxu.util as pxu
6
7__all__ = [
8 "Sum",
9]
10
11
[docs]
12class Sum(pxa.LinOp):
13 r"""
14 Multi-dimensional sum reduction :math:`\mathbf{A}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{N_{1}
15 \times\cdots\times N_{D}}`.
16
17 Notes
18 -----
19 * The co-dimension rank **always** matches the dimension rank, i.e. summed-over dimensions are not dropped.
20 Single-element dimensions can be removed by composing :py:class:`~pyxu.operator.Sum` with
21 :py:class:`~pyxu.operator.SqueezeAxes`.
22
23 * The matrix operator of a 1D reduction applied to :math:`\mathbf{x} \in \mathbb{R}^{M}` is given by
24
25 .. math::
26
27 \mathbf{A}(x) = \mathbf{1}^{T} \mathbf{x},
28
29 where :math:`\sigma_{\max}(\mathbf{A}) = \sqrt{M}`. An ND reduction is a chain of 1D reductions in orthogonal
30 dimensions. Hence the Lipschitz constant of an ND reduction is the product of Lipschitz constants of all 1D
31 reductions involved, i.e.:
32
33 .. math::
34
35 L = \sqrt{\prod_{i_{k}} M_{i_{k}}},
36
37 where :math:`\{i_{k}\}_{k}` denotes the axes being summed over.
38 """
39
[docs]
40 def __init__(
41 self,
42 dim_shape: pxt.NDArrayShape,
43 axis: pxt.NDArrayAxis = None,
44 ):
45 r"""
46 Multi-dimensional sum reduction.
47
48 Parameters
49 ----------
50 dim_shape: NDArrayShape
51 (M1,...,MD) domain dimensions.
52 axis: NDArrayAxis
53 Axis or axes along which a sum is performed. The default, axis=None, will sum all the elements of the input
54 array.
55 """
56 super().__init__(
57 dim_shape=dim_shape,
58 codim_shape=dim_shape, # temporary; to canonicalize dim_shape.
59 )
60
61 if axis is None:
62 axis = np.arange(self.dim_rank)
63 axis = pxu.as_canonical_axes(axis, rank=self.dim_rank)
64 axis = set(axis) # drop duplicates
65
66 codim_shape = list(self.dim_shape) # array shape after reduction
67 for i in range(self.dim_rank):
68 if i in axis:
69 codim_shape[i] = 1
70 self._codim_shape = tuple(codim_shape)
71
72 self._axis = tuple(axis)
73 self.lipschitz = self.estimate_lipschitz()
74
[docs]
75 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
76 sh = arr.shape[: -self.dim_rank]
77 axis = tuple(ax + len(sh) for ax in self._axis)
78 out = arr.sum(axis=axis, keepdims=True)
79 return out
80
[docs]
81 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
82 sh = arr.shape[: -self.codim_rank]
83 xp = pxu.get_array_module(arr)
84 out = xp.broadcast_to(arr, sh + self.dim_shape)
85 return out
86
[docs]
87 def estimate_lipschitz(self, **kwargs) -> pxt.Real:
88 M = np.prod(self.dim_shape) / np.prod(self.codim_shape)
89 L = np.sqrt(M)
90 return L