Source code for pyxu.operator.interop.source

  1import collections.abc as cabc
  2import types
  3
  4import pyxu.abc as pxa
  5import pyxu.info.ptype as pxt
  6import pyxu.util as pxu
  7
  8__all__ = [
  9    "from_source",
 10]
 11
 12
[docs] 13def from_source( 14 cls: pxt.OpC, 15 dim_shape: pxt.NDArrayShape, 16 codim_shape: pxt.NDArrayShape, 17 embed: dict = None, 18 vectorize: pxt.VarName = frozenset(), 19 **kwargs, 20) -> pxt.OpT: 21 r""" 22 Define an :py:class:`~pyxu.abc.Operator` from low-level constructs. 23 24 Parameters 25 ---------- 26 cls: OpC 27 Operator sub-class to instantiate. 28 dim_shape: NDArrayShape 29 Operator domain shape (M1,...,MD). 30 codim_shape: NDArrayShape 31 Operator co-domain shape (N1,...,NK). 32 embed: dict 33 (k[str], v[value]) pairs to embed into the created operator. 34 35 `embed` is useful to attach extra information to synthesized :py:class:`~pyxu.abc.Operator` used by arithmetic 36 methods. 37 kwargs: dict 38 (k[str], v[callable]) pairs to use as arithmetic methods. 39 40 Keys must be entries from :py:meth:`~pyxu.abc.Property.arithmetic_methods`. 41 42 Omitted arithmetic attributes/methods default to those provided by `cls`. 43 vectorize: VarName 44 Arithmetic methods to vectorize. 45 46 `vectorize` is useful if an arithmetic method provided to `kwargs` (ex: :py:meth:`~pyxu.abc.Map.apply`) does not 47 support stacking dimensions. 48 49 Returns 50 ------- 51 op: OpT 52 Pyxu-compliant operator :math:`A: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{N_{1} 53 \times\cdots\times N_{K}}`. 54 55 Notes 56 ----- 57 * If provided, arithmetic methods must abide exactly to the Pyxu interface. In particular, the following arithmetic 58 methods, if supplied, **must** have the following interface: 59 60 .. code-block:: python3 61 62 def apply(self, arr: pxt.NDArray) -> pxt.NDArray # (..., M1,...,MD) -> (..., N1,...,NK) 63 def grad(self, arr: pxt.NDArray) -> pxt.NDArray # (..., M1,...,MD) -> (..., M1,...,MD) 64 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray # (..., N1,...,NK) -> (..., M1,...,MD) 65 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray # (..., M1,...,MD) -> (..., M1,...,MD) 66 def pinv(self, arr: pxt.NDArray, damp: pxt.Real) -> pxt.NDArray # (..., N1,...,NK) -> (..., M1,...,MD) 67 68 Moreover, the methods above **must** accept stacking dimensions in ``arr``. If this does not hold, consider 69 populating `vectorize`. 70 71 * Auto-vectorization consists in decorating `kwargs`-specified arithmetic methods with 72 :py:func:`~pyxu.util.vectorize`. Auto-vectorization may be less efficient than explicitly providing a vectorized 73 implementation. 74 75 Examples 76 -------- 77 Creation of the custom element-wise differential operator :math:`f(\mathbf{x}) = \mathbf{x}^{2}`. 78 79 .. code-block:: python3 80 81 N = 5 82 f = from_source( 83 cls=pyxu.abc.DiffMap, 84 dim_shape=N, 85 codim_shape=N, 86 apply=lambda _, arr: arr**2, 87 ) 88 x = np.arange(N) 89 y = f(x) # [0, 1, 4, 9, 16] 90 dL = f.diff_lipschitz # inf (default value provided by DiffMap class.) 91 92 In practice we know that :math:`f` has a finite-valued diff-Lipschitz constant. It is thus recommended to set it 93 too when instantiating via ``from_source``: 94 95 .. code-block:: python3 96 97 N = 5 98 f = from_source( 99 cls=pyxu.abc.DiffMap, 100 dim_shape=N, 101 codim_shape=N, 102 embed=dict( 103 # special form to set (diff-)Lipschitz attributes via from_source() 104 _diff_lipschitz=2, 105 ), 106 apply=lambda _, arr: arr**2, 107 ) 108 x = np.arange(N) 109 y = f(x) # [0, 1, 4, 9, 16] 110 dL = f.diff_lipschitz # 2 <- instead of inf 111 """ 112 if embed is None: 113 embed = dict() 114 115 if isinstance(vectorize, str): 116 vectorize = (vectorize,) 117 vectorize = frozenset(vectorize) 118 119 src = _FromSource( 120 cls=cls, 121 dim_shape=dim_shape, 122 codim_shape=codim_shape, 123 embed=embed, 124 vectorize=vectorize, 125 **kwargs, 126 ) 127 op = src.op() 128 return op
129 130 131class _FromSource: # See from_source() for a detailed description. 132 def __init__( 133 self, 134 cls: pxt.OpC, 135 dim_shape: pxt.NDArrayShape, 136 codim_shape: pxt.NDArrayShape, 137 embed: dict, 138 vectorize: frozenset[str], 139 **kwargs, 140 ): 141 from pyxu.abc.operator import _core_operators 142 143 assert cls in _core_operators(), f"Unknown Operator type: {cls}." 144 self._op = cls( # ensure shape well-formed 145 dim_shape=dim_shape, 146 codim_shape=codim_shape, 147 ) 148 149 # Arithmetic methods to attach to `_op`. 150 meth = frozenset.union(*[p.arithmetic_methods() for p in pxa.Property]) 151 if not (set(kwargs) <= meth): 152 msg_head = "Unknown arithmetic methods:" 153 unknown = set(kwargs) - meth 154 msg_tail = ", ".join([f"{name}()" for name in unknown]) 155 raise ValueError(f"{msg_head} {msg_tail}") 156 self._kwargs = kwargs 157 158 # Extra attributes to attach to `_op`. 159 assert isinstance(embed, cabc.Mapping) 160 self._embed = embed 161 162 # Add-on vectorization functionality. 163 self._vkwargs = self._parse_vectorize(vectorize) 164 self._vectorize = vectorize 165 166 def op(self) -> pxt.OpT: 167 _op = self._op # shorthand 168 for p in _op.properties(): 169 for name in p.arithmetic_methods(): 170 if func := self._kwargs.get(name, False): 171 # vectorize() do NOT kick in for default-provided methods. 172 # (We assume they are Pyxu-compliant from the start.) 173 174 if name in self._vectorize: 175 decorate = pxu.vectorize(**self._vkwargs[name]) 176 func = decorate(func) 177 178 setattr(_op, name, types.MethodType(func, _op)) 179 180 # Embed extra attributes 181 for name, attr in self._embed.items(): 182 setattr(_op, name, attr) 183 184 return _op 185 186 def _parse_vectorize(self, vectorize: frozenset[str]): 187 vkwargs = dict( # Parameter hints for vectorize() 188 apply=dict( 189 i="arr", # Pyxu arithmetic methods broadcast along parameter `arr`. 190 dim_shape=self._op.dim_shape, 191 codim_shape=self._op.codim_shape, 192 ), 193 prox=dict( 194 i="arr", 195 dim_shape=self._op.dim_shape, 196 codim_shape=self._op.dim_shape, 197 ), 198 grad=dict( 199 i="arr", 200 dim_shape=self._op.dim_shape, 201 codim_shape=self._op.dim_shape, 202 ), 203 adjoint=dict( 204 i="arr", 205 dim_shape=self._op.codim_shape, 206 codim_shape=self._op.dim_shape, 207 ), 208 pinv=dict( 209 i="arr", 210 dim_shape=self._op.codim_shape, 211 codim_shape=self._op.dim_shape, 212 ), 213 ) 214 215 if not (vectorize <= set(vkwargs)): # un-recognized arithmetic method 216 msg_head = "Can only vectorize arithmetic methods" 217 msg_tail = ", ".join([f"{name}()" for name in vkwargs]) 218 raise ValueError(f"{msg_head} {msg_tail}") 219 return vkwargs