1import pyxu.info.deps as pxd
2import pyxu.info.ptype as pxt
3import pyxu.runtime as pxrt
4
5__all__ = [
6 "as_real_op",
7 "require_viewable",
8 "view_as_real",
9 "view_as_complex",
10]
11
12
[docs]
13def require_viewable(x: pxt.NDArray) -> pxt.NDArray:
14 """
15 Copy array if required to do real/complex view manipulations.
16
17 Real/complex view manipulations are feasible if the last axis is contiguous.
18
19 Parameters
20 ----------
21 x: NDArray
22
23 Returns
24 -------
25 y: NDArray
26 """
27 N = pxd.NDArrayInfo
28 ndi = N.from_obj(x)
29 if ndi == N.DASK:
30 # No notion of contiguity for Dask graphs -> always safe.
31 y = x
32 elif ndi in (N.NUMPY, N.CUPY):
33 if x.strides[-1] == x.dtype.itemsize:
34 y = x
35 else:
36 y = x.copy(order="C")
37 else:
38 msg = f"require_viewable() not yet defined for {ndi}."
39 raise NotImplementedError(msg)
40 return y
41
42
[docs]
43def view_as_complex(x: pxt.NDArray) -> pxt.NDArray:
44 r"""
45 View real-valued array as its complex-valued bijection. (Inverse of :py:func:`~pyxu.util.view_as_real`.)
46
47 Parameters
48 ----------
49 x: NDArray
50 (..., N, 2) real-valued array.
51
52 Returns
53 -------
54 y: NDArray
55 (..., N) complex-valued array.
56
57 Examples
58 --------
59
60 .. code-block:: python3
61
62 from pyxu.util import view_as_real, view_as_complex
63 x = np.array([[0., 1],
64 [2 , 3],
65 [4 , 5]])
66 y = view_as_complex(x) # array([0.+1.j, 2.+3.j, 4.+5.j])
67 view_as_real(y) == x # True
68
69 Notes
70 -----
71 Complex-valued inputs are returned unchanged.
72
73 See Also
74 --------
75 :py:func:`~pyxu.util.as_real_op`,
76 :py:func:`~pyxu.util.view_as_real`
77 """
78 assert x.ndim >= 2
79 if _is_complex(x):
80 return x
81
82 try:
83 r_dtype = x.dtype
84 r_width = pxrt.Width(r_dtype)
85 c_dtype = r_width.complex.value
86 except Exception:
87 raise ValueError(f"Unsupported dtype {r_dtype}.")
88 assert x.shape[-1] == 2, "Last array dimension should contain real/imaginary terms only."
89
90 y = x.view(c_dtype) # (..., N, 1)
91 y = y[..., 0] # (..., N)
92 return y
93
94
[docs]
95def view_as_real(x: pxt.NDArray) -> pxt.NDArray:
96 r"""
97 View complex-valued array as its real-valued bijection. (Inverse of :py:func:`~pyxu.util.view_as_complex`.)
98
99 Parameters
100 ----------
101 x: NDArray
102 (..., N) complex-valued array.
103
104 Returns
105 -------
106 y: NDArray
107 (..., N, 2) real-valued array.
108
109 Examples
110 --------
111
112 .. code-block:: python3
113
114 from pyxu.util import view_as_real, view_as_complex
115 x = np.r_[0+1j, 2+3j, 4+5j]
116 y = view_as_real(x) # array([[0., 1.],
117 # [2., 3.],
118 # [4., 5.]])
119 view_as_complex(y) == x # True
120
121 Notes
122 -----
123 Real-valued inputs are returned unchanged.
124
125 See Also
126 --------
127 :py:func:`~pyxu.util.as_real_op`,
128 :py:func:`~pyxu.util.view_as_complex`
129 """
130 assert x.ndim >= 1
131 if _is_real(x):
132 return x
133
134 try:
135 c_dtype = x.dtype
136 c_width = pxrt.CWidth(c_dtype)
137 r_dtype = c_width.real.value
138 except Exception:
139 raise ValueError(f"Unsupported dtype {c_dtype}.")
140
141 y = x.view(r_dtype) # (..., 2N)
142
143 ndi = pxd.NDArrayInfo.from_obj(x)
144 if ndi == pxd.NDArrayInfo.DASK:
145 y = y.map_blocks( # (..., N, 2)
146 lambda blk: blk.reshape(
147 *blk.shape[:-1],
148 blk.shape[-1] // 2,
149 2,
150 ),
151 chunks=(*x.chunks, 2),
152 new_axis=x.ndim,
153 meta=y._meta,
154 )
155 else:
156 y = y.reshape(*x.shape, 2) # (..., N, 2)
157 return y
158
159
[docs]
160def as_real_op(A: pxt.NDArray, dim_rank: pxt.Integer = None) -> pxt.NDArray:
161 r"""
162 View complex-valued linear operator as its real-valued equivalent.
163
164 Useful to transform complex-valued matrix/vector products to their real-valued counterparts.
165
166 Parameters
167 ----------
168 A: NDArray
169 (N1...,NK, M1,...,MD) complex-valued matrix.
170 dim_rank: Integer
171 Dimension rank :math:`D`. (Can be omitted if `A` is 2D.)
172
173 Returns
174 -------
175 A_r: NDArray
176 (N1,...,NK,2, M1,...,MD,2) real-valued equivalent.
177
178 Examples
179 --------
180
181 .. code-block:: python3
182
183 import numpy as np
184 import pyxu.util.complex as cpl
185
186 codim_shape = (1,2,3)
187 dim_shape = (4,5,6,7)
188 dim_rank = len(dim_shape)
189
190 rng = np.random.default_rng(0)
191 A = rng.standard_normal((*codim_shape, *dim_shape)) \
192 + 1j * rng.standard_normal((*codim_shape, *dim_shape)) # (1,2,3 |4,5,6,7 )
193 A_r = cpl.as_real_op(A, dim_rank=dim_rank) # (1,2,3,2|4,5,6,7,2)
194
195 x = rng.standard_normal(dim_shape) \
196 + 1j * rng.standard_normal(dim_shape) # (4,5,6,7 )
197 x_r = cpl.view_as_real(x) # (4,5,6,7,2)
198
199 y = np.tensordot(A, x, axes=dim_rank) # (1,2,3 )
200 y_r = np.tensordot(A_r, x_r, axes=dim_rank+1) # (1,2,3,2)
201
202 np.allclose(y, cpl.view_as_complex(y_r)) # True
203
204
205 Notes
206 -----
207 Real-valued matrices are returned unchanged.
208
209 See Also
210 --------
211 :py:func:`~pyxu.util.view_as_real`,
212 :py:func:`~pyxu.util.view_as_complex`
213 """
214 if _is_real(A):
215 return A
216
217 try:
218 c_dtype = A.dtype
219 c_width = pxrt.CWidth(c_dtype)
220 r_dtype = c_width.real.value
221 except Exception:
222 raise ValueError(f"Unsupported dtype {c_dtype}.")
223
224 if A.ndim == 2:
225 dim_rank = 1 # doesn't matter what the user specified.
226 else: # rank > 2
227 # if ND -> mandatory supplied & (1 <= dim_rank < A.ndim)
228 assert dim_rank is not None, "Dimension rank must be specified for ND operators."
229 assert 1 <= dim_rank < A.ndim
230 dim_shape = A.shape[-dim_rank:]
231 codim_shape = A.shape[:-dim_rank]
232 codim_rank = len(codim_shape)
233
234 xp = pxd.NDArrayInfo.from_obj(A).module()
235 A_r = xp.zeros((*codim_shape, 2, *dim_shape, 2), dtype=r_dtype)
236
237 codim_sel = [*(slice(None),) * codim_rank, 0]
238 dim_sel = [*(slice(None),) * dim_rank, 0]
239 A_r[*codim_sel, *dim_sel] = A.real
240
241 codim_sel = [*(slice(None),) * codim_rank, 1]
242 dim_sel = [*(slice(None),) * dim_rank, 1]
243 A_r[*codim_sel, *dim_sel] = A.real
244
245 codim_sel = [*(slice(None),) * codim_rank, 0]
246 dim_sel = [*(slice(None),) * dim_rank, 1]
247 A_r[*codim_sel, *dim_sel] = -A.imag
248
249 codim_sel = [*(slice(None),) * codim_rank, 1]
250 dim_sel = [*(slice(None),) * dim_rank, 0]
251 A_r[*codim_sel, *dim_sel] = A.imag
252 return A_r
253
254
255def _is_real(x: pxt.NDArray) -> bool:
256 try:
257 return bool(pxrt.Width(x.dtype))
258 except Exception:
259 return False
260
261
262def _is_complex(x: pxt.NDArray) -> bool:
263 try:
264 return bool(pxrt.CWidth(x.dtype))
265 except Exception:
266 return False