1import collections.abc as cabc
2import functools
3
4import numpy as np
5import scipy.special as sps
6
7import pyxu.abc as pxa
8import pyxu.info.deps as pxd
9import pyxu.info.ptype as pxt
10import pyxu.runtime as pxrt
11import pyxu.util as pxu
12
13__all__ = [
14 "CZT",
15]
16
17
[docs]
18class CZT(pxa.LinOp):
19 r"""
20 Multi-dimensional Chirp Z-Transform (CZT) :math:`C: \mathbb{C}^{N_{1} \times\cdots\times N_{D}} \to
21 \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`.
22
23 The 1D CZT of parameters :math:`(A, W, M)` is defined as:
24
25 .. math::
26
27 (C \, \mathbf{x})[k]
28 =
29 \sum_{n=0}^{N-1} \mathbf{x}[n] A^{-n} W^{nk},
30
31 where :math:`\mathbf{x} \in \mathbb{C}^{N}`, :math:`A, W \in \mathbb{C}`, and :math:`k = \{0, \ldots, M-1\}`.
32
33 A D-dimensional CZT corresponds to taking a 1D CZT along each transform axis.
34
35 .. rubric:: Implementation Notes
36
37 For stability reasons, this implementation assumes :math:`A, W \in \mathbb{C}` lie on the unit circle.
38
39 See Also
40 --------
41 :py:class:`~pyxu.operator.FFT`
42 """
43
[docs]
44 def __init__(
45 self,
46 dim_shape: pxt.NDArrayShape,
47 axes: pxt.NDArrayAxis,
48 M,
49 A,
50 W,
51 **kwargs,
52 ):
53 r"""
54 Parameters
55 ----------
56 dim_shape: NDArrayShape
57 (N1,...,ND) dimensions of the input :math:`\mathbf{x} \in \mathbb{C}^{N_{1} \times\cdots\times N_{D}}`.
58 axes: NDArrayAxis
59 Axes over which to compute the CZT. If not given, all axes are used.
60 M : int, list(int)
61 Length of the transform per axis.
62 A : complex, list(complex)
63 Circular offset from the positive real-axis per axis.
64 W : complex, list(complex)
65 Circular spacing between transform points per axis.
66 kwargs: dict
67 Extra kwargs passed to :py:class:`~pyxu.operator.FFT`.
68 """
69 dim_shape = pxu.as_canonical_shape(dim_shape)
70 if axes is None:
71 axes = tuple(range(len(dim_shape)))
72 self._axes = pxu.as_canonical_axes(axes, len(dim_shape))
73 _M, self._A, self._W = self._canonical_repr(self._axes, M, A, W)
74
75 codim_shape = list(dim_shape)
76 for i, ax in enumerate(self._axes):
77 codim_shape[ax] = _M[i]
78 super().__init__(
79 dim_shape=(*dim_shape, 2),
80 codim_shape=(*codim_shape, 2),
81 )
82 self._kwargs = kwargs
83
84 # We know a crude Lipschitz bound by default. Since computing it takes (code) space,
85 # the estimate is computed as a special case of estimate_lipschitz()
86 self.lipschitz = self.estimate_lipschitz(__rule=True)
87
[docs]
88 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
89 r"""
90 Parameters
91 ----------
92 arr: NDArray
93 (..., N1,...,ND,2) inputs :math:`\mathbf{x} \in \mathbb{C}^{N_{1} \times\cdots\times N_{D}}` viewed as a
94 real array. (See :py:func:`~pyxu.util.view_as_real`.)
95
96 Returns
97 -------
98 out: NDArray
99 (..., M1,...,MD,2) outputs :math:`\hat{\mathbf{x}} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}` viewed
100 as a real array. (See :py:func:`~pyxu.util.view_as_real`.)
101 """
102 x = pxu.view_as_complex(pxu.require_viewable(arr))
103 y = self.capply(x)
104 out = pxu.view_as_real(pxu.require_viewable(y))
105 return out
106
107 def capply(self, arr: pxt.NDArray) -> pxt.NDArray:
108 r"""
109 Parameters
110 ----------
111 arr: NDArray
112 (..., N1,...,ND) inputs :math:`\mathbf{x} \in \mathbb{C}^{N_{1} \times\cdots\times N_{D}}`.
113
114 Returns
115 -------
116 out: NDArray
117 (..., M1,...,MD) outputs :math:`\hat{\mathbf{x}} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`.
118 """
119 AWk2, FWk2, Wk2, extract, fft = self._get_meta(arr)
120 arr = arr.copy() # for in-place updates
121 for _AWk2 in AWk2:
122 arr *= _AWk2
123 y = fft.capply(arr)
124 for _FWk2 in FWk2:
125 y *= _FWk2
126 out = fft.cpinv(y, damp=0)[extract]
127 for _Wk2 in Wk2:
128 out *= _Wk2
129 return out
130
[docs]
131 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
132 r"""
133 Parameters
134 ----------
135 arr: NDArray
136 (..., M1,...,MD,2) inputs :math:`\hat{\mathbf{x}} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}` viewed
137 as a real array. (See :py:func:`~pyxu.util.view_as_real`.)
138
139 Returns
140 -------
141 out: NDArray
142 (..., N1,...,ND,2) outputs :math:`\mathbf{x} \in \mathbb{C}^{N_{1} \times\cdots\times N_{D}}` viewed as a
143 real array. (See :py:func:`~pyxu.util.view_as_real`.)
144 """
145 x = pxu.view_as_complex(pxu.require_viewable(arr))
146 y = self.cadjoint(x)
147 out = pxu.view_as_real(pxu.require_viewable(y))
148 return out
149
150 def cadjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
151 r"""
152 Parameters
153 ----------
154 arr: NDArray
155 (..., M1,...,MD) inputs :math:`\hat{\mathbf{x}} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`.
156
157 Returns
158 -------
159 out: NDArray
160 (..., N1,...,ND) outputs :math:`\mathbf{x} \in \mathbb{C}^{N_{1} \times\cdots\times N_{D}}`.
161 """
162 # CZT^{*}(y,M,A,W)[n] = CZT(y,N,A=1,W=W*)[n] * A^{n}
163 czt = CZT(
164 dim_shape=self.codim_shape[:-1],
165 axes=self._axes,
166 M=[self.dim_shape[ax] for ax in self._axes],
167 A=1,
168 W=np.conj(self._W),
169 **self._kwargs,
170 )
171 out = czt.capply(arr)
172
173 # Re-scale outputs per axis.
174 xp = pxu.get_array_module(out)
175 cdtype = pxrt.CWidth(out.dtype).value
176 for i, ax in enumerate(self._axes):
177 A = self._A[i]
178 N = self.dim_shape[ax]
179 expand = (np.newaxis,) * (self.dim_rank - 2 - ax)
180
181 An = A ** xp.arange(N)[..., *expand]
182 out *= An.astype(cdtype)
183 return out
184
185 def estimate_lipschitz(self, **kwargs) -> pxt.Real:
186 no_eval = "__rule" in kwargs
187 if no_eval:
188 # We know that
189 # L^{2} = \sigma_{max}^{2}(C) = \lambda_{max}(C.gram) = \lambda_{max}(C.cogram)
190 # We know that C.[co]gram correspond to linear convolution of the input with a Dirichlet kernel, i.e.
191 # ( Gx)[n] = \sum_{q=0}^{N-1} x[q] h1[n-q], h1[n] = A^{n} W^{-(M-1)/2 n} \sin[p M/2 n] / \sin[p 1/2 n]
192 # (CGy)[n] = \sum_{q=0}^{M-1} y[q] h2[n-q], h2[n] = W^{ (N-1)/2 n} \sin[p N/2 n] / \sin[p 1/2 n]
193 # p = \arg{W}
194 # From Young's convolution inequality, we have the upper bound
195 # \norm{x \ast h}{2} \le \norm{x}{2}\norm{h}{1}
196 # Therefore
197 # L^{2} <= max(\norm{h1}{1}, \norm{h2}{1})
198 L2 = 1
199 for i, ax in enumerate(self._axes):
200 N = self.dim_shape[ax]
201 M = self.codim_shape[ax]
202 p = np.angle(self._W[i])
203
204 h1 = sps.diric(p * np.arange(-N, N + 1), n=M) * M
205 norm1 = np.fabs(h1).sum()
206
207 h2 = sps.diric(p * np.arange(-M, M + 1), n=N) * N
208 norm2 = np.fabs(h2).sum()
209
210 L2 *= max(norm1, norm2)
211 L = np.sqrt(L2)
212 else:
213 L = super().estimate_lipschitz(**kwargs)
214 return L
215
216 def asarray(self, **kwargs) -> pxt.NDArray:
217 # We compute 1D transforms per axis, then Kronecker product them.
218
219 # Since non-NP backends may be faulty, we do everything in NUMPY ...
220 A_1D = [None] * (D := self.dim_rank - 1)
221 i = 0
222 for ax in range(D):
223 N = self.dim_shape[ax]
224 M = self.codim_shape[ax]
225 if ax in self._axes:
226 n = np.arange(N)
227 m = np.arange(M)
228 _A, _W = self._A[i], self._W[i]
229 A_1D[ax] = (_W ** np.outer(m, n)) * (_A ** (-n))
230 i += 1
231 else:
232 A_1D[ax] = np.eye(N)
233
234 A_ND = functools.reduce(np.multiply.outer, A_1D)
235 B_ND = np.transpose(
236 A_ND,
237 axes=np.r_[
238 np.arange(0, 2 * D, 2),
239 np.arange(1, 2 * D, 2),
240 ],
241 )
242
243 # ... then use the backend/precision user asked for.
244 xp = kwargs.get("xp", pxd.NDArrayInfo.default().module())
245 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value)
246 C = xp.array(
247 pxu.as_real_op(B_ND, dim_rank=D),
248 dtype=pxrt.Width(dtype).value,
249 )
250 return C
251
252 # Helper routines (internal) ----------------------------------------------
253 @staticmethod
254 def _canonical_repr(axes, M, A, W):
255 # Create canonical representations
256 # * `_M`: tuple(int)
257 # * `_A`: tuple(complex)
258 # * `_W`: tuple(complex)
259 #
260 # `axes` is already assumed in tuple-form.
261 def as_seq(x, N, _type):
262 if isinstance(x, cabc.Iterable):
263 _x = tuple(x)
264 else:
265 _x = (x,)
266 if len(_x) == 1:
267 _x *= N # broadcast
268 assert len(_x) == N
269
270 return tuple(map(_type, _x))
271
272 _M = as_seq(M, len(axes), int)
273 _A = as_seq(A, len(axes), complex)
274 _W = as_seq(W, len(axes), complex)
275 assert all(m > 0 for m in _M)
276 assert np.allclose(np.abs(_A), 1)
277 assert np.allclose(np.abs(_W), 1)
278 return _M, _A, _W
279
280 def _get_meta(self, x: pxt.NDArray):
281 # x: (..., M1,...,MD) [complex]
282 #
283 # Computes/Initializes (per axis):
284 # * `AWk2`: list[NDArray] pre-FFT modulation vectors.
285 # * `FWk2`: list[NDArray] FFT of convolution filters.
286 # * `Wk2`: list[NDArray] post-FFT modulation vectors.
287 # * `extract`: tuple[slice] FFT interval to extract.
288 # * `fft`: FFT object to transform the input.
289 from pyxu.operator import FFT
290
291 ndi = pxd.NDArrayInfo.from_obj(x)
292 if ndi == pxd.NDArrayInfo.DASK:
293 xp = pxu.get_array_module(x._meta)
294 else:
295 xp = ndi.module()
296 xpf = FFT.fft_backend(xp)
297 cdtype = pxrt.CWidth(x.dtype).value
298
299 # Initialize FFT to transform inputs.
300 fft_shape = list(self.dim_shape[:-1])
301 for i, ax in enumerate(self._axes):
302 fft_shape[ax] += self.codim_shape[ax] - 1
303 fft_shape = FFT.next_fast_len(fft_shape)
304 fft = FFT(
305 dim_shape=fft_shape,
306 axes=self._axes,
307 **self._kwargs,
308 )
309
310 # Build modulation vectors (Wk2, AWk2, FWk2).
311 Wk2, AWk2, FWk2 = [], [], []
312 for i, ax in enumerate(self._axes):
313 A = self._A[i]
314 W = self._W[i]
315 N = self.dim_shape[ax]
316 M = self.codim_shape[ax]
317 L = fft.dim_shape[ax]
318
319 k = xp.arange(max(M, N))
320 _Wk2 = W ** ((k**2) / 2)
321 _AWk2 = (A ** -k[:N]) * _Wk2[:N]
322 _FWk2 = xpf.fft(
323 xp.concatenate([_Wk2[(N - 1) : 0 : -1], _Wk2[:M]]).conj(),
324 n=L,
325 )
326 _Wk2 = _Wk2[:M]
327
328 expand = (np.newaxis,) * (self.dim_rank - 2 - ax)
329 Wk2.append(_Wk2.astype(cdtype)[..., *expand])
330 AWk2.append(_AWk2.astype(cdtype)[..., *expand])
331 FWk2.append(_FWk2.astype(cdtype)[..., *expand])
332
333 # Build (extract,)
334 N_stack = x.ndim - (self.dim_rank - 1)
335 extract = [slice(None)] * x.ndim
336 for ax in self._axes:
337 N = self.dim_shape[ax]
338 M = self.codim_shape[ax]
339 L = fft.dim_shape[ax]
340
341 extract[N_stack + ax] = slice(N - 1, N + M - 1)
342 extract = tuple(extract)
343
344 return AWk2, FWk2, Wk2, extract, fft