1import typing as typ
2import warnings
3from functools import wraps
4
5import numpy as np
6import packaging.version as pkgv
7
8import pyxu.abc as pxa
9import pyxu.info.deps as pxd
10import pyxu.info.ptype as pxt
11import pyxu.info.warning as pxw
12import pyxu.operator.interop.source as px_src
13import pyxu.runtime as pxrt
14from pyxu.util import import_module
15
16torch = import_module("torch", fail_on_error=False)
17if torch is not None:
18 version = pkgv.Version(torch.__version__)
19 supported = pxd.PYTORCH_SUPPORT
20 assert supported["min"] <= version < supported["max"]
21
22 import torch._dynamo as dynamo
23 import torch.func as functorch
24
25 TorchTensor = torch.Tensor
26else:
27 TorchTensor = typ.TypeVar("torch.Tensor")
28
29__all__ = [
30 "from_torch",
31]
32
33
34def _traceable(f):
35 # Needed to compile functorch transforms. See this issue: https://github.com/pytorch/pytorch/issues/98822
36 f = dynamo.allow_in_graph(f)
37
38 @wraps(f)
39 def wrapper(*args, **kwargs):
40 return f(*args, **kwargs)
41
42 return wrapper
43
44
[docs]
45def _to_torch(arr: pxt.NDArray, requires_grad: bool = False) -> TorchTensor:
46 r"""
47 NumPy/CuPy -> PyTorch conversion.
48
49 Convert a NumPy-like array into a PyTorch tensor, sharing data, dtype and device.
50
51 Parameters
52 ----------
53 arr: NDArray
54 Input array.
55 requires_grad: bool
56 If autograd should record operations on the returned tensor.
57
58 Returns
59 -------
60 tensor: torch.Tensor
61 Output tensor.
62
63 Notes
64 -----
65 The returned tensor and input array share the same memory. Modifications to the tensor will be reflected in the
66 ndarray and vice versa. The returned tensor is not resizable.
67 """
68 if pxd.NDArrayInfo.from_obj(arr) == pxd.NDArrayInfo.CUPY:
69 with torch.device("cuda", arr.device.id):
70 return torch.as_tensor(arr).requires_grad_(requires_grad)
71 else:
72 return torch.from_numpy(arr).requires_grad_(requires_grad)
73
74
[docs]
75def _from_torch(tensor: TorchTensor) -> pxt.NDArray:
76 r"""
77 PyTorch -> NumPy/CuPy conversion.
78
79 Convert a PyTorch tensor into a NumPy-like array, sharing data, dtype and device.
80
81 Parameters
82 ----------
83 tensor: torch.Tensor
84 Input tensor.
85
86 Returns
87 -------
88 arr: NDArray
89 Output array.
90
91 Notes
92 -----
93 The returned array and input tensor share the same memory. Modifications to the tensor will be reflected in the
94 ndarray and vice versa.
95 """
96 if tensor.get_device() == -1:
97 return tensor.detach().numpy(force=False)
98 else:
99 cp = pxd.NDArrayInfo.CUPY.module()
100 with cp.cuda.Device(tensor.get_device()):
101 return cp.asarray(tensor.detach())
102
103
[docs]
104def from_torch(
105 cls: pxt.OpC,
106 dim_shape: pxt.NDArrayShape,
107 codim_shape: pxt.NDArrayShape,
108 vectorize: pxt.VarName = frozenset(),
109 jit: bool = False,
110 enable_warnings: bool = True,
111 **kwargs,
112) -> pxt.OpT:
113 r"""
114 Define an :py:class:`~pyxu.abc.Operator` from PyTorch functions.
115
116 Parameters
117 ----------
118 cls: OpC
119 Operator sub-class to instantiate.
120 dim_shape: NDArrayShape
121 Operator domain shape (M1,...,MD).
122 codim_shape: NDArrayShape
123 Operator co-domain shape (N1,...,NK).
124 kwargs: dict
125 (k[str], v[callable]) pairs to use as arithmetic methods.
126
127 Keys are restricted to the following arithmetic methods:
128
129 .. code-block:: python3
130
131 apply(), grad(), prox(), pinv(), adjoint()
132
133 Omitted arithmetic methods default to those provided by `cls`, or are auto-inferred via auto-diff rules.
134 vectorize: VarName
135 Arithmetic methods to vectorize.
136
137 `vectorize` is useful if an arithmetic method provided to `kwargs` does not support stacking dimensions.
138 jit: bool
139 Currently has no effect (for future-compatibility only). In the future, if ``True``, then Torch-backed
140 arithmetic methods will be JIT-compiled for better performance.
141 enable_warnings: bool
142 If ``True``, emit warnings in case of precision/zero-copy issues.
143
144 Returns
145 -------
146 op: OpT
147 Pyxu-compliant operator :math:`A: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{N_{1}
148 \times\cdots\times N_{K}}`.
149
150 Notes
151 -----
152 * If provided, arithmetic methods must abide exactly to the interface below:
153
154 .. code-block:: python3
155
156 def apply(arr: torch.Tensor) -> torch.Tensor # (..., M1,...,MD) -> (..., N1,...,NK)
157 def grad(arr: torch.Tensor) -> torch.Tensor # (..., M1,...,MD) -> (..., M1,...,MD)
158 def adjoint(arr: torch.Tensor) -> torch.Tensor # (..., N1,...,NK) -> (..., M1,...,MD)
159 def prox(arr: torch.Tensor, tau: pxt.Real) -> torch.Tensor # (..., M1,...,MD) -> (..., M1,...,MD)
160 def pinv(arr: torch.Tensor, damp: pxt.Real) -> torch.Tensor # (..., N1,...,NK) -> (..., M1,...,MD)
161
162 Moreover, the methods above **must** accept stacking dimensions in ``arr``. If this does not hold, consider
163 populating `vectorize`.
164
165 * Auto-vectorization consists in decorating `kwargs`-specified arithmetic methods with :py:func:`torch.vmap`. See
166 the `PyTorch documentation <https://pytorch.org/docs/stable/func.ux_limitations.html#vmap-limitations>`_ for known
167 limitations.
168
169 * Arithmetic methods are **not currently JIT-ed** even if `jit` is set to ``True``. This is because of the
170 undocumented and currently poor interaction between :py:mod:`torch.func` transforms and :py:func:`torch.compile`.
171 See `this issue <https://github.com/pytorch/pytorch/issues/98822>`_ for additional details.
172
173 * For :py:class:`~pyxu.abc.DiffMap` (or subclasses thereof), the methods :py:meth:`~pyxu.abc.DiffMap.jacobian`,
174 :py:meth:`~pyxu.abc.DiffFunc.grad` and :py:meth:`~pyxu.abc.LinOp.adjoint` are defined implicitly if not provided
175 using the auto-differentiation transforms from :py:mod:`torch.func`. As detailed `on this page
176 <https://pytorch.org/docs/stable/func.ux_limitations.html>`_, such transforms work well on pure functions (that
177 is, functions where the output is completely determined by the input and that do not involve side effects like
178 mutation), but may fail on more complex functions. Moreover, :py:mod:`torch.func` does not yet have full coverage
179 over PyTorch operations. For functions that call a :py:class:`torch.nn.Module`, see `here
180 <https://pytorch.org/docs/stable/func.api.html#utilities-for-working-with-torch-nn-modules>`_ for some utilities.
181
182 .. Warning::
183
184 Operators created with this wrapper do not support Dask inputs for now.
185 """
186 if isinstance(vectorize, str):
187 vectorize = (vectorize,)
188 vectorize = frozenset(vectorize)
189
190 src = _FromTorch(
191 cls=cls,
192 dim_shape=dim_shape,
193 codim_shape=codim_shape,
194 vectorize=vectorize,
195 jit=bool(jit),
196 enable_warnings=bool(enable_warnings),
197 **kwargs,
198 )
199 op = src.op()
200 return op
201
202
203class _FromTorch(px_src._FromSource):
204 # supported methods in __init__(**kwargs)
205 _meth = frozenset({"apply", "grad", "prox", "pinv", "adjoint"})
206
207 def __init__( # See from_torch() for a detailed description.
208 self,
209 cls: pxt.OpC,
210 dim_shape: pxt.NDArrayShape,
211 codim_shape: pxt.NDArrayShape,
212 vectorize: frozenset[str],
213 jit: bool, # Unused for now
214 enable_warnings: bool,
215 **kwargs,
216 ):
217 self._batch_size = kwargs.pop("batch_size", None)
218 self._dtype = kwargs.pop("dtype", None)
219 super().__init__(
220 cls=cls,
221 dim_shape=dim_shape,
222 codim_shape=codim_shape,
223 embed=dict(),
224 vectorize=vectorize,
225 **kwargs,
226 )
227
228 self._jit = False # JIT-compilation is currently deactivated until torch.func goes out of beta.
229 self._enable_warnings = enable_warnings
230
231 # Only a subset of arithmetic methods allowed from from_torch().
232 if not (set(self._kwargs) <= self._meth):
233 msg_head = "Unsupported arithmetic methods:"
234 unsupported = set(self._kwargs) - self._meth
235 msg_tail = ", ".join([f"{name}()" for name in unsupported])
236 raise ValueError(f"{msg_head} {msg_tail}")
237
238 def op(self) -> pxt.OpT:
239 # Idea: modify `**kwargs` from constructor to [when required]:
240 # 1. auto-define omitted methods. [_infer_missing()]
241 # 2. auto-vectorize via vmap(). [_auto_vectorize()]
242 # 3. JIT-compile via compile(). [_compile()]
243 # 4. TORCH<>NumPy/CuPy conversions. [_interface()]
244 # Note: JIT-compilation is currently deactivated due to the undocumented interaction of torch.func transforms
245 # and torch.compile. Will be reactivated once torch.func goes out of beta.
246
247 self._infer_missing()
248 self._compile()
249 self._auto_vectorize()
250 t_state, kwargs = self._interface()
251
252 _op = px_src.from_source(
253 cls=self._op.__class__,
254 dim_shape=self._op.dim_shape,
255 codim_shape=self._op.codim_shape,
256 embed=dict(
257 _batch_size=self._batch_size,
258 _dtype=self._dtype,
259 _jit=self._jit,
260 _enable_warnings=self._enable_warnings,
261 _coerce=self._coerce,
262 _torch=t_state,
263 ),
264 # vectorize=None, # see top-level comment.
265 **kwargs,
266 )
267 return _op
268
269 def _infer_missing(self):
270 # The following methods must be auto-inferred if missing from `kwargs`:
271 #
272 # grad(), adjoint()
273 #
274 # Missing methods are auto-inferred via auto-diff rules and added to `_kwargs`.
275 # At the end of _infer_missing(), all torch-funcs required for _interface() have been added to `_kwargs`.
276 #
277 # Notes
278 # -----
279 # This method does NOT produce vectorized implementations: _auto_vectorize() is responsible for this.
280 self._vectorize = set(self._vectorize) # to allow updates below
281
282 nl_difffunc = all( # non-linear diff-func
283 [
284 self._op.has(pxa.Property.DIFFERENTIABLE_FUNCTION),
285 not self._op.has(pxa.Property.LINEAR),
286 ]
287 )
288 if nl_difffunc and ("grad" not in self._kwargs):
289 apply_no_vec = self._copy_function(self._kwargs["apply"])
290
291 def f_grad(tensor: TorchTensor) -> TorchTensor:
292 apply = lambda tnsr: torch.squeeze(apply_no_vec(tnsr))
293 grad = functorch.grad(apply)
294 return grad(tensor)
295
296 self._vectorize.add("grad")
297 self._kwargs["grad"] = f_grad
298
299 non_selfadj = all( # linear, but not self-adjoint
300 [
301 self._op.has(pxa.Property.LINEAR),
302 not self._op.has(pxa.Property.LINEAR_SELF_ADJOINT),
303 ]
304 )
305 if non_selfadj and ("adjoint" not in self._kwargs):
306 apply_no_vec = self._copy_function(self._kwargs["apply"])
307
308 def f_adjoint(tensor: TorchTensor) -> TorchTensor:
309 primal = torch.zeros(self._op.dim_shape, dtype=tensor.dtype)
310 _, f_vjp = functorch.vjp(apply_no_vec, primal)
311 return f_vjp(tensor)[0] # f_vjp returns a tuple
312
313 self._vectorize.add("adjoint")
314 self._kwargs["adjoint"] = f_adjoint
315
316 self._vectorize = frozenset(self._vectorize)
317
318 def _compile(self):
319 # JIT-compile user-specified [or _infer_missing()-added] arithmetic methods via torch.compile().
320 #
321 # Modifies `_kwargs` to hold compiled torch-funcs.
322 # Note: Currently deactivated until torch.func goes out of beta.
323
324 if self._jit:
325 for name in self._kwargs:
326 if name in self._meth:
327 func = self._kwargs[name] # necessarily torch_func
328 self._kwargs[name] = torch.compile(func)
329
330 def _auto_vectorize(self):
331 # Vectorize user-specified [or _infer_missing()-added] arithmetic methods via torch.vmap().
332 #
333 # Modifies `_kwargs` to hold vectorized torch-funcs.
334
335 for name in self._kwargs:
336 if name in self._vectorize:
337 func = self._kwargs[name] # necessarily torch_func
338 if name in [
339 "prox",
340 "pinv",
341 ]: # These methods have two arguments, but vectorization should be for the first argument only.
342 self._kwargs[name] = _traceable(torch.vmap(func, in_dims=(0, None), chunk_size=self._batch_size))
343 else:
344 self._kwargs[name] = _traceable(torch.vmap(func, chunk_size=self._batch_size))
345
346 def _interface(self):
347 # Arithmetic methods supplied in `kwargs`:
348 #
349 # * take `torch.Tensor` inputs
350 # * do not have the `self` parameter. (Reason: to be JIT-compatible.)
351 #
352 # This method creates modified arithmetic functions to match Pyxu's API, and `state` required for them to work.
353 #
354 # Returns
355 # -------
356 # t_state: dict
357 # Torch functions referenced by wrapper arithmetic methods. (See below.)
358 # kwargs: dict
359 # Pyxu-compatible functions which can be submitted to interop.from_source().
360
361 # torch_func's (potentially auto_vec/jitted at this stage)
362 t_state = {name: obj for name, obj in self._kwargs.items() if name in self._meth}
363
364 # Pyxu-compatible functions
365 kwargs = {name: getattr(self.__class__, name) for name in self._kwargs}
366
367 # Special cases.
368 # (Reason: not in `_kwargs` [c.f. from_torch() docstring], but need to be inferred.)
369 for name in ("jacobian", "_quad_spec", "_expr", "asarray"):
370 kwargs[name] = getattr(self.__class__, name)
371
372 return t_state, kwargs
373
374 def _coerce(self, arr: pxt.NDArray) -> pxt.NDArray:
375 # Coerce inputs (and raise a warning) in case of precision mis-matches.
376 if self._dtype is not None and (arr.dtype != self._dtype):
377 if self._enable_warnings:
378 msg = f"Precision mis-match! Input array was coerced to {self._dtype} precision automatically."
379 warnings.warn(msg, pxw.PrecisionWarning)
380 arr = arr.astype(self._dtype, copy=False)
381 return arr
382
383 # Wrapper arithmetic methods ----------------------------------------------
384 def apply(self, arr: pxt.NDArray) -> pxt.NDArray:
385 bsh = arr.shape[: -self.dim_rank]
386 bsize = (int(np.prod(bsh)),)
387 arr = arr.reshape(bsize + self.dim_shape)
388 arr = self._coerce(arr)
389 tensor = _to_torch(arr)
390 func = self._torch["apply"]
391 out = _from_torch(func(tensor))
392 return out.reshape(bsh + self.codim_shape)
393
394 def grad(self, arr: pxt.NDArray) -> pxt.NDArray:
395 bsh = arr.shape[: -self.dim_rank]
396 bsize = (int(np.prod(bsh)),)
397 arr = arr.reshape(bsize + self.dim_shape)
398 arr = self._coerce(arr)
399 tensor = _to_torch(arr)
400 func = self._torch["grad"]
401 out = _from_torch(func(tensor))
402 return out.reshape(bsh + self.dim_shape)
403
404 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray:
405 bsh = arr.shape[: -self.codim_rank]
406 bsize = (int(np.prod(bsh)),)
407 arr = arr.reshape(bsize + self.codim_shape)
408 arr = self._coerce(arr)
409 tensor = _to_torch(arr)
410 func = self._torch["adjoint"]
411 out = _from_torch(func(tensor))
412 return out.reshape(bsh + self.dim_shape)
413
414 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray:
415 bsh = arr.shape[: -self.dim_rank]
416 bsize = (int(np.prod(bsh)),)
417 arr = arr.reshape(bsize + self.dim_shape)
418 arr = self._coerce(arr)
419 tensor = _to_torch(arr)
420 func = self._torch["prox"]
421 out = _from_torch(func(tensor, tau))
422 return out.reshape(bsh + self.dim_shape)
423
424 def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray:
425 bsh = arr.shape[: -self.codim_rank]
426 bsize = (int(np.prod(bsh)),)
427 arr = arr.reshape(bsize + self.codim_shape)
428 arr = self._coerce(arr)
429 tensor = _to_torch(arr)
430 func = self._torch["pinv"]
431 out = func(tensor, damp) # positional args only if auto-vectorized.
432 out = _from_torch(out)
433 return out.reshape(bsh + self.dim_shape)
434
435 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT:
436 try:
437 # Use the class' method if available ...
438 klass = self.__class__
439 op = klass.jacobian(self, arr)
440 except NotImplementedError:
441 # ... and fallback to auto-inference if undefined.
442 # bsh = arr.shape[:-self.codim_rank]
443 # bsize = (int(np.prod(bsh)), )
444 # arr = arr.reshape(bsize + self.codim_shape)
445 arr = self._coerce(arr)
446 primal = _to_torch(arr)
447 f = self._torch["apply"]
448
449 def jf_apply(tan: TorchTensor) -> TorchTensor:
450 return functorch.jvp(f, primals=(primal,), tangents=(tan,))[1]
451
452 def jf_adjoint(cotan: TorchTensor) -> TorchTensor:
453 _, f_vjp = functorch.vjp(f, primal)
454 return f_vjp(cotan)[0] # f_vjp returns a tuple
455
456 klass = pxa.LinFunc if (self.codim_shape == (1,)) else pxa.LinOp
457 op = from_torch(
458 cls=klass,
459 dim_shape=self.dim_shape,
460 codim_shape=self.codim_shape,
461 vectorize=("apply", "adjoint"),
462 batch_size=self._batch_size,
463 jit=self._jit,
464 dtype=self._dtype,
465 enable_warnings=self._enable_warnings,
466 apply=jf_apply,
467 adjoint=jf_adjoint,
468 )
469 return op
470
471 def _quad_spec(self):
472 if self.has(pxa.Property.QUADRATIC):
473 # auto-infer (Q, c, t)
474 _grad = self._torch["grad"]
475
476 def Q_apply(tensor: TorchTensor) -> TorchTensor:
477 # \grad_{f}(x) = Q x + c = Q x + \grad_{f}(0)
478 # ==> Q x = \grad_{f}(x) - \grad_{f}(0)
479 z = torch.zeros_like(tensor)
480 out = _grad(tensor) - _grad(z)
481 return out
482
483 def c_apply(tensor: TorchTensor) -> TorchTensor:
484 z = torch.zeros_like(tensor)
485 c = _grad(z)
486 out = torch.sum(c * tensor, dtype=tensor.dtype).unsqueeze(-1)
487 return out
488
489 Q = from_torch(
490 apply=Q_apply,
491 dim_shape=self.dim_shape,
492 codim_shape=self.dim_shape,
493 cls=pxa.PosDefOp,
494 vectorize="apply",
495 batch_size=self._batch_size,
496 jit=self._jit,
497 dtype=self._dtype,
498 enable_warnings=self._enable_warnings,
499 )
500
501 c = from_torch(
502 apply=c_apply,
503 dim_shape=self.dim_shape,
504 codim_shape=1,
505 cls=pxa.LinFunc,
506 vectorize="apply",
507 batch_size=self._batch_size,
508 jit=self._jit,
509 dtype=self._dtype,
510 enable_warnings=self._enable_warnings,
511 )
512
513 # We cannot know a-priori which backend the supplied torch-apply() function works with.
514 # Consequence: to compute `t`, we must try different backends until one works.
515 f = self._torch["apply"]
516 t = self._compute_t(f)
517
518 return (Q, c, t)
519 else:
520 raise NotImplementedError
521
522 def _compute_t(self, f):
523 try:
524 with torch.device("cpu"):
525 return float(f(torch.zeros(self.dim_shape)))
526 except Exception:
527 for id in range(torch.cuda.device_count()):
528 try:
529 with torch.device("gpu", id):
530 return float(f(torch.zeros(self.dim_shape)))
531 except Exception:
532 continue
533 raise RuntimeError("Failed to compute t")
534
535 def asarray(self, **kwargs) -> pxt.NDArray:
536 if self.has(pxa.Property.LINEAR):
537 # Torch operators don't accept DASK inputs: cannot call Lin[Op,Func].asarray() with user-specified `xp` value.
538 # -> We arbitrarily perform computations using the NUMPY backend, then cast as needed.
539 N = pxd.NDArrayInfo # shorthand
540 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value)
541 xp = kwargs.get("xp", N.default().module())
542
543 klass = self.__class__
544 A = klass.asarray(self, dtype=dtype, xp=N.NUMPY.module())
545
546 # Not the most efficient method, but fail-proof
547 return xp.array(A, dtype=dtype)
548 else:
549 raise NotImplementedError
550
551 def _expr(self):
552 torch_funcs = list(self._torch.keys())
553 return ("from_torch", *torch_funcs)
554
555 def _copy_function(self, fn):
556 # Create a copy of a function's code, defaults, and closure.
557 # This method is necessary to create a separate instance of a function, preserving its
558 # original properties, while allowing modifications to the copied instance without affecting
559 # the original. This is particularly useful for functions that need to be wrapped or altered
560 # dynamically within the class.
561
562 import types
563
564 new_fn = types.FunctionType(
565 fn.__code__, fn.__globals__, name=fn.__name__, argdefs=fn.__defaults__, closure=fn.__closure__
566 )
567 new_fn.__doc__ = fn.__doc__
568 new_fn.__annotations__ = fn.__annotations__
569 new_fn.__dict__.update(fn.__dict__)
570 return new_fn