1import dataclasses
2
3import numpy as np
4
5import pyxu.info.deps as pxd
6import pyxu.info.ptype as pxt
7import pyxu.util as pxu
8from pyxu.operator.linop.fft.fft import FFT
9from pyxu.operator.linop.pad import Pad
10from pyxu.operator.linop.stencil._stencil import _Stencil
11from pyxu.operator.linop.stencil.stencil import Stencil
12
13__all__ = [
14 "FFTCorrelate",
15 "FFTConvolve",
16]
17
18
19KernelInfo = dataclasses.make_dataclass(
20 "KernelInfo",
21 fields=["_kernel", "_center"],
22)
23
24
[docs]
25class FFTCorrelate(Stencil):
26 r"""
27 Multi-dimensional FFT-based correlation.
28
29 :py:class:`~pyxu.operator.FFTCorrelate` has the same interface as :py:class:`~pyxu.operator.Stencil`.
30
31 .. rubric:: Implementation Notes
32
33 * :py:class:`~pyxu.operator.FFTCorrelate` can scale to much larger kernels than :py:class:`~pyxu.operator.Stencil`.
34 * This implementation is most efficient with "constant" boundary conditions (default).
35 * Kernels must be small enough to fit in memory, i.e. unbounded kernels are not allowed.
36 * Kernels should be supplied an NUMPY/CUPY arrays. DASK arrays will be evaluated if provided.
37 * :py:class:`~pyxu.operator.FFTCorrelate` instances are **not arraymodule-agnostic**: they will only work with
38 NDArrays belonging to the same array module as `kernel`, or DASK arrays where the chunk backend matches the
39 init-supplied kernel backend.
40 * A warning is emitted if inputs must be cast to the kernel dtype.
41 * The input array is transformed by calling :py:class:`~pyxu.operator.FFT`.
42 * When operating on DASK inputs, the kernel DFT is computed per chunk at the best size to handle the inputs.
43 This was deemed preferable than pre-computing a huge kernel DFT once, then sending it to each worker process to
44 compute its chunk.
45
46 See Also
47 --------
48 :py:class:`~pyxu.operator.Stencil`,
49 :py:class:`~pyxu.operator.FFTConvolve`
50 """
51
[docs]
52 def __init__(
53 self,
54 dim_shape: pxt.NDArrayShape,
55 kernel: Stencil.KernelSpec,
56 center: _Stencil.IndexSpec,
57 mode: Pad.ModeSpec = "constant",
58 enable_warnings: bool = True,
59 **kwargs,
60 ):
61 r"""
62 Parameters
63 ----------
64 dim_shape: NDArrayShape
65 (M1,...,MD) input dimensions.
66 kernel: ~pyxu.operator.Stencil.KernelSpec
67 Kernel coefficients. Two forms are accepted:
68
69 * NDArray of rank-:math:`D`: denotes a non-seperable kernel.
70 * tuple[NDArray_1, ..., NDArray_D]: a sequence of 1D kernels such that dimension[k] is filtered by kernel
71 `kernel[k]`, that is:
72
73 .. math::
74
75 k = k_1 \otimes\cdots\otimes k_D,
76
77 or in Python: ``k = functools.reduce(numpy.multiply.outer, kernel)``.
78
79 center: ~pyxu.operator._Stencil.IndexSpec
80 (i1,...,iD) index of the kernel's center.
81
82 `center` defines how a kernel is overlaid on inputs to produce outputs.
83
84 mode: str, :py:class:`list` ( str )
85 Boundary conditions. Multiple forms are accepted:
86
87 * str: unique mode shared amongst dimensions.
88 Must be one of:
89
90 * 'constant' (zero-padding)
91 * 'wrap'
92 * 'reflect'
93 * 'symmetric'
94 * 'edge'
95 * tuple[str, ...]: dimension[k] uses `mode[k]` as boundary condition.
96
97 (See :py:func:`numpy.pad` for details.)
98 enable_warnings: bool
99 If ``True``, emit a warning in case of precision mis-match issues.
100 kwargs: dict
101 Extra kwargs forwarded to :py:class:`~pyxu.operator.FFT`.
102 """
103 super().__init__(
104 dim_shape=dim_shape,
105 kernel=kernel,
106 center=center,
107 mode=mode,
108 enable_warnings=enable_warnings,
109 )
110 self._fft_kwargs = kwargs # Extra kwargs passed to FFT()
111
114
115 # Helper routines (internal) ----------------------------------------------
116 @staticmethod
117 def _compute_pad_width(_kernel, _center, _mode) -> Pad.WidthSpec:
118 N = _kernel[0].ndim
119 pad_width = [None] * N
120 for i in range(N):
121 if _mode[i] == "constant":
122 # FFT already implements padding with zeros to size N+K-1.
123 pad_width[i] = (0, 0)
124 else:
125 if len(_kernel) == 1: # non-seperable filter
126 n = _kernel[0].shape[i]
127 else: # seperable filter(s)
128 n = _kernel[i].size
129
130 # 1. Pad/Trim operators are shared amongst [apply,adjoint]():
131 # lhs/rhs are thus padded equally.
132 # 2. Pad width must match kernel dimensions to retain border effects.
133 pad_width[i] = (n - 1, n - 1)
134 return tuple(pad_width)
135
136 @staticmethod
137 def _init_fw(_kernel, _center) -> list:
138 # Initialize kernels used in apply().
139 # The returned objects must have the following fields:
140 # * _kernel: ndarray[float] (D,)
141 # * _center: ndarray[int] (D,)
142
143 # Store kernels in convolution form.
144 _st_fw = [None] * len(_kernel)
145 _kernel, _center = Stencil._bw_equivalent(_kernel, _center)
146 for i, (k_fw, c_fw) in enumerate(zip(_kernel, _center)):
147 _st_fw[i] = KernelInfo(k_fw, c_fw)
148 return _st_fw
149
150 @staticmethod
151 def _init_bw(_kernel, _center) -> list:
152 # Initialize kernels used in adjoint().
153 # The returned objects must have the following fields:
154 # * _kernel: ndarray[float] (D,)
155 # * _center: ndarray[int] (D,)
156
157 # Store kernels in convolution form.
158 _st_bw = [None] * len(_kernel)
159 for i, (k_bw, c_bw) in enumerate(zip(_kernel, _center)):
160 _st_bw[i] = KernelInfo(k_bw, c_bw)
161 return _st_bw
162
163 def _stencil_chain(self, x: pxt.NDArray, stencils: list) -> pxt.NDArray:
164 # Apply sequence of stencils to `x`.
165 #
166 # x: (..., M1,...,MD)
167 # z: (..., M1,...,MD)
168
169 # Contrary to Stencil._stencil_chain(), the `stencils` parameter is picklable directly.
170 def _chain(x, stencils, fft_kwargs):
171 xp = pxu.get_array_module(x)
172 xpf = FFT.fft_backend(xp)
173
174 # Compute constants -----------------------------------------------
175 if uni_kernel := (len(stencils) == 1):
176 M = np.r_[stencils[0]._kernel.shape]
177 dim_rank = stencils[0]._kernel.ndim
178 else:
179 M = np.array([st._kernel.size for st in stencils])
180 dim_rank = len(stencils)
181 Np = np.r_[x.shape[-dim_rank:]]
182 L = FFT.next_fast_len(Np + M - 1, xp=xp)
183 axes = tuple(range(-dim_rank, 0))
184
185 # Apply stencils in DFT domain ------------------------------------
186 fft = FFT(L, axes, **fft_kwargs)
187 Z = fft.capply(x)
188 if uni_kernel:
189 if len(M) == 1: # 1D kernel
190 K = xpf.fft(stencils[0]._kernel, n=L[0])
191 else: # ND kernel
192 K = fft.capply(stencils[0]._kernel)
193 Z *= K
194 else:
195 for ax, st in enumerate(stencils):
196 K = xpf.fft(st._kernel, n=L[ax], axis=ax)
197 Z *= K
198 z = fft.cpinv(Z, damp=0).real
199
200 # Extract ROI -----------------------------------------------------
201 if uni_kernel:
202 center = stencils[0]._center
203 else:
204 center = [st._center[i] for (i, st) in enumerate(stencils)]
205 extract = [slice(c, c + n) for (c, n) in zip(center, Np)]
206 return z[..., *extract]
207
208 ndi = pxd.NDArrayInfo.from_obj(x)
209 if ndi == pxd.NDArrayInfo.DASK:
210 # Compute (depth,boundary) values for [overlap,trim_internal]()
211 N_stack = x.ndim - self.dim_rank
212 depth = {ax: 0 for ax in range(x.ndim)}
213 for ax in range(self.dim_rank):
214 if len(stencils) == 1: # non-seperable filter
215 n = stencils[0]._kernel.shape[ax]
216 else: # seperable filter(s)
217 n = stencils[ax]._kernel.size
218 c = stencils[ax]._center[ax]
219 max_dist = max(c, n - c)
220 depth[N_stack + ax] = max_dist
221 boundary = 0
222
223 xp = ndi.module()
224 x_overlap = xp.overlap.overlap( # Share padding between chunks
225 x,
226 depth=depth,
227 boundary=boundary,
228 )
229 z_overlap = x_overlap.map_blocks( # Map _chain() to each chunk
230 func=_chain,
231 dtype=x.dtype,
232 chunks=x_overlap.chunks,
233 meta=x._meta,
234 # extra _chain() kwargs -------------------
235 stencils=stencils,
236 fft_kwargs=self._fft_kwargs,
237 )
238 z = xp.overlap.trim_internal( # Trim inter-chunk excess
239 z_overlap,
240 axes=depth,
241 boundary=boundary,
242 )
243 else:
244 z = _chain(x, stencils, self._fft_kwargs)
245 return z
246
247
[docs]
248class FFTConvolve(FFTCorrelate):
249 r"""
250 Multi-dimensional FFT-based convolution.
251
252 :py:class:`~pyxu.operator.FFTConvolve` has the same interface as :py:class:`~pyxu.operator.Convolve`.
253
254 See :py:class:`~pyxu.operator.FFTCorrelate` for implementation notes.
255
256 See Also
257 --------
258 :py:class:`~pyxu.operator.Stencil`,
259 :py:class:`~pyxu.operator.FFTCorrelate`
260 """
261
[docs]
262 def __init__(
263 self,
264 dim_shape: pxt.NDArrayShape,
265 kernel: Stencil.KernelSpec,
266 center: _Stencil.IndexSpec,
267 mode: Pad.ModeSpec = "constant",
268 enable_warnings: bool = True,
269 **kwargs,
270 ):
271 r"""
272 See :py:meth:`~pyxu.operator.FFTCorrelate.__init__` for a description of the arguments.
273 """
274 super().__init__(
275 dim_shape=dim_shape,
276 kernel=kernel,
277 center=center,
278 mode=mode,
279 enable_warnings=enable_warnings,
280 **kwargs,
281 )
282
283 # flip FW/BW kernels (& centers)
284 self._st_fw, self._st_bw = self._st_bw, self._st_fw