Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
a975e7e
feat(autojac): Add `jac_tensors` to `backward`
PierreQuinton Feb 13, 2026
e68caf9
Add usage example using `jac_tensors` to `backward`.
PierreQuinton Feb 13, 2026
e1ee125
Add tests for `jac_tensors` in test_backward
PierreQuinton Feb 13, 2026
873d56f
Add another usage example to `backward`
PierreQuinton Feb 13, 2026
ff0cb23
Remove TODOs
PierreQuinton Feb 13, 2026
a65a193
Fix usage of methods in `tests.utils.tensors.`
PierreQuinton Feb 13, 2026
7d97fd7
Merge branch 'main' into add-jac-tensors-to-backward
ValerianRey Feb 13, 2026
940c3c3
Merge branch 'main' into add-jac-tensors-to-backward
ValerianRey Feb 13, 2026
271dbbb
Fix second usage example
PierreQuinton Feb 14, 2026
49c0934
Fix third usage example
PierreQuinton Feb 14, 2026
1d82773
Remove outdated comment
PierreQuinton Feb 14, 2026
36fead0
Improve the third example.
PierreQuinton Feb 14, 2026
b91da92
Improve docstring of backward
ValerianRey Feb 15, 2026
2bcfa88
Merge branch 'main' into add-jac-tensors-to-backward
ValerianRey Feb 15, 2026
ca9ee9f
Explain analogy between jac_tensors and grad_tensors
ValerianRey Feb 15, 2026
d9dbea1
Extract create_jac_tensors_dict and create_transform, remove checked_…
ValerianRey Feb 15, 2026
89302d4
Merge branch 'main' into add-jac-tensors-to-backward
ValerianRey Feb 15, 2026
554ad32
Add jac_outputs parameter to jac function
ValerianRey Feb 16, 2026
cdb2a31
Remove create_transform in jac
ValerianRey Feb 16, 2026
175c91d
Re-add the explanation of the number of rows of the jacobian
ValerianRey Feb 16, 2026
fec3a13
Revert "Remove create_transform in jac"
ValerianRey Feb 16, 2026
b92327a
simplify _create_transform in jac
ValerianRey Feb 16, 2026
2d0ebb0
Add early validation for inconsistent Jacobian dimensions
ValerianRey Feb 16, 2026
3ca0938
Add early validation for Jacobian length mismatch
ValerianRey Feb 16, 2026
fb332a4
Add early validation for Jacobian shape mismatch
ValerianRey Feb 16, 2026
4c7fa3c
Add changelog entries for jac_tensors and jac_outputs parameters
ValerianRey Feb 16, 2026
7582d0a
Merge branch 'main' into add-jac-tensors-to-backward
ValerianRey Feb 16, 2026
51ecb4a
Apply claude suggestions
ValerianRey Feb 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ changelog does not include internal changes that do not affect the user.
- Added the function `torchjd.autojac.jac`. It's the same as `torchjd.autojac.backward` except that
it returns the Jacobians as a tuple instead of storing them in the `.jac` fields of the inputs.
Its interface is analog to that of `torch.autograd.grad`.
- Added a `jac_tensors` parameter to `backward`, allowing to pre-multiply the Jacobian computation
by initial Jacobians. This enables multi-step chain rule computations and is analogous to the
`grad_tensors` parameter in `torch.autograd.backward`.
- Added a `jac_outputs` parameter to `jac`, allowing to pre-multiply the Jacobian computation by
initial Jacobians. This is analogous to the `grad_outputs` parameter in `torch.autograd.grad`.
- Added a `scale_mode` parameter to `AlignedMTL` and `AlignedMTLWeighting`, allowing to choose
between `"min"`, `"median"`, and `"rmse"` scaling.
- Added an attribute `gramian_weighting` to all aggregators that use a gramian-based `Weighting`.
Expand Down
105 changes: 79 additions & 26 deletions src/torchjd/autojac/_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,37 @@
from torch import Tensor

from ._transform import AccumulateJac, Diagonalize, Init, Jac, OrderedSet, Transform
from ._utils import as_checked_ordered_set, check_optional_positive_chunk_size, get_leaf_tensors
from ._utils import (
as_checked_ordered_set,
check_consistent_first_dimension,
check_matching_length,
check_matching_shapes,
check_optional_positive_chunk_size,
get_leaf_tensors,
)


def backward(
tensors: Sequence[Tensor] | Tensor,
jac_tensors: Sequence[Tensor] | Tensor | None = None,
inputs: Iterable[Tensor] | None = None,
retain_graph: bool = False,
parallel_chunk_size: int | None = None,
) -> None:
r"""
Computes the Jacobians of all values in ``tensors`` with respect to all ``inputs`` and
accumulates them in the ``.jac`` fields of the ``inputs``.

:param tensors: The tensor or tensors to differentiate. Should be non-empty. The Jacobians will
have one row for each value of each of these tensors.
Computes the Jacobians of ``tensors`` with respect to ``inputs``, left-multiplied by
``jac_tensors`` (or identity if ``jac_tensors`` is ``None``), and accumulates the results in the
``.jac`` fields of the ``inputs``.

:param tensors: The tensor or tensors to differentiate. Should be non-empty.
:param jac_tensors: The initial Jacobians to backpropagate, analog to the `grad_tensors`
parameter of `torch.autograd.backward`. If provided, it must have the same structure as
``tensors`` and each tensor in ``jac_tensors`` must match the shape of the corresponding
tensor in ``tensors``, with an extra leading dimension representing the number of rows of
the resulting Jacobian (e.g. the number of losses). All tensors in ``jac_tensors`` must
have the same first dimension. If ``None``, defaults to the identity matrix. In this case,
the standard Jacobian of ``tensors`` is computed, with one row for each value in the
``tensors``.
:param inputs: The tensors with respect to which the Jacobians must be computed. These must have
their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors
that were used to compute the ``tensors`` parameter.
Expand All @@ -32,7 +48,7 @@ def backward(
.. admonition::
Example

The following code snippet showcases a simple usage of ``backward``.
This example shows a simple usage of ``backward``.

>>> import torch
>>>
Expand All @@ -52,6 +68,33 @@ def backward(
The ``.jac`` field of ``param`` now contains the Jacobian of
:math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``.

.. admonition::
Example

This is the same example as before, except that we explicitly specify ``jac_tensors`` as
the rows of the identity matrix (which is equivalent to using the default ``None``).

>>> import torch
>>>
>>> from torchjd.autojac import backward
>>>
>>> param = torch.tensor([1., 2.], requires_grad=True)
>>> # Compute arbitrary quantities that are function of param
>>> y1 = torch.tensor([-1., 1.]) @ param
>>> y2 = (param ** 2).sum()
>>>
>>> J1 = torch.tensor([1.0, 0.0])
>>> J2 = torch.tensor([0.0, 1.0])
>>>
>>> backward([y1, y2], jac_tensors=[J1, J2])
>>>
>>> param.jac
tensor([[-1., 1.],
[ 2., 4.]])

Instead of using the identity ``jac_tensors``, you can backpropagate some Jacobians obtained
by a call to :func:`torchjd.autojac.jac` on a later part of the computation graph.

.. warning::
To differentiate in parallel, ``backward`` relies on ``torch.vmap``, which has some
limitations: `it does not work on the output of compiled functions
Expand All @@ -73,34 +116,44 @@ def backward(
else:
inputs_ = OrderedSet(inputs)

backward_transform = _create_transform(
tensors=tensors_,
inputs=inputs_,
retain_graph=retain_graph,
parallel_chunk_size=parallel_chunk_size,
)
jac_tensors_dict = _create_jac_tensors_dict(tensors_, jac_tensors)
transform = _create_transform(tensors_, inputs_, parallel_chunk_size, retain_graph)
transform(jac_tensors_dict)


def _create_jac_tensors_dict(
tensors: OrderedSet[Tensor],
opt_jac_tensors: Sequence[Tensor] | Tensor | None,
) -> dict[Tensor, Tensor]:
"""
Creates a dictionary mapping tensors to their corresponding Jacobians.

backward_transform({})
:param tensors: The tensors to differentiate.
:param opt_jac_tensors: The initial Jacobians to backpropagate. If ``None``, defaults to
identity.
"""
if opt_jac_tensors is None:
# Transform that creates gradient outputs containing only ones.
init = Init(tensors)
# Transform that turns the gradients into Jacobians.
diag = Diagonalize(tensors)
return (diag << init)({})
jac_tensors = [opt_jac_tensors] if isinstance(opt_jac_tensors, Tensor) else opt_jac_tensors
check_matching_length(jac_tensors, tensors, "jac_tensors", "tensors")
check_matching_shapes(jac_tensors, tensors, "jac_tensors", "tensors")
check_consistent_first_dimension(jac_tensors, "jac_tensors")
return dict(zip(tensors, jac_tensors, strict=True))


def _create_transform(
tensors: OrderedSet[Tensor],
inputs: OrderedSet[Tensor],
retain_graph: bool,
parallel_chunk_size: int | None,
retain_graph: bool,
) -> Transform:
"""Creates the backward transform."""

# Transform that creates gradient outputs containing only ones.
init = Init(tensors)

# Transform that turns the gradients into Jacobians.
diag = Diagonalize(tensors)

"""Creates the backward transform that computes and accumulates Jacobians."""
# Transform that computes the required Jacobians.
jac = Jac(tensors, inputs, parallel_chunk_size, retain_graph)

# Transform that accumulates the result in the .jac field of the inputs.
accumulate = AccumulateJac()

return accumulate << jac << diag << init
return accumulate << jac
99 changes: 74 additions & 25 deletions src/torchjd/autojac/_jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from torchjd.autojac._transform._ordered_set import OrderedSet
from torchjd.autojac._utils import (
as_checked_ordered_set,
check_consistent_first_dimension,
check_matching_length,
check_matching_shapes,
check_optional_positive_chunk_size,
get_leaf_tensors,
)
Expand All @@ -17,19 +20,27 @@
def jac(
outputs: Sequence[Tensor] | Tensor,
inputs: Iterable[Tensor] | None = None,
jac_outputs: Sequence[Tensor] | Tensor | None = None,
retain_graph: bool = False,
parallel_chunk_size: int | None = None,
) -> tuple[Tensor, ...]:
r"""
Computes the Jacobian of all values in ``outputs`` with respect to all ``inputs``. Returns the
result as a tuple, with one Jacobian per input tensor. The returned Jacobian with respect to
input ``t`` has shape ``[m] + t.shape``.
Computes the Jacobians of ``outputs`` with respect to ``inputs``, left-multiplied by
``jac_outputs`` (or identity if ``jac_outputs`` is ``None``), and returns the result as a tuple,
with one Jacobian per input tensor. The returned Jacobian with respect to input ``t`` has shape
``[m] + t.shape``.

:param outputs: The tensor or tensors to differentiate. Should be non-empty. The Jacobians will
have one row for each value of each of these tensors.
:param outputs: The tensor or tensors to differentiate. Should be non-empty.
:param inputs: The tensors with respect to which the Jacobian must be computed. These must have
their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors
that were used to compute the ``outputs`` parameter.
:param jac_outputs: The initial Jacobians to backpropagate, analog to the ``grad_outputs``
parameter of ``torch.autograd.grad``. If provided, it must have the same structure as
``outputs`` and each tensor in ``jac_outputs`` must match the shape of the corresponding
tensor in ``outputs``, with an extra leading dimension representing the number of rows of
the resulting Jacobian (e.g. the number of losses). If ``None``, defaults to the identity
matrix. In this case, the standard Jacobian of ``outputs`` is computed, with one row for
each value in the ``outputs``.
:param retain_graph: If ``False``, the graph used to compute the grad will be freed. Defaults to
``False``.
:param parallel_chunk_size: The number of scalars to differentiate simultaneously in the
Expand Down Expand Up @@ -60,7 +71,7 @@ def jac(
>>> jacobians = jac([y1, y2], [param])
>>>
>>> jacobians
(tensor([-1., 1.],
(tensor([[-1., 1.],
[ 2., 4.]]),)

.. admonition::
Expand Down Expand Up @@ -99,6 +110,34 @@ def jac(
gradients are exactly orthogonal (they have an inner product of 0), but they conflict with
the third gradient (inner product of -1 and -3).

.. admonition::
Example

This example shows how to apply chain rule using the ``jac_outputs`` parameter to compute
the Jacobian in two steps.

>>> import torch
>>>
>>> from torchjd.autojac import jac
>>>
>>> x = torch.tensor([1., 2.], requires_grad=True)
>>> # Compose functions: x -> h -> y
>>> h = x ** 2
>>> y1 = h.sum()
>>> y2 = torch.tensor([1., -1.]) @ h
>>>
>>> # Step 1: Compute d[y1,y2]/dh
>>> jac_h = jac([y1, y2], [h])[0] # Shape: [2, 2]
>>>
>>> # Step 2: Use chain rule to compute d[y1,y2]/dx = (d[y1,y2]/dh) @ (dh/dx)
>>> jac_x = jac(h, [x], jac_outputs=jac_h)[0]
>>>
>>> jac_x
tensor([[ 2., 4.],
[ 2., -4.]])

This two-step computation is equivalent to directly computing ``jac([y1, y2], [x])``.

.. warning::
To differentiate in parallel, ``jac`` relies on ``torch.vmap``, which has some
limitations: `it does not work on the output of compiled functions
Expand All @@ -122,30 +161,40 @@ def jac(
inputs_with_repetition = list(inputs) # Create a list to avoid emptying generator
inputs_ = OrderedSet(inputs_with_repetition)

jac_transform = _create_transform(
outputs=outputs_,
inputs=inputs_,
retain_graph=retain_graph,
parallel_chunk_size=parallel_chunk_size,
)

result = jac_transform({})
jac_outputs_dict = _create_jac_outputs_dict(outputs_, jac_outputs)
transform = _create_transform(outputs_, inputs_, parallel_chunk_size, retain_graph)
result = transform(jac_outputs_dict)
return tuple(result[input] for input in inputs_with_repetition)


def _create_jac_outputs_dict(
outputs: OrderedSet[Tensor],
opt_jac_outputs: Sequence[Tensor] | Tensor | None,
) -> dict[Tensor, Tensor]:
"""
Creates a dictionary mapping outputs to their corresponding Jacobians.

:param outputs: The tensors to differentiate.
:param opt_jac_outputs: The initial Jacobians to backpropagate. If ``None``, defaults to
identity.
"""
if opt_jac_outputs is None:
# Transform that creates gradient outputs containing only ones.
init = Init(outputs)
# Transform that turns the gradients into Jacobians.
diag = Diagonalize(outputs)
return (diag << init)({})
jac_outputs = [opt_jac_outputs] if isinstance(opt_jac_outputs, Tensor) else opt_jac_outputs
check_matching_length(jac_outputs, outputs, "jac_outputs", "outputs")
check_matching_shapes(jac_outputs, outputs, "jac_outputs", "outputs")
check_consistent_first_dimension(jac_outputs, "jac_outputs")
return dict(zip(outputs, jac_outputs, strict=True))


def _create_transform(
outputs: OrderedSet[Tensor],
inputs: OrderedSet[Tensor],
retain_graph: bool,
parallel_chunk_size: int | None,
retain_graph: bool,
) -> Transform:
# Transform that creates gradient outputs containing only ones.
init = Init(outputs)

# Transform that turns the gradients into Jacobians.
diag = Diagonalize(outputs)

# Transform that computes the required Jacobians.
jac = Jac(outputs, inputs, parallel_chunk_size, retain_graph)

return jac << diag << init
return Jac(outputs, inputs, parallel_chunk_size, retain_graph)
4 changes: 2 additions & 2 deletions src/torchjd/autojac/_jac_to_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torchjd.aggregation import Aggregator

from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac
from ._utils import check_consistent_first_dimension


def jac_to_grad(
Expand Down Expand Up @@ -67,8 +68,7 @@ def jac_to_grad(

jacobians = [t.jac for t in tensors_]

if not all(jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:]):
raise ValueError("All Jacobians should have the same number of rows.")
check_consistent_first_dimension(jacobians, "tensors.jac")

if not retain_jac:
_free_jacs(tensors_)
Expand Down
Loading