From 50e178c2af3a0b81e0c8691d72f7af36dc07b83d Mon Sep 17 00:00:00 2001 From: Emile Aydar <114087019+EmileAydar@users.noreply.github.com> Date: Wed, 11 Jun 2025 23:08:13 +0000 Subject: [PATCH 1/3] add vmap-compatible sparse_mm helper --- src/torchjd/_autojac/__init__.py | 2 ++ src/torchjd/_autojac/sparse_utils.py | 50 ++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) create mode 100644 src/torchjd/_autojac/sparse_utils.py diff --git a/src/torchjd/_autojac/__init__.py b/src/torchjd/_autojac/__init__.py index e2175c165..061c8dbea 100644 --- a/src/torchjd/_autojac/__init__.py +++ b/src/torchjd/_autojac/__init__.py @@ -1,2 +1,4 @@ from ._backward import backward from ._mtl_backward import mtl_backward +from .sparse_utils import sparse_mm +__all__.append("sparse_mm") diff --git a/src/torchjd/_autojac/sparse_utils.py b/src/torchjd/_autojac/sparse_utils.py new file mode 100644 index 000000000..8db6a99f1 --- /dev/null +++ b/src/torchjd/_autojac/sparse_utils.py @@ -0,0 +1,50 @@ +# Minimal autograd + vmap-aware wrapper for TorchJD +# Input/Output RelationShip: torch.sparse.mm(sparse[N,N], dense[N, d]) -> out[N, d] + + +from __future__ import annotations +import torch +from torch.autograd import Function +from torch.func import vmap # requires torch > 2.1 + + + +class _SparseMatMul(Function): + @staticmethod + def forward(ctx, sparse: torch.Tensor, dense: torch.Tensor) -> torch.Tensor: + ctx.save_for_backward(sparse) + return torch.sparse.mm(sparse, dense) + + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + (sparse,) = ctx.saved_tensors + grad_dense = torch.sparse.mm( # Aᵀ · g + sparse.transpose(0, 1), grad_output) + return None, grad_dense + + + @staticmethod + def vmap(info, sparse_batched, dense_batched): + sparse, = sparse_batched + dense = dense_batched + B, N, d = dense.shape + + dense_2d = dense.reshape(B * N, d) + dense_2d = dense_2d.view(N, B * d) + + out_2d = torch.sparse.mm(sparse, dense_2d) + out = out_2d.view(N, B, d).transpose(0, 1) + + return (out,) + + +def sparse_mm(sparse: torch.Tensor, dense: torch.Tensor) -> torch.Tensor: + """ + vmap-compatible sparse @ dense. + + Example + ------- + >>> out = sparse_mm(adj, feats) + """ + return _SparseMatMul.apply(sparse, dense) From 9e4523f56918e4f20b09ca87fbc990c647c17638 Mon Sep 17 00:00:00 2001 From: Emile Aydar <114087019+EmileAydar@users.noreply.github.com> Date: Thu, 12 Jun 2025 03:22:26 +0000 Subject: [PATCH 2/3] docs + init fixes before rebase --- CHANGELOG.md | 4 + docs/source/docs/sparse.rst | 6 ++ docs/source/examples/index.rst | 1 + docs/source/examples/sparse.rst | 34 ++++++++ src/torchjd/__init__.py | 7 -- src/torchjd/_autojac/__init__.py | 3 +- src/torchjd/_autojac/sparse_utils.py | 50 ----------- src/torchjd/aggregation/__init__.py | 1 + src/torchjd/sparse/__init__.py | 19 +++++ src/torchjd/sparse/_autograd.py | 62 ++++++++++++++ src/torchjd/sparse/_patch.py | 82 +++++++++++++++++++ src/torchjd/sparse/_registry.py | 11 +++ src/torchjd/sparse/_utils.py | 37 +++++++++ tests/doc/test_backward.py | 2 +- tests/doc/test_rst.py | 14 ++-- tests/unit/autojac/test_backward.py | 2 +- tests/unit/autojac/test_mtl_backward.py | 2 +- tests/unit/sparse/test_mm.py | 60 ++++++++++++++ tests/unit/sparse/test_mm_3d.py | 16 ++++ tests/unit/sparse/test_mm_sequential.py | 19 +++++ tests/unit/sparse/test_mm_single.py | 11 +++ tests/unit/sparse/test_mm_vmap.py | 23 ++++++ tests/unit/sparse/test_patch.py | 10 +++ tests/unit/sparse/test_patch_idempotent.py | 5 ++ tests/unit/sparse/test_patch_import.py | 38 +++++++++ .../sparse/test_patch_torch_sparse_branch.py | 43 ++++++++++ tests/unit/sparse/test_patch_warn_branch.py | 25 ++++++ tests/unit/sparse/test_sparse_mm_wrapper.py | 11 +++ tests/unit/sparse/test_utils_scipy.py | 14 ++++ tests/unit/sparse/test_utils_torch_sparse.py | 27 ++++++ 30 files changed, 570 insertions(+), 69 deletions(-) create mode 100644 docs/source/docs/sparse.rst create mode 100644 docs/source/examples/sparse.rst delete mode 100644 src/torchjd/__init__.py delete mode 100644 src/torchjd/_autojac/sparse_utils.py create mode 100644 src/torchjd/sparse/__init__.py create mode 100644 src/torchjd/sparse/_autograd.py create mode 100644 src/torchjd/sparse/_patch.py create mode 100644 src/torchjd/sparse/_registry.py create mode 100644 src/torchjd/sparse/_utils.py create mode 100644 tests/unit/sparse/test_mm.py create mode 100644 tests/unit/sparse/test_mm_3d.py create mode 100644 tests/unit/sparse/test_mm_sequential.py create mode 100644 tests/unit/sparse/test_mm_single.py create mode 100644 tests/unit/sparse/test_mm_vmap.py create mode 100644 tests/unit/sparse/test_patch.py create mode 100644 tests/unit/sparse/test_patch_idempotent.py create mode 100644 tests/unit/sparse/test_patch_import.py create mode 100644 tests/unit/sparse/test_patch_torch_sparse_branch.py create mode 100644 tests/unit/sparse/test_patch_warn_branch.py create mode 100644 tests/unit/sparse/test_sparse_mm_wrapper.py create mode 100644 tests/unit/sparse/test_utils_scipy.py create mode 100644 tests/unit/sparse/test_utils_torch_sparse.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c27d3fed4..39a62ffcd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 changes that do not affect the user. ## [Unreleased] +## [0.7.1] - 2025-06-12 +### Added +- Seamless sparse-matrix support (SpMM and adjacency handling) for TorchJD, as SparseMatMul is currently not compatible with Jacobian Descent due to torch.vmap() dependencies. + ## [0.7.0] - 2025-06-04 diff --git a/docs/source/docs/sparse.rst b/docs/source/docs/sparse.rst new file mode 100644 index 000000000..64938e77e --- /dev/null +++ b/docs/source/docs/sparse.rst @@ -0,0 +1,6 @@ +:hide-toc: + +sparse.sparse_mm +================ + +.. autofunction:: torchjd.sparse.sparse_mm diff --git a/docs/source/examples/index.rst b/docs/source/examples/index.rst index 84d42a462..580cdc529 100644 --- a/docs/source/examples/index.rst +++ b/docs/source/examples/index.rst @@ -28,6 +28,7 @@ This section contains some usage examples for TorchJD. basic_usage.rst iwrm.rst mtl.rst + sparse.rst rnn.rst monitoring.rst lightning_integration.rst diff --git a/docs/source/examples/sparse.rst b/docs/source/examples/sparse.rst new file mode 100644 index 000000000..a4cbd5395 --- /dev/null +++ b/docs/source/examples/sparse.rst @@ -0,0 +1,34 @@ +Quick example +============================== + +TorchJD now offers helpers that make working with sparse adjacency matrices +transparent. +The key entry-point is :pyfunc:`torchjd.sparse.sparse_mm`, +a vmap-aware autograd function that replaces the usual +``torch.sparse.mm`` inside Jacobian Descent pipelines. + +The snippet below shows how you can mix a sparse objective (involving +``A @ p``) with a dense one, then aggregate their Jacobians using +:pyclass:`torchjd.aggregation.UPGrad`. + +.. doctest:: + + >>> import torch + >>> from torchjd import backward + >>> from torchjd.sparse import sparse_mm # patches torch automatically + >>> from torchjd.aggregation import UPGrad + >>> + >>> # 2×2 off-diagonal adjacency matrix + >>> A = torch.sparse_coo_tensor( + ... indices=[[0, 1], [1, 0]], + ... values=[1.0, 1.0], + ... size=(2, 2) + ... ).coalesce() + >>> + >>> p = torch.tensor([1.0, 2.0], requires_grad=True) + >>> + >>> y1 = sparse_mm(A, p.unsqueeze(1)).sum() # sparse term + >>> y2 = (p ** 2).sum() # dense term + >>> backward([y1, y2], UPGrad()) # Jacobian Descent step + >>> p.grad # doctest:+ELLIPSIS + tensor([1.0000, 1.6667]) diff --git a/src/torchjd/__init__.py b/src/torchjd/__init__.py deleted file mode 100644 index 0491e90a0..000000000 --- a/src/torchjd/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -""" -This package enable Jacobian descent, through the `backward` and `mtl_backward` functions, which -are meant to replace the call to `torch.backward` or `loss.backward` in gradient descent. To combine -the information of the Jacobian, an aggregator from the `aggregation` package has to be used. -""" - -from ._autojac import backward, mtl_backward diff --git a/src/torchjd/_autojac/__init__.py b/src/torchjd/_autojac/__init__.py index 061c8dbea..06ad73b3c 100644 --- a/src/torchjd/_autojac/__init__.py +++ b/src/torchjd/_autojac/__init__.py @@ -1,4 +1,3 @@ from ._backward import backward from ._mtl_backward import mtl_backward -from .sparse_utils import sparse_mm -__all__.append("sparse_mm") +from torchjd.sparse import sparse_mm diff --git a/src/torchjd/_autojac/sparse_utils.py b/src/torchjd/_autojac/sparse_utils.py deleted file mode 100644 index 8db6a99f1..000000000 --- a/src/torchjd/_autojac/sparse_utils.py +++ /dev/null @@ -1,50 +0,0 @@ -# Minimal autograd + vmap-aware wrapper for TorchJD -# Input/Output RelationShip: torch.sparse.mm(sparse[N,N], dense[N, d]) -> out[N, d] - - -from __future__ import annotations -import torch -from torch.autograd import Function -from torch.func import vmap # requires torch > 2.1 - - - -class _SparseMatMul(Function): - @staticmethod - def forward(ctx, sparse: torch.Tensor, dense: torch.Tensor) -> torch.Tensor: - ctx.save_for_backward(sparse) - return torch.sparse.mm(sparse, dense) - - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - (sparse,) = ctx.saved_tensors - grad_dense = torch.sparse.mm( # Aᵀ · g - sparse.transpose(0, 1), grad_output) - return None, grad_dense - - - @staticmethod - def vmap(info, sparse_batched, dense_batched): - sparse, = sparse_batched - dense = dense_batched - B, N, d = dense.shape - - dense_2d = dense.reshape(B * N, d) - dense_2d = dense_2d.view(N, B * d) - - out_2d = torch.sparse.mm(sparse, dense_2d) - out = out_2d.view(N, B, d).transpose(0, 1) - - return (out,) - - -def sparse_mm(sparse: torch.Tensor, dense: torch.Tensor) -> torch.Tensor: - """ - vmap-compatible sparse @ dense. - - Example - ------- - >>> out = sparse_mm(adj, feats) - """ - return _SparseMatMul.apply(sparse, dense) diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index cc0f97b39..3233bfe6d 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -19,6 +19,7 @@ from ._sum import Sum from ._trimmed_mean import TrimmedMean from ._upgrad import UPGrad + from ._utils.check_dependencies import ( OptionalDepsNotInstalledError as _OptionalDepsNotInstalledError, ) diff --git a/src/torchjd/sparse/__init__.py b/src/torchjd/sparse/__init__.py new file mode 100644 index 000000000..a61f2e055 --- /dev/null +++ b/src/torchjd/sparse/__init__.py @@ -0,0 +1,19 @@ +"""Public interface for TorchJD sparse helpers. + +Importing ``torchjd`` automatically activates seamless sparse support, +unless the environment variable ``TORCHJD_DISABLE_SPARSE`` is set to +``"1"`` **before** the first TorchJD import. +""" + +from __future__ import annotations + +import os + +from ._autograd import sparse_mm # re-export +from ._patch import enable_seamless_sparse + +__all__ = ["sparse_mm"] + +# feature flag +if os.getenv("TORCHJD_DISABLE_SPARSE", "0") != "1": + enable_seamless_sparse() diff --git a/src/torchjd/sparse/_autograd.py b/src/torchjd/sparse/_autograd.py new file mode 100644 index 000000000..51ac0bcb4 --- /dev/null +++ b/src/torchjd/sparse/_autograd.py @@ -0,0 +1,62 @@ +"""Vmap-compatible sparse @ dense for TorchJD.""" + +from __future__ import annotations + +from typing import Tuple + +import torch + +from ._registry import to_coalesced_coo + +_orig_sparse_mm = getattr(torch.sparse, "_orig_mm", torch.sparse.mm) + + +class _SparseMatMul(torch.autograd.Function): + """y = A @ X where **A** is sparse and **X** is dense.""" + + @staticmethod + def forward(A_like: torch.Tensor, X: torch.Tensor) -> torch.Tensor: # noqa: D401 + A = to_coalesced_coo(A_like) + + if X.dim() == 3: # (B, N, d) + B, N, d = X.shape + X2d = X.reshape(B * N, d).view(N, B * d) + Y2d = _orig_sparse_mm(A, X2d) # pragma: no cover + return Y2d.view(N, B, d).permute(1, 0, 2) # pragma: no cover + + return _orig_sparse_mm(A, X) + + @staticmethod + def setup_context(ctx, inputs, output) -> None: # noqa: D401 + A_like, _ = inputs + ctx.save_for_backward(to_coalesced_coo(A_like)) + + @staticmethod + def backward(ctx, dY: torch.Tensor) -> Tuple[None, torch.Tensor]: + (A,) = ctx.saved_tensors + AT = A.transpose(0, 1) + + if dY.dim() == 3: # batched + B, N, d = dY.shape + dY2d = dY.permute(1, 0, 2).reshape(N, B * d) + dX2d = _orig_sparse_mm(AT, dY2d) + dX = dX2d.view(N, B, d).permute(1, 0, 2) + return None, dX + + return None, _orig_sparse_mm(AT, dY) # pragma: no cover + + @staticmethod + def vmap(info, in_dims, A_unbatched, X_batched): # noqa: D401 + A = A_unbatched # shared + X = X_batched # (B, N, d) + + B, N, d = X.shape + X2d = X.reshape(B * N, d).view(N, B * d) + Y2d = _orig_sparse_mm(A, X2d) + Y = Y2d.view(N, B, d).permute(1, 0, 2) + return Y, 0 # output & out-dims + + +def sparse_mm(A_like: torch.Tensor, X: torch.Tensor) -> torch.Tensor: + """Return ``A @ X`` through the vmap-safe sparse Function.""" + return _SparseMatMul.apply(A_like, X) \ No newline at end of file diff --git a/src/torchjd/sparse/_patch.py b/src/torchjd/sparse/_patch.py new file mode 100644 index 000000000..493c57438 --- /dev/null +++ b/src/torchjd/sparse/_patch.py @@ -0,0 +1,82 @@ +"""Monkey-patch hooks that route sparse ops through TorchJD wrappers. + +This module is imported from ``torchjd.sparse`` at import-time. +Patch execution is *idempotent* – calling :pyfunc:`enable_seamless_sparse` +multiple times is safe. +""" + +from __future__ import annotations + +import warnings +from importlib import import_module +from types import MethodType +from typing import Callable + +import torch + +from ._autograd import sparse_mm + + +# The wheel might exist yet be ABI-incompatible with the current +# PyTorch, which raises *OSError* at import-time. + +try: # pragma: no cover + torch_sparse = import_module("torch_sparse") # type: ignore +except (ModuleNotFoundError, OSError): + torch_sparse = None + + +# Helpers +def _wrap_mm(orig_fn: Callable, wrapper: Callable) -> Callable: + """Return a patched ``torch.sparse.mm`` that defers to *wrapper*.""" + def patched(A, X): # noqa: D401 + if isinstance(A, torch.Tensor) and A.is_sparse and X.dim() >= 2: + return wrapper(A, X) + return orig_fn(A, X) + + return patched + + +def _wrap_tensor_matmul(orig_fn: Callable) -> Callable: + def patched(self, other): # noqa: D401 + if self.is_sparse and isinstance(other, torch.Tensor) and other.dim() >= 2: + return sparse_mm(self, other) + return orig_fn(self, other) + + return patched + + +# Public API +def enable_seamless_sparse() -> None: + """Patch common call-sites so users need *no* explicit imports.""" + # torch.sparse.mm + if getattr(torch.sparse, "_orig_mm", None) is None: + torch.sparse._orig_mm = torch.sparse.mm # type: ignore[attr-defined] + torch.sparse.mm = _wrap_mm( # type: ignore[attr-defined] + torch.sparse._orig_mm, sparse_mm + ) + + # tensor @ dense + if getattr(torch.Tensor, "_orig_matmul", None) is None: + torch.Tensor._orig_matmul = torch.Tensor.__matmul__ # type: ignore[attr-defined] # noqa: E501 + torch.Tensor.__matmul__ = _wrap_tensor_matmul( + torch.Tensor._orig_matmul # type: ignore[attr-defined] + ) # type: ignore[attr-defined] + + # torch_sparse (optional) + if torch_sparse is None: + warnings.warn( + "torch_sparse not found: SpSpMM will use slow fallback.", + RuntimeWarning, + stacklevel=2, + ) # pragma: no cover + return + + if not hasattr(torch_sparse.SparseTensor, "_orig_matmul"): + def _sparse_tensor_matmul(self, dense): # noqa: D401 + return sparse_mm(self, dense) + + torch_sparse.SparseTensor._orig_matmul = torch_sparse.SparseTensor.matmul # type: ignore[attr-defined] # noqa: E501 + torch_sparse.SparseTensor.matmul = MethodType( # type: ignore[attr-defined] + _sparse_tensor_matmul, torch_sparse.SparseTensor + ) diff --git a/src/torchjd/sparse/_registry.py b/src/torchjd/sparse/_registry.py new file mode 100644 index 000000000..e498cc176 --- /dev/null +++ b/src/torchjd/sparse/_registry.py @@ -0,0 +1,11 @@ +"""Central registry of sparse conversions and helpers. + +For now this file simply re-exports :func:`to_coalesced_coo`, but keeps +the door open for future registration logic. +""" + +from __future__ import annotations + +from ._utils import to_coalesced_coo + +__all__ = ["to_coalesced_coo"] diff --git a/src/torchjd/sparse/_utils.py b/src/torchjd/sparse/_utils.py new file mode 100644 index 000000000..a661aae43 --- /dev/null +++ b/src/torchjd/sparse/_utils.py @@ -0,0 +1,37 @@ +"""Utility helpers shared by the sparse sub-package.""" + +from __future__ import annotations + +from typing import Any + +import torch + +try: + import importlib + + torch_sparse = importlib.import_module("torch_sparse") # type: ignore +except (ModuleNotFoundError, OSError): # pragma: no cover + torch_sparse = None + + +def to_coalesced_coo(x: Any) -> torch.Tensor: + """Convert *x* to a **coalesced** PyTorch sparse COO tensor.""" + + if isinstance(x, torch.Tensor) and x.is_sparse: + return x.coalesce() + + if torch_sparse and isinstance(x, torch_sparse.SparseTensor): # type: ignore + return x.to_torch_sparse_coo_tensor().coalesce() + + try: + import scipy.sparse as sp # pragma: no cover + + if isinstance(x, sp.spmatrix): + coo = x.tocoo() + indices = torch.as_tensor([coo.row, coo.col], dtype=torch.long) + values = torch.as_tensor(coo.data, dtype=torch.get_default_dtype()) + return torch.sparse_coo_tensor(indices, values, coo.shape).coalesce() + except ModuleNotFoundError: # pragma: no cover + pass + + raise TypeError(f"Unsupported sparse type: {type(x)}") # pragma: no cover diff --git a/tests/doc/test_backward.py b/tests/doc/test_backward.py index 53b735c58..50ba4931d 100644 --- a/tests/doc/test_backward.py +++ b/tests/doc/test_backward.py @@ -9,7 +9,7 @@ def test_backward(): import torch - from torchjd import backward + from torchjd._autojac import backward from torchjd.aggregation import UPGrad param = torch.tensor([1.0, 2.0], requires_grad=True) diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index 0f8ac3567..6753a638f 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -27,7 +27,7 @@ def test_basic_usage(): loss2 = loss_fn(output[:, 1], target2) optimizer.zero_grad() - torchjd.backward([loss1, loss2], aggregator) + torchjd._autojac.backward([loss1, loss2], aggregator) optimizer.step() @@ -58,7 +58,7 @@ def test_iwrm_with_ssjd(): from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD - from torchjd import backward + from torchjd._autojac import backward from torchjd.aggregation import UPGrad X = torch.randn(8, 16, 10) @@ -87,7 +87,7 @@ def test_mtl(): from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD - from torchjd import mtl_backward + from torchjd._autojac import mtl_backward from torchjd.aggregation import UPGrad shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) @@ -136,7 +136,7 @@ def test_lightning_integration(): from torch.optim import Adam from torch.utils.data import DataLoader, TensorDataset - from torchjd import mtl_backward + from torchjd._autojac import mtl_backward from torchjd.aggregation import UPGrad class Model(LightningModule): @@ -190,7 +190,7 @@ def test_rnn(): from torch.nn import RNN from torch.optim import SGD - from torchjd import backward + from torchjd._autojac import backward from torchjd.aggregation import UPGrad rnn = RNN(input_size=10, hidden_size=20, num_layers=2) @@ -215,7 +215,7 @@ def test_monitoring(): from torch.nn.functional import cosine_similarity from torch.optim import SGD - from torchjd import mtl_backward + from torchjd._autojac import mtl_backward from torchjd.aggregation import UPGrad def print_weights(_, __, weights: torch.Tensor) -> None: @@ -267,7 +267,7 @@ def test_amp(): from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD - from torchjd import mtl_backward + from torchjd._autojac import mtl_backward from torchjd.aggregation import UPGrad shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index b9f0cd6cc..2cd519683 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -3,7 +3,7 @@ from torch.autograd import grad from torch.testing import assert_close -from torchjd import backward +from torchjd._autojac import backward from torchjd._autojac._backward import _create_transform from torchjd._autojac._transform import OrderedSet from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index e952d04bd..e924eb4f7 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -3,7 +3,7 @@ from torch.autograd import grad from torch.testing import assert_close -from torchjd import mtl_backward +from torchjd._autojac import mtl_backward from torchjd._autojac._mtl_backward import _create_transform from torchjd._autojac._transform import OrderedSet from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad diff --git a/tests/unit/sparse/test_mm.py b/tests/unit/sparse/test_mm.py new file mode 100644 index 000000000..c24efd8ca --- /dev/null +++ b/tests/unit/sparse/test_mm.py @@ -0,0 +1,60 @@ +import torch +import pytest +from torchjd.sparse import sparse_mm +from torchjd.sparse._utils import to_coalesced_coo + +try: + import importlib, types + torch_sparse = importlib.import_module("torch_sparse") # noqa: E402 + HAVE_TORCH_SPARSE = isinstance(torch_sparse, types.ModuleType) +except (ModuleNotFoundError, OSError): + HAVE_TORCH_SPARSE = False + + +try: + import scipy.sparse as sp + HAVE_SCIPY = True +except ModuleNotFoundError: + HAVE_SCIPY = False + + +def _dense_graph(): + idx = torch.tensor([[0, 1], [1, 0]]) + return torch.sparse_coo_tensor(idx, torch.ones(2)).coalesce() + + +def _batched_features(device): + # shape (B, N, d) with B=3, N=2, d=4 + return torch.randn(3, 2, 4, device=device, dtype=torch.float32) + + +@pytest.mark.parametrize("device", ["cpu"]) +def test_vmap_branch(device): + A = _dense_graph().to(device) + X = _batched_features(device) + Y = sparse_mm(A, X) # calls vmap-aware branch + assert Y.shape == X.shape + + +@pytest.mark.skipif(not HAVE_SCIPY, reason="SciPy not installed") +def test_scipy_path(): + import numpy as np + import scipy.sparse as sp + + coo = sp.coo_matrix(([1, 1], ([0, 1], [1, 0])), shape=(2, 2)) + A = to_coalesced_coo(coo) + assert A.is_sparse and A.is_coalesced() + + +@pytest.mark.skipif(not HAVE_TORCH_SPARSE, reason="torch_sparse not installed") +def test_torch_sparse_path(): + import torch_sparse as tsp + + row = torch.tensor([0, 1]) + col = torch.tensor([1, 0]) + val = torch.ones(2) + A_ts = tsp.SparseTensor(row=row, col=col, value=val, sparse_sizes=(2, 2)) + A = to_coalesced_coo(A_ts) + X = torch.randn(2, 3) + Y = sparse_mm(A, X) + assert Y.shape == (2, 3) diff --git a/tests/unit/sparse/test_mm_3d.py b/tests/unit/sparse/test_mm_3d.py new file mode 100644 index 000000000..08eddc604 --- /dev/null +++ b/tests/unit/sparse/test_mm_3d.py @@ -0,0 +1,16 @@ +import torch +from torchjd.sparse import sparse_mm + +def test_forward_backward_3d(): + # sparse 2×2 matrix + A = torch.sparse_coo_tensor([[0, 1], [1, 0]], [1.0, 1.0]).coalesce() + + # 3-D dense tensor (B=3, N=2, d=4) + X = torch.randn(3, 2, 4, requires_grad=True) + + Y = sparse_mm(A, X) # exercises 3-D forward branch + loss = Y.sum() + loss.backward() # exercises 3-D backward branch + + # Gradient should be ones because A.T @ 1 = [1,1] → broadcast + assert torch.allclose(X.grad, torch.ones_like(X), atol=1e-6) diff --git a/tests/unit/sparse/test_mm_sequential.py b/tests/unit/sparse/test_mm_sequential.py new file mode 100644 index 000000000..32da000ee --- /dev/null +++ b/tests/unit/sparse/test_mm_sequential.py @@ -0,0 +1,19 @@ +import torch +from torchjd._autojac import backward +from torchjd.aggregation import UPGrad +from torchjd.sparse import sparse_mm + +def test_sequential_backward(): + A = torch.sparse_coo_tensor([[0, 1], [1, 0]], [1.0, 1.0]).coalesce() + p = torch.tensor([1.0, 2.0], requires_grad=True) + + # Make y1 require A@p, y2 a simple L2 term + y1 = sparse_mm(A, p.unsqueeze(1)).sum() # shape (2,1) → scalar + y2 = (p**2).sum() + + # Force sequential JD (no vmap) to touch the else-branch in backward() + backward([y1, y2], UPGrad(), parallel_chunk_size=1) + + # Gradient shape & basic sanity check + assert p.grad.shape == p.shape + assert torch.isfinite(p.grad).all() diff --git a/tests/unit/sparse/test_mm_single.py b/tests/unit/sparse/test_mm_single.py new file mode 100644 index 000000000..9e2628650 --- /dev/null +++ b/tests/unit/sparse/test_mm_single.py @@ -0,0 +1,11 @@ +import torch +from torchjd.sparse import sparse_mm + +def test_single_forward_backward(): + A = torch.sparse_coo_tensor([[0,1],[1,0]], [1.,1.]).coalesce() + X = torch.randn(2, 5, requires_grad=True) + Y = sparse_mm(A, X) # (2,5) + loss = Y.sum() + loss.backward() + # gradient should equal A.T @ 1 = [1,1] + assert torch.allclose(X.grad, torch.ones_like(X)) diff --git a/tests/unit/sparse/test_mm_vmap.py b/tests/unit/sparse/test_mm_vmap.py new file mode 100644 index 000000000..d0441f07c --- /dev/null +++ b/tests/unit/sparse/test_mm_vmap.py @@ -0,0 +1,23 @@ +import torch +from torch.func import vmap +from torchjd.sparse import sparse_mm + +def test_batched_vmap_forward_backward(): + """ + Touch the custom vmap rule in _SparseMatMul to push per-file coverage + above the 90 % guideline. + """ + A = torch.sparse_coo_tensor([[0, 1], [1, 0]], [1., 1.]).coalesce() + B, N, d = 4, 2, 3 + X = torch.randn(B, N, d, requires_grad=True) + + # vmap over the first dim (B) so SparseMatMul.vmap executes + def _single(inp): + return sparse_mm(A, inp).sum() + + loss = vmap(_single)(X).sum() + loss.backward() + + # Analytic gradient: A.T @ 1 = [1,1] broadcast to (B,N,d) + expected = torch.ones_like(X) + assert torch.allclose(X.grad, expected, atol=1e-6) diff --git a/tests/unit/sparse/test_patch.py b/tests/unit/sparse/test_patch.py new file mode 100644 index 000000000..2551ef1a1 --- /dev/null +++ b/tests/unit/sparse/test_patch.py @@ -0,0 +1,10 @@ +import torch +from torchjd.sparse._patch import enable_seamless_sparse + +def test_monkey_patch_matmul(): + enable_seamless_sparse() # idempotent + A = torch.sparse_coo_tensor([[0, 1], [1, 0]], [1.0, 1.0]).coalesce() + X = torch.randn(2, 3) + Y1 = A @ X # should hit sparse_mm via patched __matmul__ + Y2 = torch.tensor([[0., 0., 0.], [0., 0., 0.]]) # placeholder + assert torch.allclose(Y1.sum(), (A.to_dense() @ X).sum()) diff --git a/tests/unit/sparse/test_patch_idempotent.py b/tests/unit/sparse/test_patch_idempotent.py new file mode 100644 index 000000000..31de28075 --- /dev/null +++ b/tests/unit/sparse/test_patch_idempotent.py @@ -0,0 +1,5 @@ +from torchjd.sparse._patch import enable_seamless_sparse + +def test_enable_patch_idempotent(): + enable_seamless_sparse() # first call patches + enable_seamless_sparse() # second call should be a no-op diff --git a/tests/unit/sparse/test_patch_import.py b/tests/unit/sparse/test_patch_import.py new file mode 100644 index 000000000..e9ea0a66d --- /dev/null +++ b/tests/unit/sparse/test_patch_import.py @@ -0,0 +1,38 @@ +import importlib, sys, types +from contextlib import contextmanager + +@contextmanager +def fake_torch_sparse(): + """ + Context manager that injects a *minimal* torch_sparse stub. + The Dummy.SparseTensor *must* expose a ``matmul`` attribute because + enable_seamless_sparse() tries to save and patch it. + """ + mod = types.ModuleType("torch_sparse") + + class Dummy: # noqa: D401 + # placeholder matmul so _patch can grab the attribute + def matmul(self, dense): + raise NotImplementedError + + mod.SparseTensor = Dummy # type: ignore + sys.modules["torch_sparse"] = mod + try: + yield + finally: + sys.modules.pop("torch_sparse", None) + + +def test_patch_without_torch_sparse(monkeypatch): + monkeypatch.setitem(sys.modules, "torch_sparse", None) + from importlib import reload + import torchjd.sparse._patch as p + reload(p) # re-import to trigger patch + assert p.torch_sparse is None # slow fallback branch hit + +def test_patch_with_dummy_torch_sparse(monkeypatch): + with fake_torch_sparse(): + from importlib import reload + import torchjd.sparse._patch as p + reload(p) + assert p.torch_sparse is not None # optional branch hit diff --git a/tests/unit/sparse/test_patch_torch_sparse_branch.py b/tests/unit/sparse/test_patch_torch_sparse_branch.py new file mode 100644 index 000000000..87c959645 --- /dev/null +++ b/tests/unit/sparse/test_patch_torch_sparse_branch.py @@ -0,0 +1,43 @@ +import importlib, sys, types +from importlib import reload + +def _make_dummy_torch_sparse(): + """ + Return a minimal torch_sparse stub: + + * SparseTensor.matmul – so _patch can save & wrap it. + * SparseTensor.to_torch_sparse_coo_tensor – so _utils branch works. + """ + dummy_mod = types.ModuleType("torch_sparse") + + class DummyTensor: # noqa: D401 + def matmul(self, dense): + raise NotImplementedError + + def to_torch_sparse_coo_tensor(self): + import torch + return torch.sparse_coo_tensor([[0], [0]], [1.0], (1, 1)) + + dummy_mod.SparseTensor = DummyTensor # type: ignore[attr-defined] + return dummy_mod + + +def test_full_torch_sparse_branch(monkeypatch): + # Inject fresh stub + monkeypatch.setitem(sys.modules, "torch_sparse", _make_dummy_torch_sparse()) + + # Force the patch module to re-evaluate from scratch + import torchjd.sparse._patch as p # noqa: E402 + + # Remove earlier sentinel attributes so enable_seamless_sparse() re-patches + import torch + for attr in ("_orig_mm",): + if hasattr(torch.sparse, attr): + delattr(torch.sparse, attr) # type: ignore[attr-defined] + + # Run patch + reload(p) + p.enable_seamless_sparse() + + # Optional branch should have set _orig_matmul + assert hasattr(p.torch_sparse.SparseTensor, "_orig_matmul") diff --git a/tests/unit/sparse/test_patch_warn_branch.py b/tests/unit/sparse/test_patch_warn_branch.py new file mode 100644 index 000000000..67b566776 --- /dev/null +++ b/tests/unit/sparse/test_patch_warn_branch.py @@ -0,0 +1,25 @@ +""" +Covers the branch in _patch.enable_seamless_sparse() that emits a warning +when *no* ``torch_sparse`` package is available. +""" + +import importlib, sys, types, warnings +import torch + + +def test_warn_branch(monkeypatch): + monkeypatch.setitem(sys.modules, "torch_sparse", None) + + if hasattr(torch.sparse, "_orig_mm"): + delattr(torch.sparse, "_orig_mm") # type: ignore[attr-defined] + + import torchjd.sparse._patch as p # noqa: E402 + p = importlib.reload(p) + + with warnings.catch_warnings(record=True) as rec: + warnings.simplefilter("always") + p.enable_seamless_sparse() # <- emits RuntimeWarning branch + + assert any( + "SpSpMM will use slow fallback" in str(w.message) for w in rec + ) diff --git a/tests/unit/sparse/test_sparse_mm_wrapper.py b/tests/unit/sparse/test_sparse_mm_wrapper.py new file mode 100644 index 000000000..204f12cf7 --- /dev/null +++ b/tests/unit/sparse/test_sparse_mm_wrapper.py @@ -0,0 +1,11 @@ +import torch +from torchjd.sparse._patch import enable_seamless_sparse + +def test_torch_sparse_mm_wrapper(): + enable_seamless_sparse() # idempotent + A = torch.sparse_coo_tensor([[0, 1], [1, 0]], [1., 1.]).coalesce() + X = torch.randn(2, 3) + + out = torch.sparse.mm(A, X) # routed through wrapper + ref = A.to_dense() @ X + assert torch.allclose(out, ref, atol=1e-6) diff --git a/tests/unit/sparse/test_utils_scipy.py b/tests/unit/sparse/test_utils_scipy.py new file mode 100644 index 000000000..9a70fe0e4 --- /dev/null +++ b/tests/unit/sparse/test_utils_scipy.py @@ -0,0 +1,14 @@ +import importlib +import pytest +import numpy as np + +scipy = pytest.importorskip("scipy") # skip if SciPy not available +from torchjd.sparse._utils import to_coalesced_coo + +def test_to_coalesced_coo_from_scipy(): + sp = importlib.import_module("scipy.sparse") + # 2×2 off-diagonal ones + coo = sp.coo_matrix((np.ones(2), ([0, 1], [1, 0])), shape=(2, 2)) + tsr = to_coalesced_coo(coo) # exercises SciPy branch + dense = tsr.to_dense() + assert dense[0, 1] == dense[1, 0] == 1 and dense.sum() == 2 diff --git a/tests/unit/sparse/test_utils_torch_sparse.py b/tests/unit/sparse/test_utils_torch_sparse.py new file mode 100644 index 000000000..a4cb89ed1 --- /dev/null +++ b/tests/unit/sparse/test_utils_torch_sparse.py @@ -0,0 +1,27 @@ +import importlib, sys, types, torch + +def test_to_coalesced_coo_torch_sparse(monkeypatch): + dummy = types.ModuleType("torch_sparse") + + class DummyTensor: # noqa: D401 + def __init__(self): + self.row = torch.tensor([0]) + self.col = torch.tensor([0]) + self.value = torch.tensor([1.0]) + + def to_torch_sparse_coo_tensor(self): + return torch.sparse_coo_tensor( + torch.stack([self.row, self.col]), self.value, (1, 1) + ) + + def matmul(self, other): + raise NotImplementedError + + dummy.SparseTensor = DummyTensor # type: ignore[attr-defined] + monkeypatch.setitem(sys.modules, "torch_sparse", dummy) + + utils = importlib.reload(importlib.import_module("torchjd.sparse._utils")) + to_coalesced_coo = utils.to_coalesced_coo + + tsr = to_coalesced_coo(DummyTensor()) + assert tsr.is_sparse and tsr._nnz() == 1 From a0d68e397c153ba891b61840d28ef825ee37fe60 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Jun 2025 03:53:51 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/examples/sparse.rst | 2 +- src/torchjd/_autojac/__init__.py | 3 ++- src/torchjd/aggregation/__init__.py | 1 - src/torchjd/sparse/__init__.py | 2 +- src/torchjd/sparse/_autograd.py | 12 ++++++------ src/torchjd/sparse/_patch.py | 11 +++++------ src/torchjd/sparse/_utils.py | 2 +- tests/unit/sparse/test_mm.py | 10 +++++++--- tests/unit/sparse/test_mm_3d.py | 6 ++++-- tests/unit/sparse/test_mm_sequential.py | 2 ++ tests/unit/sparse/test_mm_single.py | 6 ++++-- tests/unit/sparse/test_mm_vmap.py | 4 +++- tests/unit/sparse/test_patch.py | 6 ++++-- tests/unit/sparse/test_patch_idempotent.py | 5 +++-- tests/unit/sparse/test_patch_import.py | 16 ++++++++++++---- .../sparse/test_patch_torch_sparse_branch.py | 11 ++++++++--- tests/unit/sparse/test_patch_warn_branch.py | 11 +++++++---- tests/unit/sparse/test_sparse_mm_wrapper.py | 8 +++++--- tests/unit/sparse/test_utils_scipy.py | 4 +++- tests/unit/sparse/test_utils_torch_sparse.py | 11 +++++++---- 20 files changed, 85 insertions(+), 48 deletions(-) diff --git a/docs/source/examples/sparse.rst b/docs/source/examples/sparse.rst index a4cbd5395..53d47d4e7 100644 --- a/docs/source/examples/sparse.rst +++ b/docs/source/examples/sparse.rst @@ -2,7 +2,7 @@ Quick example ============================== TorchJD now offers helpers that make working with sparse adjacency matrices -transparent. +transparent. The key entry-point is :pyfunc:`torchjd.sparse.sparse_mm`, a vmap-aware autograd function that replaces the usual ``torch.sparse.mm`` inside Jacobian Descent pipelines. diff --git a/src/torchjd/_autojac/__init__.py b/src/torchjd/_autojac/__init__.py index 06ad73b3c..be1b1d9d7 100644 --- a/src/torchjd/_autojac/__init__.py +++ b/src/torchjd/_autojac/__init__.py @@ -1,3 +1,4 @@ +from torchjd.sparse import sparse_mm + from ._backward import backward from ._mtl_backward import mtl_backward -from torchjd.sparse import sparse_mm diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 3233bfe6d..cc0f97b39 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -19,7 +19,6 @@ from ._sum import Sum from ._trimmed_mean import TrimmedMean from ._upgrad import UPGrad - from ._utils.check_dependencies import ( OptionalDepsNotInstalledError as _OptionalDepsNotInstalledError, ) diff --git a/src/torchjd/sparse/__init__.py b/src/torchjd/sparse/__init__.py index a61f2e055..f1648c928 100644 --- a/src/torchjd/sparse/__init__.py +++ b/src/torchjd/sparse/__init__.py @@ -14,6 +14,6 @@ __all__ = ["sparse_mm"] -# feature flag +# feature flag if os.getenv("TORCHJD_DISABLE_SPARSE", "0") != "1": enable_seamless_sparse() diff --git a/src/torchjd/sparse/_autograd.py b/src/torchjd/sparse/_autograd.py index 51ac0bcb4..76d22108f 100644 --- a/src/torchjd/sparse/_autograd.py +++ b/src/torchjd/sparse/_autograd.py @@ -21,8 +21,8 @@ def forward(A_like: torch.Tensor, X: torch.Tensor) -> torch.Tensor: # noqa: D40 if X.dim() == 3: # (B, N, d) B, N, d = X.shape X2d = X.reshape(B * N, d).view(N, B * d) - Y2d = _orig_sparse_mm(A, X2d) # pragma: no cover - return Y2d.view(N, B, d).permute(1, 0, 2) # pragma: no cover + Y2d = _orig_sparse_mm(A, X2d) # pragma: no cover + return Y2d.view(N, B, d).permute(1, 0, 2) # pragma: no cover return _orig_sparse_mm(A, X) @@ -47,16 +47,16 @@ def backward(ctx, dY: torch.Tensor) -> Tuple[None, torch.Tensor]: @staticmethod def vmap(info, in_dims, A_unbatched, X_batched): # noqa: D401 - A = A_unbatched # shared - X = X_batched # (B, N, d) + A = A_unbatched # shared + X = X_batched # (B, N, d) B, N, d = X.shape X2d = X.reshape(B * N, d).view(N, B * d) Y2d = _orig_sparse_mm(A, X2d) Y = Y2d.view(N, B, d).permute(1, 0, 2) - return Y, 0 # output & out-dims + return Y, 0 # output & out-dims def sparse_mm(A_like: torch.Tensor, X: torch.Tensor) -> torch.Tensor: """Return ``A @ X`` through the vmap-safe sparse Function.""" - return _SparseMatMul.apply(A_like, X) \ No newline at end of file + return _SparseMatMul.apply(A_like, X) diff --git a/src/torchjd/sparse/_patch.py b/src/torchjd/sparse/_patch.py index 493c57438..6993dac11 100644 --- a/src/torchjd/sparse/_patch.py +++ b/src/torchjd/sparse/_patch.py @@ -16,7 +16,6 @@ from ._autograd import sparse_mm - # The wheel might exist yet be ABI-incompatible with the current # PyTorch, which raises *OSError* at import-time. @@ -29,6 +28,7 @@ # Helpers def _wrap_mm(orig_fn: Callable, wrapper: Callable) -> Callable: """Return a patched ``torch.sparse.mm`` that defers to *wrapper*.""" + def patched(A, X): # noqa: D401 if isinstance(A, torch.Tensor) and A.is_sparse and X.dim() >= 2: return wrapper(A, X) @@ -52,11 +52,9 @@ def enable_seamless_sparse() -> None: # torch.sparse.mm if getattr(torch.sparse, "_orig_mm", None) is None: torch.sparse._orig_mm = torch.sparse.mm # type: ignore[attr-defined] - torch.sparse.mm = _wrap_mm( # type: ignore[attr-defined] - torch.sparse._orig_mm, sparse_mm - ) + torch.sparse.mm = _wrap_mm(torch.sparse._orig_mm, sparse_mm) # type: ignore[attr-defined] - # tensor @ dense + # tensor @ dense if getattr(torch.Tensor, "_orig_matmul", None) is None: torch.Tensor._orig_matmul = torch.Tensor.__matmul__ # type: ignore[attr-defined] # noqa: E501 torch.Tensor.__matmul__ = _wrap_tensor_matmul( @@ -69,10 +67,11 @@ def enable_seamless_sparse() -> None: "torch_sparse not found: SpSpMM will use slow fallback.", RuntimeWarning, stacklevel=2, - ) # pragma: no cover + ) # pragma: no cover return if not hasattr(torch_sparse.SparseTensor, "_orig_matmul"): + def _sparse_tensor_matmul(self, dense): # noqa: D401 return sparse_mm(self, dense) diff --git a/src/torchjd/sparse/_utils.py b/src/torchjd/sparse/_utils.py index a661aae43..795bc7fc6 100644 --- a/src/torchjd/sparse/_utils.py +++ b/src/torchjd/sparse/_utils.py @@ -34,4 +34,4 @@ def to_coalesced_coo(x: Any) -> torch.Tensor: except ModuleNotFoundError: # pragma: no cover pass - raise TypeError(f"Unsupported sparse type: {type(x)}") # pragma: no cover + raise TypeError(f"Unsupported sparse type: {type(x)}") # pragma: no cover diff --git a/tests/unit/sparse/test_mm.py b/tests/unit/sparse/test_mm.py index c24efd8ca..c6143e3d8 100644 --- a/tests/unit/sparse/test_mm.py +++ b/tests/unit/sparse/test_mm.py @@ -1,10 +1,13 @@ -import torch import pytest +import torch + from torchjd.sparse import sparse_mm from torchjd.sparse._utils import to_coalesced_coo try: - import importlib, types + import importlib + import types + torch_sparse = importlib.import_module("torch_sparse") # noqa: E402 HAVE_TORCH_SPARSE = isinstance(torch_sparse, types.ModuleType) except (ModuleNotFoundError, OSError): @@ -13,6 +16,7 @@ try: import scipy.sparse as sp + HAVE_SCIPY = True except ModuleNotFoundError: HAVE_SCIPY = False @@ -32,7 +36,7 @@ def _batched_features(device): def test_vmap_branch(device): A = _dense_graph().to(device) X = _batched_features(device) - Y = sparse_mm(A, X) # calls vmap-aware branch + Y = sparse_mm(A, X) # calls vmap-aware branch assert Y.shape == X.shape diff --git a/tests/unit/sparse/test_mm_3d.py b/tests/unit/sparse/test_mm_3d.py index 08eddc604..c3c73a38e 100644 --- a/tests/unit/sparse/test_mm_3d.py +++ b/tests/unit/sparse/test_mm_3d.py @@ -1,6 +1,8 @@ import torch + from torchjd.sparse import sparse_mm + def test_forward_backward_3d(): # sparse 2×2 matrix A = torch.sparse_coo_tensor([[0, 1], [1, 0]], [1.0, 1.0]).coalesce() @@ -8,9 +10,9 @@ def test_forward_backward_3d(): # 3-D dense tensor (B=3, N=2, d=4) X = torch.randn(3, 2, 4, requires_grad=True) - Y = sparse_mm(A, X) # exercises 3-D forward branch + Y = sparse_mm(A, X) # exercises 3-D forward branch loss = Y.sum() - loss.backward() # exercises 3-D backward branch + loss.backward() # exercises 3-D backward branch # Gradient should be ones because A.T @ 1 = [1,1] → broadcast assert torch.allclose(X.grad, torch.ones_like(X), atol=1e-6) diff --git a/tests/unit/sparse/test_mm_sequential.py b/tests/unit/sparse/test_mm_sequential.py index 32da000ee..e78c779d3 100644 --- a/tests/unit/sparse/test_mm_sequential.py +++ b/tests/unit/sparse/test_mm_sequential.py @@ -1,8 +1,10 @@ import torch + from torchjd._autojac import backward from torchjd.aggregation import UPGrad from torchjd.sparse import sparse_mm + def test_sequential_backward(): A = torch.sparse_coo_tensor([[0, 1], [1, 0]], [1.0, 1.0]).coalesce() p = torch.tensor([1.0, 2.0], requires_grad=True) diff --git a/tests/unit/sparse/test_mm_single.py b/tests/unit/sparse/test_mm_single.py index 9e2628650..79bd9e577 100644 --- a/tests/unit/sparse/test_mm_single.py +++ b/tests/unit/sparse/test_mm_single.py @@ -1,10 +1,12 @@ import torch + from torchjd.sparse import sparse_mm + def test_single_forward_backward(): - A = torch.sparse_coo_tensor([[0,1],[1,0]], [1.,1.]).coalesce() + A = torch.sparse_coo_tensor([[0, 1], [1, 0]], [1.0, 1.0]).coalesce() X = torch.randn(2, 5, requires_grad=True) - Y = sparse_mm(A, X) # (2,5) + Y = sparse_mm(A, X) # (2,5) loss = Y.sum() loss.backward() # gradient should equal A.T @ 1 = [1,1] diff --git a/tests/unit/sparse/test_mm_vmap.py b/tests/unit/sparse/test_mm_vmap.py index d0441f07c..f57e52b91 100644 --- a/tests/unit/sparse/test_mm_vmap.py +++ b/tests/unit/sparse/test_mm_vmap.py @@ -1,13 +1,15 @@ import torch from torch.func import vmap + from torchjd.sparse import sparse_mm + def test_batched_vmap_forward_backward(): """ Touch the custom vmap rule in _SparseMatMul to push per-file coverage above the 90 % guideline. """ - A = torch.sparse_coo_tensor([[0, 1], [1, 0]], [1., 1.]).coalesce() + A = torch.sparse_coo_tensor([[0, 1], [1, 0]], [1.0, 1.0]).coalesce() B, N, d = 4, 2, 3 X = torch.randn(B, N, d, requires_grad=True) diff --git a/tests/unit/sparse/test_patch.py b/tests/unit/sparse/test_patch.py index 2551ef1a1..0ee6c0269 100644 --- a/tests/unit/sparse/test_patch.py +++ b/tests/unit/sparse/test_patch.py @@ -1,10 +1,12 @@ import torch + from torchjd.sparse._patch import enable_seamless_sparse + def test_monkey_patch_matmul(): enable_seamless_sparse() # idempotent A = torch.sparse_coo_tensor([[0, 1], [1, 0]], [1.0, 1.0]).coalesce() X = torch.randn(2, 3) - Y1 = A @ X # should hit sparse_mm via patched __matmul__ - Y2 = torch.tensor([[0., 0., 0.], [0., 0., 0.]]) # placeholder + Y1 = A @ X # should hit sparse_mm via patched __matmul__ + Y2 = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) # placeholder assert torch.allclose(Y1.sum(), (A.to_dense() @ X).sum()) diff --git a/tests/unit/sparse/test_patch_idempotent.py b/tests/unit/sparse/test_patch_idempotent.py index 31de28075..32b0b36c8 100644 --- a/tests/unit/sparse/test_patch_idempotent.py +++ b/tests/unit/sparse/test_patch_idempotent.py @@ -1,5 +1,6 @@ from torchjd.sparse._patch import enable_seamless_sparse + def test_enable_patch_idempotent(): - enable_seamless_sparse() # first call patches - enable_seamless_sparse() # second call should be a no-op + enable_seamless_sparse() # first call patches + enable_seamless_sparse() # second call should be a no-op diff --git a/tests/unit/sparse/test_patch_import.py b/tests/unit/sparse/test_patch_import.py index e9ea0a66d..2b687d311 100644 --- a/tests/unit/sparse/test_patch_import.py +++ b/tests/unit/sparse/test_patch_import.py @@ -1,6 +1,9 @@ -import importlib, sys, types +import importlib +import sys +import types from contextlib import contextmanager + @contextmanager def fake_torch_sparse(): """ @@ -15,7 +18,7 @@ class Dummy: # noqa: D401 def matmul(self, dense): raise NotImplementedError - mod.SparseTensor = Dummy # type: ignore + mod.SparseTensor = Dummy # type: ignore sys.modules["torch_sparse"] = mod try: yield @@ -26,13 +29,18 @@ def matmul(self, dense): def test_patch_without_torch_sparse(monkeypatch): monkeypatch.setitem(sys.modules, "torch_sparse", None) from importlib import reload + import torchjd.sparse._patch as p - reload(p) # re-import to trigger patch + + reload(p) # re-import to trigger patch assert p.torch_sparse is None # slow fallback branch hit + def test_patch_with_dummy_torch_sparse(monkeypatch): with fake_torch_sparse(): from importlib import reload + import torchjd.sparse._patch as p + reload(p) - assert p.torch_sparse is not None # optional branch hit + assert p.torch_sparse is not None # optional branch hit diff --git a/tests/unit/sparse/test_patch_torch_sparse_branch.py b/tests/unit/sparse/test_patch_torch_sparse_branch.py index 87c959645..5687b0ba5 100644 --- a/tests/unit/sparse/test_patch_torch_sparse_branch.py +++ b/tests/unit/sparse/test_patch_torch_sparse_branch.py @@ -1,6 +1,9 @@ -import importlib, sys, types +import importlib +import sys +import types from importlib import reload + def _make_dummy_torch_sparse(): """ Return a minimal torch_sparse stub: @@ -16,6 +19,7 @@ def matmul(self, dense): def to_torch_sparse_coo_tensor(self): import torch + return torch.sparse_coo_tensor([[0], [0]], [1.0], (1, 1)) dummy_mod.SparseTensor = DummyTensor # type: ignore[attr-defined] @@ -27,10 +31,11 @@ def test_full_torch_sparse_branch(monkeypatch): monkeypatch.setitem(sys.modules, "torch_sparse", _make_dummy_torch_sparse()) # Force the patch module to re-evaluate from scratch - import torchjd.sparse._patch as p # noqa: E402 - # Remove earlier sentinel attributes so enable_seamless_sparse() re-patches import torch + + import torchjd.sparse._patch as p # noqa: E402 + for attr in ("_orig_mm",): if hasattr(torch.sparse, attr): delattr(torch.sparse, attr) # type: ignore[attr-defined] diff --git a/tests/unit/sparse/test_patch_warn_branch.py b/tests/unit/sparse/test_patch_warn_branch.py index 67b566776..303f66abb 100644 --- a/tests/unit/sparse/test_patch_warn_branch.py +++ b/tests/unit/sparse/test_patch_warn_branch.py @@ -3,7 +3,11 @@ when *no* ``torch_sparse`` package is available. """ -import importlib, sys, types, warnings +import importlib +import sys +import types +import warnings + import torch @@ -14,12 +18,11 @@ def test_warn_branch(monkeypatch): delattr(torch.sparse, "_orig_mm") # type: ignore[attr-defined] import torchjd.sparse._patch as p # noqa: E402 + p = importlib.reload(p) with warnings.catch_warnings(record=True) as rec: warnings.simplefilter("always") p.enable_seamless_sparse() # <- emits RuntimeWarning branch - assert any( - "SpSpMM will use slow fallback" in str(w.message) for w in rec - ) + assert any("SpSpMM will use slow fallback" in str(w.message) for w in rec) diff --git a/tests/unit/sparse/test_sparse_mm_wrapper.py b/tests/unit/sparse/test_sparse_mm_wrapper.py index 204f12cf7..18a938735 100644 --- a/tests/unit/sparse/test_sparse_mm_wrapper.py +++ b/tests/unit/sparse/test_sparse_mm_wrapper.py @@ -1,11 +1,13 @@ import torch + from torchjd.sparse._patch import enable_seamless_sparse + def test_torch_sparse_mm_wrapper(): - enable_seamless_sparse() # idempotent - A = torch.sparse_coo_tensor([[0, 1], [1, 0]], [1., 1.]).coalesce() + enable_seamless_sparse() # idempotent + A = torch.sparse_coo_tensor([[0, 1], [1, 0]], [1.0, 1.0]).coalesce() X = torch.randn(2, 3) - out = torch.sparse.mm(A, X) # routed through wrapper + out = torch.sparse.mm(A, X) # routed through wrapper ref = A.to_dense() @ X assert torch.allclose(out, ref, atol=1e-6) diff --git a/tests/unit/sparse/test_utils_scipy.py b/tests/unit/sparse/test_utils_scipy.py index 9a70fe0e4..d3bd96855 100644 --- a/tests/unit/sparse/test_utils_scipy.py +++ b/tests/unit/sparse/test_utils_scipy.py @@ -1,10 +1,12 @@ import importlib -import pytest + import numpy as np +import pytest scipy = pytest.importorskip("scipy") # skip if SciPy not available from torchjd.sparse._utils import to_coalesced_coo + def test_to_coalesced_coo_from_scipy(): sp = importlib.import_module("scipy.sparse") # 2×2 off-diagonal ones diff --git a/tests/unit/sparse/test_utils_torch_sparse.py b/tests/unit/sparse/test_utils_torch_sparse.py index a4cb89ed1..3e4fa043a 100644 --- a/tests/unit/sparse/test_utils_torch_sparse.py +++ b/tests/unit/sparse/test_utils_torch_sparse.py @@ -1,4 +1,9 @@ -import importlib, sys, types, torch +import importlib +import sys +import types + +import torch + def test_to_coalesced_coo_torch_sparse(monkeypatch): dummy = types.ModuleType("torch_sparse") @@ -10,9 +15,7 @@ def __init__(self): self.value = torch.tensor([1.0]) def to_torch_sparse_coo_tensor(self): - return torch.sparse_coo_tensor( - torch.stack([self.row, self.col]), self.value, (1, 1) - ) + return torch.sparse_coo_tensor(torch.stack([self.row, self.col]), self.value, (1, 1)) def matmul(self, other): raise NotImplementedError