Source code for pyxu.opt.stop

  1import collections.abc as cabc
  2import datetime as dt
  3import warnings
  4
  5import numpy as np
  6
  7import pyxu.abc as pxa
  8import pyxu.info.ptype as pxt
  9import pyxu.util as pxu
 10from pyxu.info.plugin import _load_entry_points
 11
 12__all__ = [
 13    "AbsError",
 14    "ManualStop",
 15    "MaxDuration",
 16    "MaxIter",
 17    "Memorize",
 18    "RelError",
 19]
 20
 21__all__ = _load_entry_points(globals(), group="pyxu.opt.stop", names=__all__)
 22
 23SVFunction = cabc.Callable[[pxt.NDArray], pxt.NDArray]
 24
 25
 26def _norm(x: pxt.NDArray, ord: pxt.Integer, rank: pxt.Integer) -> pxt.NDArray:
 27    # x: (..., M1,...,MD) [rank=D]
 28    # n: (..., 1)         [`ord`-norm of `x`, computed over last `rank` axes]
 29    xp = pxu.get_array_module(x)
 30    axis = tuple(range(-rank, 0))
 31    if ord == 0:
 32        n = xp.sum(~xp.isclose(x, 0), axis=axis)
 33    elif ord == 1:
 34        n = xp.sum(xp.fabs(x), axis=axis)
 35    elif ord == 2:
 36        n = xp.sqrt(xp.sum(x**2, axis=axis))
 37    elif ord == np.inf:
 38        n = xp.max(xp.fabs(x), axis=axis)
 39    else:
 40        n = xp.power(xp.sum(x**ord, axis=axis), 1 / ord)
 41    return n[..., np.newaxis]
 42
 43
[docs] 44class MaxIter(pxa.StoppingCriterion): 45 """ 46 Stop iterative solver after a fixed number of iterations. 47 48 .. note:: 49 50 If you want to add a grace period to a solver, i.e. for it to do *at least* N iterations before stopping based 51 on the value of another criteria, you can AND :py:class:`~pyxu.opt.stop.MaxIter` with the other criteria. 52 53 .. code-block:: python3 54 55 sc = MaxIter(n=5) & AbsError(eps=0.1) 56 # If N_iter < 5 -> never stop. 57 # If N_iter >= 5 -> stop if AbsError() decides to. 58 """ 59
[docs] 60 def __init__(self, n: pxt.Integer): 61 """ 62 Parameters 63 ---------- 64 n: Integer 65 Max number of iterations allowed. 66 """ 67 try: 68 assert int(n) > 0 69 self._n = int(n) 70 except Exception: 71 raise ValueError(f"n: expected positive integer, got {n}.") 72 self._i = 0
73 74 def stop(self, state: cabc.Mapping) -> bool: 75 self._i += 1 76 return self._i > self._n 77 78 def info(self) -> cabc.Mapping[str, float]: 79 return dict(N_iter=self._i) 80 81 def clear(self): 82 self._i = 0
83 84
[docs] 85class ManualStop(pxa.StoppingCriterion): 86 """ 87 Continue-forever criterion. 88 89 This class is useful when calling :py:meth:`~pyxu.abc.Solver.fit` with mode=MANUAL/ASYNC to defer the stopping 90 decision to an explicit call by the user, i.e.: 91 92 * mode=MANUAL: user must stop calling ``next(solver.steps())``; 93 * mode=ASYNC: user must call :py:meth:`~pyxu.abc.Solver.stop`. 94 """ 95 96 def stop(self, state: cabc.Mapping) -> bool: 97 return False 98 99 def info(self) -> cabc.Mapping[str, float]: 100 return dict() 101 102 def clear(self): 103 pass
104 105
[docs] 106class MaxDuration(pxa.StoppingCriterion): 107 """ 108 Stop iterative solver after a specified duration has elapsed. 109 """ 110
[docs] 111 def __init__(self, t: dt.timedelta): 112 """ 113 Parameters 114 ---------- 115 t: ~datetime.timedelta 116 Max runtime allowed. 117 """ 118 try: 119 assert t > dt.timedelta() 120 self._t_max = t 121 except Exception: 122 raise ValueError(f"t: expected positive duration, got {t}.") 123 self._t_start = dt.datetime.now() 124 self._t_now = self._t_start
125 126 def stop(self, state: cabc.Mapping) -> bool: 127 self._t_now = dt.datetime.now() 128 return (self._t_now - self._t_start) > self._t_max 129 130 def info(self) -> cabc.Mapping[str, float]: 131 d = (self._t_now - self._t_start).total_seconds() 132 return dict(duration=d) 133 134 def clear(self): 135 self._t_start = dt.datetime.now() 136 self._t_now = self._t_start
137 138
[docs] 139class Memorize(pxa.StoppingCriterion): 140 """ 141 Memorize a variable. (Special :py:class:`~pyxu.abc.StoppingCriterion` mostly useful for tracking objective 142 functions in :py:class:`~pyxu.abc.Solver`.) 143 """ 144
[docs] 145 def __init__(self, var: pxt.VarName): 146 """ 147 Parameters 148 ---------- 149 var: VarName 150 Variable in :py:attr:`pyxu.abc.Solver._mstate` to query. Must be a scalar or NDArray (1D). 151 """ 152 self._var = var 153 self._val = np.r_[0] # last memorized value in stop().
154 155 def stop(self, state: cabc.Mapping) -> bool: 156 x = state[self._var] 157 if isinstance(x, pxt.Real): 158 x = np.r_[x] 159 assert x.ndim == 1 160 161 self._val = pxu.compute(x) 162 return False 163 164 def info(self) -> cabc.Mapping[str, float]: 165 if self._val.size == 1: 166 data = {f"Memorize[{self._var}]": float(self._val.max())} # takes the only element available. 167 else: 168 data = { 169 f"Memorize[{self._var}]_min": float(self._val.min()), 170 f"Memorize[{self._var}]_max": float(self._val.max()), 171 } 172 return data 173 174 def clear(self): 175 self._val = np.r_[0]
176 177
[docs] 178class AbsError(pxa.StoppingCriterion): 179 """ 180 Stop iterative solver after absolute norm of a variable (or function thereof) reaches threshold. 181 """ 182
[docs] 183 def __init__( 184 self, 185 eps: pxt.Real, 186 var: pxt.VarName = "x", 187 rank: pxt.Integer = 1, 188 f: SVFunction = None, 189 norm: pxt.Real = 2, 190 satisfy_all: bool = True, 191 ): 192 """ 193 Parameters 194 ---------- 195 eps: Real 196 Positive threshold. 197 var: VarName 198 Variable in :py:attr:`pyxu.abc.Solver._mstate` to query. 199 Must hold an NDArray. 200 rank: Integer 201 Array rank K of monitored variable **after** applying `f`. (See below.) 202 f: ~collections.abc.Callable 203 Optional function to pre-apply to ``_mstate[var]`` before applying the norm. Defaults to the identity 204 function. The callable should have the same semantics as :py:meth:`~pyxu.abc.Map.apply`: 205 206 (..., M1,...,MD) -> (..., N1,...,NK) 207 norm: Integer, Real 208 Ln norm to use >= 0. (Default: L2.) 209 satisfy_all: bool 210 If True (default) and ``_mstate[var]`` is multi-dimensional, stop if all evaluation points lie below 211 threshold. 212 """ 213 try: 214 assert eps > 0 215 self._eps = eps 216 except Exception: 217 raise ValueError(f"eps: expected positive threshold, got {eps}.") 218 219 self._var = var 220 self._rank = int(rank) 221 self._f = f if (f is not None) else (lambda _: _) 222 223 try: 224 assert norm >= 0 225 self._norm = norm 226 except Exception: 227 raise ValueError(f"norm: expected non-negative, got {norm}.") 228 229 self._satisfy_all = satisfy_all 230 self._val = np.r_[0] # last computed Ln norm(s) in stop().
231 232 def stop(self, state: cabc.Mapping) -> bool: 233 fx = self._f(state[self._var]) # (..., N1,...,NK) 234 self._val = _norm(fx, ord=self._norm, rank=self._rank) # (..., 1) 235 236 xp = pxu.get_array_module(fx) 237 rule = xp.all if self._satisfy_all else xp.any 238 decision = rule(self._val <= self._eps) # (..., 1) 239 240 self._val, decision = pxu.compute(self._val, decision) 241 return decision 242 243 def info(self) -> cabc.Mapping[str, float]: 244 if self._val.size == 1: 245 data = {f"AbsError[{self._var}]": float(self._val.max())} # takes the only element available. 246 else: 247 data = { 248 f"AbsError[{self._var}]_min": float(self._val.min()), 249 f"AbsError[{self._var}]_max": float(self._val.max()), 250 } 251 return data 252 253 def clear(self): 254 self._val = np.r_[0]
255 256
[docs] 257class RelError(pxa.StoppingCriterion): 258 """ 259 Stop iterative solver after relative norm change of a variable (or function thereof) reaches threshold. 260 """ 261
[docs] 262 def __init__( 263 self, 264 eps: pxt.Real, 265 var: pxt.VarName = "x", 266 rank: pxt.Integer = 1, 267 f: SVFunction = None, 268 norm: pxt.Real = 2, 269 satisfy_all: bool = True, 270 ): 271 """ 272 Parameters 273 ---------- 274 eps: Real 275 Positive threshold. 276 var: VarName 277 Variable in :py:attr:`pyxu.abc.Solver._mstate` to query. 278 Must hold an NDArray 279 rank: Integer 280 Array rank K of monitored variable **after** applying `f`. (See below.) 281 f: ~collections.abc.Callable 282 Optional function to pre-apply to ``_mstate[var]`` before applying the norm. Defaults to the identity 283 function. The callable should have the same semantics as :py:meth:`~pyxu.abc.Map.apply`: 284 285 (..., M1,...,MD) -> (..., N1,...,NK) 286 norm: Integer, Real 287 Ln norm to use >= 0. (Default: L2.) 288 satisfy_all: bool 289 If True (default) and ``_mstate[var]`` is multi-dimensional, stop if all evaluation points lie below 290 threshold. 291 """ 292 try: 293 assert eps > 0 294 self._eps = eps 295 except Exception: 296 raise ValueError(f"eps: expected positive threshold, got {eps}.") 297 298 self._var = var 299 self._rank = int(rank) 300 self._f = f if (f is not None) else (lambda _: _) 301 302 try: 303 assert norm >= 0 304 self._norm = norm 305 except Exception: 306 raise ValueError(f"norm: expected non-negative, got {norm}.") 307 308 self._satisfy_all = satisfy_all 309 self._val = np.r_[0] # last computed Ln rel-norm(s) in stop(). 310 self._x_prev = None # buffered var from last query.
311 312 def stop(self, state: cabc.Mapping) -> bool: 313 x = state[self._var] # (..., M1,...,MD) 314 315 if self._x_prev is None: 316 self._x_prev = x.copy() 317 fx_prev = self._f(self._x_prev) # (..., N1,...,NK) 318 319 # force 1st .info() call to have same format as further calls. 320 sh = fx_prev.shape[: -self._rank] 321 self._val = np.zeros(shape=(*sh, 1)) 322 return False # decision deferred: insufficient history to evaluate rel-err. 323 else: 324 xp = pxu.get_array_module(x) 325 rule = xp.all if self._satisfy_all else xp.any 326 327 fx_prev = self._f(self._x_prev) # (..., N1,...,NK) 328 numerator = _norm(self._f(x) - fx_prev, ord=self._norm, rank=self._rank) 329 denominator = _norm(fx_prev, ord=self._norm, rank=self._rank) 330 decision = rule(numerator <= self._eps * denominator) # (..., 1) 331 332 with warnings.catch_warnings(): 333 # Store relative improvement values for info(). Special care must be taken for the 334 # problematic case 0/0 -> NaN. 335 warnings.simplefilter("ignore") 336 self._val = numerator / denominator # (..., 1) 337 self._val[xp.isnan(self._val)] = 0 # no relative improvement. 338 self._x_prev = x.copy() 339 340 self._x_prev, self._val, decision = pxu.compute(self._x_prev, self._val, decision) 341 return decision 342 343 def info(self) -> cabc.Mapping[str, float]: 344 if self._val.size == 1: 345 data = {f"RelError[{self._var}]": float(self._val.max())} # takes the only element available. 346 else: 347 data = { 348 f"RelError[{self._var}]_min": float(self._val.min()), 349 f"RelError[{self._var}]_max": float(self._val.max()), 350 } 351 return data 352 353 def clear(self): 354 self._val = np.r_[0] 355 self._x_prev = None