1r"""
2This sampler module implements state-of-the-art algorithms that generate samples from probability distributions. These
3algorithms are particularly well-suited for high-dimensional distributions such as posterior distributions of inverse
4problems in imaging, which is a notoriously difficult task. The ability to sample from the posterior distribution is
5extremely valuable, as it allows to explore the landscape of the objective function instead of having a single point
6estimate (the maximum a posteriori solution, i.e. the mode of the posterior). This is useful for uncertainty
7quantification (UQ) purposes [UQ_MCMC]_, and it allows compute Monte Carlo estimates of expected values with respect to
8the posterior. For example, the mean of samples from the posterior is an approximation of the minimum mean-square error
9(MMSE) estimator that can be used for image reconstruction. Higher-order pixel-wise statistics (e.g., the variance) can
10also be computed in an online fashion (see :py:mod:`~pyxu.experimental.sampler.statistics`) and provide useful
11diagnostic tools for uncertainty quantification.
12
13In the following example, we showcase the unajusted Langevin algorithm (:py:class:`~pyxu.experimental.sampler.ULA`)
14applied to a total-variation denoising problem. We show the MMSE estimator as well as the pixelwise variance of the
15samples. As expected, the variance is higher around edges than in the smooth regions, indicating that there is higher
16uncertainty in these regions.
17
18.. code-block:: python3
19
20 import matplotlib.pyplot as plt
21 import numpy as np
22 import pyxu.experimental.sampler as pxe_sampler
23 import pyxu.operator as pxo
24 import pyxu.opt.solver as pxsl
25 import skimage as skim
26
27 sh_im = (128,) * 2
28 gt = skim.transform.resize(skim.data.shepp_logan_phantom(), sh_im) # Ground-truth image
29 N = np.prod(sh_im) # Number of pixels
30
31 # Noisy data
32 rng = np.random.default_rng(seed=0)
33 sigma = 1e-1 # Noise standard deviation
34 y = gt + sigma * rng.standard_normal(sh_im) # Noisy image
35 f = 1 / 2 * pxo.SquaredL2Norm(dim_shape=sh_im).argshift(-y) / sigma**2 # Data fidelity loss
36
37 # Smoothed TV regularization
38 g = pxo.L21Norm(dim_shape=(2, *sh_im)).moreau_envelope(1e-2) * pxo.Gradient(dim_shape=sh_im)
39 theta = 10 # Regularization parameter
40
41 # Compute MAP estimator
42 pgd = pxsl.PGD(f=f + theta * g)
43 pgd.fit(x0=y)
44 im_MAP = pgd.solution()
45
46 fig, ax = plt.subplots(1, 3)
47 ax[0].imshow(gt)
48 ax[0].set_title("Ground truth")
49 ax[0].axis("off")
50 ax[1].imshow(y, vmin=0, vmax=1)
51 ax[1].set_title("Noisy image")
52 ax[1].axis("off")
53 ax[2].imshow(im_MAP, vmin=0, vmax=1)
54 ax[2].set_title("MAP reconstruction")
55 ax[2].axis("off")
56
57 ula = pxe_sampler.ULA(f=f + theta * g) # ULA sampler
58
59 n = int(1e4) # Number of samples
60 burn_in = int(1e3) # Number of burn-in iterations
61 gen = ula.samples(x0=np.zeros(N), rng=rng) # Generator for ULA samples
62 # Objects for computing online statistics based on samples
63 online_mean = pxe_sampler.OnlineMoment(order=1)
64 online_var = pxe_sampler.OnlineVariance()
65
66 i = 0 # Number of samples
67 for sample in gen: # Draw ULA sample
68 i += 1
69 if i > burn_in + n:
70 break
71 if i > burn_in:
72 mean = online_mean.update(sample) # Update online mean
73 var = online_var.update(sample) # Update online variance
74
75 fig, ax = plt.subplots(1, 2)
76 mean_im = ax[0].imshow(mean, vmin=0, vmax=1)
77 fig.colorbar(mean_im, fraction=0.05, ax=ax[0])
78 ax[0].set_title("Mean (MMSE estimator)")
79 ax[0].axis("off")
80 var_im = ax[1].imshow(var)
81 fig.colorbar(var_im, fraction=0.05, ax=ax[1])
82 ax[1].set_title("Variance")
83 ax[1].axis("off")
84 fig.suptitle("Pixel-wise statistics of ULA samples")
85"""
86
87import collections.abc as cabc
88import math
89
90import pyxu.abc as pxa
91import pyxu.info.ptype as pxt
92import pyxu.operator as pxo
93import pyxu.util as pxu
94
95__all__ = [
96 "_Sampler",
97 "ULA",
98 "MYULA",
99]
100
101
102class _Sampler:
103 """Abstract base class for samplers."""
104
105 def samples(self, rng=None, **kwargs) -> cabc.Generator:
106 """Returns a generator; samples are drawn by calling next(generator)."""
107 self._sample_init(rng, **kwargs)
108
109 def _generator():
110 while True:
111 yield self._sample()
112
113 return _generator()
114
115 def _sample_init(self, rng, **kwargs):
116 """Optional method to set initial state of the sampler (e.g., a starting point)."""
117 pass
118
119 def _sample(self) -> pxt.NDArray:
120 """Method to be implemented by subclasses that returns the next sample."""
121 raise NotImplementedError
122
123
[docs]
124class ULA(_Sampler):
125 r"""
126 Unajusted Langevin algorithm (ULA).
127
128 Generates samples from the distribution
129
130 .. math::
131
132 p(\mathbf{x})
133 =
134 \frac{\exp(-\mathcal{F}(\mathbf{x}))}{\int_{\mathbb{R}^N} \exp(-\mathcal{F}(\tilde{\mathbf{x}}))
135 \mathrm{d} \tilde{\mathbf{x}} },
136
137 where :math:`\mathcal{F}: \mathbb{R}^N \to \mathbb{R}` is *differentiable* with :math:`\beta`-*Lipschitz continuous*
138 gradient.
139
140 Notes
141 -----
142 ULA is a Monte-Carlo Markov chain (MCMC) method that derives from the discretization of overdamped Langevin
143 diffusions. More specifically, it relies on the Langevin stochastic differential equation (SDE):
144
145 .. math::
146
147 \mathrm{d} \mathbf{X}_t
148 =
149 - \nabla \mathcal{F}(\mathbf{X}_t) \mathrm{d}t + \sqrt{2} \mathrm{d} \mathbf{B}_t,
150
151 where :math:`(\mathbf{B}_t)_{t \geq 0}` is a :math:`N`-dimensional Brownian motion. It is well known that under
152 mild technical assumptions, this SDE has a unique strong solution whose invariant distribution is
153 :math:`p(\mathbf{x}) \propto \exp(-\mathcal{F}(\mathbf{x}))`. The discrete-time Euler-Maruyama discretization of
154 this SDE then yields the ULA Markov chain
155
156 .. math::
157
158 \mathbf{X}_{k+1} = \mathbf{X}_{k} - \gamma \nabla \mathcal{F}(\mathbf{X}_k) + \sqrt{2 \gamma} \mathbf{Z}_{k+1}
159
160 for all :math:`k \in \mathbb{Z}`, where :math:`\gamma` is the discretization step size and :math:`(\mathbf{Z}_k)_{k
161 \in \mathbb{Z}}` is a sequence of independant and identically distributed :math:`N`-dimensional standard Gaussian
162 distributions. When :math:`\mathcal{F}` is differentiable with :math:`\beta`-Lipschitz continuous gradient and
163 :math:`\gamma \leq \frac{1}{\beta}`, the ULA Markov chain converges (see [ULA]_) to a unique stationary distribution
164 :math:`p_\gamma` such that
165
166 .. math::
167
168 \lim_{\gamma \to 0} \Vert p_\gamma - p \Vert_{\mathrm{TV}} = 0.
169
170 The discretization step :math:`\gamma` is subject to the bias-variance tradeoff: a larger step will lead to faster
171 convergence of the Markov chain at the expense of a larger bias in the approximation of the distribution :math:`p`.
172 Setting :math:`\gamma` as large as possible (default behavior) is recommended for large-scale problems, since
173 convergence speed (rather than approximation bias) is then typically the main bottelneck. See `Example` section
174 below for a concrete illustration of this tradeoff.
175
176 Remarks
177 -------
178 Like all MCMC sampling methods, ULA comes with the following challenges:
179
180 * The first few samples of the chain may not be adequate for computing statistics, as they might be located in low
181 probability regions. This challenge can either be alleviated by selecting a representative starting point to the
182 chain, or by having a `burn-in` phase where the first few samples are discarded.
183
184 * Consecutive samples are typically correlated, which can deteriorate the Monte-Carlo estimation of quantities of
185 interest. This issue can be alleviated by `thinning` the chain, i.e., selecting only every :math:`k` samples, at
186 the expense of an increased computational cost. Useful diagnostic tools to quantify this correlation between
187 samples include the pixel-wise autocorrelation function and the `effective sample size
188 <https://mc-stan.org/docs/reference-manual/effective-sample-size.html>`_.
189
190 Example
191 -------
192 We illustrate ULA on a 1D example (:math:`N = 1`) where :math:`\mathcal{F}(x) = \frac{x^2}{2}`; the target
193 distribution :math:`p(x)` is thus the 1D standard Gaussian. In this toy example, the biased distribution
194 :math:`p_\gamma(x)` can be computed in closed form. The ULA Markov chain is given by
195
196 .. math::
197
198 \mathbf{X}_{k+1} &= \mathbf{X}_{k} - \gamma \nabla\mathcal{F}(\mathbf{X}_k) + \sqrt{2\gamma}\mathbf{Z}_{k+1} \\
199 &= \mathbf{X}_{k} (1 - \gamma) + \sqrt{2 \gamma} \mathbf{Z}_{k+1}.
200
201 Assuming for simplicity that :math:`\mathbf{X}_0` is Gaussian with mean :math:`\mu_0` and variance
202 :math:`\sigma_0^2`, :math:`\mathbf{X}_k` is Gaussian for any :math:`k \in \mathbb{Z}` as a linear combination of
203 Gaussians. Taking the expected value of the recurrence relation yields
204
205 .. math::
206
207 \mu_k := \mathbb{E}(\mathbf{X}_{k}) = \mathbb{E}(\mathbf{X}_{k-1}) (1 - \gamma) = \mu_0 (1 - \gamma)^k
208
209 (geometric sequence). Taking the expected value of the square of the recurrence relation yields
210
211 .. math::
212
213 \mu^{(2)}_k := \mathbb{E}(\mathbf{X}_{k}^2) = \mathbb{E}(\mathbf{X}_{k-1}^2) (1 - \gamma)^2 + 2 \gamma =
214 (1 - \gamma)^{2k} (\sigma_0^2 - b) + b
215
216 with :math:`b = \frac{2 \gamma}{1 - (1 - \gamma)^{2}} = \frac{1}{1-\frac{\gamma}{2}}` (arithmetico-geometric
217 sequence) due to the independence of :math:`\mathbf{X}_{k-1}` and :math:`\mathbf{Z}_{k}`. Hence,
218 :math:`p_\gamma(x)` is a Gaussian with mean :math:`\mu_\gamma= \lim_{k \to \infty} \mu_k = 0` and variance
219 :math:`\sigma_\gamma^2 = \lim_{k \to \infty} \mu^{(2)}_k - \mu_k^2 = \frac{1}{1-\frac{\gamma}{2}}`. As expected, we
220 have :math:`\lim_{\gamma \to 0} \sigma_\gamma^2 = 1`, which is the variance of the target distribution :math:`p(x)`.
221
222 We plot the distribution of the samples of ULA for one large (:math:`\gamma_1 \approx 1`, i.e.
223 :math:`\sigma_{\gamma_1}^2 \approx 2`) and one small (:math:`\gamma_2 = 0.1`, i.e. :math:`\sigma_{\gamma_2}^2
224 \approx 1.05`) step size. As expected, the larger step size :math:`\gamma_1` leads to a larger bias in the
225 approximation of :math:`p(x)`. To quantify the speed of convergence of the Markov chains, we compute the
226 `Cramér-von Mises <https://en.wikipedia.org/wiki/Cram%C3%A9r%E2%80%93von_Mises_criterion>`_ tests of goodness of fit
227 of the empirical distributions to the stationary distributions of ULA :math:`p_{\gamma_1}(x)` and
228 :math:`p_{\gamma_2}(x)`. We observe that the larger step :math:`\gamma_1` leads to a better fit (lower Cramér-von
229 Mises criterion), which illustrates the aforementioned bias-variance tradeoff for the choice of the step size.
230
231 .. plot::
232
233 import matplotlib.pyplot as plt
234 import numpy as np
235 import pyxu.experimental.sampler as pxe_sampler
236 import pyxu.operator as pxo
237 import scipy as sp
238
239 f = pxo.SquaredL2Norm(dim_shape=1) / 2 # To sample 1D normal distribution (mean 0, variance 1)
240 ula = pxe_sampler.ULA(f=f) # Sampler with maximum step size
241 ula_lb = pxe_sampler.ULA(f=f, gamma=1e-1) # Sampler with small step size
242
243 gen_ula = ula.samples(x0=np.zeros(1))
244 gen_ula_lb = ula_lb.samples(x0=np.zeros(1))
245 n_burn_in = int(1e3) # Number of burn-in iterations
246 for i in range(n_burn_in):
247 next(gen_ula)
248 next(gen_ula_lb)
249
250 # Online statistics objects
251 mean_ula = pxe_sampler.OnlineMoment(order=1)
252 mean_ula_lb = pxe_sampler.OnlineMoment(order=1)
253 var_ula = pxe_sampler.OnlineVariance()
254 var_ula_lb = pxe_sampler.OnlineVariance()
255
256 n = int(1e4) # Number of samples
257 samples_ula = np.zeros(n)
258 samples_ula_lb = np.zeros(n)
259 for i in range(n):
260 sample = next(gen_ula)
261 sample_lb = next(gen_ula_lb)
262 samples_ula[i] = sample
263 samples_ula_lb[i] = sample_lb
264 mean = float(mean_ula.update(sample))
265 var = float(var_ula.update(sample))
266 mean_lb = float(mean_ula_lb.update(sample_lb))
267 var_lb = float(var_ula_lb.update(sample_lb))
268
269 # Theoretical variances of biased stationary distributions of ULA
270 biased_var = 1 / (1 - ula._gamma / 2)
271 biased_var_lb = 1 / (1 - ula_lb._gamma / 2)
272
273 # Quantify goodness of fit of empirical distribution with theoretical distribution (Cramér-von Mises test)
274 cvm = sp.stats.cramervonmises(samples_ula, "norm", args=(0, np.sqrt(biased_var)))
275 cvm_lb = sp.stats.cramervonmises(samples_ula_lb, "norm", args=(0, np.sqrt(biased_var_lb)))
276
277 # Plots
278 grid = np.linspace(-4, 4, 1000)
279
280 plt.figure()
281 plt.title(
282 f"ULA samples (large step size) \n Empirical mean: {mean:.3f} (theoretical: 0) \n "
283 f"Empirical variance: {var:.3f} (theoretical: {biased_var:.3f}) \n"
284 f"Cramér-von Mises goodness of fit: {cvm.statistic:.3f}"
285 )
286 plt.hist(samples_ula, range=(min(grid), max(grid)), bins=100, density=True)
287 plt.plot(grid, sp.stats.norm.pdf(grid), label=r"$p(x)$")
288 plt.plot(grid, sp.stats.norm.pdf(grid, scale=np.sqrt(biased_var)), label=r"$p_{\gamma_1}(x)$")
289 plt.legend()
290 plt.show()
291
292 plt.figure()
293 plt.title(
294 f"ULA samples (small step size) \n Empirical mean: {mean_lb:.3f} (theoretical: 0) \n "
295 f"Empirical variance: {var_lb:.3f} (theoretical: {biased_var_lb:.3f}) \n"
296 f"Cramér-von Mises goodness of fit: {cvm_lb.statistic:.3f}"
297 )
298 plt.hist(samples_ula_lb, range=(min(grid), max(grid)), bins=100, density=True)
299 plt.plot(grid, sp.stats.norm.pdf(grid), label=r"$p(x)$")
300 plt.plot(grid, sp.stats.norm.pdf(grid, scale=np.sqrt(biased_var_lb)), label=r"$p_{\gamma_2}(x)$")
301 plt.legend()
302 plt.show()
303 """
304
[docs]
305 def __init__(self, f: pxa.DiffFunc, gamma: pxt.Real = None):
306 r"""
307 Parameters
308 ----------
309 f: :py:class:`~pyxu.abc.DiffFunc`
310 Differentiable functional.
311 gamma: Real
312 Euler-Maruyama discretization step of the Langevin equation (see `Notes`).
313 """
314 self._f = f
315 self._beta = f.diff_lipschitz
316 self._gamma = self._set_gamma(gamma)
317 self._rng = None
318 self.x = None
319
320 def _sample_init(self, rng, x0: pxt.NDArray):
321 r"""
322 Parameters
323 ----------
324 rng:
325 Internal random generator.
326 x0: NDArray
327 Starting point of the Markov chain.
328 """
329 self.x = x0.copy()
330 if rng is None:
331 xp = pxu.get_array_module(x0)
332 self._rng = xp.random.default_rng(None)
333 else:
334 self._rng = rng
335
336 def _sample(self) -> pxt.NDArray:
337 x = self.x.copy()
338 x -= self._gamma * pxu.copy_if_unsafe(self._f.grad(self.x))
339 x += math.sqrt(2 * self._gamma) * self._rng.standard_normal(size=self.x.shape, dtype=self.x.dtype)
340 self.x = x
341 return x
342
[docs]
343 def objective_func(self) -> pxt.Real:
344 r"""
345 Negative logarithm of the target ditribution (up to the a constant) evaluated at the current state of the Markov
346 chain.
347
348 Useful for diagnostics purposes to monitor whether the Markov chain is sufficiently warm-started. If so, the
349 samples should accumulate around the modes of the target distribution, i.e., toward the minimum of
350 :math:`\mathcal{F}`.
351 """
352 return pxu.copy_if_unsafe(self._f.apply(self.x))
353
354 def _set_gamma(self, gamma: pxt.Real = None) -> pxt.Real:
355 if gamma is None:
356 if math.isfinite(self._beta):
357 return 0.98 / self._beta
358 else:
359 msg = "If f has unbounded Lipschitz gradient, the gamma parameter must be provided."
360 raise ValueError(msg)
361 else:
362 try:
363 assert gamma > 0
364 except Exception:
365 raise ValueError(f"gamma must be positive, got {gamma}.")
366 return gamma
367
368
[docs]
369class MYULA(ULA):
370 r"""
371 Moreau-Yosida unajusted Langevin algorithm (MYULA).
372
373 Generates samples from the distribution
374
375 .. math::
376
377 p(\mathbf{x}) = \frac{\exp(-\mathcal{F}(\mathbf{x}) - \mathcal{G}(\mathbf{x}))}{\int_{\mathbb{R}^N}
378 \exp(-\mathcal{F}(\tilde{\mathbf{x}}) - \mathcal{G}(\tilde{\mathbf{x}})) \mathrm{d} \tilde{\mathbf{x}} },
379
380 where :math:`\mathcal{F}: \mathbb{R}^N \to \mathbb{R}` is *convex* and *differentiable* with :math:`\beta`-
381 *Lipschitz continuous* gradient, and :math:`\mathcal{G}: \mathbb{R}^N \to \mathbb{R}` is *proper*, *lower semi-
382 continuous* and *convex* with *simple proximal operator*.
383
384 Notes
385 -----
386 MYULA is an extension of :py:class:`~pyxu.experimental.sampler.ULA` to sample from distributions whose logarithm is
387 nonsmooth. It consists in applying ULA to the differentiable functional :math:`\mathcal{U}^\lambda = \mathcal{F} +
388 \mathcal{G}^\lambda` for some :math:`\lambda > 0`, where
389
390 .. math::
391
392 \mathcal{G}^\lambda (\mathbf{x}) = \inf_{\tilde{\mathbf{x}} \in \mathbb{R}^N} \frac{1}{2 \lambda} \Vert
393 \tilde{\mathbf{x}} - \mathbf{x} \Vert_2^2 + \mathcal{G}(\tilde{\mathbf{x}})
394
395 is the Moreau-Yosida envelope of :math:`\mathcal{G}` with parameter :math:`\lambda`. We then have
396
397 .. math::
398
399 \nabla \mathcal{U}^\lambda (\mathbf{x}) = \nabla \mathcal{F}(\mathbf{x}) + \frac{1}{\lambda} (\mathbf{x} -
400 \mathrm{prox}_{\lambda \mathcal{G}}(\mathbf{x})),
401
402 hence :math:`\nabla \mathcal{U}^\lambda` is :math:`(\beta + \frac{1}{\lambda})`-Lipschitz continuous, where
403 :math:`\beta` is the Lipschitz constant of :math:`\nabla \mathcal{F}`. Note that the target distribution of the
404 underlying ULA Markov chain is not exactly :math:`p(\mathbf{x})`, but the distribution
405
406 .. math::
407
408 p^\lambda(\mathbf{x}) \propto \exp(-\mathcal{F}(\mathbf{x})-\mathcal{G}^\lambda(\mathbf{x})),
409
410 which introduces some additional bias on top of the bias of ULA related to the step size :math:`\gamma` (see `Notes`
411 of :py:class:`~pyxu.experimental.sampler.ULA` documentation). MYULA is guaranteed to converges when :math:`\gamma
412 \leq \frac{1}{\beta + \frac{1}{\lambda}}`, in which case it converges toward the stationary distribution
413 :math:`p^\lambda_\gamma(\mathbf{x})` that satisfies
414
415 .. math::
416
417 \lim_{\gamma, \lambda \to 0} \Vert p^\lambda_\gamma - p \Vert_{\mathrm{TV}} = 0
418
419 (see [MYULA]_). The parameter :math:`\lambda` parameter is subject to a similar bias-variance tradeoff as
420 :math:`\gamma`. It is recommended to set it in the order of :math:`\frac{1}{\beta}`, so that the contributions of
421 :math:`\mathcal{F}` and :math:`\mathcal{G}^\lambda` to the Lipschitz constant of :math:`\nabla \mathcal{U}^\lambda`
422 is well balanced.
423 """
424
[docs]
425 def __init__(
426 self,
427 f: pxa.DiffFunc = None,
428 g: pxa.ProxFunc = None,
429 gamma: pxt.Real = None,
430 lamb: pxt.Real = None,
431 ):
432 r"""
433 Parameters
434 ----------
435 f: :py:class:`~pyxu.abc.DiffFunc`, None
436 Differentiable functional.
437 g: :py:class:`~pyxu.abc.ProxFunc`, None
438 Proximable functional.
439 gamma: Real
440 Euler-Maruyama discretization step of the Langevin equation (see `Notes` of
441 :py:class:`~pyxu.experimental.sampler.ULA` documentation).
442 lamb: Real
443 Moreau-Yosida envelope parameter for `g`.
444 """
445 dim_shape = None
446 if f is not None:
447 dim_shape = f.dim_shape
448 if g is not None:
449 if dim_shape is None:
450 dim_shape = g.dim_shape
451 else:
452 assert g.dim_shape == dim_shape
453 if dim_shape is None:
454 raise ValueError("One of f or g must be nonzero.")
455
456 self._f_diff = pxo.NullFunc(dim_shape=dim_shape) if (f is None) else f
457 self._g = pxo.NullFunc(dim_shape=dim_shape) if (g is None) else g
458
459 self._lambda = self._set_lambda(lamb)
460 f = self._f_diff + self._g.moreau_envelope(self._lambda)
461 f.diff_lipschitz = f.estimate_diff_lipschitz()
462 super().__init__(f, gamma)
463
464 def _set_lambda(self, lamb: pxt.Real = None) -> pxt.Real:
465 if lamb is None:
466 if self._g._name == "NullFunc":
467 return 1.0 # Lambda is irrelevant if g is a NullFunc, but it must be positive
468 elif math.isfinite(dl := self._f_diff.diff_lipschitz):
469 return 2.0 if dl == 0 else min(2.0, 1.0 / dl)
470 else:
471 msg = "If f has unbounded Lipschitz gradient, the lambda parameter must be provided."
472 raise ValueError(msg)
473 else:
474 return lamb