1import collections.abc as cabc
2import functools
3import operator
4import typing as typ
5import warnings
6
7import numpy as np
8
9import pyxu.abc as pxa
10import pyxu.info.deps as pxd
11import pyxu.info.ptype as pxt
12import pyxu.info.warning as pxw
13import pyxu.runtime as pxrt
14import pyxu.util as pxu
15from pyxu.operator.linop.pad import Pad
16from pyxu.operator.linop.select import Trim
17from pyxu.operator.linop.stencil._stencil import _Stencil
18
19__all__ = [
20 "Stencil",
21 "Correlate",
22 "Convolve",
23]
24
25
[docs]
26class Stencil(pxa.SquareOp):
27 r"""
28 Multi-dimensional JIT-compiled stencil.
29
30 Stencils are a common computational pattern in which array elements are updated according to some fixed pattern
31 called the *stencil kernel*. Notable examples include multi-dimensional convolutions, correlations and finite
32 differences. (See Notes for a definition.)
33
34 Stencils can be evaluated efficiently on CPU/GPU architectures.
35
36 Several boundary conditions are supported. Moreover boundary conditions may differ per axis.
37
38 .. rubric:: Implementation Notes
39
40 * Numba (and its ``@stencil`` decorator) is used to compile efficient machine code from a stencil kernel
41 specification. This has 2 consequences:
42
43 * :py:class:`~pyxu.operator.Stencil` instances are **not arraymodule-agnostic**: they will only work with NDArrays
44 belonging to the same array module as `kernel`.
45 * Compiled stencils are not **precision-agnostic**: they will only work on NDArrays with the same dtype as
46 `kernel`. A warning is emitted if inputs must be cast to the kernel dtype.
47
48 * Stencil kernels can be specified in two forms: (See :py:meth:`~pyxu.operator.Stencil.__init__` for details.)
49
50 * A single non-separable :math:`D`-dimensional kernel :math:`k[i_{1},\ldots,i_{D}]` of shape
51 :math:`(K_{1},\ldots,K_{D})`.
52 * A sequence of separable :math:`1`-dimensional kernel(s) :math:`k_{d}[i]` of shapes
53 :math:`(K_{1},),\ldots,(K_{D},)` such that :math:`k[i_{1},\ldots,i_{D}] = \Pi_{d=1}^{D} k_{d}[i_{d}]`.
54
55 .. rubric:: Mathematical Notes
56
57 Given a :math:`D`-dimensional array :math:`x\in\mathbb{R}^{N_1\times\cdots\times N_D}` and kernel
58 :math:`k\in\mathbb{R}^{K_1\times\cdots\times K_D}` with center :math:`(c_1, \ldots, c_D)`, the output of the stencil
59 operator is an array :math:`y\in\mathbb{R}^{N_1\times\cdots\times N_D}` given by:
60
61 .. math::
62
63 y[i_{1},\ldots,i_{D}]
64 =
65 \sum_{q_{1},\ldots,q_{D}=0}^{K_{1},\ldots,K_{D}}
66 x[i_{1} - c_{1} + q_{1},\ldots,i_{D} - c_{D} + q_{D}]
67 \,\cdot\,
68 k[q_{1},\ldots,q_{D}].
69
70 This corresponds to a *correlation* with a shifted version of the kernel :math:`k`.
71
72 Numba stencils assume summation terms involving out-of-bound indices of :math:`x` are set to zero.
73 :py:class:`~pyxu.operator.Stencil` lifts this constraint by extending the stencil to boundary values via pre-padding
74 and post-trimming. Concretely, any stencil operator :math:`S` instantiated with :py:class:`~pyxu.operator.Stencil`
75 can be written as the composition :math:`S = TS_0P`, where :math:`(T, S_0, P)` are trimming, stencil with
76 zero-padding conditions, and padding operators respectively. This construct allows
77 :py:class:`~pyxu.operator.Stencil` to handle complex boundary conditions under which :math:`S` *may not be a proper
78 stencil* (i.e., varying kernel) but can still be implemented efficiently via a proper stencil upon appropriate
79 trimming/padding.
80
81 For example consider the decomposition of the following (improper) stencil operator:
82
83 .. code-block:: python3
84
85 >>> S = Stencil(
86 ... dim_shape=(5,),
87 ... kernel=np.r_[1, 2, -3],
88 ... center=(2,),
89 ... mode="reflect",
90 ... )
91
92 >>> S.asarray()
93 [[-3 2 1 0 0]
94 [ 2 -2 0 0 0]
95 [ 1 2 -3 0 0]
96 [ 0 1 2 -3 0]
97 [ 0 0 1 2 -3]] # Improper stencil (kernel varies across rows)
98 =
99 [[0 0 1 0 0 0 0 0 0]
100 [0 0 0 1 0 0 0 0 0]
101 [0 0 0 0 1 0 0 0 0]
102 [0 0 0 0 0 1 0 0 0]
103 [0 0 0 0 0 0 1 0 0]] # Trimming
104 @
105 [[-3 0 0 0 0 0 0 0 0]
106 [ 2 -3 0 0 0 0 0 0 0]
107 [ 1 2 -3 0 0 0 0 0 0]
108 [ 0 1 2 -3 0 0 0 0 0]
109 [ 0 0 1 2 -3 0 0 0 0]
110 [ 0 0 0 1 2 -3 0 0 0]
111 [ 0 0 0 0 1 2 -3 0 0]
112 [ 0 0 0 0 0 1 2 -3 0]
113 [ 0 0 0 0 0 0 1 2 -3]] # Proper stencil (Toeplitz structure)
114 @
115 [[0 0 1 0 0]
116 [0 1 0 0 0]
117 [1 0 0 0 0]
118 [0 1 0 0 0]
119 [0 0 1 0 0]
120 [0 0 0 1 0]
121 [0 0 0 0 1]
122 [0 0 0 1 0]
123 [0 0 1 0 0]] # Padding with reflect mode.
124
125 Note that the adjoint of a stencil operator may not necessarily be a stencil operator, or the associated center and
126 boundary conditions may be hard to predict. For example, the adjoint of the stencil operator defined above is given
127 by:
128
129 .. code-block:: python3
130
131 >>> S.T.asarray()
132 [[-3 2 1 0 0],
133 [ 2 -2 2 1 0],
134 [ 1 0 -3 2 1],
135 [ 0 0 0 -3 2],
136 [ 0 0 0 0 -3]]
137
138 which resembles a stencil with time-reversed kernel, but with weird (if not improper) boundary conditions. This can
139 also be seen from the fact that :math:`S^\ast = P^\ast S_0^\ast T^\ast = P^\ast S_0^\ast P_0,` and :math:`P^\ast` is
140 in general not a trimming operator. (See :py:class:`~pyxu.operator.Pad`.)
141
142 The same holds for gram/cogram operators. Consider indeed the following order-1 backward finite-difference operator
143 with zero-padding:
144
145 .. code-block:: python3
146
147 >>> S = Stencil(
148 ... dim_shape=(5,),
149 ... kernel=np.r_[-1, 1],
150 ... center=(0,),
151 ... mode='constant',
152 ... )
153
154 >>> S.gram().asarray()
155 [[ 1 -1 0 0 0]
156 [-1 2 -1 0 0]
157 [ 0 -1 2 -1 0]
158 [ 0 0 -1 2 -1]
159 [ 0 0 0 -1 2]]
160
161 We observe that the Gram differs from the order 2 centered finite-difference operator. (Reduced-order derivative on
162 one side.)
163
164 Example
165 -------
166
167 * **Moving average of a 1D signal**
168
169 Let :math:`x[n]` denote a 1D signal. The weighted moving average
170
171 .. math::
172
173 y[n] = x[n-2] + 2 x[n-1] + 3 x[n]
174
175 can be viewed as the output of the 3-point stencil of kernel :math:`k = [1, 2, 3]`.
176
177 .. code-block:: python3
178
179 import numpy as np
180 from pyxu.operator import Stencil
181
182 x = np.arange(10) # [0 1 2 3 4 5 6 7 8 9]
183
184 op = Stencil(
185 dim_shape=x.shape,
186 kernel=np.array([1, 2, 3]),
187 center=(2,), # k[2] applies on x[n]
188 )
189
190 y = op.apply(x) # [0 3 8 14 20 26 32 38 44 50]
191
192
193 * **Non-seperable image filtering**
194
195 Let :math:`x[n, m]` denote a 2D image. The blurred image
196
197 .. math::
198
199 y[n, m] = 2 x[n-1,m-1] + 3 x[n-1,m+1] + 4 x[n+1,m-1] + 5 x[n+1,m+1]
200
201 can be viewed as the output of the 9-point stencil
202
203 .. math::
204
205 k =
206 \left[
207 \begin{array}{ccc}
208 2 & 0 & 3 \\
209 0 & 0 & 0 \\
210 4 & 0 & 5
211 \end{array}
212 \right].
213
214 .. code-block:: python3
215
216 import numpy as np
217 from pyxu.operator import Stencil
218
219 x = np.arange(64).reshape(8, 8) # square image
220 # [[ 0 1 2 3 4 5 6 7]
221 # [ 8 9 10 11 12 13 14 15]
222 # [16 17 18 19 20 21 22 23]
223 # [24 25 26 27 28 29 30 31]
224 # [32 33 34 35 36 37 38 39]
225 # [40 41 42 43 44 45 46 47]
226 # [48 49 50 51 52 53 54 55]
227 # [56 57 58 59 60 61 62 63]]
228
229 op = Stencil(
230 dim_shape=x.shape,
231 kernel=np.array(
232 [[2, 0, 3],
233 [0, 0, 0],
234 [4, 0, 5]]),
235 center=(1, 1), # k[1, 1] applies on x[n, m]
236 )
237
238 y = op.apply(x)
239 # [[ 45 82 91 100 109 118 127 56 ]
240 # [ 88 160 174 188 202 216 230 100 ]
241 # [152 272 286 300 314 328 342 148 ]
242 # [216 384 398 412 426 440 454 196 ]
243 # [280 496 510 524 538 552 566 244 ]
244 # [344 608 622 636 650 664 678 292 ]
245 # [408 720 734 748 762 776 790 340 ]
246 # [147 246 251 256 261 266 271 108 ]]
247
248 * **Seperable image filtering**
249
250 Let :math:`x[n, m]` denote a 2D image. The warped image
251
252 .. math::
253
254 \begin{align*}
255 y[n, m] = & + 4 x[n-1,m-1] + 5 x[n-1,m] + 6 x[n-1,m+1] \\
256 & + 8 x[n ,m-1] + 10 x[n ,m] + 12 x[n ,m+1] \\
257 & + 12 x[n+1,m-1] + 15 x[n+1,m] + 18 x[n+1,m+1]
258 \end{align*}
259
260 can be viewed as the output of the 9-point stencil
261
262 .. math::
263
264 k_{2D} =
265 \left[
266 \begin{array}{ccc}
267 4 & 5 & 6 \\
268 8 & 10 & 12 \\
269 12 & 15 & 18 \\
270 \end{array}
271 \right].
272
273 Notice however that :math:`y[n, m]` can be implemented more efficiently by factoring the 9-point stencil as a
274 cascade of two 3-point stencils:
275
276 .. math::
277
278 k_{2D}
279 = k_{1} k_{2}^{T}
280 = \left[
281 \begin{array}{c}
282 1 \\ 2 \\ 3
283 \end{array}
284 \right]
285 \left[
286 \begin{array}{c}
287 4 & 5 & 6
288 \end{array}
289 \right].
290
291 Seperable stencils are supported and should be preferred when applicable.
292
293 .. code-block:: python3
294
295 import numpy as np
296 from pyxu.operator import Stencil
297
298 x = np.arange(64).reshape(8, 8) # square image
299 # [[ 0 1 2 3 4 5 6 7]
300 # [ 8 9 10 11 12 13 14 15]
301 # [16 17 18 19 20 21 22 23]
302 # [24 25 26 27 28 29 30 31]
303 # [32 33 34 35 36 37 38 39]
304 # [40 41 42 43 44 45 46 47]
305 # [48 49 50 51 52 53 54 55]
306 # [56 57 58 59 60 61 62 63]]
307
308 op_2D = Stencil( # using non-seperable kernel
309 dim_shape=x.shape,
310 kernel=np.array(
311 [[ 4, 5, 6],
312 [ 8, 10, 12],
313 [12, 15, 18]]),
314 center=(1, 1), # k[1, 1] applies on x[n, m]
315 )
316 op_sep = Stencil( # using seperable kernels
317 dim_shape=x.shape,
318 kernel=[
319 np.array([1, 2, 3]), # k1: stencil along 1st axis
320 np.array([4, 5, 6]), # k2: stencil along 2nd axis
321 ],
322 center=(1, 1), # k1[1] * k2[1] applies on x[n, m]
323 )
324
325 y_2D = op_2D.apply(x)
326 y_sep = op_sep.apply(x) # np.allclose(y_2D, y_sep) -> True
327 # [[ 294 445 520 595 670 745 820 511 ]
328 # [ 740 1062 1152 1242 1332 1422 1512 930 ]
329 # [1268 1782 1872 1962 2052 2142 2232 1362 ]
330 # [1796 2502 2592 2682 2772 2862 2952 1794 ]
331 # [2324 3222 3312 3402 3492 3582 3672 2226 ]
332 # [2852 3942 4032 4122 4212 4302 4392 2658 ]
333 # [3380 4662 4752 4842 4932 5022 5112 3090 ]
334 # [1778 2451 2496 2541 2586 2631 2676 1617 ]]
335
336 .. Warning::
337
338 For large, non-separable kernels, stencil compilation can be time-consuming. Depending on your computer's
339 architecture, using the :py:class:~pyxu.operator.FFTCorrelate operator might offer a more efficient solution.
340 However, the performance improvement varies, so we recommend evaluating this alternative in your specific
341 environment.
342
343 See Also
344 --------
345 :py:class:`~pyxu.operator.Convolve`,
346 :py:class:`~pyxu.operator._Stencil`,
347 :py:class:`~pyxu.operator.FFTCorrelate`,
348 :py:class:`~pyxu.operator.FFTConvolve`
349 """
350
351 KernelSpec = typ.Union[
352 pxt.NDArray, # (k1,...,kD) non-seperable kernel
353 cabc.Sequence[pxt.NDArray], # [(k1,), ..., (kD,)] seperable kernels
354 ]
355
[docs]
356 def __init__(
357 self,
358 dim_shape: pxt.NDArrayShape,
359 kernel: KernelSpec,
360 center: _Stencil.IndexSpec,
361 mode: Pad.ModeSpec = "constant",
362 enable_warnings: bool = True,
363 ):
364 r"""
365 Parameters
366 ----------
367 dim_shape: NDArrayShape
368 (M1,...,MD) input dimensions.
369 kernel: ~pyxu.operator.Stencil.KernelSpec
370 Stencil coefficients. Two forms are accepted:
371
372 * NDArray of rank-:math:`D`: denotes a non-seperable stencil.
373 * tuple[NDArray_1, ..., NDArray_D]: a sequence of 1D stencils such that dimension[k] is filtered by stencil
374 `kernel[k]`, that is:
375
376 .. math::
377
378 k = k_1 \otimes\cdots\otimes k_D,
379
380 or in Python: ``k = functools.reduce(numpy.multiply.outer, kernel)``.
381
382 center: ~pyxu.operator._Stencil.IndexSpec
383 (i1,...,iD) index of the stencil's center.
384
385 `center` defines how a kernel is overlaid on inputs to produce outputs.
386
387 mode: str, :py:class:`list` ( str )
388 Boundary conditions. Multiple forms are accepted:
389
390 * str: unique mode shared amongst dimensions.
391 Must be one of:
392
393 * 'constant' (zero-padding)
394 * 'wrap'
395 * 'reflect'
396 * 'symmetric'
397 * 'edge'
398 * tuple[str, ...]: dimension[k] uses `mode[k]` as boundary condition.
399
400 (See :py:func:`numpy.pad` for details.)
401 enable_warnings: bool
402 If ``True``, emit a warning in case of precision mis-match issues.
403 """
404 super().__init__(
405 dim_shape=dim_shape,
406 codim_shape=dim_shape,
407 )
408 _kernel, _center, _mode = self._canonical_repr(self.dim_shape, kernel, center, mode)
409
410 # Pad/Trim operators
411 pad_width = self._compute_pad_width(_kernel, _center, _mode)
412 self._pad = Pad(
413 dim_shape=dim_shape,
414 pad_width=pad_width,
415 mode=_mode,
416 )
417 self._trim = Trim(
418 dim_shape=self._pad.codim_shape,
419 trim_width=pad_width,
420 )
421
422 # Kernels (These _Stencil() instances are not used as-is in apply/adjoint calls: their ._[kernel,center]
423 # attributes are used directly there instead to bypass Numba serialization limits. These _Stencil() objects are
424 # used however for all other Operator public methods.)
425 # It is moreover advantageous to instantiate them once here to cache JIT-compile kernels upfront.
426 self._st_fw = self._init_fw(_kernel, _center)
427 self._st_bw = self._init_bw(_kernel, _center)
428
429 self._dispatch_params = dict() # Extra kwargs passed to _Stencil.apply()
430 self._dtype = _kernel[0].dtype # useful constant
431 self._enable_warnings = bool(enable_warnings)
432
433 # We know a crude Lipschitz bound by default. Since computing it takes (code) space,
434 # the estimate is computed as a special case of estimate_lipschitz()
435 self.lipschitz = self.estimate_lipschitz(__rule=True)
436
437 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
438 x = self._pad.apply(arr)
439 y = self._stencil_chain(
440 x=self._cast_warn(x),
441 stencils=self._st_fw,
442 )
443 z = self._trim.apply(y)
444 return z
445
446 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
447 x = self._trim.adjoint(arr)
448 y = self._stencil_chain(
449 x=self._cast_warn(x),
450 stencils=self._st_bw,
451 )
452 z = self._pad.adjoint(y)
453 return z
454
488
489 def estimate_lipschitz(self, **kwargs) -> pxt.Real:
490 no_eval = "__rule" in kwargs
491 if no_eval:
492 # Analytic upper bound from Young's convolution inequality:
493 # \norm{x \ast h}{2} \le \norm{x}{2}\norm{h}{1}
494 #
495 # -> L \le \norm{h}{1}
496 kernels = [st._kernel for st in self._st_fw]
497 kernel = functools.reduce(operator.mul, kernels, 1)
498 L_st = np.linalg.norm(pxu.to_NUMPY(kernel).reshape(-1), ord=1)
499
500 L_pad = self._pad.lipschitz
501 L_trim = self._trim.lipschitz
502
503 L = L_trim * L_st * L_pad # upper bound
504 else:
505 L = super().estimate_lipschitz(**kwargs)
506 return L
507
508 def asarray(self, **kwargs) -> pxt.NDArray:
509 # Stencil.apply() prefers precision provided at init-time.
510 xp = pxu.get_array_module(self._st_fw[0]._kernel)
511 _A = super().asarray(xp=xp, dtype=self._dtype)
512
513 xp = kwargs.get("xp", pxd.NDArrayInfo.NUMPY.module())
514 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value)
515 A = xp.array(pxu.to_NUMPY(_A), dtype=dtype)
516 return A
517
518 def trace(self, **kwargs) -> pxt.Real:
519 if all(m == "constant" for m in self._pad._mode):
520 # tr = (kernel center coefficient) * dim_size
521 tr = functools.reduce(
522 operator.mul,
523 [st._kernel[tuple(st._center)] for st in self._st_fw],
524 1,
525 )
526 tr *= self.dim_size
527 else:
528 # Standard algorithm, with computations restricted to precision supported by
529 # Stencil.apply().
530 kwargs.update(dtype=self._dtype)
531 tr = super().trace(**kwargs)
532 return float(tr)
533
534 # Helper routines (public) ------------------------------------------------
535 @property
536 def kernel(self) -> KernelSpec:
537 r"""
538 Stencil kernel coefficients.
539
540 Returns
541 -------
542 kern: ~pyxu.operator.Stencil.KernelSpec
543 Stencil coefficients.
544
545 If the kernel is non-seperable, a single array is returned.
546 Otherwise :math:`D` arrays are returned, one per axis.
547 """
548 if len(self._st_fw) == 1:
549 kern = self._st_fw[0]._kernel
550 else:
551 kern = [st._kernel for st in self._st_fw]
552 return kern
553
554 @property
555 def center(self) -> _Stencil.IndexSpec:
556 """
557 Stencil central position.
558
559 Returns
560 -------
561 ctr: ~pyxu.operator._Stencil.IndexSpec
562 Stencil central position.
563 """
564 if len(self._st_fw) == 1:
565 ctr = self._st_fw[0]._center
566 else:
567 ctr = [st._center[d] for (d, st) in enumerate(self._st_fw)]
568 return tuple(ctr)
569
[docs]
570 def visualize(self) -> str:
571 r"""
572 Show the :math:`D`-dimensional stencil kernel.
573
574 The stencil's center is identified by surrounding parentheses.
575
576 Example
577 -------
578 .. code-block:: python3
579
580 S = Stencil(
581 dim_shape=(5, 6),
582 kernel=[
583 np.r_[3, 2, 1],
584 np.r_[2, -1, 3, 1],
585 ],
586 center=(1, 2),
587 )
588 print(S.visualize()) # [[6.0 -3.0 9.0 3.0]
589 # [4.0 -2.0 (6.0) 2.0]
590 # [2.0 -1.0 3.0 1.0]]
591 """
592 kernels = [st._kernel for st in self._st_fw]
593 kernel = functools.reduce(operator.mul, kernels, 1)
594
595 kernel = pxu.to_NUMPY(kernel).astype(str)
596 kernel[self.center] = "(" + kernel[self.center] + ")"
597
598 kern = np.array2string(kernel).replace("'", "")
599 return kern
600
601 # Helper routines (internal) ----------------------------------------------
602 @staticmethod
603 def _canonical_repr(dim_shape, kernel, center, mode):
604 # Create canonical representations
605 # * `_kernel`: list[ndarray[float], ...]
606 # * `_center`: list[ndarray[int], ...]
607 # * `_mode`: list[str, ...]
608 #
609 # `dim_shape`` is already assumed in tuple-form.
610 N = len(dim_shape)
611 assert len(center) == N
612
613 kernel = pxu.compute(kernel, traverse=True)
614 try:
615 # array input -> non-seperable filter
616 pxu.get_array_module(kernel)
617 assert kernel.ndim == N
618 _kernel = [kernel]
619 _center = [np.array(center, dtype=int)]
620 except Exception:
621 # sequence input -> seperable filter(s)
622 assert len(kernel) == N # one filter per dimension
623
624 _kernel = [None] * N
625 for i in range(N):
626 sh = [1] * N
627 sh[i] = -1
628 _kernel[i] = kernel[i].reshape(sh)
629
630 _center = np.zeros((N, N), dtype=int)
631 _center[np.diag_indices(N)] = center
632
633 _mode = Pad( # get `mode` in canonical form
634 (3,) * _kernel[0].ndim,
635 pad_width=1,
636 mode=mode,
637 )._mode
638
639 return _kernel, _center, _mode
640
641 @staticmethod
642 def _compute_pad_width(_kernel, _center, _mode) -> Pad.WidthSpec:
643 N = _kernel[0].ndim
644 pad_width = [None] * N
645 for i in range(N):
646 if len(_kernel) == 1: # non-seperable filter
647 c = _center[0][i]
648 n = _kernel[0].shape[i]
649 else: # seperable filter(s)
650 c = _center[i][i]
651 n = _kernel[i].size
652
653 # 1. Pad/Trim operators are shared amongst [apply,adjoint]():
654 # lhs/rhs are thus padded equally.
655 # 2. When mode != "constant", pad width must match kernel dimensions to retain border
656 # effects.
657 if _mode[i] == "constant":
658 p = max(c, n - c - 1)
659 else: # anything else supported by Pad()
660 p = n - 1
661 pad_width[i] = (p, p)
662 return tuple(pad_width)
663
664 @staticmethod
665 def _init_fw(_kernel, _center) -> list:
666 # Initialize kernels used in apply().
667 # The returned objects must have the following fields:
668 # * _kernel: ndarray[float] (D,)
669 # * _center: ndarray[int] (D,)
670 _st_fw = [None] * len(_kernel)
671 for i, (k_fw, c_fw) in enumerate(zip(_kernel, _center)):
672 _st_fw[i] = _Stencil.init(kernel=k_fw, center=c_fw)
673 return _st_fw
674
675 @staticmethod
676 def _init_bw(_kernel, _center) -> list:
677 # Initialize kernels used in adjoint().
678 # The returned objects must have the following fields:
679 # * _kernel: ndarray[float] (D,)
680 # * _center: ndarray[int] (D,)
681 _st_bw = [None] * len(_kernel)
682 _kernel, _center = Stencil._bw_equivalent(_kernel, _center)
683 for i, (k_bw, c_bw) in enumerate(zip(_kernel, _center)):
684 _st_bw[i] = _Stencil.init(kernel=k_bw, center=c_bw)
685 return _st_bw
686
687 @staticmethod
688 def _bw_equivalent(_kernel, _center):
689 # Transform FW kernel/center specification to BW variant.
690 k_bw = [np.flip(k_fw) for k_fw in _kernel]
691
692 if len(_kernel) == 1: # non-seperable filter
693 c_bw = [(_kernel[0].shape - _center[0]) - 1]
694 else: # seperable filter(s)
695 N = _kernel[0].ndim
696 c_bw = np.zeros((N, N), dtype=int)
697 for i in range(N):
698 c_bw[i, i] = _kernel[i].shape[i] - _center[i][i] - 1
699 return k_bw, c_bw
700
701 def _stencil_chain(self, x: pxt.NDArray, stencils: list) -> pxt.NDArray:
702 # Apply sequence of stencils to `x`.
703 #
704 # x: (..., M1,...,MD)
705 # y: (..., M1,...,MD)
706
707 # _Stencil() instances cannot be serialized by Dask, so we pass around _[kernel,center] directly.
708 # _Stencil(kernel,center) was compiled in __init__() though, hence re-instantiating _Stencil() here is free.
709 kernel = [st._kernel for st in stencils]
710 center = [st._center for st in stencils]
711
712 def _chain(x, kernel, center, dispatch_params):
713 stencils = [_Stencil.init(k, c) for (k, c) in zip(kernel, center)]
714
715 xp = pxu.get_array_module(x)
716 if len(stencils) == 1:
717 x = xp.require(x, requirements="C")
718 y = x.copy()
719 else:
720 # [2023.04.17, Sepand]
721 # In-place updates of `x` breaks thread-safety of Stencil().
722 x, y = x.copy(), x.copy()
723
724 for st in stencils:
725 st.apply(x, y, **dispatch_params)
726 x, y = y, x
727 y = x
728 return y
729
730 ndi = pxd.NDArrayInfo.from_obj(x)
731 if ndi == pxd.NDArrayInfo.DASK:
732 stack_depth = (0,) * (x.ndim - self.dim_rank)
733 y = x.map_overlap(
734 _chain,
735 depth=stack_depth + self._pad._pad_width,
736 dtype=x.dtype,
737 meta=x._meta,
738 # extra _chain() kwargs -------------------
739 kernel=kernel,
740 center=center,
741 dispatch_params=self._dispatch_params,
742 )
743 else: # NUMPY/CUPY
744 y = _chain(x, kernel, center, self._dispatch_params)
745 return y
746
747 def _cast_warn(self, arr: pxt.NDArray) -> pxt.NDArray:
748 if arr.dtype == self._dtype:
749 out = arr
750 else:
751 if self._enable_warnings:
752 msg = "Computation may not be performed at the requested precision."
753 warnings.warn(msg, pxw.PrecisionWarning)
754 out = arr.astype(dtype=self._dtype)
755 return out
756
757
758Correlate = Stencil #: Alias of :py:class:`~pyxu.operator.Stencil`.
759
760
[docs]
761class Convolve(Stencil):
762 r"""
763 Multi-dimensional JIT-compiled convolution.
764
765 Inputs are convolved with the given kernel.
766
767 Notes
768 -----
769 Given a :math:`D`-dimensional array :math:`x\in\mathbb{R}^{N_1 \times\cdots\times N_D}` and kernel
770 :math:`k\in\mathbb{R}^{K_1 \times\cdots\times K_D}` with center :math:`(c_1, \ldots, c_D)`, the output of the
771 convolution operator is an array :math:`y\in\mathbb{R}^{N_1 \times\cdots\times N_D}` given by:
772
773 .. math::
774
775 y[i_{1},\ldots,i_{D}]
776 =
777 \sum_{q_{1},\ldots,q_{D}=0}^{K_{1},\ldots,K_{D}}
778 x[i_{1} - q_{1} + c_{1},\ldots,i_{D} - q_{D} + c_{D}]
779 \,\cdot\,
780 k[q_{1},\ldots,q_{D}].
781
782 The convolution is implemented via :py:class:`~pyxu.operator.Stencil`. To do so, the convolution kernel is
783 transformed to the equivalent correlation kernel:
784
785 .. math::
786
787 y[i_{1},\ldots,i_{D}]
788 =
789 \sum_{q_{1},\ldots,q_{D}=0}^{K_{1},\ldots,K_{D}}
790 &x[i_{1} + q_{1} - (K_{1} - c_{1}),\ldots,i_{D} + q_{D} - (K_{D} - c_{D})] \\
791 &\cdot\,
792 k[K_{1}-q_{1},\ldots,K_{D}-q_{D}].
793
794 This corresponds to a correlation with a flipped kernel and center.
795
796 .. Warning::
797
798 For large, non-separable kernels, stencil compilation can be time-consuming. Depending on your computer's
799 architecture, using the :py:class:~pyxu.operator.FFTConvolve operator might offer a more efficient solution.
800 However, the performance improvement varies, so we recommend evaluating this alternative in your specific
801 environment.
802
803 Examples
804 --------
805 .. code-block:: python3
806
807 import numpy as np
808 from scipy.ndimage import convolve
809 from pyxu.operator import Convolve
810
811 x = np.array([
812 [1, 2, 0, 0],
813 [5, 3, 0, 4],
814 [0, 0, 0, 7],
815 [9, 3, 0, 0],
816 ])
817 k = np.array([
818 [1, 1, 1],
819 [1, 1, 0],
820 [1, 0, 0],
821 ])
822 op = Convolve(
823 dim_shape=x.shape,
824 kernel=k,
825 center=(1, 1),
826 mode="constant",
827 )
828
829 y_op = op.apply(x)
830 y_sp = convolve(x, k, mode="constant", origin=0) # np.allclose(y_op, y_sp) -> True
831 # [[11 10 7 4],
832 # [10 3 11 11],
833 # [15 12 14 7],
834 # [12 3 7 0]]
835
836 See Also
837 --------
838 :py:class:`~pyxu.operator.Stencil`
839 """
840
[docs]
841 def __init__(
842 self,
843 dim_shape: pxt.NDArrayShape,
844 kernel: Stencil.KernelSpec,
845 center: _Stencil.IndexSpec,
846 mode: Pad.ModeSpec = "constant",
847 enable_warnings: bool = True,
848 ):
849 r"""
850 See :py:meth:`~pyxu.operator.Stencil.__init__` for a description of the arguments.
851 """
852 super().__init__(
853 dim_shape=dim_shape,
854 kernel=kernel,
855 center=center,
856 mode=mode,
857 enable_warnings=enable_warnings,
858 )
859
860 # flip FW/BW kernels (& centers)
861 self._st_fw, self._st_bw = self._st_bw, self._st_fw