Compare commits

...

2 Commits

Author SHA1 Message Date
Harrison Chase
721a9f389b cr 2025-10-26 09:10:27 -07:00
Harrison Chase
8eec156479 middleware extend middleware 2025-10-26 09:09:58 -07:00
3 changed files with 75 additions and 0 deletions

View File

@@ -610,6 +610,28 @@ def create_agent( # noqa: PLR0915
if tools is None:
tools = []
# Expand nested middleware: if a middleware declares additional middleware,
# insert them immediately after, recursively (parent -> children -> ...).
def _expand_middleware_sequence(
seq: "Sequence[AgentMiddleware[Any, Any]]",
_seen: set[int] | None = None,
) -> list[AgentMiddleware[Any, Any]]:
expanded: list[AgentMiddleware[Any, Any]] = []
seen = _seen if _seen is not None else set()
for m in seq:
mid = id(m)
if mid in seen:
# Skip already expanded instances to avoid cycles/duplicates
continue
seen.add(mid)
expanded.append(m)
children = getattr(m, "middleware", []) or []
if children:
expanded.extend(_expand_middleware_sequence(children, seen))
return expanded
middleware = tuple(_expand_middleware_sequence(middleware))
# Convert response format and setup structured output tools
# Raw schemas are wrapped in AutoStrategy to preserve auto-detection intent.
# AutoStrategy is converted to ToolStrategy upfront to calculate tools during agent creation,

View File

@@ -201,6 +201,14 @@ class AgentMiddleware(Generic[StateT, ContextT]):
tools: list[BaseTool]
"""Additional tools registered by the middleware."""
# Allow middleware to declare additional middleware to be inserted
# immediately after this middleware. Factored recursively, the final order
# is parent -> children -> grandchildren, etc.
# This enables composition patterns where a middleware can bundle and
# expose other middleware without users having to pass them explicitly.
middleware: list["AgentMiddleware[Any, Any]"]
"""Additional middleware registered by this middleware (optional)."""
@property
def name(self) -> str:
"""The name of the middleware instance.

View File

@@ -0,0 +1,45 @@
from typing import Any
from langchain.agents.factory import create_agent
from langchain.agents.middleware.types import AgentMiddleware, AgentState
from langchain_core.messages import SystemMessage
from ..model import FakeToolCallingModel
class _AppendMiddleware(AgentMiddleware):
def __init__(self, label: str, children: list[AgentMiddleware] | None = None) -> None:
# No tools registered by default
self.tools = []
# Optional nested middleware
self.middleware = children or []
self._label = label
def before_model(self, state: AgentState, runtime) -> dict[str, Any] | None: # type: ignore[override]
return {"messages": [SystemMessage(self._label)]}
def test_nested_middleware_ordering() -> None:
# Build nested chain: X -> Y -> Z
z = _AppendMiddleware("Z")
y = _AppendMiddleware("Y", [z])
x = _AppendMiddleware("X", [y])
# Siblings before and after X
a = _AppendMiddleware("A")
b = _AppendMiddleware("B")
agent = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt=None,
middleware=[a, x, b],
)
result = agent.invoke({"messages": [{"role": "user", "content": "hi"}]})
# The FakeToolCallingModel joins message contents with '-'
final_ai = result["messages"][-1]
content: str = final_ai.content # type: ignore[assignment]
# Ensure correct in-order appearance: A -> X -> Y -> Z -> B
assert content.index("A") < content.index("X") < content.index("Y") < content.index("Z") < content.index("B")