Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ Changelog

Unreleased
==========
- Rules resolve function/class references via the canonical qualname, so checks fire regardless of import style (``import trio``, ``import trio as t``, ``from trio import open_nursery [as on]``, …). Only module-level imports are tracked. `(issue #132) <https://github.com/python-trio/flake8-async/issues/132>`_
- :ref:`ASYNC106 <async106>` is now disabled by default; re-enable it to enforce the ``import trio`` style.
- Autofix for :ref:`ASYNC910 <async910>` / :ref:`ASYNC911 <async911>` no longer inserts checkpoints inside ``except`` clauses (which would trigger :ref:`ASYNC120 <async120>`); instead the checkpoint is added at the top of the function or of the enclosing loop. `(issue #403) <https://github.com/python-trio/flake8-async/issues/403>`_
- :ref:`ASYNC910 <async910>` and :ref:`ASYNC911 <async911>` now accept ``__aenter__`` / ``__aexit__`` methods when the partner method provides the checkpoint, or when only one of the two is defined on a class that inherits from another class (charitably assuming the partner is inherited and contains a checkpoint). `(issue #441) <https://github.com/python-trio/flake8-async/issues/441>`_
- :ref:`ASYNC300 <async300>` no longer triggers when the result of ``asyncio.create_task()`` is returned from a function. `(issue #398) <https://github.com/python-trio/flake8-async/issues/398>`_
Expand Down
5 changes: 3 additions & 2 deletions docs/rules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ ASYNC105 : missing-await
async trio function called without using ``await``.
This is only supported with trio functions, but you can get similar functionality with a type-checker.

ASYNC106 : bad-async-library-import
trio/anyio/asyncio must be imported with ``import xxx`` for the linter to work.
_`ASYNC106` : bad-async-library-import
trio/anyio/asyncio should be imported with ``import xxx`` for consistency.
Opt-in style check; the linter resolves other import styles correctly.

ASYNC109 : async-function-with-timeout
Async function definition with a ``timeout`` parameter.
Expand Down
5 changes: 5 additions & 0 deletions flake8_async/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ class SharedState:
library: tuple[str, ...] = ()
typed_calls: dict[str, str] = field(default_factory=dict[str, str])
variables: dict[str, str] = field(default_factory=dict[str, str])
# Local name -> canonical dotted qualname, populated by VisitorImportTracker[_cst].
# Helpers consult this so rules can match the canonical qualname regardless of
# how a symbol was imported (`import x`, `import x as y`, `from x import y`,
# `from x import y as z`).
imports: dict[str, str] = field(default_factory=dict[str, str])


class __CommonRunner:
Expand Down
42 changes: 42 additions & 0 deletions flake8_async/visitors/_canonical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Canonical-qualname resolution for ast / cst nodes.

Kept in its own module to avoid circular imports between
``flake8asyncvisitor`` (which exposes ``canonical_name`` on the base classes)
and ``helpers`` (which accepts an ``imports`` mapping for matcher functions).
"""

from __future__ import annotations

import ast
from typing import TYPE_CHECKING

import libcst as cst

if TYPE_CHECKING:
from collections.abc import Mapping


# Resolve a Name/Attribute/Call node to a dotted qualname via `imports`
# (local-name -> canonical dotted qualname). The root Name falls back to its own
# identifier, so `trio.open_nursery()` resolves to "trio.open_nursery" even when
# nothing was imported. Returns None for shapes we can't resolve (subscripts, etc.).
def resolve_canonical_ast(node: ast.AST, imports: Mapping[str, str]) -> str | None:
if isinstance(node, ast.Name):
return imports.get(node.id, node.id)
if isinstance(node, ast.Attribute):
prefix = resolve_canonical_ast(node.value, imports)
return None if prefix is None else f"{prefix}.{node.attr}"
if isinstance(node, ast.Call):
return resolve_canonical_ast(node.func, imports)
return None


def resolve_canonical_cst(node: cst.CSTNode, imports: Mapping[str, str]) -> str | None:
if isinstance(node, cst.Name):
return imports.get(node.value, node.value)
if isinstance(node, cst.Attribute):
prefix = resolve_canonical_cst(node.value, imports)
return None if prefix is None else f"{prefix}.{node.attr.value}"
if isinstance(node, cst.Call):
return resolve_canonical_cst(node.func, imports)
return None
15 changes: 15 additions & 0 deletions flake8_async/visitors/flake8asyncvisitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from libcst.metadata import PositionProvider

from ..base import Error, Statement, strip_error_subidentifier
from ._canonical import resolve_canonical_ast, resolve_canonical_cst

if TYPE_CHECKING:
from collections.abc import Iterable, Mapping
Expand Down Expand Up @@ -53,6 +54,13 @@ def variables(self, value: dict[str, str]) -> None:
self.__state.variables.clear()
self.__state.variables.update(value)

@property
def imports(self) -> dict[str, str]:
return self.__state.imports

def canonical_name(self, node: ast.AST) -> str | None:
return resolve_canonical_ast(node, self.__state.imports)

def visit(self, node: ast.AST):
"""Visit a node."""
# construct visitor for this node type
Expand Down Expand Up @@ -170,6 +178,13 @@ def __init__(self, shared_state: SharedState):
self.options = self.__state.options
self.noqas = self.__state.noqas

@property
def imports(self) -> dict[str, str]:
return self.__state.imports

def canonical_name(self, node: cst.CSTNode) -> str | None:
return resolve_canonical_cst(node, self.__state.imports)

def get_state(self, *attrs: str, copy: bool = False) -> dict[str, Any]:
# require attrs, since we inherit a *ton* of stuff which we don't want to copy
assert attrs
Expand Down
125 changes: 99 additions & 26 deletions flake8_async/visitors/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
utility_visitors,
utility_visitors_cst,
)
from ._canonical import resolve_canonical_ast, resolve_canonical_cst

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Sequence
from collections.abc import Iterable, Iterator, Mapping, Sequence

from .flake8asyncvisitor import (
Flake8AsyncVisitor,
Expand Down Expand Up @@ -101,29 +102,44 @@ def has_decorator(node: ast.FunctionDef | ast.AsyncFunctionDef, *names: str):
# matches the fully qualified name against fnmatch pattern
# used to match decorators and methods to user-supplied patterns
# used in 910/911 and 200
def fnmatch_qualified_name(name_list: list[ast.expr], *patterns: str) -> str | None:
def fnmatch_qualified_name(
name_list: Iterable[ast.expr],
*patterns: str,
imports: Mapping[str, str] | None = None,
) -> str | None:
for name in name_list:
if isinstance(name, ast.Call):
name = name.func
qualified_name = ast.unparse(name)

candidates = {ast.unparse(name)}
if imports is not None and (canonical := resolve_canonical_ast(name, imports)):
candidates.add(canonical)
for pattern in patterns:
# strip leading "@"s for when we're working with decorators
if fnmatch(qualified_name, pattern.lstrip("@")):
stripped = pattern.lstrip("@")
if any(fnmatch(c, stripped) for c in candidates):
return pattern
return None


def fnmatch_qualified_name_cst(
name_list: Iterable[cst.Decorator | cst.Call | cst.Attribute | cst.Name],
*patterns: str,
imports: Mapping[str, str] | None = None,
) -> str | None:
for name in name_list:
qualified_name = get_full_name_for_node_or_raise(name)

candidates = {get_full_name_for_node_or_raise(name)}
if imports is not None:
inner: cst.CSTNode = name
if isinstance(inner, cst.Decorator):
inner = inner.decorator
if isinstance(inner, cst.Call):
inner = inner.func
if (canonical := resolve_canonical_cst(inner, imports)) is not None:
candidates.add(canonical)
for pattern in patterns:
# strip leading "@"s for when we're working with decorators
if fnmatch(qualified_name, pattern.lstrip("@")):
stripped = pattern.lstrip("@")
if any(fnmatch(c, stripped) for c in candidates):
return pattern
return None

Expand Down Expand Up @@ -240,7 +256,9 @@ def iter_guaranteed_once_cst(iterable: cst.BaseExpression) -> bool:


# used in 102, 103 and 104
def critical_except(node: ast.ExceptHandler) -> Statement | None:
def critical_except(
node: ast.ExceptHandler, imports: Mapping[str, str] | None = None
) -> Statement | None:
def has_exception(node: ast.expr) -> str | None:
name = ast.unparse(node)
if name in (
Expand All @@ -253,6 +271,27 @@ def has_exception(node: ast.expr) -> str | None:
"CancelledError",
):
return name
if imports is None:
return None
# Resolve via canonical qualname for aliased / `from`-imported forms.
# The non-call spellings (`except anyio.get_cancelled_exc_class:`, or a
# Call with arguments) are type-errors that critical_except intentionally
# ignores, so only zero-arg calls count for get_cancelled_exc_class.
if isinstance(node, ast.Call):
if node.args or node.keywords:
return None
canonical = resolve_canonical_ast(node.func, imports)
if canonical == "anyio.get_cancelled_exc_class":
return "anyio.get_cancelled_exc_class()"
return None
canonical = resolve_canonical_ast(node, imports)
if canonical == "trio.Cancelled":
return "trio.Cancelled"
if canonical in (
"asyncio.exceptions.CancelledError",
"asyncio.CancelledError",
):
return "asyncio.exceptions.CancelledError"
return None

name: str | None = None
Expand Down Expand Up @@ -302,36 +341,56 @@ def __str__(self) -> str:

# convenience function used in a lot of visitors
def get_matching_call(
node: ast.AST, *names: str, base: Iterable[str] = ("trio", "anyio")
node: ast.AST,
*names: str,
base: Iterable[str] = ("trio", "anyio"),
imports: Mapping[str, str] | None = None,
) -> MatchingCall[ast.Call] | None:
if isinstance(base, str):
base = (base,)
if not isinstance(node, ast.Call):
return None
if (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Attribute)
isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id in base
and node.func.attr in names
):
return MatchingCall(node, node.func.attr, node.func.value.id)
if imports is not None:
canonical = resolve_canonical_ast(node.func, imports)
for b in base:
for n in names:
if canonical == f"{b}.{n}":
return MatchingCall(node, n, b)
return None


# ___ CST helpers ___
def get_matching_call_cst(
node: cst.CSTNode, *names: str, base: Iterable[str] = ("trio", "anyio")
node: cst.CSTNode,
*names: str,
base: Iterable[str] = ("trio", "anyio"),
imports: Mapping[str, str] | None = None,
) -> MatchingCall[cst.Call] | None:
if isinstance(base, str):
base = (base,)
if not isinstance(node, cst.Call):
return None
if (
isinstance(node, cst.Call)
and isinstance(node.func, cst.Attribute)
isinstance(node.func, cst.Attribute)
and node.func.attr.value in names
and isinstance(node.func.value, (cst.Name, cst.Attribute))
):
attr_base = identifier_to_string(node.func.value)
if attr_base is not None and attr_base in base:
return MatchingCall(node, node.func.attr.value, attr_base)
if imports is not None:
canonical = resolve_canonical_cst(node.func, imports)
for b in base:
for n in names:
if canonical == f"{b}.{n}":
return MatchingCall(node, n, b)
return None


Expand Down Expand Up @@ -377,12 +436,17 @@ def identifier_to_string(node: cst.CSTNode) -> str | None:


def with_has_call(
node: cst.With, *names: str, base: Iterable[str] | str = ("trio", "anyio")
node: cst.With,
*names: str,
base: Iterable[str] | str = ("trio", "anyio"),
imports: Mapping[str, str] | None = None,
) -> list[MatchingCall[cst.Call]]:
"""Check if a with statement has a matching call, returning a list with matches.

`names` specify the names of functions to match, `base` specifies the
library/module(s) the function must be in.
library/module(s) the function must be in. If `imports` is given, matches
are also made against the canonical qualname so aliased / `from`-imports
are detected.
The list elements in the return value are named tuples with the matched node,
base and function.

Expand All @@ -393,19 +457,15 @@ def with_has_call(
`foo.bar`, `foo.bee`, `a.b.c.bar`, and `a.b.c.bee`.

"""
if isinstance(base, str):
base = (base,)
base_tuple = (base,) if isinstance(base, str) else tuple(base)

# build matcher, using SaveMatchedNode to save the base and the function name.
matcher = m.Call(
func=m.Attribute(
value=m.SaveMatchedNode(
m.OneOf(*(build_cst_matcher(b) for b in base)), name="base"
),
attr=m.SaveMatchedNode(
oneof_names(*names),
name="function",
m.OneOf(*(build_cst_matcher(b) for b in base_tuple)), name="base"
),
attr=m.SaveMatchedNode(oneof_names(*names), name="function"),
)
)

Expand All @@ -422,10 +482,22 @@ def with_has_call(
node=item.item, base=base_string, name=res["function"].value
)
)
continue
if imports is None or not isinstance(item.item, cst.Call):
continue
canonical = resolve_canonical_cst(item.item.func, imports)
for b in base_tuple:
if canonical is not None and canonical.startswith(f"{b}."):
suffix = canonical[len(b) + 1 :]
if suffix in names:
res_list.append(MatchingCall(node=item.item, base=b, name=suffix))
break
return res_list


def calls_any_of(node: cst.With, *qualnames: str) -> bool:
def calls_any_of(
node: cst.With, *qualnames: str, imports: Mapping[str, str] | None = None
) -> bool:
"""Return True if `node` contains a withitem matching any of `qualnames`.

Each `qualname` is a dotted string like ``"trio.open_nursery"`` or
Expand All @@ -439,7 +511,8 @@ def calls_any_of(node: cst.With, *qualnames: str) -> bool:
assert name, f"{qn!r} is not a dotted qualname"
by_base[base].append(name)
return any(
with_has_call(node, *names, base=base) for base, names in by_base.items()
with_has_call(node, *names, base=base, imports=imports)
for base, names in by_base.items()
)


Expand Down
2 changes: 1 addition & 1 deletion flake8_async/visitors/visitor101.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def visit_With(self, node: cst.With):
self._yield_is_error = (
not self._safe_decorator
and not self._yield_is_error
and calls_any_of(node, *_CANCEL_SCOPE_CMS)
and calls_any_of(node, *_CANCEL_SCOPE_CMS, imports=self.imports)
)

def leave_With(
Expand Down
Loading
Loading