  1import math
  3import as pxa
  4import as pxt
  5import as pxw
  6import pyxu.util as pxu
  8__all__ = [
  9    "Adam",
[docs] 13class Adam(pxa.Solver): 14 r""" 15 Adam solver [ProxAdam]_. 16 17 Adam minimizes 18 19 .. math:: 20 21 {\min_{\mathbf{x}\in\mathbb{R}^N} \;\mathcal{F}(\mathbf{x})}, 22 23 where: 24 25 * :math:`\mathcal{F}:\mathbb{R}^N\rightarrow \mathbb{R}` is *convex* and *differentiable*, with 26 :math:`\beta`-*Lipschitz continuous* gradient, for some :math:`\beta\in[0,+\infty[`. 27 28 Adam is a suitable alternative to Proximal Gradient Descent (:py:class:`~pyxu.opt.solver.PGD`) when: 29 30 * the cost function is differentiable, 31 * computing :math:`\beta` to optimally choose the step size is infeasible, 32 * line-search methods to estimate step sizes are too expensive. 33 34 Compared to PGD, Adam auto-tunes gradient updates based on stochastic estimates of :math:`\phi_{t} = 35 \mathbb{E}[\nabla\mathcal{F}]` and :math:`\psi_{t} = \mathbb{E}[\nabla\mathcal{F}^{2}]` respectively. 36 37 Adam has many named variants for particular choices of :math:`\phi` and :math:`\psi`: 38 39 * Adam: 40 41 .. math:: 42 43 \phi_t 44 = 45 \frac{ 46 \mathbf{m}_t 47 }{ 48 1-\beta_1^t 49 } 50 \qquad 51 \psi_t 52 = 53 \sqrt{ 54 \frac{ 55 \mathbf{v}_t 56 }{ 57 1-\beta_2^t 58 } 59 } + \epsilon, 60 61 * AMSGrad: 62 63 .. math:: 64 65 \phi_t = \mathbf{m}_t 66 \qquad 67 \psi_t = \sqrt{\hat{\mathbf{v}}_t}, 68 69 * PAdam: 70 71 .. math:: 72 73 \phi_t = \mathbf{m}_t 74 \qquad 75 \psi_t = \hat{\mathbf{v}}_t^p, 76 77 where in all cases: 78 79 .. math:: 80 81 \mathbf{m}_t 82 = 83 \beta_1\mathbf{m}_{t-1} 84 + 85 (1-\beta_1)\mathbf{g}_t \\ 86 \mathbf{v}_t 87 = 88 \beta_2\mathbf{v}_{t-1} 89 + 90 (1-\beta_2)\mathbf{g}_t^2\\ 91 \hat{\mathbf{v}}_t 92 = 93 \max(\hat{\mathbf{v}}_{t-1}, \mathbf{v}_t), 94 95 with :math:`\mathbf{m}_0 = \mathbf{v}_0 = \mathbf{0}`. 96 97 Remarks 98 ------- 99 * The convergence is guaranteed for step sizes :math:`\alpha\leq 2/\beta`. 100 101 * The default stopping criterion is the relative norm change of the primal variable. By default, the algorithm 102 stops when the norm of the difference between two consecutive iterates :math:`\{\mathbf{x}_n\}_{n\in\mathbb{N}}` 103 is smaller than 1e-4. Different stopping criteria can be used. It is recommended to change the stopping 104 criterion when using the PAdam and AMSGrad variants to avoid premature stops. 105 106 Parameters (``__init__()``) 107 --------------------------- 108 * **f** (:py:class:``) 109 -- 110 Differentiable function :math:`\mathcal{F}`. 111 * **\*\*kwargs** (:py:class:``) 112 -- 113 Other keyword parameters passed on to :py:meth:``. 114 115 Parameters (``fit()``) 116 ---------------------- 117 * **x0** (:py:attr:``) 118 -- 119 (..., N) initial point(s). 120 * **variant** ("adam", "amsgrad", "padam") 121 -- 122 Name of the Adam variant to use. Defaults to "adam". 123 * **a** (:py:attr:``, :py:obj:`None`) 124 -- 125 Max normalized gradient step size. Defaults to :math:`1 / \beta` if unspecified. 126 * **b1** (:py:attr:``) 127 -- 128 1st-order gradient exponential decay :math:`\beta_{1} \in [0, 1)`. 129 * **b2** (:py:attr:``) 130 -- 131 2nd-order gradient exponential decay :math:`\beta_{2} \in [0, 1)`. 132 * **m0** (:py:attr:``, :py:obj:`None`) 133 -- 134 (..., N) initial 1st-order gradient estimate corresponding to each initial point. Defaults to the null vector if 135 unspecified. 136 * **v0** (:py:attr:``, :py:obj:`None`) 137 -- 138 (..., N) initial 2nd-order gradient estimate corresponding to each initial point. Defaults to the null vector if 139 unspecified. 140 * **p** (:py:attr:``) 141 -- 142 PAdam power parameter :math:`p \in (0, 0.5]`. Must be specified for PAdam, unused otherwise. 143 * **eps_adam** (:py:attr:``) 144 -- 145 Adam noise parameter :math:`\epsilon`. This term is used exclusively if `variant="adam"`. Defaults to 1e-6. 146 * **eps_var** (:py:attr:``) 147 -- 148 Avoids division by zero if estimated gradient variance is too small. Defaults to 1e-6. 149 * **\*\*kwargs** (:py:class:``) 150 -- 151 Other keyword parameters passed on to :py:meth:``. 152 153 Note 154 ---- 155 If provided, `m0` and `v0` must be broadcastable with `x0`. 156 157 Example 158 ------- 159 Consider the following optimization problem: 160 161 .. math:: 162 163 \min_{\mathbf{x}\in\mathbb{R}^N} \Vert{\mathbf{x}-\mathbf{1}}\Vert_2^2 164 165 .. code-block:: python3 166 167 import numpy as np 168 169 from pyxu.operator import SquaredL2Norm 170 from pyxu.opt.solver import Adam 171 172 N = 3 173 f = SquaredL2Norm(dim=N).asloss(1) 174 175 slvr = Adam(f) 176 177 x0=np.zeros((N,)), 178 variant="padam", 179 p=0.25, 180 ) 181 x_opt = slvr.solution() 182 np.allclose(x_opt, 1, rtol=1e-4) # True 183 """ 184 185 def __init__( 186 self, 187 f: pxa.DiffFunc, 188 **kwargs, 189 ): 190 kwargs.update( 191 log_var=kwargs.get("log_var", ("x",)), 192 ) 193 super().__init__(**kwargs) 194 195 self._f = f 196 197 def m_init( # default values from 198 self, 199 x0: pxt.NDArray, 200 variant: str = "adam", 201 a: pxt.Real = None, 202 b1: pxt.Real = 0.9, 203 b2: pxt.Real = 0.999, 204 m0: pxt.NDArray = None, 205 v0: pxt.NDArray = None, 206 p: pxt.Real = 0.5, 207 eps_adam: pxt.Real = 1e-6, 208 eps_var: pxt.Real = 1e-6, 209 ): 210 mst = self._mstate # shorthand 211 xp = pxu.get_array_module(x0) 212 mst["x"] = x0 213 214 mst["variant"] = self.__parse_variant(variant) 215 216 if a is None: 217 g = lambda _: math.isclose(self._f.diff_lipschitz, _) 218 if g(0) or g(math.inf): 219 error_msg = "Cannot auto-infer step size: choose it manually." 220 raise pxw.AutoInferenceWarning(error_msg) 221 else: 222 mst["a"] = 1.0 / self._f.diff_lipschitz 223 else: 224 assert a > 0, f"Parameter[a] must be positive, got {a}." 225 mst["a"] = a 226 227 assert 0 <= b1 < 1, f"Parameter[b1]: expected value in [0, 1), got {b1}." 228 mst["b1"] = b1 229 230 assert 0 <= b2 < 1, f"Parameter[b2]: expected value in [0, 1), got {b2}." 231 mst["b2"] = b2 232 233 if m0 is None: 234 mst["mean"] = xp.zeros_like(x0) 235 elif m0.shape == x0.shape: 236 # No broadcasting involved 237 mst["mean"] = m0 238 else: 239 x0, m0 = xp.broadcast_arrays(x0, m0) 240 mst["mean"] = m0.copy() 241 242 if v0 is None: 243 mst["variance"] = xp.zeros_like(x0) 244 elif v0.shape == x0.shape: 245 # No broadcasting involved 246 mst["variance"] = v0 247 else: 248 x0, v0 = xp.broadcast_arrays(x0, v0) 249 mst["variance"] = v0.copy() 250 mst["variance_hat"] = mst["variance"] 251 252 assert 0 < p <= 0.5, f"Parameter[p]: expected value in (0, 0.5], got {p}." 253 mst["padam_p"] = p 254 255 assert eps_adam > 0, f"Parameter[eps_adam]: expected positive value, got {eps_adam}." 256 mst["eps_adam"] = eps_adam 257 258 assert eps_var > 0, f"Parameter[eps_var]: expected positive value, got {eps_var}." 259 mst["eps_variance"] = eps_var 260 261 def m_step(self): 262 mst = self._mstate # shorthand 263 264 x, a = mst["x"], mst["a"] 265 xp = pxu.get_array_module(x) 266 gm = pxu.copy_if_unsafe(self._f.grad(x)) 267 gv = gm.copy() 268 269 ## Part 1: evaluate phi/psi ============================ 270 m, b1 = mst["mean"], mst["b1"] 271 # In-place implementation of ----------------- 272 # m = b1 * m + (1 - b1) * g 273 m *= b1 274 gm *= 1 - b1 275 m += gm 276 # -------------------------------------------- 277 mst["mean"] = m 278 279 v, b2 = mst["variance"], mst["b2"] 280 # In-place implementation of ----------------- 281 # v = b2 * v + (1 - b2) * (g ** 2) 282 v *= b2 283 gv **= 2 284 gv *= 1 - b2 285 v += gv 286 # -------------------------------------------- 287 mst["variance"] = v.clip(mst["eps_variance"], None) # avoid division-by-zero 288 mst["variance_hat"] = xp.maximum(mst["variance_hat"], mst["variance"]) 289 290 phi = self.__phi(t=self._astate["idx"]) 291 psi = self.__psi(t=self._astate["idx"]) 292 ## ===================================================== 293 294 ## Part 2: take a step in the gradient's direction ===== 295 # In-place implementation of ----------------- 296 # x = x - a * (phi / psi) 297 phi /= psi 298 phi *= a 299 mst["x"] = x - phi 300 301 def default_stop_crit(self) -> pxa.StoppingCriterion: 302 from pyxu.opt.stop import RelError 303 304 # Described in [ProxAdam]_ and used in their implementation: 305 # 306 rel_error = RelError( 307 eps=1e-4, 308 var="x", 309 rank=self._f.dim_rank, 310 f=None, 311 norm=2, 312 satisfy_all=True, 313 ) 314 return rel_error 315 316 def objective_func(self) -> pxt.NDArray: 317 func = lambda x: self._f.apply(x) 318 y = func(self._mstate["x"]) 319 return y 320 321 def solution(self) -> pxt.NDArray: 322 """ 323 Returns 324 ------- 325 x: NDArray 326 (..., N) solution. 327 """ 328 data, _ = self.stats() 329 return data.get("x") 330 331 def __phi(self, t: int): 332 mst = self._mstate 333 var = mst["variant"] 334 if var == "adam": 335 out = mst["mean"].copy() 336 out /= 1 - (mst["b1"] ** t) 337 elif var in ["amsgrad", "padam"]: 338 out = mst["mean"].copy() # to allow in-place updates outside __compute_phi() 339 return out 340 341 def __psi(self, t: int): 342 mst = self._mstate 343 xp = pxu.get_array_module(mst["x"]) 344 var = mst["variant"] 345 if var == "adam": 346 out = xp.sqrt(mst["variance"]) 347 out /= xp.sqrt(1 - mst["b2"] ** t) 348 out += mst["eps_adam"] 349 elif var == "amsgrad": 350 out = xp.sqrt(mst["variance_hat"]) 351 elif var == "padam": 352 out = mst["variance_hat"] ** mst["padam_p"] 353 return out 354 355 def __parse_variant(self, variant: str) -> str: 356 supported_variants = {"adam", "amsgrad", "padam"} 357 if (v := variant.lower().strip()) not in supported_variants: 358 raise ValueError(f"Unsupported variant '{variant}'.") 359 return v