From e90a03e0caa09e0d36b88808947595d7effcc1ae Mon Sep 17 00:00:00 2001 From: Daniel Miller Date: Tue, 23 Jun 2026 13:07:04 -0400 Subject: [PATCH] fix(tracing): fail open temporal span activities --- src/agentex/lib/adk/_modules/tracing.py | 70 +++++++-- .../lib/sdk/state_machine/state_machine.py | 3 +- tests/lib/adk/test_tracing_module.py | 145 +++++++++++++++++- tests/lib/test_state_machine.py | 68 ++++++++ 4 files changed, 268 insertions(+), 18 deletions(-) create mode 100644 tests/lib/test_state_machine.py diff --git a/src/agentex/lib/adk/_modules/tracing.py b/src/agentex/lib/adk/_modules/tracing.py index 8694c2078..94bf741e4 100644 --- a/src/agentex/lib/adk/_modules/tracing.py +++ b/src/agentex/lib/adk/_modules/tracing.py @@ -6,7 +6,9 @@ from datetime import timedelta from typing import Any +from temporalio import workflow from temporalio.common import RetryPolicy +from temporalio.exceptions import ActivityError, TimeoutError as TemporalTimeoutError, is_cancelled_exception from agentex import AsyncAgentex # noqa: F401 from agentex.lib.adk.utils._modules.client import create_async_agentex_client @@ -26,6 +28,18 @@ logger = make_logger(__name__) DEFAULT_RETRY_POLICY = RetryPolicy(maximum_attempts=1) +TEMPORAL_SPAN_ACTIVITY_DROPPED_METRIC = "agentex.tracing.temporal_span_activity.dropped" + + +def _record_temporal_span_activity_dropped(event_type: str) -> None: + try: + workflow.metric_meter().create_counter( + TEMPORAL_SPAN_ACTIVITY_DROPPED_METRIC, + description="Temporal tracing span activities dropped after fail-open", + unit="1", + ).add(1, {"event_type": event_type}) + except Exception: + pass class TracingModule: @@ -180,14 +194,26 @@ async def start_span( task_id=task_id, ) if in_temporal_workflow(): - return await ActivityHelpers.execute_activity( - activity_name=TracingActivityName.START_SPAN, - request=params, - response_type=Span, - start_to_close_timeout=start_to_close_timeout, - retry_policy=retry_policy, - heartbeat_timeout=heartbeat_timeout, - ) + try: + return await ActivityHelpers.execute_activity( + activity_name=TracingActivityName.START_SPAN, + request=params, + response_type=Span, + start_to_close_timeout=start_to_close_timeout, + retry_policy=retry_policy, + heartbeat_timeout=heartbeat_timeout, + ) + except (ActivityError, TemporalTimeoutError) as err: + if is_cancelled_exception(err): + raise + workflow.logger.warning( + "Failed to start tracing span %r for trace_id=%r; continuing without tracing", + name, + trace_id, + exc_info=True, + ) + _record_temporal_span_activity_dropped("start") + return None else: return await self._tracing_service.start_span( trace_id=trace_id, @@ -224,14 +250,26 @@ async def end_span( span=span, ) if in_temporal_workflow(): - return await ActivityHelpers.execute_activity( - activity_name=TracingActivityName.END_SPAN, - request=params, - response_type=Span, - start_to_close_timeout=start_to_close_timeout, - retry_policy=retry_policy, - heartbeat_timeout=heartbeat_timeout, - ) + try: + return await ActivityHelpers.execute_activity( + activity_name=TracingActivityName.END_SPAN, + request=params, + response_type=Span, + start_to_close_timeout=start_to_close_timeout, + retry_policy=retry_policy, + heartbeat_timeout=heartbeat_timeout, + ) + except (ActivityError, TemporalTimeoutError) as err: + if is_cancelled_exception(err): + raise + workflow.logger.warning( + "Failed to end tracing span %r for trace_id=%r; continuing without closing trace", + span.id, + trace_id, + exc_info=True, + ) + _record_temporal_span_activity_dropped("end") + return span else: return await self._tracing_service.end_span( trace_id=trace_id, diff --git a/src/agentex/lib/sdk/state_machine/state_machine.py b/src/agentex/lib/sdk/state_machine/state_machine.py index f1e5c4239..5679a6bd8 100644 --- a/src/agentex/lib/sdk/state_machine/state_machine.py +++ b/src/agentex/lib/sdk/state_machine/state_machine.py @@ -113,6 +113,7 @@ async def reset_to_initial_state(self): """ Reset the state machine to its initial state. """ + span = None if self._trace_transitions: if self._task_id is None: raise ValueError( @@ -126,7 +127,7 @@ async def reset_to_initial_state(self): await self.transition(self._initial_state) - if self._trace_transitions: + if self._trace_transitions and span is not None: span.output = {"output_state": self._initial_state} # type: ignore[assignment,union-attr] await adk.tracing.end_span(trace_id=self._task_id, span=span) diff --git a/tests/lib/adk/test_tracing_module.py b/tests/lib/adk/test_tracing_module.py index 52d5d3f82..58d5d4a85 100644 --- a/tests/lib/adk/test_tracing_module.py +++ b/tests/lib/adk/test_tracing_module.py @@ -1,7 +1,10 @@ from __future__ import annotations from datetime import datetime, timezone -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from temporalio.exceptions import ActivityError import agentex.lib.adk._modules.tracing as _tracing_mod from agentex.types.span import Span @@ -26,6 +29,24 @@ def _make_module() -> tuple[AsyncMock, TracingModule]: return mock_service, module +def _make_activity_error() -> ActivityError: + return ActivityError( + "activity timed out", + scheduled_event_id=1, + started_event_id=2, + identity="worker-1", + activity_type="start-span", + activity_id="activity-1", + retry_state=None, + ) + + +def _make_metric_meter() -> MagicMock: + mock_meter = MagicMock() + mock_meter.create_counter.return_value = MagicMock() + return mock_meter + + class TestStartSpan: async def test_start_span_with_task_id(self): mock_service, module = _make_module() @@ -87,6 +108,128 @@ async def test_end_span_preserves_task_id(self): mock_service.end_span.assert_called_once_with(trace_id="trace-123", span=span) +class TestTracingModuleTemporalPath: + async def test_start_span_in_workflow_returns_none_when_activity_fails(self): + mock_service, module = _make_module() + mock_meter = _make_metric_meter() + + with patch.object(_tracing_mod, "in_temporal_workflow", return_value=True), \ + patch.object(_tracing_mod, "ActivityHelpers") as mock_helpers, \ + patch.object(_tracing_mod.workflow, "logger") as mock_logger, \ + patch.object(_tracing_mod.workflow, "metric_meter", return_value=mock_meter): + mock_helpers.execute_activity = AsyncMock(side_effect=_make_activity_error()) + result = await module.start_span(trace_id="trace-123", name="test-span") + + assert result is None + mock_logger.warning.assert_called_once() + mock_meter.create_counter.assert_called_once_with( + _tracing_mod.TEMPORAL_SPAN_ACTIVITY_DROPPED_METRIC, + description="Temporal tracing span activities dropped after fail-open", + unit="1", + ) + mock_meter.create_counter.return_value.add.assert_called_once_with( + 1, {"event_type": "start"} + ) + mock_helpers.execute_activity.assert_called_once() + mock_service.start_span.assert_not_called() + + async def test_end_span_in_workflow_returns_span_when_activity_fails(self): + mock_service, module = _make_module() + span = _make_span() + mock_meter = _make_metric_meter() + + with patch.object(_tracing_mod, "in_temporal_workflow", return_value=True), \ + patch.object(_tracing_mod, "ActivityHelpers") as mock_helpers, \ + patch.object(_tracing_mod.workflow, "logger") as mock_logger, \ + patch.object(_tracing_mod.workflow, "metric_meter", return_value=mock_meter): + mock_helpers.execute_activity = AsyncMock(side_effect=_make_activity_error()) + result = await module.end_span(trace_id="trace-123", span=span) + + assert result == span + mock_logger.warning.assert_called_once() + mock_meter.create_counter.assert_called_once_with( + _tracing_mod.TEMPORAL_SPAN_ACTIVITY_DROPPED_METRIC, + description="Temporal tracing span activities dropped after fail-open", + unit="1", + ) + mock_meter.create_counter.return_value.add.assert_called_once_with( + 1, {"event_type": "end"} + ) + mock_helpers.execute_activity.assert_called_once() + mock_service.end_span.assert_not_called() + + async def test_context_manager_skips_end_when_temporal_start_fails(self): + mock_service, module = _make_module() + + with patch.object(_tracing_mod, "in_temporal_workflow", return_value=True), \ + patch.object(_tracing_mod, "ActivityHelpers") as mock_helpers, \ + patch.object(_tracing_mod.workflow, "logger"): + mock_helpers.execute_activity = AsyncMock(side_effect=_make_activity_error()) + async with module.span(trace_id="trace-123", name="test-span") as span: + assert span is None + + mock_helpers.execute_activity.assert_called_once() + mock_service.start_span.assert_not_called() + mock_service.end_span.assert_not_called() + + async def test_start_span_in_workflow_propagates_unexpected_errors(self): + mock_service, module = _make_module() + + with patch.object(_tracing_mod, "in_temporal_workflow", return_value=True), \ + patch.object(_tracing_mod, "ActivityHelpers") as mock_helpers: + mock_helpers.execute_activity = AsyncMock(side_effect=RuntimeError("bad response shape")) + try: + await module.start_span(trace_id="trace-123", name="test-span") + except RuntimeError as exc: + assert str(exc) == "bad response shape" + else: + raise AssertionError("Expected unexpected errors to propagate") + + mock_helpers.execute_activity.assert_called_once() + mock_service.start_span.assert_not_called() + + async def test_start_span_in_workflow_propagates_cancellation(self): + mock_service, module = _make_module() + activity_error = _make_activity_error() + mock_meter = _make_metric_meter() + + with patch.object(_tracing_mod, "in_temporal_workflow", return_value=True), \ + patch.object(_tracing_mod, "ActivityHelpers") as mock_helpers, \ + patch.object(_tracing_mod, "is_cancelled_exception", return_value=True), \ + patch.object(_tracing_mod.workflow, "logger") as mock_logger, \ + patch.object(_tracing_mod.workflow, "metric_meter", return_value=mock_meter): + mock_helpers.execute_activity = AsyncMock(side_effect=activity_error) + + with pytest.raises(ActivityError): + await module.start_span(trace_id="trace-123", name="test-span") + + mock_logger.warning.assert_not_called() + mock_meter.create_counter.assert_not_called() + mock_helpers.execute_activity.assert_called_once() + mock_service.start_span.assert_not_called() + + async def test_end_span_in_workflow_propagates_cancellation(self): + mock_service, module = _make_module() + span = _make_span() + activity_error = _make_activity_error() + mock_meter = _make_metric_meter() + + with patch.object(_tracing_mod, "in_temporal_workflow", return_value=True), \ + patch.object(_tracing_mod, "ActivityHelpers") as mock_helpers, \ + patch.object(_tracing_mod, "is_cancelled_exception", return_value=True), \ + patch.object(_tracing_mod.workflow, "logger") as mock_logger, \ + patch.object(_tracing_mod.workflow, "metric_meter", return_value=mock_meter): + mock_helpers.execute_activity = AsyncMock(side_effect=activity_error) + + with pytest.raises(ActivityError): + await module.end_span(trace_id="trace-123", span=span) + + mock_logger.warning.assert_not_called() + mock_meter.create_counter.assert_not_called() + mock_helpers.execute_activity.assert_called_once() + mock_service.end_span.assert_not_called() + + class TestSpanContextManager: async def test_span_context_manager_forwards_task_id(self): mock_service, module = _make_module() diff --git a/tests/lib/test_state_machine.py b/tests/lib/test_state_machine.py new file mode 100644 index 000000000..ce32ba9f0 --- /dev/null +++ b/tests/lib/test_state_machine.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from typing import override +from unittest.mock import AsyncMock, patch + +from agentex.lib.sdk.state_machine import State, StateMachine, StateWorkflow +from agentex.lib.utils.model_utils import BaseModel + + +class ExampleData(BaseModel): + value: int = 0 + + +class InitialWorkflow(StateWorkflow): + transitions = ["next"] + + @override + async def execute(self, state_machine, state_machine_data=None): + return "next" + + +class NextWorkflow(StateWorkflow): + transitions = ["initial"] + + @override + async def execute(self, state_machine, state_machine_data=None): + return "initial" + + +class ExampleStateMachine(StateMachine[ExampleData]): + @override + async def terminal_condition(self): + return False + + +def _make_state_machine() -> ExampleStateMachine: + return ExampleStateMachine( + initial_state="initial", + states=[ + State(name="initial", workflow=InitialWorkflow()), + State(name="next", workflow=NextWorkflow()), + ], + task_id="task-123", + state_machine_data=ExampleData(value=1), + trace_transitions=True, + ) + + +async def test_reset_to_initial_state_skips_end_span_when_start_span_fails_open(): + state_machine = _make_state_machine() + await state_machine.transition("next") + + with patch( + "agentex.lib.sdk.state_machine.state_machine.adk.tracing.start_span", + new=AsyncMock(return_value=None), + ) as start_span, patch( + "agentex.lib.sdk.state_machine.state_machine.adk.tracing.end_span", + new=AsyncMock(), + ) as end_span: + await state_machine.reset_to_initial_state() + + assert state_machine.get_current_state() == "initial" + start_span.assert_awaited_once_with( + trace_id="task-123", + name="state_transition_reset", + input={"input_state": "next"}, + ) + end_span.assert_not_awaited()