1import pyxu.info.deps as pxd
  2import pyxu.info.ptype as pxt
  3import pyxu.runtime as pxrt
  4
  5__all__ = [
  6    "as_real_op",
  7    "require_viewable",
  8    "view_as_real",
  9    "view_as_complex",
 10]
 11
 12
[docs]
 13def require_viewable(x: pxt.NDArray) -> pxt.NDArray:
 14    """
 15    Copy array if required to do real/complex view manipulations.
 16
 17    Real/complex view manipulations are feasible if the last axis is contiguous.
 18
 19    Parameters
 20    ----------
 21    x: NDArray
 22
 23    Returns
 24    -------
 25    y: NDArray
 26    """
 27    N = pxd.NDArrayInfo
 28    ndi = N.from_obj(x)
 29    if ndi == N.DASK:
 30        # No notion of contiguity for Dask graphs -> always safe.
 31        y = x
 32    elif ndi in (N.NUMPY, N.CUPY):
 33        if x.strides[-1] == x.dtype.itemsize:
 34            y = x
 35        else:
 36            y = x.copy(order="C")
 37    else:
 38        msg = f"require_viewable() not yet defined for {ndi}."
 39        raise NotImplementedError(msg)
 40    return y 
 41
 42
[docs]
 43def view_as_complex(x: pxt.NDArray) -> pxt.NDArray:
 44    r"""
 45    View real-valued array as its complex-valued bijection.  (Inverse of :py:func:`~pyxu.util.view_as_real`.)
 46
 47    Parameters
 48    ----------
 49    x: NDArray
 50        (..., N, 2) real-valued array.
 51
 52    Returns
 53    -------
 54    y: NDArray
 55        (..., N) complex-valued array.
 56
 57    Examples
 58    --------
 59
 60    .. code-block:: python3
 61
 62       from pyxu.util import view_as_real, view_as_complex
 63       x = np.array([[0., 1],
 64                     [2 , 3],
 65                     [4 , 5]])
 66       y = view_as_complex(x)  # array([0.+1.j, 2.+3.j, 4.+5.j])
 67       view_as_real(y) == x    # True
 68
 69    Notes
 70    -----
 71    Complex-valued inputs are returned unchanged.
 72
 73    See Also
 74    --------
 75    :py:func:`~pyxu.util.as_real_op`,
 76    :py:func:`~pyxu.util.view_as_real`
 77    """
 78    assert x.ndim >= 2
 79    if _is_complex(x):
 80        return x
 81
 82    try:
 83        r_dtype = x.dtype
 84        r_width = pxrt.Width(r_dtype)
 85        c_dtype = r_width.complex.value
 86    except Exception:
 87        raise ValueError(f"Unsupported dtype {r_dtype}.")
 88    assert x.shape[-1] == 2, "Last array dimension should contain real/imaginary terms only."
 89
 90    y = x.view(c_dtype)  # (..., N, 1)
 91    y = y[..., 0]  # (..., N)
 92    return y 
 93
 94
[docs]
 95def view_as_real(x: pxt.NDArray) -> pxt.NDArray:
 96    r"""
 97    View complex-valued array as its real-valued bijection.  (Inverse of :py:func:`~pyxu.util.view_as_complex`.)
 98
 99    Parameters
100    ----------
101    x: NDArray
102        (..., N) complex-valued array.
103
104    Returns
105    -------
106    y: NDArray
107        (..., N, 2) real-valued array.
108
109    Examples
110    --------
111
112    .. code-block:: python3
113
114       from pyxu.util import view_as_real, view_as_complex
115       x = np.r_[0+1j, 2+3j, 4+5j]
116       y = view_as_real(x)               # array([[0., 1.],
117                                         #        [2., 3.],
118                                         #        [4., 5.]])
119       view_as_complex(y) == x           # True
120
121    Notes
122    -----
123    Real-valued inputs are returned unchanged.
124
125    See Also
126    --------
127    :py:func:`~pyxu.util.as_real_op`,
128    :py:func:`~pyxu.util.view_as_complex`
129    """
130    assert x.ndim >= 1
131    if _is_real(x):
132        return x
133
134    try:
135        c_dtype = x.dtype
136        c_width = pxrt.CWidth(c_dtype)
137        r_dtype = c_width.real.value
138    except Exception:
139        raise ValueError(f"Unsupported dtype {c_dtype}.")
140
141    y = x.view(r_dtype)  # (..., 2N)
142
143    ndi = pxd.NDArrayInfo.from_obj(x)
144    if ndi == pxd.NDArrayInfo.DASK:
145        y = y.map_blocks(  # (..., N, 2)
146            lambda blk: blk.reshape(
147                *blk.shape[:-1],
148                blk.shape[-1] // 2,
149                2,
150            ),
151            chunks=(*x.chunks, 2),
152            new_axis=x.ndim,
153            meta=y._meta,
154        )
155    else:
156        y = y.reshape(*x.shape, 2)  # (..., N, 2)
157    return y 
158
159
[docs]
160def as_real_op(A: pxt.NDArray, dim_rank: pxt.Integer = None) -> pxt.NDArray:
161    r"""
162    View complex-valued linear operator as its real-valued equivalent.
163
164    Useful to transform complex-valued matrix/vector products to their real-valued counterparts.
165
166    Parameters
167    ----------
168    A: NDArray
169        (N1...,NK, M1,...,MD) complex-valued matrix.
170    dim_rank: Integer
171        Dimension rank :math:`D`. (Can be omitted if `A` is 2D.)
172
173    Returns
174    -------
175    A_r: NDArray
176        (N1,...,NK,2, M1,...,MD,2) real-valued equivalent.
177
178    Examples
179    --------
180
181    .. code-block:: python3
182
183       import numpy as np
184       import pyxu.util.complex as cpl
185
186       codim_shape = (1,2,3)
187       dim_shape = (4,5,6,7)
188       dim_rank = len(dim_shape)
189
190       rng = np.random.default_rng(0)
191       A =      rng.standard_normal((*codim_shape, *dim_shape)) \
192         + 1j * rng.standard_normal((*codim_shape, *dim_shape))    # (1,2,3  |4,5,6,7  )
193       A_r = cpl.as_real_op(A, dim_rank=dim_rank)                  # (1,2,3,2|4,5,6,7,2)
194
195       x =      rng.standard_normal(dim_shape) \
196         + 1j * rng.standard_normal(dim_shape)                     # (4,5,6,7  )
197       x_r = cpl.view_as_real(x)                                   # (4,5,6,7,2)
198
199       y = np.tensordot(A, x, axes=dim_rank)                       # (1,2,3  )
200       y_r = np.tensordot(A_r, x_r, axes=dim_rank+1)               # (1,2,3,2)
201
202       np.allclose(y, cpl.view_as_complex(y_r))                    # True
203
204
205    Notes
206    -----
207    Real-valued matrices are returned unchanged.
208
209    See Also
210    --------
211    :py:func:`~pyxu.util.view_as_real`,
212    :py:func:`~pyxu.util.view_as_complex`
213    """
214    if _is_real(A):
215        return A
216
217    try:
218        c_dtype = A.dtype
219        c_width = pxrt.CWidth(c_dtype)
220        r_dtype = c_width.real.value
221    except Exception:
222        raise ValueError(f"Unsupported dtype {c_dtype}.")
223
224    if A.ndim == 2:
225        dim_rank = 1  # doesn't matter what the user specified.
226    else:  # rank > 2
227        # if ND -> mandatory supplied & (1 <= dim_rank < A.ndim)
228        assert dim_rank is not None, "Dimension rank must be specified for ND operators."
229        assert 1 <= dim_rank < A.ndim
230    dim_shape = A.shape[-dim_rank:]
231    codim_shape = A.shape[:-dim_rank]
232    codim_rank = len(codim_shape)
233
234    xp = pxd.NDArrayInfo.from_obj(A).module()
235    A_r = xp.zeros((*codim_shape, 2, *dim_shape, 2), dtype=r_dtype)
236
237    codim_sel = [*(slice(None),) * codim_rank, 0]
238    dim_sel = [*(slice(None),) * dim_rank, 0]
239    A_r[*codim_sel, *dim_sel] = A.real
240
241    codim_sel = [*(slice(None),) * codim_rank, 1]
242    dim_sel = [*(slice(None),) * dim_rank, 1]
243    A_r[*codim_sel, *dim_sel] = A.real
244
245    codim_sel = [*(slice(None),) * codim_rank, 0]
246    dim_sel = [*(slice(None),) * dim_rank, 1]
247    A_r[*codim_sel, *dim_sel] = -A.imag
248
249    codim_sel = [*(slice(None),) * codim_rank, 1]
250    dim_sel = [*(slice(None),) * dim_rank, 0]
251    A_r[*codim_sel, *dim_sel] = A.imag
252    return A_r 
253
254
255def _is_real(x: pxt.NDArray) -> bool:
256    try:
257        return bool(pxrt.Width(x.dtype))
258    except Exception:
259        return False
260
261
262def _is_complex(x: pxt.NDArray) -> bool:
263    try:
264        return bool(pxrt.CWidth(x.dtype))
265    except Exception:
266        return False