Skip to content

Commit 7103b6e

Browse files
PierreQuintonValerianReyclaude
authored
feat(autojac): Add tangents to backward and jac (#562)
* Add `jac_tensors` parameter to `backward` * Add `jac_outputs` parameter to `jac` * Add some extra usage examples using these parameters * Add extra tests and doctests * Add changelog entry --------- Co-authored-by: Valérian Rey <valerian.rey@gmail.com> Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent 4ad95e7 commit 7103b6e

9 files changed

Lines changed: 506 additions & 63 deletions

File tree

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ changelog does not include internal changes that do not affect the user.
1313
- Added the function `torchjd.autojac.jac`. It's the same as `torchjd.autojac.backward` except that
1414
it returns the Jacobians as a tuple instead of storing them in the `.jac` fields of the inputs.
1515
Its interface is analog to that of `torch.autograd.grad`.
16+
- Added a `jac_tensors` parameter to `backward`, allowing to pre-multiply the Jacobian computation
17+
by initial Jacobians. This enables multi-step chain rule computations and is analogous to the
18+
`grad_tensors` parameter in `torch.autograd.backward`.
19+
- Added a `jac_outputs` parameter to `jac`, allowing to pre-multiply the Jacobian computation by
20+
initial Jacobians. This is analogous to the `grad_outputs` parameter in `torch.autograd.grad`.
1621
- Added a `scale_mode` parameter to `AlignedMTL` and `AlignedMTLWeighting`, allowing to choose
1722
between `"min"`, `"median"`, and `"rmse"` scaling.
1823
- Added an attribute `gramian_weighting` to all aggregators that use a gramian-based `Weighting`.

src/torchjd/autojac/_backward.py

Lines changed: 79 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,37 @@
33
from torch import Tensor
44

55
from ._transform import AccumulateJac, Diagonalize, Init, Jac, OrderedSet, Transform
6-
from ._utils import as_checked_ordered_set, check_optional_positive_chunk_size, get_leaf_tensors
6+
from ._utils import (
7+
as_checked_ordered_set,
8+
check_consistent_first_dimension,
9+
check_matching_length,
10+
check_matching_shapes,
11+
check_optional_positive_chunk_size,
12+
get_leaf_tensors,
13+
)
714

815

916
def backward(
1017
tensors: Sequence[Tensor] | Tensor,
18+
jac_tensors: Sequence[Tensor] | Tensor | None = None,
1119
inputs: Iterable[Tensor] | None = None,
1220
retain_graph: bool = False,
1321
parallel_chunk_size: int | None = None,
1422
) -> None:
1523
r"""
16-
Computes the Jacobians of all values in ``tensors`` with respect to all ``inputs`` and
17-
accumulates them in the ``.jac`` fields of the ``inputs``.
18-
19-
:param tensors: The tensor or tensors to differentiate. Should be non-empty. The Jacobians will
20-
have one row for each value of each of these tensors.
24+
Computes the Jacobians of ``tensors`` with respect to ``inputs``, left-multiplied by
25+
``jac_tensors`` (or identity if ``jac_tensors`` is ``None``), and accumulates the results in the
26+
``.jac`` fields of the ``inputs``.
27+
28+
:param tensors: The tensor or tensors to differentiate. Should be non-empty.
29+
:param jac_tensors: The initial Jacobians to backpropagate, analog to the `grad_tensors`
30+
parameter of `torch.autograd.backward`. If provided, it must have the same structure as
31+
``tensors`` and each tensor in ``jac_tensors`` must match the shape of the corresponding
32+
tensor in ``tensors``, with an extra leading dimension representing the number of rows of
33+
the resulting Jacobian (e.g. the number of losses). All tensors in ``jac_tensors`` must
34+
have the same first dimension. If ``None``, defaults to the identity matrix. In this case,
35+
the standard Jacobian of ``tensors`` is computed, with one row for each value in the
36+
``tensors``.
2137
:param inputs: The tensors with respect to which the Jacobians must be computed. These must have
2238
their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors
2339
that were used to compute the ``tensors`` parameter.
@@ -32,7 +48,7 @@ def backward(
3248
.. admonition::
3349
Example
3450
35-
The following code snippet showcases a simple usage of ``backward``.
51+
This example shows a simple usage of ``backward``.
3652
3753
>>> import torch
3854
>>>
@@ -52,6 +68,33 @@ def backward(
5268
The ``.jac`` field of ``param`` now contains the Jacobian of
5369
:math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``.
5470
71+
.. admonition::
72+
Example
73+
74+
This is the same example as before, except that we explicitly specify ``jac_tensors`` as
75+
the rows of the identity matrix (which is equivalent to using the default ``None``).
76+
77+
>>> import torch
78+
>>>
79+
>>> from torchjd.autojac import backward
80+
>>>
81+
>>> param = torch.tensor([1., 2.], requires_grad=True)
82+
>>> # Compute arbitrary quantities that are function of param
83+
>>> y1 = torch.tensor([-1., 1.]) @ param
84+
>>> y2 = (param ** 2).sum()
85+
>>>
86+
>>> J1 = torch.tensor([1.0, 0.0])
87+
>>> J2 = torch.tensor([0.0, 1.0])
88+
>>>
89+
>>> backward([y1, y2], jac_tensors=[J1, J2])
90+
>>>
91+
>>> param.jac
92+
tensor([[-1., 1.],
93+
[ 2., 4.]])
94+
95+
Instead of using the identity ``jac_tensors``, you can backpropagate some Jacobians obtained
96+
by a call to :func:`torchjd.autojac.jac` on a later part of the computation graph.
97+
5598
.. warning::
5699
To differentiate in parallel, ``backward`` relies on ``torch.vmap``, which has some
57100
limitations: `it does not work on the output of compiled functions
@@ -73,34 +116,44 @@ def backward(
73116
else:
74117
inputs_ = OrderedSet(inputs)
75118

76-
backward_transform = _create_transform(
77-
tensors=tensors_,
78-
inputs=inputs_,
79-
retain_graph=retain_graph,
80-
parallel_chunk_size=parallel_chunk_size,
81-
)
119+
jac_tensors_dict = _create_jac_tensors_dict(tensors_, jac_tensors)
120+
transform = _create_transform(tensors_, inputs_, parallel_chunk_size, retain_graph)
121+
transform(jac_tensors_dict)
122+
123+
124+
def _create_jac_tensors_dict(
125+
tensors: OrderedSet[Tensor],
126+
opt_jac_tensors: Sequence[Tensor] | Tensor | None,
127+
) -> dict[Tensor, Tensor]:
128+
"""
129+
Creates a dictionary mapping tensors to their corresponding Jacobians.
82130
83-
backward_transform({})
131+
:param tensors: The tensors to differentiate.
132+
:param opt_jac_tensors: The initial Jacobians to backpropagate. If ``None``, defaults to
133+
identity.
134+
"""
135+
if opt_jac_tensors is None:
136+
# Transform that creates gradient outputs containing only ones.
137+
init = Init(tensors)
138+
# Transform that turns the gradients into Jacobians.
139+
diag = Diagonalize(tensors)
140+
return (diag << init)({})
141+
jac_tensors = [opt_jac_tensors] if isinstance(opt_jac_tensors, Tensor) else opt_jac_tensors
142+
check_matching_length(jac_tensors, tensors, "jac_tensors", "tensors")
143+
check_matching_shapes(jac_tensors, tensors, "jac_tensors", "tensors")
144+
check_consistent_first_dimension(jac_tensors, "jac_tensors")
145+
return dict(zip(tensors, jac_tensors, strict=True))
84146

85147

86148
def _create_transform(
87149
tensors: OrderedSet[Tensor],
88150
inputs: OrderedSet[Tensor],
89-
retain_graph: bool,
90151
parallel_chunk_size: int | None,
152+
retain_graph: bool,
91153
) -> Transform:
92-
"""Creates the backward transform."""
93-
94-
# Transform that creates gradient outputs containing only ones.
95-
init = Init(tensors)
96-
97-
# Transform that turns the gradients into Jacobians.
98-
diag = Diagonalize(tensors)
99-
154+
"""Creates the backward transform that computes and accumulates Jacobians."""
100155
# Transform that computes the required Jacobians.
101156
jac = Jac(tensors, inputs, parallel_chunk_size, retain_graph)
102-
103157
# Transform that accumulates the result in the .jac field of the inputs.
104158
accumulate = AccumulateJac()
105-
106-
return accumulate << jac << diag << init
159+
return accumulate << jac

src/torchjd/autojac/_jac.py

Lines changed: 74 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from torchjd.autojac._transform._ordered_set import OrderedSet
1010
from torchjd.autojac._utils import (
1111
as_checked_ordered_set,
12+
check_consistent_first_dimension,
13+
check_matching_length,
14+
check_matching_shapes,
1215
check_optional_positive_chunk_size,
1316
get_leaf_tensors,
1417
)
@@ -17,19 +20,27 @@
1720
def jac(
1821
outputs: Sequence[Tensor] | Tensor,
1922
inputs: Iterable[Tensor] | None = None,
23+
jac_outputs: Sequence[Tensor] | Tensor | None = None,
2024
retain_graph: bool = False,
2125
parallel_chunk_size: int | None = None,
2226
) -> tuple[Tensor, ...]:
2327
r"""
24-
Computes the Jacobian of all values in ``outputs`` with respect to all ``inputs``. Returns the
25-
result as a tuple, with one Jacobian per input tensor. The returned Jacobian with respect to
26-
input ``t`` has shape ``[m] + t.shape``.
28+
Computes the Jacobians of ``outputs`` with respect to ``inputs``, left-multiplied by
29+
``jac_outputs`` (or identity if ``jac_outputs`` is ``None``), and returns the result as a tuple,
30+
with one Jacobian per input tensor. The returned Jacobian with respect to input ``t`` has shape
31+
``[m] + t.shape``.
2732
28-
:param outputs: The tensor or tensors to differentiate. Should be non-empty. The Jacobians will
29-
have one row for each value of each of these tensors.
33+
:param outputs: The tensor or tensors to differentiate. Should be non-empty.
3034
:param inputs: The tensors with respect to which the Jacobian must be computed. These must have
3135
their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors
3236
that were used to compute the ``outputs`` parameter.
37+
:param jac_outputs: The initial Jacobians to backpropagate, analog to the ``grad_outputs``
38+
parameter of ``torch.autograd.grad``. If provided, it must have the same structure as
39+
``outputs`` and each tensor in ``jac_outputs`` must match the shape of the corresponding
40+
tensor in ``outputs``, with an extra leading dimension representing the number of rows of
41+
the resulting Jacobian (e.g. the number of losses). If ``None``, defaults to the identity
42+
matrix. In this case, the standard Jacobian of ``outputs`` is computed, with one row for
43+
each value in the ``outputs``.
3344
:param retain_graph: If ``False``, the graph used to compute the grad will be freed. Defaults to
3445
``False``.
3546
:param parallel_chunk_size: The number of scalars to differentiate simultaneously in the
@@ -60,7 +71,7 @@ def jac(
6071
>>> jacobians = jac([y1, y2], [param])
6172
>>>
6273
>>> jacobians
63-
(tensor([-1., 1.],
74+
(tensor([[-1., 1.],
6475
[ 2., 4.]]),)
6576
6677
.. admonition::
@@ -99,6 +110,34 @@ def jac(
99110
gradients are exactly orthogonal (they have an inner product of 0), but they conflict with
100111
the third gradient (inner product of -1 and -3).
101112
113+
.. admonition::
114+
Example
115+
116+
This example shows how to apply chain rule using the ``jac_outputs`` parameter to compute
117+
the Jacobian in two steps.
118+
119+
>>> import torch
120+
>>>
121+
>>> from torchjd.autojac import jac
122+
>>>
123+
>>> x = torch.tensor([1., 2.], requires_grad=True)
124+
>>> # Compose functions: x -> h -> y
125+
>>> h = x ** 2
126+
>>> y1 = h.sum()
127+
>>> y2 = torch.tensor([1., -1.]) @ h
128+
>>>
129+
>>> # Step 1: Compute d[y1,y2]/dh
130+
>>> jac_h = jac([y1, y2], [h])[0] # Shape: [2, 2]
131+
>>>
132+
>>> # Step 2: Use chain rule to compute d[y1,y2]/dx = (d[y1,y2]/dh) @ (dh/dx)
133+
>>> jac_x = jac(h, [x], jac_outputs=jac_h)[0]
134+
>>>
135+
>>> jac_x
136+
tensor([[ 2., 4.],
137+
[ 2., -4.]])
138+
139+
This two-step computation is equivalent to directly computing ``jac([y1, y2], [x])``.
140+
102141
.. warning::
103142
To differentiate in parallel, ``jac`` relies on ``torch.vmap``, which has some
104143
limitations: `it does not work on the output of compiled functions
@@ -122,30 +161,40 @@ def jac(
122161
inputs_with_repetition = list(inputs) # Create a list to avoid emptying generator
123162
inputs_ = OrderedSet(inputs_with_repetition)
124163

125-
jac_transform = _create_transform(
126-
outputs=outputs_,
127-
inputs=inputs_,
128-
retain_graph=retain_graph,
129-
parallel_chunk_size=parallel_chunk_size,
130-
)
131-
132-
result = jac_transform({})
164+
jac_outputs_dict = _create_jac_outputs_dict(outputs_, jac_outputs)
165+
transform = _create_transform(outputs_, inputs_, parallel_chunk_size, retain_graph)
166+
result = transform(jac_outputs_dict)
133167
return tuple(result[input] for input in inputs_with_repetition)
134168

135169

170+
def _create_jac_outputs_dict(
171+
outputs: OrderedSet[Tensor],
172+
opt_jac_outputs: Sequence[Tensor] | Tensor | None,
173+
) -> dict[Tensor, Tensor]:
174+
"""
175+
Creates a dictionary mapping outputs to their corresponding Jacobians.
176+
177+
:param outputs: The tensors to differentiate.
178+
:param opt_jac_outputs: The initial Jacobians to backpropagate. If ``None``, defaults to
179+
identity.
180+
"""
181+
if opt_jac_outputs is None:
182+
# Transform that creates gradient outputs containing only ones.
183+
init = Init(outputs)
184+
# Transform that turns the gradients into Jacobians.
185+
diag = Diagonalize(outputs)
186+
return (diag << init)({})
187+
jac_outputs = [opt_jac_outputs] if isinstance(opt_jac_outputs, Tensor) else opt_jac_outputs
188+
check_matching_length(jac_outputs, outputs, "jac_outputs", "outputs")
189+
check_matching_shapes(jac_outputs, outputs, "jac_outputs", "outputs")
190+
check_consistent_first_dimension(jac_outputs, "jac_outputs")
191+
return dict(zip(outputs, jac_outputs, strict=True))
192+
193+
136194
def _create_transform(
137195
outputs: OrderedSet[Tensor],
138196
inputs: OrderedSet[Tensor],
139-
retain_graph: bool,
140197
parallel_chunk_size: int | None,
198+
retain_graph: bool,
141199
) -> Transform:
142-
# Transform that creates gradient outputs containing only ones.
143-
init = Init(outputs)
144-
145-
# Transform that turns the gradients into Jacobians.
146-
diag = Diagonalize(outputs)
147-
148-
# Transform that computes the required Jacobians.
149-
jac = Jac(outputs, inputs, parallel_chunk_size, retain_graph)
150-
151-
return jac << diag << init
200+
return Jac(outputs, inputs, parallel_chunk_size, retain_graph)

src/torchjd/autojac/_jac_to_grad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torchjd.aggregation import Aggregator
77

88
from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac
9+
from ._utils import check_consistent_first_dimension
910

1011

1112
def jac_to_grad(
@@ -67,8 +68,7 @@ def jac_to_grad(
6768

6869
jacobians = [t.jac for t in tensors_]
6970

70-
if not all(jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:]):
71-
raise ValueError("All Jacobians should have the same number of rows.")
71+
check_consistent_first_dimension(jacobians, "tensors.jac")
7272

7373
if not retain_jac:
7474
_free_jacs(tensors_)

0 commit comments

Comments
 (0)