1import collections.abc as cabc
2import typing as typ
4import numpy as np
6import pyxu.abc as pxa
7import pyxu.info.ptype as pxt
8import pyxu.util as pxu
10__all__ = [
11 "Pad",
15class Pad(pxa.LinOp):
16 r"""
17 Multi-dimensional padding operator.
19 This operator pads the input array in each dimension according to specified widths.
21 Notes
22 -----
23 * If inputs are D-dimensional, then some of the padding of later axes are calculated from padding of previous axes.
24 * The *adjoint* of the padding operator performs a cumulative summation over the original positions used to pad.
25 Its effect is clear from its matrix form. For example the matrix-form of ``Pad(dim_shape=(3,), mode="wrap",
26 pad_width=(1, 1))`` is:
28 .. math::
30 \mathbf{A}
31 =
32 \left[
33 \begin{array}{ccc}
34 0 & 0 & 1 \\
35 1 & 0 & 0 \\
36 0 & 1 & 0 \\
37 0 & 0 & 1 \\
38 1 & 0 & 0
39 \end{array}
40 \right].
42 The adjoint of :math:`\mathbf{A}` corresponds to its matrix transpose:
44 .. math::
46 \mathbf{A}^{\ast}
47 =
48 \left[
49 \begin{array}{ccccc}
50 0 & 1 & 0 & 0 & 1 \\
51 0 & 0 & 1 & 0 & 0 \\
52 1 & 0 & 0 & 1 & 0
53 \end{array}
54 \right].
56 This operation can be seen as a trimming (:math:`\mathbf{T}`) plus a cumulative summation (:math:`\mathbf{S}`):
58 .. math::
60 \mathbf{A}^{\ast}
61 =
62 \mathbf{T} + \mathbf{S}
63 =
64 \left[
65 \begin{array}{ccccc}
66 0 & 1 & 0 & 0 & 0 \\
67 0 & 0 & 1 & 0 & 0 \\
68 0 & 0 & 0 & 1 & 0
69 \end{array}
70 \right]
71 +
72 \left[
73 \begin{array}{ccccc}
74 0 & 0 & 0 & 0 & 1 \\
75 0 & 0 & 0 & 0 & 0 \\
76 1 & 0 & 0 & 0 & 0
77 \end{array}
78 \right],
80 where both :math:`\mathbf{T}` and :math:`\mathbf{S}` are efficiently implemented in matrix-free form.
83 * The Lipschitz constant of the multi-dimensional padding operator is upper-bounded by the product of Lipschitz
84 constants of the uni-dimensional paddings applied per dimension, i.e.:
86 .. math::
88 L \le \prod_{i} L_{i}, \qquad i \in \{1, \ldots, D\},
90 where :math:`L_{i}` depends on the boundary condition at the :math:`i`-th axis.
92 :math:`L_{i}^{2}` corresponds to the maximum singular value of the diagonal matrix
94 .. math::
96 \mathbf{A}_{i}^{\ast} \mathbf{A}_{i}
97 =
98 \mathbf{T}_{i}^{\ast} \mathbf{T}_{i} + \mathbf{S}_{i}^{\ast} \mathbf{S}_{i}
99 =
100 \mathbf{I}_{N} + \mathbf{S}_{i}^{\ast} \mathbf{S}_{i}.
102 - In mode="constant", :math:`\text{diag}(\mathbf{S}_{i}^{\ast} \mathbf{S}_{i}) = \mathbf{0}`, hence :math:`L_{i} =
103 1`.
104 - In mode="edge",
106 .. math::
108 \text{diag}(\mathbf{S}_{i}^{\ast} \mathbf{S}_{i})
109 =
110 \left[p_{lhs}, 0, \ldots, 0, p_{rhs} \right],
112 hence :math:`L_{i} = \sqrt{1 + \max(p_{lhs}, p_{rhs})}`.
113 - In mode="symmetric", "wrap", "reflect", :math:`\text{diag}(\mathbf{S}_{i}^{\ast} \mathbf{S}_{i})` equals (up to
114 a mode-dependant permutation)
116 .. math::
118 \text{diag}(\mathbf{S}_{i}^{\ast} \mathbf{S}_{i})
119 =
120 \left[1, \ldots, 1, 0, \ldots, 0\right]
121 +
122 \left[0, \ldots, 0, 1, \ldots, 1\right],
124 hence
126 .. math::
128 L^{\text{wrap, symmetric}}_{i} = \sqrt{1 + \lceil\frac{p_{lhs} + p_{rhs}}{N}\rceil}, \\
129 L^{\text{reflect}}_{i} = \sqrt{1 + \lceil\frac{p_{lhs} + p_{rhs}}{N-2}\rceil}.
130 """
131 WidthSpec = typ.Union[
132 pxt.Integer,
133 cabc.Sequence[pxt.Integer],
134 cabc.Sequence[tuple[pxt.Integer, pxt.Integer]],
135 ]
136 ModeSpec = typ.Union[str, cabc.Sequence[str]]
138 def __init__(
139 self,
140 dim_shape: pxt.NDArrayShape,
141 pad_width: WidthSpec,
142 mode: ModeSpec = "constant",
143 ):
144 r"""
145 Parameters
146 ----------
147 dim_shape: NDArrayShape
148 (M1,...,MD) domain dimensions.
149 pad_width: ~pyxu.operator.linop.pad.Pad.WidthSpec
150 Number of values padded to the edges of each axis.
151 Multiple forms are accepted:
153 * ``int``: pad each dimension's head/tail by `pad_width`.
154 * ``tuple[int, ...]``: pad dimension[k]'s head/tail by `pad_width[k]`.
155 * ``tuple[tuple[int, int], ...]``: pad dimension[k]'s head/tail by `pad_width[k][0]` /
156 `pad_width[k][1]` respectively.
157 mode: str, :py:class:`list` ( str )
158 Padding mode.
159 Multiple forms are accepted:
161 * str: unique mode shared amongst dimensions.
162 Must be one of:
164 * 'constant' (zero-padding)
165 * 'wrap'
166 * 'reflect'
167 * 'symmetric'
168 * 'edge'
169 * tuple[str, ...]: pad dimension[k] using `mode[k]`.
171 (See :py:func:`numpy.pad` for details.)
172 """
173 dim_shape = pxu.as_canonical_shape(dim_shape)
174 dim_rank = len(dim_shape)
176 # transform `pad_width` to canonical form tuple[tuple[int, int], ...]
177 is_seq = lambda _: isinstance(_, cabc.Sequence)
178 if not is_seq(pad_width): # int-form
179 pad_width = ((pad_width, pad_width),) * dim_rank
180 assert len(pad_width) == dim_rank, "dim_shape/pad_width are length-mismatched."
181 if not is_seq(pad_width[0]): # tuple[int, ...] form
182 pad_width = tuple((w, w) for w in pad_width)
183 else: # tuple[tulpe[int, int], ...] form
184 pass
185 assert all(0 <= min(lhs, rhs) for (lhs, rhs) in pad_width)
186 pad_width = tuple(pad_width)
188 # transform `mode` to canonical form tuple[str, ...]
189 if isinstance(mode, str): # shared mode
190 mode = (mode,) * dim_rank
191 elif isinstance(mode, cabc.Sequence): # tuple[str, ...]: different modes
192 assert len(mode) == dim_rank, "dim_shape/mode are length-mismatched."
193 mode = tuple(mode)
194 else:
195 raise ValueError(f"Unkwown mode encountered: {mode}.")
196 mode = tuple(map(lambda _: _.strip().lower(), mode))
197 assert set(mode) <= {
198 "constant",
199 "wrap",
200 "reflect",
201 "symmetric",
202 "edge",
203 }, "Unknown mode(s) encountered."
205 # Some modes have awkward interpretations when pad-widths cross certain thresholds.
206 # Supported pad-widths are thus limited to sensible regions.
207 for i in range(dim_rank):
208 M = dim_shape[i]
209 w_max = dict(
210 constant=np.inf,
211 wrap=M,
212 reflect=M - 1,
213 symmetric=M,
214 edge=np.inf,
215 )[mode[i]]
216 lhs, rhs = pad_width[i]
217 assert max(lhs, rhs) <= w_max, f"pad_width along dim-{i} is limited to {w_max}."
219 # Instantiate op & store useful constants
220 codim_shape = list(dim_shape)
221 for i, (lhs, rhs) in enumerate(pad_width):
222 codim_shape[i] += lhs + rhs
223 super().__init__(
224 dim_shape=dim_shape,
225 codim_shape=codim_shape,
226 )
227 self._pad_width = pad_width
228 self._mode = mode
230 # We know a crude Lipschitz bound by default. Since computing it takes (code) space,
231 # the estimate is computed as a special case of estimate_lipschitz()
232 self.lipschitz = self.estimate_lipschitz(__rule=True)
234 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
235 sh = arr.shape[: -self.dim_rank]
237 # Part 1: extend the core
238 xp = pxu.get_array_module(arr)
239 pad_width_sh = ((0, 0),) * len(sh) # don't pad stack-dims
240 out = xp.pad(
241 array=arr,
242 pad_width=pad_width_sh + self._pad_width,
243 mode="constant",
244 constant_values=0,
245 )
247 # Part 2: apply border effects (if any)
248 for i in range(self.dim_rank, 0, -1):
249 mode = self._mode[-i]
250 lhs, rhs = self._pad_width[-i]
251 N = self.codim_shape[-i]
253 r_s = [slice(None)] * (len(sh) + self.dim_rank) # read axial selector
254 w_s = [slice(None)] * (len(sh) + self.dim_rank) # write axial selector
256 if mode == "constant":
257 # no border effects
258 pass
259 elif mode == "wrap":
260 if lhs > 0: # Fix LHS
261 r_s[-i] = slice(N - rhs - lhs, N - rhs)
262 w_s[-i] = slice(0, lhs)
263 out[tuple(w_s)] = out[tuple(r_s)]
265 if rhs > 0: # Fix RHS
266 r_s[-i] = slice(lhs, lhs + rhs)
267 w_s[-i] = slice(N - rhs, N)
268 out[tuple(w_s)] = out[tuple(r_s)]
269 elif mode == "reflect":
270 if lhs > 0: # Fix LHS
271 r_s[-i] = slice(2 * lhs, lhs, -1)
272 w_s[-i] = slice(0, lhs)
273 out[tuple(w_s)] = out[tuple(r_s)]
275 if rhs > 0: # Fix RHS
276 r_s[-i] = slice(N - rhs - 2, N - 2 * rhs - 2, -1)
277 w_s[-i] = slice(N - rhs, N)
278 out[tuple(w_s)] = out[tuple(r_s)]
279 elif mode == "symmetric":
280 if lhs > 0: # Fix LHS
281 r_s[-i] = slice(2 * lhs - 1, lhs - 1, -1)
282 w_s[-i] = slice(0, lhs)
283 out[tuple(w_s)] = out[tuple(r_s)]
285 if rhs > 0: # Fix RHS
286 r_s[-i] = slice(N - rhs - 1, N - 2 * rhs - 1, -1)
287 w_s[-i] = slice(N - rhs, N)
288 out[tuple(w_s)] = out[tuple(r_s)]
289 elif mode == "edge":
290 if lhs > 0: # Fix LHS
291 r_s[-i] = slice(lhs, lhs + 1)
292 w_s[-i] = slice(0, lhs)
293 out[tuple(w_s)] = out[tuple(r_s)]
295 if rhs > 0: # Fix RHS
296 r_s[-i] = slice(N - rhs - 1, N - rhs)
297 w_s[-i] = slice(N - rhs, N)
298 out[tuple(w_s)] = out[tuple(r_s)]
300 return out
302 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
303 sh = arr.shape[: -self.codim_rank]
305 # Part 1: apply correction terms (if any)
306 out = arr.copy() # in-place updates below
307 for i in range(1, self.codim_rank + 1):
308 mode = self._mode[-i]
309 lhs, rhs = self._pad_width[-i]
310 N = self.codim_shape[-i]
312 r_s = [slice(None)] * (len(sh) + self.codim_rank) # read axial selector
313 w_s = [slice(None)] * (len(sh) + self.codim_rank) # write axial selector
315 if mode == "constant":
316 # no correction required
317 pass
318 elif mode == "wrap":
319 if lhs > 0: # Fix LHS
320 r_s[-i] = slice(0, lhs)
321 w_s[-i] = slice(N - rhs - lhs, N - rhs)
322 out[tuple(w_s)] += out[tuple(r_s)]
324 if rhs > 0: # Fix RHS
325 r_s[-i] = slice(N - rhs, N)
326 w_s[-i] = slice(lhs, lhs + rhs)
327 out[tuple(w_s)] += out[tuple(r_s)]
328 elif mode == "reflect":
329 if lhs > 0: # Fix LHS
330 r_s[-i] = slice(lhs - 1, None, -1)
331 w_s[-i] = slice(lhs + 1, 2 * lhs + 1)
332 out[tuple(w_s)] += out[tuple(r_s)]
334 if rhs > 0: # Fix RHS
335 r_s[-i] = slice(N - 1, N - rhs - 1, -1)
336 w_s[-i] = slice(N - 2 * rhs - 1, N - rhs - 1)
337 out[tuple(w_s)] += out[tuple(r_s)]
338 elif mode == "symmetric":
339 if lhs > 0: # Fix LHS
340 r_s[-i] = slice(lhs - 1, None, -1)
341 w_s[-i] = slice(lhs, 2 * lhs)
342 out[tuple(w_s)] += out[tuple(r_s)]
344 if rhs > 0: # Fix RHS
345 r_s[-i] = slice(N - 1, N - rhs - 1, -1)
346 w_s[-i] = slice(N - 2 * rhs, N - rhs)
347 out[tuple(w_s)] += out[tuple(r_s)]
348 elif mode == "edge":
349 if lhs > 0: # Fix LHS
350 r_s[-i] = slice(0, lhs)
351 w_s[-i] = slice(lhs, lhs + 1)
352 out[tuple(w_s)] += out[tuple(r_s)].sum(axis=-i, keepdims=True)
354 if rhs > 0: # Fix RHS
355 r_s[-i] = slice(N - rhs, N)
356 w_s[-i] = slice(N - rhs - 1, N - rhs)
357 out[tuple(w_s)] += out[tuple(r_s)].sum(axis=-i, keepdims=True)
359 # Part 2: extract the core
360 selector = [slice(None)] * len(sh)
361 for N, (lhs, rhs) in zip(self.codim_shape, self._pad_width):
362 s = slice(lhs, N - rhs)
363 selector.append(s)
364 out = out[tuple(selector)]
366 return out
368 def estimate_lipschitz(self, **kwargs) -> pxt.Real:
369 no_eval = "__rule" in kwargs
370 if no_eval:
371 L = [] # 1D pad-op Lipschitz constants
372 for M, m, (lhs, rhs) in zip(self.dim_shape, self._mode, self._pad_width):
373 if m == "constant":
374 _L = 1
375 elif m in {"wrap", "symmetric"}:
376 _L = np.sqrt(1 + np.ceil((lhs + rhs) / M))
377 elif m == "reflect":
378 _L = np.sqrt(1 + np.ceil((lhs + rhs) / (M - 2)))
379 elif m == "edge":
380 _L = np.sqrt(1 + max(lhs, rhs))
381 L.append(_L)
382 L = np.prod(L)
383 else:
384 L = super().estimate_lipschitz(**kwargs)
385 return L
387 def gram(self) -> pxt.OpT:
388 if all(m == "constant" for m in self._mode):
389 from pyxu.operator import IdentityOp
391 op = IdentityOp(dim_shape=self.dim_shape)
392 else:
393 op = super().gram()
394 return op
396 def cogram(self) -> pxt.OpT:
397 if all(m == "constant" for m in self._mode):
398 from pyxu.operator import Trim
400 # Orthogonal projection
401 op = Trim(
402 dim_shape=self.codim_shape,
403 trim_width=self._pad_width,
404 ).gram()
405 else:
406 op = super().cogram()
407 return op