Source code for pyxu.operator.func.indicator

  1import numpy as np
  2
  3import pyxu.abc as pxa
  4import pyxu.info.ptype as pxt
  5import pyxu.operator.func.norm as pxf
  6import pyxu.runtime as pxrt
  7import pyxu.util as pxu
  8
  9__all__ = [
 10    "L1Ball",
 11    "L2Ball",
 12    "LInfinityBall",
 13    "PositiveOrthant",
 14    "HyperSlab",
 15    "RangeSet",
 16]
 17
 18
 19class _IndicatorFunction(pxa.ProxFunc):
 20    def __init__(self, dim_shape: pxt.NDArrayShape):
 21        super().__init__(
 22            dim_shape=dim_shape,
 23            codim_shape=1,
 24        )
 25        self.lipschitz = np.inf
 26
 27    @staticmethod
 28    def _bool2indicator(x: pxt.NDArray, dtype: pxt.DType) -> pxt.NDArray:
 29        # x: NDarray[bool]
 30        # y: NDarray[(0, \inf), dtype]
 31        xp = pxu.get_array_module(x)
 32        cast = lambda _: np.array(_, dtype=dtype)[()]
 33        y = xp.where(x, cast(0), cast(np.inf))
 34        return y
 35
 36
 37class _NormBall(_IndicatorFunction):
 38    def __init__(
 39        self,
 40        dim_shape: pxt.NDArrayShape,
 41        ord: pxt.Integer,
 42        radius: pxt.Real,
 43    ):
 44        super().__init__(dim_shape=dim_shape)
 45        self._ord = ord
 46        self._radius = float(radius)
 47
 48    def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
 49        from pyxu.opt.stop import _norm
 50
 51        norm = _norm(arr, ord=self._ord, rank=self.dim_rank)  # (..., 1)
 52        out = self._bool2indicator(norm <= self._radius, arr.dtype)
 53        return out
 54
 55    def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
 56        klass = {  # class of proximal operator to use
 57            1: pxf.LInfinityNorm,
 58            2: pxf.L2Norm,
 59            np.inf: pxf.L1Norm,
 60        }[self._ord]
 61        op = klass(dim_shape=self.dim_shape)
 62
 63        out = arr.copy()
 64        out -= op.prox(arr, tau=self._radius)
 65        return out
 66
 67
[docs] 68def L1Ball(dim_shape: pxt.NDArrayShape, radius: pxt.Real = 1) -> pxt.OpT: 69 r""" 70 Indicator function of the :math:`\ell_{1}`-ball. 71 72 .. math:: 73 74 \iota_{1}^{r}(\mathbf{x}) 75 := 76 \begin{cases} 77 0 & \|\mathbf{x}\|_{1} \le r \\ 78 \infty & \text{otherwise}. 79 \end{cases} 80 81 .. math:: 82 83 \text{prox}_{\tau\, \iota_{1}^{r}}(\mathbf{x}) 84 := 85 \mathbf{x} - \text{prox}_{r\, \ell_{\infty}}(\mathbf{x}) 86 87 Parameters 88 ---------- 89 dim_shape: NDArrayShape 90 radius: Real 91 Ball radius. (Default: unit ball.) 92 93 Returns 94 ------- 95 op: OpT 96 97 Note 98 ---- 99 * Computing :py:meth:`~pyxu.abc.ProxFunc.prox` is unavailable with DASK inputs. 100 (Inefficient exact solution at scale.) 101 """ 102 op = _NormBall(dim_shape=dim_shape, ord=1, radius=radius) 103 op._name = "L1Ball" 104 return op
105 106
[docs] 107def L2Ball(dim_shape: pxt.NDArrayShape, radius: pxt.Real = 1) -> pxt.OpT: 108 r""" 109 Indicator function of the :math:`\ell_{2}`-ball. 110 111 .. math:: 112 113 \iota_{2}^{r}(\mathbf{x}) 114 := 115 \begin{cases} 116 0 & \|\mathbf{x}\|_{2} \le r \\ 117 \infty & \text{otherwise}. 118 \end{cases} 119 120 .. math:: 121 122 \text{prox}_{\tau\, \iota_{2}^{r}}(\mathbf{x}) 123 := 124 \mathbf{x} - \text{prox}_{r\, \ell_{2}}(\mathbf{x}) 125 126 Parameters 127 ---------- 128 dim_shape: NDArrayShape 129 radius: Real 130 Ball radius. (Default: unit ball.) 131 132 Returns 133 ------- 134 op: OpT 135 """ 136 op = _NormBall(dim_shape=dim_shape, ord=2, radius=radius) 137 op._name = "L2Ball" 138 return op
139 140
[docs] 141def LInfinityBall(dim_shape: pxt.NDArrayShape, radius: pxt.Real = 1) -> pxt.OpT: 142 r""" 143 Indicator function of the :math:`\ell_{\infty}`-ball. 144 145 .. math:: 146 147 \iota_{\infty}^{r}(\mathbf{x}) 148 := 149 \begin{cases} 150 0 & \|\mathbf{x}\|_{\infty} \le r \\ 151 \infty & \text{otherwise}. 152 \end{cases} 153 154 .. math:: 155 156 \text{prox}_{\tau\, \iota_{\infty}^{r}}(\mathbf{x}) 157 := 158 \mathbf{x} - \text{prox}_{r\, \ell_{1}}(\mathbf{x}) 159 160 Parameters 161 ---------- 162 dim_shape: NDArrayShape 163 radius: Real 164 Ball radius. (Default: unit ball.) 165 166 Returns 167 ------- 168 op: OpT 169 """ 170 op = _NormBall(dim_shape=dim_shape, ord=np.inf, radius=radius) 171 op._name = "LInfinityBall" 172 return op
173 174
[docs] 175class PositiveOrthant(_IndicatorFunction): 176 r""" 177 Indicator function of the positive orthant. 178 179 .. math:: 180 181 \iota_{+}(\mathbf{x}) 182 := 183 \begin{cases} 184 0 & \min{\mathbf{x}} \ge 0,\\ 185 \infty & \text{otherwise}. 186 \end{cases} 187 188 .. math:: 189 190 \text{prox}_{\tau\, \iota_{+}}(\mathbf{x}) 191 := 192 \max(\mathbf{x}, \mathbf{0}) 193 """ 194 195 def __init__(self, dim_shape: pxt.NDArrayShape): 196 super().__init__(dim_shape=dim_shape) 197 198 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 199 axis = tuple(range(-self.dim_rank, 0)) 200 in_set = (arr >= 0).all(axis=axis)[..., np.newaxis] 201 out = self._bool2indicator(in_set, arr.dtype) 202 return out 203 204 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray: 205 out = arr.clip(0, None) 206 return out
207 208
[docs] 209class HyperSlab(_IndicatorFunction): 210 r""" 211 Indicator function of a hyperslab. 212 213 .. math:: 214 215 \iota_{\mathbf{a}}^{l,u}(\mathbf{x}) 216 := 217 \begin{cases} 218 0 & l \le \langle \mathbf{a}, \mathbf{x} \rangle \le u \\ 219 \infty & \text{otherwise}. 220 \end{cases} 221 222 .. math:: 223 224 \text{prox}_{\tau\, \iota_{\mathbf{a}}^{l,u}}(\mathbf{x}) 225 := 226 \begin{cases} 227 \mathbf{x} + \frac{l - \langle \mathbf{a}, \mathbf{x} \rangle}{\|\mathbf{a}\|^{2}} \mathbf{a} & \langle \mathbf{a}, \mathbf{x} \rangle < l, \\ 228 \mathbf{x} + \frac{u - \langle \mathbf{a}, \mathbf{x} \rangle}{\|\mathbf{a}\|^{2}} \mathbf{a} & \langle \mathbf{a}, \mathbf{x} \rangle > u, \\ 229 \mathbf{x} & \text{otherwise}. 230 \end{cases} 231 """ 232
[docs] 233 def __init__(self, a: pxa.LinFunc, lb: pxt.Real, ub: pxt.Real): 234 """ 235 Parameters 236 ---------- 237 A: ~pyxu.abc.operator.LinFunc 238 Linear functional with domain (M1,...,MD). 239 lb: Real 240 Lower bound. 241 ub: Real 242 Upper bound. 243 """ 244 assert lb < ub 245 super().__init__(dim_shape=a.dim_shape) 246 247 # Everything happens internally in normalized coordinates. 248 _norm = a.lipschitz # \norm{a}{2} 249 self._a = a / _norm 250 self._l = lb / _norm 251 self._u = ub / _norm
252 253 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 254 y = self._a.apply(arr) # (..., 1) 255 in_set = ((self._l <= y) & (y <= self._u)).all(axis=-1, keepdims=True) 256 out = self._bool2indicator(in_set, arr.dtype) # (..., 1) 257 return out 258 259 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray: 260 xp = pxu.get_array_module(arr) 261 262 a = self._a.adjoint(xp.ones(1, dtype=arr.dtype)) # (M1,...,MD) 263 expand = (np.newaxis,) * (self.dim_rank - 1) 264 y = self._a.apply(arr)[..., *expand] # (..., 1,...,1) 265 out = arr.copy() 266 267 l_corr = self._l - y 268 l_corr[l_corr <= 0] = 0 269 out += l_corr * a 270 271 u_corr = self._u - y 272 u_corr[u_corr >= 0] = 0 273 out += u_corr * a 274 275 return out
276 277
[docs] 278class RangeSet(_IndicatorFunction): 279 r""" 280 Indicator function of a range set. 281 282 .. math:: 283 284 \iota_{\mathbf{A}}^{R}(\mathbf{x}) 285 := 286 \begin{cases} 287 0 & \mathbf{x} \in \text{span}(\mathbf{A}) \\ 288 \infty & \text{otherwise}. 289 \end{cases} 290 291 .. math:: 292 293 \text{prox}_{\tau\, \iota_{\mathbf{A}}^{R}}(\mathbf{x}) 294 := 295 \mathbf{A} (\mathbf{A}^{T} \mathbf{A})^{-1} \mathbf{A}^{T} \mathbf{x}. 296 """ 297
[docs] 298 def __init__(self, A: pxa.LinOp): 299 super().__init__(dim_shape=A.codim_shape) 300 self._A = A
301 302 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 303 # I'm in range(A) if prox(x)==x. 304 axis = tuple(range(-self.dim_rank, 0)) 305 y = self.prox(arr, tau=1) 306 in_set = self.isclose(y, arr).all(axis=axis) # (...,) 307 out = self._bool2indicator(in_set[..., np.newaxis], arr.dtype) 308 return out # (..., 1) 309 310 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray: 311 y = self._A.pinv(arr, damp=0) 312 out = self._A.apply(y) 313 return out 314 315 @staticmethod 316 def isclose(a: pxt.NDArray, b: pxt.NDArray) -> pxt.NDArray: 317 """ 318 Equivalent of `xp.isclose`, but where atol is automatically chosen based on input's `dtype`. 319 """ 320 atol = { 321 pxrt.Width.SINGLE.value: 2e-4, 322 pxrt.Width.DOUBLE.value: 1e-8, 323 } 324 # Numbers obtained by: 325 # * \sum_{k >= (p+1)//2} 2^{-k}, where p=<number of mantissa bits>; then 326 # * round up value to 3 significant decimal digits. 327 # N_mantissa = [23, 52] for [single, double] respectively. 328 xp = pxu.get_array_module(a) 329 prec = atol.get(a.dtype, pxrt.Width.DOUBLE.value) # default only should occur for integer types 330 eq = xp.isclose(a, b, atol=prec) 331 return eq