Source code for pyxu.operator.linop.base

  1import types
  2import typing as typ
  3import warnings
  4
  5import numpy as np
  6
  7import pyxu.abc as pxa
  8import pyxu.info.deps as pxd
  9import pyxu.info.ptype as pxt
 10import pyxu.info.warning as pxw
 11import pyxu.operator.interop.source as px_src
 12import pyxu.runtime as pxrt
 13import pyxu.util as pxu
 14
 15__all__ = [
 16    "IdentityOp",
 17    "NullOp",
 18    "NullFunc",
 19    "HomothetyOp",
 20    "DiagonalOp",
 21]
 22
 23
[docs] 24class IdentityOp(pxa.OrthProjOp): 25 """ 26 Identity operator. 27 """ 28 29 def __init__(self, dim_shape: pxt.NDArrayShape): 30 super().__init__( 31 dim_shape=dim_shape, 32 codim_shape=dim_shape, 33 ) 34 self.lipschitz = 1 35 36 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 37 return pxu.read_only(arr) 38 39 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 40 return pxu.read_only(arr) 41 42 def svdvals(self, **kwargs) -> pxt.NDArray: 43 return pxa.UnitOp.svdvals(self, **kwargs) 44 45 def asarray(self, **kwargs) -> pxt.NDArray: 46 xp = kwargs.get("xp", pxd.NDArrayInfo.default().module()) 47 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value) 48 A = xp.eye(N=self.dim_size, dtype=dtype) 49 B = A.reshape(*self.codim_shape, *self.dim_shape) 50 return B 51 52 def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray: 53 out = arr.copy() 54 out /= 1 + damp 55 return out 56 57 def dagger(self, damp: pxt.Real, **kwargs) -> pxt.OpT: 58 op = HomothetyOp( 59 cst=1 / (1 + damp), 60 dim_shape=self.dim_shape, 61 ) 62 return op 63 64 def trace(self, **kwargs) -> pxt.Real: 65 return self.dim_size
66 67
[docs] 68class NullOp(pxa.LinOp): 69 """ 70 Null operator. 71 72 This operator maps any input vector on the null vector. 73 """ 74 75 def __init__( 76 self, 77 dim_shape: pxt.NDArrayShape, 78 codim_shape: pxt.NDArrayShape, 79 ): 80 super().__init__( 81 dim_shape=dim_shape, 82 codim_shape=codim_shape, 83 ) 84 self.lipschitz = 0 85 86 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 87 ndi = pxd.NDArrayInfo.from_obj(arr) 88 kwargs = dict() 89 if ndi == pxd.NDArrayInfo.DASK: 90 stack_chunks = arr.chunks[: -self.dim_rank] 91 core_chunks = ("auto",) * self.codim_rank 92 kwargs.update(chunks=stack_chunks + core_chunks) 93 94 xp = ndi.module() 95 return xp.broadcast_to( 96 xp.array(0, arr.dtype), 97 (*arr.shape[: -self.dim_rank], *self.codim_shape), 98 **kwargs, 99 ) 100 101 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 102 ndi = pxd.NDArrayInfo.from_obj(arr) 103 kwargs = dict() 104 if ndi == pxd.NDArrayInfo.DASK: 105 stack_chunks = arr.chunks[: -self.codim_rank] 106 core_chunks = ("auto",) * self.dim_rank 107 kwargs.update(chunks=stack_chunks + core_chunks) 108 109 xp = ndi.module() 110 return xp.broadcast_to( 111 xp.array(0, arr.dtype), 112 (*arr.shape[: -self.codim_rank], *self.dim_shape), 113 **kwargs, 114 ) 115 116 def svdvals(self, **kwargs) -> pxt.NDArray: 117 gpu = kwargs.get("gpu", False) 118 xp = pxd.NDArrayInfo.from_flag(gpu).module() 119 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value) 120 121 D = xp.zeros(kwargs["k"], dtype=dtype) 122 return D 123 124 def gram(self) -> pxt.OpT: 125 op = NullOp( 126 dim_shape=self.dim_shape, 127 codim_shape=self.dim_shape, 128 ) 129 return op.asop(pxa.SelfAdjointOp) 130 131 def cogram(self) -> pxt.OpT: 132 op = NullOp( 133 dim_shape=self.codim_shape, 134 codim_shape=self.codim_shape, 135 ) 136 return op.asop(pxa.SelfAdjointOp) 137 138 def asarray(self, **kwargs) -> pxt.NDArray: 139 xp = kwargs.get("xp", pxd.NDArrayInfo.default().module()) 140 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value) 141 A = xp.broadcast_to( 142 xp.array(0, dtype=dtype), 143 (*self.codim_shape, *self.dim_shape), 144 ) 145 return A 146 147 def trace(self, **kwargs) -> pxt.Real: 148 return 0
149 150
[docs] 151def NullFunc(dim_shape: pxt.NDArrayShape) -> pxt.OpT: 152 """ 153 Null functional. 154 155 This functional maps any input vector on the null scalar. 156 """ 157 op = NullOp( 158 dim_shape=dim_shape, 159 codim_shape=1, 160 ).asop(pxa.LinFunc) 161 op._name = "NullFunc" 162 return op
163 164
[docs] 165def HomothetyOp(dim_shape: pxt.NDArrayShape, cst: pxt.Real) -> pxt.OpT: 166 """ 167 Constant scaling operator. 168 169 Parameters 170 ---------- 171 cst: Real 172 Scaling factor. 173 174 Returns 175 ------- 176 op: OpT 177 Scaling operator. 178 179 Note 180 ---- 181 This operator is not defined in terms of :py:func:`~pyxu.operator.DiagonalOp` since it is array-backend-agnostic. 182 """ 183 assert isinstance(cst, pxt.Real), f"cst: expected real, got {cst}." 184 185 if np.isclose(cst, 0): 186 op = NullOp( 187 dim_shape=dim_shape, 188 codim_shape=dim_shape, 189 ) 190 elif np.isclose(cst, 1): 191 op = IdentityOp(dim_shape=dim_shape) 192 else: # build PosDef or SelfAdjointOp 193 194 def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray: 195 out = arr.copy() 196 out *= _._cst 197 return out 198 199 def op_svdvals(_, **kwargs) -> pxt.NDArray: 200 gpu = kwargs.get("gpu", False) 201 xp = pxd.NDArrayInfo.from_flag(gpu).module() 202 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value) 203 204 D = xp.full( 205 shape=kwargs["k"], 206 fill_value=abs(_._cst), 207 dtype=dtype, 208 ) 209 return D 210 211 def op_pinv(_, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray: 212 out = arr.copy() 213 out *= _._cst / (_._cst**2 + damp) 214 return out 215 216 def op_dagger(_, damp: pxt.Real, **kwargs) -> pxt.OpT: 217 op = HomothetyOp( 218 cst=_._cst / (_._cst**2 + damp), 219 dim_shape=_.dim_shape, 220 ) 221 return op 222 223 def op_gram(_): 224 op = HomothetyOp( 225 cst=_._cst**2, 226 dim_shape=_.dim_shape, 227 ) 228 return op 229 230 def op_estimate_lipschitz(_, **kwargs) -> pxt.Real: 231 L = abs(_._cst) 232 return L 233 234 def op_trace(_, **kwargs): 235 out = _._cst * _.dim_size 236 return out 237 238 op = px_src.from_source( 239 cls=pxa.PosDefOp if (cst > 0) else pxa.SelfAdjointOp, 240 dim_shape=dim_shape, 241 codim_shape=dim_shape, 242 embed=dict( 243 _name="HomothetyOp", 244 _cst=cst, 245 _lipschitz=float(abs(cst)), 246 ), 247 apply=op_apply, 248 svdvals=op_svdvals, 249 pinv=op_pinv, 250 gram=op_gram, 251 cogram=op_gram, 252 trace=op_trace, 253 estimate_lipschitz=op_estimate_lipschitz, 254 ) 255 op.dagger = types.MethodType(op_dagger, op) 256 return op
257 258
[docs] 259def DiagonalOp( 260 vec: pxt.NDArray, 261 dim_shape: pxt.NDArrayShape = None, 262 enable_warnings: bool = True, 263) -> pxt.OpT: 264 r""" 265 Element-wise scaling operator. 266 267 Note 268 ---- 269 * :py:func:`~pyxu.operator.DiagonalOp` instances are **not arraymodule-agnostic**: they will only work with NDArrays 270 belonging to the same array module as `vec`. Moreover, inner computations may cast input arrays when the 271 precision of `vec` does not match the user-requested precision. If such a situation occurs, a warning is raised. 272 * If `vec` is a DASK array, the operator will be a :py:class:`~pyxu.abc.SelfAdjointOp`. If `vec` is a NUMPY/CUPY 273 array, the created operator specializes to :py:class:`~pyxu.abc.PosDefOp` when possible. Specialization is not 274 automatic for DASK inputs because operators should be quick to build under all circumstances, and this is not 275 guaranteed if we have to check that all entries are positive for out-of-core arrays. Users who know that all 276 `vec` entries are positive can manually cast to :py:class:`~pyxu.abc.PosDefOp` afterwards if required. 277 278 Parameters 279 ---------- 280 dim_shape: NDArrayShape 281 (M1,...,MD) shape of operator's domain. 282 Defaults to the shape of `vec` when omitted. 283 vec: NDArray 284 Scale factors. If `dim_shape` is provided, then `vec` must be broadcastable with arrays of size `dim_shape`. 285 enable_warnings: bool 286 If ``True``, emit a warning in case of precision mis-match issues. 287 """ 288 if dim_shape is None: 289 dim_shape = vec.shape 290 else: 291 dim_shape = pxu.as_canonical_shape(dim_shape) 292 sh = np.broadcast_shapes(vec.shape, dim_shape) 293 294 # Getting here means `vec` and `dim_shape` are broadcastable, but we don't know yet 295 # which one defines the upper bound. 296 assert all(s <= d for (s, d) in zip(sh, dim_shape)), "vec and dim_shape are incompatible." 297 298 def op_apply(_, arr): 299 if (_._vec.dtype != arr.dtype) and _._enable_warnings: 300 msg = "Computation may not be performed at the requested precision." 301 warnings.warn(msg, pxw.PrecisionWarning) 302 out = arr.copy() 303 out *= _._vec 304 return out 305 306 def op_asarray(_, **kwargs) -> pxt.NDArray: 307 xp = pxu.get_array_module(_._vec) 308 vec = xp.broadcast_to(_._vec, _.dim_shape) 309 A = xp.diag(vec.reshape(-1)).reshape((*_.codim_shape, *_.dim_shape)) 310 311 xp = kwargs.get("xp", pxd.NDArrayInfo.default().module()) 312 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value) 313 B = xp.array(pxu.to_NUMPY(A), dtype=dtype) 314 return B 315 316 def op_gram(_): 317 return DiagonalOp( 318 vec=_._vec**2, 319 dim_shape=_.dim_shape, 320 enable_warnings=_._enable_warnings, 321 ) 322 323 def op_svdvals(_, **kwargs): 324 gpu = kwargs.get("gpu", False) 325 xp = pxd.NDArrayInfo.from_flag(gpu).module() 326 k = kwargs["k"] 327 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value) 328 329 vec = xp.broadcast_to( 330 xp.abs(_._vec), 331 _.dim_shape, 332 ).reshape(-1) 333 if ndi == pxd.NDArrayInfo.DASK: 334 D = xp.topk(vec, k) 335 else: 336 vec = vec[vec.argsort()] 337 D = vec[-k:] 338 return D.astype(dtype) 339 340 def op_pinv(_, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray: 341 xp = pxu.get_array_module(arr) 342 with warnings.catch_warnings(): 343 warnings.simplefilter("ignore") 344 scale = _._vec / (_._vec**2 + damp) 345 scale[xp.isnan(scale)] = 0 346 out = arr.copy() 347 out *= scale 348 return out 349 350 def op_dagger(_, damp: pxt.Real, **kwargs) -> pxt.OpT: 351 xp = pxu.get_array_module(_._vec) 352 with warnings.catch_warnings(): 353 warnings.simplefilter("ignore") 354 scale = _._vec / (_._vec**2 + damp) 355 scale[xp.isnan(scale)] = 0 356 return DiagonalOp( 357 vec=scale, 358 dim_shape=_.dim_shape, 359 enable_warnings=_._enable_warnings, 360 ) 361 362 def op_trace(_, **kwargs): 363 xp = pxu.get_array_module(_._vec) 364 vec = xp.broadcast_to(_._vec, _.dim_shape) 365 return float(vec.sum()) 366 367 def op_estimate_lipschitz(_, **kwargs): 368 xp = pxu.get_array_module(_._vec) 369 _.lipschitz = float(xp.fabs(vec).max()) 370 return _.lipschitz 371 372 ndi = pxd.NDArrayInfo.from_obj(vec) 373 if ndi == pxd.NDArrayInfo.DASK: 374 klass = pxa.SelfAdjointOp 375 else: 376 positive = (vec > 0).all() 377 klass = pxa.PosDefOp if positive else pxa.SelfAdjointOp 378 op = px_src.from_source( 379 cls=klass, 380 dim_shape=dim_shape, 381 codim_shape=dim_shape, 382 embed=dict( 383 _name="DiagonalOp", 384 _vec=vec, 385 _enable_warnings=bool(enable_warnings), 386 ), 387 apply=op_apply, 388 estimate_lipschitz=op_estimate_lipschitz, 389 asarray=op_asarray, 390 gram=op_gram, 391 cogram=op_gram, 392 svdvals=op_svdvals, 393 pinv=op_pinv, 394 trace=op_trace, 395 ) 396 op.dagger = types.MethodType(op_dagger, op) 397 return op
398 399 400def _ExplicitLinOp( 401 cls: pxt.OpC, 402 mat: typ.Union[pxt.NDArray, pxt.SparseArray], 403 dim_rank: pxt.Integer = None, 404 enable_warnings: bool = True, 405) -> pxt.OpT: 406 r""" 407 Build a linear operator from its matrix representation. 408 409 Given an array :math:`\mathbf{A} \in \mathbb{R}^{N_{1} \times\cdots\times N_{K} \times M_{1} \times\cdots\times 410 M_{D}}`, the *explicit linear operator* associated to :math:`\mathbf{A}` is defined as 411 412 .. math:: 413 414 [\mathbf{A}\mathbf{x}]_{n_{1},\ldots,n_{K}} 415 = 416 \langle \mathbf{A}[n_{1},\ldots,n_{K},\ldots], \mathbf{x} \rangle 417 \qquad 418 \forall \mathbf{x} \in \mathbb{R}^{M_{1} \times\cdots\times M_{D}}. 419 420 Parameters 421 ---------- 422 cls: OpC 423 LinOp sub-class to instantiate. 424 mat: NDArray, SparseArray 425 (N1,...,NK, M1,...,MD) matrix generator. 426 The input array can be *dense* or *sparse*. 427 Accepted 2D sparse arrays are: 428 429 * CPU: COO/CSC/CSR/BSR 430 * GPU: COO/CSC/CSR 431 dim_rank: Integer 432 Rank of operator's domain. (D) 433 It can be omitted if `mat` is 2D since auto-inferred to 1. 434 enable_warnings: bool 435 If ``True``, emit a warning in case of precision mis-match issues. 436 437 Notes 438 ----- 439 * :py:class:`~pyxu.operator.linop.base._ExplicitLinOp` instances are **not arraymodule-agnostic**: 440 they will only work with NDArrays belonging to the same (dense) array module as `mat`. Moreover, inner 441 computations may cast input arrays when the precision of `mat` does not match the user-requested precision. If 442 such a situation occurs, a warning is raised. 443 444 * The matrix provided to :py:func:`~pyxu.operator.linop.base._ExplicitLinOp` is used as-is and can be accessed via 445 ``.mat``. 446 """ 447 448 def is_dense(A) -> bool: 449 # Ensure `A` is a supported array format, then return 450 # `True` if `A` is dense 451 # `False` if `A` is sparse 452 fail_dense = False 453 try: 454 pxd.NDArrayInfo.from_obj(A) 455 return True 456 except Exception: 457 fail_dense = True 458 459 fail_sparse = False 460 try: 461 pxd.SparseArrayInfo.from_obj(A) 462 return False 463 except Exception: 464 fail_sparse = True 465 466 if fail_dense and fail_sparse: 467 raise ValueError("mat: format could not be inferred.") 468 469 def tensordot(A, b, dim_rank, warn: bool): 470 # Parameters 471 # ---------- 472 # A: (N1,...,NK, M1,...,MD) dense or sparse (2D) 473 # b: (S1,...,SL, M1,...,MD) dense 474 # dim_rank: D 475 # warn: bool 476 # 477 # Returns 478 # ------- 479 # out: (S1,...,SL, N1,...,NK) dense 480 if (A.dtype != b.dtype) and warn: 481 msg = "Computation may not be performed at the requested precision." 482 warnings.warn(msg, pxw.PrecisionWarning) 483 484 dim_shape = A.shape[-dim_rank:] 485 dim_size = np.prod(dim_shape) 486 codim_shape = A.shape[:-dim_rank] 487 codim_size = np.prod(codim_shape) 488 489 sh = b.shape[:-dim_rank] 490 if not is_dense(A): # sparse matrix -> necessarily 2D 491 b = b.reshape(-1, dim_size) 492 out = A.dot(b.T).T # (prod(sh), codim_size) 493 out = out.reshape(*sh, codim_size) 494 else: # ND dense array 495 N = pxd.NDArrayInfo # short-hand 496 ndi = N.from_obj(A) 497 xp = ndi.module() 498 499 if ndi != N.DASK: 500 # NUMPY/CUPY.tensordot() works -> use it. 501 out = xp.tensordot( # (S1,...,SL, N1,...,NK) 502 b, # (S1,...,SL, M1,...,MD) 503 A, # (N1,...,NK, M1,...,MD) 504 axes=[ 505 list(range(-dim_rank, 0)), 506 list(range(-dim_rank, 0)), 507 ], 508 ) 509 else: # DASK-backed `mat` 510 # DASK.tensordot() broken -> use 2D-ops instead 511 msg = "[2023.12] DASK's tensordot() is broken. -> fallback onto 2D-shaped ops." 512 pxw.warn_dask_perf(msg) 513 514 A_2D = A.reshape(codim_size, dim_size) 515 b = b.reshape(-1, dim_size) 516 out = A_2D.dot(b.T).T # (prod(sh), codim_size) 517 out = out.reshape(*sh, *codim_shape) 518 return out 519 520 def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray: 521 out = tensordot( 522 A=_.mat, 523 b=arr, 524 dim_rank=_.dim_rank, 525 warn=_._enable_warnings, 526 ) 527 return out 528 529 def op_adjoint(_, arr: pxt.NDArray) -> pxt.NDArray: 530 if is_dense(_.mat): 531 axes = ( 532 *tuple(range(-_.dim_rank, 0)), 533 *tuple(range(_.codim_rank)), 534 ) 535 else: 536 axes = None # transposes all axes for 2D sparse arrays 537 out = tensordot( 538 A=_.mat.transpose(axes), 539 b=arr, 540 dim_rank=_.codim_rank, 541 warn=_._enable_warnings, 542 ) 543 return out 544 545 def op_estimate_lipscthitz(_, **kwargs) -> pxt.Real: 546 N = pxd.NDArrayInfo 547 S = pxd.SparseArrayInfo 548 549 if is_dense(_.mat): 550 ndi = N.from_obj(_.mat) 551 else: 552 sdi = S.from_obj(_.mat) 553 if sdi == S.SCIPY_SPARSE: 554 ndi = N.NUMPY 555 elif sdi == S.CUPY_SPARSE: 556 ndi = N.CUPY 557 else: 558 raise NotImplementedError 559 560 kwargs.update( 561 xp=ndi.module(), 562 gpu=ndi == N.CUPY, 563 dtype=_.mat.dtype, 564 ) 565 klass = _.__class__ 566 return klass.estimate_lipschitz(_, **kwargs) 567 568 def op_asarray(_, **kwargs) -> pxt.NDArray: 569 N = pxd.NDArrayInfo 570 xp = kwargs.get("xp", N.default().module()) 571 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value) 572 573 if is_dense(_.mat): 574 A = _.mat.astype(dtype, copy=False) 575 else: 576 A = _.mat.astype(dtype).toarray() # `copy field not ubiquitous` 577 B = xp.array(pxu.to_NUMPY(A), dtype=dtype) 578 return B 579 580 def op_trace(_, **kwargs) -> pxt.Real: 581 if _.dim_size != _.codim_size: 582 raise NotImplementedError 583 elif len(_.mat.shape) == 2: # dense or sparse 584 try: 585 tr = _.mat.trace() 586 except AttributeError: 587 # Not all sparse types have a .trace() method ... 588 tr = _.mat.diagonal().sum() 589 else: # ND dense arrays only 590 # We don't want to reshape `mat` if DASK-backed for performance reasons, so the trace is built by indexing 591 # the "diagonal" manually. 592 tr = 0 593 for idx in range(_.dim_size): 594 dim_idx = np.unravel_index(idx, _.dim_shape) 595 codim_idx = np.unravel_index(idx, _.codim_shape) 596 tr += _.mat[*codim_idx, *dim_idx] 597 return float(tr) 598 599 is_dense(mat) # We were given a dense/sparse array ... 600 # ... but is dim_rank correctly specified? 601 assert len(mat.shape) >= 2, "Only 2D+ arrays are supported." 602 if len(mat.shape) == 2: 603 dim_rank = 1 # doesn't matter what the user specified. 604 else: # rank > 2 605 # if ND -> mandatory supplied & (1 <= dim_rank < mat.ndim) 606 assert dim_rank is not None, "Dimension rank must be specified for ND operators." 607 assert 1 <= dim_rank < len(mat.shape) 608 609 op = px_src.from_source( 610 cls=cls, 611 dim_shape=mat.shape[-dim_rank:], 612 codim_shape=mat.shape[:-dim_rank], 613 embed=dict( 614 _name="_ExplicitLinOp", 615 mat=mat, 616 _enable_warnings=bool(enable_warnings), 617 ), 618 apply=op_apply, 619 adjoint=op_adjoint, 620 asarray=op_asarray, 621 trace=op_trace, 622 estimate_lipschitz=op_estimate_lipscthitz, 623 ) 624 return op