diff --git a/docs/changelog.rst b/docs/changelog.rst index 4df8d04..a95b1fd 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -6,6 +6,8 @@ Changelog Unreleased ========== +- Rules resolve function/class references via the canonical qualname, so checks fire regardless of import style (``import trio``, ``import trio as t``, ``from trio import open_nursery [as on]``, …). Only module-level imports are tracked. `(issue #132) `_ +- :ref:`ASYNC106 ` is now disabled by default; re-enable it to enforce the ``import trio`` style. - Autofix for :ref:`ASYNC910 ` / :ref:`ASYNC911 ` no longer inserts checkpoints inside ``except`` clauses (which would trigger :ref:`ASYNC120 `); instead the checkpoint is added at the top of the function or of the enclosing loop. `(issue #403) `_ - :ref:`ASYNC910 ` and :ref:`ASYNC911 ` now accept ``__aenter__`` / ``__aexit__`` methods when the partner method provides the checkpoint, or when only one of the two is defined on a class that inherits from another class (charitably assuming the partner is inherited and contains a checkpoint). `(issue #441) `_ - :ref:`ASYNC300 ` no longer triggers when the result of ``asyncio.create_task()`` is returned from a function. `(issue #398) `_ diff --git a/docs/rules.rst b/docs/rules.rst index ba25950..fda9f79 100644 --- a/docs/rules.rst +++ b/docs/rules.rst @@ -41,8 +41,9 @@ ASYNC105 : missing-await async trio function called without using ``await``. This is only supported with trio functions, but you can get similar functionality with a type-checker. -ASYNC106 : bad-async-library-import - trio/anyio/asyncio must be imported with ``import xxx`` for the linter to work. +_`ASYNC106` : bad-async-library-import + trio/anyio/asyncio should be imported with ``import xxx`` for consistency. + Opt-in style check; the linter resolves other import styles correctly. ASYNC109 : async-function-with-timeout Async function definition with a ``timeout`` parameter. diff --git a/flake8_async/runner.py b/flake8_async/runner.py index 38ff3f5..fb8d7a4 100644 --- a/flake8_async/runner.py +++ b/flake8_async/runner.py @@ -37,6 +37,11 @@ class SharedState: library: tuple[str, ...] = () typed_calls: dict[str, str] = field(default_factory=dict[str, str]) variables: dict[str, str] = field(default_factory=dict[str, str]) + # Local name -> canonical dotted qualname, populated by VisitorImportTracker[_cst]. + # Helpers consult this so rules can match the canonical qualname regardless of + # how a symbol was imported (`import x`, `import x as y`, `from x import y`, + # `from x import y as z`). + imports: dict[str, str] = field(default_factory=dict[str, str]) class __CommonRunner: diff --git a/flake8_async/visitors/_canonical.py b/flake8_async/visitors/_canonical.py new file mode 100644 index 0000000..0a37a24 --- /dev/null +++ b/flake8_async/visitors/_canonical.py @@ -0,0 +1,42 @@ +"""Canonical-qualname resolution for ast / cst nodes. + +Kept in its own module to avoid circular imports between +``flake8asyncvisitor`` (which exposes ``canonical_name`` on the base classes) +and ``helpers`` (which accepts an ``imports`` mapping for matcher functions). +""" + +from __future__ import annotations + +import ast +from typing import TYPE_CHECKING + +import libcst as cst + +if TYPE_CHECKING: + from collections.abc import Mapping + + +# Resolve a Name/Attribute/Call node to a dotted qualname via `imports` +# (local-name -> canonical dotted qualname). The root Name falls back to its own +# identifier, so `trio.open_nursery()` resolves to "trio.open_nursery" even when +# nothing was imported. Returns None for shapes we can't resolve (subscripts, etc.). +def resolve_canonical_ast(node: ast.AST, imports: Mapping[str, str]) -> str | None: + if isinstance(node, ast.Name): + return imports.get(node.id, node.id) + if isinstance(node, ast.Attribute): + prefix = resolve_canonical_ast(node.value, imports) + return None if prefix is None else f"{prefix}.{node.attr}" + if isinstance(node, ast.Call): + return resolve_canonical_ast(node.func, imports) + return None + + +def resolve_canonical_cst(node: cst.CSTNode, imports: Mapping[str, str]) -> str | None: + if isinstance(node, cst.Name): + return imports.get(node.value, node.value) + if isinstance(node, cst.Attribute): + prefix = resolve_canonical_cst(node.value, imports) + return None if prefix is None else f"{prefix}.{node.attr.value}" + if isinstance(node, cst.Call): + return resolve_canonical_cst(node.func, imports) + return None diff --git a/flake8_async/visitors/flake8asyncvisitor.py b/flake8_async/visitors/flake8asyncvisitor.py index 7774281..22f907f 100644 --- a/flake8_async/visitors/flake8asyncvisitor.py +++ b/flake8_async/visitors/flake8asyncvisitor.py @@ -10,6 +10,7 @@ from libcst.metadata import PositionProvider from ..base import Error, Statement, strip_error_subidentifier +from ._canonical import resolve_canonical_ast, resolve_canonical_cst if TYPE_CHECKING: from collections.abc import Iterable, Mapping @@ -53,6 +54,13 @@ def variables(self, value: dict[str, str]) -> None: self.__state.variables.clear() self.__state.variables.update(value) + @property + def imports(self) -> dict[str, str]: + return self.__state.imports + + def canonical_name(self, node: ast.AST) -> str | None: + return resolve_canonical_ast(node, self.__state.imports) + def visit(self, node: ast.AST): """Visit a node.""" # construct visitor for this node type @@ -170,6 +178,13 @@ def __init__(self, shared_state: SharedState): self.options = self.__state.options self.noqas = self.__state.noqas + @property + def imports(self) -> dict[str, str]: + return self.__state.imports + + def canonical_name(self, node: cst.CSTNode) -> str | None: + return resolve_canonical_cst(node, self.__state.imports) + def get_state(self, *attrs: str, copy: bool = False) -> dict[str, Any]: # require attrs, since we inherit a *ton* of stuff which we don't want to copy assert attrs diff --git a/flake8_async/visitors/helpers.py b/flake8_async/visitors/helpers.py index c9d4f54..9b4a5bb 100644 --- a/flake8_async/visitors/helpers.py +++ b/flake8_async/visitors/helpers.py @@ -24,9 +24,10 @@ utility_visitors, utility_visitors_cst, ) +from ._canonical import resolve_canonical_ast, resolve_canonical_cst if TYPE_CHECKING: - from collections.abc import Iterable, Iterator, Sequence + from collections.abc import Iterable, Iterator, Mapping, Sequence from .flake8asyncvisitor import ( Flake8AsyncVisitor, @@ -101,15 +102,21 @@ def has_decorator(node: ast.FunctionDef | ast.AsyncFunctionDef, *names: str): # matches the fully qualified name against fnmatch pattern # used to match decorators and methods to user-supplied patterns # used in 910/911 and 200 -def fnmatch_qualified_name(name_list: list[ast.expr], *patterns: str) -> str | None: +def fnmatch_qualified_name( + name_list: Iterable[ast.expr], + *patterns: str, + imports: Mapping[str, str] | None = None, +) -> str | None: for name in name_list: if isinstance(name, ast.Call): name = name.func - qualified_name = ast.unparse(name) - + candidates = {ast.unparse(name)} + if imports is not None and (canonical := resolve_canonical_ast(name, imports)): + candidates.add(canonical) for pattern in patterns: # strip leading "@"s for when we're working with decorators - if fnmatch(qualified_name, pattern.lstrip("@")): + stripped = pattern.lstrip("@") + if any(fnmatch(c, stripped) for c in candidates): return pattern return None @@ -117,13 +124,22 @@ def fnmatch_qualified_name(name_list: list[ast.expr], *patterns: str) -> str | N def fnmatch_qualified_name_cst( name_list: Iterable[cst.Decorator | cst.Call | cst.Attribute | cst.Name], *patterns: str, + imports: Mapping[str, str] | None = None, ) -> str | None: for name in name_list: - qualified_name = get_full_name_for_node_or_raise(name) - + candidates = {get_full_name_for_node_or_raise(name)} + if imports is not None: + inner: cst.CSTNode = name + if isinstance(inner, cst.Decorator): + inner = inner.decorator + if isinstance(inner, cst.Call): + inner = inner.func + if (canonical := resolve_canonical_cst(inner, imports)) is not None: + candidates.add(canonical) for pattern in patterns: # strip leading "@"s for when we're working with decorators - if fnmatch(qualified_name, pattern.lstrip("@")): + stripped = pattern.lstrip("@") + if any(fnmatch(c, stripped) for c in candidates): return pattern return None @@ -240,7 +256,9 @@ def iter_guaranteed_once_cst(iterable: cst.BaseExpression) -> bool: # used in 102, 103 and 104 -def critical_except(node: ast.ExceptHandler) -> Statement | None: +def critical_except( + node: ast.ExceptHandler, imports: Mapping[str, str] | None = None +) -> Statement | None: def has_exception(node: ast.expr) -> str | None: name = ast.unparse(node) if name in ( @@ -253,6 +271,27 @@ def has_exception(node: ast.expr) -> str | None: "CancelledError", ): return name + if imports is None: + return None + # Resolve via canonical qualname for aliased / `from`-imported forms. + # The non-call spellings (`except anyio.get_cancelled_exc_class:`, or a + # Call with arguments) are type-errors that critical_except intentionally + # ignores, so only zero-arg calls count for get_cancelled_exc_class. + if isinstance(node, ast.Call): + if node.args or node.keywords: + return None + canonical = resolve_canonical_ast(node.func, imports) + if canonical == "anyio.get_cancelled_exc_class": + return "anyio.get_cancelled_exc_class()" + return None + canonical = resolve_canonical_ast(node, imports) + if canonical == "trio.Cancelled": + return "trio.Cancelled" + if canonical in ( + "asyncio.exceptions.CancelledError", + "asyncio.CancelledError", + ): + return "asyncio.exceptions.CancelledError" return None name: str | None = None @@ -302,36 +341,56 @@ def __str__(self) -> str: # convenience function used in a lot of visitors def get_matching_call( - node: ast.AST, *names: str, base: Iterable[str] = ("trio", "anyio") + node: ast.AST, + *names: str, + base: Iterable[str] = ("trio", "anyio"), + imports: Mapping[str, str] | None = None, ) -> MatchingCall[ast.Call] | None: if isinstance(base, str): base = (base,) + if not isinstance(node, ast.Call): + return None if ( - isinstance(node, ast.Call) - and isinstance(node.func, ast.Attribute) + isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name) and node.func.value.id in base and node.func.attr in names ): return MatchingCall(node, node.func.attr, node.func.value.id) + if imports is not None: + canonical = resolve_canonical_ast(node.func, imports) + for b in base: + for n in names: + if canonical == f"{b}.{n}": + return MatchingCall(node, n, b) return None # ___ CST helpers ___ def get_matching_call_cst( - node: cst.CSTNode, *names: str, base: Iterable[str] = ("trio", "anyio") + node: cst.CSTNode, + *names: str, + base: Iterable[str] = ("trio", "anyio"), + imports: Mapping[str, str] | None = None, ) -> MatchingCall[cst.Call] | None: if isinstance(base, str): base = (base,) + if not isinstance(node, cst.Call): + return None if ( - isinstance(node, cst.Call) - and isinstance(node.func, cst.Attribute) + isinstance(node.func, cst.Attribute) and node.func.attr.value in names and isinstance(node.func.value, (cst.Name, cst.Attribute)) ): attr_base = identifier_to_string(node.func.value) if attr_base is not None and attr_base in base: return MatchingCall(node, node.func.attr.value, attr_base) + if imports is not None: + canonical = resolve_canonical_cst(node.func, imports) + for b in base: + for n in names: + if canonical == f"{b}.{n}": + return MatchingCall(node, n, b) return None @@ -377,12 +436,17 @@ def identifier_to_string(node: cst.CSTNode) -> str | None: def with_has_call( - node: cst.With, *names: str, base: Iterable[str] | str = ("trio", "anyio") + node: cst.With, + *names: str, + base: Iterable[str] | str = ("trio", "anyio"), + imports: Mapping[str, str] | None = None, ) -> list[MatchingCall[cst.Call]]: """Check if a with statement has a matching call, returning a list with matches. `names` specify the names of functions to match, `base` specifies the - library/module(s) the function must be in. + library/module(s) the function must be in. If `imports` is given, matches + are also made against the canonical qualname so aliased / `from`-imports + are detected. The list elements in the return value are named tuples with the matched node, base and function. @@ -393,19 +457,15 @@ def with_has_call( `foo.bar`, `foo.bee`, `a.b.c.bar`, and `a.b.c.bee`. """ - if isinstance(base, str): - base = (base,) + base_tuple = (base,) if isinstance(base, str) else tuple(base) # build matcher, using SaveMatchedNode to save the base and the function name. matcher = m.Call( func=m.Attribute( value=m.SaveMatchedNode( - m.OneOf(*(build_cst_matcher(b) for b in base)), name="base" - ), - attr=m.SaveMatchedNode( - oneof_names(*names), - name="function", + m.OneOf(*(build_cst_matcher(b) for b in base_tuple)), name="base" ), + attr=m.SaveMatchedNode(oneof_names(*names), name="function"), ) ) @@ -422,10 +482,22 @@ def with_has_call( node=item.item, base=base_string, name=res["function"].value ) ) + continue + if imports is None or not isinstance(item.item, cst.Call): + continue + canonical = resolve_canonical_cst(item.item.func, imports) + for b in base_tuple: + if canonical is not None and canonical.startswith(f"{b}."): + suffix = canonical[len(b) + 1 :] + if suffix in names: + res_list.append(MatchingCall(node=item.item, base=b, name=suffix)) + break return res_list -def calls_any_of(node: cst.With, *qualnames: str) -> bool: +def calls_any_of( + node: cst.With, *qualnames: str, imports: Mapping[str, str] | None = None +) -> bool: """Return True if `node` contains a withitem matching any of `qualnames`. Each `qualname` is a dotted string like ``"trio.open_nursery"`` or @@ -439,7 +511,8 @@ def calls_any_of(node: cst.With, *qualnames: str) -> bool: assert name, f"{qn!r} is not a dotted qualname" by_base[base].append(name) return any( - with_has_call(node, *names, base=base) for base, names in by_base.items() + with_has_call(node, *names, base=base, imports=imports) + for base, names in by_base.items() ) diff --git a/flake8_async/visitors/visitor101.py b/flake8_async/visitors/visitor101.py index fe03f94..0137f89 100644 --- a/flake8_async/visitors/visitor101.py +++ b/flake8_async/visitors/visitor101.py @@ -76,7 +76,7 @@ def visit_With(self, node: cst.With): self._yield_is_error = ( not self._safe_decorator and not self._yield_is_error - and calls_any_of(node, *_CANCEL_SCOPE_CMS) + and calls_any_of(node, *_CANCEL_SCOPE_CMS, imports=self.imports) ) def leave_With( diff --git a/flake8_async/visitors/visitor102_120.py b/flake8_async/visitors/visitor102_120.py index 83055e1..370e0cd 100644 --- a/flake8_async/visitors/visitor102_120.py +++ b/flake8_async/visitors/visitor102_120.py @@ -96,7 +96,10 @@ def is_safe_aclose_call(self, node: ast.Await) -> bool: return True # allow `trio.aclose_forcefully()` / `anyio.aclose_forcefully()`, # which are specifically designed for cleanup and cancel immediately by design - return get_matching_call(node.value, "aclose_forcefully") is not None + return ( + get_matching_call(node.value, "aclose_forcefully", imports=self.imports) + is not None + ) # trio.lowlevel.cancel_shielded_checkpoint (and the anyio equivalent) are # explicitly a schedule-but-not-cancel point, so they're safe to await @@ -106,7 +109,7 @@ def is_safe_shielded_checkpoint(self, node: ast.Await) -> bool: isinstance(node.value, ast.Call) and not node.value.args and not node.value.keywords - and ast.unparse(node.value.func) + and self.canonical_name(node.value.func) in ( "trio.lowlevel.cancel_shielded_checkpoint", "anyio.lowlevel.cancel_shielded_checkpoint", @@ -133,6 +136,7 @@ def visit_With(self, node: ast.With | ast.AsyncWith): "open_nursery", "create_task_group", *cancel_scope_names, + imports=self.imports, ) if call is None: continue @@ -151,9 +155,17 @@ def visit_AsyncWith(self, node: ast.AsyncWith): # asyncio.TaskGroup() appears to be a source of cancellation when exiting. for item in node.items: if not ( - get_matching_call(item.context_expr, "open_nursery", base="trio") + get_matching_call( + item.context_expr, + "open_nursery", + base="trio", + imports=self.imports, + ) or get_matching_call( - item.context_expr, "create_task_group", base="anyio" + item.context_expr, + "create_task_group", + base="anyio", + imports=self.imports, ) ): self.async_call_checker(node) @@ -193,7 +205,10 @@ def visit_ExceptHandler(self, node: ast.ExceptHandler): self._trio_context_managers = [] self._potential_120 = [] - if self.cancelled_caught or (res := critical_except(node)) is None: + if ( + self.cancelled_caught + or (res := critical_except(node, self.imports)) is None + ): self._critical_scope = Statement("except", node.lineno, node.col_offset) else: self._critical_scope = res diff --git a/flake8_async/visitors/visitor103_104.py b/flake8_async/visitors/visitor103_104.py index 3e234e7..ac63cca 100644 --- a/flake8_async/visitors/visitor103_104.py +++ b/flake8_async/visitors/visitor103_104.py @@ -75,7 +75,7 @@ def __init__(self, *args: Any, **kwargs: Any): # set self.unraised, and if it's still set after visiting child nodes # then there might be a code path that doesn't re-raise. def visit_ExceptHandler(self, node: ast.ExceptHandler): - marker = critical_except(node) + marker = critical_except(node, self.imports) if marker is None: # not a critical exception handler diff --git a/flake8_async/visitors/visitor105.py b/flake8_async/visitors/visitor105.py index c7e0ddf..60aea33 100644 --- a/flake8_async/visitors/visitor105.py +++ b/flake8_async/visitors/visitor105.py @@ -56,8 +56,11 @@ def visit_Call(self, node: ast.Call): if getattr(node, "awaited", False) or "trio" not in self.library: return - if (name := ast.unparse(node.func)) in trio_async_funcs: - self.error(node, name, "function") + canonical = self.canonical_name(node.func) + if canonical in trio_async_funcs: + # report the canonical qualname (rather than the user's local alias) + # so the message reads consistently. + self.error(node, canonical, "function") elif isinstance(node.func, ast.Attribute) and node.func.attr == "start": var = ast.unparse(node.func.value) diff --git a/flake8_async/visitors/visitor111.py b/flake8_async/visitors/visitor111.py index 72dca1f..b2a70d8 100644 --- a/flake8_async/visitors/visitor111.py +++ b/flake8_async/visitors/visitor111.py @@ -12,11 +12,11 @@ from collections.abc import Mapping -def is_nursery_like(node: ast.expr) -> bool: +def is_nursery_like(node: ast.expr, imports: Mapping[str, str] | None = None) -> bool: return bool( - get_matching_call(node, "open_nursery", base="trio") - or get_matching_call(node, "create_task_group", base="anyio") - or get_matching_call(node, "TaskGroup", base="asyncio") + get_matching_call(node, "open_nursery", base="trio", imports=imports) + or get_matching_call(node, "create_task_group", base="anyio", imports=imports) + or get_matching_call(node, "TaskGroup", base="asyncio", imports=imports) ) @@ -56,7 +56,7 @@ def visit_With(self, node: ast.With | ast.AsyncWith): self.TrioContextManager( item.context_expr.lineno, item.optional_vars.id, - is_nursery_like(item.context_expr), + is_nursery_like(item.context_expr, self.imports), ) ) diff --git a/flake8_async/visitors/visitor118.py b/flake8_async/visitors/visitor118.py index 4066f50..343a0f3 100644 --- a/flake8_async/visitors/visitor118.py +++ b/flake8_async/visitors/visitor118.py @@ -27,11 +27,17 @@ class Visitor118(Flake8AsyncVisitor): } def visit_Assign(self, node: ast.Assign | ast.AnnAssign): - if node.value is None: + value = node.value + if value is None: return - name = ast.unparse(node.value) - if re.fullmatch(r"(anyio.)?get_cancelled_exc_class(\(\))?", name): - self.error(node.value) + target = value.func if isinstance(value, ast.Call) else value + if self.canonical_name(target) == "anyio.get_cancelled_exc_class": + self.error(value) + return + # Fallback for code where anyio isn't importable (e.g. stubs or partial + # configs) but the name is still spelled out literally. + if re.fullmatch(r"(anyio.)?get_cancelled_exc_class(\(\))?", ast.unparse(value)): + self.error(value) visit_AnnAssign = visit_Assign diff --git a/flake8_async/visitors/visitor123.py b/flake8_async/visitors/visitor123.py index 32fd5a0..ba8475b 100644 --- a/flake8_async/visitors/visitor123.py +++ b/flake8_async/visitors/visitor123.py @@ -69,6 +69,8 @@ def visit_ExceptHandler(self, node: ast.ExceptHandler): "child_exception_names", copy=True, ) + # [Base]ExceptionGroup are builtins and almost always used unqualified, + # so a substring match on the literal source is sufficient. if node.name is None or ( not self.try_star and (node.type is None or "ExceptionGroup" not in ast.unparse(node.type)) diff --git a/flake8_async/visitors/visitor2xx.py b/flake8_async/visitors/visitor2xx.py index 91df95d..024dcf3 100644 --- a/flake8_async/visitors/visitor2xx.py +++ b/flake8_async/visitors/visitor2xx.py @@ -50,7 +50,9 @@ def visit_Call(self, node: ast.Call): def visit_blocking_call(self, node: ast.Call): blocking_calls = self.options.async200_blocking_calls - if key := fnmatch_qualified_name([node.func], *blocking_calls): + if key := fnmatch_qualified_name( + [node.func], *blocking_calls, imports=self.imports + ): self.error(node, key, blocking_calls[key]) @@ -66,37 +68,26 @@ class Visitor21X(Visitor200): ), } - def __init__(self, *args: Any, **kwargs: Any): - super().__init__(*args, **kwargs) - self.imports: set[str] = set() - - def visit_ImportFrom(self, node: ast.ImportFrom): - if node.module == "urllib3": - self.imports.add(node.module) - - def visit_Import(self, node: ast.Import): - for name in node.names: - if name.name == "urllib3": - # Could also save the name.asname for matching - self.imports.add(name.name) + def _urllib3_imported(self) -> bool: + return any( + v == "urllib3" or v.startswith("urllib3.") for v in self.imports.values() + ) def visit_blocking_call(self, node: ast.Call): - http_methods = { - "get", - "options", - "head", - "post", - "put", - "patch", - "delete", - } + http_methods = {"get", "options", "head", "post", "put", "patch", "delete"} func_name = ast.unparse(node.func) + canonical = self.canonical_name(node.func) or func_name for http_package in "requests", "httpx": - if get_matching_call(node, *http_methods | {"request"}, base=http_package): + if get_matching_call( + node, + *http_methods | {"request"}, + base=http_package, + imports=self.imports, + ): self.error(node, func_name, error_code="ASYNC210") return - if func_name in ( + if canonical in ( "urllib3.request", "urllib.request.urlopen", "request.urlopen", @@ -105,7 +96,7 @@ def visit_blocking_call(self, node: ast.Call): self.error(node, func_name, error_code="ASYNC210") elif ( - "urllib3" in self.imports + self._urllib3_imported() and isinstance(node.func, ast.Attribute) and node.func.attr == "request" and node.args @@ -209,22 +200,26 @@ def is_p_wait(arg: ast.expr) -> bool: "getstatusoutput", } + # Match against the canonical qualname, but report the user's literal spelling. func_name = ast.unparse(node.func) + canonical = self.canonical_name(node.func) or func_name error_code: str | None = None - if func_name in ("subprocess.Popen", "os.popen"): + if canonical in ("subprocess.Popen", "os.popen"): error_code = "ASYNC220" - elif func_name in ( + elif canonical in ( "os.system", "os.posix_spawn", "os.posix_spawnp", - ) or get_matching_call(node, *subprocess_calls, base="subprocess"): + ) or get_matching_call( + node, *subprocess_calls, base="subprocess", imports=self.imports + ): error_code = "ASYNC221" - elif re.fullmatch("os.wait([34]|(id)|(pid))?", func_name): + elif re.fullmatch("os.wait([34]|(id)|(pid))?", canonical): error_code = "ASYNC222" - elif re.fullmatch("os.spawn[vl]p?e?", func_name): + elif re.fullmatch("os.spawn[vl]p?e?", canonical): error_code = "ASYNC221" # if mode= is given and not [os.]P_WAIT: ASYNC220 @@ -265,8 +260,8 @@ class Visitor23X(Visitor200): } def visit_Call(self, node: ast.Call): - func_name = ast.unparse(node.func) - if re.fullmatch(r"(trio|anyio)\.wrap_file", func_name) and len(node.args) == 1: + canonical = self.canonical_name(node.func) + if canonical in ("trio.wrap_file", "anyio.wrap_file") and len(node.args) == 1: setattr(node.args[0], "wrapped", True) # noqa: B010 super().visit_Call(node) @@ -274,9 +269,10 @@ def visit_blocking_call(self, node: ast.Call): if getattr(node, "wrapped", False): return func_name = ast.unparse(node.func) - if func_name in ("open", "io.open", "io.open_code"): + canonical = self.canonical_name(node.func) or func_name + if canonical in ("builtins.open", "open", "io.open", "io.open_code"): error_code = "ASYNC230" - elif func_name == "os.fdopen": + elif canonical == "os.fdopen": error_code = "ASYNC231" else: return @@ -381,9 +377,10 @@ def visit_Call(self, node: ast.Call): return error_code = "ASYNC240_asyncio" if self.library == ("asyncio",) else "ASYNC240" func_name = ast.unparse(node.func) + canonical = self.canonical_name(node.func) or func_name if func_name in self.imports_from_ospath: self.error(node, func_name, self.library_str, error_code=error_code) - elif (m := re.fullmatch(r"os\.path\.(?P.*)", func_name)) and m.group( + elif (m := re.fullmatch(r"os\.path\.(?P.*)", canonical)) and m.group( "func" ) in self.os_funcs: self.error(node, m.group("func"), self.library_str, error_code=error_code) @@ -410,13 +407,14 @@ def visit_Call(self, node: ast.Call): if not self.async_function: return func_name = ast.unparse(node.func) - if func_name == "input": + canonical = self.canonical_name(node.func) or func_name + if canonical in ("input", "builtins.input"): error_code = "ASYNC250" if len(self.library) == 1: msg_param = wrappers[self.library_str] else: msg_param = "/".join(wrappers[lib] for lib in self.library) - elif func_name == "time.sleep": + elif canonical == "time.sleep": error_code = "ASYNC251" msg_param = self.library_str else: diff --git a/flake8_async/visitors/visitor91x.py b/flake8_async/visitors/visitor91x.py index 123fba7..4cebdb1 100644 --- a/flake8_async/visitors/visitor91x.py +++ b/flake8_async/visitors/visitor91x.py @@ -32,7 +32,6 @@ fnmatch_qualified_name_cst, func_has_decorator, get_matching_call_cst, - identifier_to_string, iter_guaranteed_once_cst, ) @@ -131,7 +130,9 @@ def leave_FunctionDef( ) # ignore functions with no_checkpoint_warning_decorators and not fnmatch_qualified_name_cst( - original_node.decorators, *self.options.no_checkpoint_warning_decorators + original_node.decorators, + *self.options.no_checkpoint_warning_decorators, + imports=self.imports, ) ): self.error(original_node) @@ -649,7 +650,9 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: self.async_function = ( node.asynchronous is not None and not fnmatch_qualified_name_cst( - node.decorators, *self.options.no_checkpoint_warning_decorators + node.decorators, + *self.options.no_checkpoint_warning_decorators, + imports=self.imports, ) ) # only visit subnodes if there is an async function defined inside @@ -864,6 +867,7 @@ def _is_exception_suppressing_context_manager(self, node: cst.With) -> bool: "contextlib.suppress", *self.suppress_imported_as, *self.options.exception_suppress_context_managers, + imports=self.imports, ) is not None ) @@ -881,7 +885,7 @@ def _checkpoint_with(self, node: cst.With, entry: bool): return for item in node.items: - if isinstance(item.item, cst.Call) and identifier_to_string( + if isinstance(item.item, cst.Call) and self.canonical_name( item.item.func ) in ( "trio.open_nursery", @@ -919,7 +923,10 @@ def visit_With_body(self, node: cst.With): for withitem in node.items: self.has_checkpoint_stack.append(ContextManager()) if get_matching_call_cst( - withitem.item, "open_nursery", "create_task_group" + withitem.item, + "open_nursery", + "create_task_group", + imports=self.imports, ): if withitem.asname is not None and isinstance( withitem.asname.name, cst.Name @@ -945,6 +952,7 @@ def visit_With_body(self, node: cst.With): "contextlib.suppress", *self.suppress_imported_as, *self.options.exception_suppress_context_managers, + imports=self.imports, ) is not None ): @@ -955,12 +963,15 @@ def visit_With_body(self, node: cst.With): continue if res := ( - get_matching_call_cst(withitem.item, *cancel_scope_names) + get_matching_call_cst( + withitem.item, *cancel_scope_names, imports=self.imports + ) or get_matching_call_cst( withitem.item, "timeout", "timeout_at", base="asyncio", + imports=self.imports, ) ): # typing issue: https://github.com/Instagram/LibCST/issues/1107 diff --git a/flake8_async/visitors/visitor_utility.py b/flake8_async/visitors/visitor_utility.py index 1e70785..1011ffb 100644 --- a/flake8_async/visitors/visitor_utility.py +++ b/flake8_async/visitors/visitor_utility.py @@ -7,16 +7,16 @@ import re from typing import TYPE_CHECKING, Any, cast +import libcst as cst import libcst.matchers as m from libcst.metadata import PositionProvider from .flake8asyncvisitor import Flake8AsyncVisitor, Flake8AsyncVisitor_cst -from .helpers import utility_visitor, utility_visitor_cst +from .helpers import identifier_to_string, utility_visitor, utility_visitor_cst if TYPE_CHECKING: from re import Match - import libcst as cst from libcst.metadata import CodeRange @@ -155,6 +155,119 @@ def visit_Import(self, node: cst.Import): self.add_library(alias.name.value) +# Populate `imports` (local-name -> canonical dotted qualname) so helpers can +# resolve call-sites regardless of import style. Mappings produced: +# "import trio" => {"trio": "trio"} +# "import trio as t" => {"t": "trio"} +# "import trio.lowlevel" => {"trio": "trio", "trio.lowlevel": "trio.lowlevel"} +# "import trio.lowlevel as ll" => {"ll": "trio.lowlevel"} +# "from trio import sleep" => {"sleep": "trio.sleep"} +# "from trio import sleep as s" => {"s": "trio.sleep"} +# +# Only module-level imports are tracked: function-/class-local imports are +# skipped to keep them out of sibling scopes. A full scope-aware resolver +# would also need to know the call site's position, which isn't justified +# given how uncommon local imports of async APIs are. +@utility_visitor +class VisitorImportTracker(Flake8AsyncVisitor): + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self._scope_depth = 0 + + def _at_module_level(self) -> bool: + return self._scope_depth == 0 + + def visit_Import(self, node: ast.Import): + if not self._at_module_level(): + return + for alias in node.names: + if alias.asname is not None: + self.imports[alias.asname] = alias.name + continue + # `import a.b.c` binds `a` and also resolves `a.b.c.` through + # the Attribute chain, so we record both. + top = alias.name.partition(".")[0] + self.imports.setdefault(top, top) + self.imports.setdefault(alias.name, alias.name) + + def visit_ImportFrom(self, node: ast.ImportFrom): + if node.module is None or node.level or not self._at_module_level(): + return + for alias in node.names: + if alias.name == "*": + continue + local = alias.asname if alias.asname is not None else alias.name + self.imports[local] = f"{node.module}.{alias.name}" + + def _enter_scope(self, node: ast.AST): + self.save_state(node, "_scope_depth") + self._scope_depth += 1 + + visit_FunctionDef = _enter_scope + visit_AsyncFunctionDef = _enter_scope + visit_ClassDef = _enter_scope + visit_Lambda = _enter_scope + + +@utility_visitor_cst +class VisitorImportTracker_cst(Flake8AsyncVisitor_cst): + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self._scope_depth = 0 + + def _at_module_level(self) -> bool: + return self._scope_depth == 0 + + def visit_Import(self, node: cst.Import): + if not self._at_module_level(): + return + for alias in node.names: + full_name = identifier_to_string(alias.name) + if full_name is None: + continue + if alias.asname is not None and isinstance(alias.asname.name, cst.Name): + self.imports[alias.asname.name.value] = full_name + continue + top = full_name.partition(".")[0] + self.imports.setdefault(top, top) + self.imports.setdefault(full_name, full_name) + + def visit_ImportFrom(self, node: cst.ImportFrom): + if ( + node.module is None + or node.relative + or isinstance(node.names, cst.ImportStar) + or not self._at_module_level() + ): + return + module = identifier_to_string(node.module) + if module is None: + return + for alias in node.names: + name = identifier_to_string(alias.name) + if name is None: + continue + if alias.asname is not None and isinstance(alias.asname.name, cst.Name): + local = alias.asname.name.value + else: + local = name + self.imports[local] = f"{module}.{name}" + + def _enter_scope(self, node: cst.CSTNode) -> None: + self._scope_depth += 1 + + def _leave_scope(self, original_node: cst.CSTNode, updated_node: Any) -> Any: + self._scope_depth -= 1 + return updated_node + + visit_FunctionDef = _enter_scope + visit_ClassDef = _enter_scope + visit_Lambda = _enter_scope + leave_FunctionDef = _leave_scope + leave_ClassDef = _leave_scope + leave_Lambda = _leave_scope + + # taken from # https://github.com/PyCQA/flake8/blob/d016204366a22d382b5b56dc14b6cbff28ce929e/src/flake8/defaults.py#L27 NOQA_INLINE_REGEXP = re.compile( diff --git a/flake8_async/visitors/visitors.py b/flake8_async/visitors/visitors.py index afc1a30..21ba57e 100644 --- a/flake8_async/visitors/visitors.py +++ b/flake8_async/visitors/visitors.py @@ -13,7 +13,6 @@ error_class_cst, get_matching_call, has_decorator, - identifier_to_string, ) if TYPE_CHECKING: @@ -25,9 +24,11 @@ @error_class +@disabled_by_default class Visitor106(Flake8AsyncVisitor): + # Opt-in style check; other rules already handle all import styles. error_codes: Mapping[str, str] = { - "ASYNC106": "{0} must be imported with `import {0}` for the linter to work.", + "ASYNC106": "{0} should be imported with `import {0}` for consistency.", } def visit_ImportFrom(self, node: ast.ImportFrom): @@ -76,16 +77,16 @@ def visit_While(self, node: ast.While): and isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Await) and ( - get_matching_call(node.body[0].value.value, "sleep", "sleep_until") + get_matching_call( + node.body[0].value.value, + "sleep", + "sleep_until", + imports=self.imports, + ) or ( - # get_matching_call doesn't (currently) support checking for trio.x.y isinstance(call := node.body[0].value.value, ast.Call) - and isinstance(call.func, ast.Attribute) - and call.func.attr == "checkpoint" - and isinstance(call.func.value, ast.Attribute) - and call.func.value.attr == "lowlevel" - and isinstance(call.func.value.value, ast.Name) - and call.func.value.value.id in ("trio", "anyio") + and self.canonical_name(call.func) + in ("trio.lowlevel.checkpoint", "anyio.lowlevel.checkpoint") ) ) ): @@ -116,15 +117,22 @@ def visit_With(self, node: ast.With | ast.AsyncWith): start_methods: tuple[str, ...] = ("start", "start_soon") # check for trio.open_nursery and anyio.create_task_group - if get_matching_call(item.context_expr, "open_nursery", base="trio"): + if get_matching_call( + item.context_expr, "open_nursery", base="trio", imports=self.imports + ): nursery_type = "nursery" elif get_matching_call( - item.context_expr, "create_task_group", base="anyio" + item.context_expr, + "create_task_group", + base="anyio", + imports=self.imports, ): nursery_type = "task group" # check for asyncio.TaskGroup - elif get_matching_call(item.context_expr, "TaskGroup", base="asyncio"): + elif get_matching_call( + item.context_expr, "TaskGroup", base="asyncio", imports=self.imports + ): nursery_type = "task group" start_methods = ("create_task",) else: @@ -138,6 +146,8 @@ def visit_With(self, node: ast.With | ast.AsyncWith): body_call = cast("ast.Call", body_call) if ( + # start[_soon] is called on the nursery/taskgroup variable, + # not a canonically-resolved name, so we don't pass imports. get_matching_call(body_call, *start_methods, base=var_name) # check for presence of as parameter and not any( @@ -304,7 +314,7 @@ class Visitor115(Flake8AsyncVisitor): } def visit_Call(self, node: ast.Call): - if not (m := get_matching_call(node, "sleep")): + if not (m := get_matching_call(node, "sleep", imports=self.imports)): return if ( len(node.args) == 1 @@ -328,7 +338,7 @@ class Visitor116(Flake8AsyncVisitor): } def visit_Call(self, node: ast.Call): - if not (m := get_matching_call(node, "sleep")): + if not (m := get_matching_call(node, "sleep", imports=self.imports)): return if len(node.args) == 1: arg = node.args[0] @@ -425,11 +435,18 @@ def visit_AsyncWith(self, node: ast.AsyncWith): self.save_state(node, "unsafe_stack", copy=True) for item in node.items: - if get_matching_call(item.context_expr, "open_nursery", base="trio"): + if get_matching_call( + item.context_expr, "open_nursery", base="trio", imports=self.imports + ): self.unsafe_stack.append("nursery") elif get_matching_call( - item.context_expr, "create_task_group", base="anyio" - ) or get_matching_call(item.context_expr, "TaskGroup", base="asyncio"): + item.context_expr, + "create_task_group", + base="anyio", + imports=self.imports, + ) or get_matching_call( + item.context_expr, "TaskGroup", base="asyncio", imports=self.imports + ): self.unsafe_stack.append("task group") def visit_While(self, node: ast.While | ast.For | ast.AsyncFor): @@ -484,7 +501,11 @@ def visit_withitem(self, node: ast.withitem): def visit_Call(self, node: ast.Call): if not self.in_withitem and ( match := get_matching_call( - node, "fail_after", "move_on_after", base=("trio", "anyio") + node, + "fail_after", + "move_on_after", + base=("trio", "anyio"), + imports=self.imports, ) ): self.error(node, str(match)) @@ -510,7 +531,12 @@ def is_constant(value: ast.expr) -> bool: return False match = get_matching_call( - node, "fail_at", "move_on_at", "CancelScope", base=("trio", "anyio") + node, + "fail_at", + "move_on_at", + "CancelScope", + base=("trio", "anyio"), + imports=self.imports, ) if match is None: return @@ -548,7 +574,8 @@ def base_name(base: ast.expr) -> str: # strip generic subscripts like `ExceptionGroup[Foo]` if isinstance(base, ast.Subscript): base = base.value - unparsed = ast.unparse(base) + canonical = self.canonical_name(base) + unparsed = canonical if canonical is not None else ast.unparse(base) return unparsed.rsplit(".", 1)[-1] if not any( @@ -586,7 +613,7 @@ def visit_CompIf(self, node: cst.CSTNode): def visit_Call(self, node: cst.Call): if ( - identifier_to_string(node.func) == "asyncio.create_task" + self.canonical_name(node.func) == "asyncio.create_task" and not self.safe_to_create_task ): self.error(node) diff --git a/tests/autofix_files/exception_suppress_context_manager.py b/tests/autofix_files/exception_suppress_context_manager.py index 0704da2..6f85d2a 100644 --- a/tests/autofix_files/exception_suppress_context_manager.py +++ b/tests/autofix_files/exception_suppress_context_manager.py @@ -87,10 +87,13 @@ async def foo_suppress_as(): # ASYNC910: 0, "exit", Statement('function definit # ############################### -# not enabled unless it's imported from contextlib -async def foo_suppress_directly_imported_1(): +# Module-level imports are visible to any function body in the same file +# (Python resolves names at call time), so the `from contextlib import suppress` +# further down makes `suppress` a suppressing CM in this function too. +async def foo_suppress_directly_imported_1(): # ASYNC910: 0, "exit", Statement('function definition', lineno) with suppress(): await foo() + await trio.lowlevel.checkpoint() from contextlib import suppress diff --git a/tests/autofix_files/exception_suppress_context_manager.py.diff b/tests/autofix_files/exception_suppress_context_manager.py.diff index 0de6726..713aa51 100644 --- a/tests/autofix_files/exception_suppress_context_manager.py.diff +++ b/tests/autofix_files/exception_suppress_context_manager.py.diff @@ -50,6 +50,14 @@ # ############################### +@@ x,6 x,7 @@ + async def foo_suppress_directly_imported_1(): # ASYNC910: 0, "exit", Statement('function definition', lineno) + with suppress(): + await foo() ++ await trio.lowlevel.checkpoint() + + + from contextlib import suppress @@ x,6 x,7 @@ async def foo_suppress_directly_imported_2(): # ASYNC910: 0, "exit", Statement('function definition', lineno) with suppress(): diff --git a/tests/eval_files/async110.py b/tests/eval_files/async110.py index 22df373..ccda813 100644 --- a/tests/eval_files/async110.py +++ b/tests/eval_files/async110.py @@ -38,8 +38,8 @@ async def foo(): await trio.sleep() await trio.sleep_until() - # check library name - while ...: + # `import trio as noerror` -- resolves to canonical `trio.sleep`. + while ...: # error: 4, "trio" await noerror.sleep() async def sleep(): ... diff --git a/tests/eval_files/async111.py b/tests/eval_files/async111.py index e40d292..6160a81 100644 --- a/tests/eval_files/async111.py +++ b/tests/eval_files/async111.py @@ -77,10 +77,10 @@ async def foo_2(): async with trio.open_process() as bar_2: nursery.start(bar_2) # safe -# specifically check for *trio*.open_nursery +# `import trio as noterror` -- open_nursery resolves to canonical qualname with noterror.open_nursery() as nursery: with trio.open("") as bar: - nursery.start(bar) + nursery.start(bar) # error: 22, line-1, line-2, "bar", "start" # specifically check for trio.*open_nursery* with trio.open_nurse() as nursery: diff --git a/tests/eval_files/async112.py b/tests/eval_files/async112.py index 8657973..9e9559e 100644 --- a/tests/eval_files/async112.py +++ b/tests/eval_files/async112.py @@ -86,8 +86,8 @@ async def foo_1(): await n.start(...) -# not *trio*.open_nursery -with noterror.open_nursery(...) as n: +# `import trio as noterror` -- open_nursery resolves to canonical qualname +with noterror.open_nursery(...) as n: # error: 5, "n", "nursery" n.start(...) # not trio.*open_nursery* diff --git a/tests/eval_files/async112_canonical_qualname.py b/tests/eval_files/async112_canonical_qualname.py new file mode 100644 index 0000000..6ca47f1 --- /dev/null +++ b/tests/eval_files/async112_canonical_qualname.py @@ -0,0 +1,25 @@ +# Regression test for https://github.com/python-trio/flake8-async/issues/132: +# rules fire against the canonical qualname regardless of import style. +# type: ignore +# ASYNCIO_NO_ERROR +# ARG --enable=ASYNC112 + +import trio +import trio as t +from trio import open_nursery +from trio import open_nursery as on + +with t.open_nursery() as n: # error: 5, "n", "nursery" + n.start(...) + + +with open_nursery() as n: # error: 5, "n", "nursery" + n.start(...) + + +with on() as n: # error: 5, "n", "nursery" + n.start_soon(...) + + +with trio.open_nursery() as n: # error: 5, "n", "nursery" + n.start(...) diff --git a/tests/eval_files/async115.py b/tests/eval_files/async115.py index 05dd758..fd5beb2 100644 --- a/tests/eval_files/async115.py +++ b/tests/eval_files/async115.py @@ -18,9 +18,10 @@ async def afoo(): trio.sleep(0) # error: 4, "trio" trio.sleep(1) - # don't error on other sleeps + # unrelated sleeps don't match time.sleep(0) - sleep(0) + # `from trio import sleep` -- resolves to canonical `trio.sleep` + sleep(0) # error: 4, "trio" # in trio it's called 'seconds', in anyio it's 'delay', but # we don't care about the kwarg name. #382 diff --git a/tests/eval_files/async115_canonical_qualname.py b/tests/eval_files/async115_canonical_qualname.py new file mode 100644 index 0000000..550c47c --- /dev/null +++ b/tests/eval_files/async115_canonical_qualname.py @@ -0,0 +1,27 @@ +# Regression test for https://github.com/python-trio/flake8-async/issues/132: +# rules fire against the canonical qualname regardless of import style. +# type: ignore +# ASYNCIO_NO_ERROR - ASYNC115 is trio/anyio-only +# ARG --enable=ASYNC115 + +import trio as t +import trio.lowlevel as ll +from trio import sleep +from trio import sleep as nap +from trio.lowlevel import checkpoint as cp + + +async def afoo(): + await t.sleep(0) # error: 10, "trio" + await sleep(0) # error: 10, "trio" + await nap(0) # error: 10, "trio" + + # `import trio.lowlevel as ll` and `from trio.lowlevel import ... as ...` + # are resolvable but aren't matched by ASYNC115 -- we're just asserting + # that resolution doesn't misfire. + ll.checkpoint() + cp() + + # a local name that shadows nothing imported must not match + sleep_2 = lambda x: None + sleep_2(0) diff --git a/tests/eval_files/async251.py b/tests/eval_files/async251.py index 9da4a10..b6b86cb 100644 --- a/tests/eval_files/async251.py +++ b/tests/eval_files/async251.py @@ -6,6 +6,5 @@ async def foo(): time.sleep(5) # ASYNC251: 4, "trio" time.sleep(5) if 5 else time.sleep(5) # ASYNC251: 4, "trio" # ASYNC251: 28, "trio" - # Not handled due to difficulty tracking imports and not wanting to trigger - # false positives. But could definitely be handled by ruff et al. - sleep(5) + # `from time import sleep` -- resolves to canonical `time.sleep` + sleep(5) # ASYNC251: 4, "trio" diff --git a/tests/eval_files/async300.py b/tests/eval_files/async300.py index 87ad535..13e3f5b 100644 --- a/tests/eval_files/async300.py +++ b/tests/eval_files/async300.py @@ -67,7 +67,7 @@ def returner_list(): with asyncio.create_task(*args) as k: # type: ignore[attr-defined] # ASYNC300: 9 ... - # import aliasing is not supported (this would raise ASYNC106 bad-async-library-import) + # function-local imports aren't tracked (so they don't leak to siblings) from asyncio import create_task create_task(*args) diff --git a/tests/eval_files/exception_suppress_context_manager.py b/tests/eval_files/exception_suppress_context_manager.py index 4b809d7..9871d52 100644 --- a/tests/eval_files/exception_suppress_context_manager.py +++ b/tests/eval_files/exception_suppress_context_manager.py @@ -80,8 +80,10 @@ async def foo_suppress_as(): # ASYNC910: 0, "exit", Statement('function definit # ############################### -# not enabled unless it's imported from contextlib -async def foo_suppress_directly_imported_1(): +# Module-level imports are visible to any function body in the same file +# (Python resolves names at call time), so the `from contextlib import suppress` +# further down makes `suppress` a suppressing CM in this function too. +async def foo_suppress_directly_imported_1(): # ASYNC910: 0, "exit", Statement('function definition', lineno) with suppress(): await foo() diff --git a/tests/test_config_and_args.py b/tests/test_config_and_args.py index 4e93c71..5547e49 100644 --- a/tests/test_config_and_args.py +++ b/tests/test_config_and_args.py @@ -491,9 +491,8 @@ def test_disable_noqa_ast( out, err = capsys.readouterr() assert not err assert ( - out - == "./example.py:1:1: ASYNC106 trio must be imported with `import trio` for the" - " linter to work.\n" + out == "./example.py:1:1: ASYNC106 trio should be imported with `import trio`" + " for consistency.\n" )