1import math
2
3import pyxu.abc as pxa
4import pyxu.info.ptype as pxt
5import pyxu.info.warning as pxw
6import pyxu.util as pxu
7
8__all__ = [
9 "Adam",
10]
11
12
[docs]
13class Adam(pxa.Solver):
14 r"""
15 Adam solver [ProxAdam]_.
16
17 Adam minimizes
18
19 .. math::
20
21 {\min_{\mathbf{x}\in\mathbb{R}^N} \;\mathcal{F}(\mathbf{x})},
22
23 where:
24
25 * :math:`\mathcal{F}:\mathbb{R}^N\rightarrow \mathbb{R}` is *convex* and *differentiable*, with
26 :math:`\beta`-*Lipschitz continuous* gradient, for some :math:`\beta\in[0,+\infty[`.
27
28 Adam is a suitable alternative to Proximal Gradient Descent (:py:class:`~pyxu.opt.solver.PGD`) when:
29
30 * the cost function is differentiable,
31 * computing :math:`\beta` to optimally choose the step size is infeasible,
32 * line-search methods to estimate step sizes are too expensive.
33
34 Compared to PGD, Adam auto-tunes gradient updates based on stochastic estimates of :math:`\phi_{t} =
35 \mathbb{E}[\nabla\mathcal{F}]` and :math:`\psi_{t} = \mathbb{E}[\nabla\mathcal{F}^{2}]` respectively.
36
37 Adam has many named variants for particular choices of :math:`\phi` and :math:`\psi`:
38
39 * Adam:
40
41 .. math::
42
43 \phi_t
44 =
45 \frac{
46 \mathbf{m}_t
47 }{
48 1-\beta_1^t
49 }
50 \qquad
51 \psi_t
52 =
53 \sqrt{
54 \frac{
55 \mathbf{v}_t
56 }{
57 1-\beta_2^t
58 }
59 } + \epsilon,
60
61 * AMSGrad:
62
63 .. math::
64
65 \phi_t = \mathbf{m}_t
66 \qquad
67 \psi_t = \sqrt{\hat{\mathbf{v}}_t},
68
69 * PAdam:
70
71 .. math::
72
73 \phi_t = \mathbf{m}_t
74 \qquad
75 \psi_t = \hat{\mathbf{v}}_t^p,
76
77 where in all cases:
78
79 .. math::
80
81 \mathbf{m}_t
82 =
83 \beta_1\mathbf{m}_{t-1}
84 +
85 (1-\beta_1)\mathbf{g}_t \\
86 \mathbf{v}_t
87 =
88 \beta_2\mathbf{v}_{t-1}
89 +
90 (1-\beta_2)\mathbf{g}_t^2\\
91 \hat{\mathbf{v}}_t
92 =
93 \max(\hat{\mathbf{v}}_{t-1}, \mathbf{v}_t),
94
95 with :math:`\mathbf{m}_0 = \mathbf{v}_0 = \mathbf{0}`.
96
97 Remarks
98 -------
99 * The convergence is guaranteed for step sizes :math:`\alpha\leq 2/\beta`.
100
101 * The default stopping criterion is the relative norm change of the primal variable. By default, the algorithm
102 stops when the norm of the difference between two consecutive iterates :math:`\{\mathbf{x}_n\}_{n\in\mathbb{N}}`
103 is smaller than 1e-4. Different stopping criteria can be used. It is recommended to change the stopping
104 criterion when using the PAdam and AMSGrad variants to avoid premature stops.
105
106 Parameters (``__init__()``)
107 ---------------------------
108 * **f** (:py:class:`~pyxu.abc.DiffFunc`)
109 --
110 Differentiable function :math:`\mathcal{F}`.
111 * **\*\*kwargs** (:py:class:`~collections.abc.Mapping`)
112 --
113 Other keyword parameters passed on to :py:meth:`pyxu.abc.Solver.__init__`.
114
115 Parameters (``fit()``)
116 ----------------------
117 * **x0** (:py:attr:`~pyxu.info.ptype.NDArray`)
118 --
119 (..., N) initial point(s).
120 * **variant** ("adam", "amsgrad", "padam")
121 --
122 Name of the Adam variant to use. Defaults to "adam".
123 * **a** (:py:attr:`~pyxu.info.ptype.Real`, :py:obj:`None`)
124 --
125 Max normalized gradient step size. Defaults to :math:`1 / \beta` if unspecified.
126 * **b1** (:py:attr:`~pyxu.info.ptype.Real`)
127 --
128 1st-order gradient exponential decay :math:`\beta_{1} \in [0, 1)`.
129 * **b2** (:py:attr:`~pyxu.info.ptype.Real`)
130 --
131 2nd-order gradient exponential decay :math:`\beta_{2} \in [0, 1)`.
132 * **m0** (:py:attr:`~pyxu.info.ptype.NDArray`, :py:obj:`None`)
133 --
134 (..., N) initial 1st-order gradient estimate corresponding to each initial point. Defaults to the null vector if
135 unspecified.
136 * **v0** (:py:attr:`~pyxu.info.ptype.NDArray`, :py:obj:`None`)
137 --
138 (..., N) initial 2nd-order gradient estimate corresponding to each initial point. Defaults to the null vector if
139 unspecified.
140 * **p** (:py:attr:`~pyxu.info.ptype.Real`)
141 --
142 PAdam power parameter :math:`p \in (0, 0.5]`. Must be specified for PAdam, unused otherwise.
143 * **eps_adam** (:py:attr:`~pyxu.info.ptype.Real`)
144 --
145 Adam noise parameter :math:`\epsilon`. This term is used exclusively if `variant="adam"`. Defaults to 1e-6.
146 * **eps_var** (:py:attr:`~pyxu.info.ptype.Real`)
147 --
148 Avoids division by zero if estimated gradient variance is too small. Defaults to 1e-6.
149 * **\*\*kwargs** (:py:class:`~collections.abc.Mapping`)
150 --
151 Other keyword parameters passed on to :py:meth:`pyxu.abc.Solver.fit`.
152
153 Note
154 ----
155 If provided, `m0` and `v0` must be broadcastable with `x0`.
156
157 Example
158 -------
159 Consider the following optimization problem:
160
161 .. math::
162
163 \min_{\mathbf{x}\in\mathbb{R}^N} \Vert{\mathbf{x}-\mathbf{1}}\Vert_2^2
164
165 .. code-block:: python3
166
167 import numpy as np
168
169 from pyxu.operator import SquaredL2Norm
170 from pyxu.opt.solver import Adam
171
172 N = 3
173 f = SquaredL2Norm(dim=N).asloss(1)
174
175 slvr = Adam(f)
176 slvr.fit(
177 x0=np.zeros((N,)),
178 variant="padam",
179 p=0.25,
180 )
181 x_opt = slvr.solution()
182 np.allclose(x_opt, 1, rtol=1e-4) # True
183 """
184
185 def __init__(
186 self,
187 f: pxa.DiffFunc,
188 **kwargs,
189 ):
190 kwargs.update(
191 log_var=kwargs.get("log_var", ("x",)),
192 )
193 super().__init__(**kwargs)
194
195 self._f = f
196
197 def m_init( # default values from https://github.com/pmelchior/proxmin/blob/master/proxmin/algorithms.py
198 self,
199 x0: pxt.NDArray,
200 variant: str = "adam",
201 a: pxt.Real = None,
202 b1: pxt.Real = 0.9,
203 b2: pxt.Real = 0.999,
204 m0: pxt.NDArray = None,
205 v0: pxt.NDArray = None,
206 p: pxt.Real = 0.5,
207 eps_adam: pxt.Real = 1e-6,
208 eps_var: pxt.Real = 1e-6,
209 ):
210 mst = self._mstate # shorthand
211 xp = pxu.get_array_module(x0)
212 mst["x"] = x0
213
214 mst["variant"] = self.__parse_variant(variant)
215
216 if a is None:
217 g = lambda _: math.isclose(self._f.diff_lipschitz, _)
218 if g(0) or g(math.inf):
219 error_msg = "Cannot auto-infer step size: choose it manually."
220 raise pxw.AutoInferenceWarning(error_msg)
221 else:
222 mst["a"] = 1.0 / self._f.diff_lipschitz
223 else:
224 assert a > 0, f"Parameter[a] must be positive, got {a}."
225 mst["a"] = a
226
227 assert 0 <= b1 < 1, f"Parameter[b1]: expected value in [0, 1), got {b1}."
228 mst["b1"] = b1
229
230 assert 0 <= b2 < 1, f"Parameter[b2]: expected value in [0, 1), got {b2}."
231 mst["b2"] = b2
232
233 if m0 is None:
234 mst["mean"] = xp.zeros_like(x0)
235 elif m0.shape == x0.shape:
236 # No broadcasting involved
237 mst["mean"] = m0
238 else:
239 x0, m0 = xp.broadcast_arrays(x0, m0)
240 mst["mean"] = m0.copy()
241
242 if v0 is None:
243 mst["variance"] = xp.zeros_like(x0)
244 elif v0.shape == x0.shape:
245 # No broadcasting involved
246 mst["variance"] = v0
247 else:
248 x0, v0 = xp.broadcast_arrays(x0, v0)
249 mst["variance"] = v0.copy()
250 mst["variance_hat"] = mst["variance"]
251
252 assert 0 < p <= 0.5, f"Parameter[p]: expected value in (0, 0.5], got {p}."
253 mst["padam_p"] = p
254
255 assert eps_adam > 0, f"Parameter[eps_adam]: expected positive value, got {eps_adam}."
256 mst["eps_adam"] = eps_adam
257
258 assert eps_var > 0, f"Parameter[eps_var]: expected positive value, got {eps_var}."
259 mst["eps_variance"] = eps_var
260
261 def m_step(self):
262 mst = self._mstate # shorthand
263
264 x, a = mst["x"], mst["a"]
265 xp = pxu.get_array_module(x)
266 gm = pxu.copy_if_unsafe(self._f.grad(x))
267 gv = gm.copy()
268
269 ## Part 1: evaluate phi/psi ============================
270 m, b1 = mst["mean"], mst["b1"]
271 # In-place implementation of -----------------
272 # m = b1 * m + (1 - b1) * g
273 m *= b1
274 gm *= 1 - b1
275 m += gm
276 # --------------------------------------------
277 mst["mean"] = m
278
279 v, b2 = mst["variance"], mst["b2"]
280 # In-place implementation of -----------------
281 # v = b2 * v + (1 - b2) * (g ** 2)
282 v *= b2
283 gv **= 2
284 gv *= 1 - b2
285 v += gv
286 # --------------------------------------------
287 mst["variance"] = v.clip(mst["eps_variance"], None) # avoid division-by-zero
288 mst["variance_hat"] = xp.maximum(mst["variance_hat"], mst["variance"])
289
290 phi = self.__phi(t=self._astate["idx"])
291 psi = self.__psi(t=self._astate["idx"])
292 ## =====================================================
293
294 ## Part 2: take a step in the gradient's direction =====
295 # In-place implementation of -----------------
296 # x = x - a * (phi / psi)
297 phi /= psi
298 phi *= a
299 mst["x"] = x - phi
300
301 def default_stop_crit(self) -> pxa.StoppingCriterion:
302 from pyxu.opt.stop import RelError
303
304 # Described in [ProxAdam]_ and used in their implementation:
305 # https://github.com/pmelchior/proxmin/blob/master/proxmin/algorithms.py
306 rel_error = RelError(
307 eps=1e-4,
308 var="x",
309 rank=self._f.dim_rank,
310 f=None,
311 norm=2,
312 satisfy_all=True,
313 )
314 return rel_error
315
316 def objective_func(self) -> pxt.NDArray:
317 func = lambda x: self._f.apply(x)
318 y = func(self._mstate["x"])
319 return y
320
321 def solution(self) -> pxt.NDArray:
322 """
323 Returns
324 -------
325 x: NDArray
326 (..., N) solution.
327 """
328 data, _ = self.stats()
329 return data.get("x")
330
331 def __phi(self, t: int):
332 mst = self._mstate
333 var = mst["variant"]
334 if var == "adam":
335 out = mst["mean"].copy()
336 out /= 1 - (mst["b1"] ** t)
337 elif var in ["amsgrad", "padam"]:
338 out = mst["mean"].copy() # to allow in-place updates outside __compute_phi()
339 return out
340
341 def __psi(self, t: int):
342 mst = self._mstate
343 xp = pxu.get_array_module(mst["x"])
344 var = mst["variant"]
345 if var == "adam":
346 out = xp.sqrt(mst["variance"])
347 out /= xp.sqrt(1 - mst["b2"] ** t)
348 out += mst["eps_adam"]
349 elif var == "amsgrad":
350 out = xp.sqrt(mst["variance_hat"])
351 elif var == "padam":
352 out = mst["variance_hat"] ** mst["padam_p"]
353 return out
354
355 def __parse_variant(self, variant: str) -> str:
356 supported_variants = {"adam", "amsgrad", "padam"}
357 if (v := variant.lower().strip()) not in supported_variants:
358 raise ValueError(f"Unsupported variant '{variant}'.")
359 return v