1import warnings
2
3import numpy as np
4
5import pyxu.abc as pxa
6import pyxu.info.ptype as pxt
7import pyxu.runtime as pxrt
8import pyxu.util as pxu
9
10__all__ = [
11 "CG",
12]
13
14
[docs]
15class CG(pxa.Solver):
16 r"""
17 Conjugate Gradient Method.
18
19 The Conjugate Gradient method solves the minimization problem
20
21 .. math::
22
23 \min_{\mathbf{x}\in\mathbb{R}^{M_{1} \times\cdots\times M_{D}}}
24 \frac{1}{2} \langle \mathbf{x}, \mathbf{A} \mathbf{x} \rangle - \langle \mathbf{x}, \mathbf{b} \rangle,
25
26 where :math:`\mathbf{A}: \mathbb{R}^{{M_{1} \times\cdots\times M_{D}}} \to \mathbb{R}^{{M_{1} \times\cdots\times
27 M_{D}}}` is a *symmetric* *positive definite* operator, and :math:`\mathbf{b} \in \mathbb{R}^{{M_{1}
28 \times\cdots\times M_{D}}}`.
29
30 The norm of the `explicit residual <https://www.wikiwand.com/en/Conjugate_gradient_method>`_ :math:`\mathbf
31 {r}_{k+1}:=\mathbf{b}-\mathbf{Ax}_{k+1}` is used as the default stopping criterion. This provides a guaranteed
32 level of accuracy both in exact arithmetic and in the presence of round-off errors. By default, the iterations stop
33 when the norm of the explicit residual is smaller than 1e-4.
34
35 Parameters (``__init__()``)
36 ---------------------------
37 * **A** (:py:class:`~pyxu.abc.PosDefOp`)
38 --
39 Positive-definite operator :math:`\mathbf{A}: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{M_{1}
40 \times\cdots\times M_{D}}`.
41 * **\*\*kwargs** (:py:class:`~collections.abc.Mapping`)
42 --
43 Other keyword parameters passed on to :py:meth:`pyxu.abc.Solver.__init__`.
44
45 Parameters (``fit()``)
46 ----------------------
47 * **b** (:py:attr:`~pyxu.info.ptype.NDArray`)
48 --
49 (..., M1,...,MD) :math:`\mathbf{b}` terms in the CG cost function.
50
51 All problems are solved in parallel.
52 * **x0** (:py:attr:`~pyxu.info.ptype.NDArray`, :py:obj:`None`)
53 --
54 (..., M1,...,MD) initial point(s).
55
56 Must be broadcastable with `b` if provided. Defaults to 0.
57 * **restart_rate** (:py:attr:`~pyxu.info.ptype.Integer`)
58 --
59 Number of iterations after which restart is applied.
60
61 By default, a restart is done after 'n' iterations, where 'n' corresponds to the dimension of :math:`\mathbf{A}`.
62 * **\*\*kwargs** (:py:class:`~collections.abc.Mapping`)
63 --
64 Other keyword parameters passed on to :py:meth:`pyxu.abc.Solver.fit`.
65 """
66
67 def __init__(self, A: pxa.PosDefOp, **kwargs):
68 kwargs.update(
69 log_var=kwargs.get("log_var", ("x",)),
70 )
71 super().__init__(**kwargs)
72
73 self._A = A
74
75 def m_init(
76 self,
77 b: pxt.NDArray,
78 x0: pxt.NDArray = None,
79 restart_rate: pxt.Integer = None,
80 ):
81 mst = self._mstate # shorthand
82
83 if restart_rate is not None:
84 assert restart_rate >= 1
85 mst["restart_rate"] = int(restart_rate)
86 else:
87 mst["restart_rate"] = self._A.dim_size
88
89 xp = pxu.get_array_module(b)
90 if x0 is None:
91 mst["b"] = b
92 mst["x"] = xp.zeros_like(b)
93 elif b.shape == x0.shape:
94 # No broadcasting involved
95 mst["b"] = b
96 mst["x"] = x0.copy()
97 else:
98 # In-place updates involving b/x won't work if shapes differ -> coerce to largest.
99 mst["b"], x0 = xp.broadcast_arrays(b, x0)
100 mst["x"] = x0.copy()
101
102 mst["residual"] = mst["b"].copy()
103 mst["residual"] -= self._A.apply(mst["x"])
104 mst["conjugate_dir"] = mst["residual"].copy()
105
106 def m_step(self):
107 mst = self._mstate # shorthand
108 x, r, p = mst["x"], mst["residual"], mst["conjugate_dir"]
109 xp = pxu.get_array_module(x)
110 reduce = lambda x: x.sum( # (..., M1,...,MD) -> (..., 1,...,1)
111 axis=tuple(range(-self._A.dim_rank, 0)),
112 keepdims=True,
113 )
114
115 Ap = self._A.apply(p) # (..., M1,...,MD)
116 rr = reduce(r**2)
117 with warnings.catch_warnings():
118 warnings.simplefilter("ignore")
119 alpha = xp.nan_to_num(rr / reduce(p * Ap)) # (..., 1,...1)
120 x += alpha * p
121
122 if pxu.compute(xp.any(rr <= pxrt.Width(rr.dtype).eps())): # explicit eval
123 r[:] = mst["b"]
124 r -= self._A.apply(x)
125 else: # implicit eval
126 r -= alpha * Ap
127
128 # Because CG can only generate N conjugate vectors in an N-dimensional space, it makes sense
129 # to restart CG every N iterations.
130 if self._astate["idx"] % mst["restart_rate"] == 0: # explicit eval
131 beta = 0
132 r[:] = mst["b"]
133 r -= self._A.apply(x)
134 else: # implicit eval
135 with warnings.catch_warnings():
136 warnings.simplefilter("ignore")
137 beta = xp.nan_to_num(reduce(r**2) / rr)
138 p *= beta
139 p += r
140
141 # for homogenity with other solver code. Optional in CG due to in-place computations.
142 mst["x"], mst["residual"], mst["conjugate_dir"] = x, r, p
143
144 def default_stop_crit(self) -> pxa.StoppingCriterion:
145 from pyxu.opt.stop import AbsError
146
147 stop_crit = AbsError(
148 eps=1e-4,
149 var="residual",
150 rank=self._A.dim_rank,
151 f=None,
152 norm=2,
153 satisfy_all=True,
154 )
155 return stop_crit
156
157 def objective_func(self) -> pxt.NDArray:
158 x = self._mstate["x"] # (..., M1,...,MD)
159 b = self._mstate["b"] # (..., M1,...,MD)
160
161 f = self._A.apply(x)
162 f = pxu.copy_if_unsafe(f)
163 f /= 2
164 f -= b
165 f *= x
166
167 axis = tuple(range(self._A.dim_rank, 0))
168 y = f.sum(axis=axis)[..., np.newaxis] # (..., 1)
169 return y
170
171 def solution(self) -> pxt.NDArray:
172 """
173 Returns
174 -------
175 x: NDArray
176 (..., M1,...,MD) solution.
177 """
178 data, _ = self.stats()
179 return data.get("x")