mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 14:43:07 +00:00
let middleware specify middleware
This commit is contained in:
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, cast
|
||||
from collections.abc import Sequence
|
||||
|
||||
# needed as top level import for pydantic schema generation on AgentState
|
||||
from langchain_core.messages import AnyMessage # noqa: TC002
|
||||
@@ -68,6 +69,11 @@ class AgentMiddleware(Generic[StateT]):
|
||||
tools: list[BaseTool]
|
||||
"""Additional tools registered by the middleware."""
|
||||
|
||||
# Nested middleware that should run before this middleware.
|
||||
# These will be expanded into the main agent's middleware list
|
||||
# and executed in order prior to this middleware.
|
||||
middleware: Sequence["AgentMiddleware"] = ()
|
||||
|
||||
def before_model(self, state: StateT) -> dict[str, Any] | None:
|
||||
"""Logic to run before the model is called."""
|
||||
|
||||
|
||||
@@ -151,6 +151,27 @@ def create_agent( # noqa: PLR0915
|
||||
native_output_binding = ProviderStrategyBinding.from_schema_spec(
|
||||
response_format.schema_spec
|
||||
)
|
||||
# Expand nested middleware so that any middleware declared within a middleware
|
||||
# is run before the declaring middleware.
|
||||
def _expand_middleware(mw: Sequence[AgentMiddleware]) -> list[AgentMiddleware]:
|
||||
expanded: list[AgentMiddleware] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
def add(m: AgentMiddleware) -> None:
|
||||
# Add children first (depth-first), then the parent
|
||||
for child in getattr(m, "middleware", ()) or ():
|
||||
add(child)
|
||||
name = m.__class__.__name__
|
||||
if name not in seen:
|
||||
expanded.append(m)
|
||||
seen.add(name)
|
||||
|
||||
for root in mw:
|
||||
add(root)
|
||||
return expanded
|
||||
|
||||
middleware = tuple(_expand_middleware(middleware))
|
||||
|
||||
middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
|
||||
|
||||
# Setup tools
|
||||
@@ -183,7 +204,7 @@ def create_agent( # noqa: PLR0915
|
||||
list(structured_output_tools.values()) if structured_output_tools else []
|
||||
) + middleware_tools
|
||||
|
||||
# validate middleware
|
||||
# validate middleware (after expansion)
|
||||
assert len({m.__class__.__name__ for m in middleware}) == len(middleware), ( # noqa: S101
|
||||
"Please remove duplicate middleware instances."
|
||||
)
|
||||
|
||||
@@ -495,6 +495,55 @@ def test_human_in_the_loop_middleware_single_tool_response() -> None:
|
||||
assert result["messages"][0].tool_call_id == "1"
|
||||
|
||||
|
||||
def test_nested_middleware_runs_before_parent() -> None:
|
||||
calls: list[str] = []
|
||||
|
||||
class Child(AgentMiddleware):
|
||||
def before_model(self, state):
|
||||
calls.append("child.before")
|
||||
|
||||
def modify_model_request(self, request: ModelRequest, state):
|
||||
calls.append("child.modify")
|
||||
return request
|
||||
|
||||
def after_model(self, state):
|
||||
calls.append("child.after")
|
||||
|
||||
class Parent(AgentMiddleware):
|
||||
def __init__(self) -> None:
|
||||
# Include child middleware which should run before this one
|
||||
self.middleware = [Child()]
|
||||
|
||||
def before_model(self, state):
|
||||
calls.append("parent.before")
|
||||
|
||||
def modify_model_request(self, request: ModelRequest, state):
|
||||
calls.append("parent.modify")
|
||||
return request
|
||||
|
||||
def after_model(self, state):
|
||||
calls.append("parent.after")
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
middleware=[Parent()],
|
||||
system_prompt="sys",
|
||||
).compile()
|
||||
|
||||
res = agent.invoke({"messages": ["hi"]})
|
||||
assert "messages" in res and len(res["messages"]) >= 2
|
||||
# Order: before -> child then parent; modify -> child then parent; after -> parent then child
|
||||
assert calls == [
|
||||
"child.before",
|
||||
"parent.before",
|
||||
"child.modify",
|
||||
"parent.modify",
|
||||
"parent.after",
|
||||
"child.after",
|
||||
]
|
||||
|
||||
|
||||
def test_human_in_the_loop_middleware_multiple_tools_mixed_responses() -> None:
|
||||
"""Test HumanInTheLoopMiddleware with multiple tools and mixed response types."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user