From a975e7eda1e3892b120bacab041234c8b72c873d Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 13 Feb 2026 09:44:37 +0100 Subject: [PATCH 01/23] feat(autojac): Add `jac_tensors` to `backward` --- src/torchjd/autojac/_backward.py | 61 ++++++++++++++--------------- tests/unit/autojac/test_backward.py | 22 +---------- 2 files changed, 30 insertions(+), 53 deletions(-) diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index e0188976..5854f742 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -2,22 +2,32 @@ from torch import Tensor -from ._transform import AccumulateJac, Diagonalize, Init, Jac, OrderedSet, Transform +from ._transform import AccumulateJac, Diagonalize, Init, Jac, OrderedSet from ._utils import as_checked_ordered_set, check_optional_positive_chunk_size, get_leaf_tensors def backward( tensors: Sequence[Tensor] | Tensor, + jac_tensors: Sequence[Tensor] | Tensor | None = None, inputs: Iterable[Tensor] | None = None, retain_graph: bool = False, parallel_chunk_size: int | None = None, ) -> None: r""" - Computes the Jacobians of all values in ``tensors`` with respect to all ``inputs`` and - accumulates them in the ``.jac`` fields of the ``inputs``. - - :param tensors: The tensor or tensors to differentiate. Should be non-empty. The Jacobians will - have one row for each value of each of these tensors. + Computes the Jacobians of ``tensors`` with respect to ``inputs``, potentially pre-multiplied by + ``jac_tensors``, and accumulates the results in the ``.jac`` fields of the ``inputs``. + + Mathematically, if ``jac_tensors`` is provided, this function computes the matrix product + :math:`J_{init} \cdot J`, where :math:`J` is the Jacobian of ``tensors`` w.r.t ``inputs``, and + :math:`J_{init}` is the concatenation of ``jac_tensors``. If ``jac_tensors`` is ``None``, it + assumes an Identity matrix, resulting in the full Jacobian. + + :param tensors: The tensor or tensors to differentiate. Should be non-empty. + :param jac_tensors: The initial Jacobian to backpropagate. If provided, it must have the same + length and structure as ``tensors``. Each tensor in ``jac_tensors`` must match the shape of + the corresponding tensor in ``tensors``, with an extra leading dimension representing the + number of rows of the resulting Jacobian. If ``None``, defaults to the Identity matrix, + resulting in the standard Jacobian of ``tensors``. :param inputs: The tensors with respect to which the Jacobians must be computed. These must have their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors that were used to compute the ``tensors`` parameter. @@ -52,6 +62,8 @@ def backward( The ``.jac`` field of ``param`` now contains the Jacobian of :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``. + # TODO: Need an example with `jac_tensors` not None. + .. warning:: To differentiate in parallel, ``backward`` relies on ``torch.vmap``, which has some limitations: `it does not work on the output of compiled functions @@ -73,34 +85,19 @@ def backward( else: inputs_ = OrderedSet(inputs) - backward_transform = _create_transform( - tensors=tensors_, - inputs=inputs_, - retain_graph=retain_graph, - parallel_chunk_size=parallel_chunk_size, - ) - - backward_transform({}) - - -def _create_transform( - tensors: OrderedSet[Tensor], - inputs: OrderedSet[Tensor], - retain_graph: bool, - parallel_chunk_size: int | None, -) -> Transform: - """Creates the backward transform.""" - - # Transform that creates gradient outputs containing only ones. - init = Init(tensors) - - # Transform that turns the gradients into Jacobians. - diag = Diagonalize(tensors) + if jac_tensors is None: + # Transform that creates gradient outputs containing only ones. + init = Init(tensors_) + # Transform that turns the gradients into Jacobians. + diag = Diagonalize(tensors_) + jac_tensors_dict = (diag << init)({}) + else: + jac_tensors_ = as_checked_ordered_set(jac_tensors, "jac_tensors") + jac_tensors_dict = dict(zip(tensors_, jac_tensors_, strict=True)) # Transform that computes the required Jacobians. - jac = Jac(tensors, inputs, parallel_chunk_size, retain_graph) - + jac = Jac(tensors_, inputs_, parallel_chunk_size, retain_graph) # Transform that accumulates the result in the .jac field of the inputs. accumulate = AccumulateJac() - return accumulate << jac << diag << init + (accumulate << jac)(jac_tensors_dict) diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index 80bd06e3..df1cb065 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -4,28 +4,8 @@ from utils.tensors import randn_, tensor_ from torchjd.autojac import backward -from torchjd.autojac._backward import _create_transform -from torchjd.autojac._transform import OrderedSet - -def test_check_create_transform(): - """Tests that _create_transform creates a valid Transform.""" - - a1 = tensor_([1.0, 2.0], requires_grad=True) - a2 = tensor_([3.0, 4.0], requires_grad=True) - - y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum() - y2 = (a1**2).sum() + a2.norm() - - transform = _create_transform( - tensors=OrderedSet([y1, y2]), - inputs=OrderedSet([a1, a2]), - retain_graph=False, - parallel_chunk_size=None, - ) - - output_keys = transform.check_keys(set()) - assert output_keys == set() +# TODO: Add tests of the `jac_tensors` parameter. def test_jac_is_populated(): From e68caf9803d3c538565c9800830df7440cce6b89 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 13 Feb 2026 10:10:47 +0100 Subject: [PATCH 02/23] Add usage example using `jac_tensors` to `backward`. --- src/torchjd/autojac/_backward.py | 21 ++++++++++++++++++++- tests/doc/test_backward.py | 14 ++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index 5854f742..eff13e91 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -62,7 +62,24 @@ def backward( The ``.jac`` field of ``param`` now contains the Jacobian of :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``. - # TODO: Need an example with `jac_tensors` not None. + .. admonition:: + Example + + If `jac_tensors` is made of matrices whose first dimension is 1, then this function is + essentially equivalent to `autograd.grad`. + + >>> import torch + >>> + >>> from torchjd.autojac import backward + >> + >>> param = torch.tensor([1., 2.], requires_grad=True) + >>> y = torch.stack([param[0] ** 2, param[1] ** 3]) + >>> + >>> weights = torch.tensor([[0.5, 1.0]]) + >>> backward([y], jac_tensors=[weights]) + >>> + >>> param.jac + tensor([[ 1., 12.]]) .. warning:: To differentiate in parallel, ``backward`` relies on ``torch.vmap``, which has some @@ -92,6 +109,8 @@ def backward( diag = Diagonalize(tensors_) jac_tensors_dict = (diag << init)({}) else: + # TODO: Check that the first dimension of each jac_tensors is the same, and that the rest + # correspond to the shape of the corresponding element in tensors_ jac_tensors_ = as_checked_ordered_set(jac_tensors, "jac_tensors") jac_tensors_dict = dict(zip(tensors_, jac_tensors_, strict=True)) diff --git a/tests/doc/test_backward.py b/tests/doc/test_backward.py index 032f902d..5caa30de 100644 --- a/tests/doc/test_backward.py +++ b/tests/doc/test_backward.py @@ -19,3 +19,17 @@ def test_backward(): backward([y1, y2]) assert_jac_close(param, torch.tensor([[-1.0, 1.0], [2.0, 4.0]]), rtol=0.0, atol=1e-04) + + +def test_backward2(): + import torch + + from torchjd.autojac import backward + + param = torch.tensor([1.0, 2.0], requires_grad=True) + y = torch.stack([param[0] ** 2, param[1] ** 3]) + + weights = torch.tensor([[0.5, 1.0]]) + backward([y], jac_tensors=[weights]) + + assert_jac_close(param, torch.tensor([[1.0, 12.0]]), rtol=0.0, atol=1e-04) From e1ee1258fe73a67e0344a14b99e8e4ede2f85197 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 13 Feb 2026 11:40:42 +0100 Subject: [PATCH 03/23] Add tests for `jac_tensors` in test_backward --- tests/unit/autojac/test_backward.py | 101 ++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index df1cb065..cf2b5cff 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -51,6 +51,107 @@ def test_value_is_correct( assert_jac_close(input, J) +@mark.parametrize("rows", [1, 2, 5]) +def test_jac_tensors_value_is_correct(rows: int): + """ + Tests that backward correctly computes the product of jac_tensors and the Jacobian. + result = jac_tensors @ Jacobian(tensors, inputs). + """ + input_size = 4 + output_size = 3 + + J_model = randn_((output_size, input_size)) + + input = randn_([input_size], requires_grad=True) + tensor = J_model @ input + + J_init = randn_((rows, output_size)) + + backward( + tensor, + jac_tensors=J_init, + inputs=[input], + ) + + expected_jac = J_init @ J_model + assert_jac_close(input, expected_jac) + + +@mark.parametrize("rows", [1, 3]) +def test_jac_tensors_multiple_components(rows: int): + """ + Tests that jac_tensors works correctly when tensors is a list of multiple tensors. The + jac_tensors must match the structure of tensors. + """ + input_len = 2 + input = randn_([input_len], requires_grad=True) + + y1 = input * 2 + y2 = torch.cat([input, input[:1]]) + + J1 = randn_((rows, 2)) + J2 = randn_((rows, 3)) + + backward([y1, y2], jac_tensors=[J1, J2], inputs=[input]) + + jac_y1 = torch.eye(2) * 2 + + jac_y2 = torch.tensor([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]]) + + expected = J1 @ jac_y1 + J2 @ jac_y2 + + assert_jac_close(input, expected) + + +def test_jac_tensors_length_mismatch(): + """Tests that backward raises an error if len(jac_tensors) != len(tensors).""" + x = tensor_([1.0, 2.0], requires_grad=True) + y1 = x * 2 + y2 = x * 3 + + J1 = randn_((2, 2)) + + with raises(ValueError): + backward([y1, y2], jac_tensors=[J1], inputs=[x]) + + +def test_jac_tensors_shape_mismatch(): + """ + Tests that backward raises an error if the shape of a tensor in jac_tensors is incompatible with + the corresponding tensor. + """ + x = tensor_([1.0, 2.0], requires_grad=True) + y = x * 2 + + J_bad = randn_((3, 5)) + + with raises((ValueError, RuntimeError)): + backward(y, jac_tensors=J_bad, inputs=[x]) + + +@mark.parametrize( + "rows_y1, rows_y2", + [ + (3, 5), + (1, 2), + ], +) +def test_jac_tensors_inconsistent_first_dimension(rows_y1: int, rows_y2: int): + """ + Tests that backward fails if the provided jac_tensors is inconsistent across the sequence. + """ + x = tensor_([1.0, 2.0], requires_grad=True) + + y1 = x * 2 + y2 = x.sum() + + j1 = randn_((rows_y1, 2)) + j2 = randn_((rows_y2,)) + + with raises((ValueError, RuntimeError)): + backward([y1, y2], jac_tensors=[j1, j2], inputs=[x]) + + def test_empty_inputs(): """Tests that backward does not fill the .jac values if no input is specified.""" From 873d56fa9a29c7df55dbe4972f76933e6c52b2d1 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 13 Feb 2026 11:46:28 +0100 Subject: [PATCH 04/23] Add another usage example to `backward` --- src/torchjd/autojac/_backward.py | 28 ++++++++++++++++++++++++++-- tests/doc/test_backward.py | 18 ++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index eff13e91..09ae5261 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -65,8 +65,32 @@ def backward( .. admonition:: Example - If `jac_tensors` is made of matrices whose first dimension is 1, then this function is - essentially equivalent to `autograd.grad`. + This is the same example as before, except that we specify the ``jac_tensors`` that correspond + to the default `None` + + >>> import torch + >>> + >>> from torchjd.autojac import backward + >>> + >>> param = torch.tensor([1., 2.], requires_grad=True) + >>> # Compute arbitrary quantities that are function of param + >>> y1 = torch.tensor([-1., 1.]) @ param + >>> y2 = (param ** 2).sum() + >>> + >>> J1 = torch.tensor([1.0, 0.0]) + >>> J2 = torch.tensor([0.0, 1.0]) + >>> + >>> backward([y1, y2]) + >>> + >>> param.jac + tensor([[-1., 1.], + [ 2., 4.]]) + + .. admonition:: + Example + + If ``jac_tensors`` is made of matrices whose first dimension is 1, then this function is + essentially equivalent to ``autograd.grad``. >>> import torch >>> diff --git a/tests/doc/test_backward.py b/tests/doc/test_backward.py index 5caa30de..a9fa2142 100644 --- a/tests/doc/test_backward.py +++ b/tests/doc/test_backward.py @@ -26,6 +26,24 @@ def test_backward2(): from torchjd.autojac import backward + param = torch.tensor([1.0, 2.0], requires_grad=True) + # Compute arbitrary quantities that are function of param + y1 = torch.tensor([-1.0, 1.0]) @ param + y2 = (param**2).sum() + + J1 = torch.tensor([1.0, 0.0]) + J2 = torch.tensor([0.0, 1.0]) + + backward([y1, y2], jac_tensors=[J1, J2]) + + assert_jac_close(param, torch.tensor([[-1.0, 1.0], [2.0, 4.0]]), rtol=0.0, atol=1e-04) + + +def test_backward3(): + import torch + + from torchjd.autojac import backward + param = torch.tensor([1.0, 2.0], requires_grad=True) y = torch.stack([param[0] ** 2, param[1] ** 3]) From ff0cb23917f307c65c029aabf55aa304c2239b17 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 13 Feb 2026 11:47:32 +0100 Subject: [PATCH 05/23] Remove TODOs --- src/torchjd/autojac/_backward.py | 1 - tests/unit/autojac/test_backward.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index 09ae5261..5c3572b0 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -133,7 +133,6 @@ def backward( diag = Diagonalize(tensors_) jac_tensors_dict = (diag << init)({}) else: - # TODO: Check that the first dimension of each jac_tensors is the same, and that the rest # correspond to the shape of the corresponding element in tensors_ jac_tensors_ = as_checked_ordered_set(jac_tensors, "jac_tensors") jac_tensors_dict = dict(zip(tensors_, jac_tensors_, strict=True)) diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index cf2b5cff..3a66c833 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -5,8 +5,6 @@ from torchjd.autojac import backward -# TODO: Add tests of the `jac_tensors` parameter. - def test_jac_is_populated(): """Tests that backward correctly fills the .jac field.""" From a65a193f2b44e077b4a68ad6b9327a70ddd64594 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 13 Feb 2026 12:06:18 +0100 Subject: [PATCH 06/23] Fix usage of methods in `tests.utils.tensors.` --- tests/unit/autojac/test_backward.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index 3a66c833..177d5514 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -1,7 +1,7 @@ import torch from pytest import mark, raises from utils.asserts import assert_has_jac, assert_has_no_jac, assert_jac_close -from utils.tensors import randn_, tensor_ +from utils.tensors import eye_, randn_, tensor_ from torchjd.autojac import backward @@ -92,9 +92,9 @@ def test_jac_tensors_multiple_components(rows: int): backward([y1, y2], jac_tensors=[J1, J2], inputs=[input]) - jac_y1 = torch.eye(2) * 2 + jac_y1 = eye_(2) * 2 - jac_y2 = torch.tensor([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]]) + jac_y2 = tensor_([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]]) expected = J1 @ jac_y1 + J2 @ jac_y2 From 271dbbb0143178335115f8895d08ba534f439fa9 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Sat, 14 Feb 2026 10:37:59 +0100 Subject: [PATCH 07/23] Fix second usage example --- src/torchjd/autojac/_backward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index 5c3572b0..973bb60c 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -80,7 +80,7 @@ def backward( >>> J1 = torch.tensor([1.0, 0.0]) >>> J2 = torch.tensor([0.0, 1.0]) >>> - >>> backward([y1, y2]) + >>> backward([y1, y2], jac_tensors=[J1, J2]) >>> >>> param.jac tensor([[-1., 1.], From 49c0934865056a812cc753ab44f0aa183e3e0e2f Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Sat, 14 Feb 2026 10:38:32 +0100 Subject: [PATCH 08/23] Fix third usage example --- src/torchjd/autojac/_backward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index 973bb60c..e7426922 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -95,7 +95,7 @@ def backward( >>> import torch >>> >>> from torchjd.autojac import backward - >> + >>> >>> param = torch.tensor([1., 2.], requires_grad=True) >>> y = torch.stack([param[0] ** 2, param[1] ** 3]) >>> From 1d82773a4c5cebdc3c0ece6683c7f109afb77561 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Sat, 14 Feb 2026 10:39:29 +0100 Subject: [PATCH 09/23] Remove outdated comment --- src/torchjd/autojac/_backward.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index e7426922..a1adf4fc 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -133,7 +133,6 @@ def backward( diag = Diagonalize(tensors_) jac_tensors_dict = (diag << init)({}) else: - # correspond to the shape of the corresponding element in tensors_ jac_tensors_ = as_checked_ordered_set(jac_tensors, "jac_tensors") jac_tensors_dict = dict(zip(tensors_, jac_tensors_, strict=True)) From 36fead0938d2c8a5ef447b23be51896a9c9197eb Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Sat, 14 Feb 2026 10:42:48 +0100 Subject: [PATCH 10/23] Improve the third example. --- src/torchjd/autojac/_backward.py | 5 +++-- tests/doc/test_backward.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index a1adf4fc..f766badf 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -90,7 +90,8 @@ def backward( Example If ``jac_tensors`` is made of matrices whose first dimension is 1, then this function is - essentially equivalent to ``autograd.grad``. + equivalent to the call ``autograd.grad(y, grad_tensors=weights)`` up to a reshape of the + output. >>> import torch >>> @@ -100,7 +101,7 @@ def backward( >>> y = torch.stack([param[0] ** 2, param[1] ** 3]) >>> >>> weights = torch.tensor([[0.5, 1.0]]) - >>> backward([y], jac_tensors=[weights]) + >>> backward(y, jac_tensors=weights) >>> >>> param.jac tensor([[ 1., 12.]]) diff --git a/tests/doc/test_backward.py b/tests/doc/test_backward.py index a9fa2142..6461047c 100644 --- a/tests/doc/test_backward.py +++ b/tests/doc/test_backward.py @@ -48,6 +48,6 @@ def test_backward3(): y = torch.stack([param[0] ** 2, param[1] ** 3]) weights = torch.tensor([[0.5, 1.0]]) - backward([y], jac_tensors=[weights]) + backward(y, jac_tensors=weights) assert_jac_close(param, torch.tensor([[1.0, 12.0]]), rtol=0.0, atol=1e-04) From b91da92f5148ce2d9089c85b743e56ce84c3a3f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 15 Feb 2026 11:50:38 +0100 Subject: [PATCH 11/23] Improve docstring of backward --- src/torchjd/autojac/_backward.py | 47 +++++++++----------------------- tests/doc/test_backward.py | 14 ---------- 2 files changed, 13 insertions(+), 48 deletions(-) diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index f766badf..14c25c69 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -14,20 +14,16 @@ def backward( parallel_chunk_size: int | None = None, ) -> None: r""" - Computes the Jacobians of ``tensors`` with respect to ``inputs``, potentially pre-multiplied by - ``jac_tensors``, and accumulates the results in the ``.jac`` fields of the ``inputs``. - - Mathematically, if ``jac_tensors`` is provided, this function computes the matrix product - :math:`J_{init} \cdot J`, where :math:`J` is the Jacobian of ``tensors`` w.r.t ``inputs``, and - :math:`J_{init}` is the concatenation of ``jac_tensors``. If ``jac_tensors`` is ``None``, it - assumes an Identity matrix, resulting in the full Jacobian. + Computes the Jacobians of ``tensors`` with respect to ``inputs``, left-multiplied by + ``jac_tensors`` (or identity if ``jac_tensors`` is ``None``), and accumulates the results in the + ``.jac`` fields of the ``inputs``. :param tensors: The tensor or tensors to differentiate. Should be non-empty. - :param jac_tensors: The initial Jacobian to backpropagate. If provided, it must have the same - length and structure as ``tensors``. Each tensor in ``jac_tensors`` must match the shape of - the corresponding tensor in ``tensors``, with an extra leading dimension representing the - number of rows of the resulting Jacobian. If ``None``, defaults to the Identity matrix, - resulting in the standard Jacobian of ``tensors``. + :param jac_tensors: The initial Jacobians to backpropagate. If provided, it must have the same + structure as ``tensors`` and each tensor in ``jac_tensors`` must match the shape of the + corresponding tensor in ``tensors``, with an extra leading dimension representing the + number of rows of the resulting Jacobian (e.g. the number of losses). If ``None``, defaults + to the identity matrix. In this case, the standard Jacobian of ``tensors`` is computed. :param inputs: The tensors with respect to which the Jacobians must be computed. These must have their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors that were used to compute the ``tensors`` parameter. @@ -42,7 +38,7 @@ def backward( .. admonition:: Example - The following code snippet showcases a simple usage of ``backward``. + This example shows a simple usage of ``backward``. >>> import torch >>> @@ -65,8 +61,8 @@ def backward( .. admonition:: Example - This is the same example as before, except that we specify the ``jac_tensors`` that correspond - to the default `None` + This is the same example as before, except that we explicitly specify the identity + ``jac_tensors`` (which is equivalent to using the default `None`). >>> import torch >>> @@ -86,25 +82,8 @@ def backward( tensor([[-1., 1.], [ 2., 4.]]) - .. admonition:: - Example - - If ``jac_tensors`` is made of matrices whose first dimension is 1, then this function is - equivalent to the call ``autograd.grad(y, grad_tensors=weights)`` up to a reshape of the - output. - - >>> import torch - >>> - >>> from torchjd.autojac import backward - >>> - >>> param = torch.tensor([1., 2.], requires_grad=True) - >>> y = torch.stack([param[0] ** 2, param[1] ** 3]) - >>> - >>> weights = torch.tensor([[0.5, 1.0]]) - >>> backward(y, jac_tensors=weights) - >>> - >>> param.jac - tensor([[ 1., 12.]]) + Instead of using the identity ``jac_tensors``, you can backpropagate some Jacobians obtained + by a call to :func:`torchjd.autojac.jac` on a later part of the computation graph. .. warning:: To differentiate in parallel, ``backward`` relies on ``torch.vmap``, which has some diff --git a/tests/doc/test_backward.py b/tests/doc/test_backward.py index 6461047c..2416210e 100644 --- a/tests/doc/test_backward.py +++ b/tests/doc/test_backward.py @@ -37,17 +37,3 @@ def test_backward2(): backward([y1, y2], jac_tensors=[J1, J2]) assert_jac_close(param, torch.tensor([[-1.0, 1.0], [2.0, 4.0]]), rtol=0.0, atol=1e-04) - - -def test_backward3(): - import torch - - from torchjd.autojac import backward - - param = torch.tensor([1.0, 2.0], requires_grad=True) - y = torch.stack([param[0] ** 2, param[1] ** 3]) - - weights = torch.tensor([[0.5, 1.0]]) - backward(y, jac_tensors=weights) - - assert_jac_close(param, torch.tensor([[1.0, 12.0]]), rtol=0.0, atol=1e-04) From ca9ee9f443f800f7aead593971a4edc253181d04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 15 Feb 2026 17:44:27 +0100 Subject: [PATCH 12/23] Explain analogy between jac_tensors and grad_tensors --- src/torchjd/autojac/_backward.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index 14c25c69..47eee8b9 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -19,11 +19,12 @@ def backward( ``.jac`` fields of the ``inputs``. :param tensors: The tensor or tensors to differentiate. Should be non-empty. - :param jac_tensors: The initial Jacobians to backpropagate. If provided, it must have the same - structure as ``tensors`` and each tensor in ``jac_tensors`` must match the shape of the - corresponding tensor in ``tensors``, with an extra leading dimension representing the - number of rows of the resulting Jacobian (e.g. the number of losses). If ``None``, defaults - to the identity matrix. In this case, the standard Jacobian of ``tensors`` is computed. + :param jac_tensors: The initial Jacobians to backpropagate, analog to the `grad_tensors` + parameter of `torch.autograd.backward`. If provided, it must have the same structure as + ``tensors`` and each tensor in ``jac_tensors`` must match the shape of the corresponding + tensor in ``tensors``, with an extra leading dimension representing the number of rows of + the resulting Jacobian (e.g. the number of losses). If ``None``, defaults to the identity + matrix. In this case, the standard Jacobian of ``tensors`` is computed. :param inputs: The tensors with respect to which the Jacobians must be computed. These must have their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors that were used to compute the ``tensors`` parameter. From d9dbea13a0297711bb74968e9e3506464119cd28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 15 Feb 2026 18:20:27 +0100 Subject: [PATCH 13/23] Extract create_jac_tensors_dict and create_transform, remove checked_ordered_set on jac_tensors --- src/torchjd/autojac/_backward.py | 46 +++++++++++++++++++++-------- tests/unit/autojac/test_backward.py | 31 +++++++++++++++++++ 2 files changed, 65 insertions(+), 12 deletions(-) diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index 47eee8b9..5fbbd7df 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -2,7 +2,7 @@ from torch import Tensor -from ._transform import AccumulateJac, Diagonalize, Init, Jac, OrderedSet +from ._transform import AccumulateJac, Diagonalize, Init, Jac, OrderedSet, Transform from ._utils import as_checked_ordered_set, check_optional_positive_chunk_size, get_leaf_tensors @@ -107,19 +107,41 @@ def backward( else: inputs_ = OrderedSet(inputs) - if jac_tensors is None: + jac_tensors_dict = _create_jac_tensors_dict(tensors_, jac_tensors) + transform = _create_transform(tensors_, inputs_, parallel_chunk_size, retain_graph) + transform(jac_tensors_dict) + + +def _create_jac_tensors_dict( + tensors: OrderedSet[Tensor], + opt_jac_tensors: Sequence[Tensor] | Tensor | None, +) -> dict[Tensor, Tensor]: + """ + Creates a dictionary mapping tensors to their corresponding Jacobians. + + :param tensors: The tensors to differentiate. + :param opt_jac_tensors: The initial Jacobians to backpropagate. If ``None``, defaults to + identity. + """ + if opt_jac_tensors is None: # Transform that creates gradient outputs containing only ones. - init = Init(tensors_) + init = Init(tensors) # Transform that turns the gradients into Jacobians. - diag = Diagonalize(tensors_) - jac_tensors_dict = (diag << init)({}) - else: - jac_tensors_ = as_checked_ordered_set(jac_tensors, "jac_tensors") - jac_tensors_dict = dict(zip(tensors_, jac_tensors_, strict=True)) - + diag = Diagonalize(tensors) + return (diag << init)({}) + jac_tensors = [opt_jac_tensors] if isinstance(opt_jac_tensors, Tensor) else opt_jac_tensors + return dict(zip(tensors, jac_tensors, strict=True)) + + +def _create_transform( + tensors: OrderedSet[Tensor], + inputs: OrderedSet[Tensor], + parallel_chunk_size: int | None, + retain_graph: bool, +) -> Transform: + """Creates the backward transform that computes and accumulates Jacobians.""" # Transform that computes the required Jacobians. - jac = Jac(tensors_, inputs_, parallel_chunk_size, retain_graph) + jac = Jac(tensors, inputs, parallel_chunk_size, retain_graph) # Transform that accumulates the result in the .jac field of the inputs. accumulate = AccumulateJac() - - (accumulate << jac)(jac_tensors_dict) + return accumulate << jac diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index 177d5514..4a354f79 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -4,6 +4,37 @@ from utils.tensors import eye_, randn_, tensor_ from torchjd.autojac import backward +from torchjd.autojac._backward import _create_jac_tensors_dict, _create_transform +from torchjd.autojac._transform import OrderedSet + + +@mark.parametrize("default_jac_tensors", [True, False]) +def test_check_create_transform(default_jac_tensors: bool): + """Tests that _create_transform creates a valid Transform.""" + + a1 = tensor_([1.0, 2.0], requires_grad=True) + a2 = tensor_([3.0, 4.0], requires_grad=True) + + y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum() + y2 = (a1**2).sum() + a2.norm() + + optional_jac_tensors = ( + None if default_jac_tensors else [tensor_([1.0, 0.0]), tensor_([0.0, 1.0])] + ) + + jac_tensors = _create_jac_tensors_dict( + tensors=OrderedSet([y1, y2]), + opt_jac_tensors=optional_jac_tensors, + ) + transform = _create_transform( + tensors=OrderedSet([y1, y2]), + inputs=OrderedSet([a1, a2]), + retain_graph=False, + parallel_chunk_size=None, + ) + + output_keys = transform.check_keys(set(jac_tensors.keys())) + assert output_keys == set() def test_jac_is_populated(): From 554ad325cf21712f5944955b51ae2f6e826c2e21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 16 Feb 2026 12:05:32 +0100 Subject: [PATCH 14/23] Add jac_outputs parameter to jac function Add a jac_outputs parameter to the jac function, similar to how jac_tensors was added to backward. The parameter allows users to backpropagate custom Jacobians instead of the identity matrix, enabling chained Jacobian computations via the chain rule. - Add jac_outputs parameter after inputs in jac signature - Extract _create_jac_outputs_dict helper function - Refactor _create_transform to only create Jac transform - Add third example showing chain rule usage with jac_outputs - Add comprehensive tests for jac_outputs functionality Co-Authored-By: Claude Sonnet 4.5 --- src/torchjd/autojac/_jac.py | 91 ++++++++++++++++++------- tests/doc/test_jac.py | 18 +++++ tests/unit/autojac/test_jac.py | 120 +++++++++++++++++++++++++++++++-- 3 files changed, 201 insertions(+), 28 deletions(-) diff --git a/src/torchjd/autojac/_jac.py b/src/torchjd/autojac/_jac.py index 1c809d2d..69e5ff37 100644 --- a/src/torchjd/autojac/_jac.py +++ b/src/torchjd/autojac/_jac.py @@ -17,19 +17,26 @@ def jac( outputs: Sequence[Tensor] | Tensor, inputs: Iterable[Tensor] | None = None, + jac_outputs: Sequence[Tensor] | Tensor | None = None, retain_graph: bool = False, parallel_chunk_size: int | None = None, ) -> tuple[Tensor, ...]: r""" - Computes the Jacobian of all values in ``outputs`` with respect to all ``inputs``. Returns the - result as a tuple, with one Jacobian per input tensor. The returned Jacobian with respect to - input ``t`` has shape ``[m] + t.shape``. + Computes the Jacobians of ``outputs`` with respect to ``inputs``, left-multiplied by + ``jac_outputs`` (or identity if ``jac_outputs`` is ``None``), and returns the result as a tuple, + with one Jacobian per input tensor. The returned Jacobian with respect to input ``t`` has shape + ``[m] + t.shape``. - :param outputs: The tensor or tensors to differentiate. Should be non-empty. The Jacobians will - have one row for each value of each of these tensors. + :param outputs: The tensor or tensors to differentiate. Should be non-empty. :param inputs: The tensors with respect to which the Jacobian must be computed. These must have their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors that were used to compute the ``outputs`` parameter. + :param jac_outputs: The initial Jacobians to backpropagate, analog to the ``grad_outputs`` + parameter of ``torch.autograd.grad``. If provided, it must have the same structure as + ``outputs`` and each tensor in ``jac_outputs`` must match the shape of the corresponding + tensor in ``outputs``, with an extra leading dimension representing the number of rows of + the resulting Jacobian (e.g. the number of losses). If ``None``, defaults to the identity + matrix. In this case, the standard Jacobian of ``outputs`` is computed. :param retain_graph: If ``False``, the graph used to compute the grad will be freed. Defaults to ``False``. :param parallel_chunk_size: The number of scalars to differentiate simultaneously in the @@ -60,7 +67,7 @@ def jac( >>> jacobians = jac([y1, y2], [param]) >>> >>> jacobians - (tensor([-1., 1.], + (tensor([[-1., 1.], [ 2., 4.]]),) .. admonition:: @@ -99,6 +106,34 @@ def jac( gradients are exactly orthogonal (they have an inner product of 0), but they conflict with the third gradient (inner product of -1 and -3). + .. admonition:: + Example + + This example shows how to apply chain rule using the ``jac_outputs`` parameter to compute + the Jacobian in two steps. + + >>> import torch + >>> + >>> from torchjd.autojac import jac + >>> + >>> x = torch.tensor([1., 2.], requires_grad=True) + >>> # Compose functions: x -> h -> y + >>> h = x ** 2 + >>> y1 = h.sum() + >>> y2 = torch.tensor([1., -1.]) @ h + >>> + >>> # Step 1: Compute d[y1,y2]/dh + >>> jac_h = jac([y1, y2], [h])[0] # Shape: [2, 2] + >>> + >>> # Step 2: Use chain rule to compute d[y1,y2]/dx = (d[y1,y2]/dh) @ (dh/dx) + >>> jac_x = jac(h, [x], jac_outputs=jac_h)[0] + >>> + >>> jac_x + tensor([[ 2., 4.], + [ 2., -4.]]) + + This two-step computation is equivalent to directly computing ``jac([y1, y2], [x])``. + .. warning:: To differentiate in parallel, ``jac`` relies on ``torch.vmap``, which has some limitations: `it does not work on the output of compiled functions @@ -122,30 +157,40 @@ def jac( inputs_with_repetition = list(inputs) # Create a list to avoid emptying generator inputs_ = OrderedSet(inputs_with_repetition) - jac_transform = _create_transform( - outputs=outputs_, - inputs=inputs_, - retain_graph=retain_graph, - parallel_chunk_size=parallel_chunk_size, - ) - - result = jac_transform({}) + jac_outputs_dict = _create_jac_outputs_dict(outputs_, jac_outputs) + transform = _create_transform(outputs_, inputs_, parallel_chunk_size, retain_graph) + result = transform(jac_outputs_dict) return tuple(result[input] for input in inputs_with_repetition) +def _create_jac_outputs_dict( + outputs: OrderedSet[Tensor], + opt_jac_outputs: Sequence[Tensor] | Tensor | None, +) -> dict[Tensor, Tensor]: + """ + Creates a dictionary mapping outputs to their corresponding Jacobians. + + :param outputs: The tensors to differentiate. + :param opt_jac_outputs: The initial Jacobians to backpropagate. If ``None``, defaults to + identity. + """ + if opt_jac_outputs is None: + # Transform that creates gradient outputs containing only ones. + init = Init(outputs) + # Transform that turns the gradients into Jacobians. + diag = Diagonalize(outputs) + return (diag << init)({}) + jac_outputs = [opt_jac_outputs] if isinstance(opt_jac_outputs, Tensor) else opt_jac_outputs + return dict(zip(outputs, jac_outputs, strict=True)) + + def _create_transform( outputs: OrderedSet[Tensor], inputs: OrderedSet[Tensor], - retain_graph: bool, parallel_chunk_size: int | None, + retain_graph: bool, ) -> Transform: - # Transform that creates gradient outputs containing only ones. - init = Init(outputs) - - # Transform that turns the gradients into Jacobians. - diag = Diagonalize(outputs) - + """Creates the jac transform that computes Jacobians.""" # Transform that computes the required Jacobians. jac = Jac(outputs, inputs, parallel_chunk_size, retain_graph) - - return jac << diag << init + return jac diff --git a/tests/doc/test_jac.py b/tests/doc/test_jac.py index f422545a..92e8745c 100644 --- a/tests/doc/test_jac.py +++ b/tests/doc/test_jac.py @@ -42,3 +42,21 @@ def test_jac_2(): rtol=0.0, atol=1e-04, ) + + +def test_jac_3(): + import torch + + from torchjd.autojac import jac + + x = torch.tensor([1.0, 2.0], requires_grad=True) + # Compose functions: x -> h -> y + h = x**2 + y1 = h.sum() + y2 = torch.tensor([1.0, -1.0]) @ h + # Step 1: Compute d[y1,y2]/dh + jac_h = jac([y1, y2], [h])[0] # Shape: [2, 2] + # Step 2: Use jac_outputs to compute d[y1,y2]/dx = (d[y1,y2]/dh) @ (dh/dx) + jac_x = jac(h, [x], jac_outputs=jac_h)[0] + + assert_close(jac_x, torch.tensor([[2.0, 4.0], [2.0, -4.0]]), rtol=0.0, atol=1e-04) diff --git a/tests/unit/autojac/test_jac.py b/tests/unit/autojac/test_jac.py index 3d776a8d..26eaefef 100644 --- a/tests/unit/autojac/test_jac.py +++ b/tests/unit/autojac/test_jac.py @@ -1,14 +1,15 @@ import torch from pytest import mark, raises from torch.testing import assert_close -from utils.tensors import randn_, tensor_ +from utils.tensors import eye_, randn_, tensor_ from torchjd.autojac import jac -from torchjd.autojac._jac import _create_transform +from torchjd.autojac._jac import _create_jac_outputs_dict, _create_transform from torchjd.autojac._transform import OrderedSet -def test_check_create_transform(): +@mark.parametrize("default_jac_outputs", [True, False]) +def test_check_create_transform(default_jac_outputs: bool): """Tests that _create_transform creates a valid Transform.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -17,14 +18,22 @@ def test_check_create_transform(): y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum() y2 = (a1**2).sum() + a2.norm() + optional_jac_outputs = ( + None if default_jac_outputs else [tensor_([1.0, 0.0]), tensor_([0.0, 1.0])] + ) + + jac_outputs = _create_jac_outputs_dict( + outputs=OrderedSet([y1, y2]), + opt_jac_outputs=optional_jac_outputs, + ) transform = _create_transform( outputs=OrderedSet([y1, y2]), inputs=OrderedSet([a1, a2]), - retain_graph=False, parallel_chunk_size=None, + retain_graph=False, ) - output_keys = transform.check_keys(set()) + output_keys = transform.check_keys(set(jac_outputs.keys())) assert output_keys == {a1, a2} @@ -76,6 +85,107 @@ def test_value_is_correct( assert_close(jacobians[0], J) +@mark.parametrize("rows", [1, 2, 5]) +def test_jac_outputs_value_is_correct(rows: int): + """ + Tests that jac correctly computes the product of jac_outputs and the Jacobian. + result = jac_outputs @ Jacobian(outputs, inputs). + """ + input_size = 4 + output_size = 3 + + J_model = randn_((output_size, input_size)) + + input = randn_([input_size], requires_grad=True) + output = J_model @ input + + J_init = randn_((rows, output_size)) + + jacobians = jac( + output, + inputs=[input], + jac_outputs=J_init, + ) + + expected_jac = J_init @ J_model + assert_close(jacobians[0], expected_jac) + + +@mark.parametrize("rows", [1, 3]) +def test_jac_outputs_multiple_components(rows: int): + """ + Tests that jac_outputs works correctly when outputs is a list of multiple tensors. The + jac_outputs must match the structure of outputs. + """ + input_len = 2 + input = randn_([input_len], requires_grad=True) + + y1 = input * 2 + y2 = torch.cat([input, input[:1]]) + + J1 = randn_((rows, 2)) + J2 = randn_((rows, 3)) + + jacobians = jac([y1, y2], inputs=[input], jac_outputs=[J1, J2]) + + jac_y1 = eye_(2) * 2 + + jac_y2 = tensor_([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]]) + + expected = J1 @ jac_y1 + J2 @ jac_y2 + + assert_close(jacobians[0], expected) + + +def test_jac_outputs_length_mismatch(): + """Tests that jac raises an error if len(jac_outputs) != len(outputs).""" + x = tensor_([1.0, 2.0], requires_grad=True) + y1 = x * 2 + y2 = x * 3 + + J1 = randn_((2, 2)) + + with raises(ValueError): + jac([y1, y2], inputs=[x], jac_outputs=[J1]) + + +def test_jac_outputs_shape_mismatch(): + """ + Tests that jac raises an error if the shape of a tensor in jac_outputs is incompatible with + the corresponding output tensor. + """ + x = tensor_([1.0, 2.0], requires_grad=True) + y = x * 2 + + J_bad = randn_((3, 5)) + + with raises((ValueError, RuntimeError)): + jac(y, inputs=[x], jac_outputs=J_bad) + + +@mark.parametrize( + "rows_y1, rows_y2", + [ + (3, 5), + (1, 2), + ], +) +def test_jac_outputs_inconsistent_first_dimension(rows_y1: int, rows_y2: int): + """ + Tests that jac fails if the provided jac_outputs is inconsistent across the sequence. + """ + x = tensor_([1.0, 2.0], requires_grad=True) + + y1 = x * 2 + y2 = x.sum() + + j1 = randn_((rows_y1, 2)) + j2 = randn_((rows_y2,)) + + with raises((ValueError, RuntimeError)): + jac([y1, y2], inputs=[x], jac_outputs=[j1, j2]) + + def test_empty_inputs(): """Tests that jac does not return any jacobian no input is specified.""" From cdb2a317e2d328f5c4070e5f18b51223165fc160 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 16 Feb 2026 12:10:22 +0100 Subject: [PATCH 15/23] Remove create_transform in jac --- src/torchjd/autojac/_jac.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/src/torchjd/autojac/_jac.py b/src/torchjd/autojac/_jac.py index 69e5ff37..2aabcbe8 100644 --- a/src/torchjd/autojac/_jac.py +++ b/src/torchjd/autojac/_jac.py @@ -2,7 +2,6 @@ from torch import Tensor -from torchjd.autojac._transform._base import Transform from torchjd.autojac._transform._diagonalize import Diagonalize from torchjd.autojac._transform._init import Init from torchjd.autojac._transform._jac import Jac @@ -158,7 +157,7 @@ def jac( inputs_ = OrderedSet(inputs_with_repetition) jac_outputs_dict = _create_jac_outputs_dict(outputs_, jac_outputs) - transform = _create_transform(outputs_, inputs_, parallel_chunk_size, retain_graph) + transform = Jac(outputs_, inputs_, parallel_chunk_size, retain_graph) result = transform(jac_outputs_dict) return tuple(result[input] for input in inputs_with_repetition) @@ -182,15 +181,3 @@ def _create_jac_outputs_dict( return (diag << init)({}) jac_outputs = [opt_jac_outputs] if isinstance(opt_jac_outputs, Tensor) else opt_jac_outputs return dict(zip(outputs, jac_outputs, strict=True)) - - -def _create_transform( - outputs: OrderedSet[Tensor], - inputs: OrderedSet[Tensor], - parallel_chunk_size: int | None, - retain_graph: bool, -) -> Transform: - """Creates the jac transform that computes Jacobians.""" - # Transform that computes the required Jacobians. - jac = Jac(outputs, inputs, parallel_chunk_size, retain_graph) - return jac From 175c91dcc4af83265a42c3a4fd188283c34e26bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 16 Feb 2026 12:11:58 +0100 Subject: [PATCH 16/23] Re-add the explanation of the number of rows of the jacobian --- src/torchjd/autojac/_backward.py | 3 ++- src/torchjd/autojac/_jac.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index 5fbbd7df..70ca6d49 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -24,7 +24,8 @@ def backward( ``tensors`` and each tensor in ``jac_tensors`` must match the shape of the corresponding tensor in ``tensors``, with an extra leading dimension representing the number of rows of the resulting Jacobian (e.g. the number of losses). If ``None``, defaults to the identity - matrix. In this case, the standard Jacobian of ``tensors`` is computed. + matrix. In this case, the standard Jacobian of ``tensors`` is computed, with one row for + each value in the ``tensors``. :param inputs: The tensors with respect to which the Jacobians must be computed. These must have their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors that were used to compute the ``tensors`` parameter. diff --git a/src/torchjd/autojac/_jac.py b/src/torchjd/autojac/_jac.py index 2aabcbe8..c8f4f357 100644 --- a/src/torchjd/autojac/_jac.py +++ b/src/torchjd/autojac/_jac.py @@ -35,7 +35,8 @@ def jac( ``outputs`` and each tensor in ``jac_outputs`` must match the shape of the corresponding tensor in ``outputs``, with an extra leading dimension representing the number of rows of the resulting Jacobian (e.g. the number of losses). If ``None``, defaults to the identity - matrix. In this case, the standard Jacobian of ``outputs`` is computed. + matrix. In this case, the standard Jacobian of ``outputs`` is computed, with one row for + each value in the ``outputs``. :param retain_graph: If ``False``, the graph used to compute the grad will be freed. Defaults to ``False``. :param parallel_chunk_size: The number of scalars to differentiate simultaneously in the From fec3a139d7b4209a288294c90f51acdfe6cbdfae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 16 Feb 2026 12:25:41 +0100 Subject: [PATCH 17/23] Revert "Remove create_transform in jac" This reverts commit cdb2a317e2d328f5c4070e5f18b51223165fc160. --- src/torchjd/autojac/_jac.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/torchjd/autojac/_jac.py b/src/torchjd/autojac/_jac.py index c8f4f357..f614ee43 100644 --- a/src/torchjd/autojac/_jac.py +++ b/src/torchjd/autojac/_jac.py @@ -2,6 +2,7 @@ from torch import Tensor +from torchjd.autojac._transform._base import Transform from torchjd.autojac._transform._diagonalize import Diagonalize from torchjd.autojac._transform._init import Init from torchjd.autojac._transform._jac import Jac @@ -158,7 +159,7 @@ def jac( inputs_ = OrderedSet(inputs_with_repetition) jac_outputs_dict = _create_jac_outputs_dict(outputs_, jac_outputs) - transform = Jac(outputs_, inputs_, parallel_chunk_size, retain_graph) + transform = _create_transform(outputs_, inputs_, parallel_chunk_size, retain_graph) result = transform(jac_outputs_dict) return tuple(result[input] for input in inputs_with_repetition) @@ -182,3 +183,15 @@ def _create_jac_outputs_dict( return (diag << init)({}) jac_outputs = [opt_jac_outputs] if isinstance(opt_jac_outputs, Tensor) else opt_jac_outputs return dict(zip(outputs, jac_outputs, strict=True)) + + +def _create_transform( + outputs: OrderedSet[Tensor], + inputs: OrderedSet[Tensor], + parallel_chunk_size: int | None, + retain_graph: bool, +) -> Transform: + """Creates the jac transform that computes Jacobians.""" + # Transform that computes the required Jacobians. + jac = Jac(outputs, inputs, parallel_chunk_size, retain_graph) + return jac From b92327a8a983561a541e7bc475464b5fa4de599a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 16 Feb 2026 12:32:43 +0100 Subject: [PATCH 18/23] simplify _create_transform in jac --- src/torchjd/autojac/_jac.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/torchjd/autojac/_jac.py b/src/torchjd/autojac/_jac.py index f614ee43..b5f3a5a8 100644 --- a/src/torchjd/autojac/_jac.py +++ b/src/torchjd/autojac/_jac.py @@ -191,7 +191,4 @@ def _create_transform( parallel_chunk_size: int | None, retain_graph: bool, ) -> Transform: - """Creates the jac transform that computes Jacobians.""" - # Transform that computes the required Jacobians. - jac = Jac(outputs, inputs, parallel_chunk_size, retain_graph) - return jac + return Jac(outputs, inputs, parallel_chunk_size, retain_graph) From 2d0ebb07d2b10255627596cf2ffa727e032edfa6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 16 Feb 2026 12:49:16 +0100 Subject: [PATCH 19/23] Add early validation for inconsistent Jacobian dimensions Introduce check_consistent_first_dimension in _utils.py to validate that all Jacobians have the same first dimension. This function is now used by backward, jac, and jac_to_grad, replacing duplicate validation logic. Error messages now include the specific parameter name (jac_tensors, jac_outputs, or tensors.jac) to help users identify the source of dimension mismatches. Co-Authored-By: Claude Sonnet 4.5 --- src/torchjd/autojac/_backward.py | 8 +++++++- src/torchjd/autojac/_jac.py | 2 ++ src/torchjd/autojac/_jac_to_grad.py | 4 ++-- src/torchjd/autojac/_utils.py | 16 ++++++++++++++++ tests/unit/autojac/test_backward.py | 7 +++++-- tests/unit/autojac/test_jac.py | 7 +++++-- 6 files changed, 37 insertions(+), 7 deletions(-) diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index 70ca6d49..37a01f0f 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -3,7 +3,12 @@ from torch import Tensor from ._transform import AccumulateJac, Diagonalize, Init, Jac, OrderedSet, Transform -from ._utils import as_checked_ordered_set, check_optional_positive_chunk_size, get_leaf_tensors +from ._utils import ( + as_checked_ordered_set, + check_consistent_first_dimension, + check_optional_positive_chunk_size, + get_leaf_tensors, +) def backward( @@ -131,6 +136,7 @@ def _create_jac_tensors_dict( diag = Diagonalize(tensors) return (diag << init)({}) jac_tensors = [opt_jac_tensors] if isinstance(opt_jac_tensors, Tensor) else opt_jac_tensors + check_consistent_first_dimension(jac_tensors, "jac_tensors") return dict(zip(tensors, jac_tensors, strict=True)) diff --git a/src/torchjd/autojac/_jac.py b/src/torchjd/autojac/_jac.py index b5f3a5a8..3b417fbc 100644 --- a/src/torchjd/autojac/_jac.py +++ b/src/torchjd/autojac/_jac.py @@ -9,6 +9,7 @@ from torchjd.autojac._transform._ordered_set import OrderedSet from torchjd.autojac._utils import ( as_checked_ordered_set, + check_consistent_first_dimension, check_optional_positive_chunk_size, get_leaf_tensors, ) @@ -182,6 +183,7 @@ def _create_jac_outputs_dict( diag = Diagonalize(outputs) return (diag << init)({}) jac_outputs = [opt_jac_outputs] if isinstance(opt_jac_outputs, Tensor) else opt_jac_outputs + check_consistent_first_dimension(jac_outputs, "jac_outputs") return dict(zip(outputs, jac_outputs, strict=True)) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 06c9c47d..9e46caa0 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -6,6 +6,7 @@ from torchjd.aggregation import Aggregator from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac +from ._utils import check_consistent_first_dimension def jac_to_grad( @@ -67,8 +68,7 @@ def jac_to_grad( jacobians = [t.jac for t in tensors_] - if not all(jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:]): - raise ValueError("All Jacobians should have the same number of rows.") + check_consistent_first_dimension(jacobians, "tensors.jac") if not retain_jac: _free_jacs(tensors_) diff --git a/src/torchjd/autojac/_utils.py b/src/torchjd/autojac/_utils.py index e286cf20..6eae9d22 100644 --- a/src/torchjd/autojac/_utils.py +++ b/src/torchjd/autojac/_utils.py @@ -32,6 +32,22 @@ def as_checked_ordered_set( return OrderedSet(tensors) +def check_consistent_first_dimension( + jacobians: Sequence[Tensor], + variable_name: str, +) -> None: + """ + Checks that all Jacobians have the same first dimension (number of rows). + + :param jacobians: Sequence of Jacobian tensors to validate. + :param variable_name: Name of the variable to include in the error message. + """ + if len(jacobians) > 0 and not all( + jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:] + ): + raise ValueError(f"All Jacobians in `{variable_name}` should have the same number of rows.") + + def get_leaf_tensors(tensors: Iterable[Tensor], excluded: Iterable[Tensor]) -> OrderedSet[Tensor]: """ Gets the leaves of the autograd graph of all specified ``tensors``. diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index 4a354f79..98db3cc7 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -167,7 +167,8 @@ def test_jac_tensors_shape_mismatch(): ) def test_jac_tensors_inconsistent_first_dimension(rows_y1: int, rows_y2: int): """ - Tests that backward fails if the provided jac_tensors is inconsistent across the sequence. + Tests that backward raises a ValueError early when the provided jac_tensors have inconsistent + first dimensions. """ x = tensor_([1.0, 2.0], requires_grad=True) @@ -177,7 +178,9 @@ def test_jac_tensors_inconsistent_first_dimension(rows_y1: int, rows_y2: int): j1 = randn_((rows_y1, 2)) j2 = randn_((rows_y2,)) - with raises((ValueError, RuntimeError)): + with raises( + ValueError, match=r"All Jacobians in `jac_tensors` should have the same number of rows\." + ): backward([y1, y2], jac_tensors=[j1, j2], inputs=[x]) diff --git a/tests/unit/autojac/test_jac.py b/tests/unit/autojac/test_jac.py index 26eaefef..d37552ce 100644 --- a/tests/unit/autojac/test_jac.py +++ b/tests/unit/autojac/test_jac.py @@ -172,7 +172,8 @@ def test_jac_outputs_shape_mismatch(): ) def test_jac_outputs_inconsistent_first_dimension(rows_y1: int, rows_y2: int): """ - Tests that jac fails if the provided jac_outputs is inconsistent across the sequence. + Tests that jac raises a ValueError early when the provided jac_outputs have inconsistent first + dimensions. """ x = tensor_([1.0, 2.0], requires_grad=True) @@ -182,7 +183,9 @@ def test_jac_outputs_inconsistent_first_dimension(rows_y1: int, rows_y2: int): j1 = randn_((rows_y1, 2)) j2 = randn_((rows_y2,)) - with raises((ValueError, RuntimeError)): + with raises( + ValueError, match=r"All Jacobians in `jac_outputs` should have the same number of rows\." + ): jac([y1, y2], inputs=[x], jac_outputs=[j1, j2]) From 3ca093824a634daba22e9bd1f8a3eac09191f600 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 16 Feb 2026 13:17:10 +0100 Subject: [PATCH 20/23] Add early validation for Jacobian length mismatch Validates that jac_outputs and jac_tensors have the same length as their corresponding outputs/tensors before processing, providing clear error messages with actual lengths. Co-Authored-By: Claude Sonnet 4.5 --- src/torchjd/autojac/_backward.py | 2 ++ src/torchjd/autojac/_jac.py | 2 ++ src/torchjd/autojac/_utils.py | 23 ++++++++++++++++++++++- tests/unit/autojac/test_backward.py | 7 +++++-- tests/unit/autojac/test_jac.py | 7 +++++-- 5 files changed, 36 insertions(+), 5 deletions(-) diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index 37a01f0f..d2a989f0 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -6,6 +6,7 @@ from ._utils import ( as_checked_ordered_set, check_consistent_first_dimension, + check_matching_length, check_optional_positive_chunk_size, get_leaf_tensors, ) @@ -136,6 +137,7 @@ def _create_jac_tensors_dict( diag = Diagonalize(tensors) return (diag << init)({}) jac_tensors = [opt_jac_tensors] if isinstance(opt_jac_tensors, Tensor) else opt_jac_tensors + check_matching_length(jac_tensors, tensors, "jac_tensors", "tensors") check_consistent_first_dimension(jac_tensors, "jac_tensors") return dict(zip(tensors, jac_tensors, strict=True)) diff --git a/src/torchjd/autojac/_jac.py b/src/torchjd/autojac/_jac.py index 3b417fbc..9ea12c6e 100644 --- a/src/torchjd/autojac/_jac.py +++ b/src/torchjd/autojac/_jac.py @@ -10,6 +10,7 @@ from torchjd.autojac._utils import ( as_checked_ordered_set, check_consistent_first_dimension, + check_matching_length, check_optional_positive_chunk_size, get_leaf_tensors, ) @@ -183,6 +184,7 @@ def _create_jac_outputs_dict( diag = Diagonalize(outputs) return (diag << init)({}) jac_outputs = [opt_jac_outputs] if isinstance(opt_jac_outputs, Tensor) else opt_jac_outputs + check_matching_length(jac_outputs, outputs, "jac_outputs", "outputs") check_consistent_first_dimension(jac_outputs, "jac_outputs") return dict(zip(outputs, jac_outputs, strict=True)) diff --git a/src/torchjd/autojac/_utils.py b/src/torchjd/autojac/_utils.py index 6eae9d22..3abab705 100644 --- a/src/torchjd/autojac/_utils.py +++ b/src/torchjd/autojac/_utils.py @@ -1,5 +1,5 @@ from collections import deque -from collections.abc import Iterable, Sequence +from collections.abc import Iterable, Sequence, Sized from typing import cast from torch import Tensor @@ -32,6 +32,27 @@ def as_checked_ordered_set( return OrderedSet(tensors) +def check_matching_length( + seq1: Sized, + seq2: Sized, + variable_name1: str, + variable_name2: str, +) -> None: + """ + Checks that two sequences have the same length. + + :param seq1: First sequence to validate. + :param seq2: Second sequence to validate. + :param variable_name1: Name of the first variable to include in the error message. + :param variable_name2: Name of the second variable to include in the error message. + """ + if len(seq1) != len(seq2): + raise ValueError( + f"`{variable_name1}` should have the same length as `{variable_name2}`. " + f"(got {len(seq1)} and {len(seq2)})", + ) + + def check_consistent_first_dimension( jacobians: Sequence[Tensor], variable_name: str, diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index 98db3cc7..3a345cfe 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -133,14 +133,17 @@ def test_jac_tensors_multiple_components(rows: int): def test_jac_tensors_length_mismatch(): - """Tests that backward raises an error if len(jac_tensors) != len(tensors).""" + """Tests that backward raises a ValueError early if len(jac_tensors) != len(tensors).""" x = tensor_([1.0, 2.0], requires_grad=True) y1 = x * 2 y2 = x * 3 J1 = randn_((2, 2)) - with raises(ValueError): + with raises( + ValueError, + match=r"`jac_tensors` should have the same length as `tensors`\. \(got 1 and 2\)", + ): backward([y1, y2], jac_tensors=[J1], inputs=[x]) diff --git a/tests/unit/autojac/test_jac.py b/tests/unit/autojac/test_jac.py index d37552ce..93cd30e0 100644 --- a/tests/unit/autojac/test_jac.py +++ b/tests/unit/autojac/test_jac.py @@ -138,14 +138,17 @@ def test_jac_outputs_multiple_components(rows: int): def test_jac_outputs_length_mismatch(): - """Tests that jac raises an error if len(jac_outputs) != len(outputs).""" + """Tests that jac raises a ValueError early if len(jac_outputs) != len(outputs).""" x = tensor_([1.0, 2.0], requires_grad=True) y1 = x * 2 y2 = x * 3 J1 = randn_((2, 2)) - with raises(ValueError): + with raises( + ValueError, + match=r"`jac_outputs` should have the same length as `outputs`\. \(got 1 and 2\)", + ): jac([y1, y2], inputs=[x], jac_outputs=[J1]) From fb332a462d833bee841ac39183c1a1fe9ec4a0a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 16 Feb 2026 13:49:26 +0100 Subject: [PATCH 21/23] Add early validation for Jacobian shape mismatch Co-Authored-By: Claude Sonnet 4.5 --- src/torchjd/autojac/_backward.py | 2 ++ src/torchjd/autojac/_jac.py | 2 ++ src/torchjd/autojac/_utils.py | 25 +++++++++++++++++++++++++ tests/unit/autojac/test_backward.py | 9 ++++++--- tests/unit/autojac/test_jac.py | 9 ++++++--- 5 files changed, 41 insertions(+), 6 deletions(-) diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index d2a989f0..830ae2ba 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -7,6 +7,7 @@ as_checked_ordered_set, check_consistent_first_dimension, check_matching_length, + check_matching_shapes, check_optional_positive_chunk_size, get_leaf_tensors, ) @@ -138,6 +139,7 @@ def _create_jac_tensors_dict( return (diag << init)({}) jac_tensors = [opt_jac_tensors] if isinstance(opt_jac_tensors, Tensor) else opt_jac_tensors check_matching_length(jac_tensors, tensors, "jac_tensors", "tensors") + check_matching_shapes(jac_tensors, tensors, "jac_tensors", "tensors") check_consistent_first_dimension(jac_tensors, "jac_tensors") return dict(zip(tensors, jac_tensors, strict=True)) diff --git a/src/torchjd/autojac/_jac.py b/src/torchjd/autojac/_jac.py index 9ea12c6e..1d2a78b9 100644 --- a/src/torchjd/autojac/_jac.py +++ b/src/torchjd/autojac/_jac.py @@ -11,6 +11,7 @@ as_checked_ordered_set, check_consistent_first_dimension, check_matching_length, + check_matching_shapes, check_optional_positive_chunk_size, get_leaf_tensors, ) @@ -185,6 +186,7 @@ def _create_jac_outputs_dict( return (diag << init)({}) jac_outputs = [opt_jac_outputs] if isinstance(opt_jac_outputs, Tensor) else opt_jac_outputs check_matching_length(jac_outputs, outputs, "jac_outputs", "outputs") + check_matching_shapes(jac_outputs, outputs, "jac_outputs", "outputs") check_consistent_first_dimension(jac_outputs, "jac_outputs") return dict(zip(outputs, jac_outputs, strict=True)) diff --git a/src/torchjd/autojac/_utils.py b/src/torchjd/autojac/_utils.py index 3abab705..96c8ab49 100644 --- a/src/torchjd/autojac/_utils.py +++ b/src/torchjd/autojac/_utils.py @@ -53,6 +53,31 @@ def check_matching_length( ) +def check_matching_shapes( + jacobians: Iterable[Tensor], + tensors: Iterable[Tensor], + jacobian_variable_name: str, + tensor_variable_name: str, +) -> None: + """ + Checks that the shape of each Jacobian (excluding first dimension) matches the corresponding + tensor shape. + + :param jacobians: Sequence of Jacobian tensors to validate. + :param tensors: Sequence of tensors whose shapes should match. + :param jacobian_variable_name: Name of the Jacobian variable for error messages. + :param tensor_variable_name: Name of the tensor variable for error messages. + """ + for i, (jacobian, tensor) in enumerate(zip(jacobians, tensors, strict=True)): + if jacobian.shape[1:] != tensor.shape: + raise ValueError( + f"Shape mismatch: `{jacobian_variable_name}[{i}]` has shape {tuple(jacobian.shape)} " + f"but `{tensor_variable_name}[{i}]` has shape {tuple(tensor.shape)}. " + f"The shape of `{jacobian_variable_name}[{i}]` (excluding the first dimension) " + f"should match the shape of `{tensor_variable_name}[{i}]`.", + ) + + def check_consistent_first_dimension( jacobians: Sequence[Tensor], variable_name: str, diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index 3a345cfe..4d75d382 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -149,15 +149,18 @@ def test_jac_tensors_length_mismatch(): def test_jac_tensors_shape_mismatch(): """ - Tests that backward raises an error if the shape of a tensor in jac_tensors is incompatible with - the corresponding tensor. + Tests that backward raises a ValueError early if the shape of a tensor in jac_tensors is + incompatible with the corresponding tensor. """ x = tensor_([1.0, 2.0], requires_grad=True) y = x * 2 J_bad = randn_((3, 5)) - with raises((ValueError, RuntimeError)): + with raises( + ValueError, + match=r"Shape mismatch: `jac_tensors\[0\]` has shape .* but `tensors\[0\]` has shape .*\.", + ): backward(y, jac_tensors=J_bad, inputs=[x]) diff --git a/tests/unit/autojac/test_jac.py b/tests/unit/autojac/test_jac.py index 93cd30e0..28e1dcb1 100644 --- a/tests/unit/autojac/test_jac.py +++ b/tests/unit/autojac/test_jac.py @@ -154,15 +154,18 @@ def test_jac_outputs_length_mismatch(): def test_jac_outputs_shape_mismatch(): """ - Tests that jac raises an error if the shape of a tensor in jac_outputs is incompatible with - the corresponding output tensor. + Tests that jac raises a ValueError early if the shape of a tensor in jac_outputs is + incompatible with the corresponding output tensor. """ x = tensor_([1.0, 2.0], requires_grad=True) y = x * 2 J_bad = randn_((3, 5)) - with raises((ValueError, RuntimeError)): + with raises( + ValueError, + match=r"Shape mismatch: `jac_outputs\[0\]` has shape .* but `outputs\[0\]` has shape .*\.", + ): jac(y, inputs=[x], jac_outputs=J_bad) From 4c7fa3cde2ce0d81130e64d5c616f9c1fc0a6611 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 16 Feb 2026 17:38:45 +0100 Subject: [PATCH 22/23] Add changelog entries for jac_tensors and jac_outputs parameters Co-Authored-By: Claude Sonnet 4.5 --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a4c6d1f..e59893ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,11 @@ changelog does not include internal changes that do not affect the user. - Added the function `torchjd.autojac.jac`. It's the same as `torchjd.autojac.backward` except that it returns the Jacobians as a tuple instead of storing them in the `.jac` fields of the inputs. Its interface is analog to that of `torch.autograd.grad`. +- Added a `jac_tensors` parameter to `backward`, allowing to pre-multiply the Jacobian computation + by initial Jacobians. This enables multi-step chain rule computations and is analogous to the + `grad_tensors` parameter in `torch.autograd.backward`. +- Added a `jac_outputs` parameter to `jac`, allowing to pre-multiply the Jacobian computation by + initial Jacobians. This is analogous to the `grad_outputs` parameter in `torch.autograd.grad`. - Added a `scale_mode` parameter to `AlignedMTL` and `AlignedMTLWeighting`, allowing to choose between `"min"`, `"median"`, and `"rmse"` scaling. - Added an attribute `gramian_weighting` to all aggregators that use a gramian-based `Weighting`. From 51ecb4ada4948044b49b99d29feedca6232c87e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 16 Feb 2026 18:25:10 +0100 Subject: [PATCH 23/23] Apply claude suggestions --- src/torchjd/autojac/_backward.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index 64ae4f95..56751545 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -30,9 +30,10 @@ def backward( parameter of `torch.autograd.backward`. If provided, it must have the same structure as ``tensors`` and each tensor in ``jac_tensors`` must match the shape of the corresponding tensor in ``tensors``, with an extra leading dimension representing the number of rows of - the resulting Jacobian (e.g. the number of losses). If ``None``, defaults to the identity - matrix. In this case, the standard Jacobian of ``tensors`` is computed, with one row for - each value in the ``tensors``. + the resulting Jacobian (e.g. the number of losses). All tensors in ``jac_tensors`` must + have the same first dimension. If ``None``, defaults to the identity matrix. In this case, + the standard Jacobian of ``tensors`` is computed, with one row for each value in the + ``tensors``. :param inputs: The tensors with respect to which the Jacobians must be computed. These must have their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors that were used to compute the ``tensors`` parameter. @@ -70,8 +71,8 @@ def backward( .. admonition:: Example - This is the same example as before, except that we explicitly specify the identity - ``jac_tensors`` (which is equivalent to using the default `None`). + This is the same example as before, except that we explicitly specify ``jac_tensors`` as + the rows of the identity matrix (which is equivalent to using the default ``None``). >>> import torch >>>