99from torchjd .autojac ._transform ._ordered_set import OrderedSet
1010from torchjd .autojac ._utils import (
1111 as_checked_ordered_set ,
12+ check_consistent_first_dimension ,
13+ check_matching_length ,
14+ check_matching_shapes ,
1215 check_optional_positive_chunk_size ,
1316 get_leaf_tensors ,
1417)
1720def jac (
1821 outputs : Sequence [Tensor ] | Tensor ,
1922 inputs : Iterable [Tensor ] | None = None ,
23+ jac_outputs : Sequence [Tensor ] | Tensor | None = None ,
2024 retain_graph : bool = False ,
2125 parallel_chunk_size : int | None = None ,
2226) -> tuple [Tensor , ...]:
2327 r"""
24- Computes the Jacobian of all values in ``outputs`` with respect to all ``inputs``. Returns the
25- result as a tuple, with one Jacobian per input tensor. The returned Jacobian with respect to
26- input ``t`` has shape ``[m] + t.shape``.
28+ Computes the Jacobians of ``outputs`` with respect to ``inputs``, left-multiplied by
29+ ``jac_outputs`` (or identity if ``jac_outputs`` is ``None``), and returns the result as a tuple,
30+ with one Jacobian per input tensor. The returned Jacobian with respect to input ``t`` has shape
31+ ``[m] + t.shape``.
2732
28- :param outputs: The tensor or tensors to differentiate. Should be non-empty. The Jacobians will
29- have one row for each value of each of these tensors.
33+ :param outputs: The tensor or tensors to differentiate. Should be non-empty.
3034 :param inputs: The tensors with respect to which the Jacobian must be computed. These must have
3135 their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors
3236 that were used to compute the ``outputs`` parameter.
37+ :param jac_outputs: The initial Jacobians to backpropagate, analog to the ``grad_outputs``
38+ parameter of ``torch.autograd.grad``. If provided, it must have the same structure as
39+ ``outputs`` and each tensor in ``jac_outputs`` must match the shape of the corresponding
40+ tensor in ``outputs``, with an extra leading dimension representing the number of rows of
41+ the resulting Jacobian (e.g. the number of losses). If ``None``, defaults to the identity
42+ matrix. In this case, the standard Jacobian of ``outputs`` is computed, with one row for
43+ each value in the ``outputs``.
3344 :param retain_graph: If ``False``, the graph used to compute the grad will be freed. Defaults to
3445 ``False``.
3546 :param parallel_chunk_size: The number of scalars to differentiate simultaneously in the
@@ -60,7 +71,7 @@ def jac(
6071 >>> jacobians = jac([y1, y2], [param])
6172 >>>
6273 >>> jacobians
63- (tensor([-1., 1.],
74+ (tensor([[ -1., 1.],
6475 [ 2., 4.]]),)
6576
6677 .. admonition::
@@ -99,6 +110,34 @@ def jac(
99110 gradients are exactly orthogonal (they have an inner product of 0), but they conflict with
100111 the third gradient (inner product of -1 and -3).
101112
113+ .. admonition::
114+ Example
115+
116+ This example shows how to apply chain rule using the ``jac_outputs`` parameter to compute
117+ the Jacobian in two steps.
118+
119+ >>> import torch
120+ >>>
121+ >>> from torchjd.autojac import jac
122+ >>>
123+ >>> x = torch.tensor([1., 2.], requires_grad=True)
124+ >>> # Compose functions: x -> h -> y
125+ >>> h = x ** 2
126+ >>> y1 = h.sum()
127+ >>> y2 = torch.tensor([1., -1.]) @ h
128+ >>>
129+ >>> # Step 1: Compute d[y1,y2]/dh
130+ >>> jac_h = jac([y1, y2], [h])[0] # Shape: [2, 2]
131+ >>>
132+ >>> # Step 2: Use chain rule to compute d[y1,y2]/dx = (d[y1,y2]/dh) @ (dh/dx)
133+ >>> jac_x = jac(h, [x], jac_outputs=jac_h)[0]
134+ >>>
135+ >>> jac_x
136+ tensor([[ 2., 4.],
137+ [ 2., -4.]])
138+
139+ This two-step computation is equivalent to directly computing ``jac([y1, y2], [x])``.
140+
102141 .. warning::
103142 To differentiate in parallel, ``jac`` relies on ``torch.vmap``, which has some
104143 limitations: `it does not work on the output of compiled functions
@@ -122,30 +161,40 @@ def jac(
122161 inputs_with_repetition = list (inputs ) # Create a list to avoid emptying generator
123162 inputs_ = OrderedSet (inputs_with_repetition )
124163
125- jac_transform = _create_transform (
126- outputs = outputs_ ,
127- inputs = inputs_ ,
128- retain_graph = retain_graph ,
129- parallel_chunk_size = parallel_chunk_size ,
130- )
131-
132- result = jac_transform ({})
164+ jac_outputs_dict = _create_jac_outputs_dict (outputs_ , jac_outputs )
165+ transform = _create_transform (outputs_ , inputs_ , parallel_chunk_size , retain_graph )
166+ result = transform (jac_outputs_dict )
133167 return tuple (result [input ] for input in inputs_with_repetition )
134168
135169
170+ def _create_jac_outputs_dict (
171+ outputs : OrderedSet [Tensor ],
172+ opt_jac_outputs : Sequence [Tensor ] | Tensor | None ,
173+ ) -> dict [Tensor , Tensor ]:
174+ """
175+ Creates a dictionary mapping outputs to their corresponding Jacobians.
176+
177+ :param outputs: The tensors to differentiate.
178+ :param opt_jac_outputs: The initial Jacobians to backpropagate. If ``None``, defaults to
179+ identity.
180+ """
181+ if opt_jac_outputs is None :
182+ # Transform that creates gradient outputs containing only ones.
183+ init = Init (outputs )
184+ # Transform that turns the gradients into Jacobians.
185+ diag = Diagonalize (outputs )
186+ return (diag << init )({})
187+ jac_outputs = [opt_jac_outputs ] if isinstance (opt_jac_outputs , Tensor ) else opt_jac_outputs
188+ check_matching_length (jac_outputs , outputs , "jac_outputs" , "outputs" )
189+ check_matching_shapes (jac_outputs , outputs , "jac_outputs" , "outputs" )
190+ check_consistent_first_dimension (jac_outputs , "jac_outputs" )
191+ return dict (zip (outputs , jac_outputs , strict = True ))
192+
193+
136194def _create_transform (
137195 outputs : OrderedSet [Tensor ],
138196 inputs : OrderedSet [Tensor ],
139- retain_graph : bool ,
140197 parallel_chunk_size : int | None ,
198+ retain_graph : bool ,
141199) -> Transform :
142- # Transform that creates gradient outputs containing only ones.
143- init = Init (outputs )
144-
145- # Transform that turns the gradients into Jacobians.
146- diag = Diagonalize (outputs )
147-
148- # Transform that computes the required Jacobians.
149- jac = Jac (outputs , inputs , parallel_chunk_size , retain_graph )
150-
151- return jac << diag << init
200+ return Jac (outputs , inputs , parallel_chunk_size , retain_graph )
0 commit comments