1import functools
2import inspect
3
4import numpy as np
5
6import pyxu.abc as pxa
7import pyxu.info.deps as pxd
8import pyxu.info.ptype as pxt
9import pyxu.runtime as pxrt
10import pyxu.util as pxu
11
12__all__ = [
13 "FFT",
14]
15
16
[docs]
17class FFT(pxa.NormalOp):
18 r"""
19 Multi-dimensional Discrete Fourier Transform (DFT) :math:`A: \mathbb{C}^{M_{1} \times\cdots\times M_{D}} \to
20 \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`.
21
22 The FFT is defined as follows:
23
24 .. math::
25
26 (A \, \mathbf{x})[\mathbf{k}]
27 =
28 \sum_{\mathbf{n}} \mathbf{x}[\mathbf{n}]
29 \exp\left[-j 2 \pi \langle \frac{\mathbf{n}}{\mathbf{N}}, \mathbf{k} \rangle \right],
30
31 .. math::
32
33 (A^{*} \, \hat{\mathbf{x}})[\mathbf{n}]
34 =
35 \sum_{\mathbf{k}} \hat{\mathbf{x}}[\mathbf{k}]
36 \exp\left[j 2 \pi \langle \frac{\mathbf{n}}{\mathbf{N}}, \mathbf{k} \rangle \right],
37
38 .. math::
39
40 (\mathbf{x}, \, \hat{\mathbf{x}}) \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}},
41 \quad
42 (\mathbf{n}, \, \mathbf{k}) \in \{0, \ldots, M_{1}-1\} \times\cdots\times \{0, \ldots, M_{D}-1\}.
43
44 The DFT is taken over any number of axes by means of the Fast Fourier Transform algorithm (FFT).
45
46
47 .. rubric:: Implementation Notes
48
49 * The CPU implementation uses `SciPy's FFT implementation <https://docs.scipy.org/doc/scipy/reference/fft.html>`_.
50 * The GPU implementation uses cuFFT via `CuPy <https://docs.cupy.dev/en/latest/reference/scipy_fft.html>`_.
51 * The DASK implementation evaluates the FFT in chunks using the `CZT algorithm
52 <https://en.wikipedia.org/wiki/Chirp_Z-transform>`_.
53
54 Caveat: the cost of assembling the DASK graph grows with the total number of chunks; just calling ``FFT.apply()``
55 may take a few seconds or more if inputs are highly chunked. Performance is ~7-10x slower than equivalent
56 non-chunked NUMPY version (assuming it fits in memory).
57
58
59 Examples
60 --------
61
62 * 1D DFT of a cosine pulse.
63
64 .. code-block:: python3
65
66 from pyxu.operator import FFT
67 import pyxu.util as pxu
68
69 N = 10
70 op = FFT(N)
71
72 x = np.cos(2 * np.pi / N * np.arange(N), dtype=complex) # (N,)
73 x_r = pxu.view_as_real(x) # (N, 2)
74
75 y_r = op.apply(x_r) # (N, 2)
76 y = pxu.view_as_complex(y_r) # (N,)
77 # [0, N/2, 0, 0, 0, 0, 0, 0, 0, N/2]
78
79 z_r = op.adjoint(op.apply(x_r)) # (N, 2)
80 z = pxu.view_as_complex(z_r) # (N,)
81 # np.allclose(z, N * x) -> True
82
83 * 1D DFT of a complex exponential pulse.
84
85 .. code-block:: python3
86
87 from pyxu.operator import FFT
88 import pyxu.util as pxu
89
90 N = 10
91 op = FFT(N)
92
93 x = np.exp(1j * 2 * np.pi / N * np.arange(N)) # (N,)
94 x_r = pxu.view_as_real(x) # (N, 2)
95
96 y_r = op.apply(x_r) # (N, 2)
97 y = pxu.view_as_complex(y_r) # (N,)
98 # [0, N, 0, 0, 0, 0, 0, 0, 0, 0]
99
100 z_r = op.adjoint(op.apply(x_r)) # (N, 2)
101 z = pxu.view_as_complex(z_r) # (N,)
102 # np.allclose(z, N * x) -> True
103
104 * 2D DFT of an image
105
106 .. code-block:: python3
107
108 from pyxu.operator import FFT
109 import pyxu.util as pxu
110
111 N_h, N_w = 10, 8
112 op = FFT((N_h, N_w))
113
114 x = np.pad( # (N_h, N_w)
115 np.ones((N_h//2, N_w//2), dtype=complex),
116 pad_width=((0, N_h//2), (0, N_w//2)),
117 )
118 x_r = pxu.view_as_real(x) # (N_h, N_w, 2)
119
120 y_r = op.apply(x_r) # (N_h, N_w, 2)
121 y = pxu.view_as_complex(y_r) # (N_h, N_w)
122
123 z_r = op.adjoint(op.apply(x_r)) # (N_h, N_w, 2)
124 z = pxu.view_as_complex(z_r) # (N_h, N_w)
125 # np.allclose(z, (N_h * N_w) * x) -> True
126 """
127
[docs]
128 def __init__(
129 self,
130 dim_shape: pxt.NDArrayShape,
131 axes: pxt.NDArrayAxis = None,
132 **kwargs,
133 ):
134 r"""
135 Parameters
136 ----------
137 dim_shape: NDArrayShape
138 (M1,...,MD) dimensions of the input :math:`\mathbf{x} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`.
139 axes: NDArrayAxis
140 Axes over which to compute the FFT. If not given, all axes are used.
141 kwargs: dict
142 Extra kwargs passed to :py:func:`scipy.fft.fftn` or :py:func:`cupyx.scipy.fft.fftn`.
143
144 Supported parameters for :py:func:`scipy.fft.fftn` are:
145
146 * workers: int = None
147
148 Supported parameters for :py:func:`cupyx.scipy.fft.fftn` are:
149
150 * NOT SUPPORTED FOR NOW
151
152 Default values are chosen if unspecified.
153 """
154 dim_shape = pxu.as_canonical_shape(dim_shape)
155 super().__init__(
156 dim_shape=(*dim_shape, 2),
157 codim_shape=(*dim_shape, 2),
158 )
159
160 if axes is None:
161 axes = tuple(range(self.codim_rank - 1))
162 axes = pxu.as_canonical_axes(axes, rank=self.codim_rank - 1)
163 self._axes = tuple(sorted(set(axes))) # drop duplicates
164
165 self._kwargs = {
166 pxd.NDArrayInfo.NUMPY: dict(
167 workers=kwargs.get("workers", None),
168 ),
169 pxd.NDArrayInfo.CUPY: dict(),
170 pxd.NDArrayInfo.DASK: dict(),
171 }
172 self.lipschitz = self.estimate_lipschitz()
173
174 def estimate_lipschitz(self, **kwargs) -> pxt.Real:
175 M = [self.codim_shape[ax] for ax in self._axes]
176 L = np.sqrt(np.prod(M))
177 return L
178
179 def gram(self) -> pxt.OpT:
180 from pyxu.operator import HomothetyOp
181
182 G = HomothetyOp(
183 dim_shape=self.dim_shape,
184 cst=self.lipschitz**2,
185 )
186 return G
187
188 def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray:
189 out = self.adjoint(arr)
190 out /= (self.lipschitz**2) + damp
191 return out
192
193 def cpinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray:
194 r"""
195 Parameters
196 ----------
197 arr: NDArray
198 (..., M1,...,MD) inputs :math:`\mathbf{x} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`.
199
200 Returns
201 -------
202 out: NDArray
203 (..., M1,...,MD) pseudo-inverse :math:`\hat{\mathbf{x}} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`.
204 """
205 out = self.cadjoint(arr)
206 out /= (self.lipschitz**2) + damp
207 return out
208
209 def dagger(self, damp: pxt.Real, **kwargs) -> pxt.OpT:
210 op = self.T / ((self.lipschitz**2) + damp)
211 return op
212
213 def svdvals(self, **kwargs) -> pxt.NDArray:
214 D = pxa.UnitOp.svdvals(self, **kwargs) * self.lipschitz
215 return D
216
217 def asarray(self, **kwargs) -> pxt.NDArray:
218 # We compute 1D transforms per axis, then Kronecker product them.
219
220 # Since non-NP backends may be faulty, we do everything in NUMPY ...
221 A_1D = [None] * (D := self.dim_rank - 1)
222 for ax in range(D):
223 N = self.dim_shape[ax]
224 if ax in self._axes:
225 n = np.arange(N)
226 A_1D[ax] = np.exp((-2j * np.pi / N) * np.outer(n, n))
227 else:
228 A_1D[ax] = np.eye(N)
229
230 A_ND = functools.reduce(np.multiply.outer, A_1D)
231 B_ND = np.transpose(
232 A_ND,
233 axes=np.r_[
234 np.arange(0, 2 * D, 2),
235 np.arange(1, 2 * D, 2),
236 ],
237 )
238
239 # ... then use the backend/precision user asked for.
240 xp = kwargs.get("xp", pxd.NDArrayInfo.default().module())
241 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value)
242 C = xp.array(
243 pxu.as_real_op(B_ND, dim_rank=D),
244 dtype=pxrt.Width(dtype).value,
245 )
246 return C
247
[docs]
248 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
249 r"""
250 Parameters
251 ----------
252 arr: NDArray
253 (..., M1,...,MD,2) inputs :math:`\mathbf{x} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}` viewed as a
254 real array. (See :py:func:`~pyxu.util.view_as_real`.)
255
256 Returns
257 -------
258 out: NDArray
259 (..., M1,...,MD,2) outputs :math:`\hat{\mathbf{x}} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}` viewed
260 as a real array. (See :py:func:`~pyxu.util.view_as_real`.)
261 """
262 x = pxu.view_as_complex(pxu.require_viewable(arr)) # (..., M1,...,MD)
263 y = self.capply(x)
264 out = pxu.view_as_real(pxu.require_viewable(y)) # (..., M1,...,MD,2)
265 return out
266
267 def capply(self, arr: pxt.NDArray) -> pxt.NDArray:
268 r"""
269 Parameters
270 ----------
271 arr: NDArray
272 (..., M1,...,MD) inputs :math:`\mathbf{x} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`.
273
274 Returns
275 -------
276 out: NDArray
277 (..., M1,...,MD) outputs :math:`\hat{\mathbf{x}} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`.
278 """
279 out = self._transform(arr, mode="fw")
280 return out
281
[docs]
282 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
283 r"""
284 Parameters
285 ----------
286 arr: NDArray
287 (..., M1,...,MD,2) inputs :math:`\hat{\mathbf{x}} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}` viewed
288 as a real array. (See :py:func:`~pyxu.util.view_as_real`.)
289
290 Returns
291 -------
292 out: NDArray
293 (..., M1,...,MD,2) outputs :math:`\mathbf{x} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}` viewed as a
294 real array. (See :py:func:`~pyxu.util.view_as_real`.)
295 """
296 x = pxu.view_as_complex(pxu.require_viewable(arr)) # (..., M1,...,MD)
297 y = self.cadjoint(x)
298 out = pxu.view_as_real(pxu.require_viewable(y)) # (..., M1,...,MD,2)
299 return out
300
301 def cadjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
302 r"""
303 Parameters
304 ----------
305 arr: NDArray
306 (..., M1,...,MD) inputs :math:`\hat{\mathbf{x}} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`.
307
308 Returns
309 -------
310 out: NDArray
311 (..., M1,...,MD) outputs :math:`\mathbf{x} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`.
312 """
313 out = self._transform(arr, mode="bw")
314 return out
315
316 # Helpers (public) --------------------------------------------------------
317 @classmethod
318 def fft_backend(cls, xp: pxt.ArrayModule = None):
319 """
320 Retrieve the namespace containing [i]fftn().
321
322 Parameters
323 ----------
324 xp: ArrayModule
325 Array module used to compute the FFT. (Default: NumPy.)
326
327 Returns
328 -------
329 xpf: ModuleType
330 """
331 N = pxd.NDArrayInfo # short-hand
332 if xp is None:
333 xp = N.default().module()
334
335 if xp == N.NUMPY.module():
336 xpf = pxu.import_module("scipy.fft")
337 elif pxd.CUPY_ENABLED and (xp == N.CUPY.module()):
338 xpf = pxu.import_module("cupyx.scipy.fft")
339 else:
340 raise NotImplementedError
341
342 return xpf
343
344 @classmethod
345 def next_fast_len(
346 cls,
347 dim_shape: pxt.NDArrayShape,
348 axes: pxt.NDArrayAxis = None,
349 xp: pxt.ArrayModule = None,
350 ) -> pxt.NDArrayShape:
351 r"""
352 Retrieve the next-best dimensions to perform an FFT.
353
354 Parameters
355 ----------
356 dim_shape: NDArrayShape
357 (M1,...,MD) dimensions of the input :math:`\mathbf{x} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`.
358 axes: NDArrayAxis
359 Axes over which to compute the FFT. If not given, all axes are used.
360 xp: ArrayModule
361 Which array module used to compute the FFT. (Default: NumPy.)
362
363 Returns
364 -------
365 opt_shape: NDArrayShape
366 FFT shape (N1,...,ND) >= (M1,...,MD).
367 """
368 xpf = cls.fft_backend(xp)
369
370 dim_shape = pxu.as_canonical_shape(dim_shape)
371 if axes is None:
372 axes = tuple(range(len(dim_shape)))
373 axes = pxu.as_canonical_axes(axes, rank=len(dim_shape))
374
375 opt_shape = list(dim_shape)
376 for ax in axes:
377 opt_shape[ax] = xpf.next_fast_len(dim_shape[ax])
378 return tuple(opt_shape)
379
380 # Helpers (internal) ------------------------------------------------------
381 def _transform(self, x: pxt.NDArray, mode: str) -> pxt.NDArray:
382 # Parameters
383 # ----------
384 # x: NDArray [real/complex]
385 # (..., M1,...,MD) array to transform.
386 # [(..., L1,...,LD), Lk <= Mk works too: will be zero-padded as required.]
387 # mode: str
388 # Transform direction:
389 #
390 # * 'fw': fftn(norm="backward")
391 # * 'bw': ifftn(norm="forward")
392 #
393 # Returns
394 # -------
395 # y: NDArray [complex]
396 # (..., M1,...,MD) transformed array.
397 N = pxd.NDArrayInfo # shorthand
398 ndi = N.from_obj(x)
399 xp = ndi.module()
400
401 axes = tuple(ax - (self.codim_rank - 1) for ax in self._axes)
402 if ndi == N.DASK:
403 # Entries must have right shape for CZT: pad if required.
404 pad_width = [(0, 0)] * x.ndim
405 for ax in axes:
406 pad_width[ax] = (0, self.dim_shape[ax - 1] - x.shape[ax])
407
408 y = xp.pad(x, pad_width)
409 for ax in axes:
410 y = self._chunked_transform1D(y, mode, ax)
411 else: # NUMPY/CUPY
412 xpf = self.fft_backend(xp)
413
414 func, norm = dict( # ref: scipy.fft norm conventions
415 fw=(xpf.fftn, "backward"),
416 bw=(xpf.ifftn, "forward"),
417 )[mode]
418
419 # `self._kwargs()` contains parameters undersood by different FFT backends.
420 # Need to drop all non-standard parameters.
421 sig = inspect.Signature.from_callable(func)
422 kwargs = {k: v for (k, v) in self._kwargs[ndi].items() if (k in sig.parameters)}
423
424 N_FFT = tuple(self.dim_shape[ax] for ax in self._axes)
425 y = func(
426 x=x,
427 s=N_FFT,
428 axes=axes,
429 norm=norm,
430 **kwargs,
431 )
432 return y
433
434 @staticmethod
435 def _chunked_transform1D(x: pxt.NDArray, mode: str, axis: int) -> pxt.NDArray:
436 # Same signature as _transform(), but:
437 # * limited to DASK inputs;
438 # * performs 1D transform along chosen axis.
439
440 def _mod_czt(x, M, A, W, n0, k0, axis):
441 # 1D Chirp Z-Transform, followed by modulation with W**(k0 * [n0:n0+M]).
442 #
443 # This is a stripped-down version for performing chunked FFTs: don't use for other purposes.
444 #
445 # Parameters
446 # ----------
447 # x : NDArray
448 # (..., N, ...) NUMPY/CUPY array.
449 # M : int
450 # Length of the transform.
451 # A : complex
452 # Circular offset from the positive real-axis.
453 # W : complex
454 # Circular spacing between transform points.
455 # k0, n0: int
456 # Modulation coefficients.
457 # axis : int
458 # Dimension of `x` along which the samples are stored.
459 #
460 # Returns
461 # -------
462 # z: NDArray
463 # (..., M, ...) modulated CZT along the axis indicated by `axis`.
464 # The precision matches that of `x`.
465 # [Note that SciPy's CZT implementation does not guarantee this.]
466
467 # set backend -------------------------------------
468 xp = pxu.get_array_module(x)
469 xpf = FFT.fft_backend(xp)
470
471 # constants ---------------------------------------
472 N = x.shape[axis]
473 N_FFT = xpf.next_fast_len(N + M - 1)
474 swap = np.arange(x.ndim)
475 swap[[axis, -1]] = [-1, axis]
476
477 # filters -----------------------------------------
478 k = xp.arange(max(M, N))
479 Wk2 = W ** ((k**2) / 2)
480 AWk2 = (A ** -k[:N]) * Wk2[:N]
481 FWk2 = xpf.fft(
482 xp.r_[Wk2[(N - 1) : 0 : -1], Wk2[:M]].conj(),
483 n=N_FFT,
484 )
485 Wk2 = Wk2[:M]
486
487 # transform inputs --------------------------------
488 x = x.transpose(*swap).copy()
489 x *= AWk2
490 y = xpf.fft(x, n=N_FFT)
491 y *= FWk2
492 z = xpf.ifft(y)[..., (N - 1) : (N + M - 1)]
493 z *= Wk2
494
495 # modulate CZT ------------------------------------
496 z *= W ** (xp.arange(n0, n0 + M) * k0)
497 return z.transpose(*swap)
498
499 def block_ip(x: list[pxt.NDArray], k: pxt.NDArray, sign: int, axis: int) -> pxt.NDArray:
500 # Block-defined inner-product.
501 #
502 # y[:,...,:,k,:,...,:] = \sum_{n} x[:,...,:,n,:,...,:] W^{nk}
503 #
504 # `x`: list of chunks along transformed dimension.
505 # `k`: consecutive sequence of output frequencies along transformed dimension.
506 # `sign`: sign of the exponent.
507 # `axis`: transformed axis.
508 ndi = pxd.NDArrayInfo.from_obj(k)
509 xp = ndi.module()
510
511 chunks = tuple(_x.shape[axis] for _x in x)
512 N, M = sum(chunks), len(k)
513 W = xp.exp(sign * 2j * np.pi / N)
514 S = xp.cumsum(xp.r_[0, chunks])
515
516 y = 0
517 for idx_n, _x in enumerate(x):
518 y += _mod_czt(
519 x=_x,
520 M=M,
521 A=W ** (-k[0]),
522 W=W,
523 n0=k[0],
524 k0=S[idx_n],
525 axis=axis,
526 )
527 return y
528
529 sign = dict(fw=-1, bw=1)[mode]
530 xp = pxd.NDArrayInfo.DASK.module()
531 try: # `x` complex-valued
532 cdtype = pxrt.CWidth(x.dtype).value
533 except Exception: # `x` was real-valued
534 cdtype = pxrt.Width(x.dtype).complex.value
535
536 ip_ind = tuple(range(x.ndim))
537 x_ind = list(range(x.ndim))
538 x_ind[axis] = -1 # block_ip() along this axis
539 k = xp.arange(x.shape[axis], chunks=x.chunks[axis])
540 k_ind = (ip_ind[axis],)
541
542 y = xp.blockwise(
543 *(block_ip, ip_ind),
544 *(x, x_ind, k, k_ind),
545 dtype=cdtype,
546 align_arrays=False,
547 concatenate=False,
548 meta=x._meta,
549 # extra kwargs for block_ip()
550 sign=sign,
551 axis=axis,
552 )
553 return y