From 159359a7efef8e5b1e70e913302472251a722ecd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 18 Feb 2026 13:07:13 +0100 Subject: [PATCH] chore: Add ARG rule to ruff --- pyproject.toml | 5 +++++ src/torchjd/autogram/_module_hook_manager.py | 12 ++++++------ src/torchjd/autojac/_jac_to_grad.py | 3 +-- src/torchjd/autojac/_transform/_accumulate.py | 4 ++-- src/torchjd/autojac/_transform/_base.py | 6 +++--- src/torchjd/autojac/_transform/_diagonalize.py | 2 +- src/torchjd/autojac/_transform/_differentiate.py | 2 +- src/torchjd/autojac/_transform/_init.py | 4 ++-- src/torchjd/autojac/_transform/_select.py | 2 +- src/torchjd/autojac/_transform/_stack.py | 2 +- tests/unit/autojac/_transform/test_base.py | 4 ++-- tests/unit/autojac/_transform/test_stack.py | 4 ++-- 12 files changed, 27 insertions(+), 23 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a69db94b..86727712 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -133,6 +133,7 @@ select = [ "W", # pycodestyle Warning "I", # isort "UP", # pyupgrade + "ARG", # flake8-unused-arguments "B", # flake8-bugbear "C4", # flake8-comprehensions "FIX", # flake8-fixme @@ -157,6 +158,10 @@ ignore = [ "COM812", # Trailing comma missing (conflicts with formatter, see https://github.com/astral-sh/ruff/issues/9216) ] +[tool.ruff.lint.per-file-ignores] +"**/conftest.py" = ["ARG"] # Can't change argument names in the functions pytest expects +"tests/doc/test_rst.py" = ["ARG"] # For the lightning example + [tool.ruff.lint.isort] combine-as-imports = true diff --git a/src/torchjd/autogram/_module_hook_manager.py b/src/torchjd/autogram/_module_hook_manager.py index ef48b784..e2958d93 100644 --- a/src/torchjd/autogram/_module_hook_manager.py +++ b/src/torchjd/autogram/_module_hook_manager.py @@ -101,7 +101,7 @@ def __init__( def __call__( self, - module: nn.Module, + _module: nn.Module, args: tuple[PyTree, ...], kwargs: dict[str, PyTree], outputs: PyTree, @@ -157,11 +157,11 @@ class AutogramNode(torch.autograd.Function): @staticmethod def forward( - gramian_accumulation_phase: BoolRef, - gramian_computer: GramianComputer, - args: tuple[PyTree, ...], - kwargs: dict[str, PyTree], - gramian_accumulator: GramianAccumulator, + _gramian_accumulation_phase: BoolRef, + _gramian_computer: GramianComputer, + _args: tuple[PyTree, ...], + _kwargs: dict[str, PyTree], + _gramian_accumulator: GramianAccumulator, *rg_tensors: Tensor, ) -> tuple[Tensor, ...]: return tuple(t.detach() for t in rg_tensors) diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 9e46caa0..63f97dc6 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -75,7 +75,7 @@ def jac_to_grad( jacobian_matrix = _unite_jacobians(jacobians) gradient_vector = aggregator(jacobian_matrix) - gradients = _disunite_gradient(gradient_vector, jacobians, tensors_) + gradients = _disunite_gradient(gradient_vector, tensors_) accumulate_grads(tensors_, gradients) @@ -87,7 +87,6 @@ def _unite_jacobians(jacobians: list[Tensor]) -> Tensor: def _disunite_gradient( gradient_vector: Tensor, - jacobians: list[Tensor], tensors: list[TensorWithJac], ) -> list[Tensor]: gradient_vectors = gradient_vector.split([t.numel() for t in tensors]) diff --git a/src/torchjd/autojac/_transform/_accumulate.py b/src/torchjd/autojac/_transform/_accumulate.py index a61c52da..f2ba25f1 100644 --- a/src/torchjd/autojac/_transform/_accumulate.py +++ b/src/torchjd/autojac/_transform/_accumulate.py @@ -18,7 +18,7 @@ def __call__(self, gradients: TensorDict, /) -> TensorDict: accumulate_grads(gradients.keys(), gradients.values()) return {} - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + def check_keys(self, _input_keys: set[Tensor], /) -> set[Tensor]: return set() @@ -35,5 +35,5 @@ def __call__(self, jacobians: TensorDict, /) -> TensorDict: accumulate_jacs(jacobians.keys(), jacobians.values()) return {} - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + def check_keys(self, _input_keys: set[Tensor], /) -> set[Tensor]: return set() diff --git a/src/torchjd/autojac/_transform/_base.py b/src/torchjd/autojac/_transform/_base.py index fbbbdcd6..579b845c 100644 --- a/src/torchjd/autojac/_transform/_base.py +++ b/src/torchjd/autojac/_transform/_base.py @@ -43,7 +43,7 @@ def __call__(self, input: TensorDict, /) -> TensorDict: """Applies the transform to the input.""" @abstractmethod - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]: """ Checks that the provided input_keys satisfy the transform's requirements and returns the corresponding output keys for recursion. @@ -78,7 +78,7 @@ def __call__(self, input: TensorDict, /) -> TensorDict: intermediate = self.inner(input) return self.outer(intermediate) - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]: intermediate_keys = self.inner.check_keys(input_keys) output_keys = self.outer.check_keys(intermediate_keys) return output_keys @@ -111,7 +111,7 @@ def __call__(self, tensor_dict: TensorDict, /) -> TensorDict: union |= transform(tensor_dict) return union - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]: output_keys_list = [key for t in self.transforms for key in t.check_keys(input_keys)] output_keys = set(output_keys_list) diff --git a/src/torchjd/autojac/_transform/_diagonalize.py b/src/torchjd/autojac/_transform/_diagonalize.py index 182306de..7954d7ce 100644 --- a/src/torchjd/autojac/_transform/_diagonalize.py +++ b/src/torchjd/autojac/_transform/_diagonalize.py @@ -69,7 +69,7 @@ def __call__(self, tensors: TensorDict, /) -> TensorDict: } return diagonalized_tensors - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]: if not set(self.key_order) == input_keys: raise RequirementError( f"The input_keys must match the key_order. Found input_keys {input_keys} and" diff --git a/src/torchjd/autojac/_transform/_differentiate.py b/src/torchjd/autojac/_transform/_differentiate.py index 18117f3c..1ce26438 100644 --- a/src/torchjd/autojac/_transform/_differentiate.py +++ b/src/torchjd/autojac/_transform/_differentiate.py @@ -55,7 +55,7 @@ def _differentiate(self, tensor_outputs: Sequence[Tensor], /) -> tuple[Tensor, . tensor_outputs should be. """ - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]: outputs = set(self.outputs) if not outputs == input_keys: raise RequirementError( diff --git a/src/torchjd/autojac/_transform/_init.py b/src/torchjd/autojac/_transform/_init.py index 26042979..50833032 100644 --- a/src/torchjd/autojac/_transform/_init.py +++ b/src/torchjd/autojac/_transform/_init.py @@ -16,10 +16,10 @@ class Init(Transform): def __init__(self, values: AbstractSet[Tensor]): self.values = values - def __call__(self, input: TensorDict, /) -> TensorDict: + def __call__(self, _input: TensorDict, /) -> TensorDict: return {value: torch.ones_like(value) for value in self.values} - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]: if not input_keys == set(): raise RequirementError( f"The input_keys should be the empty set. Found input_keys {input_keys}.", diff --git a/src/torchjd/autojac/_transform/_select.py b/src/torchjd/autojac/_transform/_select.py index 1575ecf6..b2e45caa 100644 --- a/src/torchjd/autojac/_transform/_select.py +++ b/src/torchjd/autojac/_transform/_select.py @@ -19,7 +19,7 @@ def __call__(self, tensor_dict: TensorDict, /) -> TensorDict: output = {key: tensor_dict[key] for key in self.keys} return type(tensor_dict)(output) - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]: keys = set(self.keys) if not keys.issubset(input_keys): raise RequirementError( diff --git a/src/torchjd/autojac/_transform/_stack.py b/src/torchjd/autojac/_transform/_stack.py index 1a3fc2ad..a4152afc 100644 --- a/src/torchjd/autojac/_transform/_stack.py +++ b/src/torchjd/autojac/_transform/_stack.py @@ -28,7 +28,7 @@ def __call__(self, input: TensorDict, /) -> TensorDict: result = _stack(results) return result - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]: return {key for transform in self.transforms for key in transform.check_keys(input_keys)} diff --git a/tests/unit/autojac/_transform/test_base.py b/tests/unit/autojac/_transform/test_base.py index 5da475e6..254147bd 100644 --- a/tests/unit/autojac/_transform/test_base.py +++ b/tests/unit/autojac/_transform/test_base.py @@ -17,12 +17,12 @@ def __init__(self, required_keys: set[Tensor], output_keys: set[Tensor]): def __str__(self): return "T" - def __call__(self, input: TensorDict, /) -> TensorDict: + def __call__(self, _input: TensorDict, /) -> TensorDict: # Ignore the input, create a dictionary with the right keys as an output. output_dict = {key: empty_(0) for key in self._output_keys} return output_dict - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]: # Arbitrary requirement for testing purposes. if not input_keys == self._required_keys: raise RequirementError() diff --git a/tests/unit/autojac/_transform/test_stack.py b/tests/unit/autojac/_transform/test_stack.py index fc2cdf7a..35e617d1 100644 --- a/tests/unit/autojac/_transform/test_stack.py +++ b/tests/unit/autojac/_transform/test_stack.py @@ -15,10 +15,10 @@ class FakeGradientsTransform(Transform): def __init__(self, keys: Iterable[Tensor]): self.keys = set(keys) - def __call__(self, input: TensorDict, /) -> TensorDict: + def __call__(self, _input: TensorDict, /) -> TensorDict: return {key: torch.ones_like(key) for key in self.keys} - def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]: + def check_keys(self, _input_keys: set[Tensor], /) -> set[Tensor]: return self.keys