mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
feat(langchain): register stream transformers on middleware (#37591)
Adds a `transformers` attribute to `AgentMiddleware` so middleware can
declare scope-aware `StreamTransformer` factories alongside their
`tools` and lifecycle hooks. `create_agent` merges middleware-registered
factories with any caller-supplied ones at compile time.
## API
```python
class MyMiddleware(AgentMiddleware):
transformers = (MyTransformer,) # factory: (scope,) -> StreamTransformer
```
When the agent compiles, the final transformer order on the run mux is:
1. Built-in ``ToolCallTransformer``
2. Middleware-registered factories, in middleware order
3. Caller-supplied ``transformers=`` from ``create_agent``
This ordering keeps the built-in tool-call projection in front of any
consumer transformers and gives caller-supplied entries the final word.
This commit is contained in:
@@ -81,6 +81,7 @@ if TYPE_CHECKING:
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.store.base import BaseStore
|
||||
from langgraph.stream._mux import TransformerFactory
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from langchain.agents.middleware.types import ToolCallWrapper
|
||||
@@ -708,7 +709,7 @@ def create_agent(
|
||||
debug: bool = False,
|
||||
name: str | None = None,
|
||||
cache: BaseCache[Any] | None = None,
|
||||
transformers: Sequence[Callable[[tuple[str, ...]], Any]] | None = None,
|
||||
transformers: Sequence[TransformerFactory] | None = None,
|
||||
) -> CompiledStateGraph[
|
||||
AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
|
||||
]:
|
||||
@@ -806,9 +807,11 @@ def create_agent(
|
||||
cache: An optional `BaseCache` instance to enable caching of graph execution.
|
||||
transformers: Optional sequence of scope-aware `StreamTransformer`
|
||||
factories to register on the compiled graph in addition to
|
||||
the agent defaults. Each factory is invoked per-scope
|
||||
(`factory(scope)`) so subgraph mini-muxes get fresh
|
||||
instances. Appended after the built-in `ToolCallTransformer`.
|
||||
the agent defaults. Each factory is invoked as `factory(scope)`
|
||||
so every invocation receives a fresh instance. The final order
|
||||
on the compiled graph is: `ToolCallTransformer`, then any
|
||||
factories declared by middleware via
|
||||
`AgentMiddleware.transformers`, then any factories supplied here.
|
||||
|
||||
Returns:
|
||||
A compiled `StateGraph` that can be used for chat interactions.
|
||||
@@ -1662,6 +1665,8 @@ def create_agent(
|
||||
if name:
|
||||
config["metadata"]["lc_agent_name"] = name
|
||||
|
||||
middleware_transformers = [t for m in middleware for t in getattr(m, "transformers", ())]
|
||||
|
||||
return graph.compile(
|
||||
checkpointer=checkpointer,
|
||||
store=store,
|
||||
@@ -1670,7 +1675,11 @@ def create_agent(
|
||||
debug=debug,
|
||||
name=name,
|
||||
cache=cache,
|
||||
transformers=[ToolCallTransformer, *(transformers or ())],
|
||||
transformers=[
|
||||
ToolCallTransformer,
|
||||
*middleware_transformers,
|
||||
*(transformers or ()),
|
||||
],
|
||||
).with_config(config)
|
||||
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ if TYPE_CHECKING:
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.stream._mux import TransformerFactory
|
||||
from langgraph.types import Command
|
||||
|
||||
from langchain.agents.structured_output import ResponseFormat
|
||||
@@ -397,6 +398,16 @@ class AgentMiddleware(Generic[StateT, ContextT, ResponseT]):
|
||||
tools: Sequence[BaseTool]
|
||||
"""Additional tools registered by the middleware."""
|
||||
|
||||
transformers: Sequence[TransformerFactory] = ()
|
||||
"""Stream transformer factories registered by the middleware.
|
||||
|
||||
Each entry is a scope-aware factory invoked as `factory(scope)` so every
|
||||
invocation receives a fresh instance. Factories are merged with the
|
||||
`transformers` argument of [`create_agent`][langchain.agents.create_agent]
|
||||
at graph compile time, after the `ToolCallTransformer` and before any
|
||||
user-supplied entries.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""The name of the middleware instance.
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
"""Tests for middleware-registered stream transformers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.prebuilt import ToolCallTransformer
|
||||
from langgraph.stream import StreamChannel, StreamTransformer
|
||||
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.agents.middleware.types import AgentMiddleware
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.stream._types import ProtocolEvent
|
||||
|
||||
|
||||
class _MiddlewareMarker(StreamTransformer):
|
||||
"""Marker transformer used to assert registration order."""
|
||||
|
||||
required_stream_modes = ()
|
||||
|
||||
def __init__(self, scope: tuple[str, ...] = ()) -> None:
|
||||
super().__init__(scope)
|
||||
self._log: StreamChannel[int] = StreamChannel()
|
||||
|
||||
def init(self) -> dict[str, Any]:
|
||||
return {"middleware_marker": self._log}
|
||||
|
||||
def process(self, event: ProtocolEvent) -> bool:
|
||||
del event
|
||||
return True
|
||||
|
||||
|
||||
class _UserMarker(StreamTransformer):
|
||||
"""Second marker to verify user-supplied transformers append last."""
|
||||
|
||||
required_stream_modes = ()
|
||||
|
||||
def __init__(self, scope: tuple[str, ...] = ()) -> None:
|
||||
super().__init__(scope)
|
||||
self._log: StreamChannel[int] = StreamChannel()
|
||||
|
||||
def init(self) -> dict[str, Any]:
|
||||
return {"user_marker": self._log}
|
||||
|
||||
def process(self, event: ProtocolEvent) -> bool:
|
||||
del event
|
||||
return True
|
||||
|
||||
|
||||
def test_middleware_transformer_registered_on_compiled_graph() -> None:
|
||||
"""A `transformers` factory declared on middleware is wired into the run mux."""
|
||||
|
||||
class _Middleware(AgentMiddleware):
|
||||
transformers = (_MiddlewareMarker,)
|
||||
|
||||
agent = create_agent(model=FakeToolCallingModel(), tools=[], middleware=[_Middleware()])
|
||||
|
||||
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
|
||||
|
||||
assert "middleware_marker" in run._mux.extensions # type: ignore[attr-defined]
|
||||
# Drain to close the run cleanly.
|
||||
list(run.tool_calls)
|
||||
|
||||
|
||||
def test_middleware_and_user_transformers_compose_in_order() -> None:
|
||||
"""Order is: built-in `ToolCallTransformer` → middleware → user-supplied."""
|
||||
|
||||
class _Middleware(AgentMiddleware):
|
||||
transformers = (_MiddlewareMarker,)
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
middleware=[_Middleware()],
|
||||
transformers=[_UserMarker],
|
||||
)
|
||||
|
||||
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
|
||||
|
||||
transformers = run._mux._transformers # type: ignore[attr-defined]
|
||||
tool_call_idx = next(
|
||||
i for i, t in enumerate(transformers) if isinstance(t, ToolCallTransformer)
|
||||
)
|
||||
middleware_idx = next(i for i, t in enumerate(transformers) if isinstance(t, _MiddlewareMarker))
|
||||
user_idx = next(i for i, t in enumerate(transformers) if isinstance(t, _UserMarker))
|
||||
|
||||
assert tool_call_idx < middleware_idx < user_idx, (
|
||||
"transformers must register as: built-in, then middleware, then user-supplied"
|
||||
)
|
||||
|
||||
list(run.tool_calls)
|
||||
|
||||
|
||||
def test_transformers_from_multiple_middleware_preserve_middleware_order() -> None:
|
||||
"""Transformers across middleware register in middleware-list order."""
|
||||
|
||||
class _MarkerA(_MiddlewareMarker):
|
||||
def init(self) -> dict[str, Any]:
|
||||
return {"marker_a": self._log}
|
||||
|
||||
class _MarkerB(_MiddlewareMarker):
|
||||
def init(self) -> dict[str, Any]:
|
||||
return {"marker_b": self._log}
|
||||
|
||||
class _MwA(AgentMiddleware):
|
||||
transformers = (_MarkerA,)
|
||||
|
||||
class _MwB(AgentMiddleware):
|
||||
transformers = (_MarkerB,)
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
middleware=[_MwA(), _MwB()],
|
||||
)
|
||||
|
||||
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
|
||||
|
||||
transformers = run._mux._transformers # type: ignore[attr-defined]
|
||||
idx_a = next(i for i, t in enumerate(transformers) if isinstance(t, _MarkerA))
|
||||
idx_b = next(i for i, t in enumerate(transformers) if isinstance(t, _MarkerB))
|
||||
assert idx_a < idx_b
|
||||
|
||||
list(run.tool_calls)
|
||||
|
||||
|
||||
def test_middleware_without_transformers_does_not_affect_registry() -> None:
|
||||
"""Middleware that omits `transformers` leaves the default registry intact."""
|
||||
|
||||
class _Middleware(AgentMiddleware):
|
||||
pass
|
||||
|
||||
agent = create_agent(model=FakeToolCallingModel(), tools=[], middleware=[_Middleware()])
|
||||
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
|
||||
|
||||
transformers = run._mux._transformers # type: ignore[attr-defined]
|
||||
assert any(isinstance(t, ToolCallTransformer) for t in transformers)
|
||||
assert not any(isinstance(t, _MiddlewareMarker) for t in transformers)
|
||||
|
||||
list(run.tool_calls)
|
||||
Reference in New Issue
Block a user