Source code for pyxu.operator.linop.stencil._stencil

  1"""
  2Low-level functions used to define user-facing stencils.
  3"""
  4import collections.abc as cabc
  5import itertools
  6import pathlib as plib
  7import string
  8import types
  9
 10import numpy as np
 11
 12import pyxu.info.config as pxcfg
 13import pyxu.info.deps as pxd
 14import pyxu.info.ptype as pxt
 15import pyxu.runtime as pxrt
 16import pyxu.util as pxu
 17
 18
 19def _signature(params, returns) -> str:
 20    # Translate a signature of the form
 21    #     [in_1_spec, ..., in_N_spec] -> out_spec
 22    # to Numba's string representation.
 23    #
 24    # Parameters
 25    # ----------
 26    # params: list(spec)
 27    # returns: spec | None
 28    #
 29    # Returns
 30    # -------
 31    # sig: str
 32    #
 33    # Notes
 34    # -----
 35    # A parameter spec is characterized by the triplet
 36    #     (dtype[single/double], ndim[int], c_contiguous[bool])
 37    def fmt(spec) -> str:
 38        dtype, ndim, c_contiguous = spec
 39
 40        _dtype_spec = {
 41            pxrt.Width.SINGLE: "float32",
 42            pxrt.Width.DOUBLE: "float64",
 43        }[pxrt.Width(dtype)]
 44
 45        dim_spec = [":"] * ndim
 46        if c_contiguous and (ndim > 0):
 47            dim_spec[-1] = "::1"
 48        dim_spec = "[" + ",".join(dim_spec) + "]"
 49
 50        _repr = _dtype_spec
 51        if ndim > 0:
 52            _repr += dim_spec
 53        return _repr
 54
 55    sig = "".join(
 56        [
 57            "void" if (returns is None) else fmt(returns),
 58            "(",
 59            ", ".join(map(fmt, params)),
 60            ")",
 61        ]
 62    )
 63    return sig
 64
 65
[docs] 66class _Stencil: 67 """ 68 Multi-dimensional JIT-compiled stencil. (Low-level function.) 69 70 This low-level class creates a gu-vectorized stencil applicable on multiple inputs simultaneously. 71 Only NUMPY/CUPY arrays are accepted. 72 73 Create instances via factory method :py:meth:`~pyxu.operator._Stencil.init`. 74 75 Example 76 ------- 77 Correlate a stack of images `A` with a (3, 3) kernel such that: 78 79 .. math:: 80 81 B[n, m] = A[n-1, m] + A[n, m-1] + A[n, m+1] + A[n+1, m] 82 83 .. code-block:: python3 84 85 import numpy as np 86 from pyxu.operator import _Stencil 87 88 # create the stencil 89 kernel = np.array([[0, 1, 0], 90 [1, 0, 1], 91 [0, 1, 0]], dtype=np.float64) 92 center = (1, 1) 93 stencil = _Stencil.init(kernel, center) 94 95 # apply it to the data 96 rng = np.random.default_rng() 97 A = rng.normal(size=(2, 3, 4, 30, 30)) # 24 images of size (30, 30) 98 B = np.zeros_like(A) 99 stencil.apply(A, B) # (2, 3, 4, 30, 30) 100 """ 101 102 IndexSpec = cabc.Sequence[pxt.Integer] 103
[docs] 104 @staticmethod 105 def init( 106 kernel: pxt.NDArray, 107 center: IndexSpec, 108 ): 109 """ 110 Parameters 111 ---------- 112 kernel: NDArray 113 (k1,...,kD) kernel coefficients. 114 115 Only float32/64 kernels are supported. 116 center: ~pyxu.operator._Stencil.IndexSpec 117 (D,) index of the kernel's center. 118 119 Returns 120 ------- 121 st: ~pyxu.operator._Stencil 122 Rank-D stencil. 123 """ 124 dtype = kernel.dtype 125 if dtype not in {_.value for _ in pxrt.Width}: 126 raise ValueError(f"Unsupported kernel precision {dtype}.") 127 128 center = np.array(center, dtype=int) 129 assert center.size == kernel.ndim 130 assert np.all((0 <= center) & (center < kernel.shape)) 131 132 N = pxd.NDArrayInfo 133 ndi = N.from_obj(kernel) 134 if ndi == N.NUMPY: 135 klass = _Stencil_NP 136 elif ndi == N.CUPY: 137 klass = _Stencil_CP 138 else: 139 raise NotImplementedError 140 141 st = klass(kernel, center) 142 return st
143
[docs] 144 def apply( 145 self, 146 arr: pxt.NDArray, 147 out: pxt.NDArray, 148 **kwargs, 149 ) -> pxt.NDArray: 150 r""" 151 Evaluate stencil on multiple inputs. 152 153 Parameters 154 ---------- 155 arr: NDArray 156 (..., M1,...,MD) data to process. 157 out: NDArray 158 (..., M1,...,MD) array to which outputs are written. 159 kwargs: dict 160 Extra kwargs to configure `f_jit()`, the Dispatcher instance created by Numba. 161 162 Only relevant for GPU stencils, with values: 163 164 * blockspergrid: int 165 * threadsperblock: int 166 167 Default values are chosen if unspecified. 168 169 Returns 170 ------- 171 out: NDArray 172 (..., M1,...,MD) outputs. 173 174 Notes 175 ----- 176 * `arr` and `out` must have the same type/dtype as the kernel used during instantiation. 177 * Index regions in `out` where the stencil is not fully supported are set to 0. 178 * :py:meth:`~pyxu.operator._Stencil.apply` may raise ``CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES`` when the number of 179 GPU registers required exceeds resource limits. There are 2 solutions to this problem: 180 181 (1) Pass the `max_registers` kwarg to f_jit()'s decorator; or 182 (2) `Limit the number of threads per block <https://stackoverflow.com/a/68659008>`_. 183 184 (1) must be set at compile time; it is thus left unbounded. 185 (2) is accessible through .apply(\*\*kwargs). 186 """ 187 assert arr.dtype == out.dtype == self._kernel.dtype 188 assert arr.shape == out.shape 189 assert arr.flags.c_contiguous and out.flags.c_contiguous 190 191 K_dim = len(self._kernel.shape) 192 dim_shape = arr.shape[-K_dim:] 193 194 stencil = self._configure_dispatcher(arr.size, **kwargs) 195 stencil( 196 # OK since NP/CP input constraint. 197 arr.reshape(-1, *dim_shape), 198 out.reshape(-1, *dim_shape), 199 ) 200 return out
201 202 def __init__( 203 self, 204 kernel: pxt.NDArray, 205 center: pxt.NDArray, 206 ): 207 self._kernel = kernel 208 self._center = center 209 210 cached_module = self._gen_code() 211 self._dispatch = cached_module.f_jit 212 213 def _gen_code(self) -> types.ModuleType: 214 # Compile Numba kernel `void f_jit(arr, out)`. 215 # 216 # The code is compiled only if unavailable beforehand. 217 # 218 # Returns 219 # ------- 220 # jit_module: module 221 # A (loaded) python package containing method f_jit(). 222 raise NotImplementedError 223 224 def _configure_dispatcher(self, pb_size: int, **kwargs) -> cabc.Callable: 225 # Configure `f_jit()`, the Numba Dispatcher instance. 226 # 227 # Parameters 228 # ---------- 229 # pb_size: int 230 # Number of stencil evaluations. 231 # **kwargs: dict 232 # 233 # Returns 234 # ------- 235 # f: callable 236 # Configured Numba Dispatcher. 237 raise NotImplementedError
238 239 240class _Stencil_NP(_Stencil): 241 def _gen_code(self) -> types.ModuleType: 242 # Generate the code which should be compiled -------------------------- 243 sig_spec = (self._kernel.dtype, self._kernel.ndim + 1, True) 244 signature = _signature((sig_spec,) * 2, None) 245 246 template_file = plib.Path(__file__).parent / "_template_cpu.txt" 247 with open(template_file, mode="r") as f: 248 template = string.Template(f.read()) 249 code = template.substitute( 250 signature=signature, 251 stencil_spec=self.__stencil_spec(), 252 ) 253 # --------------------------------------------------------------------- 254 255 # Store/update cached version as needed. 256 module_name = pxu.cache_module(code) 257 pxcfg.cache_dir(load=True) # make the Pyxu cache importable (if not already done) 258 jit_module = pxu.import_module(module_name) 259 return jit_module 260 261 def _configure_dispatcher(self, pb_size: int, **kwargs) -> cabc.Callable: 262 # Nothing to do for CPU targets. 263 return self._dispatch 264 265 def __stencil_spec(self) -> str: 266 f_fmt = { # coef float-formatter 267 pxrt.Width.SINGLE: "1.8e", 268 pxrt.Width.DOUBLE: "1.16e", 269 }[pxrt.Width(self._kernel.dtype)] 270 271 entry = [] 272 _range = list(map(range, self._kernel.shape)) 273 for idx in itertools.product(*_range): 274 idx_c = [i - c for (i, c) in zip(idx, self._center)] 275 idx_c = ",".join(map(str, idx_c)) 276 277 cst = self._kernel[idx] 278 if np.isclose(cst, 0): 279 # no useless look-ups at runtime 280 e = None 281 elif np.isclose(cst, 1): 282 # no multiplication required 283 e = f"a[0,{idx_c}]" 284 else: 285 # general case 286 e = f"({cst:{f_fmt}} * a[0,{idx_c}])" 287 288 if e is not None: 289 entry.append(e) 290 291 spec = " + ".join(entry) 292 return spec 293 294 295class _Stencil_CP(_Stencil): 296 def _gen_code(self) -> types.ModuleType: 297 # Generate the code which should be compiled -------------------------- 298 sig_spec = (self._kernel.dtype, self._kernel.ndim + 1, True) 299 signature = _signature((sig_spec,) * 2, None) 300 301 template_file = plib.Path(__file__).parent / "_template_gpu.txt" 302 with open(template_file, mode="r") as f: 303 template = string.Template(f.read()) 304 code = template.substitute( 305 kernel_center=str(tuple(self._center.tolist())), 306 kernel_width=str(self._kernel.shape), 307 signature=signature, 308 stencil_spec=self.__stencil_spec(), 309 unravel_spec=self.__unravel_spec(), 310 ) 311 # --------------------------------------------------------------------- 312 313 # Store/update cached version as needed. 314 module_name = pxu.cache_module(code) 315 pxcfg.cache_dir(load=True) # make the Pyxu cache importable (if not already done) 316 jit_module = pxu.import_module(module_name) 317 return jit_module 318 319 def _configure_dispatcher(self, pb_size: int, **kwargs) -> cabc.Callable: 320 # Set (`threadsperblock`, `blockspergrid`) 321 assert set(kwargs.keys()) <= { 322 "threadsperblock", 323 "blockspergrid", 324 } 325 326 attr = self._kernel.device.attributes 327 tpb = kwargs.get("threadsperblock", attr["MaxThreadsPerBlock"]) 328 bpg = kwargs.get("blockspergrid", (pb_size // tpb) + 1) 329 return self._dispatch[bpg, tpb] 330 331 def __stencil_spec(self) -> str: 332 f_fmt = { # coef float-formatter 333 pxrt.Width.SINGLE: "1.8e", 334 pxrt.Width.DOUBLE: "1.16e", 335 }[pxrt.Width(self._kernel.dtype)] 336 337 entry = [] 338 _range = list(map(range, self._kernel.shape)) 339 for idx in itertools.product(*_range): 340 # create string of form "idx[1]+i1,...,idx[K]+iK" 341 idx_c = [i - c for (i, c) in zip(idx, self._center)] 342 idx_c = [f"idx[{i1}]{i2:+d}" for (i1, i2) in enumerate(idx_c, start=1)] 343 idx_c = ",".join(idx_c) 344 345 cst = self._kernel[idx] 346 if np.isclose(cst, 0): 347 # no useless look-ups at runtime 348 e = None 349 elif np.isclose(cst, 1): 350 # no multiplication required 351 e = f"arr[idx[0],{idx_c}]" 352 else: 353 # general case 354 e = f"({cst:{f_fmt}} * arr[idx[0],{idx_c}])" 355 356 if e is not None: 357 entry.append(e) 358 359 spec = " + ".join(entry) 360 return spec 361 362 def __unravel_spec(self) -> str: 363 N = self._kernel.ndim + 1 # 1 stack-dim 364 entry = [] 365 366 # left = offset 367 e = "left = offset" 368 entry.append(e) 369 370 # blk = prod(shape) 371 e = "blk = " + " * ".join([f"shape[{n}]" for n in range(N)]) 372 entry.append(e) 373 374 for n in range(N): 375 # blk //= shape[n] 376 e = f"blk //= shape[{n}]" 377 entry.append(e) 378 379 # i{n} = left // blk 380 e = f"i{n} = left // blk" 381 entry.append(e) 382 383 # left -= i{n} * blk 384 e = f"left -= i{n} * blk" 385 entry.append(e) 386 387 # idx = (i0, ..., i{N}) 388 e = "idx = (" + ", ".join([f"i{n}" for n in range(N)]) + ")" 389 entry.append(e) 390 391 # indent each entry by 4, then concatenate 392 for i in range(1, len(entry)): # 1st line skipped 393 entry[i] = " " + entry[i] 394 spec = "\n".join(entry) 395 return spec