Source code for pyxu.operator.map.base

  1import numpy as np
  2
  3import pyxu.abc as pxa
  4import pyxu.info.deps as pxd
  5import pyxu.info.ptype as pxt
  6import pyxu.operator.interop.source as px_src
  7import pyxu.util as pxu
  8
  9__all__ = [
 10    "ConstantValued",
 11]
 12
 13
[docs] 14def ConstantValued( 15 dim_shape: pxt.NDArrayShape, 16 codim_shape: pxt.NDArrayShape, 17 cst: pxt.Real, 18) -> pxt.OpT: 19 r""" 20 Constant-valued operator :math:`C: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to 21 \mathbb{R}^{N_{1} \times\cdots\times N_{K}}`. 22 """ 23 codim_shape = pxu.as_canonical_shape(codim_shape) 24 25 cst = float(cst) 26 if np.isclose(cst, 0): 27 if codim_shape == (1,): 28 from pyxu.operator import NullFunc 29 30 op = NullFunc(dim_shape=dim_shape) 31 else: 32 from pyxu.operator import NullOp 33 34 op = NullOp( 35 dim_shape=dim_shape, 36 codim_shape=codim_shape, 37 ) 38 else: 39 40 def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray: 41 ndi = pxd.NDArrayInfo.from_obj(arr) 42 kwargs = dict() 43 if ndi == pxd.NDArrayInfo.DASK: 44 stack_chunks = arr.chunks[: -_.dim_rank] 45 core_chunks = ("auto",) * _.codim_rank 46 kwargs.update(chunks=stack_chunks + core_chunks) 47 48 xp = ndi.module() 49 sh = arr.shape[: -_.dim_rank] 50 out = xp.broadcast_to( 51 xp.array(_._cst, arr.dtype), 52 (*sh, *_.codim_shape), 53 **kwargs, 54 ) 55 return out 56 57 def op_jacobian(_, arr: pxt.NDArray) -> pxt.OpT: 58 from pyxu.operator import NullOp 59 60 return NullOp( 61 dim_shape=_.dim_shape, 62 codim_shape=_.codim_shape, 63 ) 64 65 def op_grad(_, arr: pxt.NDArray) -> pxt.NDArray: 66 ndi = pxd.NDArrayInfo.from_obj(arr) 67 kwargs = dict() 68 if ndi == pxd.NDArrayInfo.DASK: 69 kwargs.update(chunks=arr.chunks) 70 71 xp = ndi.module() 72 out = xp.broadcast_to( 73 xp.array(0, arr.dtype), 74 arr.shape, 75 **kwargs, 76 ) 77 return out 78 79 def op_prox(_, arr: pxt.NDArray, tau: pxt.NDArray) -> pxt.NDArray: 80 return pxu.read_only(arr) 81 82 if codim_shape == (1,): 83 klass = pxa.ProxDiffFunc 84 else: 85 klass = pxa.DiffMap 86 op = px_src.from_source( 87 cls=klass, 88 dim_shape=dim_shape, 89 codim_shape=codim_shape, 90 embed=dict( 91 _name="ConstantValued", 92 _cst=cst, 93 ), 94 apply=op_apply, 95 jacobian=op_jacobian, 96 grad=op_grad, 97 prox=op_prox, 98 ) 99 op.lipschitz = 0 100 op.diff_lipschitz = 0 101 return op