From 4d05c6d11a8d0a41be9da17a3c9d913a0377f5a5 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 24 Apr 2026 04:44:36 +0000 Subject: [PATCH 1/6] Match rules against canonical qualnames across all import styles Previously the linter only recognised trio/anyio/asyncio-related calls when they appeared exactly as `trio.open_nursery`, `anyio.create_task_group`, etc. Aliased imports (`import trio as t`), `from` imports (`from trio import open_nursery`), and aliased-from imports (`from trio import open_nursery as on`) silently escaped detection. This adds a pair of utility visitors (VisitorImportTracker / _cst) that build a local-name -> canonical-dotted-qualname map, a pair of helpers (resolve_canonical_ast / _cst) and base-class shortcut `canonical_name()`, and threads an `imports=` keyword through the existing matcher helpers (get_matching_call[_cst], fnmatch_qualified_name[_cst], with_has_call, calls_any_of, critical_except). The tracker only records module-level imports, so function-local imports don't leak into sibling scopes. Existing visitors are updated to pass `self.imports` to those helpers, and ASYNC105/ASYNC115/ASYNC118/ASYNC2xx/ASYNC300 etc. now match via canonical qualname instead of the literal spelling. ASYNC106 was a workaround for the old limitation; it's now disabled by default but left in place for projects that still want to enforce the `import trio` style. Closes python-trio/flake8-async#132. https://claude.ai/code/session_018Hc9rcA31SnXcN8Ee5vVwH --- docs/changelog.rst | 2 + docs/rules.rst | 5 +- flake8_async/runner.py | 6 + flake8_async/visitors/flake8asyncvisitor.py | 26 +++ flake8_async/visitors/helpers.py | 173 +++++++++++++++--- flake8_async/visitors/visitor101.py | 2 +- flake8_async/visitors/visitor102_120.py | 22 ++- flake8_async/visitors/visitor103_104.py | 2 +- flake8_async/visitors/visitor105.py | 7 +- flake8_async/visitors/visitor111.py | 12 +- flake8_async/visitors/visitor118.py | 7 + flake8_async/visitors/visitor123.py | 3 + flake8_async/visitors/visitor2xx.py | 65 +++++-- flake8_async/visitors/visitor91x.py | 22 ++- flake8_async/visitors/visitor_utility.py | 132 ++++++++++++- flake8_async/visitors/visitors.py | 76 +++++--- .../exception_suppress_context_manager.py | 9 +- ...exception_suppress_context_manager.py.diff | 8 + tests/eval_files/async110.py | 4 +- tests/eval_files/async111.py | 4 +- tests/eval_files/async112.py | 4 +- .../eval_files/async112_canonical_qualname.py | 30 +++ tests/eval_files/async115.py | 5 +- .../eval_files/async115_canonical_qualname.py | 38 ++++ tests/eval_files/async251.py | 5 +- tests/eval_files/async300.py | 3 +- .../exception_suppress_context_manager.py | 8 +- tests/test_config_and_args.py | 5 +- 28 files changed, 579 insertions(+), 106 deletions(-) create mode 100644 tests/eval_files/async112_canonical_qualname.py create mode 100644 tests/eval_files/async115_canonical_qualname.py diff --git a/docs/changelog.rst b/docs/changelog.rst index 4df8d045..3ad8e3af 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -6,6 +6,8 @@ Changelog Unreleased ========== +- Rules now resolve function/class references against the canonical qualname, so checks fire regardless of how a symbol was imported (``import trio``, ``import trio as t``, ``from trio import open_nursery``, ``from trio import open_nursery as on``, etc.). Only module-level imports are tracked; function-local imports are still considered local. `(issue #132) `_ +- :ref:`ASYNC106 ` is now disabled by default: with the canonical-qualname resolution above, the rule is no longer required for the linter to work correctly. Re-enable it explicitly if you still want to enforce the ``import trio`` style. - Autofix for :ref:`ASYNC910 ` / :ref:`ASYNC911 ` no longer inserts checkpoints inside ``except`` clauses (which would trigger :ref:`ASYNC120 `); instead the checkpoint is added at the top of the function or of the enclosing loop. `(issue #403) `_ - :ref:`ASYNC910 ` and :ref:`ASYNC911 ` now accept ``__aenter__`` / ``__aexit__`` methods when the partner method provides the checkpoint, or when only one of the two is defined on a class that inherits from another class (charitably assuming the partner is inherited and contains a checkpoint). `(issue #441) `_ - :ref:`ASYNC300 ` no longer triggers when the result of ``asyncio.create_task()`` is returned from a function. `(issue #398) `_ diff --git a/docs/rules.rst b/docs/rules.rst index ba25950e..1ee50307 100644 --- a/docs/rules.rst +++ b/docs/rules.rst @@ -42,7 +42,10 @@ ASYNC105 : missing-await This is only supported with trio functions, but you can get similar functionality with a type-checker. ASYNC106 : bad-async-library-import - trio/anyio/asyncio must be imported with ``import xxx`` for the linter to work. + trio/anyio/asyncio should be imported with ``import xxx`` for consistency. + As of the canonical-qualname refactor this rule is no longer required for the + linter to work, and is therefore disabled by default -- enable it explicitly + if you want to enforce the style. ASYNC109 : async-function-with-timeout Async function definition with a ``timeout`` parameter. diff --git a/flake8_async/runner.py b/flake8_async/runner.py index 38ff3f53..e0fbc1e8 100644 --- a/flake8_async/runner.py +++ b/flake8_async/runner.py @@ -37,6 +37,12 @@ class SharedState: library: tuple[str, ...] = () typed_calls: dict[str, str] = field(default_factory=dict[str, str]) variables: dict[str, str] = field(default_factory=dict[str, str]) + # Maps a locally-bound name to its canonical dotted qualname, populated by + # VisitorImportTracker/VisitorImportTracker_cst. Used by helpers so that + # rules can be written against canonical qualnames and match regardless of + # how things were imported (bare `import x`, `import x as y`, + # `from x import y`, or `from x import y as z`). + imports: dict[str, str] = field(default_factory=dict[str, str]) class __CommonRunner: diff --git a/flake8_async/visitors/flake8asyncvisitor.py b/flake8_async/visitors/flake8asyncvisitor.py index 77742814..d6128c95 100644 --- a/flake8_async/visitors/flake8asyncvisitor.py +++ b/flake8_async/visitors/flake8asyncvisitor.py @@ -53,6 +53,19 @@ def variables(self, value: dict[str, str]) -> None: self.__state.variables.clear() self.__state.variables.update(value) + @property + def imports(self) -> dict[str, str]: + return self.__state.imports + + def canonical_name(self, node: ast.AST) -> str | None: + """Resolve `node` to a dotted canonical qualname, consulting imports. + + See ``resolve_canonical_ast`` for semantics. + """ + from .helpers import resolve_canonical_ast + + return resolve_canonical_ast(node, self.__state.imports) + def visit(self, node: ast.AST): """Visit a node.""" # construct visitor for this node type @@ -170,6 +183,19 @@ def __init__(self, shared_state: SharedState): self.options = self.__state.options self.noqas = self.__state.noqas + @property + def imports(self) -> dict[str, str]: + return self.__state.imports + + def canonical_name(self, node: cst.CSTNode) -> str | None: + """Resolve `node` to a dotted canonical qualname, consulting imports. + + See ``resolve_canonical_cst`` for semantics. + """ + from .helpers import resolve_canonical_cst + + return resolve_canonical_cst(node, self.__state.imports) + def get_state(self, *attrs: str, copy: bool = False) -> dict[str, Any]: # require attrs, since we inherit a *ton* of stuff which we don't want to copy assert attrs diff --git a/flake8_async/visitors/helpers.py b/flake8_async/visitors/helpers.py index c9d4f544..26e42098 100644 --- a/flake8_async/visitors/helpers.py +++ b/flake8_async/visitors/helpers.py @@ -26,7 +26,7 @@ ) if TYPE_CHECKING: - from collections.abc import Iterable, Iterator, Sequence + from collections.abc import Iterable, Iterator, Mapping, Sequence from .flake8asyncvisitor import ( Flake8AsyncVisitor, @@ -101,15 +101,24 @@ def has_decorator(node: ast.FunctionDef | ast.AsyncFunctionDef, *names: str): # matches the fully qualified name against fnmatch pattern # used to match decorators and methods to user-supplied patterns # used in 910/911 and 200 -def fnmatch_qualified_name(name_list: list[ast.expr], *patterns: str) -> str | None: +def fnmatch_qualified_name( + name_list: Iterable[ast.expr], + *patterns: str, + imports: Mapping[str, str] | None = None, +) -> str | None: for name in name_list: if isinstance(name, ast.Call): name = name.func - qualified_name = ast.unparse(name) + qualified_names = [ast.unparse(name)] + if imports is not None: + canonical = resolve_canonical_ast(name, imports) + if canonical is not None and canonical not in qualified_names: + qualified_names.append(canonical) for pattern in patterns: # strip leading "@"s for when we're working with decorators - if fnmatch(qualified_name, pattern.lstrip("@")): + stripped = pattern.lstrip("@") + if any(fnmatch(qn, stripped) for qn in qualified_names): return pattern return None @@ -117,13 +126,24 @@ def fnmatch_qualified_name(name_list: list[ast.expr], *patterns: str) -> str | N def fnmatch_qualified_name_cst( name_list: Iterable[cst.Decorator | cst.Call | cst.Attribute | cst.Name], *patterns: str, + imports: Mapping[str, str] | None = None, ) -> str | None: for name in name_list: - qualified_name = get_full_name_for_node_or_raise(name) + qualified_names = [get_full_name_for_node_or_raise(name)] + if imports is not None: + node: cst.CSTNode = name + if isinstance(node, cst.Decorator): + node = node.decorator + if isinstance(node, cst.Call): + node = node.func + canonical = resolve_canonical_cst(node, imports) + if canonical is not None and canonical not in qualified_names: + qualified_names.append(canonical) for pattern in patterns: # strip leading "@"s for when we're working with decorators - if fnmatch(qualified_name, pattern.lstrip("@")): + stripped = pattern.lstrip("@") + if any(fnmatch(qn, stripped) for qn in qualified_names): return pattern return None @@ -240,7 +260,9 @@ def iter_guaranteed_once_cst(iterable: cst.BaseExpression) -> bool: # used in 102, 103 and 104 -def critical_except(node: ast.ExceptHandler) -> Statement | None: +def critical_except( + node: ast.ExceptHandler, imports: Mapping[str, str] | None = None +) -> Statement | None: def has_exception(node: ast.expr) -> str | None: name = ast.unparse(node) if name in ( @@ -253,6 +275,29 @@ def has_exception(node: ast.expr) -> str | None: "CancelledError", ): return name + # Match via canonical qualname, so `import trio as t; except t.Cancelled`, + # `from trio import Cancelled`, `from asyncio import CancelledError as CE`, etc. + # also get picked up. The non-call forms (`except anyio.get_cancelled_exc_class:` + # and `except ...(...)` with args) are type-errors the existing code + # intentionally ignores, so only match zero-arg calls for the dynamic form. + if imports is not None: + if isinstance(node, ast.Call): + if node.args or node.keywords: + return None + canonical = resolve_canonical_ast(node.func, imports) + else: + canonical = resolve_canonical_ast(node, imports) + if canonical == "trio.Cancelled" and not isinstance(node, ast.Call): + return "trio.Cancelled" + if canonical == "anyio.get_cancelled_exc_class" and isinstance( + node, ast.Call + ): + return "anyio.get_cancelled_exc_class()" + if canonical in ( + "asyncio.exceptions.CancelledError", + "asyncio.CancelledError", + ) and not isinstance(node, ast.Call): + return "asyncio.exceptions.CancelledError" return None name: str | None = None @@ -300,38 +345,95 @@ def __str__(self) -> str: return self.base + "." + self.name +# Resolve an ast Name/Attribute to a canonical dotted qualname, using the `imports` +# map (local-name -> canonical dotted qualname). Returns None for non-name nodes +# (e.g. subscripts, calls). If the root-most Name isn't in `imports`, we fall back +# to using the literal identifier text — so `trio.open_nursery()` without any +# imports still resolves to "trio.open_nursery", preserving prior behaviour. +def resolve_canonical_ast(node: ast.AST, imports: Mapping[str, str]) -> str | None: + if isinstance(node, ast.Name): + return imports.get(node.id, node.id) + if isinstance(node, ast.Attribute): + prefix = resolve_canonical_ast(node.value, imports) + if prefix is None: + return None + return f"{prefix}.{node.attr}" + if isinstance(node, ast.Call): + return resolve_canonical_ast(node.func, imports) + return None + + +def resolve_canonical_cst( + node: cst.CSTNode, imports: Mapping[str, str] +) -> str | None: + if isinstance(node, cst.Name): + return imports.get(node.value, node.value) + if isinstance(node, cst.Attribute): + prefix = resolve_canonical_cst(node.value, imports) + if prefix is None: + return None + return f"{prefix}.{node.attr.value}" + if isinstance(node, cst.Call): + return resolve_canonical_cst(node.func, imports) + return None + + # convenience function used in a lot of visitors def get_matching_call( - node: ast.AST, *names: str, base: Iterable[str] = ("trio", "anyio") + node: ast.AST, + *names: str, + base: Iterable[str] = ("trio", "anyio"), + imports: Mapping[str, str] | None = None, ) -> MatchingCall[ast.Call] | None: if isinstance(base, str): base = (base,) + if not isinstance(node, ast.Call): + return None + # Fast path: matches the existing structural check. if ( - isinstance(node, ast.Call) - and isinstance(node.func, ast.Attribute) + isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name) and node.func.value.id in base and node.func.attr in names ): return MatchingCall(node, node.func.attr, node.func.value.id) + # Canonical-qualname path: works regardless of how things got imported + # (e.g. `import trio as t`, `from trio import open_nursery [as x]`). + if imports is not None: + canonical = resolve_canonical_ast(node.func, imports) + if canonical is None: + return None + for b in base: + for n in names: + if canonical == f"{b}.{n}": + return MatchingCall(node, n, b) return None # ___ CST helpers ___ def get_matching_call_cst( - node: cst.CSTNode, *names: str, base: Iterable[str] = ("trio", "anyio") + node: cst.CSTNode, + *names: str, + base: Iterable[str] = ("trio", "anyio"), + imports: Mapping[str, str] | None = None, ) -> MatchingCall[cst.Call] | None: if isinstance(base, str): base = (base,) - if ( - isinstance(node, cst.Call) - and isinstance(node.func, cst.Attribute) - and node.func.attr.value in names - and isinstance(node.func.value, (cst.Name, cst.Attribute)) - ): - attr_base = identifier_to_string(node.func.value) - if attr_base is not None and attr_base in base: - return MatchingCall(node, node.func.attr.value, attr_base) + if not isinstance(node, cst.Call): + return None + if isinstance(node.func, cst.Attribute) and node.func.attr.value in names: + if isinstance(node.func.value, (cst.Name, cst.Attribute)): + attr_base = identifier_to_string(node.func.value) + if attr_base is not None and attr_base in base: + return MatchingCall(node, node.func.attr.value, attr_base) + if imports is not None: + canonical = resolve_canonical_cst(node.func, imports) + if canonical is None: + return None + for b in base: + for n in names: + if canonical == f"{b}.{n}": + return MatchingCall(node, n, b) return None @@ -377,7 +479,10 @@ def identifier_to_string(node: cst.CSTNode) -> str | None: def with_has_call( - node: cst.With, *names: str, base: Iterable[str] | str = ("trio", "anyio") + node: cst.With, + *names: str, + base: Iterable[str] | str = ("trio", "anyio"), + imports: Mapping[str, str] | None = None, ) -> list[MatchingCall[cst.Call]]: """Check if a with statement has a matching call, returning a list with matches. @@ -392,15 +497,18 @@ def with_has_call( `with_has_call(node, "bar", "bee", base=("foo", "a.b.c")` matches `foo.bar`, `foo.bee`, `a.b.c.bar`, and `a.b.c.bee`. + When `imports` is passed, matches against the canonical qualname so that + aliased/from-imports are detected as well. """ if isinstance(base, str): base = (base,) + base_tuple = tuple(base) # build matcher, using SaveMatchedNode to save the base and the function name. matcher = m.Call( func=m.Attribute( value=m.SaveMatchedNode( - m.OneOf(*(build_cst_matcher(b) for b in base)), name="base" + m.OneOf(*(build_cst_matcher(b) for b in base_tuple)), name="base" ), attr=m.SaveMatchedNode( oneof_names(*names), @@ -422,10 +530,26 @@ def with_has_call( node=item.item, base=base_string, name=res["function"].value ) ) + continue + if imports is None or not isinstance(item.item, cst.Call): + continue + canonical = resolve_canonical_cst(item.item.func, imports) + if canonical is None: + continue + for b in base_tuple: + for n in names: + if canonical == f"{b}.{n}": + res_list.append(MatchingCall(node=item.item, base=b, name=n)) + break + else: + continue + break return res_list -def calls_any_of(node: cst.With, *qualnames: str) -> bool: +def calls_any_of( + node: cst.With, *qualnames: str, imports: Mapping[str, str] | None = None +) -> bool: """Return True if `node` contains a withitem matching any of `qualnames`. Each `qualname` is a dotted string like ``"trio.open_nursery"`` or @@ -439,7 +563,8 @@ def calls_any_of(node: cst.With, *qualnames: str) -> bool: assert name, f"{qn!r} is not a dotted qualname" by_base[base].append(name) return any( - with_has_call(node, *names, base=base) for base, names in by_base.items() + with_has_call(node, *names, base=base, imports=imports) + for base, names in by_base.items() ) diff --git a/flake8_async/visitors/visitor101.py b/flake8_async/visitors/visitor101.py index fe03f94c..0137f891 100644 --- a/flake8_async/visitors/visitor101.py +++ b/flake8_async/visitors/visitor101.py @@ -76,7 +76,7 @@ def visit_With(self, node: cst.With): self._yield_is_error = ( not self._safe_decorator and not self._yield_is_error - and calls_any_of(node, *_CANCEL_SCOPE_CMS) + and calls_any_of(node, *_CANCEL_SCOPE_CMS, imports=self.imports) ) def leave_With( diff --git a/flake8_async/visitors/visitor102_120.py b/flake8_async/visitors/visitor102_120.py index 83055e1e..ad3e20c3 100644 --- a/flake8_async/visitors/visitor102_120.py +++ b/flake8_async/visitors/visitor102_120.py @@ -96,7 +96,10 @@ def is_safe_aclose_call(self, node: ast.Await) -> bool: return True # allow `trio.aclose_forcefully()` / `anyio.aclose_forcefully()`, # which are specifically designed for cleanup and cancel immediately by design - return get_matching_call(node.value, "aclose_forcefully") is not None + return ( + get_matching_call(node.value, "aclose_forcefully", imports=self.imports) + is not None + ) # trio.lowlevel.cancel_shielded_checkpoint (and the anyio equivalent) are # explicitly a schedule-but-not-cancel point, so they're safe to await @@ -106,7 +109,7 @@ def is_safe_shielded_checkpoint(self, node: ast.Await) -> bool: isinstance(node.value, ast.Call) and not node.value.args and not node.value.keywords - and ast.unparse(node.value.func) + and self.canonical_name(node.value.func) in ( "trio.lowlevel.cancel_shielded_checkpoint", "anyio.lowlevel.cancel_shielded_checkpoint", @@ -133,6 +136,7 @@ def visit_With(self, node: ast.With | ast.AsyncWith): "open_nursery", "create_task_group", *cancel_scope_names, + imports=self.imports, ) if call is None: continue @@ -151,9 +155,17 @@ def visit_AsyncWith(self, node: ast.AsyncWith): # asyncio.TaskGroup() appears to be a source of cancellation when exiting. for item in node.items: if not ( - get_matching_call(item.context_expr, "open_nursery", base="trio") + get_matching_call( + item.context_expr, + "open_nursery", + base="trio", + imports=self.imports, + ) or get_matching_call( - item.context_expr, "create_task_group", base="anyio" + item.context_expr, + "create_task_group", + base="anyio", + imports=self.imports, ) ): self.async_call_checker(node) @@ -193,7 +205,7 @@ def visit_ExceptHandler(self, node: ast.ExceptHandler): self._trio_context_managers = [] self._potential_120 = [] - if self.cancelled_caught or (res := critical_except(node)) is None: + if self.cancelled_caught or (res := critical_except(node, self.imports)) is None: self._critical_scope = Statement("except", node.lineno, node.col_offset) else: self._critical_scope = res diff --git a/flake8_async/visitors/visitor103_104.py b/flake8_async/visitors/visitor103_104.py index 3e234e73..ac63cca6 100644 --- a/flake8_async/visitors/visitor103_104.py +++ b/flake8_async/visitors/visitor103_104.py @@ -75,7 +75,7 @@ def __init__(self, *args: Any, **kwargs: Any): # set self.unraised, and if it's still set after visiting child nodes # then there might be a code path that doesn't re-raise. def visit_ExceptHandler(self, node: ast.ExceptHandler): - marker = critical_except(node) + marker = critical_except(node, self.imports) if marker is None: # not a critical exception handler diff --git a/flake8_async/visitors/visitor105.py b/flake8_async/visitors/visitor105.py index c7e0ddfe..fb63a93e 100644 --- a/flake8_async/visitors/visitor105.py +++ b/flake8_async/visitors/visitor105.py @@ -56,8 +56,11 @@ def visit_Call(self, node: ast.Call): if getattr(node, "awaited", False) or "trio" not in self.library: return - if (name := ast.unparse(node.func)) in trio_async_funcs: - self.error(node, name, "function") + canonical = self.canonical_name(node.func) + if canonical in trio_async_funcs: + # report the canonical qualname so the message is stable regardless of + # how the user imported the function. + self.error(node, canonical, "function") elif isinstance(node.func, ast.Attribute) and node.func.attr == "start": var = ast.unparse(node.func.value) diff --git a/flake8_async/visitors/visitor111.py b/flake8_async/visitors/visitor111.py index 72dca1ff..3cb2c6f0 100644 --- a/flake8_async/visitors/visitor111.py +++ b/flake8_async/visitors/visitor111.py @@ -12,11 +12,13 @@ from collections.abc import Mapping -def is_nursery_like(node: ast.expr) -> bool: +def is_nursery_like( + node: ast.expr, imports: Mapping[str, str] | None = None +) -> bool: return bool( - get_matching_call(node, "open_nursery", base="trio") - or get_matching_call(node, "create_task_group", base="anyio") - or get_matching_call(node, "TaskGroup", base="asyncio") + get_matching_call(node, "open_nursery", base="trio", imports=imports) + or get_matching_call(node, "create_task_group", base="anyio", imports=imports) + or get_matching_call(node, "TaskGroup", base="asyncio", imports=imports) ) @@ -56,7 +58,7 @@ def visit_With(self, node: ast.With | ast.AsyncWith): self.TrioContextManager( item.context_expr.lineno, item.optional_vars.id, - is_nursery_like(item.context_expr), + is_nursery_like(item.context_expr, self.imports), ) ) diff --git a/flake8_async/visitors/visitor118.py b/flake8_async/visitors/visitor118.py index 4066f508..a32171f8 100644 --- a/flake8_async/visitors/visitor118.py +++ b/flake8_async/visitors/visitor118.py @@ -29,6 +29,13 @@ class Visitor118(Flake8AsyncVisitor): def visit_Assign(self, node: ast.Assign | ast.AnnAssign): if node.value is None: return + value = node.value + func_node = value.func if isinstance(value, ast.Call) else value + canonical = self.canonical_name(func_node) + if canonical == "anyio.get_cancelled_exc_class": + self.error(node.value) + return + # Fallback to literal matching for un-imported or unusual forms. name = ast.unparse(node.value) if re.fullmatch(r"(anyio.)?get_cancelled_exc_class(\(\))?", name): self.error(node.value) diff --git a/flake8_async/visitors/visitor123.py b/flake8_async/visitors/visitor123.py index 32fd5a06..860dcbf0 100644 --- a/flake8_async/visitors/visitor123.py +++ b/flake8_async/visitors/visitor123.py @@ -69,6 +69,9 @@ def visit_ExceptHandler(self, node: ast.ExceptHandler): "child_exception_names", copy=True, ) + # ExceptionGroup/BaseExceptionGroup are builtins; we match by literal name + # since they're typically used unqualified (and aliasing the builtins + # themselves is very unusual). if node.name is None or ( not self.try_star and (node.type is None or "ExceptionGroup" not in ast.unparse(node.type)) diff --git a/flake8_async/visitors/visitor2xx.py b/flake8_async/visitors/visitor2xx.py index 91df95d5..4ee5887d 100644 --- a/flake8_async/visitors/visitor2xx.py +++ b/flake8_async/visitors/visitor2xx.py @@ -50,7 +50,9 @@ def visit_Call(self, node: ast.Call): def visit_blocking_call(self, node: ast.Call): blocking_calls = self.options.async200_blocking_calls - if key := fnmatch_qualified_name([node.func], *blocking_calls): + if key := fnmatch_qualified_name( + [node.func], *blocking_calls, imports=self.imports + ): self.error(node, key, blocking_calls[key]) @@ -68,17 +70,19 @@ class Visitor21X(Visitor200): def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) - self.imports: set[str] = set() + # tracked specifically for ASYNC211, which wants to know whether urllib3 + # itself has been imported (any form) before flagging `x.request(...)`. + self.urllib3_imports: set[str] = set() def visit_ImportFrom(self, node: ast.ImportFrom): if node.module == "urllib3": - self.imports.add(node.module) + self.urllib3_imports.add(node.module) def visit_Import(self, node: ast.Import): for name in node.names: if name.name == "urllib3": # Could also save the name.asname for matching - self.imports.add(name.name) + self.urllib3_imports.add(name.name) def visit_blocking_call(self, node: ast.Call): http_methods = { @@ -91,12 +95,23 @@ def visit_blocking_call(self, node: ast.Call): "delete", } func_name = ast.unparse(node.func) + canonical = self.canonical_name(node.func) or func_name for http_package in "requests", "httpx": - if get_matching_call(node, *http_methods | {"request"}, base=http_package): + if get_matching_call( + node, + *http_methods | {"request"}, + base=http_package, + imports=self.imports, + ): self.error(node, func_name, error_code="ASYNC210") return - if func_name in ( + if canonical in ( + "urllib3.request", + "urllib.request.urlopen", + "request.urlopen", + "urlopen", + ) or func_name in ( "urllib3.request", "urllib.request.urlopen", "request.urlopen", @@ -105,7 +120,7 @@ def visit_blocking_call(self, node: ast.Call): self.error(node, func_name, error_code="ASYNC210") elif ( - "urllib3" in self.imports + "urllib3" in self.urllib3_imports and isinstance(node.func, ast.Attribute) and node.func.attr == "request" and node.args @@ -209,22 +224,29 @@ def is_p_wait(arg: ast.expr) -> bool: "getstatusoutput", } - func_name = ast.unparse(node.func) + raw_name = ast.unparse(node.func) + canonical = self.canonical_name(node.func) or raw_name + # What we report to the user. Prefer the raw spelling since that's what + # they wrote, but use canonical for matching so `import os as o; o.popen`, + # `from subprocess import run`, etc. also get detected. + func_name = raw_name error_code: str | None = None - if func_name in ("subprocess.Popen", "os.popen"): + if canonical in ("subprocess.Popen", "os.popen"): error_code = "ASYNC220" - elif func_name in ( + elif canonical in ( "os.system", "os.posix_spawn", "os.posix_spawnp", - ) or get_matching_call(node, *subprocess_calls, base="subprocess"): + ) or get_matching_call( + node, *subprocess_calls, base="subprocess", imports=self.imports + ): error_code = "ASYNC221" - elif re.fullmatch("os.wait([34]|(id)|(pid))?", func_name): + elif re.fullmatch("os.wait([34]|(id)|(pid))?", canonical): error_code = "ASYNC222" - elif re.fullmatch("os.spawn[vl]p?e?", func_name): + elif re.fullmatch("os.spawn[vl]p?e?", canonical): error_code = "ASYNC221" # if mode= is given and not [os.]P_WAIT: ASYNC220 @@ -265,8 +287,8 @@ class Visitor23X(Visitor200): } def visit_Call(self, node: ast.Call): - func_name = ast.unparse(node.func) - if re.fullmatch(r"(trio|anyio)\.wrap_file", func_name) and len(node.args) == 1: + canonical = self.canonical_name(node.func) + if canonical in ("trio.wrap_file", "anyio.wrap_file") and len(node.args) == 1: setattr(node.args[0], "wrapped", True) # noqa: B010 super().visit_Call(node) @@ -274,9 +296,10 @@ def visit_blocking_call(self, node: ast.Call): if getattr(node, "wrapped", False): return func_name = ast.unparse(node.func) - if func_name in ("open", "io.open", "io.open_code"): + canonical = self.canonical_name(node.func) or func_name + if canonical in ("builtins.open", "open", "io.open", "io.open_code"): error_code = "ASYNC230" - elif func_name == "os.fdopen": + elif canonical == "os.fdopen": error_code = "ASYNC231" else: return @@ -381,9 +404,10 @@ def visit_Call(self, node: ast.Call): return error_code = "ASYNC240_asyncio" if self.library == ("asyncio",) else "ASYNC240" func_name = ast.unparse(node.func) + canonical = self.canonical_name(node.func) or func_name if func_name in self.imports_from_ospath: self.error(node, func_name, self.library_str, error_code=error_code) - elif (m := re.fullmatch(r"os\.path\.(?P.*)", func_name)) and m.group( + elif (m := re.fullmatch(r"os\.path\.(?P.*)", canonical)) and m.group( "func" ) in self.os_funcs: self.error(node, m.group("func"), self.library_str, error_code=error_code) @@ -410,13 +434,14 @@ def visit_Call(self, node: ast.Call): if not self.async_function: return func_name = ast.unparse(node.func) - if func_name == "input": + canonical = self.canonical_name(node.func) or func_name + if canonical in ("input", "builtins.input"): error_code = "ASYNC250" if len(self.library) == 1: msg_param = wrappers[self.library_str] else: msg_param = "/".join(wrappers[lib] for lib in self.library) - elif func_name == "time.sleep": + elif canonical == "time.sleep": error_code = "ASYNC251" msg_param = self.library_str else: diff --git a/flake8_async/visitors/visitor91x.py b/flake8_async/visitors/visitor91x.py index 123fba7e..9d0a2429 100644 --- a/flake8_async/visitors/visitor91x.py +++ b/flake8_async/visitors/visitor91x.py @@ -131,7 +131,9 @@ def leave_FunctionDef( ) # ignore functions with no_checkpoint_warning_decorators and not fnmatch_qualified_name_cst( - original_node.decorators, *self.options.no_checkpoint_warning_decorators + original_node.decorators, + *self.options.no_checkpoint_warning_decorators, + imports=self.imports, ) ): self.error(original_node) @@ -649,7 +651,9 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: self.async_function = ( node.asynchronous is not None and not fnmatch_qualified_name_cst( - node.decorators, *self.options.no_checkpoint_warning_decorators + node.decorators, + *self.options.no_checkpoint_warning_decorators, + imports=self.imports, ) ) # only visit subnodes if there is an async function defined inside @@ -864,6 +868,7 @@ def _is_exception_suppressing_context_manager(self, node: cst.With) -> bool: "contextlib.suppress", *self.suppress_imported_as, *self.options.exception_suppress_context_managers, + imports=self.imports, ) is not None ) @@ -881,7 +886,7 @@ def _checkpoint_with(self, node: cst.With, entry: bool): return for item in node.items: - if isinstance(item.item, cst.Call) and identifier_to_string( + if isinstance(item.item, cst.Call) and self.canonical_name( item.item.func ) in ( "trio.open_nursery", @@ -919,7 +924,10 @@ def visit_With_body(self, node: cst.With): for withitem in node.items: self.has_checkpoint_stack.append(ContextManager()) if get_matching_call_cst( - withitem.item, "open_nursery", "create_task_group" + withitem.item, + "open_nursery", + "create_task_group", + imports=self.imports, ): if withitem.asname is not None and isinstance( withitem.asname.name, cst.Name @@ -945,6 +953,7 @@ def visit_With_body(self, node: cst.With): "contextlib.suppress", *self.suppress_imported_as, *self.options.exception_suppress_context_managers, + imports=self.imports, ) is not None ): @@ -955,12 +964,15 @@ def visit_With_body(self, node: cst.With): continue if res := ( - get_matching_call_cst(withitem.item, *cancel_scope_names) + get_matching_call_cst( + withitem.item, *cancel_scope_names, imports=self.imports + ) or get_matching_call_cst( withitem.item, "timeout", "timeout_at", base="asyncio", + imports=self.imports, ) ): # typing issue: https://github.com/Instagram/LibCST/issues/1107 diff --git a/flake8_async/visitors/visitor_utility.py b/flake8_async/visitors/visitor_utility.py index 1e70785a..e1c2ea67 100644 --- a/flake8_async/visitors/visitor_utility.py +++ b/flake8_async/visitors/visitor_utility.py @@ -7,16 +7,16 @@ import re from typing import TYPE_CHECKING, Any, cast +import libcst as cst import libcst.matchers as m from libcst.metadata import PositionProvider from .flake8asyncvisitor import Flake8AsyncVisitor, Flake8AsyncVisitor_cst -from .helpers import utility_visitor, utility_visitor_cst +from .helpers import identifier_to_string, utility_visitor, utility_visitor_cst if TYPE_CHECKING: from re import Match - import libcst as cst from libcst.metadata import CodeRange @@ -155,6 +155,134 @@ def visit_Import(self, node: cst.Import): self.add_library(alias.name.value) +# Populates `self.imports` (a map of local-name -> canonical dotted qualname) +# so helpers can resolve call-sites back to their canonical qualname regardless +# of how the user imported things. +# +# Examples: +# import trio -> imports["trio"] = "trio" +# import trio as t -> imports["t"] = "trio" +# import trio.lowlevel -> imports["trio"] = "trio" +# imports["trio.lowlevel"] = "trio.lowlevel" +# import trio.lowlevel as ll -> imports["ll"] = "trio.lowlevel" +# from trio import sleep -> imports["sleep"] = "trio.sleep" +# from trio import sleep as s -> imports["s"] = "trio.sleep" +# from trio.lowlevel import wait_* is treated as "trio.lowlevel.wait_*" +# +# Only top-level (module-level) imports are tracked; function- and class-local +# imports are intentionally skipped so that a local import inside one function +# doesn't leak into sibling scopes. The CST pass runs each utility visitor +# over the whole module in one go before any error visitor, so scope-aware +# tracking would require significantly more plumbing -- and ignoring local +# imports matches what linters typically do in practice. +@utility_visitor +class VisitorImportTracker(Flake8AsyncVisitor): + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self._scope_depth = 0 + + def _add_import(self, local: str, canonical: str) -> None: + if self._scope_depth == 0: + self.imports[local] = canonical + + def visit_Import(self, node: ast.Import): + for alias in node.names: + if alias.asname is not None: + self._add_import(alias.asname, alias.name) + else: + top = alias.name.partition(".")[0] + if self._scope_depth == 0 and top not in self.imports: + self.imports[top] = top + if self._scope_depth == 0 and alias.name not in self.imports: + self.imports[alias.name] = alias.name + + def visit_ImportFrom(self, node: ast.ImportFrom): + if node.module is None or node.level: + return + for alias in node.names: + if alias.name == "*": + continue + local = alias.asname if alias.asname is not None else alias.name + self._add_import(local, f"{node.module}.{alias.name}") + + def _enter_scope(self, node: ast.AST): + self.save_state(node, "_scope_depth") + self._scope_depth += 1 + + visit_FunctionDef = _enter_scope + visit_AsyncFunctionDef = _enter_scope + visit_ClassDef = _enter_scope + visit_Lambda = _enter_scope + + +@utility_visitor_cst +class VisitorImportTracker_cst(Flake8AsyncVisitor_cst): + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self._scope_depth = 0 + + def _add_import(self, local: str, canonical: str) -> None: + if self._scope_depth == 0: + self.imports[local] = canonical + + def visit_Import(self, node: cst.Import): + for alias in node.names: + full_name = identifier_to_string(alias.name) + if full_name is None: + continue + if alias.asname is not None and isinstance(alias.asname.name, cst.Name): + self._add_import(alias.asname.name.value, full_name) + elif self._scope_depth == 0: + top = full_name.partition(".")[0] + self.imports.setdefault(top, top) + self.imports.setdefault(full_name, full_name) + + def visit_ImportFrom(self, node: cst.ImportFrom): + if node.module is None or node.relative: + return + module = identifier_to_string(node.module) + if module is None: + return + if isinstance(node.names, cst.ImportStar): + return + for alias in node.names: + name = identifier_to_string(alias.name) + if name is None: + continue + if alias.asname is not None and isinstance(alias.asname.name, cst.Name): + local = alias.asname.name.value + else: + local = name + self._add_import(local, f"{module}.{name}") + + def visit_FunctionDef(self, node: cst.FunctionDef): + self._scope_depth += 1 + + def leave_FunctionDef( + self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef + ) -> cst.FunctionDef: + self._scope_depth -= 1 + return updated_node + + def visit_ClassDef(self, node: cst.ClassDef): + self._scope_depth += 1 + + def leave_ClassDef( + self, original_node: cst.ClassDef, updated_node: cst.ClassDef + ) -> cst.ClassDef: + self._scope_depth -= 1 + return updated_node + + def visit_Lambda(self, node: cst.Lambda): + self._scope_depth += 1 + + def leave_Lambda( + self, original_node: cst.Lambda, updated_node: cst.Lambda + ) -> cst.Lambda: + self._scope_depth -= 1 + return updated_node + + # taken from # https://github.com/PyCQA/flake8/blob/d016204366a22d382b5b56dc14b6cbff28ce929e/src/flake8/defaults.py#L27 NOQA_INLINE_REGEXP = re.compile( diff --git a/flake8_async/visitors/visitors.py b/flake8_async/visitors/visitors.py index afc1a308..18a2b51a 100644 --- a/flake8_async/visitors/visitors.py +++ b/flake8_async/visitors/visitors.py @@ -13,7 +13,6 @@ error_class_cst, get_matching_call, has_decorator, - identifier_to_string, ) if TYPE_CHECKING: @@ -25,9 +24,16 @@ @error_class +@disabled_by_default class Visitor106(Flake8AsyncVisitor): + # Historically this enforced `import trio` because the linter couldn't see + # through aliased/from-imports. That limitation is gone (see #132), so + # ASYNC106 is now opt-in for projects that still want to enforce the style. error_codes: Mapping[str, str] = { - "ASYNC106": "{0} must be imported with `import {0}` for the linter to work.", + "ASYNC106": ( + "{0} should be imported with `import {0}` for consistency" + " (historical reasons; no longer required for the linter to work)." + ), } def visit_ImportFrom(self, node: ast.ImportFrom): @@ -76,16 +82,16 @@ def visit_While(self, node: ast.While): and isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Await) and ( - get_matching_call(node.body[0].value.value, "sleep", "sleep_until") + get_matching_call( + node.body[0].value.value, + "sleep", + "sleep_until", + imports=self.imports, + ) or ( - # get_matching_call doesn't (currently) support checking for trio.x.y isinstance(call := node.body[0].value.value, ast.Call) - and isinstance(call.func, ast.Attribute) - and call.func.attr == "checkpoint" - and isinstance(call.func.value, ast.Attribute) - and call.func.value.attr == "lowlevel" - and isinstance(call.func.value.value, ast.Name) - and call.func.value.value.id in ("trio", "anyio") + and self.canonical_name(call.func) + in ("trio.lowlevel.checkpoint", "anyio.lowlevel.checkpoint") ) ) ): @@ -116,15 +122,22 @@ def visit_With(self, node: ast.With | ast.AsyncWith): start_methods: tuple[str, ...] = ("start", "start_soon") # check for trio.open_nursery and anyio.create_task_group - if get_matching_call(item.context_expr, "open_nursery", base="trio"): + if get_matching_call( + item.context_expr, "open_nursery", base="trio", imports=self.imports + ): nursery_type = "nursery" elif get_matching_call( - item.context_expr, "create_task_group", base="anyio" + item.context_expr, + "create_task_group", + base="anyio", + imports=self.imports, ): nursery_type = "task group" # check for asyncio.TaskGroup - elif get_matching_call(item.context_expr, "TaskGroup", base="asyncio"): + elif get_matching_call( + item.context_expr, "TaskGroup", base="asyncio", imports=self.imports + ): nursery_type = "task group" start_methods = ("create_task",) else: @@ -138,6 +151,8 @@ def visit_With(self, node: ast.With | ast.AsyncWith): body_call = cast("ast.Call", body_call) if ( + # start[_soon] is called on the nursery/taskgroup variable, + # not a canonically-resolved name, so we don't pass imports. get_matching_call(body_call, *start_methods, base=var_name) # check for presence of as parameter and not any( @@ -304,7 +319,7 @@ class Visitor115(Flake8AsyncVisitor): } def visit_Call(self, node: ast.Call): - if not (m := get_matching_call(node, "sleep")): + if not (m := get_matching_call(node, "sleep", imports=self.imports)): return if ( len(node.args) == 1 @@ -328,7 +343,7 @@ class Visitor116(Flake8AsyncVisitor): } def visit_Call(self, node: ast.Call): - if not (m := get_matching_call(node, "sleep")): + if not (m := get_matching_call(node, "sleep", imports=self.imports)): return if len(node.args) == 1: arg = node.args[0] @@ -425,11 +440,18 @@ def visit_AsyncWith(self, node: ast.AsyncWith): self.save_state(node, "unsafe_stack", copy=True) for item in node.items: - if get_matching_call(item.context_expr, "open_nursery", base="trio"): + if get_matching_call( + item.context_expr, "open_nursery", base="trio", imports=self.imports + ): self.unsafe_stack.append("nursery") elif get_matching_call( - item.context_expr, "create_task_group", base="anyio" - ) or get_matching_call(item.context_expr, "TaskGroup", base="asyncio"): + item.context_expr, + "create_task_group", + base="anyio", + imports=self.imports, + ) or get_matching_call( + item.context_expr, "TaskGroup", base="asyncio", imports=self.imports + ): self.unsafe_stack.append("task group") def visit_While(self, node: ast.While | ast.For | ast.AsyncFor): @@ -484,7 +506,11 @@ def visit_withitem(self, node: ast.withitem): def visit_Call(self, node: ast.Call): if not self.in_withitem and ( match := get_matching_call( - node, "fail_after", "move_on_after", base=("trio", "anyio") + node, + "fail_after", + "move_on_after", + base=("trio", "anyio"), + imports=self.imports, ) ): self.error(node, str(match)) @@ -510,7 +536,12 @@ def is_constant(value: ast.expr) -> bool: return False match = get_matching_call( - node, "fail_at", "move_on_at", "CancelScope", base=("trio", "anyio") + node, + "fail_at", + "move_on_at", + "CancelScope", + base=("trio", "anyio"), + imports=self.imports, ) if match is None: return @@ -548,7 +579,8 @@ def base_name(base: ast.expr) -> str: # strip generic subscripts like `ExceptionGroup[Foo]` if isinstance(base, ast.Subscript): base = base.value - unparsed = ast.unparse(base) + canonical = self.canonical_name(base) + unparsed = canonical if canonical is not None else ast.unparse(base) return unparsed.rsplit(".", 1)[-1] if not any( @@ -586,7 +618,7 @@ def visit_CompIf(self, node: cst.CSTNode): def visit_Call(self, node: cst.Call): if ( - identifier_to_string(node.func) == "asyncio.create_task" + self.canonical_name(node.func) == "asyncio.create_task" and not self.safe_to_create_task ): self.error(node) diff --git a/tests/autofix_files/exception_suppress_context_manager.py b/tests/autofix_files/exception_suppress_context_manager.py index 0704da20..3d00a643 100644 --- a/tests/autofix_files/exception_suppress_context_manager.py +++ b/tests/autofix_files/exception_suppress_context_manager.py @@ -87,10 +87,15 @@ async def foo_suppress_as(): # ASYNC910: 0, "exit", Statement('function definit # ############################### -# not enabled unless it's imported from contextlib -async def foo_suppress_directly_imported_1(): +# Module-level imports are resolved via their canonical qualname, so the +# `from contextlib import suppress` below is recognised as a suppressing CM +# even though the function is defined above the import statement (Python +# resolves the name at call time, so this more closely matches runtime +# semantics). +async def foo_suppress_directly_imported_1(): # ASYNC910: 0, "exit", Statement('function definition', lineno) with suppress(): await foo() + await trio.lowlevel.checkpoint() from contextlib import suppress diff --git a/tests/autofix_files/exception_suppress_context_manager.py.diff b/tests/autofix_files/exception_suppress_context_manager.py.diff index 0de6726e..713aa51b 100644 --- a/tests/autofix_files/exception_suppress_context_manager.py.diff +++ b/tests/autofix_files/exception_suppress_context_manager.py.diff @@ -50,6 +50,14 @@ # ############################### +@@ x,6 x,7 @@ + async def foo_suppress_directly_imported_1(): # ASYNC910: 0, "exit", Statement('function definition', lineno) + with suppress(): + await foo() ++ await trio.lowlevel.checkpoint() + + + from contextlib import suppress @@ x,6 x,7 @@ async def foo_suppress_directly_imported_2(): # ASYNC910: 0, "exit", Statement('function definition', lineno) with suppress(): diff --git a/tests/eval_files/async110.py b/tests/eval_files/async110.py index 22df3732..8527577b 100644 --- a/tests/eval_files/async110.py +++ b/tests/eval_files/async110.py @@ -38,8 +38,8 @@ async def foo(): await trio.sleep() await trio.sleep_until() - # check library name - while ...: + # aliased-import resolves to canonical qualname, so this now errors too + while ...: # error: 4, "trio" await noerror.sleep() async def sleep(): ... diff --git a/tests/eval_files/async111.py b/tests/eval_files/async111.py index e40d2929..88d24bf5 100644 --- a/tests/eval_files/async111.py +++ b/tests/eval_files/async111.py @@ -77,10 +77,10 @@ async def foo_2(): async with trio.open_process() as bar_2: nursery.start(bar_2) # safe -# specifically check for *trio*.open_nursery +# aliased trio import is now resolved to canonical qualname, so this errors too with noterror.open_nursery() as nursery: with trio.open("") as bar: - nursery.start(bar) + nursery.start(bar) # error: 22, line-1, line-2, "bar", "start" # specifically check for trio.*open_nursery* with trio.open_nurse() as nursery: diff --git a/tests/eval_files/async112.py b/tests/eval_files/async112.py index 8657973d..01a497c6 100644 --- a/tests/eval_files/async112.py +++ b/tests/eval_files/async112.py @@ -86,8 +86,8 @@ async def foo_1(): await n.start(...) -# not *trio*.open_nursery -with noterror.open_nursery(...) as n: +# aliased trio import is now resolved to canonical qualname, so this errors too +with noterror.open_nursery(...) as n: # error: 5, "n", "nursery" n.start(...) # not trio.*open_nursery* diff --git a/tests/eval_files/async112_canonical_qualname.py b/tests/eval_files/async112_canonical_qualname.py new file mode 100644 index 00000000..2e988b79 --- /dev/null +++ b/tests/eval_files/async112_canonical_qualname.py @@ -0,0 +1,30 @@ +# Regression test for https://github.com/python-trio/flake8-async/issues/132: +# detection works regardless of how trio is imported. +# type: ignore +# ASYNCIO_NO_ERROR +# ARG --enable=ASYNC112 + +import trio +import trio as t +from trio import open_nursery +from trio import open_nursery as on + + +# `import trio as t` +with t.open_nursery() as n: # error: 5, "n", "nursery" + n.start(...) + + +# `from trio import open_nursery` +with open_nursery() as n: # error: 5, "n", "nursery" + n.start(...) + + +# `from trio import open_nursery as on` +with on() as n: # error: 5, "n", "nursery" + n.start_soon(...) + + +# canonical name still matches when chained through an ordinary `import trio` +with trio.open_nursery() as n: # error: 5, "n", "nursery" + n.start(...) diff --git a/tests/eval_files/async115.py b/tests/eval_files/async115.py index 05dd7586..70186366 100644 --- a/tests/eval_files/async115.py +++ b/tests/eval_files/async115.py @@ -18,9 +18,10 @@ async def afoo(): trio.sleep(0) # error: 4, "trio" trio.sleep(1) - # don't error on other sleeps + # don't error on unrelated sleeps time.sleep(0) - sleep(0) + # `from trio import sleep` resolves to canonical `trio.sleep`, so this now errors + sleep(0) # error: 4, "trio" # in trio it's called 'seconds', in anyio it's 'delay', but # we don't care about the kwarg name. #382 diff --git a/tests/eval_files/async115_canonical_qualname.py b/tests/eval_files/async115_canonical_qualname.py new file mode 100644 index 00000000..7e7427f9 --- /dev/null +++ b/tests/eval_files/async115_canonical_qualname.py @@ -0,0 +1,38 @@ +# Regression test for https://github.com/python-trio/flake8-async/issues/132: +# rules should fire against the canonical qualname regardless of how a symbol +# is imported -- plain `import trio`, `import ... as ...`, +# `from ... import ...`, or `from ... import ... as ...`. +# type: ignore +# ASYNCIO_NO_ERROR - ASYNC115 is trio/anyio-only +# ARG --enable=ASYNC115 + +import trio as t +import trio.lowlevel as ll +from trio import sleep +from trio import sleep as nap +from trio.lowlevel import checkpoint as cp + + +async def afoo(): + # `import trio as t` + await t.sleep(0) # error: 10, "trio" + + # `from trio import sleep` + await sleep(0) # error: 10, "trio" + + # `from trio import sleep as nap` + await nap(0) # error: 10, "trio" + + # `import trio.lowlevel as ll` still resolves the canonical qualname + # (we only track this for its side-effect on ASYNC110 elsewhere, but it + # must not crash). + ll.checkpoint() + + # `from trio.lowlevel import checkpoint as cp` -- ASYNC115 doesn't match this + # particular qualname, but we just want to show resolution doesn't crash. + cp() + + # non-aliased local name that shadows an import should not falsely match: + # no import binds `sleep_2`, so `sleep_2(0)` is not flagged. + sleep_2 = lambda x: None + sleep_2(0) diff --git a/tests/eval_files/async251.py b/tests/eval_files/async251.py index 9da4a103..00d2742e 100644 --- a/tests/eval_files/async251.py +++ b/tests/eval_files/async251.py @@ -6,6 +6,5 @@ async def foo(): time.sleep(5) # ASYNC251: 4, "trio" time.sleep(5) if 5 else time.sleep(5) # ASYNC251: 4, "trio" # ASYNC251: 28, "trio" - # Not handled due to difficulty tracking imports and not wanting to trigger - # false positives. But could definitely be handled by ruff et al. - sleep(5) + # `from time import sleep` resolves to canonical `time.sleep`, so this now errors + sleep(5) # ASYNC251: 4, "trio" diff --git a/tests/eval_files/async300.py b/tests/eval_files/async300.py index 87ad535c..8fd3f920 100644 --- a/tests/eval_files/async300.py +++ b/tests/eval_files/async300.py @@ -67,7 +67,8 @@ def returner_list(): with asyncio.create_task(*args) as k: # type: ignore[attr-defined] # ASYNC300: 9 ... - # import aliasing is not supported (this would raise ASYNC106 bad-async-library-import) + # module-level imports are resolved to their canonical qualname, but + # function-local imports are not tracked (they'd leak into sibling scopes). from asyncio import create_task create_task(*args) diff --git a/tests/eval_files/exception_suppress_context_manager.py b/tests/eval_files/exception_suppress_context_manager.py index 4b809d7f..d7114cf0 100644 --- a/tests/eval_files/exception_suppress_context_manager.py +++ b/tests/eval_files/exception_suppress_context_manager.py @@ -80,8 +80,12 @@ async def foo_suppress_as(): # ASYNC910: 0, "exit", Statement('function definit # ############################### -# not enabled unless it's imported from contextlib -async def foo_suppress_directly_imported_1(): +# Module-level imports are resolved via their canonical qualname, so the +# `from contextlib import suppress` below is recognised as a suppressing CM +# even though the function is defined above the import statement (Python +# resolves the name at call time, so this more closely matches runtime +# semantics). +async def foo_suppress_directly_imported_1(): # ASYNC910: 0, "exit", Statement('function definition', lineno) with suppress(): await foo() diff --git a/tests/test_config_and_args.py b/tests/test_config_and_args.py index 4e93c713..aa2f3709 100644 --- a/tests/test_config_and_args.py +++ b/tests/test_config_and_args.py @@ -492,8 +492,9 @@ def test_disable_noqa_ast( assert not err assert ( out - == "./example.py:1:1: ASYNC106 trio must be imported with `import trio` for the" - " linter to work.\n" + == "./example.py:1:1: ASYNC106 trio should be imported with `import trio`" + " for consistency (historical reasons; no longer required for the linter" + " to work).\n" ) From f337062077f5d738482973fbf42fba5299406df1 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 24 Apr 2026 19:30:36 +0000 Subject: [PATCH 2/6] Tighten canonical-qualname code and comments - Drop redundant canonical_name() docstrings. - Simplify fnmatch_qualified_name[_cst] to build a candidate set inline. - Fold the resolve_canonical_ast recursive arm into one-liners. - Drop ASYNC21X's bespoke urllib3-import set -- consult the shared imports map. - Collapse ASYNC22X's raw_name/canonical/func_name triplet into two locals. - Simplify with_has_call's canonical fallback to a startswith + suffix check. - Consolidate the CST scope-tracker's three visit/leave pairs into shared helpers. - Rewrite the narrative "this change" comments as reader-facing rationale, both in the code and in the eval-file annotations. https://claude.ai/code/session_018Hc9rcA31SnXcN8Ee5vVwH --- docs/changelog.rst | 4 +- docs/rules.rst | 4 +- flake8_async/runner.py | 9 +- flake8_async/visitors/flake8asyncvisitor.py | 8 -- flake8_async/visitors/helpers.py | 126 +++++++----------- flake8_async/visitors/visitor105.py | 4 +- flake8_async/visitors/visitor118.py | 19 ++- flake8_async/visitors/visitor123.py | 5 +- flake8_async/visitors/visitor2xx.py | 45 ++----- flake8_async/visitors/visitor_utility.py | 116 ++++++++-------- flake8_async/visitors/visitors.py | 9 +- .../exception_suppress_context_manager.py | 8 +- tests/eval_files/async110.py | 2 +- tests/eval_files/async111.py | 2 +- tests/eval_files/async112.py | 2 +- .../eval_files/async112_canonical_qualname.py | 6 +- tests/eval_files/async115.py | 4 +- .../eval_files/async115_canonical_qualname.py | 21 +-- tests/eval_files/async251.py | 2 +- tests/eval_files/async300.py | 3 +- .../exception_suppress_context_manager.py | 8 +- tests/test_config_and_args.py | 3 +- 22 files changed, 152 insertions(+), 258 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 3ad8e3af..a95b1fde 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -6,8 +6,8 @@ Changelog Unreleased ========== -- Rules now resolve function/class references against the canonical qualname, so checks fire regardless of how a symbol was imported (``import trio``, ``import trio as t``, ``from trio import open_nursery``, ``from trio import open_nursery as on``, etc.). Only module-level imports are tracked; function-local imports are still considered local. `(issue #132) `_ -- :ref:`ASYNC106 ` is now disabled by default: with the canonical-qualname resolution above, the rule is no longer required for the linter to work correctly. Re-enable it explicitly if you still want to enforce the ``import trio`` style. +- Rules resolve function/class references via the canonical qualname, so checks fire regardless of import style (``import trio``, ``import trio as t``, ``from trio import open_nursery [as on]``, …). Only module-level imports are tracked. `(issue #132) `_ +- :ref:`ASYNC106 ` is now disabled by default; re-enable it to enforce the ``import trio`` style. - Autofix for :ref:`ASYNC910 ` / :ref:`ASYNC911 ` no longer inserts checkpoints inside ``except`` clauses (which would trigger :ref:`ASYNC120 `); instead the checkpoint is added at the top of the function or of the enclosing loop. `(issue #403) `_ - :ref:`ASYNC910 ` and :ref:`ASYNC911 ` now accept ``__aenter__`` / ``__aexit__`` methods when the partner method provides the checkpoint, or when only one of the two is defined on a class that inherits from another class (charitably assuming the partner is inherited and contains a checkpoint). `(issue #441) `_ - :ref:`ASYNC300 ` no longer triggers when the result of ``asyncio.create_task()`` is returned from a function. `(issue #398) `_ diff --git a/docs/rules.rst b/docs/rules.rst index 1ee50307..315b5c4d 100644 --- a/docs/rules.rst +++ b/docs/rules.rst @@ -43,9 +43,7 @@ ASYNC105 : missing-await ASYNC106 : bad-async-library-import trio/anyio/asyncio should be imported with ``import xxx`` for consistency. - As of the canonical-qualname refactor this rule is no longer required for the - linter to work, and is therefore disabled by default -- enable it explicitly - if you want to enforce the style. + Opt-in style check; the linter resolves other import styles correctly. ASYNC109 : async-function-with-timeout Async function definition with a ``timeout`` parameter. diff --git a/flake8_async/runner.py b/flake8_async/runner.py index e0fbc1e8..fb8d7a49 100644 --- a/flake8_async/runner.py +++ b/flake8_async/runner.py @@ -37,11 +37,10 @@ class SharedState: library: tuple[str, ...] = () typed_calls: dict[str, str] = field(default_factory=dict[str, str]) variables: dict[str, str] = field(default_factory=dict[str, str]) - # Maps a locally-bound name to its canonical dotted qualname, populated by - # VisitorImportTracker/VisitorImportTracker_cst. Used by helpers so that - # rules can be written against canonical qualnames and match regardless of - # how things were imported (bare `import x`, `import x as y`, - # `from x import y`, or `from x import y as z`). + # Local name -> canonical dotted qualname, populated by VisitorImportTracker[_cst]. + # Helpers consult this so rules can match the canonical qualname regardless of + # how a symbol was imported (`import x`, `import x as y`, `from x import y`, + # `from x import y as z`). imports: dict[str, str] = field(default_factory=dict[str, str]) diff --git a/flake8_async/visitors/flake8asyncvisitor.py b/flake8_async/visitors/flake8asyncvisitor.py index d6128c95..b683a99b 100644 --- a/flake8_async/visitors/flake8asyncvisitor.py +++ b/flake8_async/visitors/flake8asyncvisitor.py @@ -58,10 +58,6 @@ def imports(self) -> dict[str, str]: return self.__state.imports def canonical_name(self, node: ast.AST) -> str | None: - """Resolve `node` to a dotted canonical qualname, consulting imports. - - See ``resolve_canonical_ast`` for semantics. - """ from .helpers import resolve_canonical_ast return resolve_canonical_ast(node, self.__state.imports) @@ -188,10 +184,6 @@ def imports(self) -> dict[str, str]: return self.__state.imports def canonical_name(self, node: cst.CSTNode) -> str | None: - """Resolve `node` to a dotted canonical qualname, consulting imports. - - See ``resolve_canonical_cst`` for semantics. - """ from .helpers import resolve_canonical_cst return resolve_canonical_cst(node, self.__state.imports) diff --git a/flake8_async/visitors/helpers.py b/flake8_async/visitors/helpers.py index 26e42098..089d2ecc 100644 --- a/flake8_async/visitors/helpers.py +++ b/flake8_async/visitors/helpers.py @@ -109,16 +109,13 @@ def fnmatch_qualified_name( for name in name_list: if isinstance(name, ast.Call): name = name.func - qualified_names = [ast.unparse(name)] - if imports is not None: - canonical = resolve_canonical_ast(name, imports) - if canonical is not None and canonical not in qualified_names: - qualified_names.append(canonical) - + candidates = {ast.unparse(name)} + if imports is not None and (canonical := resolve_canonical_ast(name, imports)): + candidates.add(canonical) for pattern in patterns: # strip leading "@"s for when we're working with decorators stripped = pattern.lstrip("@") - if any(fnmatch(qn, stripped) for qn in qualified_names): + if any(fnmatch(c, stripped) for c in candidates): return pattern return None @@ -129,21 +126,19 @@ def fnmatch_qualified_name_cst( imports: Mapping[str, str] | None = None, ) -> str | None: for name in name_list: - qualified_names = [get_full_name_for_node_or_raise(name)] + candidates = {get_full_name_for_node_or_raise(name)} if imports is not None: - node: cst.CSTNode = name - if isinstance(node, cst.Decorator): - node = node.decorator - if isinstance(node, cst.Call): - node = node.func - canonical = resolve_canonical_cst(node, imports) - if canonical is not None and canonical not in qualified_names: - qualified_names.append(canonical) - + inner: cst.CSTNode = name + if isinstance(inner, cst.Decorator): + inner = inner.decorator + if isinstance(inner, cst.Call): + inner = inner.func + if (canonical := resolve_canonical_cst(inner, imports)) is not None: + candidates.add(canonical) for pattern in patterns: # strip leading "@"s for when we're working with decorators stripped = pattern.lstrip("@") - if any(fnmatch(qn, stripped) for qn in qualified_names): + if any(fnmatch(c, stripped) for c in candidates): return pattern return None @@ -275,29 +270,25 @@ def has_exception(node: ast.expr) -> str | None: "CancelledError", ): return name - # Match via canonical qualname, so `import trio as t; except t.Cancelled`, - # `from trio import Cancelled`, `from asyncio import CancelledError as CE`, etc. - # also get picked up. The non-call forms (`except anyio.get_cancelled_exc_class:` - # and `except ...(...)` with args) are type-errors the existing code - # intentionally ignores, so only match zero-arg calls for the dynamic form. - if imports is not None: - if isinstance(node, ast.Call): - if node.args or node.keywords: - return None - canonical = resolve_canonical_ast(node.func, imports) - else: - canonical = resolve_canonical_ast(node, imports) - if canonical == "trio.Cancelled" and not isinstance(node, ast.Call): - return "trio.Cancelled" - if canonical == "anyio.get_cancelled_exc_class" and isinstance( - node, ast.Call - ): - return "anyio.get_cancelled_exc_class()" - if canonical in ( - "asyncio.exceptions.CancelledError", - "asyncio.CancelledError", - ) and not isinstance(node, ast.Call): - return "asyncio.exceptions.CancelledError" + if imports is None: + return None + # Resolve via canonical qualname for aliased / `from`-imported forms. + # The non-call spellings (`except anyio.get_cancelled_exc_class:`, or a + # Call with arguments) are type-errors that critical_except intentionally + # ignores, so only zero-arg calls count for get_cancelled_exc_class. + is_call = isinstance(node, ast.Call) + if is_call and (node.args or node.keywords): + return None + canonical = resolve_canonical_ast(node.func if is_call else node, imports) + if not is_call and canonical == "trio.Cancelled": + return "trio.Cancelled" + if is_call and canonical == "anyio.get_cancelled_exc_class": + return "anyio.get_cancelled_exc_class()" + if not is_call and canonical in ( + "asyncio.exceptions.CancelledError", + "asyncio.CancelledError", + ): + return "asyncio.exceptions.CancelledError" return None name: str | None = None @@ -345,19 +336,16 @@ def __str__(self) -> str: return self.base + "." + self.name -# Resolve an ast Name/Attribute to a canonical dotted qualname, using the `imports` -# map (local-name -> canonical dotted qualname). Returns None for non-name nodes -# (e.g. subscripts, calls). If the root-most Name isn't in `imports`, we fall back -# to using the literal identifier text — so `trio.open_nursery()` without any -# imports still resolves to "trio.open_nursery", preserving prior behaviour. +# Resolve a Name/Attribute/Call node to a dotted qualname via `imports` +# (local-name -> canonical dotted qualname). The root Name falls back to its own +# identifier, so `trio.open_nursery()` resolves to "trio.open_nursery" even when +# nothing was imported. Returns None for shapes we can't resolve (subscripts, etc.). def resolve_canonical_ast(node: ast.AST, imports: Mapping[str, str]) -> str | None: if isinstance(node, ast.Name): return imports.get(node.id, node.id) if isinstance(node, ast.Attribute): prefix = resolve_canonical_ast(node.value, imports) - if prefix is None: - return None - return f"{prefix}.{node.attr}" + return None if prefix is None else f"{prefix}.{node.attr}" if isinstance(node, ast.Call): return resolve_canonical_ast(node.func, imports) return None @@ -370,9 +358,7 @@ def resolve_canonical_cst( return imports.get(node.value, node.value) if isinstance(node, cst.Attribute): prefix = resolve_canonical_cst(node.value, imports) - if prefix is None: - return None - return f"{prefix}.{node.attr.value}" + return None if prefix is None else f"{prefix}.{node.attr.value}" if isinstance(node, cst.Call): return resolve_canonical_cst(node.func, imports) return None @@ -389,7 +375,6 @@ def get_matching_call( base = (base,) if not isinstance(node, ast.Call): return None - # Fast path: matches the existing structural check. if ( isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name) @@ -397,12 +382,8 @@ def get_matching_call( and node.func.attr in names ): return MatchingCall(node, node.func.attr, node.func.value.id) - # Canonical-qualname path: works regardless of how things got imported - # (e.g. `import trio as t`, `from trio import open_nursery [as x]`). if imports is not None: canonical = resolve_canonical_ast(node.func, imports) - if canonical is None: - return None for b in base: for n in names: if canonical == f"{b}.{n}": @@ -428,8 +409,6 @@ def get_matching_call_cst( return MatchingCall(node, node.func.attr.value, attr_base) if imports is not None: canonical = resolve_canonical_cst(node.func, imports) - if canonical is None: - return None for b in base: for n in names: if canonical == f"{b}.{n}": @@ -487,7 +466,9 @@ def with_has_call( """Check if a with statement has a matching call, returning a list with matches. `names` specify the names of functions to match, `base` specifies the - library/module(s) the function must be in. + library/module(s) the function must be in. If `imports` is given, matches + are also made against the canonical qualname so aliased / `from`-imports + are detected. The list elements in the return value are named tuples with the matched node, base and function. @@ -497,12 +478,8 @@ def with_has_call( `with_has_call(node, "bar", "bee", base=("foo", "a.b.c")` matches `foo.bar`, `foo.bee`, `a.b.c.bar`, and `a.b.c.bee`. - When `imports` is passed, matches against the canonical qualname so that - aliased/from-imports are detected as well. """ - if isinstance(base, str): - base = (base,) - base_tuple = tuple(base) + base_tuple = (base,) if isinstance(base, str) else tuple(base) # build matcher, using SaveMatchedNode to save the base and the function name. matcher = m.Call( @@ -510,10 +487,7 @@ def with_has_call( value=m.SaveMatchedNode( m.OneOf(*(build_cst_matcher(b) for b in base_tuple)), name="base" ), - attr=m.SaveMatchedNode( - oneof_names(*names), - name="function", - ), + attr=m.SaveMatchedNode(oneof_names(*names), name="function"), ) ) @@ -534,16 +508,14 @@ def with_has_call( if imports is None or not isinstance(item.item, cst.Call): continue canonical = resolve_canonical_cst(item.item.func, imports) - if canonical is None: - continue for b in base_tuple: - for n in names: - if canonical == f"{b}.{n}": - res_list.append(MatchingCall(node=item.item, base=b, name=n)) + if canonical is not None and canonical.startswith(f"{b}."): + suffix = canonical[len(b) + 1 :] + if suffix in names: + res_list.append( + MatchingCall(node=item.item, base=b, name=suffix) + ) break - else: - continue - break return res_list diff --git a/flake8_async/visitors/visitor105.py b/flake8_async/visitors/visitor105.py index fb63a93e..60aea336 100644 --- a/flake8_async/visitors/visitor105.py +++ b/flake8_async/visitors/visitor105.py @@ -58,8 +58,8 @@ def visit_Call(self, node: ast.Call): canonical = self.canonical_name(node.func) if canonical in trio_async_funcs: - # report the canonical qualname so the message is stable regardless of - # how the user imported the function. + # report the canonical qualname (rather than the user's local alias) + # so the message reads consistently. self.error(node, canonical, "function") elif isinstance(node.func, ast.Attribute) and node.func.attr == "start": var = ast.unparse(node.func.value) diff --git a/flake8_async/visitors/visitor118.py b/flake8_async/visitors/visitor118.py index a32171f8..343a0f3d 100644 --- a/flake8_async/visitors/visitor118.py +++ b/flake8_async/visitors/visitor118.py @@ -27,18 +27,17 @@ class Visitor118(Flake8AsyncVisitor): } def visit_Assign(self, node: ast.Assign | ast.AnnAssign): - if node.value is None: - return value = node.value - func_node = value.func if isinstance(value, ast.Call) else value - canonical = self.canonical_name(func_node) - if canonical == "anyio.get_cancelled_exc_class": - self.error(node.value) + if value is None: + return + target = value.func if isinstance(value, ast.Call) else value + if self.canonical_name(target) == "anyio.get_cancelled_exc_class": + self.error(value) return - # Fallback to literal matching for un-imported or unusual forms. - name = ast.unparse(node.value) - if re.fullmatch(r"(anyio.)?get_cancelled_exc_class(\(\))?", name): - self.error(node.value) + # Fallback for code where anyio isn't importable (e.g. stubs or partial + # configs) but the name is still spelled out literally. + if re.fullmatch(r"(anyio.)?get_cancelled_exc_class(\(\))?", ast.unparse(value)): + self.error(value) visit_AnnAssign = visit_Assign diff --git a/flake8_async/visitors/visitor123.py b/flake8_async/visitors/visitor123.py index 860dcbf0..ba8475b1 100644 --- a/flake8_async/visitors/visitor123.py +++ b/flake8_async/visitors/visitor123.py @@ -69,9 +69,8 @@ def visit_ExceptHandler(self, node: ast.ExceptHandler): "child_exception_names", copy=True, ) - # ExceptionGroup/BaseExceptionGroup are builtins; we match by literal name - # since they're typically used unqualified (and aliasing the builtins - # themselves is very unusual). + # [Base]ExceptionGroup are builtins and almost always used unqualified, + # so a substring match on the literal source is sufficient. if node.name is None or ( not self.try_star and (node.type is None or "ExceptionGroup" not in ast.unparse(node.type)) diff --git a/flake8_async/visitors/visitor2xx.py b/flake8_async/visitors/visitor2xx.py index 4ee5887d..024dcf37 100644 --- a/flake8_async/visitors/visitor2xx.py +++ b/flake8_async/visitors/visitor2xx.py @@ -68,32 +68,13 @@ class Visitor21X(Visitor200): ), } - def __init__(self, *args: Any, **kwargs: Any): - super().__init__(*args, **kwargs) - # tracked specifically for ASYNC211, which wants to know whether urllib3 - # itself has been imported (any form) before flagging `x.request(...)`. - self.urllib3_imports: set[str] = set() - - def visit_ImportFrom(self, node: ast.ImportFrom): - if node.module == "urllib3": - self.urllib3_imports.add(node.module) - - def visit_Import(self, node: ast.Import): - for name in node.names: - if name.name == "urllib3": - # Could also save the name.asname for matching - self.urllib3_imports.add(name.name) + def _urllib3_imported(self) -> bool: + return any( + v == "urllib3" or v.startswith("urllib3.") for v in self.imports.values() + ) def visit_blocking_call(self, node: ast.Call): - http_methods = { - "get", - "options", - "head", - "post", - "put", - "patch", - "delete", - } + http_methods = {"get", "options", "head", "post", "put", "patch", "delete"} func_name = ast.unparse(node.func) canonical = self.canonical_name(node.func) or func_name for http_package in "requests", "httpx": @@ -111,16 +92,11 @@ def visit_blocking_call(self, node: ast.Call): "urllib.request.urlopen", "request.urlopen", "urlopen", - ) or func_name in ( - "urllib3.request", - "urllib.request.urlopen", - "request.urlopen", - "urlopen", ): self.error(node, func_name, error_code="ASYNC210") elif ( - "urllib3" in self.urllib3_imports + self._urllib3_imported() and isinstance(node.func, ast.Attribute) and node.func.attr == "request" and node.args @@ -224,12 +200,9 @@ def is_p_wait(arg: ast.expr) -> bool: "getstatusoutput", } - raw_name = ast.unparse(node.func) - canonical = self.canonical_name(node.func) or raw_name - # What we report to the user. Prefer the raw spelling since that's what - # they wrote, but use canonical for matching so `import os as o; o.popen`, - # `from subprocess import run`, etc. also get detected. - func_name = raw_name + # Match against the canonical qualname, but report the user's literal spelling. + func_name = ast.unparse(node.func) + canonical = self.canonical_name(node.func) or func_name error_code: str | None = None if canonical in ("subprocess.Popen", "os.popen"): error_code = "ASYNC220" diff --git a/flake8_async/visitors/visitor_utility.py b/flake8_async/visitors/visitor_utility.py index e1c2ea67..f581027b 100644 --- a/flake8_async/visitors/visitor_utility.py +++ b/flake8_async/visitors/visitor_utility.py @@ -155,55 +155,50 @@ def visit_Import(self, node: cst.Import): self.add_library(alias.name.value) -# Populates `self.imports` (a map of local-name -> canonical dotted qualname) -# so helpers can resolve call-sites back to their canonical qualname regardless -# of how the user imported things. +# Populate `imports` (local-name -> canonical dotted qualname) so helpers can +# resolve call-sites regardless of import style. Examples: +# import trio -> imports["trio"] = "trio" +# import trio as t -> imports["t"] = "trio" +# import trio.lowlevel -> imports["trio"] = "trio" +# imports["trio.lowlevel"] = "trio.lowlevel" +# import trio.lowlevel as ll -> imports["ll"] = "trio.lowlevel" +# from trio import sleep -> imports["sleep"] = "trio.sleep" +# from trio import sleep as s -> imports["s"] = "trio.sleep" # -# Examples: -# import trio -> imports["trio"] = "trio" -# import trio as t -> imports["t"] = "trio" -# import trio.lowlevel -> imports["trio"] = "trio" -# imports["trio.lowlevel"] = "trio.lowlevel" -# import trio.lowlevel as ll -> imports["ll"] = "trio.lowlevel" -# from trio import sleep -> imports["sleep"] = "trio.sleep" -# from trio import sleep as s -> imports["s"] = "trio.sleep" -# from trio.lowlevel import wait_* is treated as "trio.lowlevel.wait_*" -# -# Only top-level (module-level) imports are tracked; function- and class-local -# imports are intentionally skipped so that a local import inside one function -# doesn't leak into sibling scopes. The CST pass runs each utility visitor -# over the whole module in one go before any error visitor, so scope-aware -# tracking would require significantly more plumbing -- and ignoring local -# imports matches what linters typically do in practice. +# Only module-level imports are tracked: function-/class-local imports are +# skipped to keep them out of sibling scopes. A full scope-aware resolver +# would also need to know the call site's position, which isn't justified +# given how uncommon local imports of async APIs are. @utility_visitor class VisitorImportTracker(Flake8AsyncVisitor): def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self._scope_depth = 0 - def _add_import(self, local: str, canonical: str) -> None: - if self._scope_depth == 0: - self.imports[local] = canonical + def _at_module_level(self) -> bool: + return self._scope_depth == 0 def visit_Import(self, node: ast.Import): + if not self._at_module_level(): + return for alias in node.names: if alias.asname is not None: - self._add_import(alias.asname, alias.name) - else: - top = alias.name.partition(".")[0] - if self._scope_depth == 0 and top not in self.imports: - self.imports[top] = top - if self._scope_depth == 0 and alias.name not in self.imports: - self.imports[alias.name] = alias.name + self.imports[alias.asname] = alias.name + continue + # `import a.b.c` binds `a` and also resolves `a.b.c.` through + # the Attribute chain, so we record both. + top = alias.name.partition(".")[0] + self.imports.setdefault(top, top) + self.imports.setdefault(alias.name, alias.name) def visit_ImportFrom(self, node: ast.ImportFrom): - if node.module is None or node.level: + if node.module is None or node.level or not self._at_module_level(): return for alias in node.names: if alias.name == "*": continue local = alias.asname if alias.asname is not None else alias.name - self._add_import(local, f"{node.module}.{alias.name}") + self.imports[local] = f"{node.module}.{alias.name}" def _enter_scope(self, node: ast.AST): self.save_state(node, "_scope_depth") @@ -221,30 +216,34 @@ def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self._scope_depth = 0 - def _add_import(self, local: str, canonical: str) -> None: - if self._scope_depth == 0: - self.imports[local] = canonical + def _at_module_level(self) -> bool: + return self._scope_depth == 0 def visit_Import(self, node: cst.Import): + if not self._at_module_level(): + return for alias in node.names: full_name = identifier_to_string(alias.name) if full_name is None: continue if alias.asname is not None and isinstance(alias.asname.name, cst.Name): - self._add_import(alias.asname.name.value, full_name) - elif self._scope_depth == 0: - top = full_name.partition(".")[0] - self.imports.setdefault(top, top) - self.imports.setdefault(full_name, full_name) + self.imports[alias.asname.name.value] = full_name + continue + top = full_name.partition(".")[0] + self.imports.setdefault(top, top) + self.imports.setdefault(full_name, full_name) def visit_ImportFrom(self, node: cst.ImportFrom): - if node.module is None or node.relative: + if ( + node.module is None + or node.relative + or isinstance(node.names, cst.ImportStar) + or not self._at_module_level() + ): return module = identifier_to_string(node.module) if module is None: return - if isinstance(node.names, cst.ImportStar): - return for alias in node.names: name = identifier_to_string(alias.name) if name is None: @@ -253,34 +252,23 @@ def visit_ImportFrom(self, node: cst.ImportFrom): local = alias.asname.name.value else: local = name - self._add_import(local, f"{module}.{name}") + self.imports[local] = f"{module}.{name}" - def visit_FunctionDef(self, node: cst.FunctionDef): + def _enter_scope(self, node: cst.CSTNode) -> None: self._scope_depth += 1 - def leave_FunctionDef( - self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef - ) -> cst.FunctionDef: + def _leave_scope( + self, original_node: cst.CSTNode, updated_node: Any + ) -> Any: self._scope_depth -= 1 return updated_node - def visit_ClassDef(self, node: cst.ClassDef): - self._scope_depth += 1 - - def leave_ClassDef( - self, original_node: cst.ClassDef, updated_node: cst.ClassDef - ) -> cst.ClassDef: - self._scope_depth -= 1 - return updated_node - - def visit_Lambda(self, node: cst.Lambda): - self._scope_depth += 1 - - def leave_Lambda( - self, original_node: cst.Lambda, updated_node: cst.Lambda - ) -> cst.Lambda: - self._scope_depth -= 1 - return updated_node + visit_FunctionDef = _enter_scope + visit_ClassDef = _enter_scope + visit_Lambda = _enter_scope + leave_FunctionDef = _leave_scope + leave_ClassDef = _leave_scope + leave_Lambda = _leave_scope # taken from diff --git a/flake8_async/visitors/visitors.py b/flake8_async/visitors/visitors.py index 18a2b51a..21ba57ea 100644 --- a/flake8_async/visitors/visitors.py +++ b/flake8_async/visitors/visitors.py @@ -26,14 +26,9 @@ @error_class @disabled_by_default class Visitor106(Flake8AsyncVisitor): - # Historically this enforced `import trio` because the linter couldn't see - # through aliased/from-imports. That limitation is gone (see #132), so - # ASYNC106 is now opt-in for projects that still want to enforce the style. + # Opt-in style check; other rules already handle all import styles. error_codes: Mapping[str, str] = { - "ASYNC106": ( - "{0} should be imported with `import {0}` for consistency" - " (historical reasons; no longer required for the linter to work)." - ), + "ASYNC106": "{0} should be imported with `import {0}` for consistency.", } def visit_ImportFrom(self, node: ast.ImportFrom): diff --git a/tests/autofix_files/exception_suppress_context_manager.py b/tests/autofix_files/exception_suppress_context_manager.py index 3d00a643..6f85d2a6 100644 --- a/tests/autofix_files/exception_suppress_context_manager.py +++ b/tests/autofix_files/exception_suppress_context_manager.py @@ -87,11 +87,9 @@ async def foo_suppress_as(): # ASYNC910: 0, "exit", Statement('function definit # ############################### -# Module-level imports are resolved via their canonical qualname, so the -# `from contextlib import suppress` below is recognised as a suppressing CM -# even though the function is defined above the import statement (Python -# resolves the name at call time, so this more closely matches runtime -# semantics). +# Module-level imports are visible to any function body in the same file +# (Python resolves names at call time), so the `from contextlib import suppress` +# further down makes `suppress` a suppressing CM in this function too. async def foo_suppress_directly_imported_1(): # ASYNC910: 0, "exit", Statement('function definition', lineno) with suppress(): await foo() diff --git a/tests/eval_files/async110.py b/tests/eval_files/async110.py index 8527577b..ccda813c 100644 --- a/tests/eval_files/async110.py +++ b/tests/eval_files/async110.py @@ -38,7 +38,7 @@ async def foo(): await trio.sleep() await trio.sleep_until() - # aliased-import resolves to canonical qualname, so this now errors too + # `import trio as noerror` -- resolves to canonical `trio.sleep`. while ...: # error: 4, "trio" await noerror.sleep() diff --git a/tests/eval_files/async111.py b/tests/eval_files/async111.py index 88d24bf5..6160a81d 100644 --- a/tests/eval_files/async111.py +++ b/tests/eval_files/async111.py @@ -77,7 +77,7 @@ async def foo_2(): async with trio.open_process() as bar_2: nursery.start(bar_2) # safe -# aliased trio import is now resolved to canonical qualname, so this errors too +# `import trio as noterror` -- open_nursery resolves to canonical qualname with noterror.open_nursery() as nursery: with trio.open("") as bar: nursery.start(bar) # error: 22, line-1, line-2, "bar", "start" diff --git a/tests/eval_files/async112.py b/tests/eval_files/async112.py index 01a497c6..9e9559ec 100644 --- a/tests/eval_files/async112.py +++ b/tests/eval_files/async112.py @@ -86,7 +86,7 @@ async def foo_1(): await n.start(...) -# aliased trio import is now resolved to canonical qualname, so this errors too +# `import trio as noterror` -- open_nursery resolves to canonical qualname with noterror.open_nursery(...) as n: # error: 5, "n", "nursery" n.start(...) diff --git a/tests/eval_files/async112_canonical_qualname.py b/tests/eval_files/async112_canonical_qualname.py index 2e988b79..8c8ddad0 100644 --- a/tests/eval_files/async112_canonical_qualname.py +++ b/tests/eval_files/async112_canonical_qualname.py @@ -1,5 +1,5 @@ # Regression test for https://github.com/python-trio/flake8-async/issues/132: -# detection works regardless of how trio is imported. +# rules fire against the canonical qualname regardless of import style. # type: ignore # ASYNCIO_NO_ERROR # ARG --enable=ASYNC112 @@ -10,21 +10,17 @@ from trio import open_nursery as on -# `import trio as t` with t.open_nursery() as n: # error: 5, "n", "nursery" n.start(...) -# `from trio import open_nursery` with open_nursery() as n: # error: 5, "n", "nursery" n.start(...) -# `from trio import open_nursery as on` with on() as n: # error: 5, "n", "nursery" n.start_soon(...) -# canonical name still matches when chained through an ordinary `import trio` with trio.open_nursery() as n: # error: 5, "n", "nursery" n.start(...) diff --git a/tests/eval_files/async115.py b/tests/eval_files/async115.py index 70186366..fd5beb20 100644 --- a/tests/eval_files/async115.py +++ b/tests/eval_files/async115.py @@ -18,9 +18,9 @@ async def afoo(): trio.sleep(0) # error: 4, "trio" trio.sleep(1) - # don't error on unrelated sleeps + # unrelated sleeps don't match time.sleep(0) - # `from trio import sleep` resolves to canonical `trio.sleep`, so this now errors + # `from trio import sleep` -- resolves to canonical `trio.sleep` sleep(0) # error: 4, "trio" # in trio it's called 'seconds', in anyio it's 'delay', but diff --git a/tests/eval_files/async115_canonical_qualname.py b/tests/eval_files/async115_canonical_qualname.py index 7e7427f9..550c47cd 100644 --- a/tests/eval_files/async115_canonical_qualname.py +++ b/tests/eval_files/async115_canonical_qualname.py @@ -1,7 +1,5 @@ # Regression test for https://github.com/python-trio/flake8-async/issues/132: -# rules should fire against the canonical qualname regardless of how a symbol -# is imported -- plain `import trio`, `import ... as ...`, -# `from ... import ...`, or `from ... import ... as ...`. +# rules fire against the canonical qualname regardless of import style. # type: ignore # ASYNCIO_NO_ERROR - ASYNC115 is trio/anyio-only # ARG --enable=ASYNC115 @@ -14,25 +12,16 @@ async def afoo(): - # `import trio as t` await t.sleep(0) # error: 10, "trio" - - # `from trio import sleep` await sleep(0) # error: 10, "trio" - - # `from trio import sleep as nap` await nap(0) # error: 10, "trio" - # `import trio.lowlevel as ll` still resolves the canonical qualname - # (we only track this for its side-effect on ASYNC110 elsewhere, but it - # must not crash). + # `import trio.lowlevel as ll` and `from trio.lowlevel import ... as ...` + # are resolvable but aren't matched by ASYNC115 -- we're just asserting + # that resolution doesn't misfire. ll.checkpoint() - - # `from trio.lowlevel import checkpoint as cp` -- ASYNC115 doesn't match this - # particular qualname, but we just want to show resolution doesn't crash. cp() - # non-aliased local name that shadows an import should not falsely match: - # no import binds `sleep_2`, so `sleep_2(0)` is not flagged. + # a local name that shadows nothing imported must not match sleep_2 = lambda x: None sleep_2(0) diff --git a/tests/eval_files/async251.py b/tests/eval_files/async251.py index 00d2742e..b6b86cb3 100644 --- a/tests/eval_files/async251.py +++ b/tests/eval_files/async251.py @@ -6,5 +6,5 @@ async def foo(): time.sleep(5) # ASYNC251: 4, "trio" time.sleep(5) if 5 else time.sleep(5) # ASYNC251: 4, "trio" # ASYNC251: 28, "trio" - # `from time import sleep` resolves to canonical `time.sleep`, so this now errors + # `from time import sleep` -- resolves to canonical `time.sleep` sleep(5) # ASYNC251: 4, "trio" diff --git a/tests/eval_files/async300.py b/tests/eval_files/async300.py index 8fd3f920..13e3f5b2 100644 --- a/tests/eval_files/async300.py +++ b/tests/eval_files/async300.py @@ -67,8 +67,7 @@ def returner_list(): with asyncio.create_task(*args) as k: # type: ignore[attr-defined] # ASYNC300: 9 ... - # module-level imports are resolved to their canonical qualname, but - # function-local imports are not tracked (they'd leak into sibling scopes). + # function-local imports aren't tracked (so they don't leak to siblings) from asyncio import create_task create_task(*args) diff --git a/tests/eval_files/exception_suppress_context_manager.py b/tests/eval_files/exception_suppress_context_manager.py index d7114cf0..9871d526 100644 --- a/tests/eval_files/exception_suppress_context_manager.py +++ b/tests/eval_files/exception_suppress_context_manager.py @@ -80,11 +80,9 @@ async def foo_suppress_as(): # ASYNC910: 0, "exit", Statement('function definit # ############################### -# Module-level imports are resolved via their canonical qualname, so the -# `from contextlib import suppress` below is recognised as a suppressing CM -# even though the function is defined above the import statement (Python -# resolves the name at call time, so this more closely matches runtime -# semantics). +# Module-level imports are visible to any function body in the same file +# (Python resolves names at call time), so the `from contextlib import suppress` +# further down makes `suppress` a suppressing CM in this function too. async def foo_suppress_directly_imported_1(): # ASYNC910: 0, "exit", Statement('function definition', lineno) with suppress(): await foo() diff --git a/tests/test_config_and_args.py b/tests/test_config_and_args.py index aa2f3709..df485c92 100644 --- a/tests/test_config_and_args.py +++ b/tests/test_config_and_args.py @@ -493,8 +493,7 @@ def test_disable_noqa_ast( assert ( out == "./example.py:1:1: ASYNC106 trio should be imported with `import trio`" - " for consistency (historical reasons; no longer required for the linter" - " to work).\n" + " for consistency.\n" ) From 01bd1579fabc590c2d920eaeeb5f3b085859608d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 24 Apr 2026 19:36:07 +0000 Subject: [PATCH 3/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flake8_async/visitors/helpers.py | 8 ++------ flake8_async/visitors/visitor102_120.py | 5 ++++- flake8_async/visitors/visitor111.py | 4 +--- flake8_async/visitors/visitor91x.py | 1 - flake8_async/visitors/visitor_utility.py | 4 +--- tests/eval_files/async112_canonical_qualname.py | 1 - tests/test_config_and_args.py | 3 +-- 7 files changed, 9 insertions(+), 17 deletions(-) diff --git a/flake8_async/visitors/helpers.py b/flake8_async/visitors/helpers.py index 089d2ecc..4294d0af 100644 --- a/flake8_async/visitors/helpers.py +++ b/flake8_async/visitors/helpers.py @@ -351,9 +351,7 @@ def resolve_canonical_ast(node: ast.AST, imports: Mapping[str, str]) -> str | No return None -def resolve_canonical_cst( - node: cst.CSTNode, imports: Mapping[str, str] -) -> str | None: +def resolve_canonical_cst(node: cst.CSTNode, imports: Mapping[str, str]) -> str | None: if isinstance(node, cst.Name): return imports.get(node.value, node.value) if isinstance(node, cst.Attribute): @@ -512,9 +510,7 @@ def with_has_call( if canonical is not None and canonical.startswith(f"{b}."): suffix = canonical[len(b) + 1 :] if suffix in names: - res_list.append( - MatchingCall(node=item.item, base=b, name=suffix) - ) + res_list.append(MatchingCall(node=item.item, base=b, name=suffix)) break return res_list diff --git a/flake8_async/visitors/visitor102_120.py b/flake8_async/visitors/visitor102_120.py index ad3e20c3..370e0cd7 100644 --- a/flake8_async/visitors/visitor102_120.py +++ b/flake8_async/visitors/visitor102_120.py @@ -205,7 +205,10 @@ def visit_ExceptHandler(self, node: ast.ExceptHandler): self._trio_context_managers = [] self._potential_120 = [] - if self.cancelled_caught or (res := critical_except(node, self.imports)) is None: + if ( + self.cancelled_caught + or (res := critical_except(node, self.imports)) is None + ): self._critical_scope = Statement("except", node.lineno, node.col_offset) else: self._critical_scope = res diff --git a/flake8_async/visitors/visitor111.py b/flake8_async/visitors/visitor111.py index 3cb2c6f0..b2a70d81 100644 --- a/flake8_async/visitors/visitor111.py +++ b/flake8_async/visitors/visitor111.py @@ -12,9 +12,7 @@ from collections.abc import Mapping -def is_nursery_like( - node: ast.expr, imports: Mapping[str, str] | None = None -) -> bool: +def is_nursery_like(node: ast.expr, imports: Mapping[str, str] | None = None) -> bool: return bool( get_matching_call(node, "open_nursery", base="trio", imports=imports) or get_matching_call(node, "create_task_group", base="anyio", imports=imports) diff --git a/flake8_async/visitors/visitor91x.py b/flake8_async/visitors/visitor91x.py index 9d0a2429..4cebdb1b 100644 --- a/flake8_async/visitors/visitor91x.py +++ b/flake8_async/visitors/visitor91x.py @@ -32,7 +32,6 @@ fnmatch_qualified_name_cst, func_has_decorator, get_matching_call_cst, - identifier_to_string, iter_guaranteed_once_cst, ) diff --git a/flake8_async/visitors/visitor_utility.py b/flake8_async/visitors/visitor_utility.py index f581027b..61437926 100644 --- a/flake8_async/visitors/visitor_utility.py +++ b/flake8_async/visitors/visitor_utility.py @@ -257,9 +257,7 @@ def visit_ImportFrom(self, node: cst.ImportFrom): def _enter_scope(self, node: cst.CSTNode) -> None: self._scope_depth += 1 - def _leave_scope( - self, original_node: cst.CSTNode, updated_node: Any - ) -> Any: + def _leave_scope(self, original_node: cst.CSTNode, updated_node: Any) -> Any: self._scope_depth -= 1 return updated_node diff --git a/tests/eval_files/async112_canonical_qualname.py b/tests/eval_files/async112_canonical_qualname.py index 8c8ddad0..6ca47f12 100644 --- a/tests/eval_files/async112_canonical_qualname.py +++ b/tests/eval_files/async112_canonical_qualname.py @@ -9,7 +9,6 @@ from trio import open_nursery from trio import open_nursery as on - with t.open_nursery() as n: # error: 5, "n", "nursery" n.start(...) diff --git a/tests/test_config_and_args.py b/tests/test_config_and_args.py index df485c92..5547e49d 100644 --- a/tests/test_config_and_args.py +++ b/tests/test_config_and_args.py @@ -491,8 +491,7 @@ def test_disable_noqa_ast( out, err = capsys.readouterr() assert not err assert ( - out - == "./example.py:1:1: ASYNC106 trio should be imported with `import trio`" + out == "./example.py:1:1: ASYNC106 trio should be imported with `import trio`" " for consistency.\n" ) From 617d87c0c131871e7d36b656570ec48610643676 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 24 Apr 2026 20:21:25 +0000 Subject: [PATCH 4/6] Appease ruff and mypy - Move resolve_canonical_ast/cst into a dedicated _canonical module so the base-class methods can import them at the top level (PLC0415). - Flatten nested isinstance chain in get_matching_call_cst (SIM102). - Reformat the import-tracker example table so ruff stops flagging the continuation line as commented-out code (ERA001). - Inline the isinstance(ast.Call) check in critical_except so mypy's narrowing kicks in (attr-defined on "expr"). - Drop the now-unused identifier_to_string import from visitor91x. https://claude.ai/code/session_018Hc9rcA31SnXcN8Ee5vVwH --- flake8_async/visitors/_canonical.py | 44 ++++++++++++++++ flake8_async/visitors/flake8asyncvisitor.py | 5 +- flake8_async/visitors/helpers.py | 56 +++++++-------------- flake8_async/visitors/visitor_utility.py | 15 +++--- 4 files changed, 70 insertions(+), 50 deletions(-) create mode 100644 flake8_async/visitors/_canonical.py diff --git a/flake8_async/visitors/_canonical.py b/flake8_async/visitors/_canonical.py new file mode 100644 index 00000000..a176d9d3 --- /dev/null +++ b/flake8_async/visitors/_canonical.py @@ -0,0 +1,44 @@ +"""Canonical-qualname resolution for ast / cst nodes. + +Kept in its own module to avoid circular imports between +``flake8asyncvisitor`` (which exposes ``canonical_name`` on the base classes) +and ``helpers`` (which accepts an ``imports`` mapping for matcher functions). +""" + +from __future__ import annotations + +import ast +from typing import TYPE_CHECKING + +import libcst as cst + +if TYPE_CHECKING: + from collections.abc import Mapping + + +# Resolve a Name/Attribute/Call node to a dotted qualname via `imports` +# (local-name -> canonical dotted qualname). The root Name falls back to its own +# identifier, so `trio.open_nursery()` resolves to "trio.open_nursery" even when +# nothing was imported. Returns None for shapes we can't resolve (subscripts, etc.). +def resolve_canonical_ast(node: ast.AST, imports: Mapping[str, str]) -> str | None: + if isinstance(node, ast.Name): + return imports.get(node.id, node.id) + if isinstance(node, ast.Attribute): + prefix = resolve_canonical_ast(node.value, imports) + return None if prefix is None else f"{prefix}.{node.attr}" + if isinstance(node, ast.Call): + return resolve_canonical_ast(node.func, imports) + return None + + +def resolve_canonical_cst( + node: cst.CSTNode, imports: Mapping[str, str] +) -> str | None: + if isinstance(node, cst.Name): + return imports.get(node.value, node.value) + if isinstance(node, cst.Attribute): + prefix = resolve_canonical_cst(node.value, imports) + return None if prefix is None else f"{prefix}.{node.attr.value}" + if isinstance(node, cst.Call): + return resolve_canonical_cst(node.func, imports) + return None diff --git a/flake8_async/visitors/flake8asyncvisitor.py b/flake8_async/visitors/flake8asyncvisitor.py index b683a99b..22f907f1 100644 --- a/flake8_async/visitors/flake8asyncvisitor.py +++ b/flake8_async/visitors/flake8asyncvisitor.py @@ -10,6 +10,7 @@ from libcst.metadata import PositionProvider from ..base import Error, Statement, strip_error_subidentifier +from ._canonical import resolve_canonical_ast, resolve_canonical_cst if TYPE_CHECKING: from collections.abc import Iterable, Mapping @@ -58,8 +59,6 @@ def imports(self) -> dict[str, str]: return self.__state.imports def canonical_name(self, node: ast.AST) -> str | None: - from .helpers import resolve_canonical_ast - return resolve_canonical_ast(node, self.__state.imports) def visit(self, node: ast.AST): @@ -184,8 +183,6 @@ def imports(self) -> dict[str, str]: return self.__state.imports def canonical_name(self, node: cst.CSTNode) -> str | None: - from .helpers import resolve_canonical_cst - return resolve_canonical_cst(node, self.__state.imports) def get_state(self, *attrs: str, copy: bool = False) -> dict[str, Any]: diff --git a/flake8_async/visitors/helpers.py b/flake8_async/visitors/helpers.py index 4294d0af..9b4a5bb4 100644 --- a/flake8_async/visitors/helpers.py +++ b/flake8_async/visitors/helpers.py @@ -24,6 +24,7 @@ utility_visitors, utility_visitors_cst, ) +from ._canonical import resolve_canonical_ast, resolve_canonical_cst if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Mapping, Sequence @@ -276,15 +277,17 @@ def has_exception(node: ast.expr) -> str | None: # The non-call spellings (`except anyio.get_cancelled_exc_class:`, or a # Call with arguments) are type-errors that critical_except intentionally # ignores, so only zero-arg calls count for get_cancelled_exc_class. - is_call = isinstance(node, ast.Call) - if is_call and (node.args or node.keywords): + if isinstance(node, ast.Call): + if node.args or node.keywords: + return None + canonical = resolve_canonical_ast(node.func, imports) + if canonical == "anyio.get_cancelled_exc_class": + return "anyio.get_cancelled_exc_class()" return None - canonical = resolve_canonical_ast(node.func if is_call else node, imports) - if not is_call and canonical == "trio.Cancelled": + canonical = resolve_canonical_ast(node, imports) + if canonical == "trio.Cancelled": return "trio.Cancelled" - if is_call and canonical == "anyio.get_cancelled_exc_class": - return "anyio.get_cancelled_exc_class()" - if not is_call and canonical in ( + if canonical in ( "asyncio.exceptions.CancelledError", "asyncio.CancelledError", ): @@ -336,32 +339,6 @@ def __str__(self) -> str: return self.base + "." + self.name -# Resolve a Name/Attribute/Call node to a dotted qualname via `imports` -# (local-name -> canonical dotted qualname). The root Name falls back to its own -# identifier, so `trio.open_nursery()` resolves to "trio.open_nursery" even when -# nothing was imported. Returns None for shapes we can't resolve (subscripts, etc.). -def resolve_canonical_ast(node: ast.AST, imports: Mapping[str, str]) -> str | None: - if isinstance(node, ast.Name): - return imports.get(node.id, node.id) - if isinstance(node, ast.Attribute): - prefix = resolve_canonical_ast(node.value, imports) - return None if prefix is None else f"{prefix}.{node.attr}" - if isinstance(node, ast.Call): - return resolve_canonical_ast(node.func, imports) - return None - - -def resolve_canonical_cst(node: cst.CSTNode, imports: Mapping[str, str]) -> str | None: - if isinstance(node, cst.Name): - return imports.get(node.value, node.value) - if isinstance(node, cst.Attribute): - prefix = resolve_canonical_cst(node.value, imports) - return None if prefix is None else f"{prefix}.{node.attr.value}" - if isinstance(node, cst.Call): - return resolve_canonical_cst(node.func, imports) - return None - - # convenience function used in a lot of visitors def get_matching_call( node: ast.AST, @@ -400,11 +377,14 @@ def get_matching_call_cst( base = (base,) if not isinstance(node, cst.Call): return None - if isinstance(node.func, cst.Attribute) and node.func.attr.value in names: - if isinstance(node.func.value, (cst.Name, cst.Attribute)): - attr_base = identifier_to_string(node.func.value) - if attr_base is not None and attr_base in base: - return MatchingCall(node, node.func.attr.value, attr_base) + if ( + isinstance(node.func, cst.Attribute) + and node.func.attr.value in names + and isinstance(node.func.value, (cst.Name, cst.Attribute)) + ): + attr_base = identifier_to_string(node.func.value) + if attr_base is not None and attr_base in base: + return MatchingCall(node, node.func.attr.value, attr_base) if imports is not None: canonical = resolve_canonical_cst(node.func, imports) for b in base: diff --git a/flake8_async/visitors/visitor_utility.py b/flake8_async/visitors/visitor_utility.py index 61437926..1011ffb5 100644 --- a/flake8_async/visitors/visitor_utility.py +++ b/flake8_async/visitors/visitor_utility.py @@ -156,14 +156,13 @@ def visit_Import(self, node: cst.Import): # Populate `imports` (local-name -> canonical dotted qualname) so helpers can -# resolve call-sites regardless of import style. Examples: -# import trio -> imports["trio"] = "trio" -# import trio as t -> imports["t"] = "trio" -# import trio.lowlevel -> imports["trio"] = "trio" -# imports["trio.lowlevel"] = "trio.lowlevel" -# import trio.lowlevel as ll -> imports["ll"] = "trio.lowlevel" -# from trio import sleep -> imports["sleep"] = "trio.sleep" -# from trio import sleep as s -> imports["s"] = "trio.sleep" +# resolve call-sites regardless of import style. Mappings produced: +# "import trio" => {"trio": "trio"} +# "import trio as t" => {"t": "trio"} +# "import trio.lowlevel" => {"trio": "trio", "trio.lowlevel": "trio.lowlevel"} +# "import trio.lowlevel as ll" => {"ll": "trio.lowlevel"} +# "from trio import sleep" => {"sleep": "trio.sleep"} +# "from trio import sleep as s" => {"s": "trio.sleep"} # # Only module-level imports are tracked: function-/class-local imports are # skipped to keep them out of sibling scopes. A full scope-aware resolver From 201484cb2fdf7238dcf143957279747ede9f69af Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 24 Apr 2026 20:23:53 +0000 Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flake8_async/visitors/_canonical.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/flake8_async/visitors/_canonical.py b/flake8_async/visitors/_canonical.py index a176d9d3..0a37a243 100644 --- a/flake8_async/visitors/_canonical.py +++ b/flake8_async/visitors/_canonical.py @@ -31,9 +31,7 @@ def resolve_canonical_ast(node: ast.AST, imports: Mapping[str, str]) -> str | No return None -def resolve_canonical_cst( - node: cst.CSTNode, imports: Mapping[str, str] -) -> str | None: +def resolve_canonical_cst(node: cst.CSTNode, imports: Mapping[str, str]) -> str | None: if isinstance(node, cst.Name): return imports.get(node.value, node.value) if isinstance(node, cst.Attribute): From b473acf63b7d33ecd1bfd051c780be5647848cee Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 24 Apr 2026 22:20:19 +0000 Subject: [PATCH 6/6] Add label to ASYNC106 docs so changelog cross-reference resolves The changelog entry `:ref:`ASYNC106 `` targets a rule that didn't have a Sphinx label, which made readthedocs fail with `undefined label: 'async106'` under `-W`. https://claude.ai/code/session_018Hc9rcA31SnXcN8Ee5vVwH --- docs/rules.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/rules.rst b/docs/rules.rst index 315b5c4d..fda9f79c 100644 --- a/docs/rules.rst +++ b/docs/rules.rst @@ -41,7 +41,7 @@ ASYNC105 : missing-await async trio function called without using ``await``. This is only supported with trio functions, but you can get similar functionality with a type-checker. -ASYNC106 : bad-async-library-import +_`ASYNC106` : bad-async-library-import trio/anyio/asyncio should be imported with ``import xxx`` for consistency. Opt-in style check; the linter resolves other import styles correctly.