1import numpy as np
2import scipy.optimize as sopt
3
4import pyxu.abc as pxa
5import pyxu.info.deps as pxd
6import pyxu.info.ptype as pxt
7import pyxu.util as pxu
8
9__all__ = [
10 "L1Norm",
11 "L2Norm",
12 "SquaredL2Norm",
13 "SquaredL1Norm",
14 "LInfinityNorm",
15 "L21Norm",
16 "PositiveL1Norm",
17]
18
19
[docs]
20class L1Norm(pxa.ProxFunc):
21 r"""
22 :math:`\ell_{1}`-norm, :math:`\Vert\mathbf{x}\Vert_{1} := \sum_{i} |x_{i}|`.
23 """
24
25 def __init__(self, dim_shape: pxt.NDArrayShape):
26 super().__init__(
27 dim_shape=dim_shape,
28 codim_shape=1,
29 )
30 self.lipschitz = np.sqrt(self.dim_size)
31
32 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
33 xp = pxu.get_array_module(arr)
34 axis = tuple(range(-self.dim_rank, 0))
35 y = xp.fabs(arr).sum(axis=axis)[..., np.newaxis]
36 return y
37
38 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
39 xp = pxu.get_array_module(arr)
40 y = xp.fmax(0, xp.fabs(arr) - tau)
41 y *= xp.sign(arr)
42 return y
43
44
[docs]
45class L2Norm(pxa.ProxFunc):
46 r"""
47 :math:`\ell_{2}`-norm, :math:`\Vert\mathbf{x}\Vert_{2} := \sqrt{\sum_{i} |x_{i}|^{2}}`.
48 """
49
50 def __init__(self, dim_shape: pxt.NDArrayShape):
51 super().__init__(
52 dim_shape=dim_shape,
53 codim_shape=1,
54 )
55 self.lipschitz = 1
56 self.diff_lipschitz = np.inf
57
58 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
59 xp = pxu.get_array_module(arr)
60 axis = tuple(range(-self.dim_rank, 0))
61 y = xp.sqrt((arr**2).sum(axis=axis))[..., np.newaxis]
62 return y
63
64 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
65 xp = pxu.get_array_module(arr)
66 scale = 1 - tau / xp.fmax(self.apply(arr), tau) # (..., 1)
67
68 y = arr.copy()
69 expand = (np.newaxis,) * (self.dim_rank - 1)
70 y *= scale[..., *expand]
71 return y
72
73
[docs]
74class SquaredL2Norm(pxa.QuadraticFunc):
75 r"""
76 :math:`\ell^{2}_{2}`-norm, :math:`\Vert\mathbf{x}\Vert^{2}_{2} := \sum_{i} |x_{i}|^{2}`.
77 """
78
79 def __init__(self, dim_shape: pxt.NDArrayShape):
80 super().__init__(
81 dim_shape=dim_shape,
82 codim_shape=1,
83 )
84 self.lipschitz = np.inf
85 self.diff_lipschitz = 2
86
87 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
88 axis = tuple(range(-self.dim_rank, 0))
89 y = (arr**2).sum(axis=axis)[..., np.newaxis]
90 return y
91
92 def grad(self, arr: pxt.NDArray) -> pxt.NDArray:
93 return 2 * arr
94
95 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
96 y = arr.copy()
97 y /= 2 * tau + 1
98 return y
99
100 def _quad_spec(self):
101 from pyxu.operator import HomothetyOp, NullFunc
102
103 Q = HomothetyOp(dim_shape=self.dim_shape, cst=2)
104 c = NullFunc(dim_shape=self.dim_shape)
105 t = 0
106 return (Q, c, t)
107
108
[docs]
109class SquaredL1Norm(pxa.ProxFunc):
110 r"""
111 :math:`\ell^{2}_{1}`-norm, :math:`\Vert\mathbf{x}\Vert^{2}_{1} := (\sum_{i} |x_{i}|)^{2}`.
112
113 Note
114 ----
115 * Computing :py:meth:`~pyxu.abc.ProxFunc.prox` is unavailable with DASK inputs.
116 (Inefficient exact solution at scale.)
117 """
118
[docs]
119 def __init__(self, dim_shape: pxt.NDArrayShape):
120 super().__init__(
121 dim_shape=dim_shape,
122 codim_shape=1,
123 )
124 self.lipschitz = np.inf
125
126 # prox(): vectorize
127 vectorize = pxu.vectorize(
128 i="arr",
129 dim_shape=self.dim_shape,
130 codim_shape=self.dim_shape,
131 )
132 self.prox = vectorize(self.prox)
133
134 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
135 y = L1Norm(dim_shape=self.dim_shape).apply(arr)
136 y **= 2
137 return y
138
[docs]
139 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
140 ndi = pxd.NDArrayInfo.from_obj(arr)
141 if ndi == pxd.NDArrayInfo.DASK:
142 raise NotImplementedError("Not implemented at scale.")
143
144 norm = self.apply(arr).item()
145 if norm > 0:
146 xp = ndi.module()
147
148 # Part 1: Compute \mu_opt -----------------------------------------
149 mu_opt, res = sopt.brentq(
150 f=lambda mu: (xp.fabs(arr) * xp.sqrt(tau / mu) - 2 * tau).clip(0, None).sum() - 1,
151 a=1e-12,
152 b=(xp.fabs(arr).max() ** 2) / (4 * tau),
153 full_output=True,
154 disp=False,
155 )
156 if not res.converged:
157 raise ValueError("Computing mu_opt did not converge.")
158
159 # Part 2: Compute \lambda -----------------------------------------
160 lambda_ = (xp.fabs(arr) * xp.sqrt(tau / mu_opt) - 2 * tau).clip(0, None)
161
162 # Part 3: Compute \prox -------------------------------------------
163 y = arr.copy()
164 y *= lambda_ / (lambda_ + 2 * tau)
165 else:
166 y = pxu.read_only(arr)
167
168 return y
169
170
[docs]
171class LInfinityNorm(pxa.ProxFunc):
172 r"""
173 :math:`\ell_{\infty}`-norm, :math:`\Vert\mathbf{x}\Vert_{\infty} := \max_{i} |x_{i}|`.
174
175 Note
176 ----
177 * Computing :py:meth:`~pyxu.abc.ProxFunc.prox` is unavailable with DASK inputs.
178 (Inefficient exact solution at scale.)
179 """
180
181 def __init__(self, dim_shape: pxt.NDArrayShape):
182 super().__init__(
183 dim_shape=dim_shape,
184 codim_shape=1,
185 )
186 self.lipschitz = 1
187
188 # prox(): vectorize
189 vectorize = pxu.vectorize(
190 i="arr",
191 dim_shape=self.dim_shape,
192 codim_shape=self.dim_shape,
193 )
194 self.prox = vectorize(self.prox)
195
196 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
197 xp = pxu.get_array_module(arr)
198 axis = tuple(range(-self.dim_rank, 0))
199 y = xp.fabs(arr).max(axis=axis)[..., np.newaxis]
200 return y
201
202 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
203 ndi = pxd.NDArrayInfo.from_obj(arr)
204 if ndi == pxd.NDArrayInfo.DASK:
205 raise NotImplementedError("Not implemented at scale.")
206
207 mu_max = self.apply(arr).item()
208 if mu_max > tau:
209 xp = ndi.module()
210 mu_opt = sopt.brentq(
211 f=lambda mu: (xp.fabs(arr) - mu).clip(0, None).sum() - tau,
212 a=0,
213 b=mu_max,
214 )
215 y = xp.sign(arr) * xp.fmin(xp.fabs(arr), mu_opt)
216 else:
217 y = pxu.read_only(arr)
218
219 return y
220
221
[docs]
222class L21Norm(pxa.ProxFunc):
223 r"""
224 Mixed :math:`\ell_{2}-\ell_{1}` norm, :math:`\Vert\mathbf{x}\Vert_{2, 1} := \sum_{i} \sqrt{\sum_{j} x_{i, j}^{2}}`.
225 """
226
[docs]
227 def __init__(
228 self,
229 dim_shape: pxt.NDArrayShape,
230 l2_axis: pxt.NDArrayAxis = (0,),
231 ):
232 r"""
233 Parameters
234 ----------
235 l2_axis: NDArrayAxis
236 Axis (or axes) along which the :math:`\ell_{2}` norm is applied.
237 """
238 super().__init__(
239 dim_shape=dim_shape,
240 codim_shape=1,
241 )
242 assert self.dim_rank >= 2
243
244 l2_axis = pxu.as_canonical_axes(l2_axis, rank=self.dim_rank)
245 l1_axis = tuple(ax for ax in range(self.dim_rank) if ax not in l2_axis)
246
247 self.lipschitz = np.inf
248 self._l1_axis = l1_axis
249 self._l2_axis = l2_axis
250
251 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
252 sh = arr.shape[: -self.dim_rank]
253
254 l2_axis = tuple(len(sh) + ax for ax in self._l2_axis)
255 x = (arr**2).sum(axis=l2_axis, keepdims=True)
256 xp = pxu.get_array_module(arr)
257 xp.sqrt(x, out=x)
258
259 l1_axis = tuple(len(sh) + ax for ax in self._l1_axis)
260 out = x.sum(axis=l1_axis, keepdims=True)
261
262 out = out.squeeze(l1_axis + l2_axis)[..., np.newaxis]
263 return out
264
265 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
266 sh = arr.shape[: -self.dim_rank]
267
268 l2_axis = tuple(len(sh) + ax for ax in self._l2_axis)
269 n = (arr**2).sum(axis=l2_axis, keepdims=True)
270 xp = pxu.get_array_module(arr)
271 xp.sqrt(n, out=n)
272
273 out = arr.copy()
274 out *= 1 - tau / xp.fmax(n, tau)
275 return out
276
277
[docs]
278class PositiveL1Norm(pxa.ProxFunc):
279 r"""
280 :math:`\ell_{1}`-norm, with a positivity constraint.
281
282 .. math::
283
284 f(\mathbf{x})
285 :=
286 \lVert\mathbf{x}\rVert_{1} + \iota_{+}(\mathbf{x}),
287
288 .. math::
289
290 \textbf{prox}_{\tau f}(\mathbf{z})
291 :=
292 \max(\mathrm{soft}_\tau(\mathbf{z}), \mathbf{0})
293
294 See Also
295 --------
296 :py:class:`~pyxu.operator.PositiveOrthant`
297 """
298
299 def __init__(self, dim_shape: pxt.NDArrayShape):
300 super().__init__(
301 dim_shape=dim_shape,
302 codim_shape=1,
303 )
304 from pyxu.operator.func.indicator import PositiveOrthant
305
306 self._indicator = PositiveOrthant(dim_shape=dim_shape)
307 self._l1norm = L1Norm(dim_shape=dim_shape)
308 self.lipschitz = np.inf
309
310 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
311 return self._indicator(arr) + self._l1norm(arr)
312
313 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
314 y = (arr - tau).clip(0, None)
315 return y