  1import as cabc
  2import functools
  4import numpy as np
  5import scipy.special as sps
  7import as pxa
  8import as pxd
  9import as pxt
 10import pyxu.runtime as pxrt
 11import pyxu.util as pxu
 13__all__ = [
 14    "CZT",
[docs] 18class CZT(pxa.LinOp): 19 r""" 20 Multi-dimensional Chirp Z-Transform (CZT) :math:`C: \mathbb{C}^{N_{1} \times\cdots\times N_{D}} \to 21 \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`. 22 23 The 1D CZT of parameters :math:`(A, W, M)` is defined as: 24 25 .. math:: 26 27 (C \, \mathbf{x})[k] 28 = 29 \sum_{n=0}^{N-1} \mathbf{x}[n] A^{-n} W^{nk}, 30 31 where :math:`\mathbf{x} \in \mathbb{C}^{N}`, :math:`A, W \in \mathbb{C}`, and :math:`k = \{0, \ldots, M-1\}`. 32 33 A D-dimensional CZT corresponds to taking a 1D CZT along each transform axis. 34 35 .. rubric:: Implementation Notes 36 37 For stability reasons, this implementation assumes :math:`A, W \in \mathbb{C}` lie on the unit circle. 38 39 See Also 40 -------- 41 :py:class:`~pyxu.operator.FFT` 42 """ 43
[docs] 44 def __init__( 45 self, 46 dim_shape: pxt.NDArrayShape, 47 axes: pxt.NDArrayAxis, 48 M, 49 A, 50 W, 51 **kwargs, 52 ): 53 r""" 54 Parameters 55 ---------- 56 dim_shape: NDArrayShape 57 (N1,...,ND) dimensions of the input :math:`\mathbf{x} \in \mathbb{C}^{N_{1} \times\cdots\times N_{D}}`. 58 axes: NDArrayAxis 59 Axes over which to compute the CZT. If not given, all axes are used. 60 M : int, list(int) 61 Length of the transform per axis. 62 A : complex, list(complex) 63 Circular offset from the positive real-axis per axis. 64 W : complex, list(complex) 65 Circular spacing between transform points per axis. 66 kwargs: dict 67 Extra kwargs passed to :py:class:`~pyxu.operator.FFT`. 68 """ 69 dim_shape = pxu.as_canonical_shape(dim_shape) 70 if axes is None: 71 axes = tuple(range(len(dim_shape))) 72 self._axes = pxu.as_canonical_axes(axes, len(dim_shape)) 73 _M, self._A, self._W = self._canonical_repr(self._axes, M, A, W) 74 75 codim_shape = list(dim_shape) 76 for i, ax in enumerate(self._axes): 77 codim_shape[ax] = _M[i] 78 super().__init__( 79 dim_shape=(*dim_shape, 2), 80 codim_shape=(*codim_shape, 2), 81 ) 82 self._kwargs = kwargs 83 84 # We know a crude Lipschitz bound by default. Since computing it takes (code) space, 85 # the estimate is computed as a special case of estimate_lipschitz() 86 self.lipschitz = self.estimate_lipschitz(__rule=True)
[docs] 88 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 89 r""" 90 Parameters 91 ---------- 92 arr: NDArray 93 (..., N1,...,ND,2) inputs :math:`\mathbf{x} \in \mathbb{C}^{N_{1} \times\cdots\times N_{D}}` viewed as a 94 real array. (See :py:func:`~pyxu.util.view_as_real`.) 95 96 Returns 97 ------- 98 out: NDArray 99 (..., M1,...,MD,2) outputs :math:`\hat{\mathbf{x}} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}` viewed 100 as a real array. (See :py:func:`~pyxu.util.view_as_real`.) 101 """ 102 x = pxu.view_as_complex(pxu.require_viewable(arr)) 103 y = self.capply(x) 104 out = pxu.view_as_real(pxu.require_viewable(y)) 105 return out
106 107 def capply(self, arr: pxt.NDArray) -> pxt.NDArray: 108 r""" 109 Parameters 110 ---------- 111 arr: NDArray 112 (..., N1,...,ND) inputs :math:`\mathbf{x} \in \mathbb{C}^{N_{1} \times\cdots\times N_{D}}`. 113 114 Returns 115 ------- 116 out: NDArray 117 (..., M1,...,MD) outputs :math:`\hat{\mathbf{x}} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`. 118 """ 119 AWk2, FWk2, Wk2, extract, fft = self._get_meta(arr) 120 arr = arr.copy() # for in-place updates 121 for _AWk2 in AWk2: 122 arr *= _AWk2 123 y = fft.capply(arr) 124 for _FWk2 in FWk2: 125 y *= _FWk2 126 out = fft.cpinv(y, damp=0)[extract] 127 for _Wk2 in Wk2: 128 out *= _Wk2 129 return out 130
[docs] 131 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 132 r""" 133 Parameters 134 ---------- 135 arr: NDArray 136 (..., M1,...,MD,2) inputs :math:`\hat{\mathbf{x}} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}` viewed 137 as a real array. (See :py:func:`~pyxu.util.view_as_real`.) 138 139 Returns 140 ------- 141 out: NDArray 142 (..., N1,...,ND,2) outputs :math:`\mathbf{x} \in \mathbb{C}^{N_{1} \times\cdots\times N_{D}}` viewed as a 143 real array. (See :py:func:`~pyxu.util.view_as_real`.) 144 """ 145 x = pxu.view_as_complex(pxu.require_viewable(arr)) 146 y = self.cadjoint(x) 147 out = pxu.view_as_real(pxu.require_viewable(y)) 148 return out
149 150 def cadjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 151 r""" 152 Parameters 153 ---------- 154 arr: NDArray 155 (..., M1,...,MD) inputs :math:`\hat{\mathbf{x}} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`. 156 157 Returns 158 ------- 159 out: NDArray 160 (..., N1,...,ND) outputs :math:`\mathbf{x} \in \mathbb{C}^{N_{1} \times\cdots\times N_{D}}`. 161 """ 162 # CZT^{*}(y,M,A,W)[n] = CZT(y,N,A=1,W=W*)[n] * A^{n} 163 czt = CZT( 164 dim_shape=self.codim_shape[:-1], 165 axes=self._axes, 166 M=[self.dim_shape[ax] for ax in self._axes], 167 A=1, 168 W=np.conj(self._W), 169 **self._kwargs, 170 ) 171 out = czt.capply(arr) 172 173 # Re-scale outputs per axis. 174 xp = pxu.get_array_module(out) 175 cdtype = pxrt.CWidth(out.dtype).value 176 for i, ax in enumerate(self._axes): 177 A = self._A[i] 178 N = self.dim_shape[ax] 179 expand = (np.newaxis,) * (self.dim_rank - 2 - ax) 180 181 An = A ** xp.arange(N)[..., *expand] 182 out *= An.astype(cdtype) 183 return out 184 185 def estimate_lipschitz(self, **kwargs) -> pxt.Real: 186 no_eval = "__rule" in kwargs 187 if no_eval: 188 # We know that 189 # L^{2} = \sigma_{max}^{2}(C) = \lambda_{max}(C.gram) = \lambda_{max}(C.cogram) 190 # We know that C.[co]gram correspond to linear convolution of the input with a Dirichlet kernel, i.e. 191 # ( Gx)[n] = \sum_{q=0}^{N-1} x[q] h1[n-q], h1[n] = A^{n} W^{-(M-1)/2 n} \sin[p M/2 n] / \sin[p 1/2 n] 192 # (CGy)[n] = \sum_{q=0}^{M-1} y[q] h2[n-q], h2[n] = W^{ (N-1)/2 n} \sin[p N/2 n] / \sin[p 1/2 n] 193 # p = \arg{W} 194 # From Young's convolution inequality, we have the upper bound 195 # \norm{x \ast h}{2} \le \norm{x}{2}\norm{h}{1} 196 # Therefore 197 # L^{2} <= max(\norm{h1}{1}, \norm{h2}{1}) 198 L2 = 1 199 for i, ax in enumerate(self._axes): 200 N = self.dim_shape[ax] 201 M = self.codim_shape[ax] 202 p = np.angle(self._W[i]) 203 204 h1 = sps.diric(p * np.arange(-N, N + 1), n=M) * M 205 norm1 = np.fabs(h1).sum() 206 207 h2 = sps.diric(p * np.arange(-M, M + 1), n=N) * N 208 norm2 = np.fabs(h2).sum() 209 210 L2 *= max(norm1, norm2) 211 L = np.sqrt(L2) 212 else: 213 L = super().estimate_lipschitz(**kwargs) 214 return L 215 216 def asarray(self, **kwargs) -> pxt.NDArray: 217 # We compute 1D transforms per axis, then Kronecker product them. 218 219 # Since non-NP backends may be faulty, we do everything in NUMPY ... 220 A_1D = [None] * (D := self.dim_rank - 1) 221 i = 0 222 for ax in range(D): 223 N = self.dim_shape[ax] 224 M = self.codim_shape[ax] 225 if ax in self._axes: 226 n = np.arange(N) 227 m = np.arange(M) 228 _A, _W = self._A[i], self._W[i] 229 A_1D[ax] = (_W ** np.outer(m, n)) * (_A ** (-n)) 230 i += 1 231 else: 232 A_1D[ax] = np.eye(N) 233 234 A_ND = functools.reduce(np.multiply.outer, A_1D) 235 B_ND = np.transpose( 236 A_ND, 237 axes=np.r_[ 238 np.arange(0, 2 * D, 2), 239 np.arange(1, 2 * D, 2), 240 ], 241 ) 242 243 # ... then use the backend/precision user asked for. 244 xp = kwargs.get("xp", pxd.NDArrayInfo.default().module()) 245 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value) 246 C = xp.array( 247 pxu.as_real_op(B_ND, dim_rank=D), 248 dtype=pxrt.Width(dtype).value, 249 ) 250 return C 251 252 # Helper routines (internal) ---------------------------------------------- 253 @staticmethod 254 def _canonical_repr(axes, M, A, W): 255 # Create canonical representations 256 # * `_M`: tuple(int) 257 # * `_A`: tuple(complex) 258 # * `_W`: tuple(complex) 259 # 260 # `axes` is already assumed in tuple-form. 261 def as_seq(x, N, _type): 262 if isinstance(x, cabc.Iterable): 263 _x = tuple(x) 264 else: 265 _x = (x,) 266 if len(_x) == 1: 267 _x *= N # broadcast 268 assert len(_x) == N 269 270 return tuple(map(_type, _x)) 271 272 _M = as_seq(M, len(axes), int) 273 _A = as_seq(A, len(axes), complex) 274 _W = as_seq(W, len(axes), complex) 275 assert all(m > 0 for m in _M) 276 assert np.allclose(np.abs(_A), 1) 277 assert np.allclose(np.abs(_W), 1) 278 return _M, _A, _W 279 280 def _get_meta(self, x: pxt.NDArray): 281 # x: (..., M1,...,MD) [complex] 282 # 283 # Computes/Initializes (per axis): 284 # * `AWk2`: list[NDArray] pre-FFT modulation vectors. 285 # * `FWk2`: list[NDArray] FFT of convolution filters. 286 # * `Wk2`: list[NDArray] post-FFT modulation vectors. 287 # * `extract`: tuple[slice] FFT interval to extract. 288 # * `fft`: FFT object to transform the input. 289 from pyxu.operator import FFT 290 291 ndi = pxd.NDArrayInfo.from_obj(x) 292 if ndi == pxd.NDArrayInfo.DASK: 293 xp = pxu.get_array_module(x._meta) 294 else: 295 xp = ndi.module() 296 xpf = FFT.fft_backend(xp) 297 cdtype = pxrt.CWidth(x.dtype).value 298 299 # Initialize FFT to transform inputs. 300 fft_shape = list(self.dim_shape[:-1]) 301 for i, ax in enumerate(self._axes): 302 fft_shape[ax] += self.codim_shape[ax] - 1 303 fft_shape = FFT.next_fast_len(fft_shape) 304 fft = FFT( 305 dim_shape=fft_shape, 306 axes=self._axes, 307 **self._kwargs, 308 ) 309 310 # Build modulation vectors (Wk2, AWk2, FWk2). 311 Wk2, AWk2, FWk2 = [], [], [] 312 for i, ax in enumerate(self._axes): 313 A = self._A[i] 314 W = self._W[i] 315 N = self.dim_shape[ax] 316 M = self.codim_shape[ax] 317 L = fft.dim_shape[ax] 318 319 k = xp.arange(max(M, N)) 320 _Wk2 = W ** ((k**2) / 2) 321 _AWk2 = (A ** -k[:N]) * _Wk2[:N] 322 _FWk2 = xpf.fft( 323 xp.concatenate([_Wk2[(N - 1) : 0 : -1], _Wk2[:M]]).conj(), 324 n=L, 325 ) 326 _Wk2 = _Wk2[:M] 327 328 expand = (np.newaxis,) * (self.dim_rank - 2 - ax) 329 Wk2.append(_Wk2.astype(cdtype)[..., *expand]) 330 AWk2.append(_AWk2.astype(cdtype)[..., *expand]) 331 FWk2.append(_FWk2.astype(cdtype)[..., *expand]) 332 333 # Build (extract,) 334 N_stack = x.ndim - (self.dim_rank - 1) 335 extract = [slice(None)] * x.ndim 336 for ax in self._axes: 337 N = self.dim_shape[ax] 338 M = self.codim_shape[ax] 339 L = fft.dim_shape[ax] 340 341 extract[N_stack + ax] = slice(N - 1, N + M - 1) 342 extract = tuple(extract) 343 344 return AWk2, FWk2, Wk2, extract, fft