Source code for pyxu.operator.func.norm

  1import numpy as np
  2import scipy.optimize as sopt
  4import as pxa
  5import as pxd
  6import as pxt
  7import pyxu.util as pxu
  9__all__ = [
 10    "L1Norm",
 11    "L2Norm",
 12    "SquaredL2Norm",
 13    "SquaredL1Norm",
 14    "LInfinityNorm",
 15    "L21Norm",
 16    "PositiveL1Norm",
[docs] 20class L1Norm(pxa.ProxFunc): 21 r""" 22 :math:`\ell_{1}`-norm, :math:`\Vert\mathbf{x}\Vert_{1} := \sum_{i} |x_{i}|`. 23 """ 24 25 def __init__(self, dim_shape: pxt.NDArrayShape): 26 super().__init__( 27 dim_shape=dim_shape, 28 codim_shape=1, 29 ) 30 self.lipschitz = np.sqrt(self.dim_size) 31 32 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 33 xp = pxu.get_array_module(arr) 34 axis = tuple(range(-self.dim_rank, 0)) 35 y = xp.fabs(arr).sum(axis=axis)[..., np.newaxis] 36 return y 37 38 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray: 39 xp = pxu.get_array_module(arr) 40 y = xp.fmax(0, xp.fabs(arr) - tau) 41 y *= xp.sign(arr) 42 return y
43 44
[docs] 45class L2Norm(pxa.ProxFunc): 46 r""" 47 :math:`\ell_{2}`-norm, :math:`\Vert\mathbf{x}\Vert_{2} := \sqrt{\sum_{i} |x_{i}|^{2}}`. 48 """ 49 50 def __init__(self, dim_shape: pxt.NDArrayShape): 51 super().__init__( 52 dim_shape=dim_shape, 53 codim_shape=1, 54 ) 55 self.lipschitz = 1 56 self.diff_lipschitz = np.inf 57 58 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 59 xp = pxu.get_array_module(arr) 60 axis = tuple(range(-self.dim_rank, 0)) 61 y = xp.sqrt((arr**2).sum(axis=axis))[..., np.newaxis] 62 return y 63 64 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray: 65 xp = pxu.get_array_module(arr) 66 scale = 1 - tau / xp.fmax(self.apply(arr), tau) # (..., 1) 67 68 y = arr.copy() 69 expand = (np.newaxis,) * (self.dim_rank - 1) 70 y *= scale[..., *expand] 71 return y
72 73
[docs] 74class SquaredL2Norm(pxa.QuadraticFunc): 75 r""" 76 :math:`\ell^{2}_{2}`-norm, :math:`\Vert\mathbf{x}\Vert^{2}_{2} := \sum_{i} |x_{i}|^{2}`. 77 """ 78 79 def __init__(self, dim_shape: pxt.NDArrayShape): 80 super().__init__( 81 dim_shape=dim_shape, 82 codim_shape=1, 83 ) 84 self.lipschitz = np.inf 85 self.diff_lipschitz = 2 86 87 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 88 axis = tuple(range(-self.dim_rank, 0)) 89 y = (arr**2).sum(axis=axis)[..., np.newaxis] 90 return y 91 92 def grad(self, arr: pxt.NDArray) -> pxt.NDArray: 93 return 2 * arr 94 95 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray: 96 y = arr.copy() 97 y /= 2 * tau + 1 98 return y 99 100 def _quad_spec(self): 101 from pyxu.operator import HomothetyOp, NullFunc 102 103 Q = HomothetyOp(dim_shape=self.dim_shape, cst=2) 104 c = NullFunc(dim_shape=self.dim_shape) 105 t = 0 106 return (Q, c, t)
107 108
[docs] 109class SquaredL1Norm(pxa.ProxFunc): 110 r""" 111 :math:`\ell^{2}_{1}`-norm, :math:`\Vert\mathbf{x}\Vert^{2}_{1} := (\sum_{i} |x_{i}|)^{2}`. 112 113 Note 114 ---- 115 * Computing :py:meth:`` is unavailable with DASK inputs. 116 (Inefficient exact solution at scale.) 117 """ 118
[docs] 119 def __init__(self, dim_shape: pxt.NDArrayShape): 120 super().__init__( 121 dim_shape=dim_shape, 122 codim_shape=1, 123 ) 124 self.lipschitz = np.inf 125 126 # prox(): vectorize 127 vectorize = pxu.vectorize( 128 i="arr", 129 dim_shape=self.dim_shape, 130 codim_shape=self.dim_shape, 131 ) 132 self.prox = vectorize(self.prox)
133 134 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 135 y = L1Norm(dim_shape=self.dim_shape).apply(arr) 136 y **= 2 137 return y 138
[docs] 139 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray: 140 ndi = pxd.NDArrayInfo.from_obj(arr) 141 if ndi == pxd.NDArrayInfo.DASK: 142 raise NotImplementedError("Not implemented at scale.") 143 144 norm = self.apply(arr).item() 145 if norm > 0: 146 xp = ndi.module() 147 148 # Part 1: Compute \mu_opt ----------------------------------------- 149 mu_opt, res = sopt.brentq( 150 f=lambda mu: (xp.fabs(arr) * xp.sqrt(tau / mu) - 2 * tau).clip(0, None).sum() - 1, 151 a=1e-12, 152 b=(xp.fabs(arr).max() ** 2) / (4 * tau), 153 full_output=True, 154 disp=False, 155 ) 156 if not res.converged: 157 raise ValueError("Computing mu_opt did not converge.") 158 159 # Part 2: Compute \lambda ----------------------------------------- 160 lambda_ = (xp.fabs(arr) * xp.sqrt(tau / mu_opt) - 2 * tau).clip(0, None) 161 162 # Part 3: Compute \prox ------------------------------------------- 163 y = arr.copy() 164 y *= lambda_ / (lambda_ + 2 * tau) 165 else: 166 y = pxu.read_only(arr) 167 168 return y
169 170
[docs] 171class LInfinityNorm(pxa.ProxFunc): 172 r""" 173 :math:`\ell_{\infty}`-norm, :math:`\Vert\mathbf{x}\Vert_{\infty} := \max_{i} |x_{i}|`. 174 175 Note 176 ---- 177 * Computing :py:meth:`` is unavailable with DASK inputs. 178 (Inefficient exact solution at scale.) 179 """ 180 181 def __init__(self, dim_shape: pxt.NDArrayShape): 182 super().__init__( 183 dim_shape=dim_shape, 184 codim_shape=1, 185 ) 186 self.lipschitz = 1 187 188 # prox(): vectorize 189 vectorize = pxu.vectorize( 190 i="arr", 191 dim_shape=self.dim_shape, 192 codim_shape=self.dim_shape, 193 ) 194 self.prox = vectorize(self.prox) 195 196 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 197 xp = pxu.get_array_module(arr) 198 axis = tuple(range(-self.dim_rank, 0)) 199 y = xp.fabs(arr).max(axis=axis)[..., np.newaxis] 200 return y 201 202 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray: 203 ndi = pxd.NDArrayInfo.from_obj(arr) 204 if ndi == pxd.NDArrayInfo.DASK: 205 raise NotImplementedError("Not implemented at scale.") 206 207 mu_max = self.apply(arr).item() 208 if mu_max > tau: 209 xp = ndi.module() 210 mu_opt = sopt.brentq( 211 f=lambda mu: (xp.fabs(arr) - mu).clip(0, None).sum() - tau, 212 a=0, 213 b=mu_max, 214 ) 215 y = xp.sign(arr) * xp.fmin(xp.fabs(arr), mu_opt) 216 else: 217 y = pxu.read_only(arr) 218 219 return y
220 221
[docs] 222class L21Norm(pxa.ProxFunc): 223 r""" 224 Mixed :math:`\ell_{2}-\ell_{1}` norm, :math:`\Vert\mathbf{x}\Vert_{2, 1} := \sum_{i} \sqrt{\sum_{j} x_{i, j}^{2}}`. 225 """ 226
[docs] 227 def __init__( 228 self, 229 dim_shape: pxt.NDArrayShape, 230 l2_axis: pxt.NDArrayAxis = (0,), 231 ): 232 r""" 233 Parameters 234 ---------- 235 l2_axis: NDArrayAxis 236 Axis (or axes) along which the :math:`\ell_{2}` norm is applied. 237 """ 238 super().__init__( 239 dim_shape=dim_shape, 240 codim_shape=1, 241 ) 242 assert self.dim_rank >= 2 243 244 l2_axis = pxu.as_canonical_axes(l2_axis, rank=self.dim_rank) 245 l1_axis = tuple(ax for ax in range(self.dim_rank) if ax not in l2_axis) 246 247 self.lipschitz = np.inf 248 self._l1_axis = l1_axis 249 self._l2_axis = l2_axis
250 251 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 252 sh = arr.shape[: -self.dim_rank] 253 254 l2_axis = tuple(len(sh) + ax for ax in self._l2_axis) 255 x = (arr**2).sum(axis=l2_axis, keepdims=True) 256 xp = pxu.get_array_module(arr) 257 xp.sqrt(x, out=x) 258 259 l1_axis = tuple(len(sh) + ax for ax in self._l1_axis) 260 out = x.sum(axis=l1_axis, keepdims=True) 261 262 out = out.squeeze(l1_axis + l2_axis)[..., np.newaxis] 263 return out 264 265 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray: 266 sh = arr.shape[: -self.dim_rank] 267 268 l2_axis = tuple(len(sh) + ax for ax in self._l2_axis) 269 n = (arr**2).sum(axis=l2_axis, keepdims=True) 270 xp = pxu.get_array_module(arr) 271 xp.sqrt(n, out=n) 272 273 out = arr.copy() 274 out *= 1 - tau / xp.fmax(n, tau) 275 return out
276 277
[docs] 278class PositiveL1Norm(pxa.ProxFunc): 279 r""" 280 :math:`\ell_{1}`-norm, with a positivity constraint. 281 282 .. math:: 283 284 f(\mathbf{x}) 285 := 286 \lVert\mathbf{x}\rVert_{1} + \iota_{+}(\mathbf{x}), 287 288 .. math:: 289 290 \textbf{prox}_{\tau f}(\mathbf{z}) 291 := 292 \max(\mathrm{soft}_\tau(\mathbf{z}), \mathbf{0}) 293 294 See Also 295 -------- 296 :py:class:`~pyxu.operator.PositiveOrthant` 297 """ 298 299 def __init__(self, dim_shape: pxt.NDArrayShape): 300 super().__init__( 301 dim_shape=dim_shape, 302 codim_shape=1, 303 ) 304 from pyxu.operator.func.indicator import PositiveOrthant 305 306 self._indicator = PositiveOrthant(dim_shape=dim_shape) 307 self._l1norm = L1Norm(dim_shape=dim_shape) 308 self.lipschitz = np.inf 309 310 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 311 return self._indicator(arr) + self._l1norm(arr) 312 313 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray: 314 y = (arr - tau).clip(0, None) 315 return y