Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ changelog does not include internal changes that do not affect the user.
- Added a `jac_tensors` parameter to `backward`, allowing to pre-multiply the Jacobian computation
by initial Jacobians. This enables multi-step chain rule computations and is analogous to the
`grad_tensors` parameter in `torch.autograd.backward`.
- Added a `grad_tensors` parameter to `mtl_backward`, allowing to use non-scalar `losses` (now
renamed to `tensors`). This is analogous to the `grad_tensors` parameter of
`torch.autograd.backward`. When using `scalar` losses, the usage does not change.
- Added a `jac_outputs` parameter to `jac`, allowing to pre-multiply the Jacobian computation by
initial Jacobians. This is analogous to the `grad_outputs` parameter in `torch.autograd.grad`.
- Added a `scale_mode` parameter to `AlignedMTL` and `AlignedMTLWeighting`, allowing to choose
Expand Down Expand Up @@ -50,6 +53,7 @@ changelog does not include internal changes that do not affect the user.
mtl_backward(losses, features)
jac_to_grad(shared_module.parameters(), aggregator)
```
- **BREAKING**: Renamed the `losses` parameter of `mtl_backward` to `tensors`.

- Removed an unnecessary memory duplication. This should significantly improve the memory efficiency
of `autojac`.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/examples/amp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ following example shows the resulting code for a multi-task learning use-case.
loss2 = loss_fn(output2, target2)

scaled_losses = scaler.scale([loss1, loss2])
mtl_backward(losses=scaled_losses, features=features)
mtl_backward(tensors=scaled_losses, features=features)
jac_to_grad(shared_module.parameters(), aggregator)
scaler.step(optimizer)
scaler.update()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/examples/lightning_integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ The following code example demonstrates a basic multi-task learning setup using
loss2 = mse_loss(output2, target2)

opt = self.optimizers()
mtl_backward(losses=[loss1, loss2], features=features)
mtl_backward(tensors=[loss1, loss2], features=features)
jac_to_grad(self.feature_extractor.parameters(), UPGrad())
opt.step()
opt.zero_grad()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/examples/monitoring.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ they have a negative inner product).
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)

mtl_backward(losses=[loss1, loss2], features=features)
mtl_backward(tensors=[loss1, loss2], features=features)
jac_to_grad(shared_module.parameters(), aggregator)
optimizer.step()
optimizer.zero_grad()
2 changes: 1 addition & 1 deletion docs/source/examples/mtl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks.
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)

mtl_backward(losses=[loss1, loss2], features=features)
mtl_backward(tensors=[loss1, loss2], features=features)
jac_to_grad(shared_module.parameters(), aggregator)
optimizer.step()
optimizer.zero_grad()
Expand Down
6 changes: 3 additions & 3 deletions src/torchjd/autojac/_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from ._utils import (
as_checked_ordered_set,
check_consistent_first_dimension,
check_matching_jac_shapes,
check_matching_length,
check_matching_shapes,
check_optional_positive_chunk_size,
get_leaf_tensors,
)
Expand Down Expand Up @@ -109,7 +109,7 @@ def backward(
tensors_ = as_checked_ordered_set(tensors, "tensors")

if len(tensors_) == 0:
raise ValueError("`tensors` cannot be empty")
raise ValueError("`tensors` cannot be empty.")

if inputs is None:
inputs_ = get_leaf_tensors(tensors=tensors_, excluded=set())
Expand Down Expand Up @@ -140,7 +140,7 @@ def _create_jac_tensors_dict(
return (diag << init)({})
jac_tensors = [opt_jac_tensors] if isinstance(opt_jac_tensors, Tensor) else opt_jac_tensors
check_matching_length(jac_tensors, tensors, "jac_tensors", "tensors")
check_matching_shapes(jac_tensors, tensors, "jac_tensors", "tensors")
check_matching_jac_shapes(jac_tensors, tensors, "jac_tensors", "tensors")
check_consistent_first_dimension(jac_tensors, "jac_tensors")
return dict(zip(tensors, jac_tensors, strict=True))

Expand Down
6 changes: 3 additions & 3 deletions src/torchjd/autojac/_jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from torchjd.autojac._utils import (
as_checked_ordered_set,
check_consistent_first_dimension,
check_matching_jac_shapes,
check_matching_length,
check_matching_shapes,
check_optional_positive_chunk_size,
get_leaf_tensors,
)
Expand Down Expand Up @@ -152,7 +152,7 @@ def jac(

outputs_ = as_checked_ordered_set(outputs, "outputs")
if len(outputs_) == 0:
raise ValueError("`outputs` cannot be empty")
raise ValueError("`outputs` cannot be empty.")

if inputs is None:
inputs_ = get_leaf_tensors(tensors=outputs_, excluded=set())
Expand Down Expand Up @@ -186,7 +186,7 @@ def _create_jac_outputs_dict(
return (diag << init)({})
jac_outputs = [opt_jac_outputs] if isinstance(opt_jac_outputs, Tensor) else opt_jac_outputs
check_matching_length(jac_outputs, outputs, "jac_outputs", "outputs")
check_matching_shapes(jac_outputs, outputs, "jac_outputs", "outputs")
check_matching_jac_shapes(jac_outputs, outputs, "jac_outputs", "outputs")
check_consistent_first_dimension(jac_outputs, "jac_outputs")
return dict(zip(outputs, jac_outputs, strict=True))

Expand Down
110 changes: 71 additions & 39 deletions src/torchjd/autojac/_mtl_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,19 @@
Stack,
Transform,
)
from ._utils import as_checked_ordered_set, check_optional_positive_chunk_size, get_leaf_tensors
from ._utils import (
as_checked_ordered_set,
check_matching_grad_shapes,
check_matching_length,
check_optional_positive_chunk_size,
get_leaf_tensors,
)


def mtl_backward(
losses: Sequence[Tensor],
tensors: Sequence[Tensor],
features: Sequence[Tensor] | Tensor,
grad_tensors: Sequence[Tensor] | None = None,
tasks_params: Sequence[Iterable[Tensor]] | None = None,
shared_params: Iterable[Tensor] | None = None,
retain_graph: bool = False,
Expand All @@ -28,17 +35,28 @@ def mtl_backward(
In the context of Multi-Task Learning (MTL), we often have a shared feature extractor followed
by several task-specific heads. A loss can then be computed for each task.

This function computes the gradient of each task-specific loss with respect to its task-specific
parameters and accumulates it in their ``.grad`` fields. Then, it computes the Jacobian of all
losses with respect to the shared parameters and accumulates it in their ``.jac`` fields.
This function computes the gradient of each task-specific tensor with respect to its
task-specific parameters and accumulates it in their ``.grad`` fields. It also computes the
Jacobian of all tensors with respect to the shared parameters and accumulates it in their
``.jac`` fields. These Jacobians have one row per task.

If the ``tensors`` are non-scalar, ``mtl_backward`` requires some initial gradients in
``grad_tensors``. This allows to compose ``mtl_backward`` with some other function computing
the gradients with respect to the tensors (chain rule).

:param losses: The task losses. The Jacobians will have one row per loss.
:param tensors: The task-specific tensors. If these are scalar (e.g. the losses produced by
every task), no ``grad_tensors`` are needed. If these are non-scalar tensors, providing some
``grad_tensors`` is necessary.
:param features: The last shared representation used for all tasks, as given by the feature
extractor. Should be non-empty.
:param grad_tensors: The initial gradients to backpropagate, analog to the ``grad_tensors``
parameter of ``torch.autograd.backward``. If any of the ``tensors`` is non-scalar,
``grad_tensors`` must be provided, with the same length and shapes as ``tensors``.
Otherwise, this parameter is not needed and will default to scalars of 1.
:param tasks_params: The parameters of each task-specific head. Their ``requires_grad`` flags
must be set to ``True``. If not provided, the parameters considered for each task will
default to the leaf tensors that are in the computation graph of its loss, but that were not
used to compute the ``features``.
default to the leaf tensors that are in the computation graph of its tensor, but that were
not used to compute the ``features``.
:param shared_params: The parameters of the shared feature extractor. Their ``requires_grad``
flags must be set to ``True``. If not provided, defaults to the leaf tensors that are in the
computation graph of the ``features``.
Expand Down Expand Up @@ -73,43 +91,58 @@ def mtl_backward(

check_optional_positive_chunk_size(parallel_chunk_size)

losses_ = as_checked_ordered_set(losses, "losses")
tensors_ = as_checked_ordered_set(tensors, "tensors")
features_ = as_checked_ordered_set(features, "features")

if shared_params is None:
shared_params_ = get_leaf_tensors(tensors=features_, excluded=[])
else:
shared_params_ = OrderedSet(shared_params)
if tasks_params is None:
tasks_params_ = [get_leaf_tensors(tensors=[loss], excluded=features_) for loss in losses_]
tasks_params_ = [get_leaf_tensors(tensors=[t], excluded=features_) for t in tensors_]
else:
tasks_params_ = [OrderedSet(task_params) for task_params in tasks_params]

if len(features_) == 0:
raise ValueError("`features` cannot be empty.")

_check_no_overlap(shared_params_, tasks_params_)
_check_losses_are_scalar(losses_)
if len(tensors_) == 0:
raise ValueError("`tensors` cannot be empty.")
if len(tensors_) != len(tasks_params_):
raise ValueError("`tensors` and `tasks_params` should have the same size.")

if len(losses_) == 0:
raise ValueError("`losses` cannot be empty")
if len(losses_) != len(tasks_params_):
raise ValueError("`losses` and `tasks_params` should have the same size.")
_check_no_overlap(shared_params_, tasks_params_)

grad_tensors_dict = _create_grad_tensors_dict(tensors_, grad_tensors)
backward_transform = _create_transform(
losses=losses_,
tensors=tensors_,
features=features_,
tasks_params=tasks_params_,
shared_params=shared_params_,
retain_graph=retain_graph,
parallel_chunk_size=parallel_chunk_size,
)

backward_transform({})
backward_transform(grad_tensors_dict)


def _create_grad_tensors_dict(
tensors: OrderedSet[Tensor],
opt_grad_tensors: Sequence[Tensor] | None,
) -> dict[Tensor, Tensor]:
if opt_grad_tensors is None:
_check_tensors_are_scalar(tensors)
grad_tensors_dict = Init(tensors)({})
else:
check_matching_length(opt_grad_tensors, tensors, "grad_tensors", "tensors")
check_matching_grad_shapes(opt_grad_tensors, tensors, "grad_tensors", "tensors")
grad_tensors_dict = dict(zip(tensors, opt_grad_tensors, strict=True))

return grad_tensors_dict


def _create_transform(
losses: OrderedSet[Tensor],
tensors: OrderedSet[Tensor],
features: OrderedSet[Tensor],
tasks_params: list[OrderedSet[Tensor]],
shared_params: OrderedSet[Tensor],
Expand All @@ -123,23 +156,23 @@ def _create_transform(
"""

# Task-specific transforms. Each of them computes and accumulates the gradient of the task's
# loss w.r.t. the task's specific parameters, and computes and backpropagates the gradient of
# the losses w.r.t. the shared representations.
# tensor w.r.t. the task's specific parameters, and computes and backpropagates the gradient of
# the tensor w.r.t. the shared representations.
task_transforms = [
_create_task_transform(
features,
task_params,
OrderedSet([loss]),
OrderedSet([t]),
retain_graph,
)
for task_params, loss in zip(tasks_params, losses, strict=True)
for task_params, t in zip(tasks_params, tensors, strict=True)
]

# Transform that stacks the gradients of the losses w.r.t. the shared representations into a
# Transform that stacks the gradients of the tensors w.r.t. the shared representations into a
# Jacobian.
stack = Stack(task_transforms)

# Transform that computes the Jacobians of the losses w.r.t. the shared parameters.
# Transform that computes the Jacobians of the tensors w.r.t. the shared parameters.
jac = Jac(features, shared_params, parallel_chunk_size, retain_graph)

# Transform that accumulates the result in the .jac field of the shared parameters.
Expand All @@ -151,36 +184,35 @@ def _create_transform(
def _create_task_transform(
features: OrderedSet[Tensor],
task_params: OrderedSet[Tensor],
loss: OrderedSet[Tensor], # contains a single scalar loss
tensor: OrderedSet[Tensor], # contains a single tensor
retain_graph: bool,
) -> Transform:
# Tensors with respect to which we compute the gradients.
to_differentiate = task_params + features

# Transform that initializes the gradient output to 1.
init = Init(loss)

# Transform that computes the gradients of the loss w.r.t. the task-specific parameters and
# Transform that computes the gradients of the tensor w.r.t. the task-specific parameters and
# the features.
grad = Grad(loss, to_differentiate, retain_graph)
grad = Grad(tensor, to_differentiate, retain_graph)

# Transform that accumulates the gradients w.r.t. the task-specific parameters into their
# .grad fields.
accumulate = AccumulateGrad() << Select(task_params)

# Transform that backpropagates the gradients of the losses w.r.t. the features.
# Transform that backpropagates the gradients of the tensor w.r.t. the features.
backpropagate = Select(features)

# Transform that accumulates the gradient of the losses w.r.t. the task-specific parameters into
# their .grad fields and backpropagates the gradient of the losses w.r.t. to the features.
backward_task = (backpropagate | accumulate) << grad << init
# Transform that accumulates the gradient of the tensor w.r.t. the task-specific parameters into
# their .grad fields and backpropagates the gradient of the tensor w.r.t. to the features.
backward_task = (backpropagate | accumulate) << grad << Select(tensor)
return backward_task


def _check_losses_are_scalar(losses: Iterable[Tensor]) -> None:
for loss in losses:
if loss.ndim > 0:
raise ValueError("`losses` should contain only scalars.")
def _check_tensors_are_scalar(tensors: Iterable[Tensor]) -> None:
for t in tensors:
if t.ndim > 0:
raise ValueError(
"When `tensors` are non-scalar, the `grad_tensors` parameter must be provided."
)


def _check_no_overlap(
Expand Down
26 changes: 25 additions & 1 deletion src/torchjd/autojac/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def check_matching_length(
)


def check_matching_shapes(
def check_matching_jac_shapes(
jacobians: Iterable[Tensor],
tensors: Iterable[Tensor],
jacobian_variable_name: str,
Expand All @@ -78,6 +78,30 @@ def check_matching_shapes(
)


def check_matching_grad_shapes(
gradients: Iterable[Tensor],
tensors: Iterable[Tensor],
gradient_variable_name: str,
tensor_variable_name: str,
) -> None:
"""
Checks that the shape of each gradient matches the corresponding tensor shape.

:param gradients: Sequence of gradient tensors to validate.
:param tensors: Sequence of tensors whose shapes should match.
:param gradient_variable_name: Name of the gradient variable for error messages.
:param tensor_variable_name: Name of the tensor variable for error messages.
"""
for i, (gradient, tensor) in enumerate(zip(gradients, tensors, strict=True)):
if gradient.shape != tensor.shape:
raise ValueError(
f"Shape mismatch: `{gradient_variable_name}[{i}]` has shape {tuple(gradient.shape)} "
f"but `{tensor_variable_name}[{i}]` has shape {tuple(tensor.shape)}. "
f"The shape of `{gradient_variable_name}[{i}]` should match the shape of "
f"`{tensor_variable_name}[{i}]`.",
)


def check_consistent_first_dimension(
jacobians: Sequence[Tensor],
variable_name: str,
Expand Down
8 changes: 4 additions & 4 deletions tests/doc/test_rst.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_amp():
loss2 = loss_fn(output2, target2)

scaled_losses = scaler.scale([loss1, loss2])
mtl_backward(losses=scaled_losses, features=features)
mtl_backward(tensors=scaled_losses, features=features)
jac_to_grad(shared_module.parameters(), aggregator)
scaler.step(optimizer)
scaler.update()
Expand Down Expand Up @@ -250,7 +250,7 @@ def training_step(self, batch, batch_idx) -> None:

opt = self.optimizers()

mtl_backward(losses=[loss1, loss2], features=features)
mtl_backward(tensors=[loss1, loss2], features=features)
jac_to_grad(self.feature_extractor.parameters(), UPGrad())
opt.step()
opt.zero_grad()
Expand Down Expand Up @@ -325,7 +325,7 @@ def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch.
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)

mtl_backward(losses=[loss1, loss2], features=features)
mtl_backward(tensors=[loss1, loss2], features=features)
jac_to_grad(shared_module.parameters(), aggregator)
optimizer.step()
optimizer.zero_grad()
Expand Down Expand Up @@ -363,7 +363,7 @@ def test_mtl():
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)

mtl_backward(losses=[loss1, loss2], features=features)
mtl_backward(tensors=[loss1, loss2], features=features)
jac_to_grad(shared_module.parameters(), aggregator)
optimizer.step()
optimizer.zero_grad()
Expand Down
Loading