1import numpy as np
2
3import pyxu.abc as pxa
4import pyxu.info.deps as pxd
5import pyxu.info.ptype as pxt
6import pyxu.operator.interop.source as px_src
7import pyxu.util as pxu
8
9__all__ = [
10 "kron",
11 "khatri_rao",
12]
13
14
[docs]
15def kron(A: pxt.OpT, B: pxt.OpT) -> pxt.OpT:
16 r"""
17 `Kronecker product <https://en.wikipedia.org/wiki/Kronecker_product#Definition>`_ :math:`A \otimes B` between two
18 linear operators.
19
20 The Kronecker product :math:`A \otimes B` is defined as
21
22 .. math::
23
24 A \otimes B
25 =
26 \left[
27 \begin{array}{ccc}
28 A_{11} B & \cdots & A_{1N_{A}} B \\
29 \vdots & \ddots & \vdots \\
30 A_{M_{A}1} B & \cdots & A_{M_{A}N_{A}} B \\
31 \end{array}
32 \right],
33
34 where :math:`A : \mathbb{R}^{N_{A}} \to \mathbb{R}^{M_{A}}`, and :math:`B : \mathbb{R}^{N_{B}} \to
35 \mathbb{R}^{M_{B}}`.
36
37 Parameters
38 ----------
39 A: OpT
40 (mA, nA) linear operator
41 B: OpT
42 (mB, nB) linear operator
43
44 Returns
45 -------
46 op: OpT
47 (mA*mB, nA*nB) linear operator.
48
49 Notes
50 -----
51 This implementation is **matrix-free** by leveraging properties of the Kronecker product, i.e. :math:`A` and
52 :math:`B` need not be known explicitly. In particular :math:`(A \otimes B) x` and :math:`(A \otimes B)^{*} x` are
53 computed implicitly via the relation:
54
55 .. math::
56
57 \text{vec}(\mathbf{A}\mathbf{B}\mathbf{C})
58 =
59 (\mathbf{C}^{T} \otimes \mathbf{A}) \text{vec}(\mathbf{B}),
60
61 where :math:`\mathbf{A}`, :math:`\mathbf{B}`, and :math:`\mathbf{C}` are matrices.
62 """
63
64 def _infer_op_shape(shA: pxt.NDArrayShape, shB: pxt.NDArrayShape) -> pxt.NDArrayShape:
65 sh = (shA[0] * shB[0], shA[1] * shB[1])
66 return sh
67
68 def _infer_op_klass(A: pxt.OpT, B: pxt.OpT) -> pxt.OpC:
69 # linear \kron linear -> linear
70 # square (if output square)
71 # normal \kron normal -> normal
72 # unit \kron unit -> unit
73 # self-adj \kron self-adj -> self-adj
74 # pos-def \kron pos-def -> pos-def
75 # idemp \kron idemp -> idemp
76 # func \kron func -> func
77 properties = set(A.properties() & B.properties())
78 sh = _infer_op_shape(A.shape, B.shape)
79 if sh[0] == sh[1]:
80 properties.add(pxa.Property.LINEAR_SQUARE)
81 if pxa.Property.FUNCTIONAL in properties:
82 klass = pxa.LinFunc
83 else:
84 klass = pxa.Operator._infer_operator_type(properties)
85 return klass
86
87 def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray:
88 # If `x` is a vector, then:
89 # (A \kron B)(x) = vec(B * mat(x) * A.T)
90 sh_prefix = arr.shape[:-1]
91 sh_dim = len(sh_prefix)
92
93 x = arr.reshape((*sh_prefix, _._A.dim, _._B.dim)) # (..., A.dim, B.dim)
94 y = _._B.apply(x) # (..., A.dim, B.codim)
95 z = y.transpose((*range(sh_dim), -1, -2)) # (..., B.codim, A.dim)
96 t = _._A.apply(z) # (..., B.codim, A.codim)
97 u = t.transpose((*range(sh_dim), -1, -2)) # (..., A.codim, B.codim)
98
99 out = u.reshape((*sh_prefix, -1)) # (..., A.codim * B.codim)
100 return out
101
102 def op_adjoint(_, arr: pxt.NDArray) -> pxt.NDArray:
103 # If `x` is a vector, then:
104 # (A \kron B).H(x) = vec(B.H * mat(x) * A.conj)
105 sh_prefix = arr.shape[:-1]
106 sh_dim = len(sh_prefix)
107
108 x = arr.reshape((*sh_prefix, _._A.codim, _._B.codim)) # (..., A.codim, B.codim)
109 y = _._B.adjoint(x) # (..., A.codim, B.dim)
110 z = y.transpose((*range(sh_dim), -1, -2)) # (..., B.dim, A.codim)
111 t = _._A.adjoint(z) # (..., B.dim, A.dim)
112 u = t.transpose((*range(sh_dim), -1, -2)) # (..., A.dim, B.dim)
113
114 out = u.reshape((*sh_prefix, -1)) # (..., A.dim * B.dim)
115 return out
116
117 def op_estimate_lipschitz(_, **kwargs) -> pxt.Real:
118 no_eval = "__rule" in kwargs
119 if no_eval:
120 L_A = _._A.lipschitz
121 L_B = _._B.lipschitz
122 L = L_A * L_B
123 else:
124 L = _.__class__.estimate_lipschitz(_, **kwargs)
125 return L
126
127 def op_asarray(_, **kwargs) -> pxt.NDArray:
128 # (A \kron B).asarray() = A.asarray() \kron B.asarray()
129 A = _._A.asarray(**kwargs)
130 B = _._B.asarray(**kwargs)
131 xp = kwargs.get("xp", pxd.NDArrayInfo.default().module())
132 C = xp.tensordot(A, B, axes=0).transpose((0, 2, 1, 3)).reshape(_.shape)
133 return C
134
135 def op_gram(_) -> pxt.OpT:
136 # (A \kron B).gram() = A.gram() \kron B.gram()
137 A = _._A.gram()
138 B = _._B.gram()
139 op = kron(A, B)
140 return op
141
142 def op_cogram(_) -> pxt.OpT:
143 # (A \kron B).cogram() = A.cogram() \kron B.cogram()
144 A = _._A.cogram()
145 B = _._B.cogram()
146 op = kron(A, B)
147 return op
148
149 def op_svdvals(_, **kwargs) -> pxt.NDArray:
150 # (A \kron B).svdvals(k)
151 # = outer(
152 # A.svdvals(k),
153 # B.svdvals(k)
154 # ).top(k)
155 k = kwargs.get("k", 1)
156
157 D_A = _._A.svdvals(**kwargs)
158 D_B = _._B.svdvals(**kwargs)
159 xp = pxu.get_array_module(D_A)
160 D_C = xp.concatenate([D_A, D_B])[-k:]
161 return D_C
162
163 def op_pinv(_, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray:
164 if np.isclose(damp, 0):
165 # (A \kron B).dagger() = A.dagger() \kron B.dagger()
166 op_d = kron(_._A.dagger(damp, **kwargs), _._B.dagger(damp, **kwargs))
167 out = op_d.apply(arr)
168 else:
169 # default algorithm
170 out = _.__class__.pinv(_, arr, damp, **kwargs)
171 return out
172
173 def op_trace(_, **kwargs) -> pxt.Real:
174 # tr(A \kron B) = tr(A) * tr(B)
175 # [if both square, else default algorithm]
176 P = pxa.Property.LINEAR_SQUARE
177 if not _.has(P):
178 raise NotImplementedError
179
180 if _._A.has(P) and _._B.has(P):
181 tr = _._A.trace(**kwargs) * _._B.trace(**kwargs)
182 else:
183 tr = _.__class__.trace(_, **kwargs)
184 return tr
185
186 _A = A.squeeze()
187 _B = B.squeeze()
188 assert (klass := _infer_op_klass(_A, _B)).has(pxa.Property.LINEAR)
189 is_scalar = lambda _: _.shape == (1, 1)
190 if is_scalar(_A) and is_scalar(_B):
191 from pyxu.operator.linop.base import HomothetyOp
192
193 return HomothetyOp(cst=(_A.asarray() * _B.asarray()).item(), dim=1)
194 elif is_scalar(_A) and (not is_scalar(_B)):
195 return _A.asarray().item() * _B
196 elif (not is_scalar(_A)) and is_scalar(B):
197 return _A * _B.asarray().item()
198 else:
199 op = px_src.from_source(
200 cls=klass,
201 shape=_infer_op_shape(_A.shape, _B.shape),
202 embed=dict(
203 _name="kron",
204 _A=_A,
205 _B=_B,
206 ),
207 apply=op_apply,
208 adjoint=op_adjoint,
209 asarray=op_asarray,
210 gram=op_gram,
211 cogram=op_cogram,
212 svdvals=op_svdvals,
213 pinv=op_pinv,
214 trace=op_trace,
215 estimate_lipschitz=op_estimate_lipschitz,
216 _expr=lambda _: (_._name, _._A, _._B),
217 )
218 op.lipschitz = op.estimate_lipschitz(__rule=True)
219 return op
220
221
[docs]
222def khatri_rao(A: pxt.OpT, B: pxt.OpT) -> pxt.OpT:
223 r"""
224 `Column-wise Khatri-Rao product
225 <https://en.wikipedia.org/wiki/Khatri%E2%80%93Rao_product#Column-wise_Kronecker_product>`_ :math:`A \circ B` between
226 two linear operators.
227
228 The Khatri-Rao product :math:`A \circ B` is defined as
229
230 .. math::
231
232 A \circ B
233 =
234 \left[
235 \begin{array}{ccc}
236 \mathbf{a}_{1} \otimes \mathbf{b}_{1} & \cdots & \mathbf{a}_{N} \otimes \mathbf{b}_{N}
237 \end{array}
238 \right],
239
240 where :math:`A : \mathbb{R}^{N} \to \mathbb{R}^{M_{A}}`, :math:`B : \mathbb{R}^{N} \to \mathbb{R}^{M_{B}}`, and
241 :math:`\mathbf{a}_{k}` (repectively :math:`\mathbf{b}_{k}`) denotes the :math:`k`-th column of :math:`A`
242 (respectively :math:`B`).
243
244 Parameters
245 ----------
246 A: OpT
247 (mA, n) linear operator
248 B: OpT
249 (mB, n) linear operator
250
251 Returns
252 -------
253 op: OpT
254 (mA*mB, n) linear operator.
255
256 Notes
257 -----
258 This implementation is **matrix-free** by leveraging properties of the Khatri-Rao product, i.e. :math:`A` and
259 :math:`B` need not be known explicitly. In particular :math:`(A \circ B) x` and :math:`(A \circ B)^{*} x` are
260 computed implicitly via the relation:
261
262 .. math::
263
264 \text{vec}(\mathbf{A}\text{diag}(\mathbf{b})\mathbf{C})
265 =
266 (\mathbf{C}^{T} \circ \mathbf{A}) \mathbf{b},
267
268 where :math:`\mathbf{A}`, :math:`\mathbf{C}` are matrices, and :math:`\mathbf{b}` is a vector.
269
270 Note however that a matrix-free implementation of the Khatri-Rao product does not permit the same optimizations as a
271 matrix-based implementation. Thus the Khatri-Rao product as implemented here is only marginally more efficient than
272 applying :py:func:`~pyxu.operator.kron` and pruning its output.
273 """
274
275 def _infer_op_shape(shA: pxt.NDArrayShape, shB: pxt.NDArrayShape) -> pxt.NDArrayShape:
276 if shA[1] != shB[1]:
277 raise ValueError(f"Khatri-Rao product of {shA} and {shB} operators forbidden.")
278 sh = (shA[0] * shB[0], shA[1])
279 return sh
280
281 def _infer_op_klass(A: pxt.OpT, B: pxt.OpT) -> pxt.OpC:
282 # linear \kr linear -> linear
283 # square (if output square)
284 sh = _infer_op_shape(A.shape, B.shape)
285 if sh[0] == 1:
286 klass = pxa.LinFunc
287 else:
288 properties = set(pxa.LinOp.properties())
289 if sh[0] == sh[1]:
290 properties.add(pxa.Property.LINEAR_SQUARE)
291 klass = pxa.Operator._infer_operator_type(properties)
292 return klass
293
294 def op_apply(_, arr: pxt.NDArray) -> pxt.NDArray:
295 # If `x` is a vector, then:
296 # (A \kr B)(x) = vec(B * diag(x) * A.T)
297 sh_prefix = arr.shape[:-1]
298 sh_dim = len(sh_prefix)
299 xp = pxu.get_array_module(arr)
300 I = xp.eye(N=_.dim, dtype=arr.dtype) # noqa: E741
301
302 x = arr.reshape((*sh_prefix, 1, _.dim)) # (..., 1, dim)
303 y = _._B.apply(x * I) # (..., dim, B.codim)
304 z = y.transpose((*range(sh_dim), -1, -2)) # (..., B.codim, dim)
305 t = _._A.apply(z) # (..., B.codim, A.codim)
306 u = t.transpose((*range(sh_dim), -1, -2)) # (..., A.codim, B.codim)
307
308 out = u.reshape((*sh_prefix, -1)) # (..., A.codim * B.codim)
309 return out
310
311 def op_adjoint(_, arr: pxt.NDArray) -> pxt.NDArray:
312 # If `x` is a vector, then:
313 # (A \kr B).H(x) = diag(B.H * mat(x) * A.conj)
314 sh_prefix = arr.shape[:-1]
315 sh_dim = len(sh_prefix)
316 xp = pxu.get_array_module(arr)
317 I = xp.eye(N=_.dim, dtype=arr.dtype) # noqa: E741
318
319 x = arr.reshape((*sh_prefix, _._A.codim, _._B.codim)) # (..., A.codim, B.codim)
320 y = _._B.adjoint(x) # (..., A.codim, B.dim)
321 z = y.transpose((*range(sh_dim), -1, -2)) # (..., dim, A.codim)
322 t = pxu.copy_if_unsafe(_._A.adjoint(z)) # (..., dim, dim)
323 t *= I
324
325 out = t.sum(axis=-1) # (..., dim)
326 return out
327
328 def op_asarray(_, **kwargs) -> pxt.NDArray:
329 # (A \kr B).asarray()[:,i] = A.asarray()[:,i] \kron B.asarray()[:,i]
330 A = _._A.asarray(**kwargs).T.reshape((_.dim, _._A.codim, 1))
331 B = _._B.asarray(**kwargs).T.reshape((_.dim, 1, _._B.codim))
332 C = (A * B).reshape((_.dim, -1)).T
333 return C
334
335 def op_lipschitz(_, **kwargs) -> pxt.Real:
336 if kwargs.get("tight", False):
337 _._lipschitz = _.__class__.lipschitz(_, **kwargs)
338 else:
339 op = kron(_._A, _._B)
340 _._lipschitz = op.lipschitz(**kwargs)
341 return _._lipschitz
342
343 _A = A.squeeze()
344 _B = B.squeeze()
345 assert (klass := _infer_op_klass(_A, _B)).has(pxa.Property.LINEAR)
346
347 op = px_src.from_source(
348 cls=klass,
349 shape=_infer_op_shape(_A.shape, _B.shape),
350 embed=dict(
351 _name="khatri_rao",
352 _A=_A,
353 _B=_B,
354 ),
355 apply=op_apply,
356 adjoint=op_adjoint,
357 asarray=op_asarray,
358 _expr=lambda _: (_._name, _._A, _._B),
359 )
360
361 # kr(A,B) = kron(A,B) + sub-sampling -> upper-bound provided by kron(A,B).lipschitz
362 op.lipschitz = kron(_A, _B).lipschitz
363 return op