feat(langchain): register stream transformers on middleware (#37591)

Adds a `transformers` attribute to `AgentMiddleware` so middleware can
declare scope-aware `StreamTransformer` factories alongside their
`tools` and lifecycle hooks. `create_agent` merges middleware-registered
factories with any caller-supplied ones at compile time.

## API

```python
class MyMiddleware(AgentMiddleware):
    transformers = (MyTransformer,)  # factory: (scope,) -> StreamTransformer
```

When the agent compiles, the final transformer order on the run mux is:

1. Built-in ``ToolCallTransformer``
2. Middleware-registered factories, in middleware order
3. Caller-supplied ``transformers=`` from ``create_agent``

This ordering keeps the built-in tool-call projection in front of any
consumer transformers and gives caller-supplied entries the final word.
This commit is contained in:
Nick Hollon
2026-05-21 12:08:54 -04:00
committed by GitHub
parent d2931d878f
commit 1aa4496fb4
3 changed files with 168 additions and 5 deletions

View File

@@ -81,6 +81,7 @@ if TYPE_CHECKING:
from langgraph.graph.state import CompiledStateGraph from langgraph.graph.state import CompiledStateGraph
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from langgraph.store.base import BaseStore from langgraph.store.base import BaseStore
from langgraph.stream._mux import TransformerFactory
from langgraph.types import Checkpointer from langgraph.types import Checkpointer
from langchain.agents.middleware.types import ToolCallWrapper from langchain.agents.middleware.types import ToolCallWrapper
@@ -708,7 +709,7 @@ def create_agent(
debug: bool = False, debug: bool = False,
name: str | None = None, name: str | None = None,
cache: BaseCache[Any] | None = None, cache: BaseCache[Any] | None = None,
transformers: Sequence[Callable[[tuple[str, ...]], Any]] | None = None, transformers: Sequence[TransformerFactory] | None = None,
) -> CompiledStateGraph[ ) -> CompiledStateGraph[
AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT] AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
]: ]:
@@ -806,9 +807,11 @@ def create_agent(
cache: An optional `BaseCache` instance to enable caching of graph execution. cache: An optional `BaseCache` instance to enable caching of graph execution.
transformers: Optional sequence of scope-aware `StreamTransformer` transformers: Optional sequence of scope-aware `StreamTransformer`
factories to register on the compiled graph in addition to factories to register on the compiled graph in addition to
the agent defaults. Each factory is invoked per-scope the agent defaults. Each factory is invoked as `factory(scope)`
(`factory(scope)`) so subgraph mini-muxes get fresh so every invocation receives a fresh instance. The final order
instances. Appended after the built-in `ToolCallTransformer`. on the compiled graph is: `ToolCallTransformer`, then any
factories declared by middleware via
`AgentMiddleware.transformers`, then any factories supplied here.
Returns: Returns:
A compiled `StateGraph` that can be used for chat interactions. A compiled `StateGraph` that can be used for chat interactions.
@@ -1662,6 +1665,8 @@ def create_agent(
if name: if name:
config["metadata"]["lc_agent_name"] = name config["metadata"]["lc_agent_name"] = name
middleware_transformers = [t for m in middleware for t in getattr(m, "transformers", ())]
return graph.compile( return graph.compile(
checkpointer=checkpointer, checkpointer=checkpointer,
store=store, store=store,
@@ -1670,7 +1675,11 @@ def create_agent(
debug=debug, debug=debug,
name=name, name=name,
cache=cache, cache=cache,
transformers=[ToolCallTransformer, *(transformers or ())], transformers=[
ToolCallTransformer,
*middleware_transformers,
*(transformers or ()),
],
).with_config(config) ).with_config(config)

View File

@@ -40,6 +40,7 @@ if TYPE_CHECKING:
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from langgraph.stream._mux import TransformerFactory
from langgraph.types import Command from langgraph.types import Command
from langchain.agents.structured_output import ResponseFormat from langchain.agents.structured_output import ResponseFormat
@@ -397,6 +398,16 @@ class AgentMiddleware(Generic[StateT, ContextT, ResponseT]):
tools: Sequence[BaseTool] tools: Sequence[BaseTool]
"""Additional tools registered by the middleware.""" """Additional tools registered by the middleware."""
transformers: Sequence[TransformerFactory] = ()
"""Stream transformer factories registered by the middleware.
Each entry is a scope-aware factory invoked as `factory(scope)` so every
invocation receives a fresh instance. Factories are merged with the
`transformers` argument of [`create_agent`][langchain.agents.create_agent]
at graph compile time, after the `ToolCallTransformer` and before any
user-supplied entries.
"""
@property @property
def name(self) -> str: def name(self) -> str:
"""The name of the middleware instance. """The name of the middleware instance.

View File

@@ -0,0 +1,143 @@
"""Tests for middleware-registered stream transformers."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from langchain_core.messages import HumanMessage
from langgraph.prebuilt import ToolCallTransformer
from langgraph.stream import StreamChannel, StreamTransformer
from langchain.agents.factory import create_agent
from langchain.agents.middleware.types import AgentMiddleware
from tests.unit_tests.agents.model import FakeToolCallingModel
if TYPE_CHECKING:
from langgraph.stream._types import ProtocolEvent
class _MiddlewareMarker(StreamTransformer):
"""Marker transformer used to assert registration order."""
required_stream_modes = ()
def __init__(self, scope: tuple[str, ...] = ()) -> None:
super().__init__(scope)
self._log: StreamChannel[int] = StreamChannel()
def init(self) -> dict[str, Any]:
return {"middleware_marker": self._log}
def process(self, event: ProtocolEvent) -> bool:
del event
return True
class _UserMarker(StreamTransformer):
"""Second marker to verify user-supplied transformers append last."""
required_stream_modes = ()
def __init__(self, scope: tuple[str, ...] = ()) -> None:
super().__init__(scope)
self._log: StreamChannel[int] = StreamChannel()
def init(self) -> dict[str, Any]:
return {"user_marker": self._log}
def process(self, event: ProtocolEvent) -> bool:
del event
return True
def test_middleware_transformer_registered_on_compiled_graph() -> None:
"""A `transformers` factory declared on middleware is wired into the run mux."""
class _Middleware(AgentMiddleware):
transformers = (_MiddlewareMarker,)
agent = create_agent(model=FakeToolCallingModel(), tools=[], middleware=[_Middleware()])
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
assert "middleware_marker" in run._mux.extensions # type: ignore[attr-defined]
# Drain to close the run cleanly.
list(run.tool_calls)
def test_middleware_and_user_transformers_compose_in_order() -> None:
"""Order is: built-in `ToolCallTransformer` → middleware → user-supplied."""
class _Middleware(AgentMiddleware):
transformers = (_MiddlewareMarker,)
agent = create_agent(
model=FakeToolCallingModel(),
tools=[],
middleware=[_Middleware()],
transformers=[_UserMarker],
)
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
transformers = run._mux._transformers # type: ignore[attr-defined]
tool_call_idx = next(
i for i, t in enumerate(transformers) if isinstance(t, ToolCallTransformer)
)
middleware_idx = next(i for i, t in enumerate(transformers) if isinstance(t, _MiddlewareMarker))
user_idx = next(i for i, t in enumerate(transformers) if isinstance(t, _UserMarker))
assert tool_call_idx < middleware_idx < user_idx, (
"transformers must register as: built-in, then middleware, then user-supplied"
)
list(run.tool_calls)
def test_transformers_from_multiple_middleware_preserve_middleware_order() -> None:
"""Transformers across middleware register in middleware-list order."""
class _MarkerA(_MiddlewareMarker):
def init(self) -> dict[str, Any]:
return {"marker_a": self._log}
class _MarkerB(_MiddlewareMarker):
def init(self) -> dict[str, Any]:
return {"marker_b": self._log}
class _MwA(AgentMiddleware):
transformers = (_MarkerA,)
class _MwB(AgentMiddleware):
transformers = (_MarkerB,)
agent = create_agent(
model=FakeToolCallingModel(),
tools=[],
middleware=[_MwA(), _MwB()],
)
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
transformers = run._mux._transformers # type: ignore[attr-defined]
idx_a = next(i for i, t in enumerate(transformers) if isinstance(t, _MarkerA))
idx_b = next(i for i, t in enumerate(transformers) if isinstance(t, _MarkerB))
assert idx_a < idx_b
list(run.tool_calls)
def test_middleware_without_transformers_does_not_affect_registry() -> None:
"""Middleware that omits `transformers` leaves the default registry intact."""
class _Middleware(AgentMiddleware):
pass
agent = create_agent(model=FakeToolCallingModel(), tools=[], middleware=[_Middleware()])
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
transformers = run._mux._transformers # type: ignore[attr-defined]
assert any(isinstance(t, ToolCallTransformer) for t in transformers)
assert not any(isinstance(t, _MiddlewareMarker) for t in transformers)
list(run.tool_calls)