1import numpy as np
2
3import pyxu.abc as pxa
4import pyxu.info.ptype as pxt
5import pyxu.operator.func.norm as pxf
6import pyxu.runtime as pxrt
7import pyxu.util as pxu
8
9__all__ = [
10 "L1Ball",
11 "L2Ball",
12 "LInfinityBall",
13 "PositiveOrthant",
14 "HyperSlab",
15 "RangeSet",
16]
17
18
19class _IndicatorFunction(pxa.ProxFunc):
20 def __init__(self, dim_shape: pxt.NDArrayShape):
21 super().__init__(
22 dim_shape=dim_shape,
23 codim_shape=1,
24 )
25 self.lipschitz = np.inf
26
27 @staticmethod
28 def _bool2indicator(x: pxt.NDArray, dtype: pxt.DType) -> pxt.NDArray:
29 # x: NDarray[bool]
30 # y: NDarray[(0, \inf), dtype]
31 xp = pxu.get_array_module(x)
32 cast = lambda _: np.array(_, dtype=dtype)[()]
33 y = xp.where(x, cast(0), cast(np.inf))
34 return y
35
36
37class _NormBall(_IndicatorFunction):
38 def __init__(
39 self,
40 dim_shape: pxt.NDArrayShape,
41 ord: pxt.Integer,
42 radius: pxt.Real,
43 ):
44 super().__init__(dim_shape=dim_shape)
45 self._ord = ord
46 self._radius = float(radius)
47
48 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
49 from pyxu.opt.stop import _norm
50
51 norm = _norm(arr, ord=self._ord, rank=self.dim_rank) # (..., 1)
52 out = self._bool2indicator(norm <= self._radius, arr.dtype)
53 return out
54
55 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
56 klass = { # class of proximal operator to use
57 1: pxf.LInfinityNorm,
58 2: pxf.L2Norm,
59 np.inf: pxf.L1Norm,
60 }[self._ord]
61 op = klass(dim_shape=self.dim_shape)
62
63 out = arr.copy()
64 out -= op.prox(arr, tau=self._radius)
65 return out
66
67
[docs]
68def L1Ball(dim_shape: pxt.NDArrayShape, radius: pxt.Real = 1) -> pxt.OpT:
69 r"""
70 Indicator function of the :math:`\ell_{1}`-ball.
71
72 .. math::
73
74 \iota_{1}^{r}(\mathbf{x})
75 :=
76 \begin{cases}
77 0 & \|\mathbf{x}\|_{1} \le r \\
78 \infty & \text{otherwise}.
79 \end{cases}
80
81 .. math::
82
83 \text{prox}_{\tau\, \iota_{1}^{r}}(\mathbf{x})
84 :=
85 \mathbf{x} - \text{prox}_{r\, \ell_{\infty}}(\mathbf{x})
86
87 Parameters
88 ----------
89 dim_shape: NDArrayShape
90 radius: Real
91 Ball radius. (Default: unit ball.)
92
93 Returns
94 -------
95 op: OpT
96
97 Note
98 ----
99 * Computing :py:meth:`~pyxu.abc.ProxFunc.prox` is unavailable with DASK inputs.
100 (Inefficient exact solution at scale.)
101 """
102 op = _NormBall(dim_shape=dim_shape, ord=1, radius=radius)
103 op._name = "L1Ball"
104 return op
105
106
[docs]
107def L2Ball(dim_shape: pxt.NDArrayShape, radius: pxt.Real = 1) -> pxt.OpT:
108 r"""
109 Indicator function of the :math:`\ell_{2}`-ball.
110
111 .. math::
112
113 \iota_{2}^{r}(\mathbf{x})
114 :=
115 \begin{cases}
116 0 & \|\mathbf{x}\|_{2} \le r \\
117 \infty & \text{otherwise}.
118 \end{cases}
119
120 .. math::
121
122 \text{prox}_{\tau\, \iota_{2}^{r}}(\mathbf{x})
123 :=
124 \mathbf{x} - \text{prox}_{r\, \ell_{2}}(\mathbf{x})
125
126 Parameters
127 ----------
128 dim_shape: NDArrayShape
129 radius: Real
130 Ball radius. (Default: unit ball.)
131
132 Returns
133 -------
134 op: OpT
135 """
136 op = _NormBall(dim_shape=dim_shape, ord=2, radius=radius)
137 op._name = "L2Ball"
138 return op
139
140
[docs]
141def LInfinityBall(dim_shape: pxt.NDArrayShape, radius: pxt.Real = 1) -> pxt.OpT:
142 r"""
143 Indicator function of the :math:`\ell_{\infty}`-ball.
144
145 .. math::
146
147 \iota_{\infty}^{r}(\mathbf{x})
148 :=
149 \begin{cases}
150 0 & \|\mathbf{x}\|_{\infty} \le r \\
151 \infty & \text{otherwise}.
152 \end{cases}
153
154 .. math::
155
156 \text{prox}_{\tau\, \iota_{\infty}^{r}}(\mathbf{x})
157 :=
158 \mathbf{x} - \text{prox}_{r\, \ell_{1}}(\mathbf{x})
159
160 Parameters
161 ----------
162 dim_shape: NDArrayShape
163 radius: Real
164 Ball radius. (Default: unit ball.)
165
166 Returns
167 -------
168 op: OpT
169 """
170 op = _NormBall(dim_shape=dim_shape, ord=np.inf, radius=radius)
171 op._name = "LInfinityBall"
172 return op
173
174
[docs]
175class PositiveOrthant(_IndicatorFunction):
176 r"""
177 Indicator function of the positive orthant.
178
179 .. math::
180
181 \iota_{+}(\mathbf{x})
182 :=
183 \begin{cases}
184 0 & \min{\mathbf{x}} \ge 0,\\
185 \infty & \text{otherwise}.
186 \end{cases}
187
188 .. math::
189
190 \text{prox}_{\tau\, \iota_{+}}(\mathbf{x})
191 :=
192 \max(\mathbf{x}, \mathbf{0})
193 """
194
195 def __init__(self, dim_shape: pxt.NDArrayShape):
196 super().__init__(dim_shape=dim_shape)
197
198 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
199 axis = tuple(range(-self.dim_rank, 0))
200 in_set = (arr >= 0).all(axis=axis)[..., np.newaxis]
201 out = self._bool2indicator(in_set, arr.dtype)
202 return out
203
204 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
205 out = arr.clip(0, None)
206 return out
207
208
[docs]
209class HyperSlab(_IndicatorFunction):
210 r"""
211 Indicator function of a hyperslab.
212
213 .. math::
214
215 \iota_{\mathbf{a}}^{l,u}(\mathbf{x})
216 :=
217 \begin{cases}
218 0 & l \le \langle \mathbf{a}, \mathbf{x} \rangle \le u \\
219 \infty & \text{otherwise}.
220 \end{cases}
221
222 .. math::
223
224 \text{prox}_{\tau\, \iota_{\mathbf{a}}^{l,u}}(\mathbf{x})
225 :=
226 \begin{cases}
227 \mathbf{x} + \frac{l - \langle \mathbf{a}, \mathbf{x} \rangle}{\|\mathbf{a}\|^{2}} \mathbf{a} & \langle \mathbf{a}, \mathbf{x} \rangle < l, \\
228 \mathbf{x} + \frac{u - \langle \mathbf{a}, \mathbf{x} \rangle}{\|\mathbf{a}\|^{2}} \mathbf{a} & \langle \mathbf{a}, \mathbf{x} \rangle > u, \\
229 \mathbf{x} & \text{otherwise}.
230 \end{cases}
231 """
232
[docs]
233 def __init__(self, a: pxa.LinFunc, lb: pxt.Real, ub: pxt.Real):
234 """
235 Parameters
236 ----------
237 A: ~pyxu.abc.operator.LinFunc
238 Linear functional with domain (M1,...,MD).
239 lb: Real
240 Lower bound.
241 ub: Real
242 Upper bound.
243 """
244 assert lb < ub
245 super().__init__(dim_shape=a.dim_shape)
246
247 # Everything happens internally in normalized coordinates.
248 _norm = a.lipschitz # \norm{a}{2}
249 self._a = a / _norm
250 self._l = lb / _norm
251 self._u = ub / _norm
252
253 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
254 y = self._a.apply(arr) # (..., 1)
255 in_set = ((self._l <= y) & (y <= self._u)).all(axis=-1, keepdims=True)
256 out = self._bool2indicator(in_set, arr.dtype) # (..., 1)
257 return out
258
259 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
260 xp = pxu.get_array_module(arr)
261
262 a = self._a.adjoint(xp.ones(1, dtype=arr.dtype)) # (M1,...,MD)
263 expand = (np.newaxis,) * (self.dim_rank - 1)
264 y = self._a.apply(arr)[..., *expand] # (..., 1,...,1)
265 out = arr.copy()
266
267 l_corr = self._l - y
268 l_corr[l_corr <= 0] = 0
269 out += l_corr * a
270
271 u_corr = self._u - y
272 u_corr[u_corr >= 0] = 0
273 out += u_corr * a
274
275 return out
276
277
[docs]
278class RangeSet(_IndicatorFunction):
279 r"""
280 Indicator function of a range set.
281
282 .. math::
283
284 \iota_{\mathbf{A}}^{R}(\mathbf{x})
285 :=
286 \begin{cases}
287 0 & \mathbf{x} \in \text{span}(\mathbf{A}) \\
288 \infty & \text{otherwise}.
289 \end{cases}
290
291 .. math::
292
293 \text{prox}_{\tau\, \iota_{\mathbf{A}}^{R}}(\mathbf{x})
294 :=
295 \mathbf{A} (\mathbf{A}^{T} \mathbf{A})^{-1} \mathbf{A}^{T} \mathbf{x}.
296 """
297
[docs]
298 def __init__(self, A: pxa.LinOp):
299 super().__init__(dim_shape=A.codim_shape)
300 self._A = A
301
302 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
303 # I'm in range(A) if prox(x)==x.
304 axis = tuple(range(-self.dim_rank, 0))
305 y = self.prox(arr, tau=1)
306 in_set = self.isclose(y, arr).all(axis=axis) # (...,)
307 out = self._bool2indicator(in_set[..., np.newaxis], arr.dtype)
308 return out # (..., 1)
309
310 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
311 y = self._A.pinv(arr, damp=0)
312 out = self._A.apply(y)
313 return out
314
315 @staticmethod
316 def isclose(a: pxt.NDArray, b: pxt.NDArray) -> pxt.NDArray:
317 """
318 Equivalent of `xp.isclose`, but where atol is automatically chosen based on input's `dtype`.
319 """
320 atol = {
321 pxrt.Width.SINGLE.value: 2e-4,
322 pxrt.Width.DOUBLE.value: 1e-8,
323 }
324 # Numbers obtained by:
325 # * \sum_{k >= (p+1)//2} 2^{-k}, where p=<number of mantissa bits>; then
326 # * round up value to 3 significant decimal digits.
327 # N_mantissa = [23, 52] for [single, double] respectively.
328 xp = pxu.get_array_module(a)
329 prec = atol.get(a.dtype, pxrt.Width.DOUBLE.value) # default only should occur for integer types
330 eq = xp.isclose(a, b, atol=prec)
331 return eq