Source code for pyxu.operator.interop.torch

  1import typing as typ
  2import warnings
  3from functools import wraps
  4
  5import numpy as np
  6import packaging.version as pkgv
  7
  8import pyxu.abc as pxa
  9import pyxu.info.deps as pxd
 10import pyxu.info.ptype as pxt
 11import pyxu.info.warning as pxw
 12import pyxu.operator.interop.source as px_src
 13import pyxu.runtime as pxrt
 14from pyxu.util import import_module
 15
 16torch = import_module("torch", fail_on_error=False)
 17if torch is not None:
 18    version = pkgv.Version(torch.__version__)
 19    supported = pxd.PYTORCH_SUPPORT
 20    assert supported["min"] <= version < supported["max"]
 21
 22    import torch._dynamo as dynamo
 23    import torch.func as functorch
 24
 25    TorchTensor = torch.Tensor
 26else:
 27    TorchTensor = typ.TypeVar("torch.Tensor")
 28
 29__all__ = [
 30    "from_torch",
 31]
 32
 33
 34def _traceable(f):
 35    # Needed to compile functorch transforms. See this issue: https://github.com/pytorch/pytorch/issues/98822
 36    f = dynamo.allow_in_graph(f)
 37
 38    @wraps(f)
 39    def wrapper(*args, **kwargs):
 40        return f(*args, **kwargs)
 41
 42    return wrapper
 43
 44
[docs] 45def _to_torch(arr: pxt.NDArray, requires_grad: bool = False) -> TorchTensor: 46 r""" 47 NumPy/CuPy -> PyTorch conversion. 48 49 Convert a NumPy-like array into a PyTorch tensor, sharing data, dtype and device. 50 51 Parameters 52 ---------- 53 arr: NDArray 54 Input array. 55 requires_grad: bool 56 If autograd should record operations on the returned tensor. 57 58 Returns 59 ------- 60 tensor: torch.Tensor 61 Output tensor. 62 63 Notes 64 ----- 65 The returned tensor and input array share the same memory. Modifications to the tensor will be reflected in the 66 ndarray and vice versa. The returned tensor is not resizable. 67 """ 68 if pxd.NDArrayInfo.from_obj(arr) == pxd.NDArrayInfo.CUPY: 69 with torch.device("cuda", arr.device.id): 70 return torch.as_tensor(arr).requires_grad_(requires_grad) 71 else: 72 return torch.from_numpy(arr).requires_grad_(requires_grad)
73 74
[docs] 75def _from_torch(tensor: TorchTensor) -> pxt.NDArray: 76 r""" 77 PyTorch -> NumPy/CuPy conversion. 78 79 Convert a PyTorch tensor into a NumPy-like array, sharing data, dtype and device. 80 81 Parameters 82 ---------- 83 tensor: torch.Tensor 84 Input tensor. 85 86 Returns 87 ------- 88 arr: NDArray 89 Output array. 90 91 Notes 92 ----- 93 The returned array and input tensor share the same memory. Modifications to the tensor will be reflected in the 94 ndarray and vice versa. 95 """ 96 if tensor.get_device() == -1: 97 return tensor.detach().numpy(force=False) 98 else: 99 cp = pxd.NDArrayInfo.CUPY.module() 100 with cp.cuda.Device(tensor.get_device()): 101 return cp.asarray(tensor.detach())
102 103
[docs] 104def from_torch( 105 cls: pxt.OpC, 106 dim_shape: pxt.NDArrayShape, 107 codim_shape: pxt.NDArrayShape, 108 vectorize: pxt.VarName = frozenset(), 109 jit: bool = False, 110 enable_warnings: bool = True, 111 **kwargs, 112) -> pxt.OpT: 113 r""" 114 Define an :py:class:`~pyxu.abc.Operator` from PyTorch functions. 115 116 Parameters 117 ---------- 118 cls: OpC 119 Operator sub-class to instantiate. 120 dim_shape: NDArrayShape 121 Operator domain shape (M1,...,MD). 122 codim_shape: NDArrayShape 123 Operator co-domain shape (N1,...,NK). 124 kwargs: dict 125 (k[str], v[callable]) pairs to use as arithmetic methods. 126 127 Keys are restricted to the following arithmetic methods: 128 129 .. code-block:: python3 130 131 apply(), grad(), prox(), pinv(), adjoint() 132 133 Omitted arithmetic methods default to those provided by `cls`, or are auto-inferred via auto-diff rules. 134 vectorize: VarName 135 Arithmetic methods to vectorize. 136 137 `vectorize` is useful if an arithmetic method provided to `kwargs` does not support stacking dimensions. 138 jit: bool 139 Currently has no effect (for future-compatibility only). In the future, if ``True``, then Torch-backed 140 arithmetic methods will be JIT-compiled for better performance. 141 enable_warnings: bool 142 If ``True``, emit warnings in case of precision/zero-copy issues. 143 144 Returns 145 ------- 146 op: OpT 147 Pyxu-compliant operator :math:`A: \mathbb{R}^{M_{1} \times\cdots\times M_{D}} \to \mathbb{R}^{N_{1} 148 \times\cdots\times N_{K}}`. 149 150 Notes 151 ----- 152 * If provided, arithmetic methods must abide exactly to the interface below: 153 154 .. code-block:: python3 155 156 def apply(arr: torch.Tensor) -> torch.Tensor # (..., M1,...,MD) -> (..., N1,...,NK) 157 def grad(arr: torch.Tensor) -> torch.Tensor # (..., M1,...,MD) -> (..., M1,...,MD) 158 def adjoint(arr: torch.Tensor) -> torch.Tensor # (..., N1,...,NK) -> (..., M1,...,MD) 159 def prox(arr: torch.Tensor, tau: pxt.Real) -> torch.Tensor # (..., M1,...,MD) -> (..., M1,...,MD) 160 def pinv(arr: torch.Tensor, damp: pxt.Real) -> torch.Tensor # (..., N1,...,NK) -> (..., M1,...,MD) 161 162 Moreover, the methods above **must** accept stacking dimensions in ``arr``. If this does not hold, consider 163 populating `vectorize`. 164 165 * Auto-vectorization consists in decorating `kwargs`-specified arithmetic methods with :py:func:`torch.vmap`. See 166 the `PyTorch documentation <https://pytorch.org/docs/stable/func.ux_limitations.html#vmap-limitations>`_ for known 167 limitations. 168 169 * Arithmetic methods are **not currently JIT-ed** even if `jit` is set to ``True``. This is because of the 170 undocumented and currently poor interaction between :py:mod:`torch.func` transforms and :py:func:`torch.compile`. 171 See `this issue <https://github.com/pytorch/pytorch/issues/98822>`_ for additional details. 172 173 * For :py:class:`~pyxu.abc.DiffMap` (or subclasses thereof), the methods :py:meth:`~pyxu.abc.DiffMap.jacobian`, 174 :py:meth:`~pyxu.abc.DiffFunc.grad` and :py:meth:`~pyxu.abc.LinOp.adjoint` are defined implicitly if not provided 175 using the auto-differentiation transforms from :py:mod:`torch.func`. As detailed `on this page 176 <https://pytorch.org/docs/stable/func.ux_limitations.html>`_, such transforms work well on pure functions (that 177 is, functions where the output is completely determined by the input and that do not involve side effects like 178 mutation), but may fail on more complex functions. Moreover, :py:mod:`torch.func` does not yet have full coverage 179 over PyTorch operations. For functions that call a :py:class:`torch.nn.Module`, see `here 180 <https://pytorch.org/docs/stable/func.api.html#utilities-for-working-with-torch-nn-modules>`_ for some utilities. 181 182 .. Warning:: 183 184 Operators created with this wrapper do not support Dask inputs for now. 185 """ 186 if isinstance(vectorize, str): 187 vectorize = (vectorize,) 188 vectorize = frozenset(vectorize) 189 190 src = _FromTorch( 191 cls=cls, 192 dim_shape=dim_shape, 193 codim_shape=codim_shape, 194 vectorize=vectorize, 195 jit=bool(jit), 196 enable_warnings=bool(enable_warnings), 197 **kwargs, 198 ) 199 op = src.op() 200 return op
201 202 203class _FromTorch(px_src._FromSource): 204 # supported methods in __init__(**kwargs) 205 _meth = frozenset({"apply", "grad", "prox", "pinv", "adjoint"}) 206 207 def __init__( # See from_torch() for a detailed description. 208 self, 209 cls: pxt.OpC, 210 dim_shape: pxt.NDArrayShape, 211 codim_shape: pxt.NDArrayShape, 212 vectorize: frozenset[str], 213 jit: bool, # Unused for now 214 enable_warnings: bool, 215 **kwargs, 216 ): 217 self._batch_size = kwargs.pop("batch_size", None) 218 self._dtype = kwargs.pop("dtype", None) 219 super().__init__( 220 cls=cls, 221 dim_shape=dim_shape, 222 codim_shape=codim_shape, 223 embed=dict(), 224 vectorize=vectorize, 225 **kwargs, 226 ) 227 228 self._jit = False # JIT-compilation is currently deactivated until torch.func goes out of beta. 229 self._enable_warnings = enable_warnings 230 231 # Only a subset of arithmetic methods allowed from from_torch(). 232 if not (set(self._kwargs) <= self._meth): 233 msg_head = "Unsupported arithmetic methods:" 234 unsupported = set(self._kwargs) - self._meth 235 msg_tail = ", ".join([f"{name}()" for name in unsupported]) 236 raise ValueError(f"{msg_head} {msg_tail}") 237 238 def op(self) -> pxt.OpT: 239 # Idea: modify `**kwargs` from constructor to [when required]: 240 # 1. auto-define omitted methods. [_infer_missing()] 241 # 2. auto-vectorize via vmap(). [_auto_vectorize()] 242 # 3. JIT-compile via compile(). [_compile()] 243 # 4. TORCH<>NumPy/CuPy conversions. [_interface()] 244 # Note: JIT-compilation is currently deactivated due to the undocumented interaction of torch.func transforms 245 # and torch.compile. Will be reactivated once torch.func goes out of beta. 246 247 self._infer_missing() 248 self._compile() 249 self._auto_vectorize() 250 t_state, kwargs = self._interface() 251 252 _op = px_src.from_source( 253 cls=self._op.__class__, 254 dim_shape=self._op.dim_shape, 255 codim_shape=self._op.codim_shape, 256 embed=dict( 257 _batch_size=self._batch_size, 258 _dtype=self._dtype, 259 _jit=self._jit, 260 _enable_warnings=self._enable_warnings, 261 _coerce=self._coerce, 262 _torch=t_state, 263 ), 264 # vectorize=None, # see top-level comment. 265 **kwargs, 266 ) 267 return _op 268 269 def _infer_missing(self): 270 # The following methods must be auto-inferred if missing from `kwargs`: 271 # 272 # grad(), adjoint() 273 # 274 # Missing methods are auto-inferred via auto-diff rules and added to `_kwargs`. 275 # At the end of _infer_missing(), all torch-funcs required for _interface() have been added to `_kwargs`. 276 # 277 # Notes 278 # ----- 279 # This method does NOT produce vectorized implementations: _auto_vectorize() is responsible for this. 280 self._vectorize = set(self._vectorize) # to allow updates below 281 282 nl_difffunc = all( # non-linear diff-func 283 [ 284 self._op.has(pxa.Property.DIFFERENTIABLE_FUNCTION), 285 not self._op.has(pxa.Property.LINEAR), 286 ] 287 ) 288 if nl_difffunc and ("grad" not in self._kwargs): 289 apply_no_vec = self._copy_function(self._kwargs["apply"]) 290 291 def f_grad(tensor: TorchTensor) -> TorchTensor: 292 apply = lambda tnsr: torch.squeeze(apply_no_vec(tnsr)) 293 grad = functorch.grad(apply) 294 return grad(tensor) 295 296 self._vectorize.add("grad") 297 self._kwargs["grad"] = f_grad 298 299 non_selfadj = all( # linear, but not self-adjoint 300 [ 301 self._op.has(pxa.Property.LINEAR), 302 not self._op.has(pxa.Property.LINEAR_SELF_ADJOINT), 303 ] 304 ) 305 if non_selfadj and ("adjoint" not in self._kwargs): 306 apply_no_vec = self._copy_function(self._kwargs["apply"]) 307 308 def f_adjoint(tensor: TorchTensor) -> TorchTensor: 309 primal = torch.zeros(self._op.dim_shape, dtype=tensor.dtype) 310 _, f_vjp = functorch.vjp(apply_no_vec, primal) 311 return f_vjp(tensor)[0] # f_vjp returns a tuple 312 313 self._vectorize.add("adjoint") 314 self._kwargs["adjoint"] = f_adjoint 315 316 self._vectorize = frozenset(self._vectorize) 317 318 def _compile(self): 319 # JIT-compile user-specified [or _infer_missing()-added] arithmetic methods via torch.compile(). 320 # 321 # Modifies `_kwargs` to hold compiled torch-funcs. 322 # Note: Currently deactivated until torch.func goes out of beta. 323 324 if self._jit: 325 for name in self._kwargs: 326 if name in self._meth: 327 func = self._kwargs[name] # necessarily torch_func 328 self._kwargs[name] = torch.compile(func) 329 330 def _auto_vectorize(self): 331 # Vectorize user-specified [or _infer_missing()-added] arithmetic methods via torch.vmap(). 332 # 333 # Modifies `_kwargs` to hold vectorized torch-funcs. 334 335 for name in self._kwargs: 336 if name in self._vectorize: 337 func = self._kwargs[name] # necessarily torch_func 338 if name in [ 339 "prox", 340 "pinv", 341 ]: # These methods have two arguments, but vectorization should be for the first argument only. 342 self._kwargs[name] = _traceable(torch.vmap(func, in_dims=(0, None), chunk_size=self._batch_size)) 343 else: 344 self._kwargs[name] = _traceable(torch.vmap(func, chunk_size=self._batch_size)) 345 346 def _interface(self): 347 # Arithmetic methods supplied in `kwargs`: 348 # 349 # * take `torch.Tensor` inputs 350 # * do not have the `self` parameter. (Reason: to be JIT-compatible.) 351 # 352 # This method creates modified arithmetic functions to match Pyxu's API, and `state` required for them to work. 353 # 354 # Returns 355 # ------- 356 # t_state: dict 357 # Torch functions referenced by wrapper arithmetic methods. (See below.) 358 # kwargs: dict 359 # Pyxu-compatible functions which can be submitted to interop.from_source(). 360 361 # torch_func's (potentially auto_vec/jitted at this stage) 362 t_state = {name: obj for name, obj in self._kwargs.items() if name in self._meth} 363 364 # Pyxu-compatible functions 365 kwargs = {name: getattr(self.__class__, name) for name in self._kwargs} 366 367 # Special cases. 368 # (Reason: not in `_kwargs` [c.f. from_torch() docstring], but need to be inferred.) 369 for name in ("jacobian", "_quad_spec", "_expr", "asarray"): 370 kwargs[name] = getattr(self.__class__, name) 371 372 return t_state, kwargs 373 374 def _coerce(self, arr: pxt.NDArray) -> pxt.NDArray: 375 # Coerce inputs (and raise a warning) in case of precision mis-matches. 376 if self._dtype is not None and (arr.dtype != self._dtype): 377 if self._enable_warnings: 378 msg = f"Precision mis-match! Input array was coerced to {self._dtype} precision automatically." 379 warnings.warn(msg, pxw.PrecisionWarning) 380 arr = arr.astype(self._dtype, copy=False) 381 return arr 382 383 # Wrapper arithmetic methods ---------------------------------------------- 384 def apply(self, arr: pxt.NDArray) -> pxt.NDArray: 385 bsh = arr.shape[: -self.dim_rank] 386 bsize = (int(np.prod(bsh)),) 387 arr = arr.reshape(bsize + self.dim_shape) 388 arr = self._coerce(arr) 389 tensor = _to_torch(arr) 390 func = self._torch["apply"] 391 out = _from_torch(func(tensor)) 392 return out.reshape(bsh + self.codim_shape) 393 394 def grad(self, arr: pxt.NDArray) -> pxt.NDArray: 395 bsh = arr.shape[: -self.dim_rank] 396 bsize = (int(np.prod(bsh)),) 397 arr = arr.reshape(bsize + self.dim_shape) 398 arr = self._coerce(arr) 399 tensor = _to_torch(arr) 400 func = self._torch["grad"] 401 out = _from_torch(func(tensor)) 402 return out.reshape(bsh + self.dim_shape) 403 404 def adjoint(self, arr: pxt.NDArray) -> pxt.NDArray: 405 bsh = arr.shape[: -self.codim_rank] 406 bsize = (int(np.prod(bsh)),) 407 arr = arr.reshape(bsize + self.codim_shape) 408 arr = self._coerce(arr) 409 tensor = _to_torch(arr) 410 func = self._torch["adjoint"] 411 out = _from_torch(func(tensor)) 412 return out.reshape(bsh + self.dim_shape) 413 414 def prox(self, arr: pxt.NDArray, tau: pxt.Real) -> pxt.NDArray: 415 bsh = arr.shape[: -self.dim_rank] 416 bsize = (int(np.prod(bsh)),) 417 arr = arr.reshape(bsize + self.dim_shape) 418 arr = self._coerce(arr) 419 tensor = _to_torch(arr) 420 func = self._torch["prox"] 421 out = _from_torch(func(tensor, tau)) 422 return out.reshape(bsh + self.dim_shape) 423 424 def pinv(self, arr: pxt.NDArray, damp: pxt.Real, **kwargs) -> pxt.NDArray: 425 bsh = arr.shape[: -self.codim_rank] 426 bsize = (int(np.prod(bsh)),) 427 arr = arr.reshape(bsize + self.codim_shape) 428 arr = self._coerce(arr) 429 tensor = _to_torch(arr) 430 func = self._torch["pinv"] 431 out = func(tensor, damp) # positional args only if auto-vectorized. 432 out = _from_torch(out) 433 return out.reshape(bsh + self.dim_shape) 434 435 def jacobian(self, arr: pxt.NDArray) -> pxt.OpT: 436 try: 437 # Use the class' method if available ... 438 klass = self.__class__ 439 op = klass.jacobian(self, arr) 440 except NotImplementedError: 441 # ... and fallback to auto-inference if undefined. 442 # bsh = arr.shape[:-self.codim_rank] 443 # bsize = (int(np.prod(bsh)), ) 444 # arr = arr.reshape(bsize + self.codim_shape) 445 arr = self._coerce(arr) 446 primal = _to_torch(arr) 447 f = self._torch["apply"] 448 449 def jf_apply(tan: TorchTensor) -> TorchTensor: 450 return functorch.jvp(f, primals=(primal,), tangents=(tan,))[1] 451 452 def jf_adjoint(cotan: TorchTensor) -> TorchTensor: 453 _, f_vjp = functorch.vjp(f, primal) 454 return f_vjp(cotan)[0] # f_vjp returns a tuple 455 456 klass = pxa.LinFunc if (self.codim_shape == (1,)) else pxa.LinOp 457 op = from_torch( 458 cls=klass, 459 dim_shape=self.dim_shape, 460 codim_shape=self.codim_shape, 461 vectorize=("apply", "adjoint"), 462 batch_size=self._batch_size, 463 jit=self._jit, 464 dtype=self._dtype, 465 enable_warnings=self._enable_warnings, 466 apply=jf_apply, 467 adjoint=jf_adjoint, 468 ) 469 return op 470 471 def _quad_spec(self): 472 if self.has(pxa.Property.QUADRATIC): 473 # auto-infer (Q, c, t) 474 _grad = self._torch["grad"] 475 476 def Q_apply(tensor: TorchTensor) -> TorchTensor: 477 # \grad_{f}(x) = Q x + c = Q x + \grad_{f}(0) 478 # ==> Q x = \grad_{f}(x) - \grad_{f}(0) 479 z = torch.zeros_like(tensor) 480 out = _grad(tensor) - _grad(z) 481 return out 482 483 def c_apply(tensor: TorchTensor) -> TorchTensor: 484 z = torch.zeros_like(tensor) 485 c = _grad(z) 486 out = torch.sum(c * tensor, dtype=tensor.dtype).unsqueeze(-1) 487 return out 488 489 Q = from_torch( 490 apply=Q_apply, 491 dim_shape=self.dim_shape, 492 codim_shape=self.dim_shape, 493 cls=pxa.PosDefOp, 494 vectorize="apply", 495 batch_size=self._batch_size, 496 jit=self._jit, 497 dtype=self._dtype, 498 enable_warnings=self._enable_warnings, 499 ) 500 501 c = from_torch( 502 apply=c_apply, 503 dim_shape=self.dim_shape, 504 codim_shape=1, 505 cls=pxa.LinFunc, 506 vectorize="apply", 507 batch_size=self._batch_size, 508 jit=self._jit, 509 dtype=self._dtype, 510 enable_warnings=self._enable_warnings, 511 ) 512 513 # We cannot know a-priori which backend the supplied torch-apply() function works with. 514 # Consequence: to compute `t`, we must try different backends until one works. 515 f = self._torch["apply"] 516 t = self._compute_t(f) 517 518 return (Q, c, t) 519 else: 520 raise NotImplementedError 521 522 def _compute_t(self, f): 523 try: 524 with torch.device("cpu"): 525 return float(f(torch.zeros(self.dim_shape))) 526 except Exception: 527 for id in range(torch.cuda.device_count()): 528 try: 529 with torch.device("gpu", id): 530 return float(f(torch.zeros(self.dim_shape))) 531 except Exception: 532 continue 533 raise RuntimeError("Failed to compute t") 534 535 def asarray(self, **kwargs) -> pxt.NDArray: 536 if self.has(pxa.Property.LINEAR): 537 # Torch operators don't accept DASK inputs: cannot call Lin[Op,Func].asarray() with user-specified `xp` value. 538 # -> We arbitrarily perform computations using the NUMPY backend, then cast as needed. 539 N = pxd.NDArrayInfo # shorthand 540 dtype = kwargs.get("dtype", pxrt.Width.DOUBLE.value) 541 xp = kwargs.get("xp", N.default().module()) 542 543 klass = self.__class__ 544 A = klass.asarray(self, dtype=dtype, xp=N.NUMPY.module()) 545 546 # Not the most efficient method, but fail-proof 547 return xp.array(A, dtype=dtype) 548 else: 549 raise NotImplementedError 550 551 def _expr(self): 552 torch_funcs = list(self._torch.keys()) 553 return ("from_torch", *torch_funcs) 554 555 def _copy_function(self, fn): 556 # Create a copy of a function's code, defaults, and closure. 557 # This method is necessary to create a separate instance of a function, preserving its 558 # original properties, while allowing modifications to the copied instance without affecting 559 # the original. This is particularly useful for functions that need to be wrapped or altered 560 # dynamically within the class. 561 562 import types 563 564 new_fn = types.FunctionType( 565 fn.__code__, fn.__globals__, name=fn.__name__, argdefs=fn.__defaults__, closure=fn.__closure__ 566 ) 567 new_fn.__doc__ = fn.__doc__ 568 new_fn.__annotations__ = fn.__annotations__ 569 new_fn.__dict__.update(fn.__dict__) 570 return new_fn