diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a4c6d1f..e59893ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,11 @@ changelog does not include internal changes that do not affect the user. - Added the function `torchjd.autojac.jac`. It's the same as `torchjd.autojac.backward` except that it returns the Jacobians as a tuple instead of storing them in the `.jac` fields of the inputs. Its interface is analog to that of `torch.autograd.grad`. +- Added a `jac_tensors` parameter to `backward`, allowing to pre-multiply the Jacobian computation + by initial Jacobians. This enables multi-step chain rule computations and is analogous to the + `grad_tensors` parameter in `torch.autograd.backward`. +- Added a `jac_outputs` parameter to `jac`, allowing to pre-multiply the Jacobian computation by + initial Jacobians. This is analogous to the `grad_outputs` parameter in `torch.autograd.grad`. - Added a `scale_mode` parameter to `AlignedMTL` and `AlignedMTLWeighting`, allowing to choose between `"min"`, `"median"`, and `"rmse"` scaling. - Added an attribute `gramian_weighting` to all aggregators that use a gramian-based `Weighting`. diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index 3b2fda6c..56751545 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -3,21 +3,37 @@ from torch import Tensor from ._transform import AccumulateJac, Diagonalize, Init, Jac, OrderedSet, Transform -from ._utils import as_checked_ordered_set, check_optional_positive_chunk_size, get_leaf_tensors +from ._utils import ( + as_checked_ordered_set, + check_consistent_first_dimension, + check_matching_length, + check_matching_shapes, + check_optional_positive_chunk_size, + get_leaf_tensors, +) def backward( tensors: Sequence[Tensor] | Tensor, + jac_tensors: Sequence[Tensor] | Tensor | None = None, inputs: Iterable[Tensor] | None = None, retain_graph: bool = False, parallel_chunk_size: int | None = None, ) -> None: r""" - Computes the Jacobians of all values in ``tensors`` with respect to all ``inputs`` and - accumulates them in the ``.jac`` fields of the ``inputs``. - - :param tensors: The tensor or tensors to differentiate. Should be non-empty. The Jacobians will - have one row for each value of each of these tensors. + Computes the Jacobians of ``tensors`` with respect to ``inputs``, left-multiplied by + ``jac_tensors`` (or identity if ``jac_tensors`` is ``None``), and accumulates the results in the + ``.jac`` fields of the ``inputs``. + + :param tensors: The tensor or tensors to differentiate. Should be non-empty. + :param jac_tensors: The initial Jacobians to backpropagate, analog to the `grad_tensors` + parameter of `torch.autograd.backward`. If provided, it must have the same structure as + ``tensors`` and each tensor in ``jac_tensors`` must match the shape of the corresponding + tensor in ``tensors``, with an extra leading dimension representing the number of rows of + the resulting Jacobian (e.g. the number of losses). All tensors in ``jac_tensors`` must + have the same first dimension. If ``None``, defaults to the identity matrix. In this case, + the standard Jacobian of ``tensors`` is computed, with one row for each value in the + ``tensors``. :param inputs: The tensors with respect to which the Jacobians must be computed. These must have their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors that were used to compute the ``tensors`` parameter. @@ -32,7 +48,7 @@ def backward( .. admonition:: Example - The following code snippet showcases a simple usage of ``backward``. + This example shows a simple usage of ``backward``. >>> import torch >>> @@ -52,6 +68,33 @@ def backward( The ``.jac`` field of ``param`` now contains the Jacobian of :math:`\begin{bmatrix}y_1 \\ y_2\end{bmatrix}` with respect to ``param``. + .. admonition:: + Example + + This is the same example as before, except that we explicitly specify ``jac_tensors`` as + the rows of the identity matrix (which is equivalent to using the default ``None``). + + >>> import torch + >>> + >>> from torchjd.autojac import backward + >>> + >>> param = torch.tensor([1., 2.], requires_grad=True) + >>> # Compute arbitrary quantities that are function of param + >>> y1 = torch.tensor([-1., 1.]) @ param + >>> y2 = (param ** 2).sum() + >>> + >>> J1 = torch.tensor([1.0, 0.0]) + >>> J2 = torch.tensor([0.0, 1.0]) + >>> + >>> backward([y1, y2], jac_tensors=[J1, J2]) + >>> + >>> param.jac + tensor([[-1., 1.], + [ 2., 4.]]) + + Instead of using the identity ``jac_tensors``, you can backpropagate some Jacobians obtained + by a call to :func:`torchjd.autojac.jac` on a later part of the computation graph. + .. warning:: To differentiate in parallel, ``backward`` relies on ``torch.vmap``, which has some limitations: `it does not work on the output of compiled functions @@ -73,34 +116,44 @@ def backward( else: inputs_ = OrderedSet(inputs) - backward_transform = _create_transform( - tensors=tensors_, - inputs=inputs_, - retain_graph=retain_graph, - parallel_chunk_size=parallel_chunk_size, - ) + jac_tensors_dict = _create_jac_tensors_dict(tensors_, jac_tensors) + transform = _create_transform(tensors_, inputs_, parallel_chunk_size, retain_graph) + transform(jac_tensors_dict) + + +def _create_jac_tensors_dict( + tensors: OrderedSet[Tensor], + opt_jac_tensors: Sequence[Tensor] | Tensor | None, +) -> dict[Tensor, Tensor]: + """ + Creates a dictionary mapping tensors to their corresponding Jacobians. - backward_transform({}) + :param tensors: The tensors to differentiate. + :param opt_jac_tensors: The initial Jacobians to backpropagate. If ``None``, defaults to + identity. + """ + if opt_jac_tensors is None: + # Transform that creates gradient outputs containing only ones. + init = Init(tensors) + # Transform that turns the gradients into Jacobians. + diag = Diagonalize(tensors) + return (diag << init)({}) + jac_tensors = [opt_jac_tensors] if isinstance(opt_jac_tensors, Tensor) else opt_jac_tensors + check_matching_length(jac_tensors, tensors, "jac_tensors", "tensors") + check_matching_shapes(jac_tensors, tensors, "jac_tensors", "tensors") + check_consistent_first_dimension(jac_tensors, "jac_tensors") + return dict(zip(tensors, jac_tensors, strict=True)) def _create_transform( tensors: OrderedSet[Tensor], inputs: OrderedSet[Tensor], - retain_graph: bool, parallel_chunk_size: int | None, + retain_graph: bool, ) -> Transform: - """Creates the backward transform.""" - - # Transform that creates gradient outputs containing only ones. - init = Init(tensors) - - # Transform that turns the gradients into Jacobians. - diag = Diagonalize(tensors) - + """Creates the backward transform that computes and accumulates Jacobians.""" # Transform that computes the required Jacobians. jac = Jac(tensors, inputs, parallel_chunk_size, retain_graph) - # Transform that accumulates the result in the .jac field of the inputs. accumulate = AccumulateJac() - - return accumulate << jac << diag << init + return accumulate << jac diff --git a/src/torchjd/autojac/_jac.py b/src/torchjd/autojac/_jac.py index 1c809d2d..1d2a78b9 100644 --- a/src/torchjd/autojac/_jac.py +++ b/src/torchjd/autojac/_jac.py @@ -9,6 +9,9 @@ from torchjd.autojac._transform._ordered_set import OrderedSet from torchjd.autojac._utils import ( as_checked_ordered_set, + check_consistent_first_dimension, + check_matching_length, + check_matching_shapes, check_optional_positive_chunk_size, get_leaf_tensors, ) @@ -17,19 +20,27 @@ def jac( outputs: Sequence[Tensor] | Tensor, inputs: Iterable[Tensor] | None = None, + jac_outputs: Sequence[Tensor] | Tensor | None = None, retain_graph: bool = False, parallel_chunk_size: int | None = None, ) -> tuple[Tensor, ...]: r""" - Computes the Jacobian of all values in ``outputs`` with respect to all ``inputs``. Returns the - result as a tuple, with one Jacobian per input tensor. The returned Jacobian with respect to - input ``t`` has shape ``[m] + t.shape``. + Computes the Jacobians of ``outputs`` with respect to ``inputs``, left-multiplied by + ``jac_outputs`` (or identity if ``jac_outputs`` is ``None``), and returns the result as a tuple, + with one Jacobian per input tensor. The returned Jacobian with respect to input ``t`` has shape + ``[m] + t.shape``. - :param outputs: The tensor or tensors to differentiate. Should be non-empty. The Jacobians will - have one row for each value of each of these tensors. + :param outputs: The tensor or tensors to differentiate. Should be non-empty. :param inputs: The tensors with respect to which the Jacobian must be computed. These must have their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors that were used to compute the ``outputs`` parameter. + :param jac_outputs: The initial Jacobians to backpropagate, analog to the ``grad_outputs`` + parameter of ``torch.autograd.grad``. If provided, it must have the same structure as + ``outputs`` and each tensor in ``jac_outputs`` must match the shape of the corresponding + tensor in ``outputs``, with an extra leading dimension representing the number of rows of + the resulting Jacobian (e.g. the number of losses). If ``None``, defaults to the identity + matrix. In this case, the standard Jacobian of ``outputs`` is computed, with one row for + each value in the ``outputs``. :param retain_graph: If ``False``, the graph used to compute the grad will be freed. Defaults to ``False``. :param parallel_chunk_size: The number of scalars to differentiate simultaneously in the @@ -60,7 +71,7 @@ def jac( >>> jacobians = jac([y1, y2], [param]) >>> >>> jacobians - (tensor([-1., 1.], + (tensor([[-1., 1.], [ 2., 4.]]),) .. admonition:: @@ -99,6 +110,34 @@ def jac( gradients are exactly orthogonal (they have an inner product of 0), but they conflict with the third gradient (inner product of -1 and -3). + .. admonition:: + Example + + This example shows how to apply chain rule using the ``jac_outputs`` parameter to compute + the Jacobian in two steps. + + >>> import torch + >>> + >>> from torchjd.autojac import jac + >>> + >>> x = torch.tensor([1., 2.], requires_grad=True) + >>> # Compose functions: x -> h -> y + >>> h = x ** 2 + >>> y1 = h.sum() + >>> y2 = torch.tensor([1., -1.]) @ h + >>> + >>> # Step 1: Compute d[y1,y2]/dh + >>> jac_h = jac([y1, y2], [h])[0] # Shape: [2, 2] + >>> + >>> # Step 2: Use chain rule to compute d[y1,y2]/dx = (d[y1,y2]/dh) @ (dh/dx) + >>> jac_x = jac(h, [x], jac_outputs=jac_h)[0] + >>> + >>> jac_x + tensor([[ 2., 4.], + [ 2., -4.]]) + + This two-step computation is equivalent to directly computing ``jac([y1, y2], [x])``. + .. warning:: To differentiate in parallel, ``jac`` relies on ``torch.vmap``, which has some limitations: `it does not work on the output of compiled functions @@ -122,30 +161,40 @@ def jac( inputs_with_repetition = list(inputs) # Create a list to avoid emptying generator inputs_ = OrderedSet(inputs_with_repetition) - jac_transform = _create_transform( - outputs=outputs_, - inputs=inputs_, - retain_graph=retain_graph, - parallel_chunk_size=parallel_chunk_size, - ) - - result = jac_transform({}) + jac_outputs_dict = _create_jac_outputs_dict(outputs_, jac_outputs) + transform = _create_transform(outputs_, inputs_, parallel_chunk_size, retain_graph) + result = transform(jac_outputs_dict) return tuple(result[input] for input in inputs_with_repetition) +def _create_jac_outputs_dict( + outputs: OrderedSet[Tensor], + opt_jac_outputs: Sequence[Tensor] | Tensor | None, +) -> dict[Tensor, Tensor]: + """ + Creates a dictionary mapping outputs to their corresponding Jacobians. + + :param outputs: The tensors to differentiate. + :param opt_jac_outputs: The initial Jacobians to backpropagate. If ``None``, defaults to + identity. + """ + if opt_jac_outputs is None: + # Transform that creates gradient outputs containing only ones. + init = Init(outputs) + # Transform that turns the gradients into Jacobians. + diag = Diagonalize(outputs) + return (diag << init)({}) + jac_outputs = [opt_jac_outputs] if isinstance(opt_jac_outputs, Tensor) else opt_jac_outputs + check_matching_length(jac_outputs, outputs, "jac_outputs", "outputs") + check_matching_shapes(jac_outputs, outputs, "jac_outputs", "outputs") + check_consistent_first_dimension(jac_outputs, "jac_outputs") + return dict(zip(outputs, jac_outputs, strict=True)) + + def _create_transform( outputs: OrderedSet[Tensor], inputs: OrderedSet[Tensor], - retain_graph: bool, parallel_chunk_size: int | None, + retain_graph: bool, ) -> Transform: - # Transform that creates gradient outputs containing only ones. - init = Init(outputs) - - # Transform that turns the gradients into Jacobians. - diag = Diagonalize(outputs) - - # Transform that computes the required Jacobians. - jac = Jac(outputs, inputs, parallel_chunk_size, retain_graph) - - return jac << diag << init + return Jac(outputs, inputs, parallel_chunk_size, retain_graph) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 06c9c47d..9e46caa0 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -6,6 +6,7 @@ from torchjd.aggregation import Aggregator from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac +from ._utils import check_consistent_first_dimension def jac_to_grad( @@ -67,8 +68,7 @@ def jac_to_grad( jacobians = [t.jac for t in tensors_] - if not all(jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:]): - raise ValueError("All Jacobians should have the same number of rows.") + check_consistent_first_dimension(jacobians, "tensors.jac") if not retain_jac: _free_jacs(tensors_) diff --git a/src/torchjd/autojac/_utils.py b/src/torchjd/autojac/_utils.py index 2004c483..6109db11 100644 --- a/src/torchjd/autojac/_utils.py +++ b/src/torchjd/autojac/_utils.py @@ -1,5 +1,5 @@ from collections import deque -from collections.abc import Iterable, Sequence +from collections.abc import Iterable, Sequence, Sized from typing import cast from torch import Tensor @@ -32,6 +32,68 @@ def as_checked_ordered_set( return OrderedSet(tensors) +def check_matching_length( + seq1: Sized, + seq2: Sized, + variable_name1: str, + variable_name2: str, +) -> None: + """ + Checks that two sequences have the same length. + + :param seq1: First sequence to validate. + :param seq2: Second sequence to validate. + :param variable_name1: Name of the first variable to include in the error message. + :param variable_name2: Name of the second variable to include in the error message. + """ + if len(seq1) != len(seq2): + raise ValueError( + f"`{variable_name1}` should have the same length as `{variable_name2}`. " + f"(got {len(seq1)} and {len(seq2)})", + ) + + +def check_matching_shapes( + jacobians: Iterable[Tensor], + tensors: Iterable[Tensor], + jacobian_variable_name: str, + tensor_variable_name: str, +) -> None: + """ + Checks that the shape of each Jacobian (excluding first dimension) matches the corresponding + tensor shape. + + :param jacobians: Sequence of Jacobian tensors to validate. + :param tensors: Sequence of tensors whose shapes should match. + :param jacobian_variable_name: Name of the Jacobian variable for error messages. + :param tensor_variable_name: Name of the tensor variable for error messages. + """ + for i, (jacobian, tensor) in enumerate(zip(jacobians, tensors, strict=True)): + if jacobian.shape[1:] != tensor.shape: + raise ValueError( + f"Shape mismatch: `{jacobian_variable_name}[{i}]` has shape {tuple(jacobian.shape)} " + f"but `{tensor_variable_name}[{i}]` has shape {tuple(tensor.shape)}. " + f"The shape of `{jacobian_variable_name}[{i}]` (excluding the first dimension) " + f"should match the shape of `{tensor_variable_name}[{i}]`.", + ) + + +def check_consistent_first_dimension( + jacobians: Sequence[Tensor], + variable_name: str, +) -> None: + """ + Checks that all Jacobians have the same first dimension (number of rows). + + :param jacobians: Sequence of Jacobian tensors to validate. + :param variable_name: Name of the variable to include in the error message. + """ + if len(jacobians) > 0 and not all( + jacobian.shape[0] == jacobians[0].shape[0] for jacobian in jacobians[1:] + ): + raise ValueError(f"All Jacobians in `{variable_name}` should have the same number of rows.") + + def get_leaf_tensors(tensors: Iterable[Tensor], excluded: Iterable[Tensor]) -> OrderedSet[Tensor]: """ Gets the leaves of the autograd graph of all specified ``tensors``. diff --git a/tests/doc/test_backward.py b/tests/doc/test_backward.py index 032f902d..2416210e 100644 --- a/tests/doc/test_backward.py +++ b/tests/doc/test_backward.py @@ -19,3 +19,21 @@ def test_backward(): backward([y1, y2]) assert_jac_close(param, torch.tensor([[-1.0, 1.0], [2.0, 4.0]]), rtol=0.0, atol=1e-04) + + +def test_backward2(): + import torch + + from torchjd.autojac import backward + + param = torch.tensor([1.0, 2.0], requires_grad=True) + # Compute arbitrary quantities that are function of param + y1 = torch.tensor([-1.0, 1.0]) @ param + y2 = (param**2).sum() + + J1 = torch.tensor([1.0, 0.0]) + J2 = torch.tensor([0.0, 1.0]) + + backward([y1, y2], jac_tensors=[J1, J2]) + + assert_jac_close(param, torch.tensor([[-1.0, 1.0], [2.0, 4.0]]), rtol=0.0, atol=1e-04) diff --git a/tests/doc/test_jac.py b/tests/doc/test_jac.py index f422545a..92e8745c 100644 --- a/tests/doc/test_jac.py +++ b/tests/doc/test_jac.py @@ -42,3 +42,21 @@ def test_jac_2(): rtol=0.0, atol=1e-04, ) + + +def test_jac_3(): + import torch + + from torchjd.autojac import jac + + x = torch.tensor([1.0, 2.0], requires_grad=True) + # Compose functions: x -> h -> y + h = x**2 + y1 = h.sum() + y2 = torch.tensor([1.0, -1.0]) @ h + # Step 1: Compute d[y1,y2]/dh + jac_h = jac([y1, y2], [h])[0] # Shape: [2, 2] + # Step 2: Use jac_outputs to compute d[y1,y2]/dx = (d[y1,y2]/dh) @ (dh/dx) + jac_x = jac(h, [x], jac_outputs=jac_h)[0] + + assert_close(jac_x, torch.tensor([[2.0, 4.0], [2.0, -4.0]]), rtol=0.0, atol=1e-04) diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index 80bd06e3..4d75d382 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -1,14 +1,15 @@ import torch from pytest import mark, raises from utils.asserts import assert_has_jac, assert_has_no_jac, assert_jac_close -from utils.tensors import randn_, tensor_ +from utils.tensors import eye_, randn_, tensor_ from torchjd.autojac import backward -from torchjd.autojac._backward import _create_transform +from torchjd.autojac._backward import _create_jac_tensors_dict, _create_transform from torchjd.autojac._transform import OrderedSet -def test_check_create_transform(): +@mark.parametrize("default_jac_tensors", [True, False]) +def test_check_create_transform(default_jac_tensors: bool): """Tests that _create_transform creates a valid Transform.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -17,6 +18,14 @@ def test_check_create_transform(): y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum() y2 = (a1**2).sum() + a2.norm() + optional_jac_tensors = ( + None if default_jac_tensors else [tensor_([1.0, 0.0]), tensor_([0.0, 1.0])] + ) + + jac_tensors = _create_jac_tensors_dict( + tensors=OrderedSet([y1, y2]), + opt_jac_tensors=optional_jac_tensors, + ) transform = _create_transform( tensors=OrderedSet([y1, y2]), inputs=OrderedSet([a1, a2]), @@ -24,7 +33,7 @@ def test_check_create_transform(): parallel_chunk_size=None, ) - output_keys = transform.check_keys(set()) + output_keys = transform.check_keys(set(jac_tensors.keys())) assert output_keys == set() @@ -71,6 +80,116 @@ def test_value_is_correct( assert_jac_close(input, J) +@mark.parametrize("rows", [1, 2, 5]) +def test_jac_tensors_value_is_correct(rows: int): + """ + Tests that backward correctly computes the product of jac_tensors and the Jacobian. + result = jac_tensors @ Jacobian(tensors, inputs). + """ + input_size = 4 + output_size = 3 + + J_model = randn_((output_size, input_size)) + + input = randn_([input_size], requires_grad=True) + tensor = J_model @ input + + J_init = randn_((rows, output_size)) + + backward( + tensor, + jac_tensors=J_init, + inputs=[input], + ) + + expected_jac = J_init @ J_model + assert_jac_close(input, expected_jac) + + +@mark.parametrize("rows", [1, 3]) +def test_jac_tensors_multiple_components(rows: int): + """ + Tests that jac_tensors works correctly when tensors is a list of multiple tensors. The + jac_tensors must match the structure of tensors. + """ + input_len = 2 + input = randn_([input_len], requires_grad=True) + + y1 = input * 2 + y2 = torch.cat([input, input[:1]]) + + J1 = randn_((rows, 2)) + J2 = randn_((rows, 3)) + + backward([y1, y2], jac_tensors=[J1, J2], inputs=[input]) + + jac_y1 = eye_(2) * 2 + + jac_y2 = tensor_([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]]) + + expected = J1 @ jac_y1 + J2 @ jac_y2 + + assert_jac_close(input, expected) + + +def test_jac_tensors_length_mismatch(): + """Tests that backward raises a ValueError early if len(jac_tensors) != len(tensors).""" + x = tensor_([1.0, 2.0], requires_grad=True) + y1 = x * 2 + y2 = x * 3 + + J1 = randn_((2, 2)) + + with raises( + ValueError, + match=r"`jac_tensors` should have the same length as `tensors`\. \(got 1 and 2\)", + ): + backward([y1, y2], jac_tensors=[J1], inputs=[x]) + + +def test_jac_tensors_shape_mismatch(): + """ + Tests that backward raises a ValueError early if the shape of a tensor in jac_tensors is + incompatible with the corresponding tensor. + """ + x = tensor_([1.0, 2.0], requires_grad=True) + y = x * 2 + + J_bad = randn_((3, 5)) + + with raises( + ValueError, + match=r"Shape mismatch: `jac_tensors\[0\]` has shape .* but `tensors\[0\]` has shape .*\.", + ): + backward(y, jac_tensors=J_bad, inputs=[x]) + + +@mark.parametrize( + "rows_y1, rows_y2", + [ + (3, 5), + (1, 2), + ], +) +def test_jac_tensors_inconsistent_first_dimension(rows_y1: int, rows_y2: int): + """ + Tests that backward raises a ValueError early when the provided jac_tensors have inconsistent + first dimensions. + """ + x = tensor_([1.0, 2.0], requires_grad=True) + + y1 = x * 2 + y2 = x.sum() + + j1 = randn_((rows_y1, 2)) + j2 = randn_((rows_y2,)) + + with raises( + ValueError, match=r"All Jacobians in `jac_tensors` should have the same number of rows\." + ): + backward([y1, y2], jac_tensors=[j1, j2], inputs=[x]) + + def test_empty_inputs(): """Tests that backward does not fill the .jac values if no input is specified.""" diff --git a/tests/unit/autojac/test_jac.py b/tests/unit/autojac/test_jac.py index 3d776a8d..28e1dcb1 100644 --- a/tests/unit/autojac/test_jac.py +++ b/tests/unit/autojac/test_jac.py @@ -1,14 +1,15 @@ import torch from pytest import mark, raises from torch.testing import assert_close -from utils.tensors import randn_, tensor_ +from utils.tensors import eye_, randn_, tensor_ from torchjd.autojac import jac -from torchjd.autojac._jac import _create_transform +from torchjd.autojac._jac import _create_jac_outputs_dict, _create_transform from torchjd.autojac._transform import OrderedSet -def test_check_create_transform(): +@mark.parametrize("default_jac_outputs", [True, False]) +def test_check_create_transform(default_jac_outputs: bool): """Tests that _create_transform creates a valid Transform.""" a1 = tensor_([1.0, 2.0], requires_grad=True) @@ -17,14 +18,22 @@ def test_check_create_transform(): y1 = tensor_([-1.0, 1.0]) @ a1 + a2.sum() y2 = (a1**2).sum() + a2.norm() + optional_jac_outputs = ( + None if default_jac_outputs else [tensor_([1.0, 0.0]), tensor_([0.0, 1.0])] + ) + + jac_outputs = _create_jac_outputs_dict( + outputs=OrderedSet([y1, y2]), + opt_jac_outputs=optional_jac_outputs, + ) transform = _create_transform( outputs=OrderedSet([y1, y2]), inputs=OrderedSet([a1, a2]), - retain_graph=False, parallel_chunk_size=None, + retain_graph=False, ) - output_keys = transform.check_keys(set()) + output_keys = transform.check_keys(set(jac_outputs.keys())) assert output_keys == {a1, a2} @@ -76,6 +85,116 @@ def test_value_is_correct( assert_close(jacobians[0], J) +@mark.parametrize("rows", [1, 2, 5]) +def test_jac_outputs_value_is_correct(rows: int): + """ + Tests that jac correctly computes the product of jac_outputs and the Jacobian. + result = jac_outputs @ Jacobian(outputs, inputs). + """ + input_size = 4 + output_size = 3 + + J_model = randn_((output_size, input_size)) + + input = randn_([input_size], requires_grad=True) + output = J_model @ input + + J_init = randn_((rows, output_size)) + + jacobians = jac( + output, + inputs=[input], + jac_outputs=J_init, + ) + + expected_jac = J_init @ J_model + assert_close(jacobians[0], expected_jac) + + +@mark.parametrize("rows", [1, 3]) +def test_jac_outputs_multiple_components(rows: int): + """ + Tests that jac_outputs works correctly when outputs is a list of multiple tensors. The + jac_outputs must match the structure of outputs. + """ + input_len = 2 + input = randn_([input_len], requires_grad=True) + + y1 = input * 2 + y2 = torch.cat([input, input[:1]]) + + J1 = randn_((rows, 2)) + J2 = randn_((rows, 3)) + + jacobians = jac([y1, y2], inputs=[input], jac_outputs=[J1, J2]) + + jac_y1 = eye_(2) * 2 + + jac_y2 = tensor_([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]]) + + expected = J1 @ jac_y1 + J2 @ jac_y2 + + assert_close(jacobians[0], expected) + + +def test_jac_outputs_length_mismatch(): + """Tests that jac raises a ValueError early if len(jac_outputs) != len(outputs).""" + x = tensor_([1.0, 2.0], requires_grad=True) + y1 = x * 2 + y2 = x * 3 + + J1 = randn_((2, 2)) + + with raises( + ValueError, + match=r"`jac_outputs` should have the same length as `outputs`\. \(got 1 and 2\)", + ): + jac([y1, y2], inputs=[x], jac_outputs=[J1]) + + +def test_jac_outputs_shape_mismatch(): + """ + Tests that jac raises a ValueError early if the shape of a tensor in jac_outputs is + incompatible with the corresponding output tensor. + """ + x = tensor_([1.0, 2.0], requires_grad=True) + y = x * 2 + + J_bad = randn_((3, 5)) + + with raises( + ValueError, + match=r"Shape mismatch: `jac_outputs\[0\]` has shape .* but `outputs\[0\]` has shape .*\.", + ): + jac(y, inputs=[x], jac_outputs=J_bad) + + +@mark.parametrize( + "rows_y1, rows_y2", + [ + (3, 5), + (1, 2), + ], +) +def test_jac_outputs_inconsistent_first_dimension(rows_y1: int, rows_y2: int): + """ + Tests that jac raises a ValueError early when the provided jac_outputs have inconsistent first + dimensions. + """ + x = tensor_([1.0, 2.0], requires_grad=True) + + y1 = x * 2 + y2 = x.sum() + + j1 = randn_((rows_y1, 2)) + j2 = randn_((rows_y2,)) + + with raises( + ValueError, match=r"All Jacobians in `jac_outputs` should have the same number of rows\." + ): + jac([y1, y2], inputs=[x], jac_outputs=[j1, j2]) + + def test_empty_inputs(): """Tests that jac does not return any jacobian no input is specified."""