Source code for pyxu.opt.solver.cg

  1import warnings
  2
  3import numpy as np
  4
  5import pyxu.abc as pxa
  6import pyxu.info.ptype as pxt
  7import pyxu.runtime as pxrt
  8import pyxu.util as pxu
  9
 10__all__ = [
 11    "CG",
 12]
 13
 14
[docs] 15class CG(pxa.Solver): 16 r""" 17 Conjugate Gradient Method. 18 19 The Conjugate Gradient method solves the minimization problem 20 21 .. math:: 22 23 \min_{\mathbf{x}\in\mathbb{R}^{M_{1} \times\cdots\times M_{D}}} 24 \frac{1}{2} \langle \mathbf{x}, \mathbf{A} \mathbf{x} \rangle - \langle \mathbf{x}, \mathbf{b} \rangle, 25 26 where :math:`\mathbf{A}: \mathbb{R}^{{M_{1} \times\cdots\times M_{D}}} \to \mathbb{R}^{{M_{1} \times\cdots\times 27 M_{D}}}` is a *symmetric* *positive definite* operator, and :math:`\mathbf{b} \in \mathbb{R}^{{M_{1} 28 \times\cdots\times M_{D}}}`. 29 30 The norm of the `explicit residual <https://www.wikiwand.com/en/Conjugate_gradient_method>`_ :math:`\mathbf 31 {r}_{k+1}:=\mathbf{b}-\mathbf{Ax}_{k+1}` is used as the default stopping criterion. This provides a guaranteed 32 level of accuracy both in exact arithmetic and in the presence of round-off errors. By default, the iterations stop 33 when the norm of the explicit residual is smaller than 1e-4. 34 35 Parameters (``__init__()``) 36 --------------------------- 37 * **A** (:py:class:`~pyxu.abc.PosDefOp`) 38 -- 39 Positive-definite operator :math:`\mathbf{A}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{M_{1} 40 \times\cdots\times M_{D}}`. 41 * **\*\*kwargs** (:py:class:`~collections.abc.Mapping`) 42 -- 43 Other keyword parameters passed on to :py:meth:`pyxu.abc.Solver.__init__`. 44 45 Parameters (``fit()``) 46 ---------------------- 47 * **b** (:py:attr:`~pyxu.info.ptype.NDArray`) 48 -- 49 (..., M1,...,MD) :math:`\mathbf{b}` terms in the CG cost function. 50 51 All problems are solved in parallel. 52 * **x0** (:py:attr:`~pyxu.info.ptype.NDArray`, :py:obj:`None`) 53 -- 54 (..., M1,...,MD) initial point(s). 55 56 Must be broadcastable with `b` if provided. Defaults to 0. 57 * **restart_rate** (:py:attr:`~pyxu.info.ptype.Integer`) 58 -- 59 Number of iterations after which restart is applied. 60 61 By default, a restart is done after 'n' iterations, where 'n' corresponds to the dimension of :math:`\mathbf{A}`. 62 * **\*\*kwargs** (:py:class:`~collections.abc.Mapping`) 63 -- 64 Other keyword parameters passed on to :py:meth:`pyxu.abc.Solver.fit`. 65 """ 66 67 def __init__(self, A: pxa.PosDefOp, **kwargs): 68 kwargs.update( 69 log_var=kwargs.get("log_var", ("x",)), 70 ) 71 super().__init__(**kwargs) 72 73 self._A = A 74 75 def m_init( 76 self, 77 b: pxt.NDArray, 78 x0: pxt.NDArray = None, 79 restart_rate: pxt.Integer = None, 80 ): 81 mst = self._mstate # shorthand 82 83 if restart_rate is not None: 84 assert restart_rate >= 1 85 mst["restart_rate"] = int(restart_rate) 86 else: 87 mst["restart_rate"] = self._A.dim_size 88 89 xp = pxu.get_array_module(b) 90 if x0 is None: 91 mst["b"] = b 92 mst["x"] = xp.zeros_like(b) 93 elif b.shape == x0.shape: 94 # No broadcasting involved 95 mst["b"] = b 96 mst["x"] = x0.copy() 97 else: 98 # In-place updates involving b/x won't work if shapes differ -> coerce to largest. 99 mst["b"], x0 = xp.broadcast_arrays(b, x0) 100 mst["x"] = x0.copy() 101 102 mst["residual"] = mst["b"].copy() 103 mst["residual"] -= self._A.apply(mst["x"]) 104 mst["conjugate_dir"] = mst["residual"].copy() 105 106 def m_step(self): 107 mst = self._mstate # shorthand 108 x, r, p = mst["x"], mst["residual"], mst["conjugate_dir"] 109 xp = pxu.get_array_module(x) 110 reduce = lambda x: x.sum( # (..., M1,...,MD) -> (..., 1,...,1) 111 axis=tuple(range(-self._A.dim_rank, 0)), 112 keepdims=True, 113 ) 114 115 Ap = self._A.apply(p) # (..., M1,...,MD) 116 rr = reduce(r**2) 117 with warnings.catch_warnings(): 118 warnings.simplefilter("ignore") 119 alpha = xp.nan_to_num(rr / reduce(p * Ap)) # (..., 1,...1) 120 x += alpha * p 121 122 if pxu.compute(xp.any(rr <= pxrt.Width(rr.dtype).eps())): # explicit eval 123 r[:] = mst["b"] 124 r -= self._A.apply(x) 125 else: # implicit eval 126 r -= alpha * Ap 127 128 # Because CG can only generate N conjugate vectors in an N-dimensional space, it makes sense 129 # to restart CG every N iterations. 130 if self._astate["idx"] % mst["restart_rate"] == 0: # explicit eval 131 beta = 0 132 r[:] = mst["b"] 133 r -= self._A.apply(x) 134 else: # implicit eval 135 with warnings.catch_warnings(): 136 warnings.simplefilter("ignore") 137 beta = xp.nan_to_num(reduce(r**2) / rr) 138 p *= beta 139 p += r 140 141 # for homogenity with other solver code. Optional in CG due to in-place computations. 142 mst["x"], mst["residual"], mst["conjugate_dir"] = x, r, p 143 144 def default_stop_crit(self) -> pxa.StoppingCriterion: 145 from pyxu.opt.stop import AbsError 146 147 stop_crit = AbsError( 148 eps=1e-4, 149 var="residual", 150 rank=self._A.dim_rank, 151 f=None, 152 norm=2, 153 satisfy_all=True, 154 ) 155 return stop_crit 156 157 def objective_func(self) -> pxt.NDArray: 158 x = self._mstate["x"] # (..., M1,...,MD) 159 b = self._mstate["b"] # (..., M1,...,MD) 160 161 f = self._A.apply(x) 162 f = pxu.copy_if_unsafe(f) 163 f /= 2 164 f -= b 165 f *= x 166 167 axis = tuple(range(self._A.dim_rank, 0)) 168 y = f.sum(axis=axis)[..., np.newaxis] # (..., 1) 169 return y 170 171 def solution(self) -> pxt.NDArray: 172 """ 173 Returns 174 ------- 175 x: NDArray 176 (..., M1,...,MD) solution. 177 """ 178 data, _ = self.stats() 179 return data.get("x")