Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,21 @@ changelog does not include internal changes that do not affect the user.
mtl_backward(losses, features)
jac_to_grad(shared_module.parameters(), aggregator)
```
- **BREAKING**: Renamed the `losses` parameter of `mtl_backward` to `tensors`.

- **BREAKING**: Made some parameters of the public interface of `torchjd` positional-only or
keyword-only:
- `backward`: The `tensors` parameter is now positional-only. Suggested change:
`backward(tensors=losses)` => `backward(losses)`. All other parameters are now keyword-only.
- `mtl_backward`: The `tensors` parameter (previously named `losses`) is now positional-only.
Suggested change: `mtl_backward(losses=losses, features=features)` =>
`mtl_backward(losses, features=features)`. The `features` parameter remains usable as positional
or keyword. All other parameters are now keyword-only.
- `Aggregator.__call__`: The `matrix` parameter is now positonal-only. Suggested change:
`aggregator(matrix=matrix)` => `aggregator(matrix)`.
- `Weighting.__call__`: The `stat` parameter is now positional-only. Suggested change:
`weighting(stat=gramian)` => `weighting(gramian)`.
- `GeneralizedWeighting.__call__`: The `generalized_gramian` parameter is now positional-only.
Suggested change: `generalized_weighting(generalized_gramian=generalized_gramian)` =>
`generalized_weighting(generalized_gramian)`.
- Removed an unnecessary memory duplication. This should significantly improve the memory efficiency
of `autojac`.
- Removed an unnecessary internal cloning of gradient. This should slightly improve the memory
Expand Down
2 changes: 1 addition & 1 deletion docs/source/examples/amp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ following example shows the resulting code for a multi-task learning use-case.
loss2 = loss_fn(output2, target2)
scaled_losses = scaler.scale([loss1, loss2])
mtl_backward(tensors=scaled_losses, features=features)
mtl_backward(scaled_losses, features=features)
jac_to_grad(shared_module.parameters(), aggregator)
scaler.step(optimizer)
scaler.update()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/examples/lightning_integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ The following code example demonstrates a basic multi-task learning setup using
loss2 = mse_loss(output2, target2)
opt = self.optimizers()
mtl_backward(tensors=[loss1, loss2], features=features)
mtl_backward([loss1, loss2], features=features)
jac_to_grad(self.feature_extractor.parameters(), UPGrad())
opt.step()
opt.zero_grad()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/examples/monitoring.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ they have a negative inner product).
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)

mtl_backward(tensors=[loss1, loss2], features=features)
mtl_backward([loss1, loss2], features=features)
jac_to_grad(shared_module.parameters(), aggregator)
optimizer.step()
optimizer.zero_grad()
2 changes: 1 addition & 1 deletion docs/source/examples/mtl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks.
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)
mtl_backward(tensors=[loss1, loss2], features=features)
mtl_backward([loss1, loss2], features=features)
jac_to_grad(shared_module.parameters(), aggregator)
optimizer.step()
optimizer.zero_grad()
Expand Down
6 changes: 3 additions & 3 deletions src/torchjd/aggregation/_aggregator_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def _check_is_matrix(matrix: Tensor) -> None:
)

@abstractmethod
def forward(self, matrix: Matrix) -> Tensor:
def forward(self, matrix: Matrix, /) -> Tensor:
"""Computes the aggregation from the input matrix."""

def __call__(self, matrix: Tensor) -> Tensor:
def __call__(self, matrix: Tensor, /) -> Tensor:
"""Computes the aggregation from the input matrix and applies all registered hooks."""
Aggregator._check_is_matrix(matrix)
return super().__call__(matrix)
Expand Down Expand Up @@ -62,7 +62,7 @@ def combine(matrix: Matrix, weights: Tensor) -> Tensor:
vector = weights @ matrix
return vector

def forward(self, matrix: Matrix) -> Tensor:
def forward(self, matrix: Matrix, /) -> Tensor:
weights = self.weighting(matrix)
vector = self.combine(matrix, weights)
return vector
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, pref_vector: Tensor | None = None):
# This prevents computing gradients that can be very wrong.
self.register_full_backward_pre_hook(raise_non_differentiable_error)

def forward(self, matrix: Matrix) -> Tensor:
def forward(self, matrix: Matrix, /) -> Tensor:
weights = self.weighting(matrix)
units = torch.nan_to_num((matrix / (matrix.norm(dim=1)).unsqueeze(1)), 0.0)
best_direction = torch.linalg.pinv(units) @ weights
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_flattening.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, weighting: Weighting):
super().__init__()
self.weighting = weighting

def forward(self, generalized_gramian: PSDTensor) -> Tensor:
def forward(self, generalized_gramian: PSDTensor, /) -> Tensor:
k = generalized_gramian.ndim // 2
shape = generalized_gramian.shape[:k]
square_gramian = flatten(generalized_gramian)
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_graddrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, f: Callable = _identity, leak: Tensor | None = None):
# This prevents computing gradients that can be very wrong.
self.register_full_backward_pre_hook(raise_non_differentiable_error)

def forward(self, matrix: Matrix) -> Tensor:
def forward(self, matrix: Matrix, /) -> Tensor:
self._check_matrix_has_enough_rows(matrix)

if matrix.shape[0] == 0 or matrix.shape[1] == 0:
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_trimmed_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, trim_number: int):
)
self.trim_number = trim_number

def forward(self, matrix: Tensor) -> Tensor:
def forward(self, matrix: Tensor, /) -> Tensor:
self._check_matrix_has_enough_rows(matrix)

n_rows = matrix.shape[0]
Expand Down
6 changes: 3 additions & 3 deletions src/torchjd/aggregation/_weighting_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self):
def forward(self, stat: _T, /) -> Tensor:
"""Computes the vector of weights from the input stat."""

def __call__(self, stat: Tensor) -> Tensor:
def __call__(self, stat: Tensor, /) -> Tensor:
"""Computes the vector of weights from the input stat and applies all registered hooks."""

# The value of _T (e.g. PSDMatrix) is not public, so we need the user-facing type hint of
Expand Down Expand Up @@ -67,10 +67,10 @@ def __init__(self):
super().__init__()

@abstractmethod
def forward(self, generalized_gramian: PSDTensor) -> Tensor:
def forward(self, generalized_gramian: PSDTensor, /) -> Tensor:
"""Computes the vector of weights from the input generalized Gramian."""

def __call__(self, generalized_gramian: Tensor) -> Tensor:
def __call__(self, generalized_gramian: Tensor, /) -> Tensor:
"""
Computes the tensor of weights from the input generalized Gramian and applies all registered
hooks.
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/autogram/_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _check_module_is_compatible(self, module: nn.Module) -> None:
)

# Currently, the type PSDMatrix is hidden from users, so Tensor is correct.
def compute_gramian(self, output: Tensor) -> Tensor:
def compute_gramian(self, output: Tensor, /) -> Tensor:
r"""
Computes the Gramian of the Jacobian of ``output`` with respect to the direct parameters of
all ``modules``.
Expand Down
2 changes: 2 additions & 0 deletions src/torchjd/autojac/_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

def backward(
tensors: Sequence[Tensor] | Tensor,
/,
*,
jac_tensors: Sequence[Tensor] | Tensor | None = None,
inputs: Iterable[Tensor] | None = None,
retain_graph: bool = False,
Expand Down
1 change: 1 addition & 0 deletions src/torchjd/autojac/_jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
def jac(
outputs: Sequence[Tensor] | Tensor,
inputs: Iterable[Tensor] | None = None,
*,
jac_outputs: Sequence[Tensor] | Tensor | None = None,
retain_graph: bool = False,
parallel_chunk_size: int | None = None,
Expand Down
2 changes: 2 additions & 0 deletions src/torchjd/autojac/_jac_to_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

def jac_to_grad(
tensors: Iterable[Tensor],
/,
aggregator: Aggregator,
*,
retain_jac: bool = False,
) -> None:
r"""
Expand Down
2 changes: 2 additions & 0 deletions src/torchjd/autojac/_mtl_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@

def mtl_backward(
tensors: Sequence[Tensor],
/,
features: Sequence[Tensor] | Tensor,
*,
grad_tensors: Sequence[Tensor] | None = None,
tasks_params: Sequence[Iterable[Tensor]] | None = None,
shared_params: Iterable[Tensor] | None = None,
Expand Down
8 changes: 4 additions & 4 deletions tests/doc/test_rst.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_amp():
loss2 = loss_fn(output2, target2)

scaled_losses = scaler.scale([loss1, loss2])
mtl_backward(tensors=scaled_losses, features=features)
mtl_backward(scaled_losses, features=features)
jac_to_grad(shared_module.parameters(), aggregator)
scaler.step(optimizer)
scaler.update()
Expand Down Expand Up @@ -250,7 +250,7 @@ def training_step(self, batch, batch_idx) -> None:

opt = self.optimizers()

mtl_backward(tensors=[loss1, loss2], features=features)
mtl_backward([loss1, loss2], features=features)
jac_to_grad(self.feature_extractor.parameters(), UPGrad())
opt.step()
opt.zero_grad()
Expand Down Expand Up @@ -325,7 +325,7 @@ def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch.
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)

mtl_backward(tensors=[loss1, loss2], features=features)
mtl_backward([loss1, loss2], features=features)
jac_to_grad(shared_module.parameters(), aggregator)
optimizer.step()
optimizer.zero_grad()
Expand Down Expand Up @@ -363,7 +363,7 @@ def test_mtl():
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)

mtl_backward(tensors=[loss1, loss2], features=features)
mtl_backward([loss1, loss2], features=features)
jac_to_grad(shared_module.parameters(), aggregator)
optimizer.step()
optimizer.zero_grad()
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/autojac/test_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def test_input_retaining_grad_fails():

# backward itself doesn't raise the error, but it fills b.grad with a BatchedTensor
# (and it also fills b.jac with the correct Jacobian)
backward(tensors=y, inputs=[b])
backward(y, inputs=[b])

with raises(RuntimeError):
# Using such a BatchedTensor should result in an error
Expand All @@ -329,7 +329,7 @@ def test_non_input_retaining_grad_fails():
y = 3 * b

# backward itself doesn't raise the error, but it fills b.grad with a BatchedTensor
backward(tensors=y, inputs=[a])
backward(y, inputs=[a])

with raises(RuntimeError):
# Using such a BatchedTensor should result in an error
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/autojac/test_jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def test_input_retaining_grad_fails():

# jac itself doesn't raise the error, but it fills b.grad with a BatchedTensor (and it also
# returns the correct Jacobian)
jac(outputs=y, inputs=[b])
jac(y, inputs=[b])

with raises(RuntimeError):
# Using such a BatchedTensor should result in an error
Expand All @@ -328,7 +328,7 @@ def test_non_input_retaining_grad_fails():
y = 3 * b

# jac itself doesn't raise the error, but it fills b.grad with a BatchedTensor
jac(outputs=y, inputs=[a])
jac(y, inputs=[a])

with raises(RuntimeError):
# Using such a BatchedTensor should result in an error
Expand Down
Loading