Source code for pyxu.operator.linop.stencil.stencil

  1import collections.abc as cabc
  2import functools
  3import operator
  4import typing as typ
  5import warnings
  6
  7import numpy as np
  8
  9import pyxu.abc as pxa
 10import pyxu.info.deps as pxd
 11import pyxu.info.ptype as pxt
 12import pyxu.info.warning as pxw
 13import pyxu.runtime as pxrt
 14import pyxu.util as pxu
 15from pyxu.operator.linop.pad import Pad
 16from pyxu.operator.linop.select import Trim
 17from pyxu.operator.linop.stencil._stencil import _Stencil
 18
 19__all__ = [
 20    "Stencil",
 21    "Correlate",
 22    "Convolve",
 23]
 24
 25
[docs] 26class Stencil(pxa.SquareOp): 27 r""" 28 Multi-dimensional JIT-compiled stencil. 29 30 Stencils are a common computational pattern in which array elements are updated according to some fixed pattern 31 called the *stencil kernel*. Notable examples include multi-dimensional convolutions, correlations and finite 32 differences. (See Notes for a definition.) 33 34 Stencils can be evaluated efficiently on CPU/GPU architectures. 35 36 Several boundary conditions are supported. Moreover boundary conditions may differ per axis. 37 38 .. rubric:: Implementation Notes 39 40 * Numba (and its ``@stencil`` decorator) is used to compile efficient machine code from a stencil kernel 41 specification. This has 2 consequences: 42 43 * :py:class:`~pyxu.operator.Stencil` instances are **not arraymodule-agnostic**: they will only work with NDArrays 44 belonging to the same array module as `kernel`. 45 * Compiled stencils are not **precision-agnostic**: they will only work on NDArrays with the same dtype as 46 `kernel`. A warning is emitted if inputs must be cast to the kernel dtype. 47 48 * Stencil kernels can be specified in two forms: (See :py:meth:`~pyxu.operator.Stencil.__init__` for details.) 49 50 * A single non-separable :math:`D`-dimensional kernel :math:`k[i_{1},\ldots,i_{D}]` of shape 51 :math:`(K_{1},\ldots,K_{D})`. 52 * A sequence of separable :math:`1`-dimensional kernel(s) :math:`k_{d}[i]` of shapes 53 :math:`(K_{1},),\ldots,(K_{D},)` such that :math:`k[i_{1},\ldots,i_{D}] = \Pi_{d=1}^{D} k_{d}[i_{d}]`. 54 55 .. rubric:: Mathematical Notes 56 57 Given a :math:`D`-dimensional array :math:`x\in\mathbb{R}^{N_1\times\cdots\times N_D}` and kernel 58 :math:`k\in\mathbb{R}^{K_1\times\cdots\times K_D}` with center :math:`(c_1, \ldots, c_D)`, the output of the stencil 59 operator is an array :math:`y\in\mathbb{R}^{N_1\times\cdots\times N_D}` given by: 60 61 .. math:: 62 63 y[i_{1},\ldots,i_{D}] 64 = 65 \sum_{q_{1},\ldots,q_{D}=0}^{K_{1},\ldots,K_{D}} 66 x[i_{1} - c_{1} + q_{1},\ldots,i_{D} - c_{D} + q_{D}] 67 \,\cdot\, 68 k[q_{1},\ldots,q_{D}]. 69 70 This corresponds to a *correlation* with a shifted version of the kernel :math:`k`. 71 72 Numba stencils assume summation terms involving out-of-bound indices of :math:`x` are set to zero. 73 :py:class:`~pyxu.operator.Stencil` lifts this constraint by extending the stencil to boundary values via pre-padding 74 and post-trimming. Concretely, any stencil operator :math:`S` instantiated with :py:class:`~pyxu.operator.Stencil` 75 can be written as the composition :math:`S = TS_0P`, where :math:`(T, S_0, P)` are trimming, stencil with 76 zero-padding conditions, and padding operators respectively. This construct allows 77 :py:class:`~pyxu.operator.Stencil` to handle complex boundary conditions under which :math:`S` *may not be a proper 78 stencil* (i.e., varying kernel) but can still be implemented efficiently via a proper stencil upon appropriate 79 trimming/padding. 80 81 For example consider the decomposition of the following (improper) stencil operator: 82 83 .. code-block:: python3 84 85 >>> S = Stencil( 86 ... dim_shape=(5,), 87 ... kernel=np.r_[1, 2, -3], 88 ... center=(2,), 89 ... mode="reflect", 90 ... ) 91 92 >>> S.asarray() 93 [[-3 2 1 0 0] 94 [ 2 -2 0 0 0] 95 [ 1 2 -3 0 0] 96 [ 0 1 2 -3 0] 97 [ 0 0 1 2 -3]] # Improper stencil (kernel varies across rows) 98 = 99 [[0 0 1 0 0 0 0 0 0] 100 [0 0 0 1 0 0 0 0 0] 101 [0 0 0 0 1 0 0 0 0] 102 [0 0 0 0 0 1 0 0 0] 103 [0 0 0 0 0 0 1 0 0]] # Trimming 104 @ 105 [[-3 0 0 0 0 0 0 0 0] 106 [ 2 -3 0 0 0 0 0 0 0] 107 [ 1 2 -3 0 0 0 0 0 0] 108 [ 0 1 2 -3 0 0 0 0 0] 109 [ 0 0 1 2 -3 0 0 0 0] 110 [ 0 0 0 1 2 -3 0 0 0] 111 [ 0 0 0 0 1 2 -3 0 0] 112 [ 0 0 0 0 0 1 2 -3 0] 113 [ 0 0 0 0 0 0 1 2 -3]] # Proper stencil (Toeplitz structure) 114 @ 115 [[0 0 1 0 0] 116 [0 1 0 0 0] 117 [1 0 0 0 0] 118 [0 1 0 0 0] 119 [0 0 1 0 0] 120 [0 0 0 1 0] 121 [0 0 0 0 1] 122 [0 0 0 1 0] 123 [0 0 1 0 0]] # Padding with reflect mode. 124 125 Note that the adjoint of a stencil operator may not necessarily be a stencil operator, or the associated center and 126 boundary conditions may be hard to predict. For example, the adjoint of the stencil operator defined above is given 127 by: 128 129 .. code-block:: python3 130 131 >>> S.T.asarray() 132 [[-3 2 1 0 0], 133 [ 2 -2 2 1 0], 134 [ 1 0 -3 2 1], 135 [ 0 0 0 -3 2], 136 [ 0 0 0 0 -3]] 137 138 which resembles a stencil with time-reversed kernel, but with weird (if not improper) boundary conditions. This can 139 also be seen from the fact that :math:`S^\ast = P^\ast S_0^\ast T^\ast = P^\ast S_0^\ast P_0,` and :math:`P^\ast` is 140 in general not a trimming operator. (See :py:class:`~pyxu.operator.Pad`.) 141 142 The same holds for gram/cogram operators. Consider indeed the following order-1 backward finite-difference operator 143 with zero-padding: 144 145 .. code-block:: python3 146 147 >>> S = Stencil( 148 ... dim_shape=(5,), 149 ... kernel=np.r_[-1, 1], 150 ... center=(0,), 151 ... mode='constant', 152 ... ) 153 154 >>> S.gram().asarray() 155 [[ 1 -1 0 0 0] 156 [-1 2 -1 0 0] 157 [ 0 -1 2 -1 0] 158 [ 0 0 -1 2 -1] 159 [ 0 0 0 -1 2]] 160 161 We observe that the Gram differs from the order 2 centered finite-difference operator. (Reduced-order derivative on 162 one side.) 163 164 Example 165 ------- 166 167 * **Moving average of a 1D signal** 168 169 Let :math:`x[n]` denote a 1D signal. The weighted moving average 170 171 .. math:: 172 173 y[n] = x[n-2] + 2 x[n-1] + 3 x[n] 174 175 can be viewed as the output of the 3-point stencil of kernel :math:`k = [1, 2, 3]`. 176 177 .. code-block:: python3 178 179 import numpy as np 180 from pyxu.operator import Stencil 181 182 x = np.arange(10) # [0 1 2 3 4 5 6 7 8 9] 183 184 op = Stencil( 185 dim_shape=x.shape, 186 kernel=np.array([1, 2, 3]), 187 center=(2,), # k[2] applies on x[n] 188 ) 189 190 y = op.apply(x) # [0 3 8 14 20 26 32 38 44 50] 191 192 193 * **Non-seperable image filtering** 194 195 Let :math:`x[n, m]` denote a 2D image. The blurred image 196 197 .. math:: 198 199 y[n, m] = 2 x[n-1,m-1] + 3 x[n-1,m+1] + 4 x[n+1,m-1] + 5 x[n+1,m+1] 200 201 can be viewed as the output of the 9-point stencil 202 203 .. math:: 204 205 k = 206 \left[ 207 \begin{array}{ccc} 208 2 & 0 & 3 \\ 209 0 & 0 & 0 \\ 210 4 & 0 & 5 211 \end{array} 212 \right]. 213 214 .. code-block:: python3 215 216 import numpy as np 217 from pyxu.operator import Stencil 218 219 x = np.arange(64).reshape(8, 8) # square image 220 # [[ 0 1 2 3 4 5 6 7] 221 # [ 8 9 10 11 12 13 14 15] 222 # [16 17 18 19 20 21 22 23] 223 # [24 25 26 27 28 29 30 31] 224 # [32 33 34 35 36 37 38 39] 225 # [40 41 42 43 44 45 46 47] 226 # [48 49 50 51 52 53 54 55] 227 # [56 57 58 59 60 61 62 63]] 228 229 op = Stencil( 230 dim_shape=x.shape, 231 kernel=np.array( 232 [[2, 0, 3], 233 [0, 0, 0], 234 [4, 0, 5]]), 235 center=(1, 1), # k[1, 1] applies on x[n, m] 236 ) 237 238 y = op.apply(x) 239 # [[ 45 82 91 100 109 118 127 56 ] 240 # [ 88 160 174 188 202 216 230 100 ] 241 # [152 272 286 300 314 328 342 148 ] 242 # [216 384 398 412 426 440 454 196 ] 243 # [280 496 510 524 538 552 566 244 ] 244 # [344 608 622 636 650 664 678 292 ] 245 # [408 720 734 748 762 776 790 340 ] 246 # [147 246 251 256 261 266 271 108 ]] 247 248 * **Seperable image filtering** 249 250 Let :math:`x[n, m]` denote a 2D image. The warped image 251 252 .. math:: 253 254 \begin{align*} 255 y[n, m] = & + 4 x[n-1,m-1] + 5 x[n-1,m] + 6 x[n-1,m+1] \\ 256 & + 8 x[n ,m-1] + 10 x[n ,m] + 12 x[n ,m+1] \\ 257 & + 12 x[n+1,m-1] + 15 x[n+1,m] + 18 x[n+1,m+1] 258 \end{align*} 259 260 can be viewed as the output of the 9-point stencil 261 262 .. math:: 263 264 k_{2D} = 265 \left[ 266 \begin{array}{ccc} 267 4 & 5 & 6 \\ 268 8 & 10 & 12 \\ 269 12 & 15 & 18 \\ 270 \end{array} 271 \right]. 272 273 Notice however that :math:`y[n, m]` can be implemented more efficiently by factoring the 9-point stencil as a 274 cascade of two 3-point stencils: 275 276 .. math:: 277 278 k_{2D} 279 = k_{1} k_{2}^{T} 280 = \left[ 281 \begin{array}{c} 282 1 \\ 2 \\ 3 283 \end{array} 284 \right] 285 \left[ 286 \begin{array}{c} 287 4 & 5 & 6 288 \end{array} 289 \right]. 290 291 Seperable stencils are supported and should be preferred when applicable. 292 293 .. code-block:: python3 294 295 import numpy as np 296 from pyxu.operator import Stencil 297 298 x = np.arange(64).reshape(8, 8) # square image 299 # [[ 0 1 2 3 4 5 6 7] 300 # [ 8 9 10 11 12 13 14 15] 301 # [16 17 18 19 20 21 22 23] 302 # [24 25 26 27 28 29 30 31] 303 # [32 33 34 35 36 37 38 39] 304 # [40 41 42 43 44 45 46 47] 305 # [48 49 50 51 52 53 54 55] 306 # [56 57 58 59 60 61 62 63]] 307 308 op_2D = Stencil( # using non-seperable kernel 309 dim_shape=x.shape, 310 kernel=np.array( 311 [[ 4, 5, 6], 312 [ 8, 10, 12], 313 [12, 15, 18]]), 314 center=(1, 1), # k[1, 1] applies on x[n, m] 315 ) 316 op_sep = Stencil( # using seperable kernels 317 dim_shape=x.shape, 318 kernel=[ 319 np.array([1, 2, 3]), # k1: stencil along 1st axis 320 np.array([4, 5, 6]), # k2: stencil along 2nd axis 321 ], 322 center=(1, 1), # k1[1] * k2[1] applies on x[n, m] 323 ) 324 325 y_2D = op_2D.apply(x) 326 y_sep = op_sep.apply(x) # np.allclose(y_2D, y_sep) -> True 327 # [[ 294 445 520 595 670 745 820 511 ] 328 # [ 740 1062 1152 1242 1332 1422 1512 930 ] 329 # [1268 1782 1872 1962 2052 2142 2232 1362 ] 330 # [1796 2502 2592 2682 2772 2862 2952 1794 ] 331 # [2324 3222 3312 3402 3492 3582 3672 2226 ] 332 # [2852 3942 4032 4122 4212 4302 4392 2658 ] 333 # [3380 4662 4752 4842 4932 5022 5112 3090 ] 334 # [1778 2451 2496 2541 2586 2631 2676 1617 ]] 335 336 .. Warning:: 337 338 For large, non-separable kernels, stencil compilation can be time-consuming. Depending on your computer's 339 architecture, using the :py:class:~pyxu.operator.FFTCorrelate operator might offer a more efficient solution. 340 However, the performance improvement varies, so we recommend evaluating this alternative in your specific 341 environment. 342 343 See Also 344 -------- 345 :py:class:`~pyxu.operator.Convolve`, 346 :py:class:`~pyxu.operator._Stencil`, 347 :py:class:`~pyxu.operator.FFTCorrelate`, 348 :py:class:`~pyxu.operator.FFTConvolve` 349 """ 350 351 KernelSpec = typ.Union[ 352 pxt.NDArray, # (k1,...,kD) non-seperable kernel 353 cabc.Sequence[pxt.NDArray], # [(k1,), ..., (kD,)] seperable kernels 354 ] 355
[docs] 356 def __init__( 357 self, 358 dim_shape: pxt.NDArrayShape, 359 kernel: KernelSpec, 360 center: _Stencil.IndexSpec, 361 mode: Pad.ModeSpec = "constant", 362 enable_warnings: bool = True, 363 ): 364 r""" 365 Parameters 366 ---------- 367 dim_shape: NDArrayShape 368 (M1,...,MD) input dimensions. 369 kernel: ~pyxu.operator.Stencil.KernelSpec 370 Stencil coefficients. Two forms are accepted: 371 372 * NDArray of rank-:math:`D`: denotes a non-seperable stencil. 373 * tuple[NDArray_1, ..., NDArray_D]: a sequence of 1D stencils such that dimension[k] is filtered by stencil 374 `kernel[k]`, that is: 375 376 .. math:: 377 378 k = k_1 \otimes\cdots\otimes k_D, 379 380 or in Python: ``k = functools.reduce(numpy.multiply.outer, kernel)``. 381 382 center: ~pyxu.operator._Stencil.IndexSpec 383 (i1,...,iD) index of the stencil's center. 384 385 `center` defines how a kernel is overlaid on inputs to produce outputs. 386 387 mode: str, :py:class:`list` ( str ) 388 Boundary conditions. Multiple forms are accepted: 389 390 * str: unique mode shared amongst dimensions. 391 Must be one of: 392 393 * 'constant' (zero-padding) 394 * 'wrap' 395 * 'reflect' 396 * 'symmetric' 397 * 'edge' 398 * tuple[str, ...]: dimension[k] uses `mode[k]` as boundary condition. 399 400 (See :py:func:`numpy.pad` for details.) 401 enable_warnings: bool 402 If ``True``, emit a warning in case of precision mis-match issues. 403 """ 404 super().__init__( 405 dim_shape=dim_shape, 406 codim_shape=dim_shape, 407 ) 408 _kernel, _center, _mode = self._canonical_repr(self.dim_shape, kernel, center, mode) 409 410 # Pad/Trim operators 411 pad_width = self._compute_pad_width(_kernel, _center, _mode) 412 self._pad = Pad( 413 dim_shape=dim_shape, 414 pad_width=pad_width, 415 mode=_mode, 416 ) 417 self._trim = Trim( 418 dim_shape=self._pad.codim_shape, 419 trim_width=pad_width, 420 ) 421 422 # Kernels (These _Stencil() instances are not used as-is in apply/adjoint calls: their ._[kernel,center] 423 # attributes are used directly there instead to bypass Numba serialization limits. These _Stencil() objects are 424 # used however for all other Operator public methods.) 425 # It is moreover advantageous to instantiate them once here to cache JIT-compile kernels upfront. 426 self._st_fw = self._init_fw(_kernel, _center) 427 self._st_bw = self._init_bw(_kernel, _center) 428 429 self._dispatch_params = dict() # Extra kwargs passed to _Stencil.apply() 430 self._dtype = _kernel[0].dtype # useful constant 431 self._enable_warnings = bool(enable_warnings) 432 433 # We know a crude Lipschitz bound by default. Since computing it takes (code) space, 434 # the estimate is computed as a special case of estimate_lipschitz() 435 self.lipschitz = self.estimate_lipschitz(__rule=True)
436 437 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 438 x = self._pad.apply(arr) 439 y = self._stencil_chain( 440 x=self._cast_warn(x), 441 stencils=self._st_fw, 442 ) 443 z = self._trim.apply(y) 444 return z 445 446 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 447 x = self._trim.adjoint(arr) 448 y = self._stencil_chain( 449 x=self._cast_warn(x), 450 stencils=self._st_bw, 451 ) 452 z = self._pad.adjoint(y) 453 return z 454
[docs] 455 def configure_dispatcher(self, **kwargs): 456 """ 457 (Only applies if `kernel` is a CuPy array.) 458 459 Configure stencil Dispatcher. 460 461 See :py:meth:`~pyxu.operator._Stencil.apply` for accepted options. 462 463 Example 464 ------- 465 .. code-block:: python3 466 467 import cupy as cp 468 from pyxu.operator import Stencil 469 470 x = cp.arange(10) 471 472 op = Stencil( 473 dim_shape=x.shape, 474 kernel=np.array([1, 2, 3]), 475 center=(1,), 476 ) 477 478 y = op.apply(x) # uses default threadsperblock/blockspergrid values 479 480 op.configure_dispatcher( 481 threadsperblock=50, 482 blockspergrid=3, 483 ) 484 y = op.apply(x) # supplied values used instead 485 """ 486 for k, v in kwargs.items(): 487 self._dispatch_params.update(k=v)
488 489 def estimate_lipschitz(self, **kwargs) -> pxt.Real: 490 no_eval = "__rule" in kwargs 491 if no_eval: 492 # Analytic upper bound from Young's convolution inequality: 493 # \norm{x \ast h}{2} \le \norm{x}{2}\norm{h}{1} 494 # 495 # -> L \le \norm{h}{1} 496 kernels = [st._kernel for st in self._st_fw] 497 kernel = functools.reduce(operator.mul, kernels, 1) 498 L_st = np.linalg.norm(pxu.to_NUMPY(kernel).reshape(-1), ord=1) 499 500 L_pad = self._pad.lipschitz 501 L_trim = self._trim.lipschitz 502 503 L = L_trim * L_st * L_pad # upper bound 504 else: 505 L = super().estimate_lipschitz(**kwargs) 506 return L 507 508 def asarray(self, **kwargs) -> pxt.NDArray: 509 # Stencil.apply() prefers precision provided at init-time. 510 xp = pxu.get_array_module(self._st_fw[0]._kernel) 511 _A = super().asarray(xp=xp, dtype=self._dtype) 512 513 xp = kwargs.get("xp", pxd.NDArrayInfo.NUMPY.module()) 514 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value) 515 A = xp.array(pxu.to_NUMPY(_A), dtype=dtype) 516 return A 517 518 def trace(self, **kwargs) -> pxt.Real: 519 if all(m == "constant" for m in self._pad._mode): 520 # tr = (kernel center coefficient) * dim_size 521 tr = functools.reduce( 522 operator.mul, 523 [st._kernel[tuple(st._center)] for st in self._st_fw], 524 1, 525 ) 526 tr *= self.dim_size 527 else: 528 # Standard algorithm, with computations restricted to precision supported by 529 # Stencil.apply(). 530 kwargs.update(dtype=self._dtype) 531 tr = super().trace(**kwargs) 532 return float(tr) 533 534 # Helper routines (public) ------------------------------------------------ 535 @property 536 def kernel(self) -> KernelSpec: 537 r""" 538 Stencil kernel coefficients. 539 540 Returns 541 ------- 542 kern: ~pyxu.operator.Stencil.KernelSpec 543 Stencil coefficients. 544 545 If the kernel is non-seperable, a single array is returned. 546 Otherwise :math:`D` arrays are returned, one per axis. 547 """ 548 if len(self._st_fw) == 1: 549 kern = self._st_fw[0]._kernel 550 else: 551 kern = [st._kernel for st in self._st_fw] 552 return kern 553 554 @property 555 def center(self) -> _Stencil.IndexSpec: 556 """ 557 Stencil central position. 558 559 Returns 560 ------- 561 ctr: ~pyxu.operator._Stencil.IndexSpec 562 Stencil central position. 563 """ 564 if len(self._st_fw) == 1: 565 ctr = self._st_fw[0]._center 566 else: 567 ctr = [st._center[d] for (d, st) in enumerate(self._st_fw)] 568 return tuple(ctr) 569
[docs] 570 def visualize(self) -> str: 571 r""" 572 Show the :math:`D`-dimensional stencil kernel. 573 574 The stencil's center is identified by surrounding parentheses. 575 576 Example 577 ------- 578 .. code-block:: python3 579 580 S = Stencil( 581 dim_shape=(5, 6), 582 kernel=[ 583 np.r_[3, 2, 1], 584 np.r_[2, -1, 3, 1], 585 ], 586 center=(1, 2), 587 ) 588 print(S.visualize()) # [[6.0 -3.0 9.0 3.0] 589 # [4.0 -2.0 (6.0) 2.0] 590 # [2.0 -1.0 3.0 1.0]] 591 """ 592 kernels = [st._kernel for st in self._st_fw] 593 kernel = functools.reduce(operator.mul, kernels, 1) 594 595 kernel = pxu.to_NUMPY(kernel).astype(str) 596 kernel[self.center] = "(" + kernel[self.center] + ")" 597 598 kern = np.array2string(kernel).replace("'", "") 599 return kern
600 601 # Helper routines (internal) ---------------------------------------------- 602 @staticmethod 603 def _canonical_repr(dim_shape, kernel, center, mode): 604 # Create canonical representations 605 # * `_kernel`: list[ndarray[float], ...] 606 # * `_center`: list[ndarray[int], ...] 607 # * `_mode`: list[str, ...] 608 # 609 # `dim_shape`` is already assumed in tuple-form. 610 N = len(dim_shape) 611 assert len(center) == N 612 613 kernel = pxu.compute(kernel, traverse=True) 614 try: 615 # array input -> non-seperable filter 616 pxu.get_array_module(kernel) 617 assert kernel.ndim == N 618 _kernel = [kernel] 619 _center = [np.array(center, dtype=int)] 620 except Exception: 621 # sequence input -> seperable filter(s) 622 assert len(kernel) == N # one filter per dimension 623 624 _kernel = [None] * N 625 for i in range(N): 626 sh = [1] * N 627 sh[i] = -1 628 _kernel[i] = kernel[i].reshape(sh) 629 630 _center = np.zeros((N, N), dtype=int) 631 _center[np.diag_indices(N)] = center 632 633 _mode = Pad( # get `mode` in canonical form 634 (3,) * _kernel[0].ndim, 635 pad_width=1, 636 mode=mode, 637 )._mode 638 639 return _kernel, _center, _mode 640 641 @staticmethod 642 def _compute_pad_width(_kernel, _center, _mode) -> Pad.WidthSpec: 643 N = _kernel[0].ndim 644 pad_width = [None] * N 645 for i in range(N): 646 if len(_kernel) == 1: # non-seperable filter 647 c = _center[0][i] 648 n = _kernel[0].shape[i] 649 else: # seperable filter(s) 650 c = _center[i][i] 651 n = _kernel[i].size 652 653 # 1. Pad/Trim operators are shared amongst [apply,adjoint](): 654 # lhs/rhs are thus padded equally. 655 # 2. When mode != "constant", pad width must match kernel dimensions to retain border 656 # effects. 657 if _mode[i] == "constant": 658 p = max(c, n - c - 1) 659 else: # anything else supported by Pad() 660 p = n - 1 661 pad_width[i] = (p, p) 662 return tuple(pad_width) 663 664 @staticmethod 665 def _init_fw(_kernel, _center) -> list: 666 # Initialize kernels used in apply(). 667 # The returned objects must have the following fields: 668 # * _kernel: ndarray[float] (D,) 669 # * _center: ndarray[int] (D,) 670 _st_fw = [None] * len(_kernel) 671 for i, (k_fw, c_fw) in enumerate(zip(_kernel, _center)): 672 _st_fw[i] = _Stencil.init(kernel=k_fw, center=c_fw) 673 return _st_fw 674 675 @staticmethod 676 def _init_bw(_kernel, _center) -> list: 677 # Initialize kernels used in adjoint(). 678 # The returned objects must have the following fields: 679 # * _kernel: ndarray[float] (D,) 680 # * _center: ndarray[int] (D,) 681 _st_bw = [None] * len(_kernel) 682 _kernel, _center = Stencil._bw_equivalent(_kernel, _center) 683 for i, (k_bw, c_bw) in enumerate(zip(_kernel, _center)): 684 _st_bw[i] = _Stencil.init(kernel=k_bw, center=c_bw) 685 return _st_bw 686 687 @staticmethod 688 def _bw_equivalent(_kernel, _center): 689 # Transform FW kernel/center specification to BW variant. 690 k_bw = [np.flip(k_fw) for k_fw in _kernel] 691 692 if len(_kernel) == 1: # non-seperable filter 693 c_bw = [(_kernel[0].shape - _center[0]) - 1] 694 else: # seperable filter(s) 695 N = _kernel[0].ndim 696 c_bw = np.zeros((N, N), dtype=int) 697 for i in range(N): 698 c_bw[i, i] = _kernel[i].shape[i] - _center[i][i] - 1 699 return k_bw, c_bw 700 701 def _stencil_chain(self, x: pxt.NDArray, stencils: list) -> pxt.NDArray: 702 # Apply sequence of stencils to `x`. 703 # 704 # x: (..., M1,...,MD) 705 # y: (..., M1,...,MD) 706 707 # _Stencil() instances cannot be serialized by Dask, so we pass around _[kernel,center] directly. 708 # _Stencil(kernel,center) was compiled in __init__() though, hence re-instantiating _Stencil() here is free. 709 kernel = [st._kernel for st in stencils] 710 center = [st._center for st in stencils] 711 712 def _chain(x, kernel, center, dispatch_params): 713 stencils = [_Stencil.init(k, c) for (k, c) in zip(kernel, center)] 714 715 xp = pxu.get_array_module(x) 716 if len(stencils) == 1: 717 x = xp.require(x, requirements="C") 718 y = x.copy() 719 else: 720 # [2023.04.17, Sepand] 721 # In-place updates of `x` breaks thread-safety of Stencil(). 722 x, y = x.copy(), x.copy() 723 724 for st in stencils: 725 st.apply(x, y, **dispatch_params) 726 x, y = y, x 727 y = x 728 return y 729 730 ndi = pxd.NDArrayInfo.from_obj(x) 731 if ndi == pxd.NDArrayInfo.DASK: 732 stack_depth = (0,) * (x.ndim - self.dim_rank) 733 y = x.map_overlap( 734 _chain, 735 depth=stack_depth + self._pad._pad_width, 736 dtype=x.dtype, 737 meta=x._meta, 738 # extra _chain() kwargs ------------------- 739 kernel=kernel, 740 center=center, 741 dispatch_params=self._dispatch_params, 742 ) 743 else: # NUMPY/CUPY 744 y = _chain(x, kernel, center, self._dispatch_params) 745 return y 746 747 def _cast_warn(self, arr: pxt.NDArray) -> pxt.NDArray: 748 if arr.dtype == self._dtype: 749 out = arr 750 else: 751 if self._enable_warnings: 752 msg = "Computation may not be performed at the requested precision." 753 warnings.warn(msg, pxw.PrecisionWarning) 754 out = arr.astype(dtype=self._dtype) 755 return out
756 757 758Correlate = Stencil #: Alias of :py:class:`~pyxu.operator.Stencil`. 759 760
[docs] 761class Convolve(Stencil): 762 r""" 763 Multi-dimensional JIT-compiled convolution. 764 765 Inputs are convolved with the given kernel. 766 767 Notes 768 ----- 769 Given a :math:`D`-dimensional array :math:`x\in\mathbb{R}^{N_1 \times\cdots\times N_D}` and kernel 770 :math:`k\in\mathbb{R}^{K_1 \times\cdots\times K_D}` with center :math:`(c_1, \ldots, c_D)`, the output of the 771 convolution operator is an array :math:`y\in\mathbb{R}^{N_1 \times\cdots\times N_D}` given by: 772 773 .. math:: 774 775 y[i_{1},\ldots,i_{D}] 776 = 777 \sum_{q_{1},\ldots,q_{D}=0}^{K_{1},\ldots,K_{D}} 778 x[i_{1} - q_{1} + c_{1},\ldots,i_{D} - q_{D} + c_{D}] 779 \,\cdot\, 780 k[q_{1},\ldots,q_{D}]. 781 782 The convolution is implemented via :py:class:`~pyxu.operator.Stencil`. To do so, the convolution kernel is 783 transformed to the equivalent correlation kernel: 784 785 .. math:: 786 787 y[i_{1},\ldots,i_{D}] 788 = 789 \sum_{q_{1},\ldots,q_{D}=0}^{K_{1},\ldots,K_{D}} 790 &x[i_{1} + q_{1} - (K_{1} - c_{1}),\ldots,i_{D} + q_{D} - (K_{D} - c_{D})] \\ 791 &\cdot\, 792 k[K_{1}-q_{1},\ldots,K_{D}-q_{D}]. 793 794 This corresponds to a correlation with a flipped kernel and center. 795 796 .. Warning:: 797 798 For large, non-separable kernels, stencil compilation can be time-consuming. Depending on your computer's 799 architecture, using the :py:class:~pyxu.operator.FFTConvolve operator might offer a more efficient solution. 800 However, the performance improvement varies, so we recommend evaluating this alternative in your specific 801 environment. 802 803 Examples 804 -------- 805 .. code-block:: python3 806 807 import numpy as np 808 from scipy.ndimage import convolve 809 from pyxu.operator import Convolve 810 811 x = np.array([ 812 [1, 2, 0, 0], 813 [5, 3, 0, 4], 814 [0, 0, 0, 7], 815 [9, 3, 0, 0], 816 ]) 817 k = np.array([ 818 [1, 1, 1], 819 [1, 1, 0], 820 [1, 0, 0], 821 ]) 822 op = Convolve( 823 dim_shape=x.shape, 824 kernel=k, 825 center=(1, 1), 826 mode="constant", 827 ) 828 829 y_op = op.apply(x) 830 y_sp = convolve(x, k, mode="constant", origin=0) # np.allclose(y_op, y_sp) -> True 831 # [[11 10 7 4], 832 # [10 3 11 11], 833 # [15 12 14 7], 834 # [12 3 7 0]] 835 836 See Also 837 -------- 838 :py:class:`~pyxu.operator.Stencil` 839 """ 840
[docs] 841 def __init__( 842 self, 843 dim_shape: pxt.NDArrayShape, 844 kernel: Stencil.KernelSpec, 845 center: _Stencil.IndexSpec, 846 mode: Pad.ModeSpec = "constant", 847 enable_warnings: bool = True, 848 ): 849 r""" 850 See :py:meth:`~pyxu.operator.Stencil.__init__` for a description of the arguments. 851 """ 852 super().__init__( 853 dim_shape=dim_shape, 854 kernel=kernel, 855 center=center, 856 mode=mode, 857 enable_warnings=enable_warnings, 858 ) 859 860 # flip FW/BW kernels (& centers) 861 self._st_fw, self._st_bw = self._st_bw, self._st_fw