diff --git a/CHANGELOG.md b/CHANGELOG.md index 7540c5d4..1d1f0147 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,8 +53,21 @@ changelog does not include internal changes that do not affect the user. mtl_backward(losses, features) jac_to_grad(shared_module.parameters(), aggregator) ``` -- **BREAKING**: Renamed the `losses` parameter of `mtl_backward` to `tensors`. - +- **BREAKING**: Made some parameters of the public interface of `torchjd` positional-only or + keyword-only: + - `backward`: The `tensors` parameter is now positional-only. Suggested change: + `backward(tensors=losses)` => `backward(losses)`. All other parameters are now keyword-only. + - `mtl_backward`: The `tensors` parameter (previously named `losses`) is now positional-only. + Suggested change: `mtl_backward(losses=losses, features=features)` => + `mtl_backward(losses, features=features)`. The `features` parameter remains usable as positional + or keyword. All other parameters are now keyword-only. + - `Aggregator.__call__`: The `matrix` parameter is now positonal-only. Suggested change: + `aggregator(matrix=matrix)` => `aggregator(matrix)`. + - `Weighting.__call__`: The `stat` parameter is now positional-only. Suggested change: + `weighting(stat=gramian)` => `weighting(gramian)`. + - `GeneralizedWeighting.__call__`: The `generalized_gramian` parameter is now positional-only. + Suggested change: `generalized_weighting(generalized_gramian=generalized_gramian)` => + `generalized_weighting(generalized_gramian)`. - Removed an unnecessary memory duplication. This should significantly improve the memory efficiency of `autojac`. - Removed an unnecessary internal cloning of gradient. This should slightly improve the memory diff --git a/docs/source/examples/amp.rst b/docs/source/examples/amp.rst index f1cbe5fa..2de486c1 100644 --- a/docs/source/examples/amp.rst +++ b/docs/source/examples/amp.rst @@ -48,7 +48,7 @@ following example shows the resulting code for a multi-task learning use-case. loss2 = loss_fn(output2, target2) scaled_losses = scaler.scale([loss1, loss2]) - mtl_backward(tensors=scaled_losses, features=features) + mtl_backward(scaled_losses, features=features) jac_to_grad(shared_module.parameters(), aggregator) scaler.step(optimizer) scaler.update() diff --git a/docs/source/examples/lightning_integration.rst b/docs/source/examples/lightning_integration.rst index 9cc74904..c8416083 100644 --- a/docs/source/examples/lightning_integration.rst +++ b/docs/source/examples/lightning_integration.rst @@ -43,7 +43,7 @@ The following code example demonstrates a basic multi-task learning setup using loss2 = mse_loss(output2, target2) opt = self.optimizers() - mtl_backward(tensors=[loss1, loss2], features=features) + mtl_backward([loss1, loss2], features=features) jac_to_grad(self.feature_extractor.parameters(), UPGrad()) opt.step() opt.zero_grad() diff --git a/docs/source/examples/monitoring.rst b/docs/source/examples/monitoring.rst index fda6af2d..784eea4d 100644 --- a/docs/source/examples/monitoring.rst +++ b/docs/source/examples/monitoring.rst @@ -63,7 +63,7 @@ they have a negative inner product). loss1 = loss_fn(output1, target1) loss2 = loss_fn(output2, target2) - mtl_backward(tensors=[loss1, loss2], features=features) + mtl_backward([loss1, loss2], features=features) jac_to_grad(shared_module.parameters(), aggregator) optimizer.step() optimizer.zero_grad() diff --git a/docs/source/examples/mtl.rst b/docs/source/examples/mtl.rst index e991770f..147a999b 100644 --- a/docs/source/examples/mtl.rst +++ b/docs/source/examples/mtl.rst @@ -52,7 +52,7 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks. loss1 = loss_fn(output1, target1) loss2 = loss_fn(output2, target2) - mtl_backward(tensors=[loss1, loss2], features=features) + mtl_backward([loss1, loss2], features=features) jac_to_grad(shared_module.parameters(), aggregator) optimizer.step() optimizer.zero_grad() diff --git a/src/torchjd/aggregation/_aggregator_bases.py b/src/torchjd/aggregation/_aggregator_bases.py index 5a656f67..78168eae 100644 --- a/src/torchjd/aggregation/_aggregator_bases.py +++ b/src/torchjd/aggregation/_aggregator_bases.py @@ -25,10 +25,10 @@ def _check_is_matrix(matrix: Tensor) -> None: ) @abstractmethod - def forward(self, matrix: Matrix) -> Tensor: + def forward(self, matrix: Matrix, /) -> Tensor: """Computes the aggregation from the input matrix.""" - def __call__(self, matrix: Tensor) -> Tensor: + def __call__(self, matrix: Tensor, /) -> Tensor: """Computes the aggregation from the input matrix and applies all registered hooks.""" Aggregator._check_is_matrix(matrix) return super().__call__(matrix) @@ -62,7 +62,7 @@ def combine(matrix: Matrix, weights: Tensor) -> Tensor: vector = weights @ matrix return vector - def forward(self, matrix: Matrix) -> Tensor: + def forward(self, matrix: Matrix, /) -> Tensor: weights = self.weighting(matrix) vector = self.combine(matrix, weights) return vector diff --git a/src/torchjd/aggregation/_config.py b/src/torchjd/aggregation/_config.py index 980b93b4..24866c41 100644 --- a/src/torchjd/aggregation/_config.py +++ b/src/torchjd/aggregation/_config.py @@ -58,7 +58,7 @@ def __init__(self, pref_vector: Tensor | None = None): # This prevents computing gradients that can be very wrong. self.register_full_backward_pre_hook(raise_non_differentiable_error) - def forward(self, matrix: Matrix) -> Tensor: + def forward(self, matrix: Matrix, /) -> Tensor: weights = self.weighting(matrix) units = torch.nan_to_num((matrix / (matrix.norm(dim=1)).unsqueeze(1)), 0.0) best_direction = torch.linalg.pinv(units) @ weights diff --git a/src/torchjd/aggregation/_flattening.py b/src/torchjd/aggregation/_flattening.py index 8db6027b..208db3ec 100644 --- a/src/torchjd/aggregation/_flattening.py +++ b/src/torchjd/aggregation/_flattening.py @@ -24,7 +24,7 @@ def __init__(self, weighting: Weighting): super().__init__() self.weighting = weighting - def forward(self, generalized_gramian: PSDTensor) -> Tensor: + def forward(self, generalized_gramian: PSDTensor, /) -> Tensor: k = generalized_gramian.ndim // 2 shape = generalized_gramian.shape[:k] square_gramian = flatten(generalized_gramian) diff --git a/src/torchjd/aggregation/_graddrop.py b/src/torchjd/aggregation/_graddrop.py index f8a39426..afa16451 100644 --- a/src/torchjd/aggregation/_graddrop.py +++ b/src/torchjd/aggregation/_graddrop.py @@ -40,7 +40,7 @@ def __init__(self, f: Callable = _identity, leak: Tensor | None = None): # This prevents computing gradients that can be very wrong. self.register_full_backward_pre_hook(raise_non_differentiable_error) - def forward(self, matrix: Matrix) -> Tensor: + def forward(self, matrix: Matrix, /) -> Tensor: self._check_matrix_has_enough_rows(matrix) if matrix.shape[0] == 0 or matrix.shape[1] == 0: diff --git a/src/torchjd/aggregation/_trimmed_mean.py b/src/torchjd/aggregation/_trimmed_mean.py index 07ed055a..77d33c41 100644 --- a/src/torchjd/aggregation/_trimmed_mean.py +++ b/src/torchjd/aggregation/_trimmed_mean.py @@ -24,7 +24,7 @@ def __init__(self, trim_number: int): ) self.trim_number = trim_number - def forward(self, matrix: Tensor) -> Tensor: + def forward(self, matrix: Tensor, /) -> Tensor: self._check_matrix_has_enough_rows(matrix) n_rows = matrix.shape[0] diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index 3b037891..dd7c53ee 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -27,7 +27,7 @@ def __init__(self): def forward(self, stat: _T, /) -> Tensor: """Computes the vector of weights from the input stat.""" - def __call__(self, stat: Tensor) -> Tensor: + def __call__(self, stat: Tensor, /) -> Tensor: """Computes the vector of weights from the input stat and applies all registered hooks.""" # The value of _T (e.g. PSDMatrix) is not public, so we need the user-facing type hint of @@ -67,10 +67,10 @@ def __init__(self): super().__init__() @abstractmethod - def forward(self, generalized_gramian: PSDTensor) -> Tensor: + def forward(self, generalized_gramian: PSDTensor, /) -> Tensor: """Computes the vector of weights from the input generalized Gramian.""" - def __call__(self, generalized_gramian: Tensor) -> Tensor: + def __call__(self, generalized_gramian: Tensor, /) -> Tensor: """ Computes the tensor of weights from the input generalized Gramian and applies all registered hooks. diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index f8248c0c..610b3753 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -235,7 +235,7 @@ def _check_module_is_compatible(self, module: nn.Module) -> None: ) # Currently, the type PSDMatrix is hidden from users, so Tensor is correct. - def compute_gramian(self, output: Tensor) -> Tensor: + def compute_gramian(self, output: Tensor, /) -> Tensor: r""" Computes the Gramian of the Jacobian of ``output`` with respect to the direct parameters of all ``modules``. diff --git a/src/torchjd/autojac/_backward.py b/src/torchjd/autojac/_backward.py index 206e5b92..5a5b0139 100644 --- a/src/torchjd/autojac/_backward.py +++ b/src/torchjd/autojac/_backward.py @@ -15,6 +15,8 @@ def backward( tensors: Sequence[Tensor] | Tensor, + /, + *, jac_tensors: Sequence[Tensor] | Tensor | None = None, inputs: Iterable[Tensor] | None = None, retain_graph: bool = False, diff --git a/src/torchjd/autojac/_jac.py b/src/torchjd/autojac/_jac.py index de8aa200..78c10df2 100644 --- a/src/torchjd/autojac/_jac.py +++ b/src/torchjd/autojac/_jac.py @@ -20,6 +20,7 @@ def jac( outputs: Sequence[Tensor] | Tensor, inputs: Iterable[Tensor] | None = None, + *, jac_outputs: Sequence[Tensor] | Tensor | None = None, retain_graph: bool = False, parallel_chunk_size: int | None = None, diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 63f97dc6..3f1d9a5f 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -11,7 +11,9 @@ def jac_to_grad( tensors: Iterable[Tensor], + /, aggregator: Aggregator, + *, retain_jac: bool = False, ) -> None: r""" diff --git a/src/torchjd/autojac/_mtl_backward.py b/src/torchjd/autojac/_mtl_backward.py index 7389be33..281ddfed 100644 --- a/src/torchjd/autojac/_mtl_backward.py +++ b/src/torchjd/autojac/_mtl_backward.py @@ -24,7 +24,9 @@ def mtl_backward( tensors: Sequence[Tensor], + /, features: Sequence[Tensor] | Tensor, + *, grad_tensors: Sequence[Tensor] | None = None, tasks_params: Sequence[Iterable[Tensor]] | None = None, shared_params: Iterable[Tensor] | None = None, diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index 02dda25a..ac4ac060 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -44,7 +44,7 @@ def test_amp(): loss2 = loss_fn(output2, target2) scaled_losses = scaler.scale([loss1, loss2]) - mtl_backward(tensors=scaled_losses, features=features) + mtl_backward(scaled_losses, features=features) jac_to_grad(shared_module.parameters(), aggregator) scaler.step(optimizer) scaler.update() @@ -250,7 +250,7 @@ def training_step(self, batch, batch_idx) -> None: opt = self.optimizers() - mtl_backward(tensors=[loss1, loss2], features=features) + mtl_backward([loss1, loss2], features=features) jac_to_grad(self.feature_extractor.parameters(), UPGrad()) opt.step() opt.zero_grad() @@ -325,7 +325,7 @@ def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch. loss1 = loss_fn(output1, target1) loss2 = loss_fn(output2, target2) - mtl_backward(tensors=[loss1, loss2], features=features) + mtl_backward([loss1, loss2], features=features) jac_to_grad(shared_module.parameters(), aggregator) optimizer.step() optimizer.zero_grad() @@ -363,7 +363,7 @@ def test_mtl(): loss1 = loss_fn(output1, target1) loss2 = loss_fn(output2, target2) - mtl_backward(tensors=[loss1, loss2], features=features) + mtl_backward([loss1, loss2], features=features) jac_to_grad(shared_module.parameters(), aggregator) optimizer.step() optimizer.zero_grad() diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index 4d75d382..a0398c42 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -310,7 +310,7 @@ def test_input_retaining_grad_fails(): # backward itself doesn't raise the error, but it fills b.grad with a BatchedTensor # (and it also fills b.jac with the correct Jacobian) - backward(tensors=y, inputs=[b]) + backward(y, inputs=[b]) with raises(RuntimeError): # Using such a BatchedTensor should result in an error @@ -329,7 +329,7 @@ def test_non_input_retaining_grad_fails(): y = 3 * b # backward itself doesn't raise the error, but it fills b.grad with a BatchedTensor - backward(tensors=y, inputs=[a]) + backward(y, inputs=[a]) with raises(RuntimeError): # Using such a BatchedTensor should result in an error diff --git a/tests/unit/autojac/test_jac.py b/tests/unit/autojac/test_jac.py index 28e1dcb1..9b108be5 100644 --- a/tests/unit/autojac/test_jac.py +++ b/tests/unit/autojac/test_jac.py @@ -309,7 +309,7 @@ def test_input_retaining_grad_fails(): # jac itself doesn't raise the error, but it fills b.grad with a BatchedTensor (and it also # returns the correct Jacobian) - jac(outputs=y, inputs=[b]) + jac(y, inputs=[b]) with raises(RuntimeError): # Using such a BatchedTensor should result in an error @@ -328,7 +328,7 @@ def test_non_input_retaining_grad_fails(): y = 3 * b # jac itself doesn't raise the error, but it fills b.grad with a BatchedTensor - jac(outputs=y, inputs=[a]) + jac(y, inputs=[a]) with raises(RuntimeError): # Using such a BatchedTensor should result in an error diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index ea69a20b..9f15e6fd 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -60,7 +60,7 @@ def test_shape_is_correct(): y1 = f1 * p1[0] + f2 * p1[1] y2 = f1 * p2[0] + f2 * p2[1] - mtl_backward(tensors=[y1, y2], features=[f1, f2]) + mtl_backward([y1, y2], features=[f1, f2]) assert_has_jac(p0) for p in [p1, p2]: @@ -101,7 +101,7 @@ def test_value_is_correct( tasks_params = [[p1], [p2], [p3]] if manually_specify_tasks_params else None mtl_backward( - tensors=[y1, y2, y3], + [y1, y2, y3], features=f, tasks_params=tasks_params, shared_params=shared_params, @@ -125,7 +125,7 @@ def test_empty_tasks_fails(): f2 = (p0**2).sum() + p0.norm() with raises(ValueError): - mtl_backward(tensors=[], features=[f1, f2]) + mtl_backward([], features=[f1, f2]) def test_single_task(): @@ -138,7 +138,7 @@ def test_single_task(): f2 = (p0**2).sum() + p0.norm() y1 = f1 * p1[0] + f2 * p1[1] - mtl_backward(tensors=[y1], features=[f1, f2]) + mtl_backward([y1], features=[f1, f2]) assert_has_jac(p0) assert_has_grad(p1) @@ -161,14 +161,14 @@ def test_incoherent_task_number_fails(): with raises(ValueError): mtl_backward( - tensors=[y1, y2], + [y1, y2], features=[f1, f2], tasks_params=[[p1]], # Wrong shared_params=[p0], ) with raises(ValueError): mtl_backward( - tensors=[y1], # Wrong + [y1], # Wrong features=[f1, f2], tasks_params=[[p1], [p2]], shared_params=[p0], @@ -188,7 +188,7 @@ def test_empty_params(): y2 = f1 * p2[0] + f2 * p2[1] mtl_backward( - tensors=[y1, y2], + [y1, y2], features=[f1, f2], tasks_params=[[], []], shared_params=[], @@ -214,7 +214,7 @@ def test_multiple_params_per_task(): y1 = f1 * p1_a + (f2 * p1_b).sum() + (f1 * p1_c).sum() y2 = f1 * p2_a * (f2 * p2_b).sum() - mtl_backward(tensors=[y1, y2], features=[f1, f2]) + mtl_backward([y1, y2], features=[f1, f2]) assert_has_jac(p0) for p in [p1_a, p1_b, p1_c, p2_a, p2_b]: @@ -246,7 +246,7 @@ def test_various_shared_params(shared_params_shapes: list[tuple[int]]): y2 = torch.stack([f.sum() ** 2 for f in features]).sum() mtl_backward( - tensors=[y1, y2], + [y1, y2], features=features, tasks_params=[[p1], [p2]], # Enforce differentiation w.r.t. params that haven't been used shared_params=shared_params, @@ -274,7 +274,7 @@ def test_partial_params(): y2 = f1 * p2[0] + f2 * p2[1] mtl_backward( - tensors=[y1, y2], + [y1, y2], features=[f1, f2], tasks_params=[[p1], []], shared_params=[p0], @@ -298,7 +298,7 @@ def test_empty_features_fails(): y2 = f1 * p2[0] + f2 * p2[1] with raises(ValueError): - mtl_backward(tensors=[y1, y2], features=[]) + mtl_backward([y1, y2], features=[]) @mark.parametrize( @@ -322,7 +322,7 @@ def test_various_single_features(shape: tuple[int, ...]): y1 = (f * p1[0]).sum() + (f * p1[1]).sum() y2 = (f * p2[0]).sum() * (f * p2[1]).sum() - mtl_backward(tensors=[y1, y2], features=f) + mtl_backward([y1, y2], features=f) assert_has_jac(p0) for p in [p1, p2]: @@ -354,7 +354,7 @@ def test_various_feature_lists(shapes: list[tuple[int]]): y1 = sum([(f * p).sum() for f, p in zip(features, p1, strict=True)]) y2 = (features[0] * p2).sum() - mtl_backward(tensors=[y1, y2], features=features) + mtl_backward([y1, y2], features=features) assert_has_jac(p0) for p in [p1, p2]: @@ -374,7 +374,7 @@ def test_non_scalar_loss_fails(): y2 = f1 * p2[0] + f2 * p2[1] with raises(ValueError): - mtl_backward(tensors=[y1, y2], features=[f1, f2]) + mtl_backward([y1, y2], features=[f1, f2]) @mark.parametrize("chunk_size", [None, 1, 2, 4]) @@ -391,7 +391,7 @@ def test_various_valid_chunk_sizes(chunk_size): y2 = f1 * p2[0] + f2 * p2[1] mtl_backward( - tensors=[y1, y2], + [y1, y2], features=[f1, f2], parallel_chunk_size=chunk_size, ) @@ -416,7 +416,7 @@ def test_non_positive_chunk_size_fails(chunk_size: int): with raises(ValueError): mtl_backward( - tensors=[y1, y2], + [y1, y2], features=[f1, f2], parallel_chunk_size=chunk_size, ) @@ -440,7 +440,7 @@ def test_shared_param_retaining_grad_fails(): # mtl_backward itself doesn't raise the error, but it fills a.grad with a BatchedTensor mtl_backward( - tensors=[y1, y2], + [y1, y2], features=[f], tasks_params=[[p1], [p2]], shared_params=[a, p0], @@ -469,7 +469,7 @@ def test_shared_activation_retaining_grad_fails(): # mtl_backward itself doesn't raise the error, but it fills a.grad with a BatchedTensor mtl_backward( - tensors=[y1, y2], + [y1, y2], features=[f], tasks_params=[[p1], [p2]], shared_params=[p0], @@ -492,7 +492,7 @@ def test_tasks_params_overlap(): y1 = f * p1 * p12 y2 = f * p2 * p12 - mtl_backward(tensors=[y1, y2], features=[f]) + mtl_backward([y1, y2], features=[f]) assert_grad_close(p2, f * p12) assert_grad_close(p1, f * p12) @@ -512,7 +512,7 @@ def test_tasks_params_are_the_same(): y1 = f * p1 y2 = f + p1 - mtl_backward(tensors=[y1, y2], features=[f]) + mtl_backward([y1, y2], features=[f]) assert_grad_close(p1, f + 1) @@ -534,7 +534,7 @@ def test_task_params_is_subset_of_other_task_params(): y1 = f * p1 y2 = y1 * p2 - mtl_backward(tensors=[y1, y2], features=[f], retain_graph=True) + mtl_backward([y1, y2], features=[f], retain_graph=True) assert_grad_close(p2, y1) assert_grad_close(p1, p2 * f + f) @@ -559,7 +559,7 @@ def test_shared_params_overlapping_with_tasks_params_fails(): with raises(ValueError): mtl_backward( - tensors=[y1, y2], + [y1, y2], features=[f], tasks_params=[[p1], [p0, p2]], # Problem: p0 is also shared shared_params=[p0], @@ -582,7 +582,7 @@ def test_default_shared_params_overlapping_with_default_tasks_params_fails(): with raises(ValueError): mtl_backward( - tensors=[y1, y2], + [y1, y2], features=[f], ) @@ -607,7 +607,7 @@ def test_repeated_losses(): with raises(ValueError): losses = [y1, y1, y2] - mtl_backward(tensors=losses, features=[f1, f2], retain_graph=True) + mtl_backward(losses, features=[f1, f2], retain_graph=True) def test_repeated_features(): @@ -630,7 +630,7 @@ def test_repeated_features(): with raises(ValueError): features = [f1, f1, f2] - mtl_backward(tensors=[y1, y2], features=features) + mtl_backward([y1, y2], features=features) def test_repeated_shared_params(): @@ -654,7 +654,7 @@ def test_repeated_shared_params(): g2 = grad([y2], [p2], retain_graph=True)[0] shared_params = [p0, p0] - mtl_backward(tensors=[y1, y2], features=[f1, f2], shared_params=shared_params) + mtl_backward([y1, y2], features=[f1, f2], shared_params=shared_params) assert_jac_close(p0, J0) assert_grad_close(p1, g1) @@ -682,7 +682,7 @@ def test_repeated_task_params(): g2 = grad([y2], [p2], retain_graph=True)[0] tasks_params = [[p1, p1], [p2]] - mtl_backward(tensors=[y1, y2], features=[f1, f2], tasks_params=tasks_params) + mtl_backward([y1, y2], features=[f1, f2], tasks_params=tasks_params) assert_jac_close(p0, J0) assert_grad_close(p1, g1) @@ -711,7 +711,7 @@ def test_grad_tensors_value_is_correct(): grad_y3 = torch.randn_like(y3) mtl_backward( - tensors=[y1, y2, y3], + [y1, y2, y3], features=f, grad_tensors=[grad_y1, grad_y2, grad_y3], ) @@ -741,7 +741,7 @@ def test_grad_tensors_length_mismatch(): match=r"`grad_tensors` should have the same length as `tensors`\. \(got 1 and 2\)", ): mtl_backward( - tensors=[y1, y2], + [y1, y2], features=f, grad_tensors=[torch.ones_like(y1)], ) @@ -767,7 +767,7 @@ def test_grad_tensors_shape_mismatch(): match=r"Shape mismatch: `grad_tensors\[0\]` has shape .* but `tensors\[0\]` has shape .*\.", ): mtl_backward( - tensors=[y1, y2], + [y1, y2], features=f, grad_tensors=[ones_(2), torch.ones_like(y2)], # shape (2,) != shape () of y1 )