Source code for pyxu.math.linalg

  1import numpy as np
  3import as pxa
  4import as pxd
  5import as pxt
  6import as pxw
  7import pyxu.runtime as pxrt
  9__all__ = [
 10    "hutchpp",
 11    "trace",
[docs] 15def trace( 16 op: pxa.SquareOp, 17 xp: pxt.ArrayModule = None, 18 dtype: pxt.DType = None, 19) -> pxt.Real: 20 r""" 21 Exact trace of a linear operator based on multiple evaluation of the forward operator. 22 23 Parameters 24 ---------- 25 op: 26 xp: ArrayModule 27 Array module used for internal computations. (Default: NumPy.) 28 dtype: DType 29 Precision to use for internal computations. (Default: current runtime precision.) 30 31 Returns 32 ------- 33 tr: Real 34 Exact value of tr(op). 35 """ 36 if xp is None: 37 xp = pxd.NDArrayInfo.default().module() 38 39 if dtype is None: 40 dtype = pxrt.Width.DOUBLE.value 41 42 tr = 0 43 for i in range(op.dim_size): 44 idx_in = np.unravel_index(i, op.dim_shape) 45 idx_out = np.unravel_index(i, op.codim_shape) 46 47 e = xp.zeros(op.dim_shape, dtype=dtype) 48 e[idx_in] = 1 49 tr += op.apply(e)[idx_out] 50 return float(tr)
[docs] 53def hutchpp( 54 op: pxa.SquareOp, 55 m: pxt.Integer = 4002, 56 xp: pxt.ArrayModule = None, 57 dtype: pxt.DType = None, 58 seed: pxt.Integer = None, 59) -> pxt.Real: 60 r""" 61 Stochastic trace estimation of a linear operator based on the Hutch++ algorithm. (Specifically `algorithm 3 from 62 this paper <>`_.) 63 64 Parameters 65 ---------- 66 op: 67 m: Integer 68 Number of queries used to estimate the trace of the linear operator. 69 70 `m` is set to 4002 by default based on the analysis of the variance described in theorem 10. This default 71 corresponds to having an estimation error smaller than 0.01 with probability 0.9. 72 xp: ArrayModule 73 Array module used for internal computations. (Default: NumPy.) 74 dtype: DType 75 Precision to use for internal computations. (Default: current runtime precision.) 76 seed: Integer 77 Seed for the random number generator. 78 79 Returns 80 ------- 81 tr: Real 82 Stochastic estimate of tr(op). 83 """ 84 from pyxu.operator import ReshapeAxes 85 86 if xp is None: 87 xp = pxd.NDArrayInfo.default().module() 88 if using_dask := (xp == pxd.NDArrayInfo.DASK.module()): 89 msg = "\n".join( 90 [ 91 "DASK.linalg.qr() has limitations.", 92 "[More info]", 93 ] 94 ) 95 pxw.warn_dask_perf(msg) 96 97 if dtype is None: 98 dtype = pxrt.Width.DOUBLE.value 99 100 # To avoid constant reshaping below, we use the 2D-equivalent operator. 101 lhs = ReshapeAxes(dim_shape=op.codim_shape, codim_shape=op.codim_size) 102 rhs = ReshapeAxes(dim_shape=op.dim_size, codim_shape=op.dim_shape) 103 op = lhs * op * rhs 104 105 rng = xp.random.default_rng(seed=seed) 106 s = rng.standard_normal(size=(op.dim_size, (m + 2) // 4), dtype=dtype) 107 g = rng.integers(0, 2, size=(op.dim_size, (m - 2) // 2)) * 2 - 1 108 109 data = op.apply(s.T).T # (dim, (m+2)//4) 110 111 kwargs = dict(mode="reduced") 112 if using_dask: 113 data = data.rechunk({0: "auto", 1: -1}) 114 kwargs.pop("mode", None) 115 116 q, _ = xp.linalg.qr(data, **kwargs) 117 proj = g - q @ (q.T @ g) 118 119 tr = (op.apply(q.T) @ q).trace() 120 tr += (2 / (m - 2)) * (op.apply(proj.T) @ proj).trace() 121 return float(tr)