Source code for pyxu.operator.linop.fft.fft

  1import functools
  2import inspect
  3
  4import numpy as np
  5
  6import pyxu.abc as pxa
  7import pyxu.info.deps as pxd
  8import pyxu.info.ptype as pxt
  9import pyxu.runtime as pxrt
 10import pyxu.util as pxu
 11
 12__all__ = [
 13    "FFT",
 14]
 15
 16
[docs] 17class FFT(pxa.NormalOp): 18 r""" 19 Multi-dimensional Discrete Fourier Transform (DFT) :math:`A: \mathbb{C}^{M_{1} \times\cdots\times M_{D}} \to 20 \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`. 21 22 The FFT is defined as follows: 23 24 .. math:: 25 26 (A \, \mathbf{x})[\mathbf{k}] 27 = 28 \sum_{\mathbf{n}} \mathbf{x}[\mathbf{n}] 29 \exp\left[-j 2 \pi \langle \frac{\mathbf{n}}{\mathbf{N}}, \mathbf{k} \rangle \right], 30 31 .. math:: 32 33 (A^{*} \, \hat{\mathbf{x}})[\mathbf{n}] 34 = 35 \sum_{\mathbf{k}} \hat{\mathbf{x}}[\mathbf{k}] 36 \exp\left[j 2 \pi \langle \frac{\mathbf{n}}{\mathbf{N}}, \mathbf{k} \rangle \right], 37 38 .. math:: 39 40 (\mathbf{x}, \, \hat{\mathbf{x}}) \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}, 41 \quad 42 (\mathbf{n}, \, \mathbf{k}) \in \{0, \ldots, M_{1}-1\} \times\cdots\times \{0, \ldots, M_{D}-1\}. 43 44 The DFT is taken over any number of axes by means of the Fast Fourier Transform algorithm (FFT). 45 46 47 .. rubric:: Implementation Notes 48 49 * The CPU implementation uses `SciPy's FFT implementation <https://docs.scipy.org/doc/scipy/reference/fft.html>`_. 50 * The GPU implementation uses cuFFT via `CuPy <https://docs.cupy.dev/en/latest/reference/scipy_fft.html>`_. 51 * The DASK implementation evaluates the FFT in chunks using the `CZT algorithm 52 <https://en.wikipedia.org/wiki/Chirp_Z-transform>`_. 53 54 Caveat: the cost of assembling the DASK graph grows with the total number of chunks; just calling ``FFT.apply()`` 55 may take a few seconds or more if inputs are highly chunked. Performance is ~7-10x slower than equivalent 56 non-chunked NUMPY version (assuming it fits in memory). 57 58 59 Examples 60 -------- 61 62 * 1D DFT of a cosine pulse. 63 64 .. code-block:: python3 65 66 from pyxu.operator import FFT 67 import pyxu.util as pxu 68 69 N = 10 70 op = FFT(N) 71 72 x = np.cos(2 * np.pi / N * np.arange(N), dtype=complex) # (N,) 73 x_r = pxu.view_as_real(x) # (N, 2) 74 75 y_r = op.apply(x_r) # (N, 2) 76 y = pxu.view_as_complex(y_r) # (N,) 77 # [0, N/2, 0, 0, 0, 0, 0, 0, 0, N/2] 78 79 z_r = op.adjoint(op.apply(x_r)) # (N, 2) 80 z = pxu.view_as_complex(z_r) # (N,) 81 # np.allclose(z, N * x) -> True 82 83 * 1D DFT of a complex exponential pulse. 84 85 .. code-block:: python3 86 87 from pyxu.operator import FFT 88 import pyxu.util as pxu 89 90 N = 10 91 op = FFT(N) 92 93 x = np.exp(1j * 2 * np.pi / N * np.arange(N)) # (N,) 94 x_r = pxu.view_as_real(x) # (N, 2) 95 96 y_r = op.apply(x_r) # (N, 2) 97 y = pxu.view_as_complex(y_r) # (N,) 98 # [0, N, 0, 0, 0, 0, 0, 0, 0, 0] 99 100 z_r = op.adjoint(op.apply(x_r)) # (N, 2) 101 z = pxu.view_as_complex(z_r) # (N,) 102 # np.allclose(z, N * x) -> True 103 104 * 2D DFT of an image 105 106 .. code-block:: python3 107 108 from pyxu.operator import FFT 109 import pyxu.util as pxu 110 111 N_h, N_w = 10, 8 112 op = FFT((N_h, N_w)) 113 114 x = np.pad( # (N_h, N_w) 115 np.ones((N_h//2, N_w//2), dtype=complex), 116 pad_width=((0, N_h//2), (0, N_w//2)), 117 ) 118 x_r = pxu.view_as_real(x) # (N_h, N_w, 2) 119 120 y_r = op.apply(x_r) # (N_h, N_w, 2) 121 y = pxu.view_as_complex(y_r) # (N_h, N_w) 122 123 z_r = op.adjoint(op.apply(x_r)) # (N_h, N_w, 2) 124 z = pxu.view_as_complex(z_r) # (N_h, N_w) 125 # np.allclose(z, (N_h * N_w) * x) -> True 126 """ 127
[docs] 128 def __init__( 129 self, 130 dim_shape: pxt.NDArrayShape, 131 axes: pxt.NDArrayAxis = None, 132 **kwargs, 133 ): 134 r""" 135 Parameters 136 ---------- 137 dim_shape: NDArrayShape 138 (M1,...,MD) dimensions of the input :math:`\mathbf{x} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`. 139 axes: NDArrayAxis 140 Axes over which to compute the FFT. If not given, all axes are used. 141 kwargs: dict 142 Extra kwargs passed to :py:func:`scipy.fft.fftn` or :py:func:`cupyx.scipy.fft.fftn`. 143 144 Supported parameters for :py:func:`scipy.fft.fftn` are: 145 146 * workers: int = None 147 148 Supported parameters for :py:func:`cupyx.scipy.fft.fftn` are: 149 150 * NOT SUPPORTED FOR NOW 151 152 Default values are chosen if unspecified. 153 """ 154 dim_shape = pxu.as_canonical_shape(dim_shape) 155 super().__init__( 156 dim_shape=(*dim_shape, 2), 157 codim_shape=(*dim_shape, 2), 158 ) 159 160 if axes is None: 161 axes = tuple(range(self.codim_rank - 1)) 162 axes = pxu.as_canonical_axes(axes, rank=self.codim_rank - 1) 163 self._axes = tuple(sorted(set(axes))) # drop duplicates 164 165 self._kwargs = { 166 pxd.NDArrayInfo.NUMPY: dict( 167 workers=kwargs.get("workers", None), 168 ), 169 pxd.NDArrayInfo.CUPY: dict(), 170 pxd.NDArrayInfo.DASK: dict(), 171 } 172 self.lipschitz = self.estimate_lipschitz()
173 174 def estimate_lipschitz(self, **kwargs) -> pxt.Real: 175 M = [self.codim_shape[ax] for ax in self._axes] 176 L = np.sqrt(np.prod(M)) 177 return L 178 179 def gram(self) -> pxt.OpT: 180 from pyxu.operator import HomothetyOp 181 182 G = HomothetyOp( 183 dim_shape=self.dim_shape, 184 cst=self.lipschitz**2, 185 ) 186 return G 187 188 def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray: 189 out = self.adjoint(arr) 190 out /= (self.lipschitz**2) + damp 191 return out 192 193 def cpinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray: 194 r""" 195 Parameters 196 ---------- 197 arr: NDArray 198 (..., M1,...,MD) inputs :math:`\mathbf{x} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`. 199 200 Returns 201 ------- 202 out: NDArray 203 (..., M1,...,MD) pseudo-inverse :math:`\hat{\mathbf{x}} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`. 204 """ 205 out = self.cadjoint(arr) 206 out /= (self.lipschitz**2) + damp 207 return out 208 209 def dagger(self, damp: pxt.Real, **kwargs) -> pxt.OpT: 210 op = self.T / ((self.lipschitz**2) + damp) 211 return op 212 213 def svdvals(self, **kwargs) -> pxt.NDArray: 214 D = pxa.UnitOp.svdvals(self, **kwargs) * self.lipschitz 215 return D 216 217 def asarray(self, **kwargs) -> pxt.NDArray: 218 # We compute 1D transforms per axis, then Kronecker product them. 219 220 # Since non-NP backends may be faulty, we do everything in NUMPY ... 221 A_1D = [None] * (D := self.dim_rank - 1) 222 for ax in range(D): 223 N = self.dim_shape[ax] 224 if ax in self._axes: 225 n = np.arange(N) 226 A_1D[ax] = np.exp((-2j * np.pi / N) * np.outer(n, n)) 227 else: 228 A_1D[ax] = np.eye(N) 229 230 A_ND = functools.reduce(np.multiply.outer, A_1D) 231 B_ND = np.transpose( 232 A_ND, 233 axes=np.r_[ 234 np.arange(0, 2 * D, 2), 235 np.arange(1, 2 * D, 2), 236 ], 237 ) 238 239 # ... then use the backend/precision user asked for. 240 xp = kwargs.get("xp", pxd.NDArrayInfo.default().module()) 241 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value) 242 C = xp.array( 243 pxu.as_real_op(B_ND, dim_rank=D), 244 dtype=pxrt.Width(dtype).value, 245 ) 246 return C 247
[docs] 248 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 249 r""" 250 Parameters 251 ---------- 252 arr: NDArray 253 (..., M1,...,MD,2) inputs :math:`\mathbf{x} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}` viewed as a 254 real array. (See :py:func:`~pyxu.util.view_as_real`.) 255 256 Returns 257 ------- 258 out: NDArray 259 (..., M1,...,MD,2) outputs :math:`\hat{\mathbf{x}} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}` viewed 260 as a real array. (See :py:func:`~pyxu.util.view_as_real`.) 261 """ 262 x = pxu.view_as_complex(pxu.require_viewable(arr)) # (..., M1,...,MD) 263 y = self.capply(x) 264 out = pxu.view_as_real(pxu.require_viewable(y)) # (..., M1,...,MD,2) 265 return out
266 267 def capply(self, arr: pxt.NDArray) -> pxt.NDArray: 268 r""" 269 Parameters 270 ---------- 271 arr: NDArray 272 (..., M1,...,MD) inputs :math:`\mathbf{x} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`. 273 274 Returns 275 ------- 276 out: NDArray 277 (..., M1,...,MD) outputs :math:`\hat{\mathbf{x}} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`. 278 """ 279 out = self._transform(arr, mode="fw") 280 return out 281
[docs] 282 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 283 r""" 284 Parameters 285 ---------- 286 arr: NDArray 287 (..., M1,...,MD,2) inputs :math:`\hat{\mathbf{x}} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}` viewed 288 as a real array. (See :py:func:`~pyxu.util.view_as_real`.) 289 290 Returns 291 ------- 292 out: NDArray 293 (..., M1,...,MD,2) outputs :math:`\mathbf{x} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}` viewed as a 294 real array. (See :py:func:`~pyxu.util.view_as_real`.) 295 """ 296 x = pxu.view_as_complex(pxu.require_viewable(arr)) # (..., M1,...,MD) 297 y = self.cadjoint(x) 298 out = pxu.view_as_real(pxu.require_viewable(y)) # (..., M1,...,MD,2) 299 return out
300 301 def cadjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 302 r""" 303 Parameters 304 ---------- 305 arr: NDArray 306 (..., M1,...,MD) inputs :math:`\hat{\mathbf{x}} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`. 307 308 Returns 309 ------- 310 out: NDArray 311 (..., M1,...,MD) outputs :math:`\mathbf{x} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`. 312 """ 313 out = self._transform(arr, mode="bw") 314 return out 315 316 # Helpers (public) -------------------------------------------------------- 317 @classmethod 318 def fft_backend(cls, xp: pxt.ArrayModule = None): 319 """ 320 Retrieve the namespace containing [i]fftn(). 321 322 Parameters 323 ---------- 324 xp: ArrayModule 325 Array module used to compute the FFT. (Default: NumPy.) 326 327 Returns 328 ------- 329 xpf: ModuleType 330 """ 331 N = pxd.NDArrayInfo # short-hand 332 if xp is None: 333 xp = N.default().module() 334 335 if xp == N.NUMPY.module(): 336 xpf = pxu.import_module("scipy.fft") 337 elif pxd.CUPY_ENABLED and (xp == N.CUPY.module()): 338 xpf = pxu.import_module("cupyx.scipy.fft") 339 else: 340 raise NotImplementedError 341 342 return xpf 343 344 @classmethod 345 def next_fast_len( 346 cls, 347 dim_shape: pxt.NDArrayShape, 348 axes: pxt.NDArrayAxis = None, 349 xp: pxt.ArrayModule = None, 350 ) -> pxt.NDArrayShape: 351 r""" 352 Retrieve the next-best dimensions to perform an FFT. 353 354 Parameters 355 ---------- 356 dim_shape: NDArrayShape 357 (M1,...,MD) dimensions of the input :math:`\mathbf{x} \in \mathbb{C}^{M_{1} \times\cdots\times M_{D}}`. 358 axes: NDArrayAxis 359 Axes over which to compute the FFT. If not given, all axes are used. 360 xp: ArrayModule 361 Which array module used to compute the FFT. (Default: NumPy.) 362 363 Returns 364 ------- 365 opt_shape: NDArrayShape 366 FFT shape (N1,...,ND) >= (M1,...,MD). 367 """ 368 xpf = cls.fft_backend(xp) 369 370 dim_shape = pxu.as_canonical_shape(dim_shape) 371 if axes is None: 372 axes = tuple(range(len(dim_shape))) 373 axes = pxu.as_canonical_axes(axes, rank=len(dim_shape)) 374 375 opt_shape = list(dim_shape) 376 for ax in axes: 377 opt_shape[ax] = xpf.next_fast_len(dim_shape[ax]) 378 return tuple(opt_shape) 379 380 # Helpers (internal) ------------------------------------------------------ 381 def _transform(self, x: pxt.NDArray, mode: str) -> pxt.NDArray: 382 # Parameters 383 # ---------- 384 # x: NDArray [real/complex] 385 # (..., M1,...,MD) array to transform. 386 # [(..., L1,...,LD), Lk <= Mk works too: will be zero-padded as required.] 387 # mode: str 388 # Transform direction: 389 # 390 # * 'fw': fftn(norm="backward") 391 # * 'bw': ifftn(norm="forward") 392 # 393 # Returns 394 # ------- 395 # y: NDArray [complex] 396 # (..., M1,...,MD) transformed array. 397 N = pxd.NDArrayInfo # shorthand 398 ndi = N.from_obj(x) 399 xp = ndi.module() 400 401 axes = tuple(ax - (self.codim_rank - 1) for ax in self._axes) 402 if ndi == N.DASK: 403 # Entries must have right shape for CZT: pad if required. 404 pad_width = [(0, 0)] * x.ndim 405 for ax in axes: 406 pad_width[ax] = (0, self.dim_shape[ax - 1] - x.shape[ax]) 407 408 y = xp.pad(x, pad_width) 409 for ax in axes: 410 y = self._chunked_transform1D(y, mode, ax) 411 else: # NUMPY/CUPY 412 xpf = self.fft_backend(xp) 413 414 func, norm = dict( # ref: scipy.fft norm conventions 415 fw=(xpf.fftn, "backward"), 416 bw=(xpf.ifftn, "forward"), 417 )[mode] 418 419 # `self._kwargs()` contains parameters undersood by different FFT backends. 420 # Need to drop all non-standard parameters. 421 sig = inspect.Signature.from_callable(func) 422 kwargs = {k: v for (k, v) in self._kwargs[ndi].items() if (k in sig.parameters)} 423 424 N_FFT = tuple(self.dim_shape[ax] for ax in self._axes) 425 y = func( 426 x=x, 427 s=N_FFT, 428 axes=axes, 429 norm=norm, 430 **kwargs, 431 ) 432 return y 433 434 @staticmethod 435 def _chunked_transform1D(x: pxt.NDArray, mode: str, axis: int) -> pxt.NDArray: 436 # Same signature as _transform(), but: 437 # * limited to DASK inputs; 438 # * performs 1D transform along chosen axis. 439 440 def _mod_czt(x, M, A, W, n0, k0, axis): 441 # 1D Chirp Z-Transform, followed by modulation with W**(k0 * [n0:n0+M]). 442 # 443 # This is a stripped-down version for performing chunked FFTs: don't use for other purposes. 444 # 445 # Parameters 446 # ---------- 447 # x : NDArray 448 # (..., N, ...) NUMPY/CUPY array. 449 # M : int 450 # Length of the transform. 451 # A : complex 452 # Circular offset from the positive real-axis. 453 # W : complex 454 # Circular spacing between transform points. 455 # k0, n0: int 456 # Modulation coefficients. 457 # axis : int 458 # Dimension of `x` along which the samples are stored. 459 # 460 # Returns 461 # ------- 462 # z: NDArray 463 # (..., M, ...) modulated CZT along the axis indicated by `axis`. 464 # The precision matches that of `x`. 465 # [Note that SciPy's CZT implementation does not guarantee this.] 466 467 # set backend ------------------------------------- 468 xp = pxu.get_array_module(x) 469 xpf = FFT.fft_backend(xp) 470 471 # constants --------------------------------------- 472 N = x.shape[axis] 473 N_FFT = xpf.next_fast_len(N + M - 1) 474 swap = np.arange(x.ndim) 475 swap[[axis, -1]] = [-1, axis] 476 477 # filters ----------------------------------------- 478 k = xp.arange(max(M, N)) 479 Wk2 = W ** ((k**2) / 2) 480 AWk2 = (A ** -k[:N]) * Wk2[:N] 481 FWk2 = xpf.fft( 482 xp.r_[Wk2[(N - 1) : 0 : -1], Wk2[:M]].conj(), 483 n=N_FFT, 484 ) 485 Wk2 = Wk2[:M] 486 487 # transform inputs -------------------------------- 488 x = x.transpose(*swap).copy() 489 x *= AWk2 490 y = xpf.fft(x, n=N_FFT) 491 y *= FWk2 492 z = xpf.ifft(y)[..., (N - 1) : (N + M - 1)] 493 z *= Wk2 494 495 # modulate CZT ------------------------------------ 496 z *= W ** (xp.arange(n0, n0 + M) * k0) 497 return z.transpose(*swap) 498 499 def block_ip(x: list[pxt.NDArray], k: pxt.NDArray, sign: int, axis: int) -> pxt.NDArray: 500 # Block-defined inner-product. 501 # 502 # y[:,...,:,k,:,...,:] = \sum_{n} x[:,...,:,n,:,...,:] W^{nk} 503 # 504 # `x`: list of chunks along transformed dimension. 505 # `k`: consecutive sequence of output frequencies along transformed dimension. 506 # `sign`: sign of the exponent. 507 # `axis`: transformed axis. 508 ndi = pxd.NDArrayInfo.from_obj(k) 509 xp = ndi.module() 510 511 chunks = tuple(_x.shape[axis] for _x in x) 512 N, M = sum(chunks), len(k) 513 W = xp.exp(sign * 2j * np.pi / N) 514 S = xp.cumsum(xp.r_[0, chunks]) 515 516 y = 0 517 for idx_n, _x in enumerate(x): 518 y += _mod_czt( 519 x=_x, 520 M=M, 521 A=W ** (-k[0]), 522 W=W, 523 n0=k[0], 524 k0=S[idx_n], 525 axis=axis, 526 ) 527 return y 528 529 sign = dict(fw=-1, bw=1)[mode] 530 xp = pxd.NDArrayInfo.DASK.module() 531 try: # `x` complex-valued 532 cdtype = pxrt.CWidth(x.dtype).value 533 except Exception: # `x` was real-valued 534 cdtype = pxrt.Width(x.dtype).complex.value 535 536 ip_ind = tuple(range(x.ndim)) 537 x_ind = list(range(x.ndim)) 538 x_ind[axis] = -1 # block_ip() along this axis 539 k = xp.arange(x.shape[axis], chunks=x.chunks[axis]) 540 k_ind = (ip_ind[axis],) 541 542 y = xp.blockwise( 543 *(block_ip, ip_ind), 544 *(x, x_ind, k, k_ind), 545 dtype=cdtype, 546 align_arrays=False, 547 concatenate=False, 548 meta=x._meta, 549 # extra kwargs for block_ip() 550 sign=sign, 551 axis=axis, 552 ) 553 return y