Skip to content

Commit 18b6df0

Browse files
authored
chore: Add ARG rule to ruff (#577)
* Add ARG rule to ruff * Add [tool.ruff.lint.per-file-ignores] to ignore the ARG rule in `conftest.py` and `test_rst.py` * Add `_` in front of unused parameter names that we have to keep * Remove unused `jacobians` parameter of `_disunite_gradient` * Make the `input_keys` parameter of `check_keys` positional-only, so that we can rename it in subclasses
1 parent 3ff9a8b commit 18b6df0

File tree

12 files changed

+27
-23
lines changed

12 files changed

+27
-23
lines changed

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ select = [
133133
"W", # pycodestyle Warning
134134
"I", # isort
135135
"UP", # pyupgrade
136+
"ARG", # flake8-unused-arguments
136137
"B", # flake8-bugbear
137138
"C4", # flake8-comprehensions
138139
"FIX", # flake8-fixme
@@ -157,6 +158,10 @@ ignore = [
157158
"COM812", # Trailing comma missing (conflicts with formatter, see https://github.com/astral-sh/ruff/issues/9216)
158159
]
159160

161+
[tool.ruff.lint.per-file-ignores]
162+
"**/conftest.py" = ["ARG"] # Can't change argument names in the functions pytest expects
163+
"tests/doc/test_rst.py" = ["ARG"] # For the lightning example
164+
160165
[tool.ruff.lint.isort]
161166
combine-as-imports = true
162167

src/torchjd/autogram/_module_hook_manager.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def __init__(
101101

102102
def __call__(
103103
self,
104-
module: nn.Module,
104+
_module: nn.Module,
105105
args: tuple[PyTree, ...],
106106
kwargs: dict[str, PyTree],
107107
outputs: PyTree,
@@ -157,11 +157,11 @@ class AutogramNode(torch.autograd.Function):
157157

158158
@staticmethod
159159
def forward(
160-
gramian_accumulation_phase: BoolRef,
161-
gramian_computer: GramianComputer,
162-
args: tuple[PyTree, ...],
163-
kwargs: dict[str, PyTree],
164-
gramian_accumulator: GramianAccumulator,
160+
_gramian_accumulation_phase: BoolRef,
161+
_gramian_computer: GramianComputer,
162+
_args: tuple[PyTree, ...],
163+
_kwargs: dict[str, PyTree],
164+
_gramian_accumulator: GramianAccumulator,
165165
*rg_tensors: Tensor,
166166
) -> tuple[Tensor, ...]:
167167
return tuple(t.detach() for t in rg_tensors)

src/torchjd/autojac/_jac_to_grad.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def jac_to_grad(
7575

7676
jacobian_matrix = _unite_jacobians(jacobians)
7777
gradient_vector = aggregator(jacobian_matrix)
78-
gradients = _disunite_gradient(gradient_vector, jacobians, tensors_)
78+
gradients = _disunite_gradient(gradient_vector, tensors_)
7979
accumulate_grads(tensors_, gradients)
8080

8181

@@ -87,7 +87,6 @@ def _unite_jacobians(jacobians: list[Tensor]) -> Tensor:
8787

8888
def _disunite_gradient(
8989
gradient_vector: Tensor,
90-
jacobians: list[Tensor],
9190
tensors: list[TensorWithJac],
9291
) -> list[Tensor]:
9392
gradient_vectors = gradient_vector.split([t.numel() for t in tensors])

src/torchjd/autojac/_transform/_accumulate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __call__(self, gradients: TensorDict, /) -> TensorDict:
1818
accumulate_grads(gradients.keys(), gradients.values())
1919
return {}
2020

21-
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
21+
def check_keys(self, _input_keys: set[Tensor], /) -> set[Tensor]:
2222
return set()
2323

2424

@@ -35,5 +35,5 @@ def __call__(self, jacobians: TensorDict, /) -> TensorDict:
3535
accumulate_jacs(jacobians.keys(), jacobians.values())
3636
return {}
3737

38-
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
38+
def check_keys(self, _input_keys: set[Tensor], /) -> set[Tensor]:
3939
return set()

src/torchjd/autojac/_transform/_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __call__(self, input: TensorDict, /) -> TensorDict:
4343
"""Applies the transform to the input."""
4444

4545
@abstractmethod
46-
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
46+
def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]:
4747
"""
4848
Checks that the provided input_keys satisfy the transform's requirements and returns the
4949
corresponding output keys for recursion.
@@ -78,7 +78,7 @@ def __call__(self, input: TensorDict, /) -> TensorDict:
7878
intermediate = self.inner(input)
7979
return self.outer(intermediate)
8080

81-
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
81+
def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]:
8282
intermediate_keys = self.inner.check_keys(input_keys)
8383
output_keys = self.outer.check_keys(intermediate_keys)
8484
return output_keys
@@ -111,7 +111,7 @@ def __call__(self, tensor_dict: TensorDict, /) -> TensorDict:
111111
union |= transform(tensor_dict)
112112
return union
113113

114-
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
114+
def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]:
115115
output_keys_list = [key for t in self.transforms for key in t.check_keys(input_keys)]
116116
output_keys = set(output_keys_list)
117117

src/torchjd/autojac/_transform/_diagonalize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __call__(self, tensors: TensorDict, /) -> TensorDict:
6969
}
7070
return diagonalized_tensors
7171

72-
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
72+
def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]:
7373
if not set(self.key_order) == input_keys:
7474
raise RequirementError(
7575
f"The input_keys must match the key_order. Found input_keys {input_keys} and"

src/torchjd/autojac/_transform/_differentiate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _differentiate(self, tensor_outputs: Sequence[Tensor], /) -> tuple[Tensor, .
5555
tensor_outputs should be.
5656
"""
5757

58-
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
58+
def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]:
5959
outputs = set(self.outputs)
6060
if not outputs == input_keys:
6161
raise RequirementError(

src/torchjd/autojac/_transform/_init.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ class Init(Transform):
1616
def __init__(self, values: AbstractSet[Tensor]):
1717
self.values = values
1818

19-
def __call__(self, input: TensorDict, /) -> TensorDict:
19+
def __call__(self, _input: TensorDict, /) -> TensorDict:
2020
return {value: torch.ones_like(value) for value in self.values}
2121

22-
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
22+
def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]:
2323
if not input_keys == set():
2424
raise RequirementError(
2525
f"The input_keys should be the empty set. Found input_keys {input_keys}.",

src/torchjd/autojac/_transform/_select.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __call__(self, tensor_dict: TensorDict, /) -> TensorDict:
1919
output = {key: tensor_dict[key] for key in self.keys}
2020
return type(tensor_dict)(output)
2121

22-
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
22+
def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]:
2323
keys = set(self.keys)
2424
if not keys.issubset(input_keys):
2525
raise RequirementError(

src/torchjd/autojac/_transform/_stack.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __call__(self, input: TensorDict, /) -> TensorDict:
2828
result = _stack(results)
2929
return result
3030

31-
def check_keys(self, input_keys: set[Tensor]) -> set[Tensor]:
31+
def check_keys(self, input_keys: set[Tensor], /) -> set[Tensor]:
3232
return {key for transform in self.transforms for key in transform.check_keys(input_keys)}
3333

3434

0 commit comments

Comments
 (0)