1import types
2
3import numpy as np
4
5import pyxu.abc as pxa
6import pyxu.info.deps as pxd
7import pyxu.info.ptype as pxt
8import pyxu.runtime as pxrt
9import pyxu.util as pxu
10
11__all__ = [
12 "BroadcastAxes",
13 "RechunkAxes",
14 "ReshapeAxes",
15 "SqueezeAxes",
16 "TransposeAxes",
17]
18
19
[docs]
20class TransposeAxes(pxa.UnitOp):
21 """
22 Reverse or permute the axes of an array.
23 """
24
[docs]
25 def __init__(
26 self,
27 dim_shape: pxt.NDArrayShape,
28 axes: pxt.NDArrayAxis = None,
29 ):
30 """
31 Parameters
32 ----------
33 axes: NDArrayAxis
34 New axis order.
35
36 If specified, must be a tuple or list which contains a permutation of [0,1,...,D-1].
37 All axes are reversed if unspecified. (Default)
38 """
39 super().__init__(
40 dim_shape=dim_shape,
41 codim_shape=dim_shape, # preliminary; just to get dim_rank computed correctly.
42 )
43
44 if axes is None:
45 axes = np.arange(self.dim_rank)[::-1]
46 axes = pxu.as_canonical_axes(axes, rank=self.dim_rank)
47 assert len(axes) == len(set(axes)) == self.dim_rank # right number of axes provided & no duplicates
48
49 # update codim to right shape
50 self._codim_shape = tuple(self.dim_shape[ax] for ax in axes)
51 self._axes_fw = axes
52 self._axes_bw = tuple(np.argsort(axes))
53
54 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
55 sh = arr.shape[: -self.dim_rank]
56 N = len(sh)
57 axes = tuple(range(N)) + tuple(N + ax for ax in self._axes_fw)
58 out = arr.transpose(axes)
59 return out
60
61 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
62 sh = arr.shape[: -self.codim_rank]
63 N = len(sh)
64 axes = tuple(range(N)) + tuple(N + ax for ax in self._axes_bw)
65 out = arr.transpose(axes)
66 return out
67
68
[docs]
69class SqueezeAxes(pxa.UnitOp):
70 """
71 Remove axes of length one.
72 """
73
[docs]
74 def __init__(
75 self,
76 dim_shape: pxt.NDArrayShape,
77 axes: pxt.NDArrayAxis = None,
78 ):
79 """
80 Parameters
81 ----------
82 axes: NDArrayAxis
83 Axes to drop.
84
85 If unspecified, all axes of shape 1 will be dropped.
86 If an axis is selected with shape greater than 1, an error is raised.
87
88 Notes
89 -----
90 * 1D arrays cannot be squeezed.
91 * Given a D-dimensional input, at most D-1 dimensions may be dropped.
92 """
93 super().__init__(
94 dim_shape=dim_shape,
95 codim_shape=dim_shape, # preliminary; just to get dim_rank computed correctly.
96 )
97
98 dim_shape = np.array(self.dim_shape) # for advanced indexing below.
99 if axes is None:
100 axes = np.arange(self.dim_rank)[dim_shape == 1]
101 axes = pxu.as_canonical_axes(axes, rank=self.dim_rank)
102 axes = np.unique(axes) # drop duplicates
103 if len(axes) > 0:
104 assert np.all(dim_shape[axes] == 1) # only squeezing size-1 dimensions
105 assert len(axes) < self.dim_rank # cannot squeeze to 0d array.
106
107 # update codim to right shape
108 self._codim_shape = tuple(dim_shape[ax] for ax in range(self.dim_rank) if ax not in axes)
109 self._idx_fw = tuple(0 if (ax in axes) else slice(None) for ax in range(self.dim_rank))
110 self._idx_bw = tuple(np.newaxis if (ax in axes) else slice(None) for ax in range(self.dim_rank))
111
112 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
113 out = arr[..., *self._idx_fw]
114 return out
115
116 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
117 out = arr[..., *self._idx_bw]
118 return out
119
120
[docs]
121class ReshapeAxes(pxa.UnitOp):
122 """
123 Reshape an array.
124
125 Notes
126 -----
127 * If an integer, then the result will be a 1D array of that length. One co-dimension can be -1. In this case, the
128 value is inferred from the length of the array and remaining dimensions.
129 * Reshaping DASK inputs may be sub-optimal based on the array's chunk structure: use at your own risk.
130 """
131
132 def __init__(
133 self,
134 dim_shape: pxt.NDArrayShape,
135 codim_shape: pxt.NDArrayShape,
136 ):
137 dim_shape = pxu.as_canonical_shape(dim_shape)
138 codim_shape = pxu.as_canonical_shape(codim_shape)
139
140 if all(ax >= 1 for ax in codim_shape):
141 pass # all good
142 elif sum(ax == -1 for ax in codim_shape) == 1:
143 # infer missing dimension value
144 size = np.prod(dim_shape) // abs(np.prod(codim_shape))
145
146 codim_shape = list(codim_shape)
147 codim_shape[codim_shape.index(-1)] = size
148 else:
149 raise ValueError("Only one -1 entry allowed.")
150
151 super().__init__(
152 dim_shape=dim_shape,
153 codim_shape=codim_shape,
154 )
155 assert self.dim_size == self.codim_size # reshaping does not change cell count.
156
157 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
158 sh = arr.shape[: -self.dim_rank]
159 out = arr.reshape(*sh, *self.codim_shape)
160 return out
161
162 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
163 sh = arr.shape[: -self.codim_rank]
164 out = arr.reshape(*sh, *self.dim_shape)
165 return out
166
[docs]
167 def cogram(self) -> pxt.OpT:
168 from pyxu.operator import IdentityOp
169
170 return IdentityOp(dim_shape=self.codim_shape)
171
172
[docs]
173class BroadcastAxes(pxa.LinOp):
174 """
175 Broadcast an array.
176 """
177
178 def __init__(
179 self,
180 dim_shape: pxt.NDArrayShape,
181 codim_shape: pxt.NDArrayShape,
182 ):
183 super().__init__(
184 dim_shape=dim_shape,
185 codim_shape=codim_shape,
186 )
187
188 # Fail if not broadcastable.
189 assert self.codim_size >= self.dim_size
190 np.broadcast_shapes(self.dim_shape, self.codim_shape)
191
192 self.lipschitz = self.estimate_lipschitz()
193
194 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
195 # Compute (expand,) assuming no stacking dimensions.
196 rank_diff = self.codim_rank - self.dim_rank
197 expand = (np.newaxis,) * rank_diff
198
199 # Extend (expand,) to handle stacking dimensions.
200 sh = arr.shape[: -self.dim_rank]
201 expand = ((slice(None),) * len(sh)) + expand
202
203 xp = pxu.get_array_module(arr)
204 y = xp.broadcast_to(
205 arr[expand],
206 (*sh, *self.codim_shape),
207 )
208 return pxu.read_only(y)
209
210 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
211 # Compute (axis, select) assuming no stacking dimensions.
212 rank_diff = self.codim_rank - self.dim_rank
213 dim_shape_bcast = ((1,) * rank_diff) + self.dim_shape
214 axis = filter(
215 lambda i: self.codim_shape[i] != dim_shape_bcast[i],
216 range(self.codim_rank),
217 )
218 select = (0,) * rank_diff
219
220 # Extend (axis, select) to handle stacking dimensions
221 sh = arr.shape[: -self.codim_rank]
222 axis = tuple(ax + len(sh) for ax in axis)
223 select = ((slice(None),) * len(sh)) + select
224
225 y = arr.sum(axis=axis, keepdims=True)[select]
226 return y
227
[docs]
228 def estimate_lipschitz(self, **kwargs) -> pxt.Real:
229 L = np.sqrt(self.codim_size / self.dim_size)
230 return L
231
[docs]
232 def svdvals(self, **kwargs) -> pxt.NDArray:
233 gpu = kwargs.get("gpu", False)
234 xp = pxd.NDArrayInfo.from_flag(gpu).module()
235 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value)
236
237 D = xp.full(kwargs["k"], self.lipschitz, dtype=dtype)
238 return D
239
[docs]
240 def gram(self) -> pxt.OpT:
241 from pyxu.operator import HomothetyOp
242
243 op = HomothetyOp(
244 dim_shape=self.dim_shape,
245 cst=self.codim_size / self.dim_size,
246 )
247 return op
248
[docs]
249 def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray:
250 out = pxu.copy_if_unsafe(self.adjoint(arr))
251 cst = self.codim_size / self.dim_size
252 out /= cst + damp
253 return out
254
[docs]
255 def dagger(self, damp: pxt.Real, **kwargs) -> pxt.OpT:
256 cst = self.codim_size / self.dim_size
257 op = self.T / (cst + damp)
258 return op
259
260
[docs]
261def RechunkAxes(dim_shape: pxt.NDArrayShape, chunks: dict) -> pxt.OpT:
262 """
263 Re-chunk core dimensions to new chunk size.
264
265 Parameters
266 ----------
267 dim_shape: NDArrayShape
268 chunks: dict
269 (ax -> chunk_size) mapping, where `chunk_size` can be:
270
271 * int (non-negative)
272 * tuple[int]
273
274 The following special values per axis can also be used:
275
276 * None: do not change chunks.
277 * -1: do not chunk.
278 * "auto": select a good chunk size.
279
280 Returns
281 -------
282 op: OpT
283
284 Notes
285 -----
286 * :py:meth:`~pyxu.abc.Map.apply` is a no-op if inputs are not DASK arrays.
287 * :py:meth:`~pyxu.abc.LinOp.adjoint` is always a no-op.
288 * Chunk sizes along stacking dimensions are not modified.
289 """
290 from pyxu.operator import IdentityOp
291
292 # Create chunking UnitOp
293 op = IdentityOp(dim_shape=dim_shape).asop(pxa.UnitOp)
294 op._name = "RechunkAxes"
295
296 # Canonicalize & store chunks
297 assert isinstance(chunks, dict)
298 op._chunks = {
299 pxu.as_canonical_axes(
300 ax,
301 rank=op.dim_rank,
302 )[0]: chunk_size
303 for (ax, chunk_size) in chunks.items()
304 }
305
306 # Update op.apply() to perform re-chunking
307 def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray:
308 ndi = pxd.NDArrayInfo.from_obj(arr)
309 if ndi != pxd.NDArrayInfo.DASK:
310 out = pxu.read_only(arr)
311 else:
312 stack_rank = len(arr.shape[: -_.dim_rank])
313 stack_chunks = {ax: arr.chunks[ax] for ax in range(stack_rank)}
314 core_chunks = {ax + stack_rank: chk for (ax, chk) in _._chunks.items()}
315 out = arr.rechunk(chunks=stack_chunks | core_chunks)
316 return out
317
318 op.apply = types.MethodType(op_apply, op)
319 op.__call__ = types.MethodType(op_apply, op)
320
321 # To print 'RechunkAxes' in place of 'IdentityOp'
322 # [Consequence of `asop()`'s forwarding principle.]
323 op._expr = types.MethodType(lambda _: (_,), op)
324
325 return op