Source code for pyxu.operator.map.base
1import numpy as np
2
3import pyxu.abc as pxa
4import pyxu.info.deps as pxd
5import pyxu.info.ptype as pxt
6import pyxu.operator.interop.source as px_src
7import pyxu.util as pxu
8
9__all__ = [
10 "ConstantValued",
11]
12
13
[docs]
14def ConstantValued(
15 dim_shape: pxt.NDArrayShape,
16 codim_shape: pxt.NDArrayShape,
17 cst: pxt.Real,
18) -> pxt.OpT:
19 r"""
20 Constant-valued operator :math:`C: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to
21 \mathbb{R}^{N_{1} \times\cdots\times N_{K}}`.
22 """
23 codim_shape = pxu.as_canonical_shape(codim_shape)
24
25 cst = float(cst)
26 if np.isclose(cst, 0):
27 if codim_shape == (1,):
28 from pyxu.operator import NullFunc
29
30 op = NullFunc(dim_shape=dim_shape)
31 else:
32 from pyxu.operator import NullOp
33
34 op = NullOp(
35 dim_shape=dim_shape,
36 codim_shape=codim_shape,
37 )
38 else:
39
40 def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray:
41 ndi = pxd.NDArrayInfo.from_obj(arr)
42 kwargs = dict()
43 if ndi == pxd.NDArrayInfo.DASK:
44 stack_chunks = arr.chunks[: -_.dim_rank]
45 core_chunks = ("auto",) * _.codim_rank
46 kwargs.update(chunks=stack_chunks + core_chunks)
47
48 xp = ndi.module()
49 sh = arr.shape[: -_.dim_rank]
50 out = xp.broadcast_to(
51 xp.array(_._cst, arr.dtype),
52 (*sh, *_.codim_shape),
53 **kwargs,
54 )
55 return out
56
57 def op_jacobian(_, arr: pxt.NDArray) -> pxt.OpT:
58 from pyxu.operator import NullOp
59
60 return NullOp(
61 dim_shape=_.dim_shape,
62 codim_shape=_.codim_shape,
63 )
64
65 def op_grad(_, arr: pxt.NDArray) -> pxt.NDArray:
66 ndi = pxd.NDArrayInfo.from_obj(arr)
67 kwargs = dict()
68 if ndi == pxd.NDArrayInfo.DASK:
69 kwargs.update(chunks=arr.chunks)
70
71 xp = ndi.module()
72 out = xp.broadcast_to(
73 xp.array(0, arr.dtype),
74 arr.shape,
75 **kwargs,
76 )
77 return out
78
79 def op_prox(_, arr: pxt.NDArray, tau: pxt.NDArray) -> pxt.NDArray:
80 return pxu.read_only(arr)
81
82 if codim_shape == (1,):
83 klass = pxa.ProxDiffFunc
84 else:
85 klass = pxa.DiffMap
86 op = px_src.from_source(
87 cls=klass,
88 dim_shape=dim_shape,
89 codim_shape=codim_shape,
90 embed=dict(
91 _name="ConstantValued",
92 _cst=cst,
93 ),
94 apply=op_apply,
95 jacobian=op_jacobian,
96 grad=op_grad,
97 prox=op_prox,
98 )
99 op.lipschitz = 0
100 op.diff_lipschitz = 0
101 return op