diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index c84f4954ed1..58f4beb5c0b 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -4,6 +4,7 @@ from __future__ import annotations from dataclasses import dataclass, field from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, cast +from collections.abc import Sequence # needed as top level import for pydantic schema generation on AgentState from langchain_core.messages import AnyMessage # noqa: TC002 @@ -68,6 +69,11 @@ class AgentMiddleware(Generic[StateT]): tools: list[BaseTool] """Additional tools registered by the middleware.""" + # Nested middleware that should run before this middleware. + # These will be expanded into the main agent's middleware list + # and executed in order prior to this middleware. + middleware: Sequence["AgentMiddleware"] = () + def before_model(self, state: StateT) -> dict[str, Any] | None: """Logic to run before the model is called.""" diff --git a/libs/langchain_v1/langchain/agents/middleware_agent.py b/libs/langchain_v1/langchain/agents/middleware_agent.py index 4caa6c8af62..9f7f83bf225 100644 --- a/libs/langchain_v1/langchain/agents/middleware_agent.py +++ b/libs/langchain_v1/langchain/agents/middleware_agent.py @@ -151,6 +151,27 @@ def create_agent( # noqa: PLR0915 native_output_binding = ProviderStrategyBinding.from_schema_spec( response_format.schema_spec ) + # Expand nested middleware so that any middleware declared within a middleware + # is run before the declaring middleware. + def _expand_middleware(mw: Sequence[AgentMiddleware]) -> list[AgentMiddleware]: + expanded: list[AgentMiddleware] = [] + seen: set[str] = set() + + def add(m: AgentMiddleware) -> None: + # Add children first (depth-first), then the parent + for child in getattr(m, "middleware", ()) or (): + add(child) + name = m.__class__.__name__ + if name not in seen: + expanded.append(m) + seen.add(name) + + for root in mw: + add(root) + return expanded + + middleware = tuple(_expand_middleware(middleware)) + middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])] # Setup tools @@ -183,7 +204,7 @@ def create_agent( # noqa: PLR0915 list(structured_output_tools.values()) if structured_output_tools else [] ) + middleware_tools - # validate middleware + # validate middleware (after expansion) assert len({m.__class__.__name__ for m in middleware}) == len(middleware), ( # noqa: S101 "Please remove duplicate middleware instances." ) diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py index 58138ac5227..2b99ebb18f0 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py @@ -495,6 +495,55 @@ def test_human_in_the_loop_middleware_single_tool_response() -> None: assert result["messages"][0].tool_call_id == "1" +def test_nested_middleware_runs_before_parent() -> None: + calls: list[str] = [] + + class Child(AgentMiddleware): + def before_model(self, state): + calls.append("child.before") + + def modify_model_request(self, request: ModelRequest, state): + calls.append("child.modify") + return request + + def after_model(self, state): + calls.append("child.after") + + class Parent(AgentMiddleware): + def __init__(self) -> None: + # Include child middleware which should run before this one + self.middleware = [Child()] + + def before_model(self, state): + calls.append("parent.before") + + def modify_model_request(self, request: ModelRequest, state): + calls.append("parent.modify") + return request + + def after_model(self, state): + calls.append("parent.after") + + agent = create_agent( + model=FakeToolCallingModel(), + tools=[], + middleware=[Parent()], + system_prompt="sys", + ).compile() + + res = agent.invoke({"messages": ["hi"]}) + assert "messages" in res and len(res["messages"]) >= 2 + # Order: before -> child then parent; modify -> child then parent; after -> parent then child + assert calls == [ + "child.before", + "parent.before", + "child.modify", + "parent.modify", + "parent.after", + "child.after", + ] + + def test_human_in_the_loop_middleware_multiple_tools_mixed_responses() -> None: """Test HumanInTheLoopMiddleware with multiple tools and mixed response types."""