1import math
2import types
3import typing as typ
4
5import pyxu.info.ptype as pxt
6import pyxu.util as pxu
7
8__all__ = [
9 "OnlineMoment",
10 "OnlineCenteredMoment",
11 "OnlineVariance",
12 "OnlineStd",
13 "OnlineSkewness",
14 "OnlineKurtosis",
15]
16
17
18class _OnlineStat:
19 """
20 Abstract base class to compute online statistics based on outputs of a Sampler object.
21
22 An _OnlineStat object should be paired with a single Sampler object.
23
24 Composite _OnlineStat objects can be implemented via the overloaded +, -, *, and ** operators.
25 """
26
27 def __init__(self):
28 self._num_samples = 0
29 self._stat = None
30
31 def update(self, x: pxt.NDArray) -> pxt.NDArray:
32 """
33 Update the online statistic based on a new sample. Should update `_num_samples` and `_stat` attributes.
34
35 Parameters
36 ----------
37 x: NDArray
38 New sample.
39
40 Returns
41 -------
42 stat: NDArray
43 Updated statistic.
44 """
45 raise NotImplementedError
46
47 def stat(self) -> pxt.NDArray:
48 """Get current online statistic estimate."""
49 return self._stat
50
51 def __add__(self, other: "_OnlineStat"):
52 stat = _OnlineStat()
53 stat_update = lambda _, x: self.update(x) + other.update(x)
54 stat.update = types.MethodType(stat_update, stat)
55 return stat
56
57 def __sub__(self, other: "_OnlineStat"):
58 stat = _OnlineStat()
59 stat_update = lambda _, x: self.update(x) - other.update(x)
60 stat.update = types.MethodType(stat_update, stat)
61 return stat
62
63 def __mul__(self, other: typ.Union[pxt.Real, pxt.Integer, "_OnlineStat"]):
64 stat = _OnlineStat()
65 if isinstance(other, _OnlineStat):
66 stat_update = lambda _, x: self.update(x) * other.update(x)
67 elif isinstance(other, pxt.Real) or isinstance(other, pxt.Integer):
68 stat_update = lambda _, x: self.update(x) * other
69 else:
70 return NotImplemented
71 stat.update = types.MethodType(stat_update, stat)
72 return stat
73
74 def __rmul__(self, other: typ.Union[pxt.Real, pxt.Integer]):
75 return self.__mul__(other)
76
77 def __truediv__(self, other: typ.Union[pxt.Real, pxt.Integer, "_OnlineStat"]):
78 stat = _OnlineStat()
79 if isinstance(other, _OnlineStat):
80
81 def stat_update(stat, x):
82 xp = pxu.get_array_module(x)
83 out = xp.divide(self.update(x), other.update(x), where=(other.update(x) != 0))
84 out[other.update(x) == 0] = xp.nan
85 return out
86
87 elif isinstance(other, pxt.Real) or isinstance(other, pxt.Integer):
88 stat_update = lambda _, x: self.update(x) / other
89 else:
90 return NotImplemented
91 stat.update = types.MethodType(stat_update, stat)
92 return stat
93
94 def __pow__(self, expo: typ.Union[pxt.Real, pxt.Integer]):
95 if not (isinstance(expo, pxt.Real) or isinstance(expo, pxt.Integer)):
96 return NotImplemented
97 stat = _OnlineStat()
98 stat_update = lambda _, x: self.update(x) ** expo
99 stat.update = types.MethodType(stat_update, stat)
100 return stat
101
102
[docs]
103class OnlineMoment(_OnlineStat):
104 r"""
105 Pointwise online moment.
106
107 For :math:`d \geq 1`, the :math:`d`-th order centered moment of the :math:`K` samples :math:`(\mathbf{x}_k)_{1 \leq
108 k \leq K}` is given by :math:`\frac{1}{K}\sum_{k=1}^K \mathbf{x}_k^d`.
109 """
110
111 def __init__(self, order: pxt.Real = 1):
112 super().__init__()
113 self._order = order
114
[docs]
115 def update(self, x: pxt.NDArray) -> pxt.NDArray:
116 if self._num_samples == 0:
117 self._stat = x**self._order
118 else:
119 self._stat *= self._num_samples
120 self._stat += x**self._order
121 self._num_samples += 1
122 self._stat /= self._num_samples
123 return self._stat
124
125
[docs]
126class OnlineCenteredMoment(_OnlineStat):
127 r"""
128 Pointwise online centered moment.
129
130 For :math:`d \geq 2`, the :math:`d`-th order centered moment of the :math:`K` samples :math:`(\mathbf{x}_k)_{1 \leq
131 k \leq K}` is given by :math:`\boldsymbol{\mu}_d=\frac{1}{K} \sum_{k=1}^K (\mathbf{x}_k-\boldsymbol{\mu})^d`, where
132 :math:`\boldsymbol{\mu}` is the sample mean.
133
134 Notes
135 -----
136 This class implements the *Welford algorithm* described in [WelfordAlg]_, which is a numerically stable algorithm
137 for computing online centered moments. In particular, it avoids `catastrophic cancellation
138 <https://en.wikipedia.org/wiki/Catastrophic_cancellation>`_ issues that may arise when naively computing online
139 centered moments, which would lead to a loss of numerical precision.
140
141 Note that this class internally stores the values of all online centered moments of order :math:`d'` for :math:`2
142 \leq d' \leq d` in the attribute ``_corrected_sums`` as well as the online mean (``_mean`` attribute). More
143 precisely, the array ``_corrected_sums[i, :]`` corresponds to the online sum :math:`\boldsymbol{\mu}_{i+2}=
144 \sum_{k=1}^K (\mathbf{x}_k-\boldsymbol{\mu})^{i+2}` for :math:`0 \leq i \leq d-2`.
145 """
146
147 def __init__(self, order: pxt.Real = 2):
148 super().__init__()
149 self._order = order
150 self._corrected_sums = None
151 self._mean = None
152
[docs]
153 def update(self, x: pxt.NDArray) -> pxt.NDArray:
154 if self._num_samples == 0:
155 xp = pxu.get_array_module(x)
156 self._corrected_sums = xp.zeros((self._order - 1,) + x.shape)
157 self._mean = x.copy()
158 else:
159 temp = (x - self._mean) / (self._num_samples + 1)
160 for r in range(self._order, 1, -1): # Update in descending order because updates depend on lower orders
161 for s in range(2, r): # s = r term excluded since it corresponds to previous iterate
162 self._corrected_sums[r - 2, :] += (
163 math.comb(r, s) * self._corrected_sums[s - 2, :] * (-temp) ** (r - s)
164 )
165 self._corrected_sums[r - 2, :] += self._num_samples * (-temp) ** r # Contribution of s = 0 term
166 self._corrected_sums[r - 2, :] += (self._num_samples * temp) ** r
167 self._mean *= self._num_samples / (self._num_samples + 1)
168 self._mean += x / (self._num_samples + 1)
169 self._num_samples += 1
170 self._stat = self._corrected_sums[-1, :] / self._num_samples
171 return self._stat
172
173
[docs]
174def OnlineVariance():
175 r"""
176 Pointwise online variance.
177
178 The pointwise online variance of the :math:`K` samples :math:`(\mathbf{x}_k)_{1 \leq k \leq K}` is given by
179 :math:`\boldsymbol{\sigma}^2 = \frac{1}{K}\sum_{k=1}^K (\mathbf{x}_k-\boldsymbol{\mu})^2`, where
180 :math:`\boldsymbol{\mu}` is the sample mean.
181 """
182 return OnlineCenteredMoment(order=2)
183
184
[docs]
185def OnlineStd():
186 r"""
187 Pointwise online standard deviation.
188
189 The pointwise online standard deviation of the :math:`K` samples :math:`(\mathbf{x}_k)_{1 \leq k \leq K}` is given
190 by :math:`\boldsymbol{\sigma}=\sqrt{\frac{1}{K}\sum_{k=1}^K (\mathbf{x}_k - \boldsymbol{\mu})^2}`, where
191 :math:`\boldsymbol{\mu}` is the sample mean.
192 """
193 return OnlineVariance() ** (1 / 2)
194
195
[docs]
196def OnlineSkewness():
197 r"""
198 Pointwise online skewness.
199
200 The pointwise online skewness of the :math:`K` samples :math:`(\mathbf{x}_k)_{1 \leq k \leq K}` is given by
201 :math:`\frac{1}{K}\sum_{k=1}^K \left( \frac{\mathbf{x}_k-\boldsymbol{\mu}}{\boldsymbol{\sigma}}\right)^3`, where
202 :math:`\boldsymbol{\mu}` is the sample mean and :math:`\boldsymbol{\sigma}` its standard deviation.
203
204 `Skewness <https://en.wikipedia.org/wiki/Skewness>`_ is a measure of asymmetry of a distribution around its mean.
205 Negative skewness indicates that the distribution has a heavier tail on the left side than on the right side,
206 positive skewness indicates the opposite, and values close to zero indicate a symmetric distribution.
207 """
208 return OnlineCenteredMoment(order=3) / OnlineStd() ** 3
209
210
[docs]
211def OnlineKurtosis():
212 r"""
213 Pointwise online kurtosis.
214
215 The pointwise online variance of the :math:`K` samples :math:`(\mathbf{x}_k)_{1 \leq k \leq K}` is given by
216 :math:`\frac{1}{K}\sum_{k=1}^K \left( \frac{\mathbf{x}_k-\boldsymbol{\mu}}{\boldsymbol{\sigma}}\right)^4`, where
217 :math:`\boldsymbol{\mu}` is the sample mean and :math:`\boldsymbol{\sigma}` its standard deviation.
218
219 `Kurtosis <https://en.wikipedia.org/wiki/Kurtosis>`_ is a measure of the heavy-tailedness of a distribution. In
220 particular, the kurtosis of a Gaussian distribution is always 3.
221 """
222 return OnlineCenteredMoment(order=4) / OnlineStd() ** 4