Source code for pyxu.util.complex

  1import pyxu.info.deps as pxd
  2import pyxu.info.ptype as pxt
  3import pyxu.runtime as pxrt
  4
  5__all__ = [
  6    "as_real_op",
  7    "require_viewable",
  8    "view_as_real",
  9    "view_as_complex",
 10]
 11
 12
[docs] 13def require_viewable(x: pxt.NDArray) -> pxt.NDArray: 14 """ 15 Copy array if required to do real/complex view manipulations. 16 17 Real/complex view manipulations are feasible if the last axis is contiguous. 18 19 Parameters 20 ---------- 21 x: NDArray 22 23 Returns 24 ------- 25 y: NDArray 26 """ 27 N = pxd.NDArrayInfo 28 ndi = N.from_obj(x) 29 if ndi == N.DASK: 30 # No notion of contiguity for Dask graphs -> always safe. 31 y = x 32 elif ndi in (N.NUMPY, N.CUPY): 33 if x.strides[-1] == x.dtype.itemsize: 34 y = x 35 else: 36 y = x.copy(order="C") 37 else: 38 msg = f"require_viewable() not yet defined for {ndi}." 39 raise NotImplementedError(msg) 40 return y
41 42
[docs] 43def view_as_complex(x: pxt.NDArray) -> pxt.NDArray: 44 r""" 45 View real-valued array as its complex-valued bijection. (Inverse of :py:func:`~pyxu.util.view_as_real`.) 46 47 Parameters 48 ---------- 49 x: NDArray 50 (..., N, 2) real-valued array. 51 52 Returns 53 ------- 54 y: NDArray 55 (..., N) complex-valued array. 56 57 Examples 58 -------- 59 60 .. code-block:: python3 61 62 from pyxu.util import view_as_real, view_as_complex 63 x = np.array([[0., 1], 64 [2 , 3], 65 [4 , 5]]) 66 y = view_as_complex(x) # array([0.+1.j, 2.+3.j, 4.+5.j]) 67 view_as_real(y) == x # True 68 69 Notes 70 ----- 71 Complex-valued inputs are returned unchanged. 72 73 See Also 74 -------- 75 :py:func:`~pyxu.util.as_real_op`, 76 :py:func:`~pyxu.util.view_as_real` 77 """ 78 assert x.ndim >= 2 79 if _is_complex(x): 80 return x 81 82 try: 83 r_dtype = x.dtype 84 r_width = pxrt.Width(r_dtype) 85 c_dtype = r_width.complex.value 86 except Exception: 87 raise ValueError(f"Unsupported dtype {r_dtype}.") 88 assert x.shape[-1] == 2, "Last array dimension should contain real/imaginary terms only." 89 90 y = x.view(c_dtype) # (..., N, 1) 91 y = y[..., 0] # (..., N) 92 return y
93 94
[docs] 95def view_as_real(x: pxt.NDArray) -> pxt.NDArray: 96 r""" 97 View complex-valued array as its real-valued bijection. (Inverse of :py:func:`~pyxu.util.view_as_complex`.) 98 99 Parameters 100 ---------- 101 x: NDArray 102 (..., N) complex-valued array. 103 104 Returns 105 ------- 106 y: NDArray 107 (..., N, 2) real-valued array. 108 109 Examples 110 -------- 111 112 .. code-block:: python3 113 114 from pyxu.util import view_as_real, view_as_complex 115 x = np.r_[0+1j, 2+3j, 4+5j] 116 y = view_as_real(x) # array([[0., 1.], 117 # [2., 3.], 118 # [4., 5.]]) 119 view_as_complex(y) == x # True 120 121 Notes 122 ----- 123 Real-valued inputs are returned unchanged. 124 125 See Also 126 -------- 127 :py:func:`~pyxu.util.as_real_op`, 128 :py:func:`~pyxu.util.view_as_complex` 129 """ 130 assert x.ndim >= 1 131 if _is_real(x): 132 return x 133 134 try: 135 c_dtype = x.dtype 136 c_width = pxrt.CWidth(c_dtype) 137 r_dtype = c_width.real.value 138 except Exception: 139 raise ValueError(f"Unsupported dtype {c_dtype}.") 140 141 y = x.view(r_dtype) # (..., 2N) 142 143 ndi = pxd.NDArrayInfo.from_obj(x) 144 if ndi == pxd.NDArrayInfo.DASK: 145 y = y.map_blocks( # (..., N, 2) 146 lambda blk: blk.reshape( 147 *blk.shape[:-1], 148 blk.shape[-1] // 2, 149 2, 150 ), 151 chunks=(*x.chunks, 2), 152 new_axis=x.ndim, 153 meta=y._meta, 154 ) 155 else: 156 y = y.reshape(*x.shape, 2) # (..., N, 2) 157 return y
158 159
[docs] 160def as_real_op(A: pxt.NDArray, dim_rank: pxt.Integer = None) -> pxt.NDArray: 161 r""" 162 View complex-valued linear operator as its real-valued equivalent. 163 164 Useful to transform complex-valued matrix/vector products to their real-valued counterparts. 165 166 Parameters 167 ---------- 168 A: NDArray 169 (N1...,NK, M1,...,MD) complex-valued matrix. 170 dim_rank: Integer 171 Dimension rank :math:`D`. (Can be omitted if `A` is 2D.) 172 173 Returns 174 ------- 175 A_r: NDArray 176 (N1,...,NK,2, M1,...,MD,2) real-valued equivalent. 177 178 Examples 179 -------- 180 181 .. code-block:: python3 182 183 import numpy as np 184 import pyxu.util.complex as cpl 185 186 codim_shape = (1,2,3) 187 dim_shape = (4,5,6,7) 188 dim_rank = len(dim_shape) 189 190 rng = np.random.default_rng(0) 191 A = rng.standard_normal((*codim_shape, *dim_shape)) \ 192 + 1j * rng.standard_normal((*codim_shape, *dim_shape)) # (1,2,3 |4,5,6,7 ) 193 A_r = cpl.as_real_op(A, dim_rank=dim_rank) # (1,2,3,2|4,5,6,7,2) 194 195 x = rng.standard_normal(dim_shape) \ 196 + 1j * rng.standard_normal(dim_shape) # (4,5,6,7 ) 197 x_r = cpl.view_as_real(x) # (4,5,6,7,2) 198 199 y = np.tensordot(A, x, axes=dim_rank) # (1,2,3 ) 200 y_r = np.tensordot(A_r, x_r, axes=dim_rank+1) # (1,2,3,2) 201 202 np.allclose(y, cpl.view_as_complex(y_r)) # True 203 204 205 Notes 206 ----- 207 Real-valued matrices are returned unchanged. 208 209 See Also 210 -------- 211 :py:func:`~pyxu.util.view_as_real`, 212 :py:func:`~pyxu.util.view_as_complex` 213 """ 214 if _is_real(A): 215 return A 216 217 try: 218 c_dtype = A.dtype 219 c_width = pxrt.CWidth(c_dtype) 220 r_dtype = c_width.real.value 221 except Exception: 222 raise ValueError(f"Unsupported dtype {c_dtype}.") 223 224 if A.ndim == 2: 225 dim_rank = 1 # doesn't matter what the user specified. 226 else: # rank > 2 227 # if ND -> mandatory supplied & (1 <= dim_rank < A.ndim) 228 assert dim_rank is not None, "Dimension rank must be specified for ND operators." 229 assert 1 <= dim_rank < A.ndim 230 dim_shape = A.shape[-dim_rank:] 231 codim_shape = A.shape[:-dim_rank] 232 codim_rank = len(codim_shape) 233 234 xp = pxd.NDArrayInfo.from_obj(A).module() 235 A_r = xp.zeros((*codim_shape, 2, *dim_shape, 2), dtype=r_dtype) 236 237 codim_sel = [*(slice(None),) * codim_rank, 0] 238 dim_sel = [*(slice(None),) * dim_rank, 0] 239 A_r[*codim_sel, *dim_sel] = A.real 240 241 codim_sel = [*(slice(None),) * codim_rank, 1] 242 dim_sel = [*(slice(None),) * dim_rank, 1] 243 A_r[*codim_sel, *dim_sel] = A.real 244 245 codim_sel = [*(slice(None),) * codim_rank, 0] 246 dim_sel = [*(slice(None),) * dim_rank, 1] 247 A_r[*codim_sel, *dim_sel] = -A.imag 248 249 codim_sel = [*(slice(None),) * codim_rank, 1] 250 dim_sel = [*(slice(None),) * dim_rank, 0] 251 A_r[*codim_sel, *dim_sel] = A.imag 252 return A_r
253 254 255def _is_real(x: pxt.NDArray) -> bool: 256 try: 257 return bool(pxrt.Width(x.dtype)) 258 except Exception: 259 return False 260 261 262def _is_complex(x: pxt.NDArray) -> bool: 263 try: 264 return bool(pxrt.CWidth(x.dtype)) 265 except Exception: 266 return False