Source code for pyxu.experimental.sampler.statistics

  1import math
  2import types
  3import typing as typ
  4
  5import pyxu.info.ptype as pxt
  6import pyxu.util as pxu
  7
  8__all__ = [
  9    "OnlineMoment",
 10    "OnlineCenteredMoment",
 11    "OnlineVariance",
 12    "OnlineStd",
 13    "OnlineSkewness",
 14    "OnlineKurtosis",
 15]
 16
 17
 18class _OnlineStat:
 19    """
 20    Abstract base class to compute online statistics based on outputs of a Sampler object.
 21
 22    An _OnlineStat object should be paired with a single Sampler object.
 23
 24    Composite _OnlineStat objects can be implemented via the overloaded +, -, *, and ** operators.
 25    """
 26
 27    def __init__(self):
 28        self._num_samples = 0
 29        self._stat = None
 30
 31    def update(self, x: pxt.NDArray) -> pxt.NDArray:
 32        """
 33        Update the online statistic based on a new sample. Should update `_num_samples` and `_stat` attributes.
 34
 35        Parameters
 36        ----------
 37        x: NDArray
 38            New sample.
 39
 40        Returns
 41        -------
 42        stat: NDArray
 43            Updated statistic.
 44        """
 45        raise NotImplementedError
 46
 47    def stat(self) -> pxt.NDArray:
 48        """Get current online statistic estimate."""
 49        return self._stat
 50
 51    def __add__(self, other: "_OnlineStat"):
 52        stat = _OnlineStat()
 53        stat_update = lambda _, x: self.update(x) + other.update(x)
 54        stat.update = types.MethodType(stat_update, stat)
 55        return stat
 56
 57    def __sub__(self, other: "_OnlineStat"):
 58        stat = _OnlineStat()
 59        stat_update = lambda _, x: self.update(x) - other.update(x)
 60        stat.update = types.MethodType(stat_update, stat)
 61        return stat
 62
 63    def __mul__(self, other: typ.Union[pxt.Real, pxt.Integer, "_OnlineStat"]):
 64        stat = _OnlineStat()
 65        if isinstance(other, _OnlineStat):
 66            stat_update = lambda _, x: self.update(x) * other.update(x)
 67        elif isinstance(other, pxt.Real) or isinstance(other, pxt.Integer):
 68            stat_update = lambda _, x: self.update(x) * other
 69        else:
 70            return NotImplemented
 71        stat.update = types.MethodType(stat_update, stat)
 72        return stat
 73
 74    def __rmul__(self, other: typ.Union[pxt.Real, pxt.Integer]):
 75        return self.__mul__(other)
 76
 77    def __truediv__(self, other: typ.Union[pxt.Real, pxt.Integer, "_OnlineStat"]):
 78        stat = _OnlineStat()
 79        if isinstance(other, _OnlineStat):
 80
 81            def stat_update(stat, x):
 82                xp = pxu.get_array_module(x)
 83                out = xp.divide(self.update(x), other.update(x), where=(other.update(x) != 0))
 84                out[other.update(x) == 0] = xp.nan
 85                return out
 86
 87        elif isinstance(other, pxt.Real) or isinstance(other, pxt.Integer):
 88            stat_update = lambda _, x: self.update(x) / other
 89        else:
 90            return NotImplemented
 91        stat.update = types.MethodType(stat_update, stat)
 92        return stat
 93
 94    def __pow__(self, expo: typ.Union[pxt.Real, pxt.Integer]):
 95        if not (isinstance(expo, pxt.Real) or isinstance(expo, pxt.Integer)):
 96            return NotImplemented
 97        stat = _OnlineStat()
 98        stat_update = lambda _, x: self.update(x) ** expo
 99        stat.update = types.MethodType(stat_update, stat)
100        return stat
101
102
[docs] 103class OnlineMoment(_OnlineStat): 104 r""" 105 Pointwise online moment. 106 107 For :math:`d \geq 1`, the :math:`d`-th order centered moment of the :math:`K` samples :math:`(\mathbf{x}_k)_{1 \leq 108 k \leq K}` is given by :math:`\frac{1}{K}\sum_{k=1}^K \mathbf{x}_k^d`. 109 """ 110 111 def __init__(self, order: pxt.Real = 1): 112 super().__init__() 113 self._order = order 114
[docs] 115 def update(self, x: pxt.NDArray) -> pxt.NDArray: 116 if self._num_samples == 0: 117 self._stat = x**self._order 118 else: 119 self._stat *= self._num_samples 120 self._stat += x**self._order 121 self._num_samples += 1 122 self._stat /= self._num_samples 123 return self._stat
124 125
[docs] 126class OnlineCenteredMoment(_OnlineStat): 127 r""" 128 Pointwise online centered moment. 129 130 For :math:`d \geq 2`, the :math:`d`-th order centered moment of the :math:`K` samples :math:`(\mathbf{x}_k)_{1 \leq 131 k \leq K}` is given by :math:`\boldsymbol{\mu}_d=\frac{1}{K} \sum_{k=1}^K (\mathbf{x}_k-\boldsymbol{\mu})^d`, where 132 :math:`\boldsymbol{\mu}` is the sample mean. 133 134 Notes 135 ----- 136 This class implements the *Welford algorithm* described in [WelfordAlg]_, which is a numerically stable algorithm 137 for computing online centered moments. In particular, it avoids `catastrophic cancellation 138 <https://en.wikipedia.org/wiki/Catastrophic_cancellation>`_ issues that may arise when naively computing online 139 centered moments, which would lead to a loss of numerical precision. 140 141 Note that this class internally stores the values of all online centered moments of order :math:`d'` for :math:`2 142 \leq d' \leq d` in the attribute ``_corrected_sums`` as well as the online mean (``_mean`` attribute). More 143 precisely, the array ``_corrected_sums[i, :]`` corresponds to the online sum :math:`\boldsymbol{\mu}_{i+2}= 144 \sum_{k=1}^K (\mathbf{x}_k-\boldsymbol{\mu})^{i+2}` for :math:`0 \leq i \leq d-2`. 145 """ 146 147 def __init__(self, order: pxt.Real = 2): 148 super().__init__() 149 self._order = order 150 self._corrected_sums = None 151 self._mean = None 152
[docs] 153 def update(self, x: pxt.NDArray) -> pxt.NDArray: 154 if self._num_samples == 0: 155 xp = pxu.get_array_module(x) 156 self._corrected_sums = xp.zeros((self._order - 1,) + x.shape) 157 self._mean = x.copy() 158 else: 159 temp = (x - self._mean) / (self._num_samples + 1) 160 for r in range(self._order, 1, -1): # Update in descending order because updates depend on lower orders 161 for s in range(2, r): # s = r term excluded since it corresponds to previous iterate 162 self._corrected_sums[r - 2, :] += ( 163 math.comb(r, s) * self._corrected_sums[s - 2, :] * (-temp) ** (r - s) 164 ) 165 self._corrected_sums[r - 2, :] += self._num_samples * (-temp) ** r # Contribution of s = 0 term 166 self._corrected_sums[r - 2, :] += (self._num_samples * temp) ** r 167 self._mean *= self._num_samples / (self._num_samples + 1) 168 self._mean += x / (self._num_samples + 1) 169 self._num_samples += 1 170 self._stat = self._corrected_sums[-1, :] / self._num_samples 171 return self._stat
172 173
[docs] 174def OnlineVariance(): 175 r""" 176 Pointwise online variance. 177 178 The pointwise online variance of the :math:`K` samples :math:`(\mathbf{x}_k)_{1 \leq k \leq K}` is given by 179 :math:`\boldsymbol{\sigma}^2 = \frac{1}{K}\sum_{k=1}^K (\mathbf{x}_k-\boldsymbol{\mu})^2`, where 180 :math:`\boldsymbol{\mu}` is the sample mean. 181 """ 182 return OnlineCenteredMoment(order=2)
183 184
[docs] 185def OnlineStd(): 186 r""" 187 Pointwise online standard deviation. 188 189 The pointwise online standard deviation of the :math:`K` samples :math:`(\mathbf{x}_k)_{1 \leq k \leq K}` is given 190 by :math:`\boldsymbol{\sigma}=\sqrt{\frac{1}{K}\sum_{k=1}^K (\mathbf{x}_k - \boldsymbol{\mu})^2}`, where 191 :math:`\boldsymbol{\mu}` is the sample mean. 192 """ 193 return OnlineVariance() ** (1 / 2)
194 195
[docs] 196def OnlineSkewness(): 197 r""" 198 Pointwise online skewness. 199 200 The pointwise online skewness of the :math:`K` samples :math:`(\mathbf{x}_k)_{1 \leq k \leq K}` is given by 201 :math:`\frac{1}{K}\sum_{k=1}^K \left( \frac{\mathbf{x}_k-\boldsymbol{\mu}}{\boldsymbol{\sigma}}\right)^3`, where 202 :math:`\boldsymbol{\mu}` is the sample mean and :math:`\boldsymbol{\sigma}` its standard deviation. 203 204 `Skewness <https://en.wikipedia.org/wiki/Skewness>`_ is a measure of asymmetry of a distribution around its mean. 205 Negative skewness indicates that the distribution has a heavier tail on the left side than on the right side, 206 positive skewness indicates the opposite, and values close to zero indicate a symmetric distribution. 207 """ 208 return OnlineCenteredMoment(order=3) / OnlineStd() ** 3
209 210
[docs] 211def OnlineKurtosis(): 212 r""" 213 Pointwise online kurtosis. 214 215 The pointwise online variance of the :math:`K` samples :math:`(\mathbf{x}_k)_{1 \leq k \leq K}` is given by 216 :math:`\frac{1}{K}\sum_{k=1}^K \left( \frac{\mathbf{x}_k-\boldsymbol{\mu}}{\boldsymbol{\sigma}}\right)^4`, where 217 :math:`\boldsymbol{\mu}` is the sample mean and :math:`\boldsymbol{\sigma}` its standard deviation. 218 219 `Kurtosis <https://en.wikipedia.org/wiki/Kurtosis>`_ is a measure of the heavy-tailedness of a distribution. In 220 particular, the kurtosis of a Gaussian distribution is always 3. 221 """ 222 return OnlineCenteredMoment(order=4) / OnlineStd() ** 4