mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 18:50:33 +00:00
feat(langchain): register stream transformers on middleware (#37591)
Adds a `transformers` attribute to `AgentMiddleware` so middleware can
declare scope-aware `StreamTransformer` factories alongside their
`tools` and lifecycle hooks. `create_agent` merges middleware-registered
factories with any caller-supplied ones at compile time.
## API
```python
class MyMiddleware(AgentMiddleware):
transformers = (MyTransformer,) # factory: (scope,) -> StreamTransformer
```
When the agent compiles, the final transformer order on the run mux is:
1. Built-in ``ToolCallTransformer``
2. Middleware-registered factories, in middleware order
3. Caller-supplied ``transformers=`` from ``create_agent``
This ordering keeps the built-in tool-call projection in front of any
consumer transformers and gives caller-supplied entries the final word.
This commit is contained in:
@@ -81,6 +81,7 @@ if TYPE_CHECKING:
|
|||||||
from langgraph.graph.state import CompiledStateGraph
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
from langgraph.store.base import BaseStore
|
from langgraph.store.base import BaseStore
|
||||||
|
from langgraph.stream._mux import TransformerFactory
|
||||||
from langgraph.types import Checkpointer
|
from langgraph.types import Checkpointer
|
||||||
|
|
||||||
from langchain.agents.middleware.types import ToolCallWrapper
|
from langchain.agents.middleware.types import ToolCallWrapper
|
||||||
@@ -708,7 +709,7 @@ def create_agent(
|
|||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
cache: BaseCache[Any] | None = None,
|
cache: BaseCache[Any] | None = None,
|
||||||
transformers: Sequence[Callable[[tuple[str, ...]], Any]] | None = None,
|
transformers: Sequence[TransformerFactory] | None = None,
|
||||||
) -> CompiledStateGraph[
|
) -> CompiledStateGraph[
|
||||||
AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
|
AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
|
||||||
]:
|
]:
|
||||||
@@ -806,9 +807,11 @@ def create_agent(
|
|||||||
cache: An optional `BaseCache` instance to enable caching of graph execution.
|
cache: An optional `BaseCache` instance to enable caching of graph execution.
|
||||||
transformers: Optional sequence of scope-aware `StreamTransformer`
|
transformers: Optional sequence of scope-aware `StreamTransformer`
|
||||||
factories to register on the compiled graph in addition to
|
factories to register on the compiled graph in addition to
|
||||||
the agent defaults. Each factory is invoked per-scope
|
the agent defaults. Each factory is invoked as `factory(scope)`
|
||||||
(`factory(scope)`) so subgraph mini-muxes get fresh
|
so every invocation receives a fresh instance. The final order
|
||||||
instances. Appended after the built-in `ToolCallTransformer`.
|
on the compiled graph is: `ToolCallTransformer`, then any
|
||||||
|
factories declared by middleware via
|
||||||
|
`AgentMiddleware.transformers`, then any factories supplied here.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A compiled `StateGraph` that can be used for chat interactions.
|
A compiled `StateGraph` that can be used for chat interactions.
|
||||||
@@ -1662,6 +1665,8 @@ def create_agent(
|
|||||||
if name:
|
if name:
|
||||||
config["metadata"]["lc_agent_name"] = name
|
config["metadata"]["lc_agent_name"] = name
|
||||||
|
|
||||||
|
middleware_transformers = [t for m in middleware for t in getattr(m, "transformers", ())]
|
||||||
|
|
||||||
return graph.compile(
|
return graph.compile(
|
||||||
checkpointer=checkpointer,
|
checkpointer=checkpointer,
|
||||||
store=store,
|
store=store,
|
||||||
@@ -1670,7 +1675,11 @@ def create_agent(
|
|||||||
debug=debug,
|
debug=debug,
|
||||||
name=name,
|
name=name,
|
||||||
cache=cache,
|
cache=cache,
|
||||||
transformers=[ToolCallTransformer, *(transformers or ())],
|
transformers=[
|
||||||
|
ToolCallTransformer,
|
||||||
|
*middleware_transformers,
|
||||||
|
*(transformers or ()),
|
||||||
|
],
|
||||||
).with_config(config)
|
).with_config(config)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ if TYPE_CHECKING:
|
|||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
from langgraph.stream._mux import TransformerFactory
|
||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
|
|
||||||
from langchain.agents.structured_output import ResponseFormat
|
from langchain.agents.structured_output import ResponseFormat
|
||||||
@@ -397,6 +398,16 @@ class AgentMiddleware(Generic[StateT, ContextT, ResponseT]):
|
|||||||
tools: Sequence[BaseTool]
|
tools: Sequence[BaseTool]
|
||||||
"""Additional tools registered by the middleware."""
|
"""Additional tools registered by the middleware."""
|
||||||
|
|
||||||
|
transformers: Sequence[TransformerFactory] = ()
|
||||||
|
"""Stream transformer factories registered by the middleware.
|
||||||
|
|
||||||
|
Each entry is a scope-aware factory invoked as `factory(scope)` so every
|
||||||
|
invocation receives a fresh instance. Factories are merged with the
|
||||||
|
`transformers` argument of [`create_agent`][langchain.agents.create_agent]
|
||||||
|
at graph compile time, after the `ToolCallTransformer` and before any
|
||||||
|
user-supplied entries.
|
||||||
|
"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
"""The name of the middleware instance.
|
"""The name of the middleware instance.
|
||||||
|
|||||||
@@ -0,0 +1,143 @@
|
|||||||
|
"""Tests for middleware-registered stream transformers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from langgraph.prebuilt import ToolCallTransformer
|
||||||
|
from langgraph.stream import StreamChannel, StreamTransformer
|
||||||
|
|
||||||
|
from langchain.agents.factory import create_agent
|
||||||
|
from langchain.agents.middleware.types import AgentMiddleware
|
||||||
|
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langgraph.stream._types import ProtocolEvent
|
||||||
|
|
||||||
|
|
||||||
|
class _MiddlewareMarker(StreamTransformer):
|
||||||
|
"""Marker transformer used to assert registration order."""
|
||||||
|
|
||||||
|
required_stream_modes = ()
|
||||||
|
|
||||||
|
def __init__(self, scope: tuple[str, ...] = ()) -> None:
|
||||||
|
super().__init__(scope)
|
||||||
|
self._log: StreamChannel[int] = StreamChannel()
|
||||||
|
|
||||||
|
def init(self) -> dict[str, Any]:
|
||||||
|
return {"middleware_marker": self._log}
|
||||||
|
|
||||||
|
def process(self, event: ProtocolEvent) -> bool:
|
||||||
|
del event
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class _UserMarker(StreamTransformer):
|
||||||
|
"""Second marker to verify user-supplied transformers append last."""
|
||||||
|
|
||||||
|
required_stream_modes = ()
|
||||||
|
|
||||||
|
def __init__(self, scope: tuple[str, ...] = ()) -> None:
|
||||||
|
super().__init__(scope)
|
||||||
|
self._log: StreamChannel[int] = StreamChannel()
|
||||||
|
|
||||||
|
def init(self) -> dict[str, Any]:
|
||||||
|
return {"user_marker": self._log}
|
||||||
|
|
||||||
|
def process(self, event: ProtocolEvent) -> bool:
|
||||||
|
del event
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_middleware_transformer_registered_on_compiled_graph() -> None:
|
||||||
|
"""A `transformers` factory declared on middleware is wired into the run mux."""
|
||||||
|
|
||||||
|
class _Middleware(AgentMiddleware):
|
||||||
|
transformers = (_MiddlewareMarker,)
|
||||||
|
|
||||||
|
agent = create_agent(model=FakeToolCallingModel(), tools=[], middleware=[_Middleware()])
|
||||||
|
|
||||||
|
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
|
||||||
|
|
||||||
|
assert "middleware_marker" in run._mux.extensions # type: ignore[attr-defined]
|
||||||
|
# Drain to close the run cleanly.
|
||||||
|
list(run.tool_calls)
|
||||||
|
|
||||||
|
|
||||||
|
def test_middleware_and_user_transformers_compose_in_order() -> None:
|
||||||
|
"""Order is: built-in `ToolCallTransformer` → middleware → user-supplied."""
|
||||||
|
|
||||||
|
class _Middleware(AgentMiddleware):
|
||||||
|
transformers = (_MiddlewareMarker,)
|
||||||
|
|
||||||
|
agent = create_agent(
|
||||||
|
model=FakeToolCallingModel(),
|
||||||
|
tools=[],
|
||||||
|
middleware=[_Middleware()],
|
||||||
|
transformers=[_UserMarker],
|
||||||
|
)
|
||||||
|
|
||||||
|
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
|
||||||
|
|
||||||
|
transformers = run._mux._transformers # type: ignore[attr-defined]
|
||||||
|
tool_call_idx = next(
|
||||||
|
i for i, t in enumerate(transformers) if isinstance(t, ToolCallTransformer)
|
||||||
|
)
|
||||||
|
middleware_idx = next(i for i, t in enumerate(transformers) if isinstance(t, _MiddlewareMarker))
|
||||||
|
user_idx = next(i for i, t in enumerate(transformers) if isinstance(t, _UserMarker))
|
||||||
|
|
||||||
|
assert tool_call_idx < middleware_idx < user_idx, (
|
||||||
|
"transformers must register as: built-in, then middleware, then user-supplied"
|
||||||
|
)
|
||||||
|
|
||||||
|
list(run.tool_calls)
|
||||||
|
|
||||||
|
|
||||||
|
def test_transformers_from_multiple_middleware_preserve_middleware_order() -> None:
|
||||||
|
"""Transformers across middleware register in middleware-list order."""
|
||||||
|
|
||||||
|
class _MarkerA(_MiddlewareMarker):
|
||||||
|
def init(self) -> dict[str, Any]:
|
||||||
|
return {"marker_a": self._log}
|
||||||
|
|
||||||
|
class _MarkerB(_MiddlewareMarker):
|
||||||
|
def init(self) -> dict[str, Any]:
|
||||||
|
return {"marker_b": self._log}
|
||||||
|
|
||||||
|
class _MwA(AgentMiddleware):
|
||||||
|
transformers = (_MarkerA,)
|
||||||
|
|
||||||
|
class _MwB(AgentMiddleware):
|
||||||
|
transformers = (_MarkerB,)
|
||||||
|
|
||||||
|
agent = create_agent(
|
||||||
|
model=FakeToolCallingModel(),
|
||||||
|
tools=[],
|
||||||
|
middleware=[_MwA(), _MwB()],
|
||||||
|
)
|
||||||
|
|
||||||
|
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
|
||||||
|
|
||||||
|
transformers = run._mux._transformers # type: ignore[attr-defined]
|
||||||
|
idx_a = next(i for i, t in enumerate(transformers) if isinstance(t, _MarkerA))
|
||||||
|
idx_b = next(i for i, t in enumerate(transformers) if isinstance(t, _MarkerB))
|
||||||
|
assert idx_a < idx_b
|
||||||
|
|
||||||
|
list(run.tool_calls)
|
||||||
|
|
||||||
|
|
||||||
|
def test_middleware_without_transformers_does_not_affect_registry() -> None:
|
||||||
|
"""Middleware that omits `transformers` leaves the default registry intact."""
|
||||||
|
|
||||||
|
class _Middleware(AgentMiddleware):
|
||||||
|
pass
|
||||||
|
|
||||||
|
agent = create_agent(model=FakeToolCallingModel(), tools=[], middleware=[_Middleware()])
|
||||||
|
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
|
||||||
|
|
||||||
|
transformers = run._mux._transformers # type: ignore[attr-defined]
|
||||||
|
assert any(isinstance(t, ToolCallTransformer) for t in transformers)
|
||||||
|
assert not any(isinstance(t, _MiddlewareMarker) for t in transformers)
|
||||||
|
|
||||||
|
list(run.tool_calls)
|
||||||
Reference in New Issue
Block a user