Source code for pyxu.math.linesearch
1import pyxu.abc as pxa
2import pyxu.info.deps as pxd
3import pyxu.info.ptype as pxt
4
5__all__ = [
6 "backtracking_linesearch",
7]
8
9
10LINESEARCH_DEFAULT_R = 0.5
11LINESEARCH_DEFAULT_C = 1e-4
12newaxis = None # Since we don't import NumPy
13
14
[docs]
15def backtracking_linesearch(
16 f: pxa.DiffFunc,
17 x: pxt.NDArray,
18 direction: pxt.NDArray,
19 gradient: pxt.NDArray = None,
20 a0: pxt.Real = None,
21 r: pxt.Real = LINESEARCH_DEFAULT_R,
22 c: pxt.Real = LINESEARCH_DEFAULT_C,
23) -> pxt.NDArray:
24 r"""
25 Backtracking line search algorithm based on the `Armijo-Goldstein condition
26 <https://www.wikiwand.com/en/Backtracking_line_search>`_.
27
28 Parameters
29 ----------
30 f: ~pyxu.abc.operator.DiffFunc
31 Differentiable functional.
32 x: NDArray
33 (..., M1,...,MD) initial search point(s).
34 direction: NDArray
35 (..., M1,...,MD) search direction(s) corresponding to initial point(s).
36 gradient: NDArray
37 (..., M1,...,MD) gradient of `f` at initial search point(s).
38
39 Specifying `gradient` when known is an optimization: it will be autocomputed via
40 :py:meth:`~pyxu.abc.DiffFunc.grad` if unspecified.
41 a0: Real
42 Initial step size.
43
44 If unspecified and :math:`\nabla f` is :math:`\beta`-Lipschitz continuous, then `a0` is auto-chosen as
45 :math:`\frac{1}{\beta}`.
46 r: Real
47 Step reduction factor.
48 c: Real
49 Bound reduction factor.
50
51 Returns
52 -------
53 a: NDArray
54 (..., 1) step sizes.
55
56 Notes
57 -----
58 * Performing a line-search with DASK inputs is inefficient due to iterative nature of algorithm.
59 """
60 ndi = pxd.NDArrayInfo.from_obj(x)
61 xp = ndi.module()
62
63 assert 0 < r < 1
64 assert 0 < c < 1
65 if a0 is None:
66 a0 = 1.0 / f.diff_lipschitz
67 assert a0 > 0, "a0: cannot auto-set step size."
68 else:
69 assert a0 > 0
70
71 if gradient is None:
72 gradient = f.grad(x)
73
74 f_x = f.apply(x) # (..., 1)
75 d_f = ( # \delta f (..., 1)
76 c
77 * xp.sum(
78 gradient * direction,
79 axis=tuple(range(-f.dim_rank, 0)),
80 )[..., newaxis]
81 )
82
83 def refine(a: pxt.NDArray) -> pxt.NDArray:
84 # Do one iteration of the algorithm.
85 #
86 # Parameters
87 # ----------
88 # a : NDArray
89 # (..., 1) current step size(s).
90 #
91 # Returns
92 # -------
93 # mask : NDArray[bool]
94 # (..., 1) refinement points
95 a_D = a[..., *((newaxis,) * (f.dim_rank - 1))] # (..., 1,...,1)
96 lhs = f.apply(x + a_D * direction) # (..., 1)
97 rhs = f_x + a * d_f # (..., 1)
98 return lhs > rhs # mask
99
100 a = xp.full_like(d_f, fill_value=a0, dtype=x.dtype)
101 while xp.any(mask := refine(a)):
102 a[mask] *= r
103
104 if ndi == pxd.NDArrayInfo.DASK:
105 a.compute_chunk_sizes()
106 return a