Source code for pyxu.operator.linop.fft.czt

  1import collections.abc as cabc
  2import functools
  3
  4import numpy as np
  5import scipy.special as sps
  6
  7import pyxu.abc as pxa
  8import pyxu.info.deps as pxd
  9import pyxu.info.ptype as pxt
 10import pyxu.runtime as pxrt
 11import pyxu.util as pxu
 12
 13__all__ = [
 14    "CZT",
 15]
 16
 17
[docs] 18class CZT(pxa.LinOp): 19 r""" 20 Multi-dimensional Chirp Z-Transform (CZT) :math:`C: \mathbb{C}^{N_{1} \times\cdots\times N_{D}} \to 21 \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`. 22 23 The 1D CZT of parameters :math:`(A, W, M)` is defined as: 24 25 .. math:: 26 27 (C \, \mathbf{x})[k] 28 = 29 \sum_{n=0}^{N-1} \mathbf{x}[n] A^{-n} W^{nk}, 30 31 where :math:`\mathbf{x} \in \mathbb{C}^{N}`, :math:`A, W \in \mathbb{C}`, and :math:`k = \{0, \ldots, M-1\}`. 32 33 A D-dimensional CZT corresponds to taking a 1D CZT along each transform axis. 34 35 .. rubric:: Implementation Notes 36 37 For stability reasons, this implementation assumes :math:`A, W \in \mathbb{C}` lie on the unit circle. 38 39 See Also 40 -------- 41 :py:class:`~pyxu.operator.FFT` 42 """ 43
[docs] 44 def __init__( 45 self, 46 dim_shape: pxt.NDArrayShape, 47 axes: pxt.NDArrayAxis, 48 M, 49 A, 50 W, 51 **kwargs, 52 ): 53 r""" 54 Parameters 55 ---------- 56 dim_shape: NDArrayShape 57 (N1,...,ND) dimensions of the input :math:`\mathbf{x} \in \mathbb{C}^{N_{1} \times\cdots\times N_{D}}`. 58 axes: NDArrayAxis 59 Axes over which to compute the CZT. If not given, all axes are used. 60 M : int, list(int) 61 Length of the transform per axis. 62 A : complex, list(complex) 63 Circular offset from the positive real-axis per axis. 64 W : complex, list(complex) 65 Circular spacing between transform points per axis. 66 kwargs: dict 67 Extra kwargs passed to :py:class:`~pyxu.operator.FFT`. 68 """ 69 dim_shape = pxu.as_canonical_shape(dim_shape) 70 if axes is None: 71 axes = tuple(range(len(dim_shape))) 72 self._axes = pxu.as_canonical_axes(axes, len(dim_shape)) 73 _M, self._A, self._W = self._canonical_repr(self._axes, M, A, W) 74 75 codim_shape = list(dim_shape) 76 for i, ax in enumerate(self._axes): 77 codim_shape[ax] = _M[i] 78 super().__init__( 79 dim_shape=(*dim_shape, 2), 80 codim_shape=(*codim_shape, 2), 81 ) 82 self._kwargs = kwargs 83 84 # We know a crude Lipschitz bound by default. Since computing it takes (code) space, 85 # the estimate is computed as a special case of estimate_lipschitz() 86 self.lipschitz = self.estimate_lipschitz(__rule=True)
87
[docs] 88 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 89 r""" 90 Parameters 91 ---------- 92 arr: NDArray 93 (..., N1,...,ND,2) inputs :math:`\mathbf{x} \in \mathbb{C}^{N_{1} \times\cdots\times N_{D}}` viewed as a 94 real array. (See :py:func:`~pyxu.util.view_as_real`.) 95 96 Returns 97 ------- 98 out: NDArray 99 (..., M1,...,MD,2) outputs :math:`\hat{\mathbf{x}} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}` viewed 100 as a real array. (See :py:func:`~pyxu.util.view_as_real`.) 101 """ 102 x = pxu.view_as_complex(pxu.require_viewable(arr)) 103 y = self.capply(x) 104 out = pxu.view_as_real(pxu.require_viewable(y)) 105 return out
106 107 def capply(self, arr: pxt.NDArray) -> pxt.NDArray: 108 r""" 109 Parameters 110 ---------- 111 arr: NDArray 112 (..., N1,...,ND) inputs :math:`\mathbf{x} \in \mathbb{C}^{N_{1} \times\cdots\times N_{D}}`. 113 114 Returns 115 ------- 116 out: NDArray 117 (..., M1,...,MD) outputs :math:`\hat{\mathbf{x}} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`. 118 """ 119 AWk2, FWk2, Wk2, extract, fft = self._get_meta(arr) 120 arr = arr.copy() # for in-place updates 121 for _AWk2 in AWk2: 122 arr *= _AWk2 123 y = fft.capply(arr) 124 for _FWk2 in FWk2: 125 y *= _FWk2 126 out = fft.cpinv(y, damp=0)[extract] 127 for _Wk2 in Wk2: 128 out *= _Wk2 129 return out 130
[docs] 131 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 132 r""" 133 Parameters 134 ---------- 135 arr: NDArray 136 (..., M1,...,MD,2) inputs :math:`\hat{\mathbf{x}} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}` viewed 137 as a real array. (See :py:func:`~pyxu.util.view_as_real`.) 138 139 Returns 140 ------- 141 out: NDArray 142 (..., N1,...,ND,2) outputs :math:`\mathbf{x} \in \mathbb{C}^{N_{1} \times\cdots\times N_{D}}` viewed as a 143 real array. (See :py:func:`~pyxu.util.view_as_real`.) 144 """ 145 x = pxu.view_as_complex(pxu.require_viewable(arr)) 146 y = self.cadjoint(x) 147 out = pxu.view_as_real(pxu.require_viewable(y)) 148 return out
149 150 def cadjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 151 r""" 152 Parameters 153 ---------- 154 arr: NDArray 155 (..., M1,...,MD) inputs :math:`\hat{\mathbf{x}} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`. 156 157 Returns 158 ------- 159 out: NDArray 160 (..., N1,...,ND) outputs :math:`\mathbf{x} \in \mathbb{C}^{N_{1} \times\cdots\times N_{D}}`. 161 """ 162 # CZT^{*}(y,M,A,W)[n] = CZT(y,N,A=1,W=W*)[n] * A^{n} 163 czt = CZT( 164 dim_shape=self.codim_shape[:-1], 165 axes=self._axes, 166 M=[self.dim_shape[ax] for ax in self._axes], 167 A=1, 168 W=np.conj(self._W), 169 **self._kwargs, 170 ) 171 out = czt.capply(arr) 172 173 # Re-scale outputs per axis. 174 xp = pxu.get_array_module(out) 175 cdtype = pxrt.CWidth(out.dtype).value 176 for i, ax in enumerate(self._axes): 177 A = self._A[i] 178 N = self.dim_shape[ax] 179 expand = (np.newaxis,) * (self.dim_rank - 2 - ax) 180 181 An = A ** xp.arange(N)[..., *expand] 182 out *= An.astype(cdtype) 183 return out 184 185 def estimate_lipschitz(self, **kwargs) -> pxt.Real: 186 no_eval = "__rule" in kwargs 187 if no_eval: 188 # We know that 189 # L^{2} = \sigma_{max}^{2}(C) = \lambda_{max}(C.gram) = \lambda_{max}(C.cogram) 190 # We know that C.[co]gram correspond to linear convolution of the input with a Dirichlet kernel, i.e. 191 # ( Gx)[n] = \sum_{q=0}^{N-1} x[q] h1[n-q], h1[n] = A^{n} W^{-(M-1)/2 n} \sin[p M/2 n] / \sin[p 1/2 n] 192 # (CGy)[n] = \sum_{q=0}^{M-1} y[q] h2[n-q], h2[n] = W^{ (N-1)/2 n} \sin[p N/2 n] / \sin[p 1/2 n] 193 # p = \arg{W} 194 # From Young's convolution inequality, we have the upper bound 195 # \norm{x \ast h}{2} \le \norm{x}{2}\norm{h}{1} 196 # Therefore 197 # L^{2} <= max(\norm{h1}{1}, \norm{h2}{1}) 198 L2 = 1 199 for i, ax in enumerate(self._axes): 200 N = self.dim_shape[ax] 201 M = self.codim_shape[ax] 202 p = np.angle(self._W[i]) 203 204 h1 = sps.diric(p * np.arange(-N, N + 1), n=M) * M 205 norm1 = np.fabs(h1).sum() 206 207 h2 = sps.diric(p * np.arange(-M, M + 1), n=N) * N 208 norm2 = np.fabs(h2).sum() 209 210 L2 *= max(norm1, norm2) 211 L = np.sqrt(L2) 212 else: 213 L = super().estimate_lipschitz(**kwargs) 214 return L 215 216 def asarray(self, **kwargs) -> pxt.NDArray: 217 # We compute 1D transforms per axis, then Kronecker product them. 218 219 # Since non-NP backends may be faulty, we do everything in NUMPY ... 220 A_1D = [None] * (D := self.dim_rank - 1) 221 i = 0 222 for ax in range(D): 223 N = self.dim_shape[ax] 224 M = self.codim_shape[ax] 225 if ax in self._axes: 226 n = np.arange(N) 227 m = np.arange(M) 228 _A, _W = self._A[i], self._W[i] 229 A_1D[ax] = (_W ** np.outer(m, n)) * (_A ** (-n)) 230 i += 1 231 else: 232 A_1D[ax] = np.eye(N) 233 234 A_ND = functools.reduce(np.multiply.outer, A_1D) 235 B_ND = np.transpose( 236 A_ND, 237 axes=np.r_[ 238 np.arange(0, 2 * D, 2), 239 np.arange(1, 2 * D, 2), 240 ], 241 ) 242 243 # ... then use the backend/precision user asked for. 244 xp = kwargs.get("xp", pxd.NDArrayInfo.default().module()) 245 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value) 246 C = xp.array( 247 pxu.as_real_op(B_ND, dim_rank=D), 248 dtype=pxrt.Width(dtype).value, 249 ) 250 return C 251 252 # Helper routines (internal) ---------------------------------------------- 253 @staticmethod 254 def _canonical_repr(axes, M, A, W): 255 # Create canonical representations 256 # * `_M`: tuple(int) 257 # * `_A`: tuple(complex) 258 # * `_W`: tuple(complex) 259 # 260 # `axes` is already assumed in tuple-form. 261 def as_seq(x, N, _type): 262 if isinstance(x, cabc.Iterable): 263 _x = tuple(x) 264 else: 265 _x = (x,) 266 if len(_x) == 1: 267 _x *= N # broadcast 268 assert len(_x) == N 269 270 return tuple(map(_type, _x)) 271 272 _M = as_seq(M, len(axes), int) 273 _A = as_seq(A, len(axes), complex) 274 _W = as_seq(W, len(axes), complex) 275 assert all(m > 0 for m in _M) 276 assert np.allclose(np.abs(_A), 1) 277 assert np.allclose(np.abs(_W), 1) 278 return _M, _A, _W 279 280 def _get_meta(self, x: pxt.NDArray): 281 # x: (..., M1,...,MD) [complex] 282 # 283 # Computes/Initializes (per axis): 284 # * `AWk2`: list[NDArray] pre-FFT modulation vectors. 285 # * `FWk2`: list[NDArray] FFT of convolution filters. 286 # * `Wk2`: list[NDArray] post-FFT modulation vectors. 287 # * `extract`: tuple[slice] FFT interval to extract. 288 # * `fft`: FFT object to transform the input. 289 from pyxu.operator import FFT 290 291 ndi = pxd.NDArrayInfo.from_obj(x) 292 if ndi == pxd.NDArrayInfo.DASK: 293 xp = pxu.get_array_module(x._meta) 294 else: 295 xp = ndi.module() 296 xpf = FFT.fft_backend(xp) 297 cdtype = pxrt.CWidth(x.dtype).value 298 299 # Initialize FFT to transform inputs. 300 fft_shape = list(self.dim_shape[:-1]) 301 for i, ax in enumerate(self._axes): 302 fft_shape[ax] += self.codim_shape[ax] - 1 303 fft_shape = FFT.next_fast_len(fft_shape) 304 fft = FFT( 305 dim_shape=fft_shape, 306 axes=self._axes, 307 **self._kwargs, 308 ) 309 310 # Build modulation vectors (Wk2, AWk2, FWk2). 311 Wk2, AWk2, FWk2 = [], [], [] 312 for i, ax in enumerate(self._axes): 313 A = self._A[i] 314 W = self._W[i] 315 N = self.dim_shape[ax] 316 M = self.codim_shape[ax] 317 L = fft.dim_shape[ax] 318 319 k = xp.arange(max(M, N)) 320 _Wk2 = W ** ((k**2) / 2) 321 _AWk2 = (A ** -k[:N]) * _Wk2[:N] 322 _FWk2 = xpf.fft( 323 xp.concatenate([_Wk2[(N - 1) : 0 : -1], _Wk2[:M]]).conj(), 324 n=L, 325 ) 326 _Wk2 = _Wk2[:M] 327 328 expand = (np.newaxis,) * (self.dim_rank - 2 - ax) 329 Wk2.append(_Wk2.astype(cdtype)[..., *expand]) 330 AWk2.append(_AWk2.astype(cdtype)[..., *expand]) 331 FWk2.append(_FWk2.astype(cdtype)[..., *expand]) 332 333 # Build (extract,) 334 N_stack = x.ndim - (self.dim_rank - 1) 335 extract = [slice(None)] * x.ndim 336 for ax in self._axes: 337 N = self.dim_shape[ax] 338 M = self.codim_shape[ax] 339 L = fft.dim_shape[ax] 340 341 extract[N_stack + ax] = slice(N - 1, N + M - 1) 342 extract = tuple(extract) 343 344 return AWk2, FWk2, Wk2, extract, fft