Source code for pyxu.experimental.sampler._sampler

  1r"""
  2This sampler module implements state-of-the-art algorithms that generate samples from probability distributions.  These
  3algorithms are particularly well-suited for high-dimensional distributions such as posterior distributions of inverse
  4problems in imaging, which is a notoriously difficult task.  The ability to sample from the posterior distribution is
  5extremely valuable, as it allows to explore the landscape of the objective function instead of having a single point
  6estimate (the maximum a posteriori solution, i.e. the mode of the posterior).  This is useful for uncertainty
  7quantification (UQ) purposes [UQ_MCMC]_, and it allows compute Monte Carlo estimates of expected values with respect to
  8the posterior.  For example, the mean of samples from the posterior is an approximation of the minimum mean-square error
  9(MMSE) estimator that can be used for image reconstruction.  Higher-order pixel-wise statistics (e.g., the variance) can
 10also be computed in an online fashion (see :py:mod:`~pyxu.experimental.sampler.statistics`) and provide useful
 11diagnostic tools for uncertainty quantification.
 12
 13In the following example, we showcase the unajusted Langevin algorithm (:py:class:`~pyxu.experimental.sampler.ULA`)
 14applied to a total-variation denoising problem.  We show the MMSE estimator as well as the pixelwise variance of the
 15samples.  As expected, the variance is higher around edges than in the smooth regions, indicating that there is higher
 16uncertainty in these regions.
 17
 18.. code-block:: python3
 19
 20   import matplotlib.pyplot as plt
 21   import numpy as np
 22   import pyxu.experimental.sampler as pxe_sampler
 23   import pyxu.operator as pxo
 24   import pyxu.opt.solver as pxsl
 25   import skimage as skim
 26
 27   sh_im = (128,) * 2
 28   gt = skim.transform.resize(skim.data.shepp_logan_phantom(), sh_im)  # Ground-truth image
 29   N = np.prod(sh_im)  # Number of pixels
 30
 31   # Noisy data
 32   rng = np.random.default_rng(seed=0)
 33   sigma = 1e-1  # Noise standard deviation
 34   y = gt + sigma * rng.standard_normal(sh_im)  # Noisy image
 35   f = 1 / 2 * pxo.SquaredL2Norm(dim_shape=sh_im).argshift(-y) / sigma**2  # Data fidelity loss
 36
 37   # Smoothed TV regularization
 38   g = pxo.L21Norm(dim_shape=(2, *sh_im)).moreau_envelope(1e-2) * pxo.Gradient(dim_shape=sh_im)
 39   theta = 10  # Regularization parameter
 40
 41   # Compute MAP estimator
 42   pgd = pxsl.PGD(f=f + theta * g)
 43   pgd.fit(x0=y)
 44   im_MAP = pgd.solution()
 45
 46   fig, ax = plt.subplots(1, 3)
 47   ax[0].imshow(gt)
 48   ax[0].set_title("Ground truth")
 49   ax[0].axis("off")
 50   ax[1].imshow(y, vmin=0, vmax=1)
 51   ax[1].set_title("Noisy image")
 52   ax[1].axis("off")
 53   ax[2].imshow(im_MAP, vmin=0, vmax=1)
 54   ax[2].set_title("MAP reconstruction")
 55   ax[2].axis("off")
 56
 57   ula = pxe_sampler.ULA(f=f + theta * g)  # ULA sampler
 58
 59   n = int(1e4)  # Number of samples
 60   burn_in = int(1e3)  # Number of burn-in iterations
 61   gen = ula.samples(x0=np.zeros(N), rng=rng)  # Generator for ULA samples
 62   # Objects for computing online statistics based on samples
 63   online_mean = pxe_sampler.OnlineMoment(order=1)
 64   online_var = pxe_sampler.OnlineVariance()
 65
 66   i = 0  # Number of samples
 67   for sample in gen:  # Draw ULA sample
 68       i += 1
 69       if i > burn_in + n:
 70           break
 71       if i > burn_in:
 72           mean = online_mean.update(sample)  # Update online mean
 73           var = online_var.update(sample)  # Update online variance
 74
 75   fig, ax = plt.subplots(1, 2)
 76   mean_im = ax[0].imshow(mean, vmin=0, vmax=1)
 77   fig.colorbar(mean_im, fraction=0.05, ax=ax[0])
 78   ax[0].set_title("Mean (MMSE estimator)")
 79   ax[0].axis("off")
 80   var_im = ax[1].imshow(var)
 81   fig.colorbar(var_im, fraction=0.05, ax=ax[1])
 82   ax[1].set_title("Variance")
 83   ax[1].axis("off")
 84   fig.suptitle("Pixel-wise statistics of ULA samples")
 85"""
 86
 87import collections.abc as cabc
 88import math
 89
 90import pyxu.abc as pxa
 91import pyxu.info.ptype as pxt
 92import pyxu.operator as pxo
 93import pyxu.util as pxu
 94
 95__all__ = [
 96    "_Sampler",
 97    "ULA",
 98    "MYULA",
 99]
100
101
102class _Sampler:
103    """Abstract base class for samplers."""
104
105    def samples(self, rng=None, **kwargs) -> cabc.Generator:
106        """Returns a generator; samples are drawn by calling next(generator)."""
107        self._sample_init(rng, **kwargs)
108
109        def _generator():
110            while True:
111                yield self._sample()
112
113        return _generator()
114
115    def _sample_init(self, rng, **kwargs):
116        """Optional method to set initial state of the sampler (e.g., a starting point)."""
117        pass
118
119    def _sample(self) -> pxt.NDArray:
120        """Method to be implemented by subclasses that returns the next sample."""
121        raise NotImplementedError
122
123
[docs] 124class ULA(_Sampler): 125 r""" 126 Unajusted Langevin algorithm (ULA). 127 128 Generates samples from the distribution 129 130 .. math:: 131 132 p(\mathbf{x}) 133 = 134 \frac{\exp(-\mathcal{F}(\mathbf{x}))}{\int_{\mathbb{R}^N} \exp(-\mathcal{F}(\tilde{\mathbf{x}})) 135 \mathrm{d} \tilde{\mathbf{x}} }, 136 137 where :math:`\mathcal{F}: \mathbb{R}^N \to \mathbb{R}` is *differentiable* with :math:`\beta`-*Lipschitz continuous* 138 gradient. 139 140 Notes 141 ----- 142 ULA is a Monte-Carlo Markov chain (MCMC) method that derives from the discretization of overdamped Langevin 143 diffusions. More specifically, it relies on the Langevin stochastic differential equation (SDE): 144 145 .. math:: 146 147 \mathrm{d} \mathbf{X}_t 148 = 149 - \nabla \mathcal{F}(\mathbf{X}_t) \mathrm{d}t + \sqrt{2} \mathrm{d} \mathbf{B}_t, 150 151 where :math:`(\mathbf{B}_t)_{t \geq 0}` is a :math:`N`-dimensional Brownian motion. It is well known that under 152 mild technical assumptions, this SDE has a unique strong solution whose invariant distribution is 153 :math:`p(\mathbf{x}) \propto \exp(-\mathcal{F}(\mathbf{x}))`. The discrete-time Euler-Maruyama discretization of 154 this SDE then yields the ULA Markov chain 155 156 .. math:: 157 158 \mathbf{X}_{k+1} = \mathbf{X}_{k} - \gamma \nabla \mathcal{F}(\mathbf{X}_k) + \sqrt{2 \gamma} \mathbf{Z}_{k+1} 159 160 for all :math:`k \in \mathbb{Z}`, where :math:`\gamma` is the discretization step size and :math:`(\mathbf{Z}_k)_{k 161 \in \mathbb{Z}}` is a sequence of independant and identically distributed :math:`N`-dimensional standard Gaussian 162 distributions. When :math:`\mathcal{F}` is differentiable with :math:`\beta`-Lipschitz continuous gradient and 163 :math:`\gamma \leq \frac{1}{\beta}`, the ULA Markov chain converges (see [ULA]_) to a unique stationary distribution 164 :math:`p_\gamma` such that 165 166 .. math:: 167 168 \lim_{\gamma \to 0} \Vert p_\gamma - p \Vert_{\mathrm{TV}} = 0. 169 170 The discretization step :math:`\gamma` is subject to the bias-variance tradeoff: a larger step will lead to faster 171 convergence of the Markov chain at the expense of a larger bias in the approximation of the distribution :math:`p`. 172 Setting :math:`\gamma` as large as possible (default behavior) is recommended for large-scale problems, since 173 convergence speed (rather than approximation bias) is then typically the main bottelneck. See `Example` section 174 below for a concrete illustration of this tradeoff. 175 176 Remarks 177 ------- 178 Like all MCMC sampling methods, ULA comes with the following challenges: 179 180 * The first few samples of the chain may not be adequate for computing statistics, as they might be located in low 181 probability regions. This challenge can either be alleviated by selecting a representative starting point to the 182 chain, or by having a `burn-in` phase where the first few samples are discarded. 183 184 * Consecutive samples are typically correlated, which can deteriorate the Monte-Carlo estimation of quantities of 185 interest. This issue can be alleviated by `thinning` the chain, i.e., selecting only every :math:`k` samples, at 186 the expense of an increased computational cost. Useful diagnostic tools to quantify this correlation between 187 samples include the pixel-wise autocorrelation function and the `effective sample size 188 <https://mc-stan.org/docs/reference-manual/effective-sample-size.html>`_. 189 190 Example 191 ------- 192 We illustrate ULA on a 1D example (:math:`N = 1`) where :math:`\mathcal{F}(x) = \frac{x^2}{2}`; the target 193 distribution :math:`p(x)` is thus the 1D standard Gaussian. In this toy example, the biased distribution 194 :math:`p_\gamma(x)` can be computed in closed form. The ULA Markov chain is given by 195 196 .. math:: 197 198 \mathbf{X}_{k+1} &= \mathbf{X}_{k} - \gamma \nabla\mathcal{F}(\mathbf{X}_k) + \sqrt{2\gamma}\mathbf{Z}_{k+1} \\ 199 &= \mathbf{X}_{k} (1 - \gamma) + \sqrt{2 \gamma} \mathbf{Z}_{k+1}. 200 201 Assuming for simplicity that :math:`\mathbf{X}_0` is Gaussian with mean :math:`\mu_0` and variance 202 :math:`\sigma_0^2`, :math:`\mathbf{X}_k` is Gaussian for any :math:`k \in \mathbb{Z}` as a linear combination of 203 Gaussians. Taking the expected value of the recurrence relation yields 204 205 .. math:: 206 207 \mu_k := \mathbb{E}(\mathbf{X}_{k}) = \mathbb{E}(\mathbf{X}_{k-1}) (1 - \gamma) = \mu_0 (1 - \gamma)^k 208 209 (geometric sequence). Taking the expected value of the square of the recurrence relation yields 210 211 .. math:: 212 213 \mu^{(2)}_k := \mathbb{E}(\mathbf{X}_{k}^2) = \mathbb{E}(\mathbf{X}_{k-1}^2) (1 - \gamma)^2 + 2 \gamma = 214 (1 - \gamma)^{2k} (\sigma_0^2 - b) + b 215 216 with :math:`b = \frac{2 \gamma}{1 - (1 - \gamma)^{2}} = \frac{1}{1-\frac{\gamma}{2}}` (arithmetico-geometric 217 sequence) due to the independence of :math:`\mathbf{X}_{k-1}` and :math:`\mathbf{Z}_{k}`. Hence, 218 :math:`p_\gamma(x)` is a Gaussian with mean :math:`\mu_\gamma= \lim_{k \to \infty} \mu_k = 0` and variance 219 :math:`\sigma_\gamma^2 = \lim_{k \to \infty} \mu^{(2)}_k - \mu_k^2 = \frac{1}{1-\frac{\gamma}{2}}`. As expected, we 220 have :math:`\lim_{\gamma \to 0} \sigma_\gamma^2 = 1`, which is the variance of the target distribution :math:`p(x)`. 221 222 We plot the distribution of the samples of ULA for one large (:math:`\gamma_1 \approx 1`, i.e. 223 :math:`\sigma_{\gamma_1}^2 \approx 2`) and one small (:math:`\gamma_2 = 0.1`, i.e. :math:`\sigma_{\gamma_2}^2 224 \approx 1.05`) step size. As expected, the larger step size :math:`\gamma_1` leads to a larger bias in the 225 approximation of :math:`p(x)`. To quantify the speed of convergence of the Markov chains, we compute the 226 `Cramér-von Mises <https://en.wikipedia.org/wiki/Cram%C3%A9r%E2%80%93von_Mises_criterion>`_ tests of goodness of fit 227 of the empirical distributions to the stationary distributions of ULA :math:`p_{\gamma_1}(x)` and 228 :math:`p_{\gamma_2}(x)`. We observe that the larger step :math:`\gamma_1` leads to a better fit (lower Cramér-von 229 Mises criterion), which illustrates the aforementioned bias-variance tradeoff for the choice of the step size. 230 231 .. plot:: 232 233 import matplotlib.pyplot as plt 234 import numpy as np 235 import pyxu.experimental.sampler as pxe_sampler 236 import pyxu.operator as pxo 237 import scipy as sp 238 239 f = pxo.SquaredL2Norm(dim_shape=1) / 2 # To sample 1D normal distribution (mean 0, variance 1) 240 ula = pxe_sampler.ULA(f=f) # Sampler with maximum step size 241 ula_lb = pxe_sampler.ULA(f=f, gamma=1e-1) # Sampler with small step size 242 243 gen_ula = ula.samples(x0=np.zeros(1)) 244 gen_ula_lb = ula_lb.samples(x0=np.zeros(1)) 245 n_burn_in = int(1e3) # Number of burn-in iterations 246 for i in range(n_burn_in): 247 next(gen_ula) 248 next(gen_ula_lb) 249 250 # Online statistics objects 251 mean_ula = pxe_sampler.OnlineMoment(order=1) 252 mean_ula_lb = pxe_sampler.OnlineMoment(order=1) 253 var_ula = pxe_sampler.OnlineVariance() 254 var_ula_lb = pxe_sampler.OnlineVariance() 255 256 n = int(1e4) # Number of samples 257 samples_ula = np.zeros(n) 258 samples_ula_lb = np.zeros(n) 259 for i in range(n): 260 sample = next(gen_ula) 261 sample_lb = next(gen_ula_lb) 262 samples_ula[i] = sample 263 samples_ula_lb[i] = sample_lb 264 mean = float(mean_ula.update(sample)) 265 var = float(var_ula.update(sample)) 266 mean_lb = float(mean_ula_lb.update(sample_lb)) 267 var_lb = float(var_ula_lb.update(sample_lb)) 268 269 # Theoretical variances of biased stationary distributions of ULA 270 biased_var = 1 / (1 - ula._gamma / 2) 271 biased_var_lb = 1 / (1 - ula_lb._gamma / 2) 272 273 # Quantify goodness of fit of empirical distribution with theoretical distribution (Cramér-von Mises test) 274 cvm = sp.stats.cramervonmises(samples_ula, "norm", args=(0, np.sqrt(biased_var))) 275 cvm_lb = sp.stats.cramervonmises(samples_ula_lb, "norm", args=(0, np.sqrt(biased_var_lb))) 276 277 # Plots 278 grid = np.linspace(-4, 4, 1000) 279 280 plt.figure() 281 plt.title( 282 f"ULA samples (large step size) \n Empirical mean: {mean:.3f} (theoretical: 0) \n " 283 f"Empirical variance: {var:.3f} (theoretical: {biased_var:.3f}) \n" 284 f"Cramér-von Mises goodness of fit: {cvm.statistic:.3f}" 285 ) 286 plt.hist(samples_ula, range=(min(grid), max(grid)), bins=100, density=True) 287 plt.plot(grid, sp.stats.norm.pdf(grid), label=r"$p(x)$") 288 plt.plot(grid, sp.stats.norm.pdf(grid, scale=np.sqrt(biased_var)), label=r"$p_{\gamma_1}(x)$") 289 plt.legend() 290 plt.show() 291 292 plt.figure() 293 plt.title( 294 f"ULA samples (small step size) \n Empirical mean: {mean_lb:.3f} (theoretical: 0) \n " 295 f"Empirical variance: {var_lb:.3f} (theoretical: {biased_var_lb:.3f}) \n" 296 f"Cramér-von Mises goodness of fit: {cvm_lb.statistic:.3f}" 297 ) 298 plt.hist(samples_ula_lb, range=(min(grid), max(grid)), bins=100, density=True) 299 plt.plot(grid, sp.stats.norm.pdf(grid), label=r"$p(x)$") 300 plt.plot(grid, sp.stats.norm.pdf(grid, scale=np.sqrt(biased_var_lb)), label=r"$p_{\gamma_2}(x)$") 301 plt.legend() 302 plt.show() 303 """ 304
[docs] 305 def __init__(self, f: pxa.DiffFunc, gamma: pxt.Real = None): 306 r""" 307 Parameters 308 ---------- 309 f: :py:class:`~pyxu.abc.DiffFunc` 310 Differentiable functional. 311 gamma: Real 312 Euler-Maruyama discretization step of the Langevin equation (see `Notes`). 313 """ 314 self._f = f 315 self._beta = f.diff_lipschitz 316 self._gamma = self._set_gamma(gamma) 317 self._rng = None 318 self.x = None
319 320 def _sample_init(self, rng, x0: pxt.NDArray): 321 r""" 322 Parameters 323 ---------- 324 rng: 325 Internal random generator. 326 x0: NDArray 327 Starting point of the Markov chain. 328 """ 329 self.x = x0.copy() 330 if rng is None: 331 xp = pxu.get_array_module(x0) 332 self._rng = xp.random.default_rng(None) 333 else: 334 self._rng = rng 335 336 def _sample(self) -> pxt.NDArray: 337 x = self.x.copy() 338 x -= self._gamma * pxu.copy_if_unsafe(self._f.grad(self.x)) 339 x += math.sqrt(2 * self._gamma) * self._rng.standard_normal(size=self.x.shape, dtype=self.x.dtype) 340 self.x = x 341 return x 342
[docs] 343 def objective_func(self) -> pxt.Real: 344 r""" 345 Negative logarithm of the target ditribution (up to the a constant) evaluated at the current state of the Markov 346 chain. 347 348 Useful for diagnostics purposes to monitor whether the Markov chain is sufficiently warm-started. If so, the 349 samples should accumulate around the modes of the target distribution, i.e., toward the minimum of 350 :math:`\mathcal{F}`. 351 """ 352 return pxu.copy_if_unsafe(self._f.apply(self.x))
353 354 def _set_gamma(self, gamma: pxt.Real = None) -> pxt.Real: 355 if gamma is None: 356 if math.isfinite(self._beta): 357 return 0.98 / self._beta 358 else: 359 msg = "If f has unbounded Lipschitz gradient, the gamma parameter must be provided." 360 raise ValueError(msg) 361 else: 362 try: 363 assert gamma > 0 364 except Exception: 365 raise ValueError(f"gamma must be positive, got {gamma}.") 366 return gamma
367 368
[docs] 369class MYULA(ULA): 370 r""" 371 Moreau-Yosida unajusted Langevin algorithm (MYULA). 372 373 Generates samples from the distribution 374 375 .. math:: 376 377 p(\mathbf{x}) = \frac{\exp(-\mathcal{F}(\mathbf{x}) - \mathcal{G}(\mathbf{x}))}{\int_{\mathbb{R}^N} 378 \exp(-\mathcal{F}(\tilde{\mathbf{x}}) - \mathcal{G}(\tilde{\mathbf{x}})) \mathrm{d} \tilde{\mathbf{x}} }, 379 380 where :math:`\mathcal{F}: \mathbb{R}^N \to \mathbb{R}` is *convex* and *differentiable* with :math:`\beta`- 381 *Lipschitz continuous* gradient, and :math:`\mathcal{G}: \mathbb{R}^N \to \mathbb{R}` is *proper*, *lower semi- 382 continuous* and *convex* with *simple proximal operator*. 383 384 Notes 385 ----- 386 MYULA is an extension of :py:class:`~pyxu.experimental.sampler.ULA` to sample from distributions whose logarithm is 387 nonsmooth. It consists in applying ULA to the differentiable functional :math:`\mathcal{U}^\lambda = \mathcal{F} + 388 \mathcal{G}^\lambda` for some :math:`\lambda > 0`, where 389 390 .. math:: 391 392 \mathcal{G}^\lambda (\mathbf{x}) = \inf_{\tilde{\mathbf{x}} \in \mathbb{R}^N} \frac{1}{2 \lambda} \Vert 393 \tilde{\mathbf{x}} - \mathbf{x} \Vert_2^2 + \mathcal{G}(\tilde{\mathbf{x}}) 394 395 is the Moreau-Yosida envelope of :math:`\mathcal{G}` with parameter :math:`\lambda`. We then have 396 397 .. math:: 398 399 \nabla \mathcal{U}^\lambda (\mathbf{x}) = \nabla \mathcal{F}(\mathbf{x}) + \frac{1}{\lambda} (\mathbf{x} - 400 \mathrm{prox}_{\lambda \mathcal{G}}(\mathbf{x})), 401 402 hence :math:`\nabla \mathcal{U}^\lambda` is :math:`(\beta + \frac{1}{\lambda})`-Lipschitz continuous, where 403 :math:`\beta` is the Lipschitz constant of :math:`\nabla \mathcal{F}`. Note that the target distribution of the 404 underlying ULA Markov chain is not exactly :math:`p(\mathbf{x})`, but the distribution 405 406 .. math:: 407 408 p^\lambda(\mathbf{x}) \propto \exp(-\mathcal{F}(\mathbf{x})-\mathcal{G}^\lambda(\mathbf{x})), 409 410 which introduces some additional bias on top of the bias of ULA related to the step size :math:`\gamma` (see `Notes` 411 of :py:class:`~pyxu.experimental.sampler.ULA` documentation). MYULA is guaranteed to converges when :math:`\gamma 412 \leq \frac{1}{\beta + \frac{1}{\lambda}}`, in which case it converges toward the stationary distribution 413 :math:`p^\lambda_\gamma(\mathbf{x})` that satisfies 414 415 .. math:: 416 417 \lim_{\gamma, \lambda \to 0} \Vert p^\lambda_\gamma - p \Vert_{\mathrm{TV}} = 0 418 419 (see [MYULA]_). The parameter :math:`\lambda` parameter is subject to a similar bias-variance tradeoff as 420 :math:`\gamma`. It is recommended to set it in the order of :math:`\frac{1}{\beta}`, so that the contributions of 421 :math:`\mathcal{F}` and :math:`\mathcal{G}^\lambda` to the Lipschitz constant of :math:`\nabla \mathcal{U}^\lambda` 422 is well balanced. 423 """ 424
[docs] 425 def __init__( 426 self, 427 f: pxa.DiffFunc = None, 428 g: pxa.ProxFunc = None, 429 gamma: pxt.Real = None, 430 lamb: pxt.Real = None, 431 ): 432 r""" 433 Parameters 434 ---------- 435 f: :py:class:`~pyxu.abc.DiffFunc`, None 436 Differentiable functional. 437 g: :py:class:`~pyxu.abc.ProxFunc`, None 438 Proximable functional. 439 gamma: Real 440 Euler-Maruyama discretization step of the Langevin equation (see `Notes` of 441 :py:class:`~pyxu.experimental.sampler.ULA` documentation). 442 lamb: Real 443 Moreau-Yosida envelope parameter for `g`. 444 """ 445 dim_shape = None 446 if f is not None: 447 dim_shape = f.dim_shape 448 if g is not None: 449 if dim_shape is None: 450 dim_shape = g.dim_shape 451 else: 452 assert g.dim_shape == dim_shape 453 if dim_shape is None: 454 raise ValueError("One of f or g must be nonzero.") 455 456 self._f_diff = pxo.NullFunc(dim_shape=dim_shape) if (f is None) else f 457 self._g = pxo.NullFunc(dim_shape=dim_shape) if (g is None) else g 458 459 self._lambda = self._set_lambda(lamb) 460 f = self._f_diff + self._g.moreau_envelope(self._lambda) 461 f.diff_lipschitz = f.estimate_diff_lipschitz() 462 super().__init__(f, gamma)
463 464 def _set_lambda(self, lamb: pxt.Real = None) -> pxt.Real: 465 if lamb is None: 466 if self._g._name == "NullFunc": 467 return 1.0 # Lambda is irrelevant if g is a NullFunc, but it must be positive 468 elif math.isfinite(dl := self._f_diff.diff_lipschitz): 469 return 2.0 if dl == 0 else min(2.0, 1.0 / dl) 470 else: 471 msg = "If f has unbounded Lipschitz gradient, the lambda parameter must be provided." 472 raise ValueError(msg) 473 else: 474 return lamb