1import collections.abc as cabc
2import datetime as dt
3import warnings
4
5import numpy as np
6
7import pyxu.abc as pxa
8import pyxu.info.ptype as pxt
9import pyxu.util as pxu
10from pyxu.info.plugin import _load_entry_points
11
12__all__ = [
13 "AbsError",
14 "ManualStop",
15 "MaxDuration",
16 "MaxIter",
17 "Memorize",
18 "RelError",
19]
20
21__all__ = _load_entry_points(globals(), group="pyxu.opt.stop", names=__all__)
22
23SVFunction = cabc.Callable[[pxt.NDArray], pxt.NDArray]
24
25
26def _norm(x: pxt.NDArray, ord: pxt.Integer, rank: pxt.Integer) -> pxt.NDArray:
27 # x: (..., M1,...,MD) [rank=D]
28 # n: (..., 1) [`ord`-norm of `x`, computed over last `rank` axes]
29 xp = pxu.get_array_module(x)
30 axis = tuple(range(-rank, 0))
31 if ord == 0:
32 n = xp.sum(~xp.isclose(x, 0), axis=axis)
33 elif ord == 1:
34 n = xp.sum(xp.fabs(x), axis=axis)
35 elif ord == 2:
36 n = xp.sqrt(xp.sum(x**2, axis=axis))
37 elif ord == np.inf:
38 n = xp.max(xp.fabs(x), axis=axis)
39 else:
40 n = xp.power(xp.sum(x**ord, axis=axis), 1 / ord)
41 return n[..., np.newaxis]
42
43
[docs]
44class MaxIter(pxa.StoppingCriterion):
45 """
46 Stop iterative solver after a fixed number of iterations.
47
48 .. note::
49
50 If you want to add a grace period to a solver, i.e. for it to do *at least* N iterations before stopping based
51 on the value of another criteria, you can AND :py:class:`~pyxu.opt.stop.MaxIter` with the other criteria.
52
53 .. code-block:: python3
54
55 sc = MaxIter(n=5) & AbsError(eps=0.1)
56 # If N_iter < 5 -> never stop.
57 # If N_iter >= 5 -> stop if AbsError() decides to.
58 """
59
[docs]
60 def __init__(self, n: pxt.Integer):
61 """
62 Parameters
63 ----------
64 n: Integer
65 Max number of iterations allowed.
66 """
67 try:
68 assert int(n) > 0
69 self._n = int(n)
70 except Exception:
71 raise ValueError(f"n: expected positive integer, got {n}.")
72 self._i = 0
73
74 def stop(self, state: cabc.Mapping) -> bool:
75 self._i += 1
76 return self._i > self._n
77
78 def info(self) -> cabc.Mapping[str, float]:
79 return dict(N_iter=self._i)
80
81 def clear(self):
82 self._i = 0
83
84
[docs]
85class ManualStop(pxa.StoppingCriterion):
86 """
87 Continue-forever criterion.
88
89 This class is useful when calling :py:meth:`~pyxu.abc.Solver.fit` with mode=MANUAL/ASYNC to defer the stopping
90 decision to an explicit call by the user, i.e.:
91
92 * mode=MANUAL: user must stop calling ``next(solver.steps())``;
93 * mode=ASYNC: user must call :py:meth:`~pyxu.abc.Solver.stop`.
94 """
95
96 def stop(self, state: cabc.Mapping) -> bool:
97 return False
98
99 def info(self) -> cabc.Mapping[str, float]:
100 return dict()
101
102 def clear(self):
103 pass
104
105
[docs]
106class MaxDuration(pxa.StoppingCriterion):
107 """
108 Stop iterative solver after a specified duration has elapsed.
109 """
110
[docs]
111 def __init__(self, t: dt.timedelta):
112 """
113 Parameters
114 ----------
115 t: ~datetime.timedelta
116 Max runtime allowed.
117 """
118 try:
119 assert t > dt.timedelta()
120 self._t_max = t
121 except Exception:
122 raise ValueError(f"t: expected positive duration, got {t}.")
123 self._t_start = dt.datetime.now()
124 self._t_now = self._t_start
125
126 def stop(self, state: cabc.Mapping) -> bool:
127 self._t_now = dt.datetime.now()
128 return (self._t_now - self._t_start) > self._t_max
129
130 def info(self) -> cabc.Mapping[str, float]:
131 d = (self._t_now - self._t_start).total_seconds()
132 return dict(duration=d)
133
134 def clear(self):
135 self._t_start = dt.datetime.now()
136 self._t_now = self._t_start
137
138
[docs]
139class Memorize(pxa.StoppingCriterion):
140 """
141 Memorize a variable. (Special :py:class:`~pyxu.abc.StoppingCriterion` mostly useful for tracking objective
142 functions in :py:class:`~pyxu.abc.Solver`.)
143 """
144
[docs]
145 def __init__(self, var: pxt.VarName):
146 """
147 Parameters
148 ----------
149 var: VarName
150 Variable in :py:attr:`pyxu.abc.Solver._mstate` to query. Must be a scalar or NDArray (1D).
151 """
152 self._var = var
153 self._val = np.r_[0] # last memorized value in stop().
154
155 def stop(self, state: cabc.Mapping) -> bool:
156 x = state[self._var]
157 if isinstance(x, pxt.Real):
158 x = np.r_[x]
159 assert x.ndim == 1
160
161 self._val = pxu.compute(x)
162 return False
163
164 def info(self) -> cabc.Mapping[str, float]:
165 if self._val.size == 1:
166 data = {f"Memorize[{self._var}]": float(self._val.max())} # takes the only element available.
167 else:
168 data = {
169 f"Memorize[{self._var}]_min": float(self._val.min()),
170 f"Memorize[{self._var}]_max": float(self._val.max()),
171 }
172 return data
173
174 def clear(self):
175 self._val = np.r_[0]
176
177
[docs]
178class AbsError(pxa.StoppingCriterion):
179 """
180 Stop iterative solver after absolute norm of a variable (or function thereof) reaches threshold.
181 """
182
[docs]
183 def __init__(
184 self,
185 eps: pxt.Real,
186 var: pxt.VarName = "x",
187 rank: pxt.Integer = 1,
188 f: SVFunction = None,
189 norm: pxt.Real = 2,
190 satisfy_all: bool = True,
191 ):
192 """
193 Parameters
194 ----------
195 eps: Real
196 Positive threshold.
197 var: VarName
198 Variable in :py:attr:`pyxu.abc.Solver._mstate` to query.
199 Must hold an NDArray.
200 rank: Integer
201 Array rank K of monitored variable **after** applying `f`. (See below.)
202 f: ~collections.abc.Callable
203 Optional function to pre-apply to ``_mstate[var]`` before applying the norm. Defaults to the identity
204 function. The callable should have the same semantics as :py:meth:`~pyxu.abc.Map.apply`:
205
206 (..., M1,...,MD) -> (..., N1,...,NK)
207 norm: Integer, Real
208 Ln norm to use >= 0. (Default: L2.)
209 satisfy_all: bool
210 If True (default) and ``_mstate[var]`` is multi-dimensional, stop if all evaluation points lie below
211 threshold.
212 """
213 try:
214 assert eps > 0
215 self._eps = eps
216 except Exception:
217 raise ValueError(f"eps: expected positive threshold, got {eps}.")
218
219 self._var = var
220 self._rank = int(rank)
221 self._f = f if (f is not None) else (lambda _: _)
222
223 try:
224 assert norm >= 0
225 self._norm = norm
226 except Exception:
227 raise ValueError(f"norm: expected non-negative, got {norm}.")
228
229 self._satisfy_all = satisfy_all
230 self._val = np.r_[0] # last computed Ln norm(s) in stop().
231
232 def stop(self, state: cabc.Mapping) -> bool:
233 fx = self._f(state[self._var]) # (..., N1,...,NK)
234 self._val = _norm(fx, ord=self._norm, rank=self._rank) # (..., 1)
235
236 xp = pxu.get_array_module(fx)
237 rule = xp.all if self._satisfy_all else xp.any
238 decision = rule(self._val <= self._eps) # (..., 1)
239
240 self._val, decision = pxu.compute(self._val, decision)
241 return decision
242
243 def info(self) -> cabc.Mapping[str, float]:
244 if self._val.size == 1:
245 data = {f"AbsError[{self._var}]": float(self._val.max())} # takes the only element available.
246 else:
247 data = {
248 f"AbsError[{self._var}]_min": float(self._val.min()),
249 f"AbsError[{self._var}]_max": float(self._val.max()),
250 }
251 return data
252
253 def clear(self):
254 self._val = np.r_[0]
255
256
[docs]
257class RelError(pxa.StoppingCriterion):
258 """
259 Stop iterative solver after relative norm change of a variable (or function thereof) reaches threshold.
260 """
261
[docs]
262 def __init__(
263 self,
264 eps: pxt.Real,
265 var: pxt.VarName = "x",
266 rank: pxt.Integer = 1,
267 f: SVFunction = None,
268 norm: pxt.Real = 2,
269 satisfy_all: bool = True,
270 ):
271 """
272 Parameters
273 ----------
274 eps: Real
275 Positive threshold.
276 var: VarName
277 Variable in :py:attr:`pyxu.abc.Solver._mstate` to query.
278 Must hold an NDArray
279 rank: Integer
280 Array rank K of monitored variable **after** applying `f`. (See below.)
281 f: ~collections.abc.Callable
282 Optional function to pre-apply to ``_mstate[var]`` before applying the norm. Defaults to the identity
283 function. The callable should have the same semantics as :py:meth:`~pyxu.abc.Map.apply`:
284
285 (..., M1,...,MD) -> (..., N1,...,NK)
286 norm: Integer, Real
287 Ln norm to use >= 0. (Default: L2.)
288 satisfy_all: bool
289 If True (default) and ``_mstate[var]`` is multi-dimensional, stop if all evaluation points lie below
290 threshold.
291 """
292 try:
293 assert eps > 0
294 self._eps = eps
295 except Exception:
296 raise ValueError(f"eps: expected positive threshold, got {eps}.")
297
298 self._var = var
299 self._rank = int(rank)
300 self._f = f if (f is not None) else (lambda _: _)
301
302 try:
303 assert norm >= 0
304 self._norm = norm
305 except Exception:
306 raise ValueError(f"norm: expected non-negative, got {norm}.")
307
308 self._satisfy_all = satisfy_all
309 self._val = np.r_[0] # last computed Ln rel-norm(s) in stop().
310 self._x_prev = None # buffered var from last query.
311
312 def stop(self, state: cabc.Mapping) -> bool:
313 x = state[self._var] # (..., M1,...,MD)
314
315 if self._x_prev is None:
316 self._x_prev = x.copy()
317 fx_prev = self._f(self._x_prev) # (..., N1,...,NK)
318
319 # force 1st .info() call to have same format as further calls.
320 sh = fx_prev.shape[: -self._rank]
321 self._val = np.zeros(shape=(*sh, 1))
322 return False # decision deferred: insufficient history to evaluate rel-err.
323 else:
324 xp = pxu.get_array_module(x)
325 rule = xp.all if self._satisfy_all else xp.any
326
327 fx_prev = self._f(self._x_prev) # (..., N1,...,NK)
328 numerator = _norm(self._f(x) - fx_prev, ord=self._norm, rank=self._rank)
329 denominator = _norm(fx_prev, ord=self._norm, rank=self._rank)
330 decision = rule(numerator <= self._eps * denominator) # (..., 1)
331
332 with warnings.catch_warnings():
333 # Store relative improvement values for info(). Special care must be taken for the
334 # problematic case 0/0 -> NaN.
335 warnings.simplefilter("ignore")
336 self._val = numerator / denominator # (..., 1)
337 self._val[xp.isnan(self._val)] = 0 # no relative improvement.
338 self._x_prev = x.copy()
339
340 self._x_prev, self._val, decision = pxu.compute(self._x_prev, self._val, decision)
341 return decision
342
343 def info(self) -> cabc.Mapping[str, float]:
344 if self._val.size == 1:
345 data = {f"RelError[{self._var}]": float(self._val.max())} # takes the only element available.
346 else:
347 data = {
348 f"RelError[{self._var}]_min": float(self._val.min()),
349 f"RelError[{self._var}]_max": float(self._val.max()),
350 }
351 return data
352
353 def clear(self):
354 self._val = np.r_[0]
355 self._x_prev = None