1import collections.abc as cabc
2import typing as typ
3
4import numpy as np
5
6import pyxu.abc as pxa
7import pyxu.info.deps as pxd
8import pyxu.info.ptype as pxt
9import pyxu.operator.interop.source as px_src
10import pyxu.util as pxu
11
12__all__ = [
13 "SubSample",
14 "Trim",
15]
16
17
[docs]
18class SubSample(pxa.LinOp):
19 r"""
20 Multi-dimensional sub-sampling operator.
21
22 This operator extracts a subset of the input matching the provided subset specifier.
23 Its Lipschitz constant is 1.
24
25 Examples
26 --------
27 .. code-block:: python3
28
29 ### Extract even samples of a 1D signal.
30 import pyxu.operator as pxo
31 x = np.arange(10)
32 S = pxo.SubSample(
33 x.shape,
34 slice(0, None, 2),
35 )
36 y = S(x) # array([0, 2, 4, 6, 8])
37
38
39 .. code-block:: python3
40
41 ### Extract columns[1, 3, -1] from a 2D matrix
42 import pyxu.operator as pxo
43 x = np.arange(3 * 40).reshape(3, 40) # the input
44 S = pxo.SubSample(
45 x.shape,
46 slice(None), # take all rows
47 [1, 3, -1], # and these columns
48 )
49 y = S(x) # array([[ 1., 3., 39.],
50 # [ 41., 43., 79.],
51 # [ 81., 83., 119.]])
52
53 .. code-block:: python3
54
55 ### Extract all red rows of an (D,H,W) RGB image matching a boolean mask.
56 import pyxu.operator as pxo
57 x = np.arange(3 * 5 * 4).reshape(3, 5, 4)
58 mask = np.r_[True, False, False, True, False]
59 S = pxo.SubSample(
60 x.shape,
61 0, # red channel
62 mask, # row selector
63 slice(None), # all columns; this field can be omitted.
64 )
65 y = S(x) # array([[[ 0, 1, 2, 3],
66 # [12, 13, 14, 15]]])
67 """
68 IndexSpec = typ.Union[
69 pxt.Integer,
70 cabc.Sequence[pxt.Integer],
71 cabc.Sequence[bool],
72 slice,
73 ]
74
75 TrimSpec = typ.Union[
76 pxt.Integer,
77 cabc.Sequence[pxt.Integer],
78 cabc.Sequence[tuple[pxt.Integer, pxt.Integer]],
79 ]
80
[docs]
81 def __init__(
82 self,
83 dim_shape: pxt.NDArrayShape,
84 *indices: IndexSpec,
85 ):
86 """
87 Parameters
88 ----------
89 dim_shape: NDArrayShape
90 (M1,...,MD) domain dimensions.
91 indices: ~pyxu.operator.SubSample.IndexSpec
92 Sub-sample specifier per dimension. (See examples.)
93
94 Valid specifiers are:
95
96 * integers
97 * 1D sequence of int/bool-s
98 * slices
99
100 Unspecified trailing dimensions are not sub-sampled.
101
102 Notes
103 -----
104 The co-dimension rank **always** matches the dimension rank, i.e. sub-sampling does not drop dimensions.
105 Single-element dimensions can be removed by composing :py:class:`~pyxu.operator.SubSample` with
106 :py:class:`~pyxu.operator.SqueezeAxes`.
107 """
108 super().__init__(
109 dim_shape=dim_shape,
110 codim_shape=dim_shape, # temporary; just to validate dim_shape
111 )
112 assert 1 <= len(indices) <= self.dim_rank
113
114 # Explicitize missing trailing indices.
115 idx = [slice(None)] * self.dim_rank
116 for i, _idx in enumerate(indices):
117 idx[i] = _idx
118
119 # Replace integer indices with slices.
120 for i, _idx in enumerate(idx):
121 if isinstance(_idx, pxt.Integer):
122 M = self.dim_shape[i]
123 _idx = (_idx + M) % M # get rid of negative indices
124 idx[i] = slice(_idx, _idx + 1)
125
126 # Compute output shape, then re-instantiate `self`.
127 self._idx = tuple(idx)
128 out = np.broadcast_to(0, self.dim_shape)[self._idx]
129 super().__init__(
130 dim_shape=dim_shape,
131 codim_shape=out.shape,
132 )
133 self.lipschitz = 1
134
[docs]
135 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
136 """
137 Sub-sample the data.
138
139 Parameters
140 ----------
141 arr: NDArray
142 (..., M1,...,MD) data points.
143
144 Returns
145 -------
146 out: NDArray
147 (..., N1,..,NK) sub-sampled data points.
148 """
149 sh = arr.shape[: -self.dim_rank]
150
151 selector = ((slice(None),) * len(sh)) + self._idx
152 out = arr[selector]
153 return pxu.read_only(out)
154
[docs]
155 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
156 """
157 Up-sample the data.
158
159 Parameters
160 ----------
161 arr: NDArray
162 (..., N1,...,NK) data points.
163
164 Returns
165 -------
166 out: NDArray
167 (..., M1,...,MD) up-sampled data points. (Zero-filled.)
168 """
169 sh = arr.shape[: -self.codim_rank]
170
171 ndi = pxd.NDArrayInfo.from_obj(arr)
172 kwargs = dict(
173 shape=(*sh, *self.dim_shape),
174 dtype=arr.dtype,
175 )
176 if ndi == pxd.NDArrayInfo.DASK:
177 stack_chunks = arr.chunks[: -self.codim_rank]
178 core_chunks = ("auto",) * self.dim_rank
179 kwargs.update(chunks=stack_chunks + core_chunks)
180 out = ndi.module().zeros(**kwargs)
181
182 selector = ((slice(None),) * len(sh)) + self._idx
183 out[selector] = arr
184 return out
185
186 def svdvals(self, **kwargs) -> pxt.NDArray:
187 D = pxa.UnitOp.svdvals(self, **kwargs)
188 return D
189
190 def gram(self) -> pxt.OpT:
191 def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray:
192 _op = _._op
193 out = _op.adjoint(_op.apply(arr))
194 return out
195
196 G = px_src.from_source(
197 cls=pxa.OrthProjOp,
198 dim_shape=self.dim_shape,
199 codim_shape=self.dim_shape,
200 embed=dict(_op=self),
201 apply=op_apply,
202 )
203 return G
204
205 def cogram(self) -> pxt.OpT:
206 from pyxu.operator import IdentityOp
207
208 CG = IdentityOp(dim_shape=self.codim_shape)
209 return CG
210
211 def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray:
212 out = self.adjoint(arr)
213 out /= 1 + damp
214 return out
215
216 def dagger(self, damp: pxt.Real, **kwargs) -> pxt.OpT:
217 op = self.T / (1 + damp)
218 return op
219
220
[docs]
221def Trim(
222 dim_shape: pxt.NDArrayShape,
223 trim_width: SubSample.TrimSpec,
224) -> pxt.OpT:
225 """
226 Multi-dimensional trimming operator.
227
228 This operator trims the input array in each dimension according to specified widths.
229
230 Parameters
231 ----------
232 dim_shape: NDArrayShape
233 (M1,...,MD) domain dimensions.
234 trim_width: ~pyxu.operator.SubSample.TrimSpec
235 Number of values trimmed from the edges of each axis.
236 Multiple forms are accepted:
237
238 * ``int``: trim each dimension's head/tail by `trim_width`.
239 * ``tuple[int, ...]``: trim dimension[k]'s head/tail by `trim_width[k]`.
240 * ``tuple[tuple[int, int], ...]``: trim dimension[k]'s head/tail by `trim_width[k][0]` / `trim_width[k][1]`
241 respectively.
242
243 Returns
244 -------
245 op: OpT
246 """
247 dim_shape = pxu.as_canonical_shape(dim_shape)
248 N_dim = len(dim_shape)
249
250 # transform `trim_width` to canonical form tuple[tuple[int, int]]
251 is_seq = lambda _: isinstance(_, cabc.Sequence)
252 if not is_seq(trim_width): # int-form
253 trim_width = ((trim_width, trim_width),) * N_dim
254 assert len(trim_width) == N_dim, "dim_shape/trim_width are length-mismatched."
255 if not is_seq(trim_width[0]): # tuple[int, ...] form
256 trim_width = tuple((w, w) for w in trim_width)
257 else: # tuple[tuple[int, int], ...] form
258 pass
259
260 # translate `trim_width` to `indices` needed for SubSample
261 indices = []
262 for (w_head, w_tail), dim_size in zip(trim_width, dim_shape):
263 s = slice(w_head, dim_size - w_tail)
264 indices.append(s)
265
266 op = SubSample(dim_shape, *indices)
267 return op