Source code for pyxu.opt.solver.pgd

  1import itertools
  2import math
  3import warnings
  4
  5import numpy as np
  6
  7import pyxu.abc as pxa
  8import pyxu.info.ptype as pxt
  9import pyxu.info.warning as pxw
 10import pyxu.operator as pxo
 11import pyxu.util as pxu
 12
 13__all__ = [
 14    "PGD",
 15]
 16
 17
[docs] 18class PGD(pxa.Solver): 19 r""" 20 Proximal Gradient Descent (PGD) solver. 21 22 PGD solves minimization problems of the form 23 24 .. math:: 25 26 {\min_{\mathbf{x}\in\mathbb{R}^{M_{1} \times\cdots\times M_{D}}} \; 27 \mathcal{F}(\mathbf{x})\;\;+\;\;\mathcal{G}(\mathbf{x})}, 28 29 where: 30 31 * :math:`\mathcal{F}:\mathbb{R}^{M_{1} \times\cdots\times M_{D}}\rightarrow \mathbb{R}` is *convex* and 32 *differentiable*, with :math:`\beta`-*Lipschitz continuous* gradient, for some :math:`\beta\in[0,+\infty[`. 33 * :math:`\mathcal{G}:\mathbb{R}^{M_{1} \times\cdots\times M_{D}}\rightarrow \mathbb{R}\cup\{+\infty\}` is a 34 *proper*, *lower semicontinuous* and *convex function* with a *simple proximal operator*. 35 36 Remarks 37 ------- 38 * The problem is *feasible* -- i.e. there exists at least one solution. 39 40 * The algorithm is still valid if either :math:`\mathcal{F}` or :math:`\mathcal{G}` is zero. 41 42 * The convergence is guaranteed for step sizes :math:`\tau\leq 1/\beta`. 43 44 * Various acceleration schemes are described in [APGD]_. PGD achieves the following (optimal) *convergence rate* 45 with the implemented acceleration scheme from Chambolle & Dossal: 46 47 .. math:: 48 49 \lim\limits_{n\rightarrow \infty} n^2\left\vert \mathcal{J}(\mathbf{x}^\star)- \mathcal{J}(\mathbf{x}_n)\right\vert=0 50 \qquad\&\qquad 51 \lim\limits_{n\rightarrow \infty} n^2\Vert \mathbf{x}_n-\mathbf{x}_{n-1}\Vert^2_\mathcal{X}=0, 52 53 for *some minimiser* :math:`{\mathbf{x}^\star}\in\arg\min_{\mathbf{x}\in\mathbb{R}^{M_{1} \times\cdots\times 54 M_{D}}} \;\left\{\mathcal{J}(\mathbf{x}):=\mathcal{F}(\mathbf{x})+\mathcal{G}(\mathbf{x})\right\}`. In other 55 words, both the objective functional and the PGD iterates :math:`\{\mathbf{x}_n\}_{n\in\mathbb{N}}` converge at a 56 rate :math:`o(1/n^2)`. Significant practical *speedup* can be achieved for values of :math:`d` in the range 57 :math:`[50,100]` [APGD]_. 58 59 * The relative norm change of the primal variable is used as the default stopping criterion. By default, the 60 algorithm stops when the norm of the difference between two consecutive PGD iterates 61 :math:`\{\mathbf{x}_n\}_{n\in\mathbb{N}}` is smaller than 1e-4. Different stopping criteria can be used. 62 63 Parameters (``__init__()``) 64 --------------------------- 65 * **f** (:py:class:`~pyxu.abc.DiffFunc`, :py:obj:`None`) 66 -- 67 Differentiable function :math:`\mathcal{F}`. 68 * **g** (:py:class:`~pyxu.abc.ProxFunc`, :py:obj:`None`) 69 -- 70 Proximable function :math:`\mathcal{G}`. 71 * **\*\*kwargs** (:py:class:`~collections.abc.Mapping`) 72 -- 73 Other keyword parameters passed on to :py:meth:`pyxu.abc.Solver.__init__`. 74 75 Parameters (``fit()``) 76 ---------------------- 77 * **x0** (:py:attr:`~pyxu.info.ptype.NDArray`) 78 -- 79 (..., M1,...,MD) initial point(s). 80 * **tau** (:py:attr:`~pyxu.info.ptype.Real`, :py:obj:`None`) 81 -- 82 Gradient step size. Defaults to :math:`1 / \beta` if unspecified. 83 * **acceleration** (:py:obj:`bool`) 84 -- 85 If True (default), then use Chambolle & Dossal acceleration scheme. 86 * **d** (:py:attr:`~pyxu.info.ptype.Real`) 87 -- 88 Chambolle & Dossal acceleration parameter :math:`d`. Should be greater than 2. Only meaningful if `acceleration` 89 is True. Defaults to 75 in unspecified. 90 * **\*\*kwargs** (:py:class:`~collections.abc.Mapping`) 91 -- 92 Other keyword parameters passed on to :py:meth:`pyxu.abc.Solver.fit`. 93 """ 94 95 def __init__( 96 self, 97 f: pxa.DiffFunc = None, 98 g: pxa.ProxFunc = None, 99 **kwargs, 100 ): 101 kwargs.update( 102 log_var=kwargs.get("log_var", ("x",)), 103 ) 104 super().__init__(**kwargs) 105 106 if (f is None) and (g is None): 107 msg = " ".join( 108 [ 109 "Cannot minimize always-0 functional.", 110 "At least one of Parameter[f, g] must be specified.", 111 ] 112 ) 113 raise NotImplementedError(msg) 114 elif f is None: 115 self._f = pxo.NullFunc(dim_shape=g.dim_shape) 116 self._g = g 117 elif g is None: 118 self._f = f 119 self._g = pxo.NullFunc(dim_shape=f.dim_shape) 120 else: 121 self._f = f 122 self._g = g 123 124 def m_init( 125 self, 126 x0: pxt.NDArray, 127 tau: pxt.Real = None, 128 acceleration: bool = True, 129 d: pxt.Real = 75, 130 ): 131 mst = self._mstate # shorthand 132 mst["x"] = mst["x_prev"] = x0 133 134 if tau is None: 135 mst["tau"] = 1 / np.array(self._f.diff_lipschitz) 136 if math.isclose(mst["tau"], 0): 137 # _f does not provide any "useful" diff_lipschitz constant. 138 msg = "\n".join( 139 [ 140 "No useful step size could be auto-determined from Parameter[f].", 141 "Consider initializing Parameter[tau] directly, or set (an estimate of) the diff-Lipschitz constant of Parameter[f] before calling fit().", 142 "Solver iterations as-is may stagnate.", 143 ] 144 ) 145 warnings.warn(msg, pxw.AutoInferenceWarning) 146 if math.isinf(mst["tau"]): 147 # _f is constant-valued: \tau is a free parameter. 148 mst["tau"] = 1 149 msg = "\n".join( 150 [ 151 rf"The gradient/proximal step size \tau is auto-set to {mst['tau']}.", 152 r"Choosing \tau manually may lead to faster convergence.", 153 ] 154 ) 155 warnings.warn(msg, pxw.AutoInferenceWarning) 156 else: 157 try: 158 assert tau > 0 159 mst["tau"] = tau 160 except Exception: 161 raise ValueError(f"tau must be positive, got {tau}.") 162 163 if acceleration: 164 try: 165 assert d > 2 166 mst["a"] = (k / (k + 1 + d) for k in itertools.count(start=0)) 167 except Exception: 168 raise ValueError(f"Expected d > 2, got {d}.") 169 else: 170 mst["a"] = itertools.repeat(0.0) 171 172 def m_step(self): 173 mst = self._mstate # shorthand 174 a = next(mst["a"]) 175 176 # In-place implementation of ----------------- 177 # y = (1 + a) * mst["x"] - a * mst["x_prev"] 178 y = mst["x"] - mst["x_prev"] 179 y *= a 180 y += mst["x"] 181 # -------------------------------------------- 182 183 # In-place implementation of ----------------- 184 # z = y - mst["tau"] * self._f.grad(y) 185 z = pxu.copy_if_unsafe(self._f.grad(y)) 186 z *= -mst["tau"] 187 z += y 188 # -------------------------------------------- 189 190 mst["x_prev"], mst["x"] = mst["x"], self._g.prox(z, mst["tau"]) 191 192 def default_stop_crit(self) -> pxa.StoppingCriterion: 193 from pyxu.opt.stop import RelError 194 195 stop_crit = RelError( 196 eps=1e-4, 197 var="x", 198 rank=self._f.dim_rank, 199 f=None, 200 norm=2, 201 satisfy_all=True, 202 ) 203 return stop_crit 204 205 def objective_func(self) -> pxt.NDArray: 206 func = lambda x: self._f.apply(x) + self._g.apply(x) 207 208 y = func(self._mstate["x"]) 209 return y 210 211 def solution(self) -> pxt.NDArray: 212 """ 213 Returns 214 ------- 215 x: NDArray 216 (..., M1,...,MD) solution. 217 """ 218 data, _ = self.stats() 219 return data.get("x")