diff --git a/docs/changelog.rst b/docs/changelog.rst index 7df8b2e..4df8d04 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -11,6 +11,8 @@ Unreleased - :ref:`ASYNC300 ` no longer triggers when the result of ``asyncio.create_task()`` is returned from a function. `(issue #398) `_ - Add :ref:`ASYNC126 ` exceptiongroup-subclass-missing-derive. `(issue #334) `_ - :ref:`ASYNC102 ` no longer warns on ``await trio.aclose_forcefully(...)`` / ``await anyio.aclose_forcefully(...)``, which are designed for cleanup and cancel immediately by design. `(issue #446) `_ +- :ref:`ASYNC101 ` now also triggers for common third-party context managers that open internal cancel scopes, nurseries, or task groups: ``trio_websocket.{open_websocket, open_websocket_url, serve_websocket}``, ``trio_asyncio.open_loop``, ``trio_parallel.open_worker_context``, ``trio_util.{move_on_when, run_and_cancelling}``, ``qtrio.{open_emissions_nursery, enter_emissions_channel}``, ``anyio.from_thread.{BlockingPortal, start_blocking_portal}``, ``asgi_lifespan.LifespanManager``, ``apscheduler.AsyncScheduler``, ``mcp.client.streamable_http.streamablehttp_client``, and ``mcp.client.sse.sse_client``. `(issue #350) `_ +- :ref:`ASYNC102 ` and :ref:`ASYNC120 ` no longer trigger on ``await trio.lowlevel.cancel_shielded_checkpoint()`` (or the ``anyio.lowlevel`` equivalent), which is explicitly a schedule-but-not-cancel point and therefore safe inside ``finally`` / cancelled ``except`` / ``__aexit__``. 25.7.1 ====== diff --git a/docs/rules.rst b/docs/rules.rst index 81685c9..ba25950 100644 --- a/docs/rules.rst +++ b/docs/rules.rst @@ -20,6 +20,7 @@ _`ASYNC101` : yield-in-cancel-scope ``yield`` inside a :ref:`taskgroup_nursery` or :ref:`timeout_context` is only safe when implementing a context manager - otherwise, it breaks exception handling. See `this thread `_ for discussion of a future PEP. This has substantial overlap with :ref:`ASYNC119 `, which will warn on almost all instances of ASYNC101, but ASYNC101 is about a conceptually different problem that will not get resolved by :pep:`533`. + Also triggered on common third-party context managers that open internal cancel scopes, nurseries, or task groups: ``trio_websocket.{open_websocket, open_websocket_url, serve_websocket}``, ``trio_asyncio.open_loop``, ``trio_parallel.open_worker_context``, ``trio_util.{move_on_when, run_and_cancelling}``, ``qtrio.{open_emissions_nursery, enter_emissions_channel}``, ``anyio.from_thread.{BlockingPortal, start_blocking_portal}``, ``asgi_lifespan.LifespanManager``, ``apscheduler.AsyncScheduler``, ``mcp.client.streamable_http.streamablehttp_client``, and ``mcp.client.sse.sse_client``. _`ASYNC102` : await-in-finally-or-cancelled ``await`` inside ``finally``, :ref:`cancelled-catching ` ``except:``, or ``__aexit__`` must have shielded :ref:`cancel scope ` with timeout. diff --git a/flake8_async/visitors/helpers.py b/flake8_async/visitors/helpers.py index 5764e16..c9d4f54 100644 --- a/flake8_async/visitors/helpers.py +++ b/flake8_async/visitors/helpers.py @@ -6,6 +6,7 @@ from __future__ import annotations import ast +from collections import defaultdict from collections.abc import Sized from dataclasses import dataclass from fnmatch import fnmatch @@ -355,7 +356,7 @@ def build_cst_matcher(attr: str) -> m.BaseExpression: """Build a cst matcher structure with attributes&names matching a string `a.b.c`.""" if "." not in attr: return m.Name(value=attr) - body, tail = attr.rsplit(".") + body, tail = attr.rsplit(".", 1) return m.Attribute(value=build_cst_matcher(body), attr=m.Name(value=tail)) @@ -424,6 +425,24 @@ def with_has_call( return res_list +def calls_any_of(node: cst.With, *qualnames: str) -> bool: + """Return True if `node` contains a withitem matching any of `qualnames`. + + Each `qualname` is a dotted string like ``"trio.open_nursery"`` or + ``"mcp.client.sse.sse_client"``: everything before the final dot is the + base, the final component is the function/class name. + """ + by_base: dict[str, list[str]] = defaultdict(list) + for qn in qualnames: + base, _, name = qn.rpartition(".") + assert base, f"{qn!r} is not a dotted qualname" + assert name, f"{qn!r} is not a dotted qualname" + by_base[base].append(name) + return any( + with_has_call(node, *names, base=base) for base, names in by_base.items() + ) + + def func_has_decorator(func: cst.FunctionDef, *names: str) -> bool: return any( list_contains( diff --git a/flake8_async/visitors/visitor101.py b/flake8_async/visitors/visitor101.py index 59d0463..fe03f94 100644 --- a/flake8_async/visitors/visitor101.py +++ b/flake8_async/visitors/visitor101.py @@ -10,10 +10,41 @@ from .flake8asyncvisitor import Flake8AsyncVisitor_cst from .helpers import ( + calls_any_of, cancel_scope_names, error_class_cst, func_has_decorator, - with_has_call, +) + +# Qualified names of context managers that open a nursery / task group / cancel +# scope. `yield`ing inside any of these breaks exception handling unless the +# enclosing function is itself a context manager (see ASYNC101 docs). +_CANCEL_SCOPE_CMS: tuple[str, ...] = ( + # nursery/taskgroup + "trio.open_nursery", + "anyio.create_task_group", + "asyncio.TaskGroup", + # stdlib cancel scopes + "asyncio.timeout", + "asyncio.timeout_at", + # trio/anyio share the same cancel-scope spelling + *(f"{lib}.{name}" for lib in ("trio", "anyio") for name in cancel_scope_names), + # 3rd-party CMs with internal cancel scopes / nurseries. See issue #350. + "trio_websocket.open_websocket", + "trio_websocket.open_websocket_url", + "trio_websocket.serve_websocket", + "trio_asyncio.open_loop", + "trio_parallel.open_worker_context", + "trio_util.move_on_when", + "trio_util.run_and_cancelling", + "qtrio.open_emissions_nursery", + "qtrio.enter_emissions_channel", + "anyio.from_thread.BlockingPortal", + "anyio.from_thread.start_blocking_portal", + "asgi_lifespan.LifespanManager", + "apscheduler.AsyncScheduler", + "mcp.client.streamable_http.streamablehttp_client", + "mcp.client.sse.sse_client", ) if TYPE_CHECKING: @@ -45,17 +76,7 @@ def visit_With(self, node: cst.With): self._yield_is_error = ( not self._safe_decorator and not self._yield_is_error - # It's not strictly necessary to specify the base, as raising errors on - # e.g. anyio.open_nursery isn't much of a problem. - and bool( - # nursery/taskgroup - with_has_call(node, "open_nursery", base="trio") - or with_has_call(node, "create_task_group", base="anyio") - or with_has_call(node, "TaskGroup", base="asyncio") - # cancel scopes - or with_has_call(node, "timeout", "timeout_at", base="asyncio") - or with_has_call(node, *cancel_scope_names, base=("trio", "anyio")) - ) + and calls_any_of(node, *_CANCEL_SCOPE_CMS) ) def leave_With( diff --git a/flake8_async/visitors/visitor102_120.py b/flake8_async/visitors/visitor102_120.py index f191403..83055e1 100644 --- a/flake8_async/visitors/visitor102_120.py +++ b/flake8_async/visitors/visitor102_120.py @@ -98,9 +98,27 @@ def is_safe_aclose_call(self, node: ast.Await) -> bool: # which are specifically designed for cleanup and cancel immediately by design return get_matching_call(node.value, "aclose_forcefully") is not None + # trio.lowlevel.cancel_shielded_checkpoint (and the anyio equivalent) are + # explicitly a schedule-but-not-cancel point, so they're safe to await + # inside a finally / cancelled except / __aexit__. + def is_safe_shielded_checkpoint(self, node: ast.Await) -> bool: + return ( + isinstance(node.value, ast.Call) + and not node.value.args + and not node.value.keywords + and ast.unparse(node.value.func) + in ( + "trio.lowlevel.cancel_shielded_checkpoint", + "anyio.lowlevel.cancel_shielded_checkpoint", + ) + ) + def visit_Await(self, node: ast.Await): - # allow calls to `.aclose()` and `[trio/anyio].aclose_forcefully(...)` - if not (self.is_safe_aclose_call(node)): + # allow calls to `.aclose()`, `[trio/anyio].aclose_forcefully(...)`, and + # `[trio/anyio].lowlevel.cancel_shielded_checkpoint()` + if not ( + self.is_safe_aclose_call(node) or self.is_safe_shielded_checkpoint(node) + ): self.async_call_checker(node) visit_AsyncFor = async_call_checker diff --git a/tests/eval_files/async101_third_party.py b/tests/eval_files/async101_third_party.py new file mode 100644 index 0000000..26c2638 --- /dev/null +++ b/tests/eval_files/async101_third_party.py @@ -0,0 +1,116 @@ +# 3rd-party context managers with internal cancel scopes / nurseries / +# task groups. `yield`ing inside any of them breaks exception handling. +# +# Package names like `trio_websocket` / `qtrio` survive the test framework's +# `trio` -> `anyio` / `asyncio` library substitution because that substitution +# only matches at word boundaries. +from contextlib import asynccontextmanager + +import anyio.from_thread +import apscheduler +import asgi_lifespan +import mcp.client.sse +import mcp.client.streamable_http +import qtrio +import trio_asyncio +import trio_parallel +import trio_util +import trio_websocket + + +# trio_websocket +async def foo_open_websocket(): + async with trio_websocket.open_websocket("h", 80, "/", use_ssl=False) as _: + yield 1 # error: 8 + + +async def foo_open_websocket_url(): + async with trio_websocket.open_websocket_url("ws://x") as _: + yield 1 # error: 8 + + +async def foo_serve_websocket(): + async with trio_websocket.serve_websocket( + lambda *_: None, "h", 80, ssl_context=None + ) as _: + yield 1 # error: 8 + + +@asynccontextmanager +async def foo_trio_websocket_safe(): + async with trio_websocket.open_websocket_url("ws://x") as _: + yield 1 # safe + + +# trio_asyncio +async def foo_open_loop(): + async with trio_asyncio.open_loop() as _: + yield 1 # error: 8 + + +# trio_parallel +async def foo_open_worker_context(): + async with trio_parallel.open_worker_context() as _: + yield 1 # error: 8 + + +# trio_util +async def foo_move_on_when(): + async with trio_util.move_on_when(lambda: None) as _: + yield 1 # error: 8 + + +async def foo_run_and_cancelling(): + async with trio_util.run_and_cancelling(lambda: None) as _: + yield 1 # error: 8 + + +# qtrio +async def foo_open_emissions_nursery(): + async with qtrio.open_emissions_nursery() as _: + yield 1 # error: 8 + + +async def foo_enter_emissions_channel(): + async with qtrio.enter_emissions_channel(signals=()) as _: + yield 1 # error: 8 + + +# asgi_lifespan +async def foo_lifespan_manager(): + async with asgi_lifespan.LifespanManager(None) as _: + yield 1 # error: 8 + + +@asynccontextmanager +async def foo_lifespan_manager_safe(): + async with asgi_lifespan.LifespanManager(None) as _: + yield 1 # safe + + +# apscheduler v4 +async def foo_async_scheduler(): + async with apscheduler.AsyncScheduler() as _: + yield 1 # error: 8 + + +# anyio.from_thread +async def foo_blocking_portal(): + with anyio.from_thread.BlockingPortal() as _: + yield 1 # error: 8 + + +async def foo_start_blocking_portal(): + with anyio.from_thread.start_blocking_portal() as _: + yield 1 # error: 8 + + +# MCP SDK +async def foo_streamablehttp_client(): + async with mcp.client.streamable_http.streamablehttp_client("http://x") as _: + yield 1 # error: 8 + + +async def foo_sse_client(): + async with mcp.client.sse.sse_client("http://x") as _: + yield 1 # error: 8 diff --git a/tests/eval_files/async102.py b/tests/eval_files/async102.py index 3d52a28..bac389b 100644 --- a/tests/eval_files/async102.py +++ b/tests/eval_files/async102.py @@ -364,3 +364,25 @@ async def foo_aclose_forcefully(): ... finally: await aclose_forcefully(x) # ASYNC102: 8, Statement("try/finally", lineno-3) + + +# exclude `await *.lowlevel.cancel_shielded_checkpoint()`, which is +# explicitly a schedule-but-not-cancel point. +async def foo_cancel_shielded_checkpoint(): + try: + ... + except BaseException: + await trio.lowlevel.cancel_shielded_checkpoint() + finally: + await trio.lowlevel.cancel_shielded_checkpoint() + + +# still raise errors if there are args, or a different name +# fmt: off +async def foo_cancel_shielded_checkpoint_bad(): + try: + ... + finally: + await trio.lowlevel.cancel_shielded_checkpoint(foo) # ASYNC102: 8, Statement("try/finally", lineno-3) + await trio.lowlevel.checkpoint() # ASYNC102: 8, Statement("try/finally", lineno-4) +# fmt: on diff --git a/tests/test_flake8_async.py b/tests/test_flake8_async.py index aa00f49..c85e564 100644 --- a/tests/test_flake8_async.py +++ b/tests/test_flake8_async.py @@ -104,11 +104,25 @@ def diff_strings(first: str, second: str, /) -> str: # make sure only single newline at end of file -# replaces all instances of `original` with `new` in string -# unless it's preceded by a `-`, which indicates it's part of a command-line flag +# replaces all instances of `original` with `new` in string, matching at word +# boundaries so e.g. "trio" doesn't rewrite "trio_websocket" or "qtrio", and +# skipping occurrences preceded by a `-` (which would be part of a CLI flag). def replace_library(string: str, original: str = "trio", new: str = "anyio") -> str: def replace_str(string: str, original: str, new: str) -> str: - return re.sub(rf"(?