Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions temporalio/contrib/langgraph/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,107 @@ await g.ainvoke({...}, context=Context(user_id="alice"))

Your `context` object must be serializable by the configured Temporal payload converter, since it crosses the Activity boundary.

## Streaming

When `streaming_topic` is set on `LangGraphPlugin`, calls to `langgraph.config.get_stream_writer()` inside a node publish to the named topic on the workflow's [`WorkflowStream`](https://github.com/temporalio/sdk-python/tree/main/temporalio/contrib/workflow_streams). Activity-side nodes publish via `WorkflowStreamClient` (a signal carrying batched items, controlled by `streaming_batch_interval`); workflow-side nodes publish synchronously to the in-workflow stream (no signal). External subscribers consume the stream with `WorkflowStreamClient.create(...).topic(...).subscribe(...)`.

The workflow **must** construct `WorkflowStream()` in its `@workflow.init` (i.e. `__init__`)

```python
from datetime import timedelta
from typing import Any

from langgraph.config import get_stream_writer
from langgraph.graph import START, StateGraph
from typing_extensions import TypedDict

from temporalio import workflow
from temporalio.client import Client
from temporalio.contrib.langgraph import LangGraphPlugin, graph
from temporalio.contrib.workflow_streams import WorkflowStream, WorkflowStreamClient
from temporalio.worker import Worker


class State(TypedDict):
value: str


async def token_node(state: State) -> dict[str, str]:
writer = get_stream_writer()
for token in ["hello", " ", "world"]:
writer({"token": token})
writer({"done": True})
return {"value": "hello world"}


@workflow.defn
class StreamingWorkflow:
def __init__(self) -> None:
# Required when streaming_topic is set on the plugin.
_ = WorkflowStream()
self.app = graph("streaming").compile()

@workflow.run
async def run(self) -> str:
result = await self.app.ainvoke({"value": ""})
return result["value"]


async def main(client: Client) -> None:
g = StateGraph(State)
g.add_node("token_node", token_node, metadata={"execute_in": "activity"})
g.add_edge(START, "token_node")

async with Worker(
client,
task_queue="streaming-tq",
workflows=[StreamingWorkflow],
plugins=[
LangGraphPlugin(
graphs={"streaming": g},
default_activity_options={
"start_to_close_timeout": timedelta(seconds=10)
},
streaming_topic="tokens",
)
],
):
handle = await client.start_workflow(
StreamingWorkflow.run, id="streaming-wf", task_queue="streaming-tq"
)

ws_client = WorkflowStreamClient.create(client, handle.id)
async for item in ws_client.topic("tokens", type=dict).subscribe(from_offset=0):
print(item.data)
if item.data.get("done"):
break

print(await handle.result())
```

### What's covered, and what isn't

`streaming_topic` wires up exactly **one** LangGraph stream mode: `stream_mode="custom"`, i.e. values written through `get_stream_writer()`. The other modes — `"messages"`, `"values"`, `"updates"`, `"debug"` — are **not** captured by `streaming_topic`. They aren't produced by node-side writers; LangGraph's orchestrator emits them as it walks the graph. The documented pattern is to **bridge `astream()` in the workflow** and republish each yielded chunk to a `WorkflowStream` topic yourself:

```python
@workflow.defn
class AstreamBridge:
def __init__(self) -> None:
self.stream = WorkflowStream()
self.app = graph("g").compile()

@workflow.run
async def run(self) -> None:
topic = self.stream.topic("astream")
async for chunk in self.app.astream({...}, stream_mode="messages"):
topic.publish(chunk)
topic.publish({"done": True})
```

### Retry semantics

Streaming has **at-least-once** delivery per activity attempt. When an activity-wrapped node retries (transient failure, worker crash, etc.), the user function re-runs from scratch and re-publishes its writes — earlier publishes from the failed attempt are not rolled back. Subscribers should be ready to see duplicates and recover idempotently (e.g. dedupe on a sequence id you include in each chunk, or treat the stream as advisory and rely on the workflow's final result for state).

## Tracing

We recommend the [Temporal LangSmith Plugin](https://github.com/temporalio/sdk-python/tree/main/temporalio/contrib/langsmith) to trace your LangGraph Workflows and Activities.
Expand Down
2 changes: 1 addition & 1 deletion temporalio/contrib/langgraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

__all__ = [
"LangGraphPlugin",
"entrypoint",
"cache",
"entrypoint",
"graph",
]
Comment thread
brianstrauch marked this conversation as resolved.
64 changes: 46 additions & 18 deletions temporalio/contrib/langgraph/_activity.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Activity wrappers for executing LangGraph nodes and tasks."""

import asyncio
from collections.abc import Awaitable
from dataclasses import dataclass
from datetime import timedelta
from inspect import iscoroutinefunction, signature
from typing import Any, Callable

Expand All @@ -19,6 +21,7 @@
cache_lookup,
cache_put,
)
from temporalio.contrib.workflow_streams import WorkflowStreamClient

# Per-run dedupe so we only warn once when a user passes a Store via
# graph.compile(store=...) / @entrypoint(store=...). Cleared by
Expand Down Expand Up @@ -51,28 +54,53 @@ class ActivityOutput:

def wrap_activity(
func: Callable,
*,
streaming_topic: str | None = None,
Comment thread
brianstrauch marked this conversation as resolved.
streaming_batch_interval: timedelta = timedelta(milliseconds=100),
) -> Callable[[ActivityInput], Awaitable[ActivityOutput]]:
"""Wrap a function as a Temporal activity that handles LangGraph config and interrupts."""
# Graph nodes declare `runtime: Runtime[Ctx]` in their signature; tasks
# don't and instead reach for Runtime via get_runtime(). We re-inject the
# reconstructed Runtime only when the user function asks.
accepts_runtime = "runtime" in signature(func).parameters

async def wrapper(input: ActivityInput) -> ActivityOutput:
runtime = set_langgraph_config(input.langgraph_config)
kwargs = dict(input.kwargs)
if accepts_runtime:
kwargs["runtime"] = runtime
try:
if iscoroutinefunction(func):
result = await func(*input.args, **kwargs)
else:
result = func(*input.args, **kwargs)
if isinstance(result, Command):
return ActivityOutput(langgraph_command=result)
return ActivityOutput(result=result)
except GraphInterrupt as e:
return ActivityOutput(langgraph_interrupts=e.args[0])
async def run(stream_writer: Callable[[Any], None] | None) -> ActivityOutput:
# Sync funcs run on a thread (so the loop keeps flushing the
# stream client mid-execution); marshal writer calls back to
# the loop thread because the client's flush event is an
# asyncio.Event and isn't safe to set off-thread.
effective_writer = stream_writer
if not iscoroutinefunction(func) and stream_writer is not None:
loop = asyncio.get_running_loop()
inner_writer = stream_writer

def thread_safe_writer(value: Any) -> None:
loop.call_soon_threadsafe(inner_writer, value)

effective_writer = thread_safe_writer

runtime = set_langgraph_config(
input.langgraph_config, stream_writer=effective_writer
)
kwargs = dict(input.kwargs)
if "runtime" in signature(func).parameters:
kwargs["runtime"] = runtime
Comment on lines +79 to +84

try:
if iscoroutinefunction(func):
result = await func(*input.args, **kwargs)
else:
result = await asyncio.to_thread(func, *input.args, **kwargs)
if isinstance(result, Command):
return ActivityOutput(langgraph_command=result)
return ActivityOutput(result=result)
except GraphInterrupt as e:
return ActivityOutput(langgraph_interrupts=e.args[0])

if streaming_topic is None:
return await run(stream_writer=None)
async with WorkflowStreamClient.from_within_activity(
Comment thread
brianstrauch marked this conversation as resolved.
batch_interval=streaming_batch_interval,
) as client:
topic = client.topic(streaming_topic)
return await run(stream_writer=topic.publish)
Comment thread
brianstrauch marked this conversation as resolved.

return wrapper

Expand Down
16 changes: 16 additions & 0 deletions temporalio/contrib/langgraph/_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from temporalio import workflow
from temporalio.contrib.langgraph._activity import clear_store_warning
from temporalio.contrib.workflow_streams._stream import _PUBLISH_SIGNAL
from temporalio.worker import (
ExecuteWorkflowInput,
Interceptor,
Expand All @@ -30,17 +31,20 @@ def __init__(
self,
graphs: dict[str, StateGraph[Any, Any, Any, Any]],
entrypoints: dict[str, Pregel[Any, Any, Any, Any]],
streaming_topic: str | None = None,
) -> None:
"""Initialize with the graphs and entrypoints to scope to each workflow run."""
self._graphs = graphs
self._entrypoints = entrypoints
self._streaming_topic = streaming_topic

def workflow_interceptor_class(
self, input: WorkflowInterceptorClassInput
) -> type[WorkflowInboundInterceptor]:
"""Return the inbound interceptor class used to scope graphs per run."""
graphs = self._graphs
entrypoints = self._entrypoints
streaming_topic = self._streaming_topic

class Inbound(WorkflowInboundInterceptor):
def init(self, outbound: WorkflowOutboundInterceptor) -> None:
Expand All @@ -50,6 +54,18 @@ def init(self, outbound: WorkflowOutboundInterceptor) -> None:
super().init(outbound)

async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any:
if (
streaming_topic is not None
and workflow.get_signal_handler(_PUBLISH_SIGNAL) is None
):
raise RuntimeError(
f"LangGraphPlugin was configured with "
f"streaming_topic={streaming_topic!r}, but workflow "
f"{workflow.info().workflow_type!r} did not register a "
f"WorkflowStream. Construct WorkflowStream() in the "
f"workflow's @workflow.init (i.e. __init__) method so "
f"streaming activities can publish to it."
)
try:
return await self.next.execute_workflow(input)
finally:
Expand Down
10 changes: 7 additions & 3 deletions temporalio/contrib/langgraph/_langgraph_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# pyright: reportMissingTypeStubs=false

import dataclasses
from typing import Any
from typing import Any, Callable

from langchain_core.runnables.config import var_child_runnable_config
from langgraph._internal._constants import (
Expand Down Expand Up @@ -93,7 +93,11 @@ def get_langgraph_config() -> dict[str, Any]:
}


def set_langgraph_config(config: dict[str, Any]) -> Runtime:
def set_langgraph_config(
config: dict[str, Any],
*,
stream_writer: Callable[[Any], None] | None = None,
) -> Runtime:
"""Restore a LangGraph runnable config from a serialized dict.

Returns the reconstructed Runtime so callers can re-inject it into the
Expand All @@ -112,7 +116,7 @@ def get_null_resume(consume: bool = False) -> Any:
execution_info_dict = config.get("execution_info")
runtime = Runtime(
context=config.get("context"),
stream_writer=lambda _: None,
stream_writer=stream_writer or (lambda _: None),
previous=config.get("previous"),
execution_info=(
ExecutionInfo(**execution_info_dict) if execution_info_dict else None
Expand Down
Loading
Loading