Source code for pyxu.operator.linop.kron

  1import numpy as np
  2
  3import pyxu.abc as pxa
  4import pyxu.info.deps as pxd
  5import pyxu.info.ptype as pxt
  6import pyxu.operator.interop.source as px_src
  7import pyxu.util as pxu
  8
  9__all__ = [
 10    "kron",
 11    "khatri_rao",
 12]
 13
 14
[docs] 15def kron(A: pxt.OpT, B: pxt.OpT) -> pxt.OpT: 16 r""" 17 `Kronecker product <https://en.wikipedia.org/wiki/Kronecker_product#Definition>`_ :math:`A \otimes B` between two 18 linear operators. 19 20 The Kronecker product :math:`A \otimes B` is defined as 21 22 .. math:: 23 24 A \otimes B 25 = 26 \left[ 27 \begin{array}{ccc} 28 A_{11} B & \cdots & A_{1N_{A}} B \\ 29 \vdots & \ddots & \vdots \\ 30 A_{M_{A}1} B & \cdots & A_{M_{A}N_{A}} B \\ 31 \end{array} 32 \right], 33 34 where :math:`A : \mathbb{R}^{N_{A}} \to \mathbb{R}^{M_{A}}`, and :math:`B : \mathbb{R}^{N_{B}} \to 35 \mathbb{R}^{M_{B}}`. 36 37 Parameters 38 ---------- 39 A: OpT 40 (mA, nA) linear operator 41 B: OpT 42 (mB, nB) linear operator 43 44 Returns 45 ------- 46 op: OpT 47 (mA*mB, nA*nB) linear operator. 48 49 Notes 50 ----- 51 This implementation is **matrix-free** by leveraging properties of the Kronecker product, i.e. :math:`A` and 52 :math:`B` need not be known explicitly. In particular :math:`(A \otimes B) x` and :math:`(A \otimes B)^{*} x` are 53 computed implicitly via the relation: 54 55 .. math:: 56 57 \text{vec}(\mathbf{A}\mathbf{B}\mathbf{C}) 58 = 59 (\mathbf{C}^{T} \otimes \mathbf{A}) \text{vec}(\mathbf{B}), 60 61 where :math:`\mathbf{A}`, :math:`\mathbf{B}`, and :math:`\mathbf{C}` are matrices. 62 """ 63 64 def _infer_op_shape(shA: pxt.NDArrayShape, shB: pxt.NDArrayShape) -> pxt.NDArrayShape: 65 sh = (shA[0] * shB[0], shA[1] * shB[1]) 66 return sh 67 68 def _infer_op_klass(A: pxt.OpT, B: pxt.OpT) -> pxt.OpC: 69 # linear \kron linear -> linear 70 # square (if output square) 71 # normal \kron normal -> normal 72 # unit \kron unit -> unit 73 # self-adj \kron self-adj -> self-adj 74 # pos-def \kron pos-def -> pos-def 75 # idemp \kron idemp -> idemp 76 # func \kron func -> func 77 properties = set(A.properties() & B.properties()) 78 sh = _infer_op_shape(A.shape, B.shape) 79 if sh[0] == sh[1]: 80 properties.add(pxa.Property.LINEAR_SQUARE) 81 if pxa.Property.FUNCTIONAL in properties: 82 klass = pxa.LinFunc 83 else: 84 klass = pxa.Operator._infer_operator_type(properties) 85 return klass 86 87 def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray: 88 # If `x` is a vector, then: 89 # (A \kron B)(x) = vec(B * mat(x) * A.T) 90 sh_prefix = arr.shape[:-1] 91 sh_dim = len(sh_prefix) 92 93 x = arr.reshape((*sh_prefix, _._A.dim, _._B.dim)) # (..., A.dim, B.dim) 94 y = _._B.apply(x) # (..., A.dim, B.codim) 95 z = y.transpose((*range(sh_dim), -1, -2)) # (..., B.codim, A.dim) 96 t = _._A.apply(z) # (..., B.codim, A.codim) 97 u = t.transpose((*range(sh_dim), -1, -2)) # (..., A.codim, B.codim) 98 99 out = u.reshape((*sh_prefix, -1)) # (..., A.codim * B.codim) 100 return out 101 102 def op_adjoint(_, arr: pxt.NDArray) -> pxt.NDArray: 103 # If `x` is a vector, then: 104 # (A \kron B).H(x) = vec(B.H * mat(x) * A.conj) 105 sh_prefix = arr.shape[:-1] 106 sh_dim = len(sh_prefix) 107 108 x = arr.reshape((*sh_prefix, _._A.codim, _._B.codim)) # (..., A.codim, B.codim) 109 y = _._B.adjoint(x) # (..., A.codim, B.dim) 110 z = y.transpose((*range(sh_dim), -1, -2)) # (..., B.dim, A.codim) 111 t = _._A.adjoint(z) # (..., B.dim, A.dim) 112 u = t.transpose((*range(sh_dim), -1, -2)) # (..., A.dim, B.dim) 113 114 out = u.reshape((*sh_prefix, -1)) # (..., A.dim * B.dim) 115 return out 116 117 def op_estimate_lipschitz(_, **kwargs) -> pxt.Real: 118 no_eval = "__rule" in kwargs 119 if no_eval: 120 L_A = _._A.lipschitz 121 L_B = _._B.lipschitz 122 L = L_A * L_B 123 else: 124 L = _.__class__.estimate_lipschitz(_, **kwargs) 125 return L 126 127 def op_asarray(_, **kwargs) -> pxt.NDArray: 128 # (A \kron B).asarray() = A.asarray() \kron B.asarray() 129 A = _._A.asarray(**kwargs) 130 B = _._B.asarray(**kwargs) 131 xp = kwargs.get("xp", pxd.NDArrayInfo.default().module()) 132 C = xp.tensordot(A, B, axes=0).transpose((0, 2, 1, 3)).reshape(_.shape) 133 return C 134 135 def op_gram(_) -> pxt.OpT: 136 # (A \kron B).gram() = A.gram() \kron B.gram() 137 A = _._A.gram() 138 B = _._B.gram() 139 op = kron(A, B) 140 return op 141 142 def op_cogram(_) -> pxt.OpT: 143 # (A \kron B).cogram() = A.cogram() \kron B.cogram() 144 A = _._A.cogram() 145 B = _._B.cogram() 146 op = kron(A, B) 147 return op 148 149 def op_svdvals(_, **kwargs) -> pxt.NDArray: 150 # (A \kron B).svdvals(k) 151 # = outer( 152 # A.svdvals(k), 153 # B.svdvals(k) 154 # ).top(k) 155 k = kwargs.get("k", 1) 156 157 D_A = _._A.svdvals(**kwargs) 158 D_B = _._B.svdvals(**kwargs) 159 xp = pxu.get_array_module(D_A) 160 D_C = xp.concatenate([D_A, D_B])[-k:] 161 return D_C 162 163 def op_pinv(_, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray: 164 if np.isclose(damp, 0): 165 # (A \kron B).dagger() = A.dagger() \kron B.dagger() 166 op_d = kron(_._A.dagger(damp, **kwargs), _._B.dagger(damp, **kwargs)) 167 out = op_d.apply(arr) 168 else: 169 # default algorithm 170 out = _.__class__.pinv(_, arr, damp, **kwargs) 171 return out 172 173 def op_trace(_, **kwargs) -> pxt.Real: 174 # tr(A \kron B) = tr(A) * tr(B) 175 # [if both square, else default algorithm] 176 P = pxa.Property.LINEAR_SQUARE 177 if not _.has(P): 178 raise NotImplementedError 179 180 if _._A.has(P) and _._B.has(P): 181 tr = _._A.trace(**kwargs) * _._B.trace(**kwargs) 182 else: 183 tr = _.__class__.trace(_, **kwargs) 184 return tr 185 186 _A = A.squeeze() 187 _B = B.squeeze() 188 assert (klass := _infer_op_klass(_A, _B)).has(pxa.Property.LINEAR) 189 is_scalar = lambda _: _.shape == (1, 1) 190 if is_scalar(_A) and is_scalar(_B): 191 from pyxu.operator.linop.base import HomothetyOp 192 193 return HomothetyOp(cst=(_A.asarray() * _B.asarray()).item(), dim=1) 194 elif is_scalar(_A) and (not is_scalar(_B)): 195 return _A.asarray().item() * _B 196 elif (not is_scalar(_A)) and is_scalar(B): 197 return _A * _B.asarray().item() 198 else: 199 op = px_src.from_source( 200 cls=klass, 201 shape=_infer_op_shape(_A.shape, _B.shape), 202 embed=dict( 203 _name="kron", 204 _A=_A, 205 _B=_B, 206 ), 207 apply=op_apply, 208 adjoint=op_adjoint, 209 asarray=op_asarray, 210 gram=op_gram, 211 cogram=op_cogram, 212 svdvals=op_svdvals, 213 pinv=op_pinv, 214 trace=op_trace, 215 estimate_lipschitz=op_estimate_lipschitz, 216 _expr=lambda _: (_._name, _._A, _._B), 217 ) 218 op.lipschitz = op.estimate_lipschitz(__rule=True) 219 return op
220 221
[docs] 222def khatri_rao(A: pxt.OpT, B: pxt.OpT) -> pxt.OpT: 223 r""" 224 `Column-wise Khatri-Rao product 225 <https://en.wikipedia.org/wiki/Khatri%E2%80%93Rao_product#Column-wise_Kronecker_product>`_ :math:`A \circ B` between 226 two linear operators. 227 228 The Khatri-Rao product :math:`A \circ B` is defined as 229 230 .. math:: 231 232 A \circ B 233 = 234 \left[ 235 \begin{array}{ccc} 236 \mathbf{a}_{1} \otimes \mathbf{b}_{1} & \cdots & \mathbf{a}_{N} \otimes \mathbf{b}_{N} 237 \end{array} 238 \right], 239 240 where :math:`A : \mathbb{R}^{N} \to \mathbb{R}^{M_{A}}`, :math:`B : \mathbb{R}^{N} \to \mathbb{R}^{M_{B}}`, and 241 :math:`\mathbf{a}_{k}` (repectively :math:`\mathbf{b}_{k}`) denotes the :math:`k`-th column of :math:`A` 242 (respectively :math:`B`). 243 244 Parameters 245 ---------- 246 A: OpT 247 (mA, n) linear operator 248 B: OpT 249 (mB, n) linear operator 250 251 Returns 252 ------- 253 op: OpT 254 (mA*mB, n) linear operator. 255 256 Notes 257 ----- 258 This implementation is **matrix-free** by leveraging properties of the Khatri-Rao product, i.e. :math:`A` and 259 :math:`B` need not be known explicitly. In particular :math:`(A \circ B) x` and :math:`(A \circ B)^{*} x` are 260 computed implicitly via the relation: 261 262 .. math:: 263 264 \text{vec}(\mathbf{A}\text{diag}(\mathbf{b})\mathbf{C}) 265 = 266 (\mathbf{C}^{T} \circ \mathbf{A}) \mathbf{b}, 267 268 where :math:`\mathbf{A}`, :math:`\mathbf{C}` are matrices, and :math:`\mathbf{b}` is a vector. 269 270 Note however that a matrix-free implementation of the Khatri-Rao product does not permit the same optimizations as a 271 matrix-based implementation. Thus the Khatri-Rao product as implemented here is only marginally more efficient than 272 applying :py:func:`~pyxu.operator.kron` and pruning its output. 273 """ 274 275 def _infer_op_shape(shA: pxt.NDArrayShape, shB: pxt.NDArrayShape) -> pxt.NDArrayShape: 276 if shA[1] != shB[1]: 277 raise ValueError(f"Khatri-Rao product of {shA} and {shB} operators forbidden.") 278 sh = (shA[0] * shB[0], shA[1]) 279 return sh 280 281 def _infer_op_klass(A: pxt.OpT, B: pxt.OpT) -> pxt.OpC: 282 # linear \kr linear -> linear 283 # square (if output square) 284 sh = _infer_op_shape(A.shape, B.shape) 285 if sh[0] == 1: 286 klass = pxa.LinFunc 287 else: 288 properties = set(pxa.LinOp.properties()) 289 if sh[0] == sh[1]: 290 properties.add(pxa.Property.LINEAR_SQUARE) 291 klass = pxa.Operator._infer_operator_type(properties) 292 return klass 293 294 def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray: 295 # If `x` is a vector, then: 296 # (A \kr B)(x) = vec(B * diag(x) * A.T) 297 sh_prefix = arr.shape[:-1] 298 sh_dim = len(sh_prefix) 299 xp = pxu.get_array_module(arr) 300 I = xp.eye(N=_.dim, dtype=arr.dtype) # noqa: E741 301 302 x = arr.reshape((*sh_prefix, 1, _.dim)) # (..., 1, dim) 303 y = _._B.apply(x * I) # (..., dim, B.codim) 304 z = y.transpose((*range(sh_dim), -1, -2)) # (..., B.codim, dim) 305 t = _._A.apply(z) # (..., B.codim, A.codim) 306 u = t.transpose((*range(sh_dim), -1, -2)) # (..., A.codim, B.codim) 307 308 out = u.reshape((*sh_prefix, -1)) # (..., A.codim * B.codim) 309 return out 310 311 def op_adjoint(_, arr: pxt.NDArray) -> pxt.NDArray: 312 # If `x` is a vector, then: 313 # (A \kr B).H(x) = diag(B.H * mat(x) * A.conj) 314 sh_prefix = arr.shape[:-1] 315 sh_dim = len(sh_prefix) 316 xp = pxu.get_array_module(arr) 317 I = xp.eye(N=_.dim, dtype=arr.dtype) # noqa: E741 318 319 x = arr.reshape((*sh_prefix, _._A.codim, _._B.codim)) # (..., A.codim, B.codim) 320 y = _._B.adjoint(x) # (..., A.codim, B.dim) 321 z = y.transpose((*range(sh_dim), -1, -2)) # (..., dim, A.codim) 322 t = pxu.copy_if_unsafe(_._A.adjoint(z)) # (..., dim, dim) 323 t *= I 324 325 out = t.sum(axis=-1) # (..., dim) 326 return out 327 328 def op_asarray(_, **kwargs) -> pxt.NDArray: 329 # (A \kr B).asarray()[:,i] = A.asarray()[:,i] \kron B.asarray()[:,i] 330 A = _._A.asarray(**kwargs).T.reshape((_.dim, _._A.codim, 1)) 331 B = _._B.asarray(**kwargs).T.reshape((_.dim, 1, _._B.codim)) 332 C = (A * B).reshape((_.dim, -1)).T 333 return C 334 335 def op_lipschitz(_, **kwargs) -> pxt.Real: 336 if kwargs.get("tight", False): 337 _._lipschitz = _.__class__.lipschitz(_, **kwargs) 338 else: 339 op = kron(_._A, _._B) 340 _._lipschitz = op.lipschitz(**kwargs) 341 return _._lipschitz 342 343 _A = A.squeeze() 344 _B = B.squeeze() 345 assert (klass := _infer_op_klass(_A, _B)).has(pxa.Property.LINEAR) 346 347 op = px_src.from_source( 348 cls=klass, 349 shape=_infer_op_shape(_A.shape, _B.shape), 350 embed=dict( 351 _name="khatri_rao", 352 _A=_A, 353 _B=_B, 354 ), 355 apply=op_apply, 356 adjoint=op_adjoint, 357 asarray=op_asarray, 358 _expr=lambda _: (_._name, _._A, _._B), 359 ) 360 361 # kr(A,B) = kron(A,B) + sub-sampling -> upper-bound provided by kron(A,B).lipschitz 362 op.lipschitz = kron(_A, _B).lipschitz 363 return op