Source code for pyxu.operator.linop.fft.filter

  1import dataclasses
  2
  3import numpy as np
  4
  5import pyxu.info.deps as pxd
  6import pyxu.info.ptype as pxt
  7import pyxu.util as pxu
  8from pyxu.operator.linop.fft.fft import FFT
  9from pyxu.operator.linop.pad import Pad
 10from pyxu.operator.linop.stencil._stencil import _Stencil
 11from pyxu.operator.linop.stencil.stencil import Stencil
 12
 13__all__ = [
 14    "FFTCorrelate",
 15    "FFTConvolve",
 16]
 17
 18
 19KernelInfo = dataclasses.make_dataclass(
 20    "KernelInfo",
 21    fields=["_kernel", "_center"],
 22)
 23
 24
[docs] 25class FFTCorrelate(Stencil): 26 r""" 27 Multi-dimensional FFT-based correlation. 28 29 :py:class:`~pyxu.operator.FFTCorrelate` has the same interface as :py:class:`~pyxu.operator.Stencil`. 30 31 .. rubric:: Implementation Notes 32 33 * :py:class:`~pyxu.operator.FFTCorrelate` can scale to much larger kernels than :py:class:`~pyxu.operator.Stencil`. 34 * This implementation is most efficient with "constant" boundary conditions (default). 35 * Kernels must be small enough to fit in memory, i.e. unbounded kernels are not allowed. 36 * Kernels should be supplied an NUMPY/CUPY arrays. DASK arrays will be evaluated if provided. 37 * :py:class:`~pyxu.operator.FFTCorrelate` instances are **not arraymodule-agnostic**: they will only work with 38 NDArrays belonging to the same array module as `kernel`, or DASK arrays where the chunk backend matches the 39 init-supplied kernel backend. 40 * A warning is emitted if inputs must be cast to the kernel dtype. 41 * The input array is transformed by calling :py:class:`~pyxu.operator.FFT`. 42 * When operating on DASK inputs, the kernel DFT is computed per chunk at the best size to handle the inputs. 43 This was deemed preferable than pre-computing a huge kernel DFT once, then sending it to each worker process to 44 compute its chunk. 45 46 See Also 47 -------- 48 :py:class:`~pyxu.operator.Stencil`, 49 :py:class:`~pyxu.operator.FFTConvolve` 50 """ 51
[docs] 52 def __init__( 53 self, 54 dim_shape: pxt.NDArrayShape, 55 kernel: Stencil.KernelSpec, 56 center: _Stencil.IndexSpec, 57 mode: Pad.ModeSpec = "constant", 58 enable_warnings: bool = True, 59 **kwargs, 60 ): 61 r""" 62 Parameters 63 ---------- 64 dim_shape: NDArrayShape 65 (M1,...,MD) input dimensions. 66 kernel: ~pyxu.operator.Stencil.KernelSpec 67 Kernel coefficients. Two forms are accepted: 68 69 * NDArray of rank-:math:`D`: denotes a non-seperable kernel. 70 * tuple[NDArray_1, ..., NDArray_D]: a sequence of 1D kernels such that dimension[k] is filtered by kernel 71 `kernel[k]`, that is: 72 73 .. math:: 74 75 k = k_1 \otimes\cdots\otimes k_D, 76 77 or in Python: ``k = functools.reduce(numpy.multiply.outer, kernel)``. 78 79 center: ~pyxu.operator._Stencil.IndexSpec 80 (i1,...,iD) index of the kernel's center. 81 82 `center` defines how a kernel is overlaid on inputs to produce outputs. 83 84 mode: str, :py:class:`list` ( str ) 85 Boundary conditions. Multiple forms are accepted: 86 87 * str: unique mode shared amongst dimensions. 88 Must be one of: 89 90 * 'constant' (zero-padding) 91 * 'wrap' 92 * 'reflect' 93 * 'symmetric' 94 * 'edge' 95 * tuple[str, ...]: dimension[k] uses `mode[k]` as boundary condition. 96 97 (See :py:func:`numpy.pad` for details.) 98 enable_warnings: bool 99 If ``True``, emit a warning in case of precision mis-match issues. 100 kwargs: dict 101 Extra kwargs forwarded to :py:class:`~pyxu.operator.FFT`. 102 """ 103 super().__init__( 104 dim_shape=dim_shape, 105 kernel=kernel, 106 center=center, 107 mode=mode, 108 enable_warnings=enable_warnings, 109 ) 110 self._fft_kwargs = kwargs # Extra kwargs passed to FFT()
111
[docs] 112 def configure_dispatcher(self, **kwargs): 113 raise NotImplementedError("Irrelevant for FFT-backed filtering.")
114 115 # Helper routines (internal) ---------------------------------------------- 116 @staticmethod 117 def _compute_pad_width(_kernel, _center, _mode) -> Pad.WidthSpec: 118 N = _kernel[0].ndim 119 pad_width = [None] * N 120 for i in range(N): 121 if _mode[i] == "constant": 122 # FFT already implements padding with zeros to size N+K-1. 123 pad_width[i] = (0, 0) 124 else: 125 if len(_kernel) == 1: # non-seperable filter 126 n = _kernel[0].shape[i] 127 else: # seperable filter(s) 128 n = _kernel[i].size 129 130 # 1. Pad/Trim operators are shared amongst [apply,adjoint](): 131 # lhs/rhs are thus padded equally. 132 # 2. Pad width must match kernel dimensions to retain border effects. 133 pad_width[i] = (n - 1, n - 1) 134 return tuple(pad_width) 135 136 @staticmethod 137 def _init_fw(_kernel, _center) -> list: 138 # Initialize kernels used in apply(). 139 # The returned objects must have the following fields: 140 # * _kernel: ndarray[float] (D,) 141 # * _center: ndarray[int] (D,) 142 143 # Store kernels in convolution form. 144 _st_fw = [None] * len(_kernel) 145 _kernel, _center = Stencil._bw_equivalent(_kernel, _center) 146 for i, (k_fw, c_fw) in enumerate(zip(_kernel, _center)): 147 _st_fw[i] = KernelInfo(k_fw, c_fw) 148 return _st_fw 149 150 @staticmethod 151 def _init_bw(_kernel, _center) -> list: 152 # Initialize kernels used in adjoint(). 153 # The returned objects must have the following fields: 154 # * _kernel: ndarray[float] (D,) 155 # * _center: ndarray[int] (D,) 156 157 # Store kernels in convolution form. 158 _st_bw = [None] * len(_kernel) 159 for i, (k_bw, c_bw) in enumerate(zip(_kernel, _center)): 160 _st_bw[i] = KernelInfo(k_bw, c_bw) 161 return _st_bw 162 163 def _stencil_chain(self, x: pxt.NDArray, stencils: list) -> pxt.NDArray: 164 # Apply sequence of stencils to `x`. 165 # 166 # x: (..., M1,...,MD) 167 # z: (..., M1,...,MD) 168 169 # Contrary to Stencil._stencil_chain(), the `stencils` parameter is picklable directly. 170 def _chain(x, stencils, fft_kwargs): 171 xp = pxu.get_array_module(x) 172 xpf = FFT.fft_backend(xp) 173 174 # Compute constants ----------------------------------------------- 175 if uni_kernel := (len(stencils) == 1): 176 M = np.r_[stencils[0]._kernel.shape] 177 dim_rank = stencils[0]._kernel.ndim 178 else: 179 M = np.array([st._kernel.size for st in stencils]) 180 dim_rank = len(stencils) 181 Np = np.r_[x.shape[-dim_rank:]] 182 L = FFT.next_fast_len(Np + M - 1, xp=xp) 183 axes = tuple(range(-dim_rank, 0)) 184 185 # Apply stencils in DFT domain ------------------------------------ 186 fft = FFT(L, axes, **fft_kwargs) 187 Z = fft.capply(x) 188 if uni_kernel: 189 if len(M) == 1: # 1D kernel 190 K = xpf.fft(stencils[0]._kernel, n=L[0]) 191 else: # ND kernel 192 K = fft.capply(stencils[0]._kernel) 193 Z *= K 194 else: 195 for ax, st in enumerate(stencils): 196 K = xpf.fft(st._kernel, n=L[ax], axis=ax) 197 Z *= K 198 z = fft.cpinv(Z, damp=0).real 199 200 # Extract ROI ----------------------------------------------------- 201 if uni_kernel: 202 center = stencils[0]._center 203 else: 204 center = [st._center[i] for (i, st) in enumerate(stencils)] 205 extract = [slice(c, c + n) for (c, n) in zip(center, Np)] 206 return z[..., *extract] 207 208 ndi = pxd.NDArrayInfo.from_obj(x) 209 if ndi == pxd.NDArrayInfo.DASK: 210 # Compute (depth,boundary) values for [overlap,trim_internal]() 211 N_stack = x.ndim - self.dim_rank 212 depth = {ax: 0 for ax in range(x.ndim)} 213 for ax in range(self.dim_rank): 214 if len(stencils) == 1: # non-seperable filter 215 n = stencils[0]._kernel.shape[ax] 216 else: # seperable filter(s) 217 n = stencils[ax]._kernel.size 218 c = stencils[ax]._center[ax] 219 max_dist = max(c, n - c) 220 depth[N_stack + ax] = max_dist 221 boundary = 0 222 223 xp = ndi.module() 224 x_overlap = xp.overlap.overlap( # Share padding between chunks 225 x, 226 depth=depth, 227 boundary=boundary, 228 ) 229 z_overlap = x_overlap.map_blocks( # Map _chain() to each chunk 230 func=_chain, 231 dtype=x.dtype, 232 chunks=x_overlap.chunks, 233 meta=x._meta, 234 # extra _chain() kwargs ------------------- 235 stencils=stencils, 236 fft_kwargs=self._fft_kwargs, 237 ) 238 z = xp.overlap.trim_internal( # Trim inter-chunk excess 239 z_overlap, 240 axes=depth, 241 boundary=boundary, 242 ) 243 else: 244 z = _chain(x, stencils, self._fft_kwargs) 245 return z
246 247
[docs] 248class FFTConvolve(FFTCorrelate): 249 r""" 250 Multi-dimensional FFT-based convolution. 251 252 :py:class:`~pyxu.operator.FFTConvolve` has the same interface as :py:class:`~pyxu.operator.Convolve`. 253 254 See :py:class:`~pyxu.operator.FFTCorrelate` for implementation notes. 255 256 See Also 257 -------- 258 :py:class:`~pyxu.operator.Stencil`, 259 :py:class:`~pyxu.operator.FFTCorrelate` 260 """ 261
[docs] 262 def __init__( 263 self, 264 dim_shape: pxt.NDArrayShape, 265 kernel: Stencil.KernelSpec, 266 center: _Stencil.IndexSpec, 267 mode: Pad.ModeSpec = "constant", 268 enable_warnings: bool = True, 269 **kwargs, 270 ): 271 r""" 272 See :py:meth:`~pyxu.operator.FFTCorrelate.__init__` for a description of the arguments. 273 """ 274 super().__init__( 275 dim_shape=dim_shape, 276 kernel=kernel, 277 center=center, 278 mode=mode, 279 enable_warnings=enable_warnings, 280 **kwargs, 281 ) 282 283 # flip FW/BW kernels (& centers) 284 self._st_fw, self._st_bw = self._st_bw, self._st_fw