Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 54 additions & 16 deletions src/agentex/lib/adk/_modules/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from datetime import timedelta
from typing import Any

from temporalio import workflow
from temporalio.common import RetryPolicy
from temporalio.exceptions import ActivityError, TimeoutError as TemporalTimeoutError, is_cancelled_exception

from agentex import AsyncAgentex # noqa: F401
from agentex.lib.adk.utils._modules.client import create_async_agentex_client
Expand All @@ -26,6 +28,18 @@
logger = make_logger(__name__)

DEFAULT_RETRY_POLICY = RetryPolicy(maximum_attempts=1)
TEMPORAL_SPAN_ACTIVITY_DROPPED_METRIC = "agentex.tracing.temporal_span_activity.dropped"


def _record_temporal_span_activity_dropped(event_type: str) -> None:
try:
workflow.metric_meter().create_counter(
TEMPORAL_SPAN_ACTIVITY_DROPPED_METRIC,
description="Temporal tracing span activities dropped after fail-open",
unit="1",
).add(1, {"event_type": event_type})
except Exception:
pass


class TracingModule:
Expand Down Expand Up @@ -180,14 +194,26 @@ async def start_span(
task_id=task_id,
)
if in_temporal_workflow():
return await ActivityHelpers.execute_activity(
activity_name=TracingActivityName.START_SPAN,
request=params,
response_type=Span,
start_to_close_timeout=start_to_close_timeout,
retry_policy=retry_policy,
heartbeat_timeout=heartbeat_timeout,
)
try:
return await ActivityHelpers.execute_activity(
activity_name=TracingActivityName.START_SPAN,
request=params,
response_type=Span,
start_to_close_timeout=start_to_close_timeout,
retry_policy=retry_policy,
heartbeat_timeout=heartbeat_timeout,
)
except (ActivityError, TemporalTimeoutError) as err:
if is_cancelled_exception(err):
raise
workflow.logger.warning(
"Failed to start tracing span %r for trace_id=%r; continuing without tracing",
name,
trace_id,
exc_info=True,
)
_record_temporal_span_activity_dropped("start")
return None
Comment thread
danielmillerp marked this conversation as resolved.
else:
return await self._tracing_service.start_span(
trace_id=trace_id,
Expand Down Expand Up @@ -224,14 +250,26 @@ async def end_span(
span=span,
)
if in_temporal_workflow():
return await ActivityHelpers.execute_activity(
activity_name=TracingActivityName.END_SPAN,
request=params,
response_type=Span,
start_to_close_timeout=start_to_close_timeout,
retry_policy=retry_policy,
heartbeat_timeout=heartbeat_timeout,
)
try:
return await ActivityHelpers.execute_activity(
activity_name=TracingActivityName.END_SPAN,
request=params,
response_type=Span,
start_to_close_timeout=start_to_close_timeout,
retry_policy=retry_policy,
heartbeat_timeout=heartbeat_timeout,
)
except (ActivityError, TemporalTimeoutError) as err:
if is_cancelled_exception(err):
raise
workflow.logger.warning(
"Failed to end tracing span %r for trace_id=%r; continuing without closing trace",
span.id,
trace_id,
exc_info=True,
)
_record_temporal_span_activity_dropped("end")
return span
else:
return await self._tracing_service.end_span(
trace_id=trace_id,
Expand Down
3 changes: 2 additions & 1 deletion src/agentex/lib/sdk/state_machine/state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ async def reset_to_initial_state(self):
"""
Reset the state machine to its initial state.
"""
span = None
if self._trace_transitions:
if self._task_id is None:
raise ValueError(
Expand All @@ -126,7 +127,7 @@ async def reset_to_initial_state(self):

await self.transition(self._initial_state)

if self._trace_transitions:
if self._trace_transitions and span is not None:
span.output = {"output_state": self._initial_state} # type: ignore[assignment,union-attr]
await adk.tracing.end_span(trace_id=self._task_id, span=span)

Expand Down
145 changes: 144 additions & 1 deletion tests/lib/adk/test_tracing_module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

from datetime import datetime, timezone
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from temporalio.exceptions import ActivityError

import agentex.lib.adk._modules.tracing as _tracing_mod
from agentex.types.span import Span
Expand All @@ -26,6 +29,24 @@ def _make_module() -> tuple[AsyncMock, TracingModule]:
return mock_service, module


def _make_activity_error() -> ActivityError:
return ActivityError(
"activity timed out",
scheduled_event_id=1,
started_event_id=2,
identity="worker-1",
activity_type="start-span",
activity_id="activity-1",
retry_state=None,
)


def _make_metric_meter() -> MagicMock:
mock_meter = MagicMock()
mock_meter.create_counter.return_value = MagicMock()
return mock_meter


class TestStartSpan:
async def test_start_span_with_task_id(self):
mock_service, module = _make_module()
Expand Down Expand Up @@ -87,6 +108,128 @@ async def test_end_span_preserves_task_id(self):
mock_service.end_span.assert_called_once_with(trace_id="trace-123", span=span)


class TestTracingModuleTemporalPath:
async def test_start_span_in_workflow_returns_none_when_activity_fails(self):
mock_service, module = _make_module()
mock_meter = _make_metric_meter()

with patch.object(_tracing_mod, "in_temporal_workflow", return_value=True), \
patch.object(_tracing_mod, "ActivityHelpers") as mock_helpers, \
patch.object(_tracing_mod.workflow, "logger") as mock_logger, \
patch.object(_tracing_mod.workflow, "metric_meter", return_value=mock_meter):
mock_helpers.execute_activity = AsyncMock(side_effect=_make_activity_error())
result = await module.start_span(trace_id="trace-123", name="test-span")

assert result is None
mock_logger.warning.assert_called_once()
mock_meter.create_counter.assert_called_once_with(
_tracing_mod.TEMPORAL_SPAN_ACTIVITY_DROPPED_METRIC,
description="Temporal tracing span activities dropped after fail-open",
unit="1",
)
mock_meter.create_counter.return_value.add.assert_called_once_with(
1, {"event_type": "start"}
)
mock_helpers.execute_activity.assert_called_once()
mock_service.start_span.assert_not_called()

async def test_end_span_in_workflow_returns_span_when_activity_fails(self):
mock_service, module = _make_module()
span = _make_span()
mock_meter = _make_metric_meter()

with patch.object(_tracing_mod, "in_temporal_workflow", return_value=True), \
patch.object(_tracing_mod, "ActivityHelpers") as mock_helpers, \
patch.object(_tracing_mod.workflow, "logger") as mock_logger, \
patch.object(_tracing_mod.workflow, "metric_meter", return_value=mock_meter):
mock_helpers.execute_activity = AsyncMock(side_effect=_make_activity_error())
result = await module.end_span(trace_id="trace-123", span=span)

assert result == span
mock_logger.warning.assert_called_once()
mock_meter.create_counter.assert_called_once_with(
_tracing_mod.TEMPORAL_SPAN_ACTIVITY_DROPPED_METRIC,
description="Temporal tracing span activities dropped after fail-open",
unit="1",
)
mock_meter.create_counter.return_value.add.assert_called_once_with(
1, {"event_type": "end"}
)
mock_helpers.execute_activity.assert_called_once()
mock_service.end_span.assert_not_called()

async def test_context_manager_skips_end_when_temporal_start_fails(self):
mock_service, module = _make_module()

with patch.object(_tracing_mod, "in_temporal_workflow", return_value=True), \
patch.object(_tracing_mod, "ActivityHelpers") as mock_helpers, \
patch.object(_tracing_mod.workflow, "logger"):
mock_helpers.execute_activity = AsyncMock(side_effect=_make_activity_error())
async with module.span(trace_id="trace-123", name="test-span") as span:
assert span is None

mock_helpers.execute_activity.assert_called_once()
mock_service.start_span.assert_not_called()
mock_service.end_span.assert_not_called()

async def test_start_span_in_workflow_propagates_unexpected_errors(self):
mock_service, module = _make_module()

with patch.object(_tracing_mod, "in_temporal_workflow", return_value=True), \
patch.object(_tracing_mod, "ActivityHelpers") as mock_helpers:
mock_helpers.execute_activity = AsyncMock(side_effect=RuntimeError("bad response shape"))
try:
await module.start_span(trace_id="trace-123", name="test-span")
except RuntimeError as exc:
assert str(exc) == "bad response shape"
else:
raise AssertionError("Expected unexpected errors to propagate")

mock_helpers.execute_activity.assert_called_once()
mock_service.start_span.assert_not_called()

async def test_start_span_in_workflow_propagates_cancellation(self):
mock_service, module = _make_module()
activity_error = _make_activity_error()
mock_meter = _make_metric_meter()

with patch.object(_tracing_mod, "in_temporal_workflow", return_value=True), \
patch.object(_tracing_mod, "ActivityHelpers") as mock_helpers, \
patch.object(_tracing_mod, "is_cancelled_exception", return_value=True), \
patch.object(_tracing_mod.workflow, "logger") as mock_logger, \
patch.object(_tracing_mod.workflow, "metric_meter", return_value=mock_meter):
mock_helpers.execute_activity = AsyncMock(side_effect=activity_error)

with pytest.raises(ActivityError):
await module.start_span(trace_id="trace-123", name="test-span")

mock_logger.warning.assert_not_called()
mock_meter.create_counter.assert_not_called()
mock_helpers.execute_activity.assert_called_once()
mock_service.start_span.assert_not_called()

async def test_end_span_in_workflow_propagates_cancellation(self):
mock_service, module = _make_module()
span = _make_span()
activity_error = _make_activity_error()
mock_meter = _make_metric_meter()

with patch.object(_tracing_mod, "in_temporal_workflow", return_value=True), \
patch.object(_tracing_mod, "ActivityHelpers") as mock_helpers, \
patch.object(_tracing_mod, "is_cancelled_exception", return_value=True), \
patch.object(_tracing_mod.workflow, "logger") as mock_logger, \
patch.object(_tracing_mod.workflow, "metric_meter", return_value=mock_meter):
mock_helpers.execute_activity = AsyncMock(side_effect=activity_error)

with pytest.raises(ActivityError):
await module.end_span(trace_id="trace-123", span=span)

mock_logger.warning.assert_not_called()
mock_meter.create_counter.assert_not_called()
mock_helpers.execute_activity.assert_called_once()
mock_service.end_span.assert_not_called()


class TestSpanContextManager:
async def test_span_context_manager_forwards_task_id(self):
mock_service, module = _make_module()
Expand Down
68 changes: 68 additions & 0 deletions tests/lib/test_state_machine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from __future__ import annotations

from typing import override
from unittest.mock import AsyncMock, patch

from agentex.lib.sdk.state_machine import State, StateMachine, StateWorkflow
from agentex.lib.utils.model_utils import BaseModel


class ExampleData(BaseModel):
value: int = 0


class InitialWorkflow(StateWorkflow):
transitions = ["next"]

@override
async def execute(self, state_machine, state_machine_data=None):
return "next"


class NextWorkflow(StateWorkflow):
transitions = ["initial"]

@override
async def execute(self, state_machine, state_machine_data=None):
return "initial"


class ExampleStateMachine(StateMachine[ExampleData]):
@override
async def terminal_condition(self):
return False


def _make_state_machine() -> ExampleStateMachine:
return ExampleStateMachine(
initial_state="initial",
states=[
State(name="initial", workflow=InitialWorkflow()),
State(name="next", workflow=NextWorkflow()),
],
task_id="task-123",
state_machine_data=ExampleData(value=1),
trace_transitions=True,
)


async def test_reset_to_initial_state_skips_end_span_when_start_span_fails_open():
state_machine = _make_state_machine()
await state_machine.transition("next")

with patch(
"agentex.lib.sdk.state_machine.state_machine.adk.tracing.start_span",
new=AsyncMock(return_value=None),
) as start_span, patch(
"agentex.lib.sdk.state_machine.state_machine.adk.tracing.end_span",
new=AsyncMock(),
) as end_span:
await state_machine.reset_to_initial_state()

assert state_machine.get_current_state() == "initial"
start_span.assert_awaited_once_with(
trace_id="task-123",
name="state_transition_reset",
input={"input_state": "next"},
)
end_span.assert_not_awaited()
Loading