Source code for pyxu.operator.func.loss

  1import numpy as np
  2
  3import pyxu.abc as pxa
  4import pyxu.info.deps as pxd
  5import pyxu.info.ptype as pxt
  6import pyxu.util as pxu
  7
  8__all__ = [
  9    "KLDivergence",
 10]
 11
 12
[docs] 13class KLDivergence(pxa.ProxFunc): 14 r""" 15 Generalised Kullback-Leibler divergence 16 :math:`D_{KL}(\mathbf{y}||\mathbf{x}) := \sum_{i} y_{i} \log(y_{i} / x_{i}) - y_{i} + x_{i}`. 17 """ 18
[docs] 19 def __init__(self, data: pxt.NDArray): 20 r""" 21 Parameters 22 ---------- 23 data: NDArray 24 (M1,...,MD) non-negative input data. 25 26 Examples 27 -------- 28 .. code-block:: python3 29 30 import numpy as np 31 from pyxu.operator import KLDivergence 32 33 y = np.arange(5) 34 loss = KLDivergence(y) 35 36 loss(2 * y) # [3.06852819] 37 np.round(loss.prox(2 * y, tau=1)) # [0. 2. 4. 6. 8.] 38 39 Notes 40 ----- 41 * When :math:`\mathbf{y}` and :math:`\mathbf{x}` sum to one, and hence can be interpreted as discrete 42 probability distributions, the KL-divergence corresponds to the relative entropy of :math:`\mathbf{y}` w.r.t. 43 :math:`\mathbf{x}`, i.e. the amount of information lost when using :math:`\mathbf{x}` to approximate 44 :math:`\mathbf{y}`. It is particularly useful in the context of count data with Poisson distribution; the 45 KL-divergence then corresponds (up to an additive constant) to the likelihood of :math:`\mathbf{y}` where each 46 component is independent with Poisson distribution and respective intensities given by :math:`\mathbf{x}`. See 47 [FuncSphere]_ Chapter 7, Section 5 for the computation of its proximal operator. 48 * :py:class:`~pyxu.operator.KLDivergence` is not backend-agnostic: inputs to arithmetic methods must have the 49 same backend as `data`. 50 * If `data` is a DASK array, it's entries are assumed non-negative de-facto. Reason: the operator should be 51 quick to build under all circumstances, and this is not guaranteed if we have to check that all entries are positive for out-of-core arrays. 52 * If `data` is a DASK array, the core-dimensions of arrays supplied to arithmetic methods **must** have the 53 same chunk-size as `data`. 54 """ 55 super().__init__( 56 dim_shape=data.shape, 57 codim_shape=1, 58 ) 59 60 ndi = pxd.NDArrayInfo.from_obj(data) 61 if ndi != pxd.NDArrayInfo.DASK: 62 assert (data >= 0).all(), "KL Divergence only defined for non-negative arguments." 63 self._data = data
64 65 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 66 axis = tuple(range(-self.dim_rank, 0)) 67 out = self._kl_div(arr, self._data) 68 out = out.sum(axis=axis)[..., np.newaxis] 69 return out 70 71 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray: 72 xp = pxu.get_array_module(arr) 73 x = arr - tau 74 out = x + xp.sqrt((x**2) + ((4 * tau) * self._data)) 75 out *= 0.5 76 return out 77 78 @staticmethod 79 def _kl_div(x: pxt.NDArray, data: pxt.NDArray) -> pxt.NDArray: 80 # Element-wise KL-divergence 81 N = pxd.NDArrayInfo # short-hand 82 ndi = N.from_obj(x) 83 84 if ndi == N.NUMPY: 85 sp = pxu.import_module("scipy.special") 86 out = sp.kl_div(data, x) 87 elif ndi == N.CUPY: 88 sp = pxu.import_module("cupyx.scipy.special") 89 out = sp.kl_div(data, x) 90 elif ndi == N.DASK: 91 assert x.chunks[-data.ndim :] == data.chunks 92 xp = ndi.module() 93 94 out = xp.map_blocks( 95 KLDivergence._kl_div, 96 x, 97 data, 98 dtype=data.dtype, 99 meta=data._meta, 100 ) 101 else: 102 raise NotImplementedError 103 return out