1"""
2Low-level functions used to define user-facing stencils.
3"""
4import collections.abc as cabc
5import itertools
6import pathlib as plib
7import string
8import types
9
10import numpy as np
11
12import pyxu.info.config as pxcfg
13import pyxu.info.deps as pxd
14import pyxu.info.ptype as pxt
15import pyxu.runtime as pxrt
16import pyxu.util as pxu
17
18
19def _signature(params, returns) -> str:
20 # Translate a signature of the form
21 # [in_1_spec, ..., in_N_spec] -> out_spec
22 # to Numba's string representation.
23 #
24 # Parameters
25 # ----------
26 # params: list(spec)
27 # returns: spec | None
28 #
29 # Returns
30 # -------
31 # sig: str
32 #
33 # Notes
34 # -----
35 # A parameter spec is characterized by the triplet
36 # (dtype[single/double], ndim[int], c_contiguous[bool])
37 def fmt(spec) -> str:
38 dtype, ndim, c_contiguous = spec
39
40 _dtype_spec = {
41 pxrt.Width.SINGLE: "float32",
42 pxrt.Width.DOUBLE: "float64",
43 }[pxrt.Width(dtype)]
44
45 dim_spec = [":"] * ndim
46 if c_contiguous and (ndim > 0):
47 dim_spec[-1] = "::1"
48 dim_spec = "[" + ",".join(dim_spec) + "]"
49
50 _repr = _dtype_spec
51 if ndim > 0:
52 _repr += dim_spec
53 return _repr
54
55 sig = "".join(
56 [
57 "void" if (returns is None) else fmt(returns),
58 "(",
59 ", ".join(map(fmt, params)),
60 ")",
61 ]
62 )
63 return sig
64
65
[docs]
66class _Stencil:
67 """
68 Multi-dimensional JIT-compiled stencil. (Low-level function.)
69
70 This low-level class creates a gu-vectorized stencil applicable on multiple inputs simultaneously.
71 Only NUMPY/CUPY arrays are accepted.
72
73 Create instances via factory method :py:meth:`~pyxu.operator._Stencil.init`.
74
75 Example
76 -------
77 Correlate a stack of images `A` with a (3, 3) kernel such that:
78
79 .. math::
80
81 B[n, m] = A[n-1, m] + A[n, m-1] + A[n, m+1] + A[n+1, m]
82
83 .. code-block:: python3
84
85 import numpy as np
86 from pyxu.operator import _Stencil
87
88 # create the stencil
89 kernel = np.array([[0, 1, 0],
90 [1, 0, 1],
91 [0, 1, 0]], dtype=np.float64)
92 center = (1, 1)
93 stencil = _Stencil.init(kernel, center)
94
95 # apply it to the data
96 rng = np.random.default_rng()
97 A = rng.normal(size=(2, 3, 4, 30, 30)) # 24 images of size (30, 30)
98 B = np.zeros_like(A)
99 stencil.apply(A, B) # (2, 3, 4, 30, 30)
100 """
101
102 IndexSpec = cabc.Sequence[pxt.Integer]
103
[docs]
104 @staticmethod
105 def init(
106 kernel: pxt.NDArray,
107 center: IndexSpec,
108 ):
109 """
110 Parameters
111 ----------
112 kernel: NDArray
113 (k1,...,kD) kernel coefficients.
114
115 Only float32/64 kernels are supported.
116 center: ~pyxu.operator._Stencil.IndexSpec
117 (D,) index of the kernel's center.
118
119 Returns
120 -------
121 st: ~pyxu.operator._Stencil
122 Rank-D stencil.
123 """
124 dtype = kernel.dtype
125 if dtype not in {_.value for _ in pxrt.Width}:
126 raise ValueError(f"Unsupported kernel precision {dtype}.")
127
128 center = np.array(center, dtype=int)
129 assert center.size == kernel.ndim
130 assert np.all((0 <= center) & (center < kernel.shape))
131
132 N = pxd.NDArrayInfo
133 ndi = N.from_obj(kernel)
134 if ndi == N.NUMPY:
135 klass = _Stencil_NP
136 elif ndi == N.CUPY:
137 klass = _Stencil_CP
138 else:
139 raise NotImplementedError
140
141 st = klass(kernel, center)
142 return st
143
[docs]
144 def apply(
145 self,
146 arr: pxt.NDArray,
147 out: pxt.NDArray,
148 **kwargs,
149 ) -> pxt.NDArray:
150 r"""
151 Evaluate stencil on multiple inputs.
152
153 Parameters
154 ----------
155 arr: NDArray
156 (..., M1,...,MD) data to process.
157 out: NDArray
158 (..., M1,...,MD) array to which outputs are written.
159 kwargs: dict
160 Extra kwargs to configure `f_jit()`, the Dispatcher instance created by Numba.
161
162 Only relevant for GPU stencils, with values:
163
164 * blockspergrid: int
165 * threadsperblock: int
166
167 Default values are chosen if unspecified.
168
169 Returns
170 -------
171 out: NDArray
172 (..., M1,...,MD) outputs.
173
174 Notes
175 -----
176 * `arr` and `out` must have the same type/dtype as the kernel used during instantiation.
177 * Index regions in `out` where the stencil is not fully supported are set to 0.
178 * :py:meth:`~pyxu.operator._Stencil.apply` may raise ``CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES`` when the number of
179 GPU registers required exceeds resource limits. There are 2 solutions to this problem:
180
181 (1) Pass the `max_registers` kwarg to f_jit()'s decorator; or
182 (2) `Limit the number of threads per block <https://stackoverflow.com/a/68659008>`_.
183
184 (1) must be set at compile time; it is thus left unbounded.
185 (2) is accessible through .apply(\*\*kwargs).
186 """
187 assert arr.dtype == out.dtype == self._kernel.dtype
188 assert arr.shape == out.shape
189 assert arr.flags.c_contiguous and out.flags.c_contiguous
190
191 K_dim = len(self._kernel.shape)
192 dim_shape = arr.shape[-K_dim:]
193
194 stencil = self._configure_dispatcher(arr.size, **kwargs)
195 stencil(
196 # OK since NP/CP input constraint.
197 arr.reshape(-1, *dim_shape),
198 out.reshape(-1, *dim_shape),
199 )
200 return out
201
202 def __init__(
203 self,
204 kernel: pxt.NDArray,
205 center: pxt.NDArray,
206 ):
207 self._kernel = kernel
208 self._center = center
209
210 cached_module = self._gen_code()
211 self._dispatch = cached_module.f_jit
212
213 def _gen_code(self) -> types.ModuleType:
214 # Compile Numba kernel `void f_jit(arr, out)`.
215 #
216 # The code is compiled only if unavailable beforehand.
217 #
218 # Returns
219 # -------
220 # jit_module: module
221 # A (loaded) python package containing method f_jit().
222 raise NotImplementedError
223
224 def _configure_dispatcher(self, pb_size: int, **kwargs) -> cabc.Callable:
225 # Configure `f_jit()`, the Numba Dispatcher instance.
226 #
227 # Parameters
228 # ----------
229 # pb_size: int
230 # Number of stencil evaluations.
231 # **kwargs: dict
232 #
233 # Returns
234 # -------
235 # f: callable
236 # Configured Numba Dispatcher.
237 raise NotImplementedError
238
239
240class _Stencil_NP(_Stencil):
241 def _gen_code(self) -> types.ModuleType:
242 # Generate the code which should be compiled --------------------------
243 sig_spec = (self._kernel.dtype, self._kernel.ndim + 1, True)
244 signature = _signature((sig_spec,) * 2, None)
245
246 template_file = plib.Path(__file__).parent / "_template_cpu.txt"
247 with open(template_file, mode="r") as f:
248 template = string.Template(f.read())
249 code = template.substitute(
250 signature=signature,
251 stencil_spec=self.__stencil_spec(),
252 )
253 # ---------------------------------------------------------------------
254
255 # Store/update cached version as needed.
256 module_name = pxu.cache_module(code)
257 pxcfg.cache_dir(load=True) # make the Pyxu cache importable (if not already done)
258 jit_module = pxu.import_module(module_name)
259 return jit_module
260
261 def _configure_dispatcher(self, pb_size: int, **kwargs) -> cabc.Callable:
262 # Nothing to do for CPU targets.
263 return self._dispatch
264
265 def __stencil_spec(self) -> str:
266 f_fmt = { # coef float-formatter
267 pxrt.Width.SINGLE: "1.8e",
268 pxrt.Width.DOUBLE: "1.16e",
269 }[pxrt.Width(self._kernel.dtype)]
270
271 entry = []
272 _range = list(map(range, self._kernel.shape))
273 for idx in itertools.product(*_range):
274 idx_c = [i - c for (i, c) in zip(idx, self._center)]
275 idx_c = ",".join(map(str, idx_c))
276
277 cst = self._kernel[idx]
278 if np.isclose(cst, 0):
279 # no useless look-ups at runtime
280 e = None
281 elif np.isclose(cst, 1):
282 # no multiplication required
283 e = f"a[0,{idx_c}]"
284 else:
285 # general case
286 e = f"({cst:{f_fmt}} * a[0,{idx_c}])"
287
288 if e is not None:
289 entry.append(e)
290
291 spec = " + ".join(entry)
292 return spec
293
294
295class _Stencil_CP(_Stencil):
296 def _gen_code(self) -> types.ModuleType:
297 # Generate the code which should be compiled --------------------------
298 sig_spec = (self._kernel.dtype, self._kernel.ndim + 1, True)
299 signature = _signature((sig_spec,) * 2, None)
300
301 template_file = plib.Path(__file__).parent / "_template_gpu.txt"
302 with open(template_file, mode="r") as f:
303 template = string.Template(f.read())
304 code = template.substitute(
305 kernel_center=str(tuple(self._center)),
306 kernel_width=str(self._kernel.shape),
307 signature=signature,
308 stencil_spec=self.__stencil_spec(),
309 unravel_spec=self.__unravel_spec(),
310 )
311 # ---------------------------------------------------------------------
312
313 # Store/update cached version as needed.
314 module_name = pxu.cache_module(code)
315 pxcfg.cache_dir(load=True) # make the Pyxu cache importable (if not already done)
316 jit_module = pxu.import_module(module_name)
317 return jit_module
318
319 def _configure_dispatcher(self, pb_size: int, **kwargs) -> cabc.Callable:
320 # Set (`threadsperblock`, `blockspergrid`)
321 assert set(kwargs.keys()) <= {
322 "threadsperblock",
323 "blockspergrid",
324 }
325
326 attr = self._kernel.device.attributes
327 tpb = kwargs.get("threadsperblock", attr["MaxThreadsPerBlock"])
328 bpg = kwargs.get("blockspergrid", (pb_size // tpb) + 1)
329 return self._dispatch[bpg, tpb]
330
331 def __stencil_spec(self) -> str:
332 f_fmt = { # coef float-formatter
333 pxrt.Width.SINGLE: "1.8e",
334 pxrt.Width.DOUBLE: "1.16e",
335 }[pxrt.Width(self._kernel.dtype)]
336
337 entry = []
338 _range = list(map(range, self._kernel.shape))
339 for idx in itertools.product(*_range):
340 # create string of form "idx[1]+i1,...,idx[K]+iK"
341 idx_c = [i - c for (i, c) in zip(idx, self._center)]
342 idx_c = [f"idx[{i1}]{i2:+d}" for (i1, i2) in enumerate(idx_c, start=1)]
343 idx_c = ",".join(idx_c)
344
345 cst = self._kernel[idx]
346 if np.isclose(cst, 0):
347 # no useless look-ups at runtime
348 e = None
349 elif np.isclose(cst, 1):
350 # no multiplication required
351 e = f"arr[idx[0],{idx_c}]"
352 else:
353 # general case
354 e = f"({cst:{f_fmt}} * arr[idx[0],{idx_c}])"
355
356 if e is not None:
357 entry.append(e)
358
359 spec = " + ".join(entry)
360 return spec
361
362 def __unravel_spec(self) -> str:
363 N = self._kernel.ndim + 1 # 1 stack-dim
364 entry = []
365
366 # left = offset
367 e = "left = offset"
368 entry.append(e)
369
370 # blk = prod(shape)
371 e = "blk = " + " * ".join([f"shape[{n}]" for n in range(N)])
372 entry.append(e)
373
374 for n in range(N):
375 # blk //= shape[n]
376 e = f"blk //= shape[{n}]"
377 entry.append(e)
378
379 # i{n} = left // blk
380 e = f"i{n} = left // blk"
381 entry.append(e)
382
383 # left -= i{n} * blk
384 e = f"left -= i{n} * blk"
385 entry.append(e)
386
387 # idx = (i0, ..., i{N})
388 e = "idx = (" + ", ".join([f"i{n}" for n in range(N)]) + ")"
389 entry.append(e)
390
391 # indent each entry by 4, then concatenate
392 for i in range(1, len(entry)): # 1st line skipped
393 entry[i] = " " + entry[i]
394 spec = "\n".join(entry)
395 return spec