Source code for pyxu.operator.linop.select

  1import collections.abc as cabc
  2import typing as typ
  3
  4import numpy as np
  5
  6import pyxu.abc as pxa
  7import pyxu.info.deps as pxd
  8import pyxu.info.ptype as pxt
  9import pyxu.operator.interop.source as px_src
 10import pyxu.util as pxu
 11
 12__all__ = [
 13    "SubSample",
 14    "Trim",
 15]
 16
 17
[docs] 18class SubSample(pxa.LinOp): 19 r""" 20 Multi-dimensional sub-sampling operator. 21 22 This operator extracts a subset of the input matching the provided subset specifier. 23 Its Lipschitz constant is 1. 24 25 Examples 26 -------- 27 .. code-block:: python3 28 29 ### Extract even samples of a 1D signal. 30 import pyxu.operator as pxo 31 x = np.arange(10) 32 S = pxo.SubSample( 33 x.shape, 34 slice(0, None, 2), 35 ) 36 y = S(x) # array([0, 2, 4, 6, 8]) 37 38 39 .. code-block:: python3 40 41 ### Extract columns[1, 3, -1] from a 2D matrix 42 import pyxu.operator as pxo 43 x = np.arange(3 * 40).reshape(3, 40) # the input 44 S = pxo.SubSample( 45 x.shape, 46 slice(None), # take all rows 47 [1, 3, -1], # and these columns 48 ) 49 y = S(x) # array([[ 1., 3., 39.], 50 # [ 41., 43., 79.], 51 # [ 81., 83., 119.]]) 52 53 .. code-block:: python3 54 55 ### Extract all red rows of an (D,H,W) RGB image matching a boolean mask. 56 import pyxu.operator as pxo 57 x = np.arange(3 * 5 * 4).reshape(3, 5, 4) 58 mask = np.r_[True, False, False, True, False] 59 S = pxo.SubSample( 60 x.shape, 61 0, # red channel 62 mask, # row selector 63 slice(None), # all columns; this field can be omitted. 64 ) 65 y = S(x) # array([[[ 0, 1, 2, 3], 66 # [12, 13, 14, 15]]]) 67 """ 68 IndexSpec = typ.Union[ 69 pxt.Integer, 70 cabc.Sequence[pxt.Integer], 71 cabc.Sequence[bool], 72 slice, 73 ] 74 75 TrimSpec = typ.Union[ 76 pxt.Integer, 77 cabc.Sequence[pxt.Integer], 78 cabc.Sequence[tuple[pxt.Integer, pxt.Integer]], 79 ] 80
[docs] 81 def __init__( 82 self, 83 dim_shape: pxt.NDArrayShape, 84 *indices: IndexSpec, 85 ): 86 """ 87 Parameters 88 ---------- 89 dim_shape: NDArrayShape 90 (M1,...,MD) domain dimensions. 91 indices: ~pyxu.operator.SubSample.IndexSpec 92 Sub-sample specifier per dimension. (See examples.) 93 94 Valid specifiers are: 95 96 * integers 97 * 1D sequence of int/bool-s 98 * slices 99 100 Unspecified trailing dimensions are not sub-sampled. 101 102 Notes 103 ----- 104 The co-dimension rank **always** matches the dimension rank, i.e. sub-sampling does not drop dimensions. 105 Single-element dimensions can be removed by composing :py:class:`~pyxu.operator.SubSample` with 106 :py:class:`~pyxu.operator.SqueezeAxes`. 107 """ 108 super().__init__( 109 dim_shape=dim_shape, 110 codim_shape=dim_shape, # temporary; just to validate dim_shape 111 ) 112 assert 1 <= len(indices) <= self.dim_rank 113 114 # Explicitize missing trailing indices. 115 idx = [slice(None)] * self.dim_rank 116 for i, _idx in enumerate(indices): 117 idx[i] = _idx 118 119 # Replace integer indices with slices. 120 for i, _idx in enumerate(idx): 121 if isinstance(_idx, pxt.Integer): 122 M = self.dim_shape[i] 123 _idx = (_idx + M) % M # get rid of negative indices 124 idx[i] = slice(_idx, _idx + 1) 125 126 # Compute output shape, then re-instantiate `self`. 127 self._idx = tuple(idx) 128 out = np.broadcast_to(0, self.dim_shape)[self._idx] 129 super().__init__( 130 dim_shape=dim_shape, 131 codim_shape=out.shape, 132 ) 133 self.lipschitz = 1
134
[docs] 135 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 136 """ 137 Sub-sample the data. 138 139 Parameters 140 ---------- 141 arr: NDArray 142 (..., M1,...,MD) data points. 143 144 Returns 145 ------- 146 out: NDArray 147 (..., N1,..,NK) sub-sampled data points. 148 """ 149 sh = arr.shape[: -self.dim_rank] 150 151 selector = ((slice(None),) * len(sh)) + self._idx 152 out = arr[selector] 153 return pxu.read_only(out)
154
[docs] 155 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 156 """ 157 Up-sample the data. 158 159 Parameters 160 ---------- 161 arr: NDArray 162 (..., N1,...,NK) data points. 163 164 Returns 165 ------- 166 out: NDArray 167 (..., M1,...,MD) up-sampled data points. (Zero-filled.) 168 """ 169 sh = arr.shape[: -self.codim_rank] 170 171 ndi = pxd.NDArrayInfo.from_obj(arr) 172 kwargs = dict( 173 shape=(*sh, *self.dim_shape), 174 dtype=arr.dtype, 175 ) 176 if ndi == pxd.NDArrayInfo.DASK: 177 stack_chunks = arr.chunks[: -self.codim_rank] 178 core_chunks = ("auto",) * self.dim_rank 179 kwargs.update(chunks=stack_chunks + core_chunks) 180 out = ndi.module().zeros(**kwargs) 181 182 selector = ((slice(None),) * len(sh)) + self._idx 183 out[selector] = arr 184 return out
185 186 def svdvals(self, **kwargs) -> pxt.NDArray: 187 D = pxa.UnitOp.svdvals(self, **kwargs) 188 return D 189 190 def gram(self) -> pxt.OpT: 191 def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray: 192 _op = _._op 193 out = _op.adjoint(_op.apply(arr)) 194 return out 195 196 G = px_src.from_source( 197 cls=pxa.OrthProjOp, 198 dim_shape=self.dim_shape, 199 codim_shape=self.dim_shape, 200 embed=dict(_op=self), 201 apply=op_apply, 202 ) 203 return G 204 205 def cogram(self) -> pxt.OpT: 206 from pyxu.operator import IdentityOp 207 208 CG = IdentityOp(dim_shape=self.codim_shape) 209 return CG 210 211 def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray: 212 out = self.adjoint(arr) 213 out /= 1 + damp 214 return out 215 216 def dagger(self, damp: pxt.Real, **kwargs) -> pxt.OpT: 217 op = self.T / (1 + damp) 218 return op
219 220
[docs] 221def Trim( 222 dim_shape: pxt.NDArrayShape, 223 trim_width: SubSample.TrimSpec, 224) -> pxt.OpT: 225 """ 226 Multi-dimensional trimming operator. 227 228 This operator trims the input array in each dimension according to specified widths. 229 230 Parameters 231 ---------- 232 dim_shape: NDArrayShape 233 (M1,...,MD) domain dimensions. 234 trim_width: ~pyxu.operator.SubSample.TrimSpec 235 Number of values trimmed from the edges of each axis. 236 Multiple forms are accepted: 237 238 * ``int``: trim each dimension's head/tail by `trim_width`. 239 * ``tuple[int, ...]``: trim dimension[k]'s head/tail by `trim_width[k]`. 240 * ``tuple[tuple[int, int], ...]``: trim dimension[k]'s head/tail by `trim_width[k][0]` / `trim_width[k][1]` 241 respectively. 242 243 Returns 244 ------- 245 op: OpT 246 """ 247 dim_shape = pxu.as_canonical_shape(dim_shape) 248 N_dim = len(dim_shape) 249 250 # transform `trim_width` to canonical form tuple[tuple[int, int]] 251 is_seq = lambda _: isinstance(_, cabc.Sequence) 252 if not is_seq(trim_width): # int-form 253 trim_width = ((trim_width, trim_width),) * N_dim 254 assert len(trim_width) == N_dim, "dim_shape/trim_width are length-mismatched." 255 if not is_seq(trim_width[0]): # tuple[int, ...] form 256 trim_width = tuple((w, w) for w in trim_width) 257 else: # tuple[tuple[int, int], ...] form 258 pass 259 260 # translate `trim_width` to `indices` needed for SubSample 261 indices = [] 262 for (w_head, w_tail), dim_size in zip(trim_width, dim_shape): 263 s = slice(w_head, dim_size - w_tail) 264 indices.append(s) 265 266 op = SubSample(dim_shape, *indices) 267 return op