Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ae6be7d
WIP: add gramian-based jac_to_grad
ValerianRey Jan 21, 2026
8bdf512
Update changelog
ValerianRey Jan 21, 2026
aaf2544
Use deque to free memory asap
ValerianRey Jan 23, 2026
64b06ad
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 23, 2026
745f707
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 28, 2026
5eb77f9
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 28, 2026
8f65caa
Use gramian_weighting in jac_to_grad
ValerianRey Jan 28, 2026
6fe15a4
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 28, 2026
d5cb5c2
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 29, 2026
f986950
Only optimize when no forward hooks
ValerianRey Jan 29, 2026
4cf5cbb
Make _gramian_based take aggregator instead of weighting
ValerianRey Jan 29, 2026
add549c
Add _can_skip_jacobian_combination helper function
ValerianRey Jan 29, 2026
453971a
Add test_can_skip_jacobian_combination
ValerianRey Jan 29, 2026
9d4c41c
Optimize compute_gramian for when contracted_dims=-1
ValerianRey Jan 29, 2026
48cd70b
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Jan 30, 2026
8f2660d
Use TypeGuard in _can_skip_jacobian_combination
ValerianRey Jan 30, 2026
fc9bbcf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 30, 2026
3f9a6d1
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 1, 2026
9d9cbf0
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 4, 2026
b5ca226
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2026
0baa914
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 5, 2026
2ed1d7c
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 13, 2026
86be778
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 13, 2026
4ace19e
Merge branch 'main' into optimize_jac_to_grad
ValerianRey Feb 13, 2026
2a84bef
Add ruff if-else squeezing
ValerianRey Feb 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,8 @@ changelog does not include internal changes that do not affect the user.
jac_to_grad(shared_module.parameters(), aggregator)
```

- Removed an unnecessary memory duplication. This should significantly improve the memory efficiency
of `autojac`.
- Removed an unnecessary internal cloning of gradient. This should slightly improve the memory
efficiency of `autojac`.
- Removed several unnecessary memory duplications. This should significantly improve the memory
efficiency and speed of `autojac`.
- Increased the lower bounds of the torch (from 2.0.0 to 2.3.0) and numpy (from 1.21.0
to 1.21.2) dependencies to reflect what really works with torchjd. We now also run torchjd's tests
with the dependency lower-bounds specified in `pyproject.toml`, so we should now always accurately
Expand Down
19 changes: 14 additions & 5 deletions src/torchjd/_linalg/_gramian.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,20 @@ def compute_gramian(t: Tensor, contracted_dims: int = -1) -> PSDTensor:
first dimension).
"""

contracted_dims = contracted_dims if contracted_dims >= 0 else contracted_dims + t.ndim
indices_source = list(range(t.ndim - contracted_dims))
indices_dest = list(range(t.ndim - 1, contracted_dims - 1, -1))
transposed = t.movedim(indices_source, indices_dest)
gramian = torch.tensordot(t, transposed, dims=contracted_dims)
# Optimization: it's faster to do that than moving dims and using tensordot, and this case
# happens very often, sometimes hundreds of times for a single jac_to_grad.
if contracted_dims == -1:
matrix = t.unsqueeze(1) if t.ndim == 1 else t.flatten(start_dim=1)

gramian = matrix @ matrix.T

else:
contracted_dims = contracted_dims if contracted_dims >= 0 else contracted_dims + t.ndim
indices_source = list(range(t.ndim - contracted_dims))
indices_dest = list(range(t.ndim - 1, contracted_dims - 1, -1))
transposed = t.movedim(indices_source, indices_dest)
gramian = torch.tensordot(t, transposed, dims=contracted_dims)

return cast(PSDTensor, gramian)


Expand Down
70 changes: 58 additions & 12 deletions src/torchjd/autojac/_jac_to_grad.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from collections import deque
from collections.abc import Iterable
from typing import TypeGuard, cast

import torch
from torch import Tensor
from torch import Tensor, nn

from torchjd._linalg import PSDMatrix, compute_gramian
from torchjd.aggregation import Aggregator
from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator

from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac

Expand Down Expand Up @@ -65,31 +69,73 @@ def jac_to_grad(
if len(tensors_) == 0:
return

jacobians = [t.jac for t in tensors_]
jacobians = deque(t.jac for t in tensors_)

if not all(jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:]):
if not all(jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians):
raise ValueError("All Jacobians should have the same number of rows.")

if not retain_jac:
_free_jacs(tensors_)

if _can_skip_jacobian_combination(aggregator):
gradients = _gramian_based(aggregator, jacobians, tensors_)
else:
gradients = _jacobian_based(aggregator, jacobians, tensors_)
accumulate_grads(tensors_, gradients)


def _can_skip_jacobian_combination(aggregator: Aggregator) -> TypeGuard[GramianWeightedAggregator]:
return isinstance(aggregator, GramianWeightedAggregator) and not _has_forward_hook(aggregator)


def _has_forward_hook(module: nn.Module) -> bool:
"""Return whether the module has any forward hook registered."""
return len(module._forward_hooks) > 0 or len(module._forward_pre_hooks) > 0


def _jacobian_based(
aggregator: Aggregator,
jacobians: deque[Tensor],
tensors: list[TensorWithJac],
) -> list[Tensor]:
jacobian_matrix = _unite_jacobians(jacobians)
gradient_vector = aggregator(jacobian_matrix)
gradients = _disunite_gradient(gradient_vector, jacobians, tensors_)
accumulate_grads(tensors_, gradients)
gradients = _disunite_gradient(gradient_vector, tensors)
return gradients


def _unite_jacobians(jacobians: list[Tensor]) -> Tensor:
jacobian_matrices = [jacobian.reshape(jacobian.shape[0], -1) for jacobian in jacobians]
def _gramian_based(
aggregator: GramianWeightedAggregator,
jacobians: deque[Tensor],
tensors: list[TensorWithJac],
) -> list[Tensor]:
weighting = aggregator.gramian_weighting
gramian = _compute_gramian_sum(jacobians)
weights = weighting(gramian)

gradients = list[Tensor]()
while jacobians:
jacobian = jacobians.popleft() # get jacobian + dereference it to free memory asap
gradients.append(torch.tensordot(weights, jacobian, dims=1))

return gradients


def _compute_gramian_sum(jacobians: deque[Tensor]) -> PSDMatrix:
gramian = sum([compute_gramian(matrix) for matrix in jacobians])
return cast(PSDMatrix, gramian)


def _unite_jacobians(jacobians: deque[Tensor]) -> Tensor:
jacobian_matrices = list[Tensor]()
while jacobians:
jacobian = jacobians.popleft() # get jacobian + dereference it to free memory asap
jacobian_matrices.append(jacobian.reshape(jacobian.shape[0], -1))
jacobian_matrix = torch.concat(jacobian_matrices, dim=1)
return jacobian_matrix


def _disunite_gradient(
gradient_vector: Tensor,
jacobians: list[Tensor],
tensors: list[TensorWithJac],
) -> list[Tensor]:
def _disunite_gradient(gradient_vector: Tensor, tensors: list[TensorWithJac]) -> list[Tensor]:
gradient_vectors = gradient_vector.split([t.numel() for t in tensors])
gradients = [g.view(t.shape) for g, t in zip(gradient_vectors, tensors, strict=True)]
return gradients
Expand Down
117 changes: 115 additions & 2 deletions tests/unit/autojac/test_jac_to_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,28 @@
from utils.asserts import assert_grad_close, assert_has_jac, assert_has_no_jac
from utils.tensors import tensor_

from torchjd.aggregation import Aggregator, Mean, PCGrad, UPGrad
from torchjd.autojac._jac_to_grad import jac_to_grad
from torchjd.aggregation import (
IMTLG,
MGDA,
Aggregator,
AlignedMTL,
ConFIG,
Constant,
DualProj,
GradDrop,
Krum,
Mean,
PCGrad,
Random,
Sum,
TrimmedMean,
UPGrad,
)
from torchjd.autojac._jac_to_grad import (
_can_skip_jacobian_combination,
_has_forward_hook,
jac_to_grad,
)


@mark.parametrize("aggregator", [Mean(), UPGrad(), PCGrad()])
Expand Down Expand Up @@ -101,3 +121,96 @@ def test_jacs_are_freed(retain_jac: bool):
check = assert_has_jac if retain_jac else assert_has_no_jac
check(t1)
check(t2)


def test_has_forward_hook():
"""Tests that _has_forward_hook correctly detects the presence of forward hooks."""

module = UPGrad()

def dummy_forward_hook(_module, _input, _output):
return _output

def dummy_forward_pre_hook(_module, _input):
return _input

def dummy_backward_hook(_module, _grad_input, _grad_output):
return _grad_input

def dummy_backward_pre_hook(_module, _grad_output):
return _grad_output

# Module with no hooks or backward hooks only should return False
assert not _has_forward_hook(module)
module.register_full_backward_hook(dummy_backward_hook)
assert not _has_forward_hook(module)
module.register_full_backward_pre_hook(dummy_backward_pre_hook)
assert not _has_forward_hook(module)

# Module with forward hook should return True
handle1 = module.register_forward_hook(dummy_forward_hook)
assert _has_forward_hook(module)
handle2 = module.register_forward_hook(dummy_forward_hook)
assert _has_forward_hook(module)
handle1.remove()
assert _has_forward_hook(module)
handle2.remove()
assert not _has_forward_hook(module)

# Module with forward pre-hook should return True
handle3 = module.register_forward_pre_hook(dummy_forward_pre_hook)
assert _has_forward_hook(module)
handle4 = module.register_forward_pre_hook(dummy_forward_pre_hook)
assert _has_forward_hook(module)
handle3.remove()
assert _has_forward_hook(module)
handle4.remove()
assert not _has_forward_hook(module)


_PARAMETRIZATIONS = [
(AlignedMTL(), True),
(DualProj(), True),
(IMTLG(), True),
(Krum(n_byzantine=1), True),
(MGDA(), True),
(PCGrad(), True),
(UPGrad(), True),
(ConFIG(), False),
(Constant(tensor_([0.5, 0.5])), False),
(GradDrop(), False),
(Mean(), False),
(Random(), False),
(Sum(), False),
(TrimmedMean(trim_number=1), False),
]

try:
from torchjd.aggregation import CAGrad

_PARAMETRIZATIONS.append((CAGrad(c=0.5), True))
except ImportError:
pass

try:
from torchjd.aggregation import NashMTL

_PARAMETRIZATIONS.append((NashMTL(n_tasks=2), False))
except ImportError:
pass


@mark.parametrize("aggregator, expected", _PARAMETRIZATIONS)
def test_can_skip_jacobian_combination(aggregator: Aggregator, expected: bool):
"""
Tests that _can_skip_jacobian_combination correctly identifies when optimization can be used.
"""

assert _can_skip_jacobian_combination(aggregator) == expected
handle = aggregator.register_forward_hook(lambda module, input, output: output)
assert not _can_skip_jacobian_combination(aggregator)
handle.remove()
handle = aggregator.register_forward_pre_hook(lambda module, input: input)
assert not _can_skip_jacobian_combination(aggregator)
handle.remove()
assert _can_skip_jacobian_combination(aggregator) == expected