pyxu.math#

backtracking_linesearch(f, x, direction, gradient=None, a0=None, r=0.5, c=0.0001)[source]#

Backtracking line search algorithm based on the Armijo-Goldstein condition.

Parameters:
  • f (DiffFunc) – Differentiable functional.

  • x (NDArray) – (…, M1,…,MD) initial search point(s).

  • direction (NDArray) – (…, M1,…,MD) search direction(s) corresponding to initial point(s).

  • gradient (NDArray) –

    (…, M1,…,MD) gradient of f at initial search point(s).

    Specifying gradient when known is an optimization: it will be autocomputed via grad() if unspecified.

  • a0 (Real) –

    Initial step size.

    If unspecified and \(\nabla f\) is \(\beta\)-Lipschitz continuous, then a0 is auto-chosen as \(\frac{1}{\beta}\).

  • r (Real) – Step reduction factor.

  • c (Real) – Bound reduction factor.

Returns:

a – (…, 1) step sizes.

Return type:

NDArray

Notes

  • Performing a line-search with DASK inputs is inefficient due to iterative nature of algorithm.

hutchpp(op, m=4002, xp=None, dtype=None, seed=None)[source]#

Stochastic trace estimation of a linear operator based on the Hutch++ algorithm. (Specifically algorithm 3 from this paper.)

Parameters:
  • op (SquareOp)

  • m (Integer) –

    Number of queries used to estimate the trace of the linear operator.

    m is set to 4002 by default based on the analysis of the variance described in theorem 10. This default corresponds to having an estimation error smaller than 0.01 with probability 0.9.

  • xp (ArrayModule) – Array module used for internal computations. (Default: NumPy.)

  • dtype (DType) – Precision to use for internal computations. (Default: current runtime precision.)

  • seed (Integer) – Seed for the random number generator.

Returns:

tr – Stochastic estimate of tr(op).

Return type:

Real

trace(op, xp=None, dtype=None)[source]#

Exact trace of a linear operator based on multiple evaluation of the forward operator.

Parameters:
  • op (SquareOp)

  • xp (ArrayModule) – Array module used for internal computations. (Default: NumPy.)

  • dtype (DType) – Precision to use for internal computations. (Default: current runtime precision.)

Returns:

tr – Exact value of tr(op).

Return type:

Real