1import itertools
2import math
3import warnings
4
5import numpy as np
6
7import pyxu.abc as pxa
8import pyxu.info.ptype as pxt
9import pyxu.info.warning as pxw
10import pyxu.operator as pxo
11import pyxu.util as pxu
12
13__all__ = [
14 "PGD",
15]
16
17
[docs]
18class PGD(pxa.Solver):
19 r"""
20 Proximal Gradient Descent (PGD) solver.
21
22 PGD solves minimization problems of the form
23
24 .. math::
25
26 {\min_{\mathbf{x}\in\mathbb{R}^{M_{1} \times\cdots\times M_{D}}} \;
27 \mathcal{F}(\mathbf{x})\;\;+\;\;\mathcal{G}(\mathbf{x})},
28
29 where:
30
31 * :math:`\mathcal{F}:\mathbb{R}^{M_{1} \times\cdots\times M_{D}}\rightarrow \mathbb{R}` is *convex* and
32 *differentiable*, with :math:`\beta`-*Lipschitz continuous* gradient, for some :math:`\beta\in[0,+\infty[`.
33 * :math:`\mathcal{G}:\mathbb{R}^{M_{1} \times\cdots\times M_{D}}\rightarrow \mathbb{R}\cup\{+\infty\}` is a
34 *proper*, *lower semicontinuous* and *convex function* with a *simple proximal operator*.
35
36 Remarks
37 -------
38 * The problem is *feasible* -- i.e. there exists at least one solution.
39
40 * The algorithm is still valid if either :math:`\mathcal{F}` or :math:`\mathcal{G}` is zero.
41
42 * The convergence is guaranteed for step sizes :math:`\tau\leq 1/\beta`.
43
44 * Various acceleration schemes are described in [APGD]_. PGD achieves the following (optimal) *convergence rate*
45 with the implemented acceleration scheme from Chambolle & Dossal:
46
47 .. math::
48
49 \lim\limits_{n\rightarrow \infty} n^2\left\vert \mathcal{J}(\mathbf{x}^\star)- \mathcal{J}(\mathbf{x}_n)\right\vert=0
50 \qquad\&\qquad
51 \lim\limits_{n\rightarrow \infty} n^2\Vert \mathbf{x}_n-\mathbf{x}_{n-1}\Vert^2_\mathcal{X}=0,
52
53 for *some minimiser* :math:`{\mathbf{x}^\star}\in\arg\min_{\mathbf{x}\in\mathbb{R}^{M_{1} \times\cdots\times
54 M_{D}}} \;\left\{\mathcal{J}(\mathbf{x}):=\mathcal{F}(\mathbf{x})+\mathcal{G}(\mathbf{x})\right\}`. In other
55 words, both the objective functional and the PGD iterates :math:`\{\mathbf{x}_n\}_{n\in\mathbb{N}}` converge at a
56 rate :math:`o(1/n^2)`. Significant practical *speedup* can be achieved for values of :math:`d` in the range
57 :math:`[50,100]` [APGD]_.
58
59 * The relative norm change of the primal variable is used as the default stopping criterion. By default, the
60 algorithm stops when the norm of the difference between two consecutive PGD iterates
61 :math:`\{\mathbf{x}_n\}_{n\in\mathbb{N}}` is smaller than 1e-4. Different stopping criteria can be used.
62
63 Parameters (``__init__()``)
64 ---------------------------
65 * **f** (:py:class:`~pyxu.abc.DiffFunc`, :py:obj:`None`)
66 --
67 Differentiable function :math:`\mathcal{F}`.
68 * **g** (:py:class:`~pyxu.abc.ProxFunc`, :py:obj:`None`)
69 --
70 Proximable function :math:`\mathcal{G}`.
71 * **\*\*kwargs** (:py:class:`~collections.abc.Mapping`)
72 --
73 Other keyword parameters passed on to :py:meth:`pyxu.abc.Solver.__init__`.
74
75 Parameters (``fit()``)
76 ----------------------
77 * **x0** (:py:attr:`~pyxu.info.ptype.NDArray`)
78 --
79 (..., M1,...,MD) initial point(s).
80 * **tau** (:py:attr:`~pyxu.info.ptype.Real`, :py:obj:`None`)
81 --
82 Gradient step size. Defaults to :math:`1 / \beta` if unspecified.
83 * **acceleration** (:py:obj:`bool`)
84 --
85 If True (default), then use Chambolle & Dossal acceleration scheme.
86 * **d** (:py:attr:`~pyxu.info.ptype.Real`)
87 --
88 Chambolle & Dossal acceleration parameter :math:`d`. Should be greater than 2. Only meaningful if `acceleration`
89 is True. Defaults to 75 in unspecified.
90 * **\*\*kwargs** (:py:class:`~collections.abc.Mapping`)
91 --
92 Other keyword parameters passed on to :py:meth:`pyxu.abc.Solver.fit`.
93 """
94
95 def __init__(
96 self,
97 f: pxa.DiffFunc = None,
98 g: pxa.ProxFunc = None,
99 **kwargs,
100 ):
101 kwargs.update(
102 log_var=kwargs.get("log_var", ("x",)),
103 )
104 super().__init__(**kwargs)
105
106 if (f is None) and (g is None):
107 msg = " ".join(
108 [
109 "Cannot minimize always-0 functional.",
110 "At least one of Parameter[f, g] must be specified.",
111 ]
112 )
113 raise NotImplementedError(msg)
114 elif f is None:
115 self._f = pxo.NullFunc(dim_shape=g.dim_shape)
116 self._g = g
117 elif g is None:
118 self._f = f
119 self._g = pxo.NullFunc(dim_shape=f.dim_shape)
120 else:
121 self._f = f
122 self._g = g
123
124 def m_init(
125 self,
126 x0: pxt.NDArray,
127 tau: pxt.Real = None,
128 acceleration: bool = True,
129 d: pxt.Real = 75,
130 ):
131 mst = self._mstate # shorthand
132 mst["x"] = mst["x_prev"] = x0
133
134 if tau is None:
135 mst["tau"] = 1 / np.array(self._f.diff_lipschitz)
136 if math.isclose(mst["tau"], 0):
137 # _f does not provide any "useful" diff_lipschitz constant.
138 msg = "\n".join(
139 [
140 "No useful step size could be auto-determined from Parameter[f].",
141 "Consider initializing Parameter[tau] directly, or set (an estimate of) the diff-Lipschitz constant of Parameter[f] before calling fit().",
142 "Solver iterations as-is may stagnate.",
143 ]
144 )
145 warnings.warn(msg, pxw.AutoInferenceWarning)
146 if math.isinf(mst["tau"]):
147 # _f is constant-valued: \tau is a free parameter.
148 mst["tau"] = 1
149 msg = "\n".join(
150 [
151 rf"The gradient/proximal step size \tau is auto-set to {mst['tau']}.",
152 r"Choosing \tau manually may lead to faster convergence.",
153 ]
154 )
155 warnings.warn(msg, pxw.AutoInferenceWarning)
156 else:
157 try:
158 assert tau > 0
159 mst["tau"] = tau
160 except Exception:
161 raise ValueError(f"tau must be positive, got {tau}.")
162
163 if acceleration:
164 try:
165 assert d > 2
166 mst["a"] = (k / (k + 1 + d) for k in itertools.count(start=0))
167 except Exception:
168 raise ValueError(f"Expected d > 2, got {d}.")
169 else:
170 mst["a"] = itertools.repeat(0.0)
171
172 def m_step(self):
173 mst = self._mstate # shorthand
174 a = next(mst["a"])
175
176 # In-place implementation of -----------------
177 # y = (1 + a) * mst["x"] - a * mst["x_prev"]
178 y = mst["x"] - mst["x_prev"]
179 y *= a
180 y += mst["x"]
181 # --------------------------------------------
182
183 # In-place implementation of -----------------
184 # z = y - mst["tau"] * self._f.grad(y)
185 z = pxu.copy_if_unsafe(self._f.grad(y))
186 z *= -mst["tau"]
187 z += y
188 # --------------------------------------------
189
190 mst["x_prev"], mst["x"] = mst["x"], self._g.prox(z, mst["tau"])
191
192 def default_stop_crit(self) -> pxa.StoppingCriterion:
193 from pyxu.opt.stop import RelError
194
195 stop_crit = RelError(
196 eps=1e-4,
197 var="x",
198 rank=self._f.dim_rank,
199 f=None,
200 norm=2,
201 satisfy_all=True,
202 )
203 return stop_crit
204
205 def objective_func(self) -> pxt.NDArray:
206 func = lambda x: self._f.apply(x) + self._g.apply(x)
207
208 y = func(self._mstate["x"])
209 return y
210
211 def solution(self) -> pxt.NDArray:
212 """
213 Returns
214 -------
215 x: NDArray
216 (..., M1,...,MD) solution.
217 """
218 data, _ = self.stats()
219 return data.get("x")