Source code for pyxu.operator.linop.pad

  1import collections.abc as cabc
  2import typing as typ
  3
  4import numpy as np
  5
  6import pyxu.abc as pxa
  7import pyxu.info.ptype as pxt
  8import pyxu.util as pxu
  9
 10__all__ = [
 11    "Pad",
 12]
 13
 14
[docs] 15class Pad(pxa.LinOp): 16 r""" 17 Multi-dimensional padding operator. 18 19 This operator pads the input array in each dimension according to specified widths. 20 21 Notes 22 ----- 23 * If inputs are D-dimensional, then some of the padding of later axes are calculated from padding of previous axes. 24 * The *adjoint* of the padding operator performs a cumulative summation over the original positions used to pad. 25 Its effect is clear from its matrix form. For example the matrix-form of ``Pad(dim_shape=(3,), mode="wrap", 26 pad_width=(1, 1))`` is: 27 28 .. math:: 29 30 \mathbf{A} 31 = 32 \left[ 33 \begin{array}{ccc} 34 0 & 0 & 1 \\ 35 1 & 0 & 0 \\ 36 0 & 1 & 0 \\ 37 0 & 0 & 1 \\ 38 1 & 0 & 0 39 \end{array} 40 \right]. 41 42 The adjoint of :math:`\mathbf{A}` corresponds to its matrix transpose: 43 44 .. math:: 45 46 \mathbf{A}^{\ast} 47 = 48 \left[ 49 \begin{array}{ccccc} 50 0 & 1 & 0 & 0 & 1 \\ 51 0 & 0 & 1 & 0 & 0 \\ 52 1 & 0 & 0 & 1 & 0 53 \end{array} 54 \right]. 55 56 This operation can be seen as a trimming (:math:`\mathbf{T}`) plus a cumulative summation (:math:`\mathbf{S}`): 57 58 .. math:: 59 60 \mathbf{A}^{\ast} 61 = 62 \mathbf{T} + \mathbf{S} 63 = 64 \left[ 65 \begin{array}{ccccc} 66 0 & 1 & 0 & 0 & 0 \\ 67 0 & 0 & 1 & 0 & 0 \\ 68 0 & 0 & 0 & 1 & 0 69 \end{array} 70 \right] 71 + 72 \left[ 73 \begin{array}{ccccc} 74 0 & 0 & 0 & 0 & 1 \\ 75 0 & 0 & 0 & 0 & 0 \\ 76 1 & 0 & 0 & 0 & 0 77 \end{array} 78 \right], 79 80 where both :math:`\mathbf{T}` and :math:`\mathbf{S}` are efficiently implemented in matrix-free form. 81 82 83 * The Lipschitz constant of the multi-dimensional padding operator is upper-bounded by the product of Lipschitz 84 constants of the uni-dimensional paddings applied per dimension, i.e.: 85 86 .. math:: 87 88 L \le \prod_{i} L_{i}, \qquad i \in \{1, \ldots, D\}, 89 90 where :math:`L_{i}` depends on the boundary condition at the :math:`i`-th axis. 91 92 :math:`L_{i}^{2}` corresponds to the maximum singular value of the diagonal matrix 93 94 .. math:: 95 96 \mathbf{A}_{i}^{\ast} \mathbf{A}_{i} 97 = 98 \mathbf{T}_{i}^{\ast} \mathbf{T}_{i} + \mathbf{S}_{i}^{\ast} \mathbf{S}_{i} 99 = 100 \mathbf{I}_{N} + \mathbf{S}_{i}^{\ast} \mathbf{S}_{i}. 101 102 - In mode="constant", :math:`\text{diag}(\mathbf{S}_{i}^{\ast} \mathbf{S}_{i}) = \mathbf{0}`, hence :math:`L_{i} = 103 1`. 104 - In mode="edge", 105 106 .. math:: 107 108 \text{diag}(\mathbf{S}_{i}^{\ast} \mathbf{S}_{i}) 109 = 110 \left[p_{lhs}, 0, \ldots, 0, p_{rhs} \right], 111 112 hence :math:`L_{i} = \sqrt{1 + \max(p_{lhs}, p_{rhs})}`. 113 - In mode="symmetric", "wrap", "reflect", :math:`\text{diag}(\mathbf{S}_{i}^{\ast} \mathbf{S}_{i})` equals (up to 114 a mode-dependant permutation) 115 116 .. math:: 117 118 \text{diag}(\mathbf{S}_{i}^{\ast} \mathbf{S}_{i}) 119 = 120 \left[1, \ldots, 1, 0, \ldots, 0\right] 121 + 122 \left[0, \ldots, 0, 1, \ldots, 1\right], 123 124 hence 125 126 .. math:: 127 128 L^{\text{wrap, symmetric}}_{i} = \sqrt{1 + \lceil\frac{p_{lhs} + p_{rhs}}{N}\rceil}, \\ 129 L^{\text{reflect}}_{i} = \sqrt{1 + \lceil\frac{p_{lhs} + p_{rhs}}{N-2}\rceil}. 130 """ 131 WidthSpec = typ.Union[ 132 pxt.Integer, 133 cabc.Sequence[pxt.Integer], 134 cabc.Sequence[tuple[pxt.Integer, pxt.Integer]], 135 ] 136 ModeSpec = typ.Union[str, cabc.Sequence[str]] 137
[docs] 138 def __init__( 139 self, 140 dim_shape: pxt.NDArrayShape, 141 pad_width: WidthSpec, 142 mode: ModeSpec = "constant", 143 ): 144 r""" 145 Parameters 146 ---------- 147 dim_shape: NDArrayShape 148 (M1,...,MD) domain dimensions. 149 pad_width: ~pyxu.operator.linop.pad.Pad.WidthSpec 150 Number of values padded to the edges of each axis. 151 Multiple forms are accepted: 152 153 * ``int``: pad each dimension's head/tail by `pad_width`. 154 * ``tuple[int, ...]``: pad dimension[k]'s head/tail by `pad_width[k]`. 155 * ``tuple[tuple[int, int], ...]``: pad dimension[k]'s head/tail by `pad_width[k][0]` / 156 `pad_width[k][1]` respectively. 157 mode: str, :py:class:`list` ( str ) 158 Padding mode. 159 Multiple forms are accepted: 160 161 * str: unique mode shared amongst dimensions. 162 Must be one of: 163 164 * 'constant' (zero-padding) 165 * 'wrap' 166 * 'reflect' 167 * 'symmetric' 168 * 'edge' 169 * tuple[str, ...]: pad dimension[k] using `mode[k]`. 170 171 (See :py:func:`numpy.pad` for details.) 172 """ 173 dim_shape = pxu.as_canonical_shape(dim_shape) 174 dim_rank = len(dim_shape) 175 176 # transform `pad_width` to canonical form tuple[tuple[int, int], ...] 177 is_seq = lambda _: isinstance(_, cabc.Sequence) 178 if not is_seq(pad_width): # int-form 179 pad_width = ((pad_width, pad_width),) * dim_rank 180 assert len(pad_width) == dim_rank, "dim_shape/pad_width are length-mismatched." 181 if not is_seq(pad_width[0]): # tuple[int, ...] form 182 pad_width = tuple((w, w) for w in pad_width) 183 else: # tuple[tulpe[int, int], ...] form 184 pass 185 assert all(0 <= min(lhs, rhs) for (lhs, rhs) in pad_width) 186 pad_width = tuple(pad_width) 187 188 # transform `mode` to canonical form tuple[str, ...] 189 if isinstance(mode, str): # shared mode 190 mode = (mode,) * dim_rank 191 elif isinstance(mode, cabc.Sequence): # tuple[str, ...]: different modes 192 assert len(mode) == dim_rank, "dim_shape/mode are length-mismatched." 193 mode = tuple(mode) 194 else: 195 raise ValueError(f"Unkwown mode encountered: {mode}.") 196 mode = tuple(map(lambda _: _.strip().lower(), mode)) 197 assert set(mode) <= { 198 "constant", 199 "wrap", 200 "reflect", 201 "symmetric", 202 "edge", 203 }, "Unknown mode(s) encountered." 204 205 # Some modes have awkward interpretations when pad-widths cross certain thresholds. 206 # Supported pad-widths are thus limited to sensible regions. 207 for i in range(dim_rank): 208 M = dim_shape[i] 209 w_max = dict( 210 constant=np.inf, 211 wrap=M, 212 reflect=M - 1, 213 symmetric=M, 214 edge=np.inf, 215 )[mode[i]] 216 lhs, rhs = pad_width[i] 217 assert max(lhs, rhs) <= w_max, f"pad_width along dim-{i} is limited to {w_max}." 218 219 # Instantiate op & store useful constants 220 codim_shape = list(dim_shape) 221 for i, (lhs, rhs) in enumerate(pad_width): 222 codim_shape[i] += lhs + rhs 223 super().__init__( 224 dim_shape=dim_shape, 225 codim_shape=codim_shape, 226 ) 227 self._pad_width = pad_width 228 self._mode = mode 229 230 # We know a crude Lipschitz bound by default. Since computing it takes (code) space, 231 # the estimate is computed as a special case of estimate_lipschitz() 232 self.lipschitz = self.estimate_lipschitz(__rule=True)
233 234 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 235 sh = arr.shape[: -self.dim_rank] 236 237 # Part 1: extend the core 238 xp = pxu.get_array_module(arr) 239 pad_width_sh = ((0, 0),) * len(sh) # don't pad stack-dims 240 out = xp.pad( 241 array=arr, 242 pad_width=pad_width_sh + self._pad_width, 243 mode="constant", 244 constant_values=0, 245 ) 246 247 # Part 2: apply border effects (if any) 248 for i in range(self.dim_rank, 0, -1): 249 mode = self._mode[-i] 250 lhs, rhs = self._pad_width[-i] 251 N = self.codim_shape[-i] 252 253 r_s = [slice(None)] * (len(sh) + self.dim_rank) # read axial selector 254 w_s = [slice(None)] * (len(sh) + self.dim_rank) # write axial selector 255 256 if mode == "constant": 257 # no border effects 258 pass 259 elif mode == "wrap": 260 if lhs > 0: # Fix LHS 261 r_s[-i] = slice(N - rhs - lhs, N - rhs) 262 w_s[-i] = slice(0, lhs) 263 out[tuple(w_s)] = out[tuple(r_s)] 264 265 if rhs > 0: # Fix RHS 266 r_s[-i] = slice(lhs, lhs + rhs) 267 w_s[-i] = slice(N - rhs, N) 268 out[tuple(w_s)] = out[tuple(r_s)] 269 elif mode == "reflect": 270 if lhs > 0: # Fix LHS 271 r_s[-i] = slice(2 * lhs, lhs, -1) 272 w_s[-i] = slice(0, lhs) 273 out[tuple(w_s)] = out[tuple(r_s)] 274 275 if rhs > 0: # Fix RHS 276 r_s[-i] = slice(N - rhs - 2, N - 2 * rhs - 2, -1) 277 w_s[-i] = slice(N - rhs, N) 278 out[tuple(w_s)] = out[tuple(r_s)] 279 elif mode == "symmetric": 280 if lhs > 0: # Fix LHS 281 r_s[-i] = slice(2 * lhs - 1, lhs - 1, -1) 282 w_s[-i] = slice(0, lhs) 283 out[tuple(w_s)] = out[tuple(r_s)] 284 285 if rhs > 0: # Fix RHS 286 r_s[-i] = slice(N - rhs - 1, N - 2 * rhs - 1, -1) 287 w_s[-i] = slice(N - rhs, N) 288 out[tuple(w_s)] = out[tuple(r_s)] 289 elif mode == "edge": 290 if lhs > 0: # Fix LHS 291 r_s[-i] = slice(lhs, lhs + 1) 292 w_s[-i] = slice(0, lhs) 293 out[tuple(w_s)] = out[tuple(r_s)] 294 295 if rhs > 0: # Fix RHS 296 r_s[-i] = slice(N - rhs - 1, N - rhs) 297 w_s[-i] = slice(N - rhs, N) 298 out[tuple(w_s)] = out[tuple(r_s)] 299 300 return out 301 302 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 303 sh = arr.shape[: -self.codim_rank] 304 305 # Part 1: apply correction terms (if any) 306 out = arr.copy() # in-place updates below 307 for i in range(1, self.codim_rank + 1): 308 mode = self._mode[-i] 309 lhs, rhs = self._pad_width[-i] 310 N = self.codim_shape[-i] 311 312 r_s = [slice(None)] * (len(sh) + self.codim_rank) # read axial selector 313 w_s = [slice(None)] * (len(sh) + self.codim_rank) # write axial selector 314 315 if mode == "constant": 316 # no correction required 317 pass 318 elif mode == "wrap": 319 if lhs > 0: # Fix LHS 320 r_s[-i] = slice(0, lhs) 321 w_s[-i] = slice(N - rhs - lhs, N - rhs) 322 out[tuple(w_s)] += out[tuple(r_s)] 323 324 if rhs > 0: # Fix RHS 325 r_s[-i] = slice(N - rhs, N) 326 w_s[-i] = slice(lhs, lhs + rhs) 327 out[tuple(w_s)] += out[tuple(r_s)] 328 elif mode == "reflect": 329 if lhs > 0: # Fix LHS 330 r_s[-i] = slice(lhs - 1, None, -1) 331 w_s[-i] = slice(lhs + 1, 2 * lhs + 1) 332 out[tuple(w_s)] += out[tuple(r_s)] 333 334 if rhs > 0: # Fix RHS 335 r_s[-i] = slice(N - 1, N - rhs - 1, -1) 336 w_s[-i] = slice(N - 2 * rhs - 1, N - rhs - 1) 337 out[tuple(w_s)] += out[tuple(r_s)] 338 elif mode == "symmetric": 339 if lhs > 0: # Fix LHS 340 r_s[-i] = slice(lhs - 1, None, -1) 341 w_s[-i] = slice(lhs, 2 * lhs) 342 out[tuple(w_s)] += out[tuple(r_s)] 343 344 if rhs > 0: # Fix RHS 345 r_s[-i] = slice(N - 1, N - rhs - 1, -1) 346 w_s[-i] = slice(N - 2 * rhs, N - rhs) 347 out[tuple(w_s)] += out[tuple(r_s)] 348 elif mode == "edge": 349 if lhs > 0: # Fix LHS 350 r_s[-i] = slice(0, lhs) 351 w_s[-i] = slice(lhs, lhs + 1) 352 out[tuple(w_s)] += out[tuple(r_s)].sum(axis=-i, keepdims=True) 353 354 if rhs > 0: # Fix RHS 355 r_s[-i] = slice(N - rhs, N) 356 w_s[-i] = slice(N - rhs - 1, N - rhs) 357 out[tuple(w_s)] += out[tuple(r_s)].sum(axis=-i, keepdims=True) 358 359 # Part 2: extract the core 360 selector = [slice(None)] * len(sh) 361 for N, (lhs, rhs) in zip(self.codim_shape, self._pad_width): 362 s = slice(lhs, N - rhs) 363 selector.append(s) 364 out = out[tuple(selector)] 365 366 return out 367 368 def estimate_lipschitz(self, **kwargs) -> pxt.Real: 369 no_eval = "__rule" in kwargs 370 if no_eval: 371 L = [] # 1D pad-op Lipschitz constants 372 for M, m, (lhs, rhs) in zip(self.dim_shape, self._mode, self._pad_width): 373 if m == "constant": 374 _L = 1 375 elif m in {"wrap", "symmetric"}: 376 _L = np.sqrt(1 + np.ceil((lhs + rhs) / M)) 377 elif m == "reflect": 378 _L = np.sqrt(1 + np.ceil((lhs + rhs) / (M - 2))) 379 elif m == "edge": 380 _L = np.sqrt(1 + max(lhs, rhs)) 381 L.append(_L) 382 L = np.prod(L) 383 else: 384 L = super().estimate_lipschitz(**kwargs) 385 return L 386 387 def gram(self) -> pxt.OpT: 388 if all(m == "constant" for m in self._mode): 389 from pyxu.operator import IdentityOp 390 391 op = IdentityOp(dim_shape=self.dim_shape) 392 else: 393 op = super().gram() 394 return op 395 396 def cogram(self) -> pxt.OpT: 397 if all(m == "constant" for m in self._mode): 398 from pyxu.operator import Trim 399 400 # Orthogonal projection 401 op = Trim( 402 dim_shape=self.codim_shape, 403 trim_width=self._pad_width, 404 ).gram() 405 else: 406 op = super().cogram() 407 return op