Source code for pyxu.operator.func.loss
1import numpy as np
2
3import pyxu.abc as pxa
4import pyxu.info.deps as pxd
5import pyxu.info.ptype as pxt
6import pyxu.util as pxu
7
8__all__ = [
9 "KLDivergence",
10]
11
12
[docs]
13class KLDivergence(pxa.ProxFunc):
14 r"""
15 Generalised Kullback-Leibler divergence
16 :math:`D_{KL}(\mathbf{y}||\mathbf{x}) := \sum_{i} y_{i} \log(y_{i} / x_{i}) - y_{i} + x_{i}`.
17 """
18
[docs]
19 def __init__(self, data: pxt.NDArray):
20 r"""
21 Parameters
22 ----------
23 data: NDArray
24 (M1,...,MD) non-negative input data.
25
26 Examples
27 --------
28 .. code-block:: python3
29
30 import numpy as np
31 from pyxu.operator import KLDivergence
32
33 y = np.arange(5)
34 loss = KLDivergence(y)
35
36 loss(2 * y) # [3.06852819]
37 np.round(loss.prox(2 * y, tau=1)) # [0. 2. 4. 6. 8.]
38
39 Notes
40 -----
41 * When :math:`\mathbf{y}` and :math:`\mathbf{x}` sum to one, and hence can be interpreted as discrete
42 probability distributions, the KL-divergence corresponds to the relative entropy of :math:`\mathbf{y}` w.r.t.
43 :math:`\mathbf{x}`, i.e. the amount of information lost when using :math:`\mathbf{x}` to approximate
44 :math:`\mathbf{y}`. It is particularly useful in the context of count data with Poisson distribution; the
45 KL-divergence then corresponds (up to an additive constant) to the likelihood of :math:`\mathbf{y}` where each
46 component is independent with Poisson distribution and respective intensities given by :math:`\mathbf{x}`. See
47 [FuncSphere]_ Chapter 7, Section 5 for the computation of its proximal operator.
48 * :py:class:`~pyxu.operator.KLDivergence` is not backend-agnostic: inputs to arithmetic methods must have the
49 same backend as `data`.
50 * If `data` is a DASK array, it's entries are assumed non-negative de-facto. Reason: the operator should be
51 quick to build under all circumstances, and this is not guaranteed if we have to check that all entries are positive for out-of-core arrays.
52 * If `data` is a DASK array, the core-dimensions of arrays supplied to arithmetic methods **must** have the
53 same chunk-size as `data`.
54 """
55 super().__init__(
56 dim_shape=data.shape,
57 codim_shape=1,
58 )
59
60 ndi = pxd.NDArrayInfo.from_obj(data)
61 if ndi != pxd.NDArrayInfo.DASK:
62 assert (data >= 0).all(), "KL Divergence only defined for non-negative arguments."
63 self._data = data
64
65 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
66 axis = tuple(range(-self.dim_rank, 0))
67 out = self._kl_div(arr, self._data)
68 out = out.sum(axis=axis)[..., np.newaxis]
69 return out
70
71 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
72 xp = pxu.get_array_module(arr)
73 x = arr - tau
74 out = x + xp.sqrt((x**2) + ((4 * tau) * self._data))
75 out *= 0.5
76 return out
77
78 @staticmethod
79 def _kl_div(x: pxt.NDArray, data: pxt.NDArray) -> pxt.NDArray:
80 # Element-wise KL-divergence
81 N = pxd.NDArrayInfo # short-hand
82 ndi = N.from_obj(x)
83
84 if ndi == N.NUMPY:
85 sp = pxu.import_module("scipy.special")
86 out = sp.kl_div(data, x)
87 elif ndi == N.CUPY:
88 sp = pxu.import_module("cupyx.scipy.special")
89 out = sp.kl_div(data, x)
90 elif ndi == N.DASK:
91 assert x.chunks[-data.ndim :] == data.chunks
92 xp = ndi.module()
93
94 out = xp.map_blocks(
95 KLDivergence._kl_div,
96 x,
97 data,
98 dtype=data.dtype,
99 meta=data._meta,
100 )
101 else:
102 raise NotImplementedError
103 return out