1import numpy as np
2
3import pyxu.abc as pxa
4import pyxu.info.ptype as pxt
5import pyxu.math as pxm
6import pyxu.util as pxu
7
8__all__ = [
9 "NLCG",
10]
11
12
[docs]
13class NLCG(pxa.Solver):
14 r"""
15 Nonlinear Conjugate Gradient Method (NLCG).
16
17 The Nonlinear Conjugate Gradient method finds a local minimum of the problem
18
19 .. math::
20
21 \min_{\mathbf{x}\in\mathbb{R}^{N}} f(\mathbf{x}),
22
23 where :math:`f: \mathbb{R}^{N} \to \mathbb{R}` is a *differentiable* functional. When :math:`f` is quadratic, NLCG
24 is equivalent to the Conjugate Gradient (CG) method. NLCG hence has similar convergence behaviour to CG if
25 :math:`f` is locally-quadratic. The converge speed may be slower however due to its line-search overhead
26 [NumOpt_NocWri]_.
27
28 The norm of the `gradient <https://www.wikiwand.com/en/Nonlinear_conjugate_gradient_method>`_ :math:`\nabla f_k =
29 \nabla f(\mathbf{x}_k)` is used as the default stopping criterion. By default, the iterations stop when the norm of the
30 gradient is smaller than 1e-4.
31
32 Multiple variants of NLCG exist. They differ mainly in how the weights applied to conjugate directions are updated.
33 Two popular variants are implemented:
34
35 * The Fletcher-Reeves variant:
36
37 .. math::
38
39 \beta_k^\text{FR}
40 =
41 \frac{
42 \Vert{\nabla f_{k+1}}\Vert_{2}^{2}
43 }{
44 \Vert{\nabla f_{k}}\Vert_{2}^{2}
45 }.
46
47 * The Polak-Ribière+ method:
48
49 .. math::
50
51 \beta_k^\text{PR}
52 =
53 \frac{
54 \nabla f_{k+1}^T\left(\nabla f_{k+1} - \nabla f_k\right)
55 }{
56 \Vert{\nabla f_{k}}\Vert_{2}^{2}
57 } \\
58 \beta_k^\text{PR+}
59 =
60 \max\left(0, \beta_k^\text{PR}\right).
61
62 Parameters (``__init__()``)
63 ---------------------------
64 * **f** (:py:class:`~pyxu.abc.DiffFunc`)
65 --
66 Differentiable function :math:`\mathcal{F}`.
67 * **\*\*kwargs** (:py:class:`~collections.abc.Mapping`)
68 --
69 Other keyword parameters passed on to :py:meth:`pyxu.abc.Solver.__init__`.
70
71 Parameters (``fit()``)
72 ----------------------
73 * **x0** (:py:attr:`~pyxu.info.ptype.NDArray`)
74 --
75 (..., N) initial point(s).
76 * **variant** ("PR", "FR")
77 --
78 Name of the NLCG variant to use:
79
80 * "PR": Polak-Ribière+ variant (default).
81 * "FR": Fletcher-Reeves variant.
82 * **restart_rate** (:py:attr:`~pyxu.info.ptype.Integer`, :py:obj:`None`)
83 --
84 Number of iterations after which restart is applied.
85
86 By default, restart is done after :math:`N` iterations.
87 * **\*\*kwargs** (:py:class:`~collections.abc.Mapping`)
88 --
89 Optional parameters forwarded to :py:func:`~pyxu.math.backtracking_linesearch`.
90
91 If `a0` is unspecified and :math:`\nabla f` is :math:`\beta`-Lipschitz continuous, then `a0` is auto-chosen as
92 :math:`\beta^{-1}`. Users are expected to set `a0` if its value cannot be auto-inferred.
93
94 Other keyword parameters are passed on to :py:meth:`pyxu.abc.Solver.fit`.
95
96 Example
97 -------
98 Consider the following quadratic optimization problem:
99
100 .. math:
101
102 \min_{\mathbf{x}} \Vert{A\mathbf{x}-\mathbf{b}}\Vert_2^2
103
104
105 This problem is strictly convex, hence NLCG will converge to the optimal solution:
106
107 .. code-block:: python3
108
109 import numpy as np
110
111 import pyxu.operator as pxo
112 import pyxu.opt.solver as pxsl
113
114 N, a, b = 5, 3, 1
115 f = pxo.SquaredL2Norm(N).asloss(b).argscale(a) # \norm(Ax - b)**2
116
117 nlcg = pxsl.NLCG(f)
118 nlcg.fit(x0=np.zeros((N,)), variant="FR")
119 x_opt = nlcg.solution()
120 np.allclose(x_opt, 1/a) # True
121
122 Note however that the CG method is preferable in this context since it omits the linesearch overhead. The former
123 depends on the cost of applying :math:`A`, and may be significant.
124 """
125
126 def __init__(self, f: pxa.DiffFunc, **kwargs):
127 kwargs.update(
128 log_var=kwargs.get("log_var", ("x",)),
129 )
130 super().__init__(**kwargs)
131
132 self._f = f
133
134 def m_init(
135 self,
136 x0: pxt.NDArray,
137 variant: str = "PR",
138 restart_rate: pxt.Integer = None,
139 **kwargs,
140 ):
141 mst = self._mstate # shorthand
142
143 if (a0 := kwargs.get("a0")) is None:
144 d_l = self._f.diff_lipschitz
145 if np.isclose(d_l, np.inf) or np.isclose(d_l, 0):
146 msg = "[NLCG] cannot auto-infer initial step size: specify `a0` manually in NLCG.fit()"
147 raise ValueError(msg)
148 else:
149 a0 = 1.0 / d_l
150
151 if restart_rate is not None:
152 assert restart_rate >= 1
153 mst["restart_rate"] = int(restart_rate)
154 else:
155 mst["restart_rate"] = x0.shape[-1]
156
157 import pyxu.math.linesearch as ls
158
159 mst["x"] = x0
160 mst["gradient"] = self._f.grad(x0)
161 mst["conjugate_dir"] = -mst["gradient"].copy()
162 mst["variant"] = self.__parse_variant(variant)
163 mst["ls_a0"] = a0
164 mst["ls_r"] = kwargs.get("r", ls.LINESEARCH_DEFAULT_R)
165 mst["ls_c"] = kwargs.get("c", ls.LINESEARCH_DEFAULT_C)
166 mst["ls_a_k"] = mst["ls_a0"]
167
168 def m_step(self):
169 mst = self._mstate # shorthand
170 x_k, g_k, p_k = mst["x"], mst["gradient"], mst["conjugate_dir"]
171
172 a_k = pxm.backtracking_linesearch(
173 f=self._f,
174 x=x_k,
175 gradient=g_k,
176 direction=p_k,
177 a0=mst["ls_a0"],
178 r=mst["ls_r"],
179 c=mst["ls_c"],
180 )
181 # In-place implementation of -----------------
182 # x_kp1 = x_k + p_k * a_k
183 x_kp1 = p_k.copy()
184 x_kp1 *= a_k
185 x_kp1 += x_k
186 # --------------------------------------------
187 g_kp1 = self._f.grad(x_kp1)
188
189 # Because NLCG can only generate n conjugate vectors in an n-dimensional space, it makes sense
190 # to restart NLCG every n iterations.
191 if self._astate["idx"] % mst["restart_rate"] == 0:
192 beta_kp1 = 0.0
193 else:
194 beta_kp1 = self.__compute_beta(g_k, g_kp1)
195
196 # In-place implementation of -----------------
197 # p_kp1 = -g_kp1 + beta_kp1 * p_k
198 p_kp1 = p_k.copy()
199 p_kp1 *= beta_kp1
200 p_kp1 -= g_kp1
201 # --------------------------------------------
202
203 mst["x"], mst["gradient"], mst["conjugate_dir"], mst["ls_a_k"] = x_kp1, g_kp1, p_kp1, a_k
204
205 def default_stop_crit(self) -> pxa.StoppingCriterion:
206 from pyxu.opt.stop import AbsError
207
208 stop_crit = AbsError(
209 eps=1e-4,
210 var="gradient",
211 rank=self._f.dim_rank,
212 f=None,
213 norm=2,
214 satisfy_all=True,
215 )
216 return stop_crit
217
218 def objective_func(self) -> pxt.NDArray:
219 return self._f(self._mstate["x"])
220
221 def solution(self) -> pxt.NDArray:
222 """
223 Returns
224 -------
225 x: NDArray
226 (..., N) solution.
227 """
228 data, _ = self.stats()
229 return data.get("x")
230
231 def __compute_beta(self, g_k: pxt.NDArray, g_kp1: pxt.NDArray) -> pxt.NDArray:
232 v = self._mstate["variant"]
233 xp = pxu.get_array_module(g_k)
234 if v == "fr": # Fletcher-Reeves
235 gn_k = xp.linalg.norm(g_k, axis=-1, keepdims=True)
236 gn_kp1 = xp.linalg.norm(g_kp1, axis=-1, keepdims=True)
237 beta = (gn_kp1 / gn_k) ** 2
238 elif v == "pr": # Poliak-Ribière+
239 gn_k = xp.linalg.norm(g_k, axis=-1, keepdims=True)
240 numerator = (g_kp1 * (g_kp1 - g_k)).sum(axis=-1, keepdims=True)
241 beta = numerator / (gn_k**2)
242 beta = beta.clip(min=0)
243 return beta # (..., 1)
244
245 def __parse_variant(self, variant: str) -> str:
246 supported_variants = {"fr", "pr"}
247 if (v := variant.lower().strip()) not in supported_variants:
248 raise ValueError(f"Unsupported variant '{variant}'.")
249 return v