diff --git a/libs/langchain_v1/langchain/agents/factory.py b/libs/langchain_v1/langchain/agents/factory.py index a245a4d4bef..de69ef505c2 100644 --- a/libs/langchain_v1/langchain/agents/factory.py +++ b/libs/langchain_v1/langchain/agents/factory.py @@ -81,6 +81,7 @@ if TYPE_CHECKING: from langgraph.graph.state import CompiledStateGraph from langgraph.runtime import Runtime from langgraph.store.base import BaseStore + from langgraph.stream._mux import TransformerFactory from langgraph.types import Checkpointer from langchain.agents.middleware.types import ToolCallWrapper @@ -708,7 +709,7 @@ def create_agent( debug: bool = False, name: str | None = None, cache: BaseCache[Any] | None = None, - transformers: Sequence[Callable[[tuple[str, ...]], Any]] | None = None, + transformers: Sequence[TransformerFactory] | None = None, ) -> CompiledStateGraph[ AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT] ]: @@ -806,9 +807,11 @@ def create_agent( cache: An optional `BaseCache` instance to enable caching of graph execution. transformers: Optional sequence of scope-aware `StreamTransformer` factories to register on the compiled graph in addition to - the agent defaults. Each factory is invoked per-scope - (`factory(scope)`) so subgraph mini-muxes get fresh - instances. Appended after the built-in `ToolCallTransformer`. + the agent defaults. Each factory is invoked as `factory(scope)` + so every invocation receives a fresh instance. The final order + on the compiled graph is: `ToolCallTransformer`, then any + factories declared by middleware via + `AgentMiddleware.transformers`, then any factories supplied here. Returns: A compiled `StateGraph` that can be used for chat interactions. @@ -1662,6 +1665,8 @@ def create_agent( if name: config["metadata"]["lc_agent_name"] = name + middleware_transformers = [t for m in middleware for t in getattr(m, "transformers", ())] + return graph.compile( checkpointer=checkpointer, store=store, @@ -1670,7 +1675,11 @@ def create_agent( debug=debug, name=name, cache=cache, - transformers=[ToolCallTransformer, *(transformers or ())], + transformers=[ + ToolCallTransformer, + *middleware_transformers, + *(transformers or ()), + ], ).with_config(config) diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index c249d17465f..1570ebd9402 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -40,6 +40,7 @@ if TYPE_CHECKING: from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.tools import BaseTool from langgraph.runtime import Runtime + from langgraph.stream._mux import TransformerFactory from langgraph.types import Command from langchain.agents.structured_output import ResponseFormat @@ -397,6 +398,16 @@ class AgentMiddleware(Generic[StateT, ContextT, ResponseT]): tools: Sequence[BaseTool] """Additional tools registered by the middleware.""" + transformers: Sequence[TransformerFactory] = () + """Stream transformer factories registered by the middleware. + + Each entry is a scope-aware factory invoked as `factory(scope)` so every + invocation receives a fresh instance. Factories are merged with the + `transformers` argument of [`create_agent`][langchain.agents.create_agent] + at graph compile time, after the `ToolCallTransformer` and before any + user-supplied entries. + """ + @property def name(self) -> str: """The name of the middleware instance. diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_transformers.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_transformers.py new file mode 100644 index 00000000000..c660b8bb13c --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_transformers.py @@ -0,0 +1,143 @@ +"""Tests for middleware-registered stream transformers.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from langchain_core.messages import HumanMessage +from langgraph.prebuilt import ToolCallTransformer +from langgraph.stream import StreamChannel, StreamTransformer + +from langchain.agents.factory import create_agent +from langchain.agents.middleware.types import AgentMiddleware +from tests.unit_tests.agents.model import FakeToolCallingModel + +if TYPE_CHECKING: + from langgraph.stream._types import ProtocolEvent + + +class _MiddlewareMarker(StreamTransformer): + """Marker transformer used to assert registration order.""" + + required_stream_modes = () + + def __init__(self, scope: tuple[str, ...] = ()) -> None: + super().__init__(scope) + self._log: StreamChannel[int] = StreamChannel() + + def init(self) -> dict[str, Any]: + return {"middleware_marker": self._log} + + def process(self, event: ProtocolEvent) -> bool: + del event + return True + + +class _UserMarker(StreamTransformer): + """Second marker to verify user-supplied transformers append last.""" + + required_stream_modes = () + + def __init__(self, scope: tuple[str, ...] = ()) -> None: + super().__init__(scope) + self._log: StreamChannel[int] = StreamChannel() + + def init(self) -> dict[str, Any]: + return {"user_marker": self._log} + + def process(self, event: ProtocolEvent) -> bool: + del event + return True + + +def test_middleware_transformer_registered_on_compiled_graph() -> None: + """A `transformers` factory declared on middleware is wired into the run mux.""" + + class _Middleware(AgentMiddleware): + transformers = (_MiddlewareMarker,) + + agent = create_agent(model=FakeToolCallingModel(), tools=[], middleware=[_Middleware()]) + + run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3") + + assert "middleware_marker" in run._mux.extensions # type: ignore[attr-defined] + # Drain to close the run cleanly. + list(run.tool_calls) + + +def test_middleware_and_user_transformers_compose_in_order() -> None: + """Order is: built-in `ToolCallTransformer` → middleware → user-supplied.""" + + class _Middleware(AgentMiddleware): + transformers = (_MiddlewareMarker,) + + agent = create_agent( + model=FakeToolCallingModel(), + tools=[], + middleware=[_Middleware()], + transformers=[_UserMarker], + ) + + run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3") + + transformers = run._mux._transformers # type: ignore[attr-defined] + tool_call_idx = next( + i for i, t in enumerate(transformers) if isinstance(t, ToolCallTransformer) + ) + middleware_idx = next(i for i, t in enumerate(transformers) if isinstance(t, _MiddlewareMarker)) + user_idx = next(i for i, t in enumerate(transformers) if isinstance(t, _UserMarker)) + + assert tool_call_idx < middleware_idx < user_idx, ( + "transformers must register as: built-in, then middleware, then user-supplied" + ) + + list(run.tool_calls) + + +def test_transformers_from_multiple_middleware_preserve_middleware_order() -> None: + """Transformers across middleware register in middleware-list order.""" + + class _MarkerA(_MiddlewareMarker): + def init(self) -> dict[str, Any]: + return {"marker_a": self._log} + + class _MarkerB(_MiddlewareMarker): + def init(self) -> dict[str, Any]: + return {"marker_b": self._log} + + class _MwA(AgentMiddleware): + transformers = (_MarkerA,) + + class _MwB(AgentMiddleware): + transformers = (_MarkerB,) + + agent = create_agent( + model=FakeToolCallingModel(), + tools=[], + middleware=[_MwA(), _MwB()], + ) + + run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3") + + transformers = run._mux._transformers # type: ignore[attr-defined] + idx_a = next(i for i, t in enumerate(transformers) if isinstance(t, _MarkerA)) + idx_b = next(i for i, t in enumerate(transformers) if isinstance(t, _MarkerB)) + assert idx_a < idx_b + + list(run.tool_calls) + + +def test_middleware_without_transformers_does_not_affect_registry() -> None: + """Middleware that omits `transformers` leaves the default registry intact.""" + + class _Middleware(AgentMiddleware): + pass + + agent = create_agent(model=FakeToolCallingModel(), tools=[], middleware=[_Middleware()]) + run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3") + + transformers = run._mux._transformers # type: ignore[attr-defined] + assert any(isinstance(t, ToolCallTransformer) for t in transformers) + assert not any(isinstance(t, _MiddlewareMarker) for t in transformers) + + list(run.tool_calls)