From 934e8f380a31e51e4c0c20eea68d2f56ade63838 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 18 Feb 2026 15:19:51 +0100 Subject: [PATCH 01/14] Make `output` parameter of `Engine.compute_gramian` positional-only. --- src/torchjd/autogram/_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index f8248c0c..610b3753 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -235,7 +235,7 @@ def _check_module_is_compatible(self, module: nn.Module) -> None: ) # Currently, the type PSDMatrix is hidden from users, so Tensor is correct. - def compute_gramian(self, output: Tensor) -> Tensor: + def compute_gramian(self, output: Tensor, /) -> Tensor: r""" Computes the Gramian of the Jacobian of ``output`` with respect to the direct parameters of all ``modules``. From 95e691f19803b77151ac50693415433c0fdb380b Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 18 Feb 2026 15:24:26 +0100 Subject: [PATCH 02/14] make `__call__` in `Aggregator`, `Weighting` and `GeneralizedWeighting` take a positional-only parameter. --- src/torchjd/aggregation/_aggregator_bases.py | 2 +- src/torchjd/aggregation/_weighting_bases.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchjd/aggregation/_aggregator_bases.py b/src/torchjd/aggregation/_aggregator_bases.py index 5a656f67..82418e9a 100644 --- a/src/torchjd/aggregation/_aggregator_bases.py +++ b/src/torchjd/aggregation/_aggregator_bases.py @@ -28,7 +28,7 @@ def _check_is_matrix(matrix: Tensor) -> None: def forward(self, matrix: Matrix) -> Tensor: """Computes the aggregation from the input matrix.""" - def __call__(self, matrix: Tensor) -> Tensor: + def __call__(self, matrix: Tensor, /) -> Tensor: """Computes the aggregation from the input matrix and applies all registered hooks.""" Aggregator._check_is_matrix(matrix) return super().__call__(matrix) diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index 3b037891..1e87707e 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -27,7 +27,7 @@ def __init__(self): def forward(self, stat: _T, /) -> Tensor: """Computes the vector of weights from the input stat.""" - def __call__(self, stat: Tensor) -> Tensor: + def __call__(self, stat: Tensor, /) -> Tensor: """Computes the vector of weights from the input stat and applies all registered hooks.""" # The value of _T (e.g. PSDMatrix) is not public, so we need the user-facing type hint of @@ -70,7 +70,7 @@ def __init__(self): def forward(self, generalized_gramian: PSDTensor) -> Tensor: """Computes the vector of weights from the input generalized Gramian.""" - def __call__(self, generalized_gramian: Tensor) -> Tensor: + def __call__(self, generalized_gramian: Tensor, /) -> Tensor: """ Computes the tensor of weights from the input generalized Gramian and applies all registered hooks. From a44c7ccfd566292cbee9f5158ae9542d6db92da0 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 18 Feb 2026 15:26:54 +0100 Subject: [PATCH 03/14] (Private interface) make `GeneralizedWeighting.forward` take a positional-only argument. --- src/torchjd/aggregation/_flattening.py | 2 +- src/torchjd/aggregation/_weighting_bases.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchjd/aggregation/_flattening.py b/src/torchjd/aggregation/_flattening.py index 8db6027b..208db3ec 100644 --- a/src/torchjd/aggregation/_flattening.py +++ b/src/torchjd/aggregation/_flattening.py @@ -24,7 +24,7 @@ def __init__(self, weighting: Weighting): super().__init__() self.weighting = weighting - def forward(self, generalized_gramian: PSDTensor) -> Tensor: + def forward(self, generalized_gramian: PSDTensor, /) -> Tensor: k = generalized_gramian.ndim // 2 shape = generalized_gramian.shape[:k] square_gramian = flatten(generalized_gramian) diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index 1e87707e..dd7c53ee 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -67,7 +67,7 @@ def __init__(self): super().__init__() @abstractmethod - def forward(self, generalized_gramian: PSDTensor) -> Tensor: + def forward(self, generalized_gramian: PSDTensor, /) -> Tensor: """Computes the vector of weights from the input generalized Gramian.""" def __call__(self, generalized_gramian: Tensor, /) -> Tensor: From c2c4c52786d57659f11c034bbcb97fb67c73b967 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 18 Feb 2026 15:31:06 +0100 Subject: [PATCH 04/14] (Private interface) make `Aggregator.forward` take a single positional only argument. --- src/torchjd/aggregation/_aggregator_bases.py | 4 ++-- src/torchjd/aggregation/_config.py | 2 +- src/torchjd/aggregation/_graddrop.py | 2 +- src/torchjd/aggregation/_trimmed_mean.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/torchjd/aggregation/_aggregator_bases.py b/src/torchjd/aggregation/_aggregator_bases.py index 82418e9a..78168eae 100644 --- a/src/torchjd/aggregation/_aggregator_bases.py +++ b/src/torchjd/aggregation/_aggregator_bases.py @@ -25,7 +25,7 @@ def _check_is_matrix(matrix: Tensor) -> None: ) @abstractmethod - def forward(self, matrix: Matrix) -> Tensor: + def forward(self, matrix: Matrix, /) -> Tensor: """Computes the aggregation from the input matrix.""" def __call__(self, matrix: Tensor, /) -> Tensor: @@ -62,7 +62,7 @@ def combine(matrix: Matrix, weights: Tensor) -> Tensor: vector = weights @ matrix return vector - def forward(self, matrix: Matrix) -> Tensor: + def forward(self, matrix: Matrix, /) -> Tensor: weights = self.weighting(matrix) vector = self.combine(matrix, weights) return vector diff --git a/src/torchjd/aggregation/_config.py b/src/torchjd/aggregation/_config.py index 980b93b4..24866c41 100644 --- a/src/torchjd/aggregation/_config.py +++ b/src/torchjd/aggregation/_config.py @@ -58,7 +58,7 @@ def __init__(self, pref_vector: Tensor | None = None): # This prevents computing gradients that can be very wrong. self.register_full_backward_pre_hook(raise_non_differentiable_error) - def forward(self, matrix: Matrix) -> Tensor: + def forward(self, matrix: Matrix, /) -> Tensor: weights = self.weighting(matrix) units = torch.nan_to_num((matrix / (matrix.norm(dim=1)).unsqueeze(1)), 0.0) best_direction = torch.linalg.pinv(units) @ weights diff --git a/src/torchjd/aggregation/_graddrop.py b/src/torchjd/aggregation/_graddrop.py index f8a39426..afa16451 100644 --- a/src/torchjd/aggregation/_graddrop.py +++ b/src/torchjd/aggregation/_graddrop.py @@ -40,7 +40,7 @@ def __init__(self, f: Callable = _identity, leak: Tensor | None = None): # This prevents computing gradients that can be very wrong. self.register_full_backward_pre_hook(raise_non_differentiable_error) - def forward(self, matrix: Matrix) -> Tensor: + def forward(self, matrix: Matrix, /) -> Tensor: self._check_matrix_has_enough_rows(matrix) if matrix.shape[0] == 0 or matrix.shape[1] == 0: diff --git a/src/torchjd/aggregation/_trimmed_mean.py b/src/torchjd/aggregation/_trimmed_mean.py index 07ed055a..77d33c41 100644 --- a/src/torchjd/aggregation/_trimmed_mean.py +++ b/src/torchjd/aggregation/_trimmed_mean.py @@ -24,7 +24,7 @@ def __init__(self, trim_number: int): ) self.trim_number = trim_number - def forward(self, matrix: Tensor) -> Tensor: + def forward(self, matrix: Tensor, /) -> Tensor: self._check_matrix_has_enough_rows(matrix) n_rows = matrix.shape[0] From 3b158b92d66854351a01aa12f3e6dd5a7046ef1d Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 18 Feb 2026 15:37:33 +0100 Subject: [PATCH 05/14] Proposition of minimal separation for backward, could be `tensors` positional only and the rest named-only in my opinion. --- src/torchjd/autojac/_backward.py | 2 ++ tests/unit/autojac/test_backward.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index 206e5b92..9612d186 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -15,8 +15,10 @@ def backward( tensors: Sequence[Tensor] | Tensor, + /, jac_tensors: Sequence[Tensor] | Tensor | None = None, inputs: Iterable[Tensor] | None = None, + *, retain_graph: bool = False, parallel_chunk_size: int | None = None, ) -> None: diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index 4d75d382..a0398c42 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -310,7 +310,7 @@ def test_input_retaining_grad_fails(): # backward itself doesn't raise the error, but it fills b.grad with a BatchedTensor # (and it also fills b.jac with the correct Jacobian) - backward(tensors=y, inputs=[b]) + backward(y, inputs=[b]) with raises(RuntimeError): # Using such a BatchedTensor should result in an error @@ -329,7 +329,7 @@ def test_non_input_retaining_grad_fails(): y = 3 * b # backward itself doesn't raise the error, but it fills b.grad with a BatchedTensor - backward(tensors=y, inputs=[a]) + backward(y, inputs=[a]) with raises(RuntimeError): # Using such a BatchedTensor should result in an error From ec0aec1b30fac23411a87cbbbe086679b77d34e9 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 18 Feb 2026 15:40:33 +0100 Subject: [PATCH 06/14] Minimal separation for mtl_backward, could be stricter (tensors and features positional only and the rest named only). Note that for `tensors` it makes little sense to use named now, before we had `losses` and then why not. --- docs/source/examples/amp.rst | 2 +- .../source/examples/lightning_integration.rst | 2 +- docs/source/examples/monitoring.rst | 2 +- docs/source/examples/mtl.rst | 2 +- src/torchjd/autojac/_mtl_backward.py | 2 + tests/doc/test_rst.py | 8 +-- tests/unit/autojac/test_mtl_backward.py | 60 +++++++++---------- 7 files changed, 40 insertions(+), 38 deletions(-) diff --git a/docs/source/examples/amp.rst b/docs/source/examples/amp.rst index f1cbe5fa..2de486c1 100644 --- a/docs/source/examples/amp.rst +++ b/docs/source/examples/amp.rst @@ -48,7 +48,7 @@ following example shows the resulting code for a multi-task learning use-case. loss2 = loss_fn(output2, target2) scaled_losses = scaler.scale([loss1, loss2]) - mtl_backward(tensors=scaled_losses, features=features) + mtl_backward(scaled_losses, features=features) jac_to_grad(shared_module.parameters(), aggregator) scaler.step(optimizer) scaler.update() diff --git a/docs/source/examples/lightning_integration.rst b/docs/source/examples/lightning_integration.rst index 9cc74904..c8416083 100644 --- a/docs/source/examples/lightning_integration.rst +++ b/docs/source/examples/lightning_integration.rst @@ -43,7 +43,7 @@ The following code example demonstrates a basic multi-task learning setup using loss2 = mse_loss(output2, target2) opt = self.optimizers() - mtl_backward(tensors=[loss1, loss2], features=features) + mtl_backward([loss1, loss2], features=features) jac_to_grad(self.feature_extractor.parameters(), UPGrad()) opt.step() opt.zero_grad() diff --git a/docs/source/examples/monitoring.rst b/docs/source/examples/monitoring.rst index fda6af2d..784eea4d 100644 --- a/docs/source/examples/monitoring.rst +++ b/docs/source/examples/monitoring.rst @@ -63,7 +63,7 @@ they have a negative inner product). loss1 = loss_fn(output1, target1) loss2 = loss_fn(output2, target2) - mtl_backward(tensors=[loss1, loss2], features=features) + mtl_backward([loss1, loss2], features=features) jac_to_grad(shared_module.parameters(), aggregator) optimizer.step() optimizer.zero_grad() diff --git a/docs/source/examples/mtl.rst b/docs/source/examples/mtl.rst index e991770f..147a999b 100644 --- a/docs/source/examples/mtl.rst +++ b/docs/source/examples/mtl.rst @@ -52,7 +52,7 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks. loss1 = loss_fn(output1, target1) loss2 = loss_fn(output2, target2) - mtl_backward(tensors=[loss1, loss2], features=features) + mtl_backward([loss1, loss2], features=features) jac_to_grad(shared_module.parameters(), aggregator) optimizer.step() optimizer.zero_grad() diff --git a/src/torchjd/autojac/_mtl_backward.py b/src/torchjd/autojac/_mtl_backward.py index 7389be33..ac3284cc 100644 --- a/src/torchjd/autojac/_mtl_backward.py +++ b/src/torchjd/autojac/_mtl_backward.py @@ -24,10 +24,12 @@ def mtl_backward( tensors: Sequence[Tensor], + /, features: Sequence[Tensor] | Tensor, grad_tensors: Sequence[Tensor] | None = None, tasks_params: Sequence[Iterable[Tensor]] | None = None, shared_params: Iterable[Tensor] | None = None, + *, retain_graph: bool = False, parallel_chunk_size: int | None = None, ) -> None: diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index 02dda25a..ac4ac060 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -44,7 +44,7 @@ def test_amp(): loss2 = loss_fn(output2, target2) scaled_losses = scaler.scale([loss1, loss2]) - mtl_backward(tensors=scaled_losses, features=features) + mtl_backward(scaled_losses, features=features) jac_to_grad(shared_module.parameters(), aggregator) scaler.step(optimizer) scaler.update() @@ -250,7 +250,7 @@ def training_step(self, batch, batch_idx) -> None: opt = self.optimizers() - mtl_backward(tensors=[loss1, loss2], features=features) + mtl_backward([loss1, loss2], features=features) jac_to_grad(self.feature_extractor.parameters(), UPGrad()) opt.step() opt.zero_grad() @@ -325,7 +325,7 @@ def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch. loss1 = loss_fn(output1, target1) loss2 = loss_fn(output2, target2) - mtl_backward(tensors=[loss1, loss2], features=features) + mtl_backward([loss1, loss2], features=features) jac_to_grad(shared_module.parameters(), aggregator) optimizer.step() optimizer.zero_grad() @@ -363,7 +363,7 @@ def test_mtl(): loss1 = loss_fn(output1, target1) loss2 = loss_fn(output2, target2) - mtl_backward(tensors=[loss1, loss2], features=features) + mtl_backward([loss1, loss2], features=features) jac_to_grad(shared_module.parameters(), aggregator) optimizer.step() optimizer.zero_grad() diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index ea69a20b..9f15e6fd 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -60,7 +60,7 @@ def test_shape_is_correct(): y1 = f1 * p1[0] + f2 * p1[1] y2 = f1 * p2[0] + f2 * p2[1] - mtl_backward(tensors=[y1, y2], features=[f1, f2]) + mtl_backward([y1, y2], features=[f1, f2]) assert_has_jac(p0) for p in [p1, p2]: @@ -101,7 +101,7 @@ def test_value_is_correct( tasks_params = [[p1], [p2], [p3]] if manually_specify_tasks_params else None mtl_backward( - tensors=[y1, y2, y3], + [y1, y2, y3], features=f, tasks_params=tasks_params, shared_params=shared_params, @@ -125,7 +125,7 @@ def test_empty_tasks_fails(): f2 = (p0**2).sum() + p0.norm() with raises(ValueError): - mtl_backward(tensors=[], features=[f1, f2]) + mtl_backward([], features=[f1, f2]) def test_single_task(): @@ -138,7 +138,7 @@ def test_single_task(): f2 = (p0**2).sum() + p0.norm() y1 = f1 * p1[0] + f2 * p1[1] - mtl_backward(tensors=[y1], features=[f1, f2]) + mtl_backward([y1], features=[f1, f2]) assert_has_jac(p0) assert_has_grad(p1) @@ -161,14 +161,14 @@ def test_incoherent_task_number_fails(): with raises(ValueError): mtl_backward( - tensors=[y1, y2], + [y1, y2], features=[f1, f2], tasks_params=[[p1]], # Wrong shared_params=[p0], ) with raises(ValueError): mtl_backward( - tensors=[y1], # Wrong + [y1], # Wrong features=[f1, f2], tasks_params=[[p1], [p2]], shared_params=[p0], @@ -188,7 +188,7 @@ def test_empty_params(): y2 = f1 * p2[0] + f2 * p2[1] mtl_backward( - tensors=[y1, y2], + [y1, y2], features=[f1, f2], tasks_params=[[], []], shared_params=[], @@ -214,7 +214,7 @@ def test_multiple_params_per_task(): y1 = f1 * p1_a + (f2 * p1_b).sum() + (f1 * p1_c).sum() y2 = f1 * p2_a * (f2 * p2_b).sum() - mtl_backward(tensors=[y1, y2], features=[f1, f2]) + mtl_backward([y1, y2], features=[f1, f2]) assert_has_jac(p0) for p in [p1_a, p1_b, p1_c, p2_a, p2_b]: @@ -246,7 +246,7 @@ def test_various_shared_params(shared_params_shapes: list[tuple[int]]): y2 = torch.stack([f.sum() ** 2 for f in features]).sum() mtl_backward( - tensors=[y1, y2], + [y1, y2], features=features, tasks_params=[[p1], [p2]], # Enforce differentiation w.r.t. params that haven't been used shared_params=shared_params, @@ -274,7 +274,7 @@ def test_partial_params(): y2 = f1 * p2[0] + f2 * p2[1] mtl_backward( - tensors=[y1, y2], + [y1, y2], features=[f1, f2], tasks_params=[[p1], []], shared_params=[p0], @@ -298,7 +298,7 @@ def test_empty_features_fails(): y2 = f1 * p2[0] + f2 * p2[1] with raises(ValueError): - mtl_backward(tensors=[y1, y2], features=[]) + mtl_backward([y1, y2], features=[]) @mark.parametrize( @@ -322,7 +322,7 @@ def test_various_single_features(shape: tuple[int, ...]): y1 = (f * p1[0]).sum() + (f * p1[1]).sum() y2 = (f * p2[0]).sum() * (f * p2[1]).sum() - mtl_backward(tensors=[y1, y2], features=f) + mtl_backward([y1, y2], features=f) assert_has_jac(p0) for p in [p1, p2]: @@ -354,7 +354,7 @@ def test_various_feature_lists(shapes: list[tuple[int]]): y1 = sum([(f * p).sum() for f, p in zip(features, p1, strict=True)]) y2 = (features[0] * p2).sum() - mtl_backward(tensors=[y1, y2], features=features) + mtl_backward([y1, y2], features=features) assert_has_jac(p0) for p in [p1, p2]: @@ -374,7 +374,7 @@ def test_non_scalar_loss_fails(): y2 = f1 * p2[0] + f2 * p2[1] with raises(ValueError): - mtl_backward(tensors=[y1, y2], features=[f1, f2]) + mtl_backward([y1, y2], features=[f1, f2]) @mark.parametrize("chunk_size", [None, 1, 2, 4]) @@ -391,7 +391,7 @@ def test_various_valid_chunk_sizes(chunk_size): y2 = f1 * p2[0] + f2 * p2[1] mtl_backward( - tensors=[y1, y2], + [y1, y2], features=[f1, f2], parallel_chunk_size=chunk_size, ) @@ -416,7 +416,7 @@ def test_non_positive_chunk_size_fails(chunk_size: int): with raises(ValueError): mtl_backward( - tensors=[y1, y2], + [y1, y2], features=[f1, f2], parallel_chunk_size=chunk_size, ) @@ -440,7 +440,7 @@ def test_shared_param_retaining_grad_fails(): # mtl_backward itself doesn't raise the error, but it fills a.grad with a BatchedTensor mtl_backward( - tensors=[y1, y2], + [y1, y2], features=[f], tasks_params=[[p1], [p2]], shared_params=[a, p0], @@ -469,7 +469,7 @@ def test_shared_activation_retaining_grad_fails(): # mtl_backward itself doesn't raise the error, but it fills a.grad with a BatchedTensor mtl_backward( - tensors=[y1, y2], + [y1, y2], features=[f], tasks_params=[[p1], [p2]], shared_params=[p0], @@ -492,7 +492,7 @@ def test_tasks_params_overlap(): y1 = f * p1 * p12 y2 = f * p2 * p12 - mtl_backward(tensors=[y1, y2], features=[f]) + mtl_backward([y1, y2], features=[f]) assert_grad_close(p2, f * p12) assert_grad_close(p1, f * p12) @@ -512,7 +512,7 @@ def test_tasks_params_are_the_same(): y1 = f * p1 y2 = f + p1 - mtl_backward(tensors=[y1, y2], features=[f]) + mtl_backward([y1, y2], features=[f]) assert_grad_close(p1, f + 1) @@ -534,7 +534,7 @@ def test_task_params_is_subset_of_other_task_params(): y1 = f * p1 y2 = y1 * p2 - mtl_backward(tensors=[y1, y2], features=[f], retain_graph=True) + mtl_backward([y1, y2], features=[f], retain_graph=True) assert_grad_close(p2, y1) assert_grad_close(p1, p2 * f + f) @@ -559,7 +559,7 @@ def test_shared_params_overlapping_with_tasks_params_fails(): with raises(ValueError): mtl_backward( - tensors=[y1, y2], + [y1, y2], features=[f], tasks_params=[[p1], [p0, p2]], # Problem: p0 is also shared shared_params=[p0], @@ -582,7 +582,7 @@ def test_default_shared_params_overlapping_with_default_tasks_params_fails(): with raises(ValueError): mtl_backward( - tensors=[y1, y2], + [y1, y2], features=[f], ) @@ -607,7 +607,7 @@ def test_repeated_losses(): with raises(ValueError): losses = [y1, y1, y2] - mtl_backward(tensors=losses, features=[f1, f2], retain_graph=True) + mtl_backward(losses, features=[f1, f2], retain_graph=True) def test_repeated_features(): @@ -630,7 +630,7 @@ def test_repeated_features(): with raises(ValueError): features = [f1, f1, f2] - mtl_backward(tensors=[y1, y2], features=features) + mtl_backward([y1, y2], features=features) def test_repeated_shared_params(): @@ -654,7 +654,7 @@ def test_repeated_shared_params(): g2 = grad([y2], [p2], retain_graph=True)[0] shared_params = [p0, p0] - mtl_backward(tensors=[y1, y2], features=[f1, f2], shared_params=shared_params) + mtl_backward([y1, y2], features=[f1, f2], shared_params=shared_params) assert_jac_close(p0, J0) assert_grad_close(p1, g1) @@ -682,7 +682,7 @@ def test_repeated_task_params(): g2 = grad([y2], [p2], retain_graph=True)[0] tasks_params = [[p1, p1], [p2]] - mtl_backward(tensors=[y1, y2], features=[f1, f2], tasks_params=tasks_params) + mtl_backward([y1, y2], features=[f1, f2], tasks_params=tasks_params) assert_jac_close(p0, J0) assert_grad_close(p1, g1) @@ -711,7 +711,7 @@ def test_grad_tensors_value_is_correct(): grad_y3 = torch.randn_like(y3) mtl_backward( - tensors=[y1, y2, y3], + [y1, y2, y3], features=f, grad_tensors=[grad_y1, grad_y2, grad_y3], ) @@ -741,7 +741,7 @@ def test_grad_tensors_length_mismatch(): match=r"`grad_tensors` should have the same length as `tensors`\. \(got 1 and 2\)", ): mtl_backward( - tensors=[y1, y2], + [y1, y2], features=f, grad_tensors=[torch.ones_like(y1)], ) @@ -767,7 +767,7 @@ def test_grad_tensors_shape_mismatch(): match=r"Shape mismatch: `grad_tensors\[0\]` has shape .* but `tensors\[0\]` has shape .*\.", ): mtl_backward( - tensors=[y1, y2], + [y1, y2], features=f, grad_tensors=[ones_(2), torch.ones_like(y2)], # shape (2,) != shape () of y1 ) From 70876ab66cee1ed8d23913f42f899d0950e07280 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 18 Feb 2026 15:42:29 +0100 Subject: [PATCH 07/14] Minimal separation for `jac`, could be a bit stricter with inputs the only positional and named parameter, everything before that positional and after named. --- src/torchjd/autojac/_jac.py | 2 ++ tests/unit/autojac/test_jac.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/torchjd/autojac/_jac.py b/src/torchjd/autojac/_jac.py index de8aa200..59551cf5 100644 --- a/src/torchjd/autojac/_jac.py +++ b/src/torchjd/autojac/_jac.py @@ -19,8 +19,10 @@ def jac( outputs: Sequence[Tensor] | Tensor, + /, inputs: Iterable[Tensor] | None = None, jac_outputs: Sequence[Tensor] | Tensor | None = None, + *, retain_graph: bool = False, parallel_chunk_size: int | None = None, ) -> tuple[Tensor, ...]: diff --git a/tests/unit/autojac/test_jac.py b/tests/unit/autojac/test_jac.py index 28e1dcb1..9b108be5 100644 --- a/tests/unit/autojac/test_jac.py +++ b/tests/unit/autojac/test_jac.py @@ -309,7 +309,7 @@ def test_input_retaining_grad_fails(): # jac itself doesn't raise the error, but it fills b.grad with a BatchedTensor (and it also # returns the correct Jacobian) - jac(outputs=y, inputs=[b]) + jac(y, inputs=[b]) with raises(RuntimeError): # Using such a BatchedTensor should result in an error @@ -328,7 +328,7 @@ def test_non_input_retaining_grad_fails(): y = 3 * b # jac itself doesn't raise the error, but it fills b.grad with a BatchedTensor - jac(outputs=y, inputs=[a]) + jac(y, inputs=[a]) with raises(RuntimeError): # Using such a BatchedTensor should result in an error From 8886a6275e269366352c55c25063805fe5b352b9 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 18 Feb 2026 15:46:14 +0100 Subject: [PATCH 08/14] Minimal separation for `jac_to_grad`, `aggregator` could also be positional only but this is fine IMO. --- src/torchjd/autojac/_jac_to_grad.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 63f97dc6..3f1d9a5f 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -11,7 +11,9 @@ def jac_to_grad( tensors: Iterable[Tensor], + /, aggregator: Aggregator, + *, retain_jac: bool = False, ) -> None: r""" From 0e0d16b4fad6ce27871be9fcf5232633edcf47a0 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 18 Feb 2026 16:25:00 +0100 Subject: [PATCH 09/14] Make backward stricter. --- src/torchjd/autojac/_backward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index 9612d186..5a5b0139 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -16,9 +16,9 @@ def backward( tensors: Sequence[Tensor] | Tensor, /, + *, jac_tensors: Sequence[Tensor] | Tensor | None = None, inputs: Iterable[Tensor] | None = None, - *, retain_graph: bool = False, parallel_chunk_size: int | None = None, ) -> None: From 87c119a51ad13e3e4b04252988b324c9c5af4aab Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 18 Feb 2026 16:26:05 +0100 Subject: [PATCH 10/14] Make mtl_backward stricter. `feature` did not pass the smell test of runing the tests, most call to `mtl_backward` use key-word for `feature` while we would want to make it positional. So I think I now vouch for both for this one. --- src/torchjd/autojac/_mtl_backward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autojac/_mtl_backward.py b/src/torchjd/autojac/_mtl_backward.py index ac3284cc..281ddfed 100644 --- a/src/torchjd/autojac/_mtl_backward.py +++ b/src/torchjd/autojac/_mtl_backward.py @@ -26,10 +26,10 @@ def mtl_backward( tensors: Sequence[Tensor], /, features: Sequence[Tensor] | Tensor, + *, grad_tensors: Sequence[Tensor] | None = None, tasks_params: Sequence[Iterable[Tensor]] | None = None, shared_params: Iterable[Tensor] | None = None, - *, retain_graph: bool = False, parallel_chunk_size: int | None = None, ) -> None: From f0a725143278c635ea6ab5f21292991afa61bac2 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 18 Feb 2026 16:31:24 +0100 Subject: [PATCH 11/14] Make `jac` stricter. Problem with `inputs`, it is not supposed to be optional and should therefore be kw+pos --- src/torchjd/autojac/_jac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/autojac/_jac.py b/src/torchjd/autojac/_jac.py index 59551cf5..75d0eb0b 100644 --- a/src/torchjd/autojac/_jac.py +++ b/src/torchjd/autojac/_jac.py @@ -21,8 +21,8 @@ def jac( outputs: Sequence[Tensor] | Tensor, /, inputs: Iterable[Tensor] | None = None, - jac_outputs: Sequence[Tensor] | Tensor | None = None, *, + jac_outputs: Sequence[Tensor] | Tensor | None = None, retain_graph: bool = False, parallel_chunk_size: int | None = None, ) -> tuple[Tensor, ...]: From 14d5b28e5a78b4778ff55aa45ad357c42257f1fc Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 18 Feb 2026 17:13:16 +0100 Subject: [PATCH 12/14] change `outpust` of `jac` --- src/torchjd/autojac/_jac.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/torchjd/autojac/_jac.py b/src/torchjd/autojac/_jac.py index 75d0eb0b..78c10df2 100644 --- a/src/torchjd/autojac/_jac.py +++ b/src/torchjd/autojac/_jac.py @@ -19,7 +19,6 @@ def jac( outputs: Sequence[Tensor] | Tensor, - /, inputs: Iterable[Tensor] | None = None, *, jac_outputs: Sequence[Tensor] | Tensor | None = None, From 4f7a389c15a0476963af65039bf9140e9df2c4c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 18 Feb 2026 17:55:26 +0100 Subject: [PATCH 13/14] Add changelog entry --- CHANGELOG.md | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7540c5d4..acbfc995 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,8 +53,12 @@ changelog does not include internal changes that do not affect the user. mtl_backward(losses, features) jac_to_grad(shared_module.parameters(), aggregator) ``` -- **BREAKING**: Renamed the `losses` parameter of `mtl_backward` to `tensors`. - +- **BREAKING**: Made some parameters of the public interface of `torchjd` positional-only or keyword-only: + - `backward`: The `tensors` parameter is now positional-only. Suggested change: `backward(tensors=losses)` => `backward(losses)`. All other parameters are now keyword-only. + - `mtl_backward`: The `tensors` parameter (previously named `losses`) is now positional-only. Suggested change: `mtl_backward(losses=losses, features=features)` => `mtl_backward(losses, features=features)`. The `features` parameter remains usable as positional or keyword. All other parameters are now keyword-only. + - `Aggregator.__call__`: The `matrix` parameter is now positonal-only. Suggested change: `aggregator(matrix=matrix)` => `aggregator(matrix)`. + - `Weighting.__call__`: The `stat` parameter is now positional-only. Suggested change: `weighting(stat=gramian)` => `weighting(gramian)`. + - `GeneralizedWeighting.__call__`: The `generalized_gramian` parameter is now positional-only. Suggested change: `generalized_weighting(generalized_gramian=generalized_gramian)` => `generalized_weighting(generalized_gramian)`. - Removed an unnecessary memory duplication. This should significantly improve the memory efficiency of `autojac`. - Removed an unnecessary internal cloning of gradient. This should slightly improve the memory From caec5da63a43ba6a6fda790155bbbcfb7e14a685 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 18 Feb 2026 17:57:38 +0100 Subject: [PATCH 14/14] Fix formatting --- CHANGELOG.md | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index acbfc995..1d1f0147 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,12 +53,21 @@ changelog does not include internal changes that do not affect the user. mtl_backward(losses, features) jac_to_grad(shared_module.parameters(), aggregator) ``` -- **BREAKING**: Made some parameters of the public interface of `torchjd` positional-only or keyword-only: - - `backward`: The `tensors` parameter is now positional-only. Suggested change: `backward(tensors=losses)` => `backward(losses)`. All other parameters are now keyword-only. - - `mtl_backward`: The `tensors` parameter (previously named `losses`) is now positional-only. Suggested change: `mtl_backward(losses=losses, features=features)` => `mtl_backward(losses, features=features)`. The `features` parameter remains usable as positional or keyword. All other parameters are now keyword-only. - - `Aggregator.__call__`: The `matrix` parameter is now positonal-only. Suggested change: `aggregator(matrix=matrix)` => `aggregator(matrix)`. - - `Weighting.__call__`: The `stat` parameter is now positional-only. Suggested change: `weighting(stat=gramian)` => `weighting(gramian)`. - - `GeneralizedWeighting.__call__`: The `generalized_gramian` parameter is now positional-only. Suggested change: `generalized_weighting(generalized_gramian=generalized_gramian)` => `generalized_weighting(generalized_gramian)`. +- **BREAKING**: Made some parameters of the public interface of `torchjd` positional-only or + keyword-only: + - `backward`: The `tensors` parameter is now positional-only. Suggested change: + `backward(tensors=losses)` => `backward(losses)`. All other parameters are now keyword-only. + - `mtl_backward`: The `tensors` parameter (previously named `losses`) is now positional-only. + Suggested change: `mtl_backward(losses=losses, features=features)` => + `mtl_backward(losses, features=features)`. The `features` parameter remains usable as positional + or keyword. All other parameters are now keyword-only. + - `Aggregator.__call__`: The `matrix` parameter is now positonal-only. Suggested change: + `aggregator(matrix=matrix)` => `aggregator(matrix)`. + - `Weighting.__call__`: The `stat` parameter is now positional-only. Suggested change: + `weighting(stat=gramian)` => `weighting(gramian)`. + - `GeneralizedWeighting.__call__`: The `generalized_gramian` parameter is now positional-only. + Suggested change: `generalized_weighting(generalized_gramian=generalized_gramian)` => + `generalized_weighting(generalized_gramian)`. - Removed an unnecessary memory duplication. This should significantly improve the memory efficiency of `autojac`. - Removed an unnecessary internal cloning of gradient. This should slightly improve the memory