Source code for pyxu.math.linalg
1import numpy as np
2
3import pyxu.abc as pxa
4import pyxu.info.deps as pxd
5import pyxu.info.ptype as pxt
6import pyxu.info.warning as pxw
7import pyxu.runtime as pxrt
8
9__all__ = [
10 "hutchpp",
11 "trace",
12]
13
14
[docs]
15def trace(
16 op: pxa.SquareOp,
17 xp: pxt.ArrayModule = None,
18 dtype: pxt.DType = None,
19) -> pxt.Real:
20 r"""
21 Exact trace of a linear operator based on multiple evaluation of the forward operator.
22
23 Parameters
24 ----------
25 op: ~pyxu.abc.operator.SquareOp
26 xp: ArrayModule
27 Array module used for internal computations. (Default: NumPy.)
28 dtype: DType
29 Precision to use for internal computations. (Default: current runtime precision.)
30
31 Returns
32 -------
33 tr: Real
34 Exact value of tr(op).
35 """
36 if xp is None:
37 xp = pxd.NDArrayInfo.default().module()
38
39 if dtype is None:
40 dtype = pxrt.Width.DOUBLE.value
41
42 tr = 0
43 for i in range(op.dim_size):
44 idx_in = np.unravel_index(i, op.dim_shape)
45 idx_out = np.unravel_index(i, op.codim_shape)
46
47 e = xp.zeros(op.dim_shape, dtype=dtype)
48 e[idx_in] = 1
49 tr += op.apply(e)[idx_out]
50 return float(tr)
51
52
[docs]
53def hutchpp(
54 op: pxa.SquareOp,
55 m: pxt.Integer = 4002,
56 xp: pxt.ArrayModule = None,
57 dtype: pxt.DType = None,
58 seed: pxt.Integer = None,
59) -> pxt.Real:
60 r"""
61 Stochastic trace estimation of a linear operator based on the Hutch++ algorithm. (Specifically `algorithm 3 from
62 this paper <https://arxiv.org/abs/2010.09649>`_.)
63
64 Parameters
65 ----------
66 op: ~pyxu.abc.operator.SquareOp
67 m: Integer
68 Number of queries used to estimate the trace of the linear operator.
69
70 `m` is set to 4002 by default based on the analysis of the variance described in theorem 10. This default
71 corresponds to having an estimation error smaller than 0.01 with probability 0.9.
72 xp: ArrayModule
73 Array module used for internal computations. (Default: NumPy.)
74 dtype: DType
75 Precision to use for internal computations. (Default: current runtime precision.)
76 seed: Integer
77 Seed for the random number generator.
78
79 Returns
80 -------
81 tr: Real
82 Stochastic estimate of tr(op).
83 """
84 from pyxu.operator import ReshapeAxes
85
86 if xp is None:
87 xp = pxd.NDArrayInfo.default().module()
88 if using_dask := (xp == pxd.NDArrayInfo.DASK.module()):
89 msg = "\n".join(
90 [
91 "DASK.linalg.qr() has limitations.",
92 "[More info] https://docs.dask.org/en/stable/_modules/dask/array/linalg.html#qr",
93 ]
94 )
95 pxw.warn_dask_perf(msg)
96
97 if dtype is None:
98 dtype = pxrt.Width.DOUBLE.value
99
100 # To avoid constant reshaping below, we use the 2D-equivalent operator.
101 lhs = ReshapeAxes(dim_shape=op.codim_shape, codim_shape=op.codim_size)
102 rhs = ReshapeAxes(dim_shape=op.dim_size, codim_shape=op.dim_shape)
103 op = lhs * op * rhs
104
105 rng = xp.random.default_rng(seed=seed)
106 s = rng.standard_normal(size=(op.dim_size, (m + 2) // 4), dtype=dtype)
107 g = rng.integers(0, 2, size=(op.dim_size, (m - 2) // 2)) * 2 - 1
108
109 data = op.apply(s.T).T # (dim, (m+2)//4)
110
111 kwargs = dict(mode="reduced")
112 if using_dask:
113 data = data.rechunk({0: "auto", 1: -1})
114 kwargs.pop("mode", None)
115
116 q, _ = xp.linalg.qr(data, **kwargs)
117 proj = g - q @ (q.T @ g)
118
119 tr = (op.apply(q.T) @ q).trace()
120 tr += (2 / (m - 2)) * (op.apply(proj.T) @ proj).trace()
121 return float(tr)