Compare commits

...

2 Commits

Author SHA1 Message Date
Mason Daugherty
f3ec1a8420 Merge branch 'master' into harrison/middleware-middleware 2025-09-24 16:40:53 -04:00
Harrison Chase
d2fc91c079 let middleware specify middleware 2025-09-18 10:56:04 -04:00
3 changed files with 78 additions and 21 deletions

View File

@@ -3,20 +3,8 @@
from __future__ import annotations
from dataclasses import dataclass, field
from inspect import signature
from typing import (
TYPE_CHECKING,
Annotated,
Any,
ClassVar,
Generic,
Literal,
Protocol,
TypeAlias,
TypeGuard,
cast,
overload,
)
from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, cast
from collections.abc import Sequence
# needed as top level import for pydantic schema generation on AgentState
from langchain_core.messages import AnyMessage # noqa: TC002
@@ -120,13 +108,12 @@ class AgentMiddleware(Generic[StateT, ContextT]):
tools: list[BaseTool]
"""Additional tools registered by the middleware."""
before_model_jump_to: ClassVar[list[JumpTo]] = []
"""Valid jump destinations for before_model hook. Used to establish conditional edges."""
# Nested middleware that should run before this middleware.
# These will be expanded into the main agent's middleware list
# and executed in order prior to this middleware.
middleware: Sequence["AgentMiddleware"] = ()
after_model_jump_to: ClassVar[list[JumpTo]] = []
"""Valid jump destinations for after_model hook. Used to establish conditional edges."""
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
def before_model(self, state: StateT) -> dict[str, Any] | None:
"""Logic to run before the model is called."""
def modify_model_request(

View File

@@ -175,6 +175,27 @@ def create_agent( # noqa: PLR0915
native_output_binding = ProviderStrategyBinding.from_schema_spec(
response_format.schema_spec
)
# Expand nested middleware so that any middleware declared within a middleware
# is run before the declaring middleware.
def _expand_middleware(mw: Sequence[AgentMiddleware]) -> list[AgentMiddleware]:
expanded: list[AgentMiddleware] = []
seen: set[str] = set()
def add(m: AgentMiddleware) -> None:
# Add children first (depth-first), then the parent
for child in getattr(m, "middleware", ()) or ():
add(child)
name = m.__class__.__name__
if name not in seen:
expanded.append(m)
seen.add(name)
for root in mw:
add(root)
return expanded
middleware = tuple(_expand_middleware(middleware))
middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
# Setup tools
@@ -207,7 +228,7 @@ def create_agent( # noqa: PLR0915
list(structured_output_tools.values()) if structured_output_tools else []
) + middleware_tools
# validate middleware
# validate middleware (after expansion)
assert len({m.__class__.__name__ for m in middleware}) == len(middleware), ( # noqa: S101
"Please remove duplicate middleware instances."
)

View File

@@ -517,6 +517,55 @@ def test_human_in_the_loop_middleware_single_tool_response() -> None:
assert result["messages"][0].tool_call_id == "1"
def test_nested_middleware_runs_before_parent() -> None:
calls: list[str] = []
class Child(AgentMiddleware):
def before_model(self, state):
calls.append("child.before")
def modify_model_request(self, request: ModelRequest, state):
calls.append("child.modify")
return request
def after_model(self, state):
calls.append("child.after")
class Parent(AgentMiddleware):
def __init__(self) -> None:
# Include child middleware which should run before this one
self.middleware = [Child()]
def before_model(self, state):
calls.append("parent.before")
def modify_model_request(self, request: ModelRequest, state):
calls.append("parent.modify")
return request
def after_model(self, state):
calls.append("parent.after")
agent = create_agent(
model=FakeToolCallingModel(),
tools=[],
middleware=[Parent()],
system_prompt="sys",
).compile()
res = agent.invoke({"messages": ["hi"]})
assert "messages" in res and len(res["messages"]) >= 2
# Order: before -> child then parent; modify -> child then parent; after -> parent then child
assert calls == [
"child.before",
"parent.before",
"child.modify",
"parent.modify",
"parent.after",
"child.after",
]
def test_human_in_the_loop_middleware_multiple_tools_mixed_responses() -> None:
"""Test HumanInTheLoopMiddleware with multiple tools and mixed response types."""