Source code for pyxu.info.deps
1import collections.abc as cabc
2import enum
3import importlib.util
4import types
5
6import dask.array
7import numpy
8import packaging.version as pkgv
9import scipy.sparse
10
11#: Show if CuPy-based backends are available.
12CUPY_ENABLED: bool = importlib.util.find_spec("cupy") is not None
13if CUPY_ENABLED:
14 try:
15 import cupy
16 import cupyx.scipy.sparse
17 import cupyx.scipy.sparse.linalg
18
19 cupy.is_available() # will fail if hardware/drivers/runtime missing
20 except Exception:
21 CUPY_ENABLED = False
22
23
[docs]
24@enum.unique
25class NDArrayInfo(enum.Enum):
26 """
27 Supported dense array backends.
28 """
29
30 NUMPY = enum.auto()
31 DASK = enum.auto()
32 CUPY = enum.auto()
33
[docs]
34 @classmethod
35 def default(cls) -> "NDArrayInfo":
36 """Default array backend to use."""
37 return cls.NUMPY
38
[docs]
39 def type(self) -> type:
40 """Array type associated to a backend."""
41 if self.name == "NUMPY":
42 return numpy.ndarray
43 elif self.name == "DASK":
44 return dask.array.core.Array
45 elif self.name == "CUPY":
46 return cupy.ndarray if CUPY_ENABLED else type(None)
47 else:
48 raise ValueError(f"No known array type for {self.name}.")
49
[docs]
50 @classmethod
51 def from_obj(cls, obj) -> "NDArrayInfo":
52 """Find array backend associated to `obj`."""
53 if obj is not None:
54 for ndi in cls:
55 if isinstance(obj, ndi.type()):
56 return ndi
57 raise ValueError(f"No known array type to match {obj}.")
58
[docs]
59 @classmethod
60 def from_flag(cls, gpu: bool) -> "NDArrayInfo":
61 """Find array backend suitable for in-memory CPU/GPU computing."""
62 if gpu:
63 return cls.CUPY
64 else:
65 return cls.NUMPY
66
[docs]
67 def module(self, linalg: bool = False) -> types.ModuleType:
68 """
69 Python module associated to an array backend.
70
71 Parameters
72 ----------
73 linalg: bool
74 Return the linear-algebra submodule with identical API to :py:mod:`numpy.linalg`.
75 """
76 if self.name == "NUMPY":
77 xp = numpy
78 xpl = xp.linalg
79 elif self.name == "DASK":
80 xp = dask.array
81 xpl = xp.linalg
82 elif self.name == "CUPY":
83 xp = cupy if CUPY_ENABLED else None
84 xpl = xp if (xp is None) else xp.linalg
85 else:
86 raise ValueError(f"No known module(s) for {self.name}.")
87 return xpl if linalg else xp
88
89
[docs]
90@enum.unique
91class SparseArrayInfo(enum.Enum):
92 """
93 Supported sparse array backends.
94 """
95
96 SCIPY_SPARSE = enum.auto()
97 CUPY_SPARSE = enum.auto()
98
[docs]
99 @classmethod
100 def default(cls) -> "SparseArrayInfo":
101 """Default array backend to use."""
102 return cls.SCIPY_SPARSE
103
[docs]
104 def type(self) -> type:
105 """Array type associated to a backend."""
106 if self.name == "SCIPY_SPARSE":
107 # All `*matrix` classes descend from `spmatrix`.
108 return scipy.sparse.spmatrix
109 elif self.name == "CUPY_SPARSE":
110 return cupyx.scipy.sparse.spmatrix if CUPY_ENABLED else type(None)
111 else:
112 raise ValueError(f"No known array type for {self.name}.")
113
[docs]
114 @classmethod
115 def from_obj(cls, obj) -> "SparseArrayInfo":
116 """Find array backend associated to `obj`."""
117 if obj is not None:
118 for sai in cls:
119 if isinstance(obj, sai.type()):
120 return sai
121 raise ValueError(f"No known array type to match {sai}.")
122
[docs]
123 def module(self, linalg: bool = False) -> types.ModuleType:
124 """
125 Python module associated to an array backend.
126
127 Parameters
128 ----------
129 linalg: bool
130 Return the linear-algebra submodule with identical API to :py:mod:`scipy.sparse.linalg`.
131 """
132 if self.name == "SCIPY_SPARSE":
133 xp = scipy.sparse
134 xpl = xp.linalg
135 elif self.name == "CUPY_SPARSE":
136 xp = cupyx.scipy.sparse if CUPY_ENABLED else None
137 xpl = xp if (xp is None) else cupyx.scipy.sparse.linalg
138 else:
139 raise ValueError(f"No known array module for {self.name}.")
140 return xpl if linalg else xp
141
142
[docs]
143def supported_array_types() -> cabc.Collection[type]:
144 """List of all supported dense array types in current Pyxu install."""
145 data = set()
146 for ndi in NDArrayInfo:
147 if (ndi != NDArrayInfo.CUPY) or CUPY_ENABLED:
148 data.add(ndi.type())
149 return tuple(data)
150
151
[docs]
152def supported_array_modules() -> cabc.Collection[types.ModuleType]:
153 """List of all supported dense array modules in current Pyxu install."""
154 data = set()
155 for ndi in NDArrayInfo:
156 if (ndi != NDArrayInfo.CUPY) or CUPY_ENABLED:
157 data.add(ndi.module())
158 return tuple(data)
159
160
[docs]
161def supported_sparse_types() -> cabc.Collection[type]:
162 """List of all supported sparse array types in current Pyxu install."""
163 data = set()
164 for sai in SparseArrayInfo:
165 if (sai != SparseArrayInfo.CUPY_SPARSE) or CUPY_ENABLED:
166 data.add(sai.type())
167 return tuple(data)
168
169
[docs]
170def supported_sparse_modules() -> cabc.Collection[types.ModuleType]:
171 """List of all supported sparse array modules in current Pyxu install."""
172 data = set()
173 for sai in SparseArrayInfo:
174 if (sai != SparseArrayInfo.CUPY_SPARSE) or CUPY_ENABLED:
175 data.add(sai.module())
176 return tuple(data)
177
178
179JAX_SUPPORT = dict(
180 min=pkgv.Version("0.4.8"),
181 max=pkgv.Version("1.0"),
182)
183PYTORCH_SUPPORT = dict(
184 min=pkgv.Version("2.0"),
185 max=pkgv.Version("3.0"),
186)
187
188__all__ = [
189 "CUPY_ENABLED",
190 "NDArrayInfo",
191 "SparseArrayInfo",
192 "supported_array_types",
193 "supported_array_modules",
194 "supported_sparse_types",
195 "supported_sparse_modules",
196 "JAX_SUPPORT",
197 "PYTORCH_SUPPORT",
198]