Source code for pyxu.info.deps

  1import collections.abc as cabc
  2import enum
  3import importlib.util
  4import types
  5
  6import dask.array
  7import numpy
  8import packaging.version as pkgv
  9import scipy.sparse
 10
 11#: Show if CuPy-based backends are available.
 12CUPY_ENABLED: bool = importlib.util.find_spec("cupy") is not None
 13if CUPY_ENABLED:
 14    try:
 15        import cupy
 16        import cupyx.scipy.sparse
 17        import cupyx.scipy.sparse.linalg
 18
 19        cupy.is_available()  # will fail if hardware/drivers/runtime missing
 20    except Exception:
 21        CUPY_ENABLED = False
 22
 23
[docs] 24@enum.unique 25class NDArrayInfo(enum.Enum): 26 """ 27 Supported dense array backends. 28 """ 29 30 NUMPY = enum.auto() 31 DASK = enum.auto() 32 CUPY = enum.auto() 33
[docs] 34 @classmethod 35 def default(cls) -> "NDArrayInfo": 36 """Default array backend to use.""" 37 return cls.NUMPY
38
[docs] 39 def type(self) -> type: 40 """Array type associated to a backend.""" 41 if self.name == "NUMPY": 42 return numpy.ndarray 43 elif self.name == "DASK": 44 return dask.array.core.Array 45 elif self.name == "CUPY": 46 return cupy.ndarray if CUPY_ENABLED else type(None) 47 else: 48 raise ValueError(f"No known array type for {self.name}.")
49
[docs] 50 @classmethod 51 def from_obj(cls, obj) -> "NDArrayInfo": 52 """Find array backend associated to `obj`.""" 53 if obj is not None: 54 for ndi in cls: 55 if isinstance(obj, ndi.type()): 56 return ndi 57 raise ValueError(f"No known array type to match {obj}.")
58
[docs] 59 @classmethod 60 def from_flag(cls, gpu: bool) -> "NDArrayInfo": 61 """Find array backend suitable for in-memory CPU/GPU computing.""" 62 if gpu: 63 return cls.CUPY 64 else: 65 return cls.NUMPY
66
[docs] 67 def module(self, linalg: bool = False) -> types.ModuleType: 68 """ 69 Python module associated to an array backend. 70 71 Parameters 72 ---------- 73 linalg: bool 74 Return the linear-algebra submodule with identical API to :py:mod:`numpy.linalg`. 75 """ 76 if self.name == "NUMPY": 77 xp = numpy 78 xpl = xp.linalg 79 elif self.name == "DASK": 80 xp = dask.array 81 xpl = xp.linalg 82 elif self.name == "CUPY": 83 xp = cupy if CUPY_ENABLED else None 84 xpl = xp if (xp is None) else xp.linalg 85 else: 86 raise ValueError(f"No known module(s) for {self.name}.") 87 return xpl if linalg else xp
88 89
[docs] 90@enum.unique 91class SparseArrayInfo(enum.Enum): 92 """ 93 Supported sparse array backends. 94 """ 95 96 SCIPY_SPARSE = enum.auto() 97 CUPY_SPARSE = enum.auto() 98
[docs] 99 @classmethod 100 def default(cls) -> "SparseArrayInfo": 101 """Default array backend to use.""" 102 return cls.SCIPY_SPARSE
103
[docs] 104 def type(self) -> type: 105 """Array type associated to a backend.""" 106 if self.name == "SCIPY_SPARSE": 107 # All `*matrix` classes descend from `spmatrix`. 108 return scipy.sparse.spmatrix 109 elif self.name == "CUPY_SPARSE": 110 return cupyx.scipy.sparse.spmatrix if CUPY_ENABLED else type(None) 111 else: 112 raise ValueError(f"No known array type for {self.name}.")
113
[docs] 114 @classmethod 115 def from_obj(cls, obj) -> "SparseArrayInfo": 116 """Find array backend associated to `obj`.""" 117 if obj is not None: 118 for sai in cls: 119 if isinstance(obj, sai.type()): 120 return sai 121 raise ValueError(f"No known array type to match {sai}.")
122
[docs] 123 def module(self, linalg: bool = False) -> types.ModuleType: 124 """ 125 Python module associated to an array backend. 126 127 Parameters 128 ---------- 129 linalg: bool 130 Return the linear-algebra submodule with identical API to :py:mod:`scipy.sparse.linalg`. 131 """ 132 if self.name == "SCIPY_SPARSE": 133 xp = scipy.sparse 134 xpl = xp.linalg 135 elif self.name == "CUPY_SPARSE": 136 xp = cupyx.scipy.sparse if CUPY_ENABLED else None 137 xpl = xp if (xp is None) else cupyx.scipy.sparse.linalg 138 else: 139 raise ValueError(f"No known array module for {self.name}.") 140 return xpl if linalg else xp
141 142
[docs] 143def supported_array_types() -> cabc.Collection[type]: 144 """List of all supported dense array types in current Pyxu install.""" 145 data = set() 146 for ndi in NDArrayInfo: 147 if (ndi != NDArrayInfo.CUPY) or CUPY_ENABLED: 148 data.add(ndi.type()) 149 return tuple(data)
150 151
[docs] 152def supported_array_modules() -> cabc.Collection[types.ModuleType]: 153 """List of all supported dense array modules in current Pyxu install.""" 154 data = set() 155 for ndi in NDArrayInfo: 156 if (ndi != NDArrayInfo.CUPY) or CUPY_ENABLED: 157 data.add(ndi.module()) 158 return tuple(data)
159 160
[docs] 161def supported_sparse_types() -> cabc.Collection[type]: 162 """List of all supported sparse array types in current Pyxu install.""" 163 data = set() 164 for sai in SparseArrayInfo: 165 if (sai != SparseArrayInfo.CUPY_SPARSE) or CUPY_ENABLED: 166 data.add(sai.type()) 167 return tuple(data)
168 169
[docs] 170def supported_sparse_modules() -> cabc.Collection[types.ModuleType]: 171 """List of all supported sparse array modules in current Pyxu install.""" 172 data = set() 173 for sai in SparseArrayInfo: 174 if (sai != SparseArrayInfo.CUPY_SPARSE) or CUPY_ENABLED: 175 data.add(sai.module()) 176 return tuple(data)
177 178 179JAX_SUPPORT = dict( 180 min=pkgv.Version("0.4.8"), 181 max=pkgv.Version("1.0"), 182) 183PYTORCH_SUPPORT = dict( 184 min=pkgv.Version("2.0"), 185 max=pkgv.Version("3.0"), 186) 187 188__all__ = [ 189 "CUPY_ENABLED", 190 "NDArrayInfo", 191 "SparseArrayInfo", 192 "supported_array_types", 193 "supported_array_modules", 194 "supported_sparse_types", 195 "supported_sparse_modules", 196 "JAX_SUPPORT", 197 "PYTORCH_SUPPORT", 198]