Source code for pyxu.math.linesearch

  1import pyxu.abc as pxa
  2import pyxu.info.deps as pxd
  3import pyxu.info.ptype as pxt
  4
  5__all__ = [
  6    "backtracking_linesearch",
  7]
  8
  9
 10LINESEARCH_DEFAULT_R = 0.5
 11LINESEARCH_DEFAULT_C = 1e-4
 12newaxis = None  # Since we don't import NumPy
 13
 14
[docs] 15def backtracking_linesearch( 16 f: pxa.DiffFunc, 17 x: pxt.NDArray, 18 direction: pxt.NDArray, 19 gradient: pxt.NDArray = None, 20 a0: pxt.Real = None, 21 r: pxt.Real = LINESEARCH_DEFAULT_R, 22 c: pxt.Real = LINESEARCH_DEFAULT_C, 23) -> pxt.NDArray: 24 r""" 25 Backtracking line search algorithm based on the `Armijo-Goldstein condition 26 <https://www.wikiwand.com/en/Backtracking_line_search>`_. 27 28 Parameters 29 ---------- 30 f: ~pyxu.abc.operator.DiffFunc 31 Differentiable functional. 32 x: NDArray 33 (..., M1,...,MD) initial search point(s). 34 direction: NDArray 35 (..., M1,...,MD) search direction(s) corresponding to initial point(s). 36 gradient: NDArray 37 (..., M1,...,MD) gradient of `f` at initial search point(s). 38 39 Specifying `gradient` when known is an optimization: it will be autocomputed via 40 :py:meth:`~pyxu.abc.DiffFunc.grad` if unspecified. 41 a0: Real 42 Initial step size. 43 44 If unspecified and :math:`\nabla f` is :math:`\beta`-Lipschitz continuous, then `a0` is auto-chosen as 45 :math:`\frac{1}{\beta}`. 46 r: Real 47 Step reduction factor. 48 c: Real 49 Bound reduction factor. 50 51 Returns 52 ------- 53 a: NDArray 54 (..., 1) step sizes. 55 56 Notes 57 ----- 58 * Performing a line-search with DASK inputs is inefficient due to iterative nature of algorithm. 59 """ 60 ndi = pxd.NDArrayInfo.from_obj(x) 61 xp = ndi.module() 62 63 assert 0 < r < 1 64 assert 0 < c < 1 65 if a0 is None: 66 a0 = 1.0 / f.diff_lipschitz 67 assert a0 > 0, "a0: cannot auto-set step size." 68 else: 69 assert a0 > 0 70 71 if gradient is None: 72 gradient = f.grad(x) 73 74 f_x = f.apply(x) # (..., 1) 75 d_f = ( # \delta f (..., 1) 76 c 77 * xp.sum( 78 gradient * direction, 79 axis=tuple(range(-f.dim_rank, 0)), 80 )[..., newaxis] 81 ) 82 83 def refine(a: pxt.NDArray) -> pxt.NDArray: 84 # Do one iteration of the algorithm. 85 # 86 # Parameters 87 # ---------- 88 # a : NDArray 89 # (..., 1) current step size(s). 90 # 91 # Returns 92 # ------- 93 # mask : NDArray[bool] 94 # (..., 1) refinement points 95 a_D = a[..., *((newaxis,) * (f.dim_rank - 1))] # (..., 1,...,1) 96 lhs = f.apply(x + a_D * direction) # (..., 1) 97 rhs = f_x + a * d_f # (..., 1) 98 return lhs > rhs # mask 99 100 a = xp.full_like(d_f, fill_value=a0, dtype=x.dtype) 101 while xp.any(mask := refine(a)): 102 a[mask] *= r 103 104 if ndi == pxd.NDArrayInfo.DASK: 105 a.compute_chunk_sizes() 106 return a