Skip to content

Commit db59ec6

Browse files
refactor!: Made some parameters keyword or positional only (#578)
* (Private interface) make `GeneralizedWeighting.forward` take a positional-only argument. * (Private interface) make `Aggregator.forward` take a single positional only argument. * Change many parameters of the public interface to be either positional-only or keyword-only. Refer to the changelog change for the exact list. * Add changelog entry. --------- Co-authored-by: Valérian Rey <valerian.rey@gmail.com>
1 parent f872551 commit db59ec6

20 files changed

Lines changed: 75 additions & 55 deletions

CHANGELOG.md

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,21 @@ changelog does not include internal changes that do not affect the user.
5353
mtl_backward(losses, features)
5454
jac_to_grad(shared_module.parameters(), aggregator)
5555
```
56-
- **BREAKING**: Renamed the `losses` parameter of `mtl_backward` to `tensors`.
57-
56+
- **BREAKING**: Made some parameters of the public interface of `torchjd` positional-only or
57+
keyword-only:
58+
- `backward`: The `tensors` parameter is now positional-only. Suggested change:
59+
`backward(tensors=losses)` => `backward(losses)`. All other parameters are now keyword-only.
60+
- `mtl_backward`: The `tensors` parameter (previously named `losses`) is now positional-only.
61+
Suggested change: `mtl_backward(losses=losses, features=features)` =>
62+
`mtl_backward(losses, features=features)`. The `features` parameter remains usable as positional
63+
or keyword. All other parameters are now keyword-only.
64+
- `Aggregator.__call__`: The `matrix` parameter is now positonal-only. Suggested change:
65+
`aggregator(matrix=matrix)` => `aggregator(matrix)`.
66+
- `Weighting.__call__`: The `stat` parameter is now positional-only. Suggested change:
67+
`weighting(stat=gramian)` => `weighting(gramian)`.
68+
- `GeneralizedWeighting.__call__`: The `generalized_gramian` parameter is now positional-only.
69+
Suggested change: `generalized_weighting(generalized_gramian=generalized_gramian)` =>
70+
`generalized_weighting(generalized_gramian)`.
5871
- Removed an unnecessary memory duplication. This should significantly improve the memory efficiency
5972
of `autojac`.
6073
- Removed an unnecessary internal cloning of gradient. This should slightly improve the memory

docs/source/examples/amp.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ following example shows the resulting code for a multi-task learning use-case.
4848
loss2 = loss_fn(output2, target2)
4949
5050
scaled_losses = scaler.scale([loss1, loss2])
51-
mtl_backward(tensors=scaled_losses, features=features)
51+
mtl_backward(scaled_losses, features=features)
5252
jac_to_grad(shared_module.parameters(), aggregator)
5353
scaler.step(optimizer)
5454
scaler.update()

docs/source/examples/lightning_integration.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ The following code example demonstrates a basic multi-task learning setup using
4343
loss2 = mse_loss(output2, target2)
4444
4545
opt = self.optimizers()
46-
mtl_backward(tensors=[loss1, loss2], features=features)
46+
mtl_backward([loss1, loss2], features=features)
4747
jac_to_grad(self.feature_extractor.parameters(), UPGrad())
4848
opt.step()
4949
opt.zero_grad()

docs/source/examples/monitoring.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ they have a negative inner product).
6363
loss1 = loss_fn(output1, target1)
6464
loss2 = loss_fn(output2, target2)
6565
66-
mtl_backward(tensors=[loss1, loss2], features=features)
66+
mtl_backward([loss1, loss2], features=features)
6767
jac_to_grad(shared_module.parameters(), aggregator)
6868
optimizer.step()
6969
optimizer.zero_grad()

docs/source/examples/mtl.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks.
5252
loss1 = loss_fn(output1, target1)
5353
loss2 = loss_fn(output2, target2)
5454
55-
mtl_backward(tensors=[loss1, loss2], features=features)
55+
mtl_backward([loss1, loss2], features=features)
5656
jac_to_grad(shared_module.parameters(), aggregator)
5757
optimizer.step()
5858
optimizer.zero_grad()

src/torchjd/aggregation/_aggregator_bases.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ def _check_is_matrix(matrix: Tensor) -> None:
2525
)
2626

2727
@abstractmethod
28-
def forward(self, matrix: Matrix) -> Tensor:
28+
def forward(self, matrix: Matrix, /) -> Tensor:
2929
"""Computes the aggregation from the input matrix."""
3030

31-
def __call__(self, matrix: Tensor) -> Tensor:
31+
def __call__(self, matrix: Tensor, /) -> Tensor:
3232
"""Computes the aggregation from the input matrix and applies all registered hooks."""
3333
Aggregator._check_is_matrix(matrix)
3434
return super().__call__(matrix)
@@ -62,7 +62,7 @@ def combine(matrix: Matrix, weights: Tensor) -> Tensor:
6262
vector = weights @ matrix
6363
return vector
6464

65-
def forward(self, matrix: Matrix) -> Tensor:
65+
def forward(self, matrix: Matrix, /) -> Tensor:
6666
weights = self.weighting(matrix)
6767
vector = self.combine(matrix, weights)
6868
return vector

src/torchjd/aggregation/_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(self, pref_vector: Tensor | None = None):
5858
# This prevents computing gradients that can be very wrong.
5959
self.register_full_backward_pre_hook(raise_non_differentiable_error)
6060

61-
def forward(self, matrix: Matrix) -> Tensor:
61+
def forward(self, matrix: Matrix, /) -> Tensor:
6262
weights = self.weighting(matrix)
6363
units = torch.nan_to_num((matrix / (matrix.norm(dim=1)).unsqueeze(1)), 0.0)
6464
best_direction = torch.linalg.pinv(units) @ weights

src/torchjd/aggregation/_flattening.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, weighting: Weighting):
2424
super().__init__()
2525
self.weighting = weighting
2626

27-
def forward(self, generalized_gramian: PSDTensor) -> Tensor:
27+
def forward(self, generalized_gramian: PSDTensor, /) -> Tensor:
2828
k = generalized_gramian.ndim // 2
2929
shape = generalized_gramian.shape[:k]
3030
square_gramian = flatten(generalized_gramian)

src/torchjd/aggregation/_graddrop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(self, f: Callable = _identity, leak: Tensor | None = None):
4040
# This prevents computing gradients that can be very wrong.
4141
self.register_full_backward_pre_hook(raise_non_differentiable_error)
4242

43-
def forward(self, matrix: Matrix) -> Tensor:
43+
def forward(self, matrix: Matrix, /) -> Tensor:
4444
self._check_matrix_has_enough_rows(matrix)
4545

4646
if matrix.shape[0] == 0 or matrix.shape[1] == 0:

src/torchjd/aggregation/_trimmed_mean.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, trim_number: int):
2424
)
2525
self.trim_number = trim_number
2626

27-
def forward(self, matrix: Tensor) -> Tensor:
27+
def forward(self, matrix: Tensor, /) -> Tensor:
2828
self._check_matrix_has_enough_rows(matrix)
2929

3030
n_rows = matrix.shape[0]

0 commit comments

Comments
 (0)