Source code for pyxu.opt.solver.nlcg

  1import numpy as np
  2
  3import pyxu.abc as pxa
  4import pyxu.info.ptype as pxt
  5import pyxu.math as pxm
  6import pyxu.util as pxu
  7
  8__all__ = [
  9    "NLCG",
 10]
 11
 12
[docs] 13class NLCG(pxa.Solver): 14 r""" 15 Nonlinear Conjugate Gradient Method (NLCG). 16 17 The Nonlinear Conjugate Gradient method finds a local minimum of the problem 18 19 .. math:: 20 21 \min_{\mathbf{x}\in\mathbb{R}^{N}} f(\mathbf{x}), 22 23 where :math:`f: \mathbb{R}^{N} \to \mathbb{R}` is a *differentiable* functional. When :math:`f` is quadratic, NLCG 24 is equivalent to the Conjugate Gradient (CG) method. NLCG hence has similar convergence behaviour to CG if 25 :math:`f` is locally-quadratic. The converge speed may be slower however due to its line-search overhead 26 [NumOpt_NocWri]_. 27 28 The norm of the `gradient <https://www.wikiwand.com/en/Nonlinear_conjugate_gradient_method>`_ :math:`\nabla f_k = 29 \nabla f(\mathbf{x}_k)` is used as the default stopping criterion. By default, the iterations stop when the norm of the 30 gradient is smaller than 1e-4. 31 32 Multiple variants of NLCG exist. They differ mainly in how the weights applied to conjugate directions are updated. 33 Two popular variants are implemented: 34 35 * The Fletcher-Reeves variant: 36 37 .. math:: 38 39 \beta_k^\text{FR} 40 = 41 \frac{ 42 \Vert{\nabla f_{k+1}}\Vert_{2}^{2} 43 }{ 44 \Vert{\nabla f_{k}}\Vert_{2}^{2} 45 }. 46 47 * The Polak-Ribière+ method: 48 49 .. math:: 50 51 \beta_k^\text{PR} 52 = 53 \frac{ 54 \nabla f_{k+1}^T\left(\nabla f_{k+1} - \nabla f_k\right) 55 }{ 56 \Vert{\nabla f_{k}}\Vert_{2}^{2} 57 } \\ 58 \beta_k^\text{PR+} 59 = 60 \max\left(0, \beta_k^\text{PR}\right). 61 62 Parameters (``__init__()``) 63 --------------------------- 64 * **f** (:py:class:`~pyxu.abc.DiffFunc`) 65 -- 66 Differentiable function :math:`\mathcal{F}`. 67 * **\*\*kwargs** (:py:class:`~collections.abc.Mapping`) 68 -- 69 Other keyword parameters passed on to :py:meth:`pyxu.abc.Solver.__init__`. 70 71 Parameters (``fit()``) 72 ---------------------- 73 * **x0** (:py:attr:`~pyxu.info.ptype.NDArray`) 74 -- 75 (..., N) initial point(s). 76 * **variant** ("PR", "FR") 77 -- 78 Name of the NLCG variant to use: 79 80 * "PR": Polak-Ribière+ variant (default). 81 * "FR": Fletcher-Reeves variant. 82 * **restart_rate** (:py:attr:`~pyxu.info.ptype.Integer`, :py:obj:`None`) 83 -- 84 Number of iterations after which restart is applied. 85 86 By default, restart is done after :math:`N` iterations. 87 * **\*\*kwargs** (:py:class:`~collections.abc.Mapping`) 88 -- 89 Optional parameters forwarded to :py:func:`~pyxu.math.backtracking_linesearch`. 90 91 If `a0` is unspecified and :math:`\nabla f` is :math:`\beta`-Lipschitz continuous, then `a0` is auto-chosen as 92 :math:`\beta^{-1}`. Users are expected to set `a0` if its value cannot be auto-inferred. 93 94 Other keyword parameters are passed on to :py:meth:`pyxu.abc.Solver.fit`. 95 96 Example 97 ------- 98 Consider the following quadratic optimization problem: 99 100 .. math: 101 102 \min_{\mathbf{x}} \Vert{A\mathbf{x}-\mathbf{b}}\Vert_2^2 103 104 105 This problem is strictly convex, hence NLCG will converge to the optimal solution: 106 107 .. code-block:: python3 108 109 import numpy as np 110 111 import pyxu.operator as pxo 112 import pyxu.opt.solver as pxsl 113 114 N, a, b = 5, 3, 1 115 f = pxo.SquaredL2Norm(N).asloss(b).argscale(a) # \norm(Ax - b)**2 116 117 nlcg = pxsl.NLCG(f) 118 nlcg.fit(x0=np.zeros((N,)), variant="FR") 119 x_opt = nlcg.solution() 120 np.allclose(x_opt, 1/a) # True 121 122 Note however that the CG method is preferable in this context since it omits the linesearch overhead. The former 123 depends on the cost of applying :math:`A`, and may be significant. 124 """ 125 126 def __init__(self, f: pxa.DiffFunc, **kwargs): 127 kwargs.update( 128 log_var=kwargs.get("log_var", ("x",)), 129 ) 130 super().__init__(**kwargs) 131 132 self._f = f 133 134 def m_init( 135 self, 136 x0: pxt.NDArray, 137 variant: str = "PR", 138 restart_rate: pxt.Integer = None, 139 **kwargs, 140 ): 141 mst = self._mstate # shorthand 142 143 if (a0 := kwargs.get("a0")) is None: 144 d_l = self._f.diff_lipschitz 145 if np.isclose(d_l, np.inf) or np.isclose(d_l, 0): 146 msg = "[NLCG] cannot auto-infer initial step size: specify `a0` manually in NLCG.fit()" 147 raise ValueError(msg) 148 else: 149 a0 = 1.0 / d_l 150 151 if restart_rate is not None: 152 assert restart_rate >= 1 153 mst["restart_rate"] = int(restart_rate) 154 else: 155 mst["restart_rate"] = x0.shape[-1] 156 157 import pyxu.math.linesearch as ls 158 159 mst["x"] = x0 160 mst["gradient"] = self._f.grad(x0) 161 mst["conjugate_dir"] = -mst["gradient"].copy() 162 mst["variant"] = self.__parse_variant(variant) 163 mst["ls_a0"] = a0 164 mst["ls_r"] = kwargs.get("r", ls.LINESEARCH_DEFAULT_R) 165 mst["ls_c"] = kwargs.get("c", ls.LINESEARCH_DEFAULT_C) 166 mst["ls_a_k"] = mst["ls_a0"] 167 168 def m_step(self): 169 mst = self._mstate # shorthand 170 x_k, g_k, p_k = mst["x"], mst["gradient"], mst["conjugate_dir"] 171 172 a_k = pxm.backtracking_linesearch( 173 f=self._f, 174 x=x_k, 175 gradient=g_k, 176 direction=p_k, 177 a0=mst["ls_a0"], 178 r=mst["ls_r"], 179 c=mst["ls_c"], 180 ) 181 # In-place implementation of ----------------- 182 # x_kp1 = x_k + p_k * a_k 183 x_kp1 = p_k.copy() 184 x_kp1 *= a_k 185 x_kp1 += x_k 186 # -------------------------------------------- 187 g_kp1 = self._f.grad(x_kp1) 188 189 # Because NLCG can only generate n conjugate vectors in an n-dimensional space, it makes sense 190 # to restart NLCG every n iterations. 191 if self._astate["idx"] % mst["restart_rate"] == 0: 192 beta_kp1 = 0.0 193 else: 194 beta_kp1 = self.__compute_beta(g_k, g_kp1) 195 196 # In-place implementation of ----------------- 197 # p_kp1 = -g_kp1 + beta_kp1 * p_k 198 p_kp1 = p_k.copy() 199 p_kp1 *= beta_kp1 200 p_kp1 -= g_kp1 201 # -------------------------------------------- 202 203 mst["x"], mst["gradient"], mst["conjugate_dir"], mst["ls_a_k"] = x_kp1, g_kp1, p_kp1, a_k 204 205 def default_stop_crit(self) -> pxa.StoppingCriterion: 206 from pyxu.opt.stop import AbsError 207 208 stop_crit = AbsError( 209 eps=1e-4, 210 var="gradient", 211 rank=self._f.dim_rank, 212 f=None, 213 norm=2, 214 satisfy_all=True, 215 ) 216 return stop_crit 217 218 def objective_func(self) -> pxt.NDArray: 219 return self._f(self._mstate["x"]) 220 221 def solution(self) -> pxt.NDArray: 222 """ 223 Returns 224 ------- 225 x: NDArray 226 (..., N) solution. 227 """ 228 data, _ = self.stats() 229 return data.get("x") 230 231 def __compute_beta(self, g_k: pxt.NDArray, g_kp1: pxt.NDArray) -> pxt.NDArray: 232 v = self._mstate["variant"] 233 xp = pxu.get_array_module(g_k) 234 if v == "fr": # Fletcher-Reeves 235 gn_k = xp.linalg.norm(g_k, axis=-1, keepdims=True) 236 gn_kp1 = xp.linalg.norm(g_kp1, axis=-1, keepdims=True) 237 beta = (gn_kp1 / gn_k) ** 2 238 elif v == "pr": # Poliak-Ribière+ 239 gn_k = xp.linalg.norm(g_k, axis=-1, keepdims=True) 240 numerator = (g_kp1 * (g_kp1 - g_k)).sum(axis=-1, keepdims=True) 241 beta = numerator / (gn_k**2) 242 beta = beta.clip(min=0) 243 return beta # (..., 1) 244 245 def __parse_variant(self, variant: str) -> str: 246 supported_variants = {"fr", "pr"} 247 if (v := variant.lower().strip()) not in supported_variants: 248 raise ValueError(f"Unsupported variant '{variant}'.") 249 return v