Source code for pyxu.operator.misc

  1import types
  2
  3import numpy as np
  4
  5import pyxu.abc as pxa
  6import pyxu.info.deps as pxd
  7import pyxu.info.ptype as pxt
  8import pyxu.runtime as pxrt
  9import pyxu.util as pxu
 10
 11__all__ = [
 12    "BroadcastAxes",
 13    "RechunkAxes",
 14    "ReshapeAxes",
 15    "SqueezeAxes",
 16    "TransposeAxes",
 17]
 18
 19
[docs] 20class TransposeAxes(pxa.UnitOp): 21 """ 22 Reverse or permute the axes of an array. 23 """ 24
[docs] 25 def __init__( 26 self, 27 dim_shape: pxt.NDArrayShape, 28 axes: pxt.NDArrayAxis = None, 29 ): 30 """ 31 Parameters 32 ---------- 33 axes: NDArrayAxis 34 New axis order. 35 36 If specified, must be a tuple or list which contains a permutation of [0,1,...,D-1]. 37 All axes are reversed if unspecified. (Default) 38 """ 39 super().__init__( 40 dim_shape=dim_shape, 41 codim_shape=dim_shape, # preliminary; just to get dim_rank computed correctly. 42 ) 43 44 if axes is None: 45 axes = np.arange(self.dim_rank)[::-1] 46 axes = pxu.as_canonical_axes(axes, rank=self.dim_rank) 47 assert len(axes) == len(set(axes)) == self.dim_rank # right number of axes provided & no duplicates 48 49 # update codim to right shape 50 self._codim_shape = tuple(self.dim_shape[ax] for ax in axes) 51 self._axes_fw = axes 52 self._axes_bw = tuple(np.argsort(axes))
53 54 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 55 sh = arr.shape[: -self.dim_rank] 56 N = len(sh) 57 axes = tuple(range(N)) + tuple(N + ax for ax in self._axes_fw) 58 out = arr.transpose(axes) 59 return out 60 61 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 62 sh = arr.shape[: -self.codim_rank] 63 N = len(sh) 64 axes = tuple(range(N)) + tuple(N + ax for ax in self._axes_bw) 65 out = arr.transpose(axes) 66 return out
67 68
[docs] 69class SqueezeAxes(pxa.UnitOp): 70 """ 71 Remove axes of length one. 72 """ 73
[docs] 74 def __init__( 75 self, 76 dim_shape: pxt.NDArrayShape, 77 axes: pxt.NDArrayAxis = None, 78 ): 79 """ 80 Parameters 81 ---------- 82 axes: NDArrayAxis 83 Axes to drop. 84 85 If unspecified, all axes of shape 1 will be dropped. 86 If an axis is selected with shape greater than 1, an error is raised. 87 88 Notes 89 ----- 90 * 1D arrays cannot be squeezed. 91 * Given a D-dimensional input, at most D-1 dimensions may be dropped. 92 """ 93 super().__init__( 94 dim_shape=dim_shape, 95 codim_shape=dim_shape, # preliminary; just to get dim_rank computed correctly. 96 ) 97 98 dim_shape = np.array(self.dim_shape) # for advanced indexing below. 99 if axes is None: 100 axes = np.arange(self.dim_rank)[dim_shape == 1] 101 axes = pxu.as_canonical_axes(axes, rank=self.dim_rank) 102 axes = np.unique(axes) # drop duplicates 103 if len(axes) > 0: 104 assert np.all(dim_shape[axes] == 1) # only squeezing size-1 dimensions 105 assert len(axes) < self.dim_rank # cannot squeeze to 0d array. 106 107 # update codim to right shape 108 self._codim_shape = tuple(dim_shape[ax] for ax in range(self.dim_rank) if ax not in axes) 109 self._idx_fw = tuple(0 if (ax in axes) else slice(None) for ax in range(self.dim_rank)) 110 self._idx_bw = tuple(np.newaxis if (ax in axes) else slice(None) for ax in range(self.dim_rank))
111 112 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 113 out = arr[..., *self._idx_fw] 114 return out 115 116 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 117 out = arr[..., *self._idx_bw] 118 return out
119 120
[docs] 121class ReshapeAxes(pxa.UnitOp): 122 """ 123 Reshape an array. 124 125 Notes 126 ----- 127 * If an integer, then the result will be a 1D array of that length. One co-dimension can be -1. In this case, the 128 value is inferred from the length of the array and remaining dimensions. 129 * Reshaping DASK inputs may be sub-optimal based on the array's chunk structure: use at your own risk. 130 """ 131 132 def __init__( 133 self, 134 dim_shape: pxt.NDArrayShape, 135 codim_shape: pxt.NDArrayShape, 136 ): 137 dim_shape = pxu.as_canonical_shape(dim_shape) 138 codim_shape = pxu.as_canonical_shape(codim_shape) 139 140 if all(ax >= 1 for ax in codim_shape): 141 pass # all good 142 elif sum(ax == -1 for ax in codim_shape) == 1: 143 # infer missing dimension value 144 size = np.prod(dim_shape) // abs(np.prod(codim_shape)) 145 146 codim_shape = list(codim_shape) 147 codim_shape[codim_shape.index(-1)] = size 148 else: 149 raise ValueError("Only one -1 entry allowed.") 150 151 super().__init__( 152 dim_shape=dim_shape, 153 codim_shape=codim_shape, 154 ) 155 assert self.dim_size == self.codim_size # reshaping does not change cell count. 156 157 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 158 sh = arr.shape[: -self.dim_rank] 159 out = arr.reshape(*sh, *self.codim_shape) 160 return out 161 162 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 163 sh = arr.shape[: -self.codim_rank] 164 out = arr.reshape(*sh, *self.dim_shape) 165 return out 166
[docs] 167 def cogram(self) -> pxt.OpT: 168 from pyxu.operator import IdentityOp 169 170 return IdentityOp(dim_shape=self.codim_shape)
171 172
[docs] 173class BroadcastAxes(pxa.LinOp): 174 """ 175 Broadcast an array. 176 """ 177 178 def __init__( 179 self, 180 dim_shape: pxt.NDArrayShape, 181 codim_shape: pxt.NDArrayShape, 182 ): 183 super().__init__( 184 dim_shape=dim_shape, 185 codim_shape=codim_shape, 186 ) 187 188 # Fail if not broadcastable. 189 assert self.codim_size >= self.dim_size 190 np.broadcast_shapes(self.dim_shape, self.codim_shape) 191 192 self.lipschitz = self.estimate_lipschitz() 193 194 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 195 # Compute (expand,) assuming no stacking dimensions. 196 rank_diff = self.codim_rank - self.dim_rank 197 expand = (np.newaxis,) * rank_diff 198 199 # Extend (expand,) to handle stacking dimensions. 200 sh = arr.shape[: -self.dim_rank] 201 expand = ((slice(None),) * len(sh)) + expand 202 203 xp = pxu.get_array_module(arr) 204 y = xp.broadcast_to( 205 arr[expand], 206 (*sh, *self.codim_shape), 207 ) 208 return pxu.read_only(y) 209 210 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 211 # Compute (axis, select) assuming no stacking dimensions. 212 rank_diff = self.codim_rank - self.dim_rank 213 dim_shape_bcast = ((1,) * rank_diff) + self.dim_shape 214 axis = filter( 215 lambda i: self.codim_shape[i] != dim_shape_bcast[i], 216 range(self.codim_rank), 217 ) 218 select = (0,) * rank_diff 219 220 # Extend (axis, select) to handle stacking dimensions 221 sh = arr.shape[: -self.codim_rank] 222 axis = tuple(ax + len(sh) for ax in axis) 223 select = ((slice(None),) * len(sh)) + select 224 225 y = arr.sum(axis=axis, keepdims=True)[select] 226 return y 227
[docs] 228 def estimate_lipschitz(self, **kwargs) -> pxt.Real: 229 L = np.sqrt(self.codim_size / self.dim_size) 230 return L
231
[docs] 232 def svdvals(self, **kwargs) -> pxt.NDArray: 233 gpu = kwargs.get("gpu", False) 234 xp = pxd.NDArrayInfo.from_flag(gpu).module() 235 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value) 236 237 D = xp.full(kwargs["k"], self.lipschitz, dtype=dtype) 238 return D
239
[docs] 240 def gram(self) -> pxt.OpT: 241 from pyxu.operator import HomothetyOp 242 243 op = HomothetyOp( 244 dim_shape=self.dim_shape, 245 cst=self.codim_size / self.dim_size, 246 ) 247 return op
248
[docs] 249 def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray: 250 out = pxu.copy_if_unsafe(self.adjoint(arr)) 251 cst = self.codim_size / self.dim_size 252 out /= cst + damp 253 return out
254
[docs] 255 def dagger(self, damp: pxt.Real, **kwargs) -> pxt.OpT: 256 cst = self.codim_size / self.dim_size 257 op = self.T / (cst + damp) 258 return op
259 260
[docs] 261def RechunkAxes(dim_shape: pxt.NDArrayShape, chunks: dict) -> pxt.OpT: 262 """ 263 Re-chunk core dimensions to new chunk size. 264 265 Parameters 266 ---------- 267 dim_shape: NDArrayShape 268 chunks: dict 269 (ax -> chunk_size) mapping, where `chunk_size` can be: 270 271 * int (non-negative) 272 * tuple[int] 273 274 The following special values per axis can also be used: 275 276 * None: do not change chunks. 277 * -1: do not chunk. 278 * "auto": select a good chunk size. 279 280 Returns 281 ------- 282 op: OpT 283 284 Notes 285 ----- 286 * :py:meth:`~pyxu.abc.Map.apply` is a no-op if inputs are not DASK arrays. 287 * :py:meth:`~pyxu.abc.LinOp.adjoint` is always a no-op. 288 * Chunk sizes along stacking dimensions are not modified. 289 """ 290 from pyxu.operator import IdentityOp 291 292 # Create chunking UnitOp 293 op = IdentityOp(dim_shape=dim_shape).asop(pxa.UnitOp) 294 op._name = "RechunkAxes" 295 296 # Canonicalize & store chunks 297 assert isinstance(chunks, dict) 298 op._chunks = { 299 pxu.as_canonical_axes( 300 ax, 301 rank=op.dim_rank, 302 )[0]: chunk_size 303 for (ax, chunk_size) in chunks.items() 304 } 305 306 # Update op.apply() to perform re-chunking 307 def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray: 308 ndi = pxd.NDArrayInfo.from_obj(arr) 309 if ndi != pxd.NDArrayInfo.DASK: 310 out = pxu.read_only(arr) 311 else: 312 stack_rank = len(arr.shape[: -_.dim_rank]) 313 stack_chunks = {ax: arr.chunks[ax] for ax in range(stack_rank)} 314 core_chunks = {ax + stack_rank: chk for (ax, chk) in _._chunks.items()} 315 out = arr.rechunk(chunks=stack_chunks | core_chunks) 316 return out 317 318 op.apply = types.MethodType(op_apply, op) 319 op.__call__ = types.MethodType(op_apply, op) 320 321 # To print 'RechunkAxes' in place of 'IdentityOp' 322 # [Consequence of `asop()`'s forwarding principle.] 323 op._expr = types.MethodType(lambda _: (_,), op) 324 325 return op