1import collections.abc as cabc
2import types
3
4import pyxu.abc as pxa
5import pyxu.info.ptype as pxt
6import pyxu.util as pxu
7
8__all__ = [
9 "from_source",
10]
11
12
[docs]
13def from_source(
14 cls: pxt.OpC,
15 dim_shape: pxt.NDArrayShape,
16 codim_shape: pxt.NDArrayShape,
17 embed: dict = None,
18 vectorize: pxt.VarName = frozenset(),
19 **kwargs,
20) -> pxt.OpT:
21 r"""
22 Define an :py:class:`~pyxu.abc.Operator` from low-level constructs.
23
24 Parameters
25 ----------
26 cls: OpC
27 Operator sub-class to instantiate.
28 dim_shape: NDArrayShape
29 Operator domain shape (M1,...,MD).
30 codim_shape: NDArrayShape
31 Operator co-domain shape (N1,...,NK).
32 embed: dict
33 (k[str], v[value]) pairs to embed into the created operator.
34
35 `embed` is useful to attach extra information to synthesized :py:class:`~pyxu.abc.Operator` used by arithmetic
36 methods.
37 kwargs: dict
38 (k[str], v[callable]) pairs to use as arithmetic methods.
39
40 Keys must be entries from :py:meth:`~pyxu.abc.Property.arithmetic_methods`.
41
42 Omitted arithmetic attributes/methods default to those provided by `cls`.
43 vectorize: VarName
44 Arithmetic methods to vectorize.
45
46 `vectorize` is useful if an arithmetic method provided to `kwargs` (ex: :py:meth:`~pyxu.abc.Map.apply`) does not
47 support stacking dimensions.
48
49 Returns
50 -------
51 op: OpT
52 Pyxu-compliant operator :math:`A: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{N_{1}
53 \times\cdots\times N_{K}}`.
54
55 Notes
56 -----
57 * If provided, arithmetic methods must abide exactly to the Pyxu interface. In particular, the following arithmetic
58 methods, if supplied, **must** have the following interface:
59
60 .. code-block:: python3
61
62 def apply(self, arr: pxt.NDArray) -> pxt.NDArray # (..., M1,...,MD) -> (..., N1,...,NK)
63 def grad(self, arr: pxt.NDArray) -> pxt.NDArray # (..., M1,...,MD) -> (..., M1,...,MD)
64 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray # (..., N1,...,NK) -> (..., M1,...,MD)
65 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray # (..., M1,...,MD) -> (..., M1,...,MD)
66 def pinv(self, arr: pxt.NDArray, damp: pxt.Real) -> pxt.NDArray # (..., N1,...,NK) -> (..., M1,...,MD)
67
68 Moreover, the methods above **must** accept stacking dimensions in ``arr``. If this does not hold, consider
69 populating `vectorize`.
70
71 * Auto-vectorization consists in decorating `kwargs`-specified arithmetic methods with
72 :py:func:`~pyxu.util.vectorize`. Auto-vectorization may be less efficient than explicitly providing a vectorized
73 implementation.
74
75 Examples
76 --------
77 Creation of the custom element-wise differential operator :math:`f(\mathbf{x}) = \mathbf{x}^{2}`.
78
79 .. code-block:: python3
80
81 N = 5
82 f = from_source(
83 cls=pyxu.abc.DiffMap,
84 dim_shape=N,
85 codim_shape=N,
86 apply=lambda _, arr: arr**2,
87 )
88 x = np.arange(N)
89 y = f(x) # [0, 1, 4, 9, 16]
90 dL = f.diff_lipschitz # inf (default value provided by DiffMap class.)
91
92 In practice we know that :math:`f` has a finite-valued diff-Lipschitz constant. It is thus recommended to set it
93 too when instantiating via ``from_source``:
94
95 .. code-block:: python3
96
97 N = 5
98 f = from_source(
99 cls=pyxu.abc.DiffMap,
100 dim_shape=N,
101 codim_shape=N,
102 embed=dict(
103 # special form to set (diff-)Lipschitz attributes via from_source()
104 _diff_lipschitz=2,
105 ),
106 apply=lambda _, arr: arr**2,
107 )
108 x = np.arange(N)
109 y = f(x) # [0, 1, 4, 9, 16]
110 dL = f.diff_lipschitz # 2 <- instead of inf
111 """
112 if embed is None:
113 embed = dict()
114
115 if isinstance(vectorize, str):
116 vectorize = (vectorize,)
117 vectorize = frozenset(vectorize)
118
119 src = _FromSource(
120 cls=cls,
121 dim_shape=dim_shape,
122 codim_shape=codim_shape,
123 embed=embed,
124 vectorize=vectorize,
125 **kwargs,
126 )
127 op = src.op()
128 return op
129
130
131class _FromSource: # See from_source() for a detailed description.
132 def __init__(
133 self,
134 cls: pxt.OpC,
135 dim_shape: pxt.NDArrayShape,
136 codim_shape: pxt.NDArrayShape,
137 embed: dict,
138 vectorize: frozenset[str],
139 **kwargs,
140 ):
141 from pyxu.abc.operator import _core_operators
142
143 assert cls in _core_operators(), f"Unknown Operator type: {cls}."
144 self._op = cls( # ensure shape well-formed
145 dim_shape=dim_shape,
146 codim_shape=codim_shape,
147 )
148
149 # Arithmetic methods to attach to `_op`.
150 meth = frozenset.union(*[p.arithmetic_methods() for p in pxa.Property])
151 if not (set(kwargs) <= meth):
152 msg_head = "Unknown arithmetic methods:"
153 unknown = set(kwargs) - meth
154 msg_tail = ", ".join([f"{name}()" for name in unknown])
155 raise ValueError(f"{msg_head} {msg_tail}")
156 self._kwargs = kwargs
157
158 # Extra attributes to attach to `_op`.
159 assert isinstance(embed, cabc.Mapping)
160 self._embed = embed
161
162 # Add-on vectorization functionality.
163 self._vkwargs = self._parse_vectorize(vectorize)
164 self._vectorize = vectorize
165
166 def op(self) -> pxt.OpT:
167 _op = self._op # shorthand
168 for p in _op.properties():
169 for name in p.arithmetic_methods():
170 if func := self._kwargs.get(name, False):
171 # vectorize() do NOT kick in for default-provided methods.
172 # (We assume they are Pyxu-compliant from the start.)
173
174 if name in self._vectorize:
175 decorate = pxu.vectorize(**self._vkwargs[name])
176 func = decorate(func)
177
178 setattr(_op, name, types.MethodType(func, _op))
179
180 # Embed extra attributes
181 for name, attr in self._embed.items():
182 setattr(_op, name, attr)
183
184 return _op
185
186 def _parse_vectorize(self, vectorize: frozenset[str]):
187 vkwargs = dict( # Parameter hints for vectorize()
188 apply=dict(
189 i="arr", # Pyxu arithmetic methods broadcast along parameter `arr`.
190 dim_shape=self._op.dim_shape,
191 codim_shape=self._op.codim_shape,
192 ),
193 prox=dict(
194 i="arr",
195 dim_shape=self._op.dim_shape,
196 codim_shape=self._op.dim_shape,
197 ),
198 grad=dict(
199 i="arr",
200 dim_shape=self._op.dim_shape,
201 codim_shape=self._op.dim_shape,
202 ),
203 adjoint=dict(
204 i="arr",
205 dim_shape=self._op.codim_shape,
206 codim_shape=self._op.dim_shape,
207 ),
208 pinv=dict(
209 i="arr",
210 dim_shape=self._op.codim_shape,
211 codim_shape=self._op.dim_shape,
212 ),
213 )
214
215 if not (vectorize <= set(vkwargs)): # un-recognized arithmetic method
216 msg_head = "Can only vectorize arithmetic methods"
217 msg_tail = ", ".join([f"{name}()" for name in vkwargs])
218 raise ValueError(f"{msg_head} {msg_tail}")
219 return vkwargs