mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-04 08:10:25 +00:00
Compare commits
5 Commits
sr/refine-
...
sr/async-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fd3acabe9d | ||
|
|
348075987f | ||
|
|
ea5d6f2cfa | ||
|
|
cd9a12cc9b | ||
|
|
33b11630fe |
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from inspect import signature
|
||||
from inspect import iscoroutinefunction, signature
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
@@ -18,6 +18,11 @@ from typing import (
|
||||
overload,
|
||||
)
|
||||
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable
|
||||
|
||||
# needed as top level import for pydantic schema generation on AgentState
|
||||
from langchain_core.messages import AnyMessage # noqa: TC002
|
||||
from langgraph.channels.ephemeral_value import EphemeralValue
|
||||
@@ -129,6 +134,11 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
||||
"""Logic to run before the model is called."""
|
||||
|
||||
async def abefore_model(
|
||||
self, state: StateT, runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async logic to run before the model is called."""
|
||||
|
||||
def modify_model_request(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
@@ -138,14 +148,35 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
"""Logic to modify request kwargs before the model is called."""
|
||||
return request
|
||||
|
||||
async def amodify_model_request(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> ModelRequest:
|
||||
"""Async logic to modify request kwargs before the model is called."""
|
||||
# Try calling sync version with runtime first, fall back to without runtime
|
||||
try:
|
||||
return await run_in_executor(None, self.modify_model_request, request, state, runtime)
|
||||
except TypeError:
|
||||
# Sync version doesn't accept runtime, call without it
|
||||
return await run_in_executor(None, self.modify_model_request, request, state)
|
||||
|
||||
def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
||||
"""Logic to run after the model is called."""
|
||||
|
||||
async def aafter_model(
|
||||
self, state: StateT, runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async logic to run after the model is called."""
|
||||
|
||||
|
||||
class _CallableWithState(Protocol[StateT_contra]):
|
||||
"""Callable with AgentState as argument."""
|
||||
|
||||
def __call__(self, state: StateT_contra) -> dict[str, Any] | Command | None:
|
||||
def __call__(
|
||||
self, state: StateT_contra
|
||||
) -> dict[str, Any] | Command | None | Awaitable[dict[str, Any] | Command | None]:
|
||||
"""Perform some logic with the state."""
|
||||
...
|
||||
|
||||
@@ -155,7 +186,7 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
||||
|
||||
def __call__(
|
||||
self, state: StateT_contra, runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | Command | None:
|
||||
) -> dict[str, Any] | Command | None | Awaitable[dict[str, Any] | Command | None]:
|
||||
"""Perform some logic with the state and runtime."""
|
||||
...
|
||||
|
||||
@@ -163,7 +194,9 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
||||
class _CallableWithModelRequestAndState(Protocol[StateT_contra]):
|
||||
"""Callable with ModelRequest and AgentState as arguments."""
|
||||
|
||||
def __call__(self, request: ModelRequest, state: StateT_contra) -> ModelRequest:
|
||||
def __call__(
|
||||
self, request: ModelRequest, state: StateT_contra
|
||||
) -> ModelRequest | Awaitable[ModelRequest]:
|
||||
"""Perform some logic with the model request and state."""
|
||||
...
|
||||
|
||||
@@ -173,7 +206,7 @@ class _CallableWithModelRequestAndStateAndRuntime(Protocol[StateT_contra, Contex
|
||||
|
||||
def __call__(
|
||||
self, request: ModelRequest, state: StateT_contra, runtime: Runtime[ContextT]
|
||||
) -> ModelRequest:
|
||||
) -> ModelRequest | Awaitable[ModelRequest]:
|
||||
"""Perform some logic with the model request, state, and runtime."""
|
||||
...
|
||||
|
||||
@@ -278,14 +311,53 @@ def before_model(
|
||||
"""
|
||||
|
||||
def decorator(func: _NodeSignature[StateT, ContextT]) -> AgentMiddleware[StateT, ContextT]:
|
||||
if is_callable_with_runtime(func):
|
||||
is_async = iscoroutinefunction(func)
|
||||
uses_runtime = is_callable_with_runtime(func)
|
||||
|
||||
if is_async:
|
||||
if uses_runtime:
|
||||
|
||||
async def async_wrapped_with_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | Command | None:
|
||||
return await func(state, runtime) # type: ignore[misc]
|
||||
|
||||
async_wrapped = async_wrapped_with_runtime
|
||||
else:
|
||||
|
||||
async def async_wrapped_without_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
state: StateT,
|
||||
) -> dict[str, Any] | Command | None:
|
||||
return await func(state) # type: ignore[call-arg,misc]
|
||||
|
||||
async_wrapped = async_wrapped_without_runtime # type: ignore[assignment]
|
||||
|
||||
middleware_name = name or cast(
|
||||
"str", getattr(func, "__name__", "BeforeModelMiddleware")
|
||||
)
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"before_model_jump_to": jump_to or [],
|
||||
"abefore_model": async_wrapped,
|
||||
},
|
||||
)()
|
||||
|
||||
if uses_runtime:
|
||||
|
||||
def wrapped_with_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | Command | None:
|
||||
return func(state, runtime)
|
||||
return func(state, runtime) # type: ignore[return-value]
|
||||
|
||||
wrapped = wrapped_with_runtime
|
||||
else:
|
||||
@@ -298,7 +370,6 @@ def before_model(
|
||||
|
||||
wrapped = wrapped_without_runtime # type: ignore[assignment]
|
||||
|
||||
# Use function name as default if no name provided
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "BeforeModelMiddleware"))
|
||||
|
||||
return type(
|
||||
@@ -394,7 +465,47 @@ def modify_model_request(
|
||||
def decorator(
|
||||
func: _ModelRequestSignature[StateT, ContextT],
|
||||
) -> AgentMiddleware[StateT, ContextT]:
|
||||
if is_callable_with_runtime_and_request(func):
|
||||
is_async = iscoroutinefunction(func)
|
||||
uses_runtime = is_callable_with_runtime_and_request(func)
|
||||
|
||||
if is_async:
|
||||
if uses_runtime:
|
||||
|
||||
async def async_wrapped_with_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
request: ModelRequest,
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> ModelRequest:
|
||||
return await func(request, state, runtime) # type: ignore[misc]
|
||||
|
||||
async_wrapped = async_wrapped_with_runtime
|
||||
else:
|
||||
|
||||
async def async_wrapped_without_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
request: ModelRequest,
|
||||
state: StateT,
|
||||
) -> ModelRequest:
|
||||
return await func(request, state) # type: ignore[call-arg,misc]
|
||||
|
||||
async_wrapped = async_wrapped_without_runtime # type: ignore[assignment]
|
||||
|
||||
middleware_name = name or cast(
|
||||
"str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
|
||||
)
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"amodify_model_request": async_wrapped,
|
||||
},
|
||||
)()
|
||||
|
||||
if uses_runtime:
|
||||
|
||||
def wrapped_with_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
@@ -402,7 +513,7 @@ def modify_model_request(
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> ModelRequest:
|
||||
return func(request, state, runtime)
|
||||
return func(request, state, runtime) # type: ignore[return-value]
|
||||
|
||||
wrapped = wrapped_with_runtime
|
||||
else:
|
||||
@@ -416,7 +527,6 @@ def modify_model_request(
|
||||
|
||||
wrapped = wrapped_without_runtime # type: ignore[assignment]
|
||||
|
||||
# Use function name as default if no name provided
|
||||
middleware_name = name or cast(
|
||||
"str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
|
||||
)
|
||||
@@ -504,14 +614,51 @@ def after_model(
|
||||
"""
|
||||
|
||||
def decorator(func: _NodeSignature[StateT, ContextT]) -> AgentMiddleware[StateT, ContextT]:
|
||||
if is_callable_with_runtime(func):
|
||||
is_async = iscoroutinefunction(func)
|
||||
uses_runtime = is_callable_with_runtime(func)
|
||||
|
||||
if is_async:
|
||||
if uses_runtime:
|
||||
|
||||
async def async_wrapped_with_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | Command | None:
|
||||
return await func(state, runtime) # type: ignore[misc]
|
||||
|
||||
async_wrapped = async_wrapped_with_runtime
|
||||
else:
|
||||
|
||||
async def async_wrapped_without_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
state: StateT,
|
||||
) -> dict[str, Any] | Command | None:
|
||||
return await func(state) # type: ignore[call-arg,misc]
|
||||
|
||||
async_wrapped = async_wrapped_without_runtime # type: ignore[assignment]
|
||||
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware"))
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"after_model_jump_to": jump_to or [],
|
||||
"aafter_model": async_wrapped,
|
||||
},
|
||||
)()
|
||||
|
||||
if uses_runtime:
|
||||
|
||||
def wrapped_with_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | Command | None:
|
||||
return func(state, runtime)
|
||||
return func(state, runtime) # type: ignore[return-value]
|
||||
|
||||
wrapped = wrapped_with_runtime
|
||||
else:
|
||||
@@ -524,7 +671,6 @@ def after_model(
|
||||
|
||||
wrapped = wrapped_without_runtime # type: ignore[assignment]
|
||||
|
||||
# Use function name as default if no name provided
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware"))
|
||||
|
||||
return type(
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
|
||||
import itertools
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from inspect import signature
|
||||
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
|
||||
from typing import Annotated, Any, Generic, cast, get_args, get_origin, get_type_hints
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
|
||||
@@ -38,6 +39,27 @@ from langchain.tools import ToolNode
|
||||
|
||||
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
|
||||
|
||||
ResponseT = TypeVar("ResponseT")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MiddlewareSignature(Generic[ResponseT, ContextT]):
|
||||
"""Structured metadata for a middleware's hook implementations.
|
||||
|
||||
Attributes:
|
||||
middleware: The middleware instance.
|
||||
has_sync: Whether the middleware implements a sync version of the hook.
|
||||
has_async: Whether the middleware implements an async version of the hook.
|
||||
sync_uses_runtime: Whether the sync hook accepts a runtime argument.
|
||||
async_uses_runtime: Whether the async hook accepts a runtime argument.
|
||||
"""
|
||||
|
||||
middleware: AgentMiddleware[AgentState[ResponseT], ContextT]
|
||||
has_sync: bool
|
||||
has_async: bool
|
||||
sync_uses_runtime: bool
|
||||
async_uses_runtime: bool
|
||||
|
||||
|
||||
def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
|
||||
"""Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations.
|
||||
@@ -130,9 +152,6 @@ def _handle_structured_output_error(
|
||||
return False, ""
|
||||
|
||||
|
||||
ResponseT = TypeVar("ResponseT")
|
||||
|
||||
|
||||
def create_agent( # noqa: PLR0915
|
||||
*,
|
||||
model: str | BaseChatModel,
|
||||
@@ -212,15 +231,22 @@ def create_agent( # noqa: PLR0915
|
||||
"Please remove duplicate middleware instances."
|
||||
)
|
||||
middleware_w_before = [
|
||||
m for m in middleware if m.__class__.before_model is not AgentMiddleware.before_model
|
||||
m
|
||||
for m in middleware
|
||||
if m.__class__.before_model is not AgentMiddleware.before_model
|
||||
or m.__class__.abefore_model is not AgentMiddleware.abefore_model
|
||||
]
|
||||
middleware_w_modify_model_request = [
|
||||
m
|
||||
for m in middleware
|
||||
if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request
|
||||
or m.__class__.amodify_model_request is not AgentMiddleware.amodify_model_request
|
||||
]
|
||||
middleware_w_after = [
|
||||
m for m in middleware if m.__class__.after_model is not AgentMiddleware.after_model
|
||||
m
|
||||
for m in middleware
|
||||
if m.__class__.after_model is not AgentMiddleware.after_model
|
||||
or m.__class__.aafter_model is not AgentMiddleware.aafter_model
|
||||
]
|
||||
|
||||
state_schemas = {m.state_schema for m in middleware}
|
||||
@@ -346,12 +372,32 @@ def create_agent( # noqa: PLR0915
|
||||
)
|
||||
return request.model.bind(**request.model_settings)
|
||||
|
||||
model_request_signatures: list[
|
||||
tuple[bool, AgentMiddleware[AgentState[ResponseT], ContextT]]
|
||||
] = [
|
||||
("runtime" in signature(m.modify_model_request).parameters, m)
|
||||
for m in middleware_w_modify_model_request
|
||||
]
|
||||
# Build signatures for modify_model_request middleware with async support
|
||||
model_request_signatures: list[MiddlewareSignature[ResponseT, ContextT]] = []
|
||||
for m in middleware_w_modify_model_request:
|
||||
# Check if async version is implemented (not the default)
|
||||
has_sync = m.__class__.modify_model_request is not AgentMiddleware.modify_model_request
|
||||
has_async = m.__class__.amodify_model_request is not AgentMiddleware.amodify_model_request
|
||||
|
||||
# Check runtime usage for each implementation
|
||||
sync_uses_runtime = (
|
||||
"runtime" in signature(m.modify_model_request).parameters if has_sync else False
|
||||
)
|
||||
# If async is implemented, check its signature; otherwise default is True
|
||||
# (the default async implementation always expects runtime)
|
||||
async_uses_runtime = (
|
||||
"runtime" in signature(m.amodify_model_request).parameters if has_async else True
|
||||
)
|
||||
|
||||
model_request_signatures.append(
|
||||
MiddlewareSignature(
|
||||
middleware=m,
|
||||
has_sync=has_sync,
|
||||
has_async=has_async,
|
||||
sync_uses_runtime=sync_uses_runtime,
|
||||
async_uses_runtime=async_uses_runtime,
|
||||
)
|
||||
)
|
||||
|
||||
def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
"""Sync model request handler with sequential middleware processing."""
|
||||
@@ -365,11 +411,20 @@ def create_agent( # noqa: PLR0915
|
||||
)
|
||||
|
||||
# Apply modify_model_request middleware in sequence
|
||||
for use_runtime, m in model_request_signatures:
|
||||
if use_runtime:
|
||||
m.modify_model_request(request, state, runtime)
|
||||
else:
|
||||
m.modify_model_request(request, state) # type: ignore[call-arg]
|
||||
for sig in model_request_signatures:
|
||||
if sig.has_sync:
|
||||
if sig.sync_uses_runtime:
|
||||
sig.middleware.modify_model_request(request, state, runtime)
|
||||
else:
|
||||
sig.middleware.modify_model_request(request, state) # type: ignore[call-arg]
|
||||
elif sig.has_async:
|
||||
msg = (
|
||||
f"No synchronous function provided for "
|
||||
f'{sig.middleware.__class__.__name__}.amodify_model_request".'
|
||||
"\nEither initialize with a synchronous function or invoke"
|
||||
" via the async API (ainvoke, astream, etc.)"
|
||||
)
|
||||
raise TypeError(msg)
|
||||
|
||||
# Get the final model and messages
|
||||
model_ = _get_bound_model(request)
|
||||
@@ -393,11 +448,13 @@ def create_agent( # noqa: PLR0915
|
||||
)
|
||||
|
||||
# Apply modify_model_request middleware in sequence
|
||||
for use_runtime, m in model_request_signatures:
|
||||
if use_runtime:
|
||||
m.modify_model_request(request, state, runtime)
|
||||
for sig in model_request_signatures:
|
||||
# If async is overridden and doesn't use runtime, call without it
|
||||
if sig.has_async and not sig.async_uses_runtime:
|
||||
await sig.middleware.amodify_model_request(request, state) # type: ignore[call-arg]
|
||||
# Otherwise call async with runtime (default implementation handles sync delegation)
|
||||
else:
|
||||
m.modify_model_request(request, state) # type: ignore[call-arg]
|
||||
await sig.middleware.amodify_model_request(request, state, runtime)
|
||||
|
||||
# Get the final model and messages
|
||||
model_ = _get_bound_model(request)
|
||||
@@ -419,14 +476,46 @@ def create_agent( # noqa: PLR0915
|
||||
|
||||
# Add middleware nodes
|
||||
for m in middleware:
|
||||
if m.__class__.before_model is not AgentMiddleware.before_model:
|
||||
if (
|
||||
m.__class__.before_model is not AgentMiddleware.before_model
|
||||
or m.__class__.abefore_model is not AgentMiddleware.abefore_model
|
||||
):
|
||||
# Use RunnableCallable to support both sync and async
|
||||
# Pass None for sync if not overridden to avoid signature conflicts
|
||||
sync_before = (
|
||||
m.before_model
|
||||
if m.__class__.before_model is not AgentMiddleware.before_model
|
||||
else None
|
||||
)
|
||||
async_before = (
|
||||
m.abefore_model
|
||||
if m.__class__.abefore_model is not AgentMiddleware.abefore_model
|
||||
else None
|
||||
)
|
||||
before_node = RunnableCallable(sync_before, async_before)
|
||||
graph.add_node(
|
||||
f"{m.__class__.__name__}.before_model", m.before_model, input_schema=state_schema
|
||||
f"{m.__class__.__name__}.before_model", before_node, input_schema=state_schema
|
||||
)
|
||||
|
||||
if m.__class__.after_model is not AgentMiddleware.after_model:
|
||||
if (
|
||||
m.__class__.after_model is not AgentMiddleware.after_model
|
||||
or m.__class__.aafter_model is not AgentMiddleware.aafter_model
|
||||
):
|
||||
# Use RunnableCallable to support both sync and async
|
||||
# Pass None for sync if not overridden to avoid signature conflicts
|
||||
sync_after = (
|
||||
m.after_model
|
||||
if m.__class__.after_model is not AgentMiddleware.after_model
|
||||
else None
|
||||
)
|
||||
async_after = (
|
||||
m.aafter_model
|
||||
if m.__class__.aafter_model is not AgentMiddleware.aafter_model
|
||||
else None
|
||||
)
|
||||
after_node = RunnableCallable(sync_after, async_after)
|
||||
graph.add_node(
|
||||
f"{m.__class__.__name__}.after_model", m.after_model, input_schema=state_schema
|
||||
f"{m.__class__.__name__}.after_model", after_node, input_schema=state_schema
|
||||
)
|
||||
|
||||
# add start edge
|
||||
|
||||
@@ -1661,3 +1661,236 @@ def test_planning_middleware_custom_system_prompt() -> None:
|
||||
assert result["todos"] == [{"content": "Custom task", "status": "pending"}]
|
||||
# assert custom system prompt is in the first AI message
|
||||
assert "call the write_todos tool" in result["messages"][1].content
|
||||
|
||||
|
||||
# Async Middleware Tests
|
||||
async def test_create_agent_async_invoke() -> None:
|
||||
"""Test async invoke with async middleware hooks."""
|
||||
calls = []
|
||||
|
||||
class AsyncMiddleware(AgentMiddleware):
|
||||
async def abefore_model(self, state) -> None:
|
||||
calls.append("AsyncMiddleware.abefore_model")
|
||||
|
||||
async def amodify_model_request(self, request, state) -> ModelRequest:
|
||||
calls.append("AsyncMiddleware.amodify_model_request")
|
||||
request.messages.append(HumanMessage("async middleware message"))
|
||||
return request
|
||||
|
||||
async def aafter_model(self, state) -> None:
|
||||
calls.append("AsyncMiddleware.aafter_model")
|
||||
|
||||
@tool
|
||||
def my_tool(input: str) -> str:
|
||||
"""A great tool"""
|
||||
calls.append("my_tool")
|
||||
return input.upper()
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[{"args": {"input": "yo"}, "id": "1", "name": "my_tool"}],
|
||||
[],
|
||||
]
|
||||
),
|
||||
tools=[my_tool],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[AsyncMiddleware()],
|
||||
).compile()
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("hello")]})
|
||||
|
||||
# Should have:
|
||||
# 1. Original hello message
|
||||
# 2. Async middleware message (first invoke)
|
||||
# 3. AI message with tool call
|
||||
# 4. Tool message
|
||||
# 5. Async middleware message (second invoke)
|
||||
# 6. Final AI message
|
||||
assert len(result["messages"]) == 6
|
||||
assert result["messages"][0].content == "hello"
|
||||
assert result["messages"][1].content == "async middleware message"
|
||||
assert calls == [
|
||||
"AsyncMiddleware.abefore_model",
|
||||
"AsyncMiddleware.amodify_model_request",
|
||||
"AsyncMiddleware.aafter_model",
|
||||
"my_tool",
|
||||
"AsyncMiddleware.abefore_model",
|
||||
"AsyncMiddleware.amodify_model_request",
|
||||
"AsyncMiddleware.aafter_model",
|
||||
]
|
||||
|
||||
|
||||
async def test_create_agent_async_invoke_multiple_middleware() -> None:
|
||||
"""Test async invoke with multiple async middleware hooks."""
|
||||
calls = []
|
||||
|
||||
class AsyncMiddlewareOne(AgentMiddleware):
|
||||
async def abefore_model(self, state) -> None:
|
||||
calls.append("AsyncMiddlewareOne.abefore_model")
|
||||
|
||||
async def amodify_model_request(self, request, state) -> ModelRequest:
|
||||
calls.append("AsyncMiddlewareOne.amodify_model_request")
|
||||
return request
|
||||
|
||||
async def aafter_model(self, state) -> None:
|
||||
calls.append("AsyncMiddlewareOne.aafter_model")
|
||||
|
||||
class AsyncMiddlewareTwo(AgentMiddleware):
|
||||
async def abefore_model(self, state) -> None:
|
||||
calls.append("AsyncMiddlewareTwo.abefore_model")
|
||||
|
||||
async def amodify_model_request(self, request, state) -> ModelRequest:
|
||||
calls.append("AsyncMiddlewareTwo.amodify_model_request")
|
||||
return request
|
||||
|
||||
async def aafter_model(self, state) -> None:
|
||||
calls.append("AsyncMiddlewareTwo.aafter_model")
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[AsyncMiddlewareOne(), AsyncMiddlewareTwo()],
|
||||
).compile()
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("hello")]})
|
||||
|
||||
assert calls == [
|
||||
"AsyncMiddlewareOne.abefore_model",
|
||||
"AsyncMiddlewareTwo.abefore_model",
|
||||
"AsyncMiddlewareOne.amodify_model_request",
|
||||
"AsyncMiddlewareTwo.amodify_model_request",
|
||||
"AsyncMiddlewareTwo.aafter_model",
|
||||
"AsyncMiddlewareOne.aafter_model",
|
||||
]
|
||||
|
||||
|
||||
async def test_create_agent_async_jump() -> None:
|
||||
"""Test async invoke with async middleware using jump_to."""
|
||||
calls = []
|
||||
|
||||
class AsyncMiddlewareOne(AgentMiddleware):
|
||||
async def abefore_model(self, state) -> None:
|
||||
calls.append("AsyncMiddlewareOne.abefore_model")
|
||||
|
||||
class AsyncMiddlewareTwo(AgentMiddleware):
|
||||
before_model_jump_to = ["end"]
|
||||
|
||||
async def abefore_model(self, state) -> dict[str, Any]:
|
||||
calls.append("AsyncMiddlewareTwo.abefore_model")
|
||||
return {"jump_to": "end"}
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[AsyncMiddlewareOne(), AsyncMiddlewareTwo()],
|
||||
).compile()
|
||||
|
||||
result = await agent.ainvoke({"messages": []})
|
||||
|
||||
assert result == {"messages": []}
|
||||
assert calls == ["AsyncMiddlewareOne.abefore_model", "AsyncMiddlewareTwo.abefore_model"]
|
||||
|
||||
|
||||
async def test_create_agent_mixed_sync_async_middleware() -> None:
|
||||
"""Test async invoke with mixed sync and async middleware."""
|
||||
calls = []
|
||||
|
||||
class SyncMiddleware(AgentMiddleware):
|
||||
def before_model(self, state) -> None:
|
||||
calls.append("SyncMiddleware.before_model")
|
||||
|
||||
def modify_model_request(self, request, state) -> ModelRequest:
|
||||
calls.append("SyncMiddleware.modify_model_request")
|
||||
return request
|
||||
|
||||
def after_model(self, state) -> None:
|
||||
calls.append("SyncMiddleware.after_model")
|
||||
|
||||
class AsyncMiddleware(AgentMiddleware):
|
||||
async def abefore_model(self, state) -> None:
|
||||
calls.append("AsyncMiddleware.abefore_model")
|
||||
|
||||
async def amodify_model_request(self, request, state) -> ModelRequest:
|
||||
calls.append("AsyncMiddleware.amodify_model_request")
|
||||
return request
|
||||
|
||||
async def aafter_model(self, state) -> None:
|
||||
calls.append("AsyncMiddleware.aafter_model")
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[SyncMiddleware(), AsyncMiddleware()],
|
||||
).compile()
|
||||
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("hello")]})
|
||||
|
||||
# In async mode, both sync and async middleware should work
|
||||
assert calls == [
|
||||
"SyncMiddleware.before_model",
|
||||
"AsyncMiddleware.abefore_model",
|
||||
"SyncMiddleware.modify_model_request",
|
||||
"AsyncMiddleware.amodify_model_request",
|
||||
"AsyncMiddleware.aafter_model",
|
||||
"SyncMiddleware.after_model",
|
||||
]
|
||||
|
||||
|
||||
def test_create_agent_sync_invoke_with_only_async_middleware_raises_error() -> None:
|
||||
"""Test that sync invoke with only async middleware raises TypeError."""
|
||||
|
||||
class AsyncOnlyMiddleware(AgentMiddleware):
|
||||
async def amodify_model_request(self, request, state) -> ModelRequest:
|
||||
return request
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[AsyncOnlyMiddleware()],
|
||||
).compile()
|
||||
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match=r"No synchronous function provided for AsyncOnlyMiddleware\.amodify_model_request",
|
||||
):
|
||||
agent.invoke({"messages": [HumanMessage("hello")]})
|
||||
|
||||
|
||||
def test_create_agent_sync_invoke_with_mixed_middleware() -> None:
|
||||
"""Test that sync invoke works with mixed sync/async middleware when sync versions exist."""
|
||||
calls = []
|
||||
|
||||
class MixedMiddleware(AgentMiddleware):
|
||||
def before_model(self, state) -> None:
|
||||
calls.append("MixedMiddleware.before_model")
|
||||
|
||||
async def abefore_model(self, state) -> None:
|
||||
calls.append("MixedMiddleware.abefore_model")
|
||||
|
||||
def modify_model_request(self, request, state) -> ModelRequest:
|
||||
calls.append("MixedMiddleware.modify_model_request")
|
||||
return request
|
||||
|
||||
async def amodify_model_request(self, request, state) -> ModelRequest:
|
||||
calls.append("MixedMiddleware.amodify_model_request")
|
||||
return request
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
tools=[],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
middleware=[MixedMiddleware()],
|
||||
).compile()
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("hello")]})
|
||||
|
||||
# In sync mode, only sync methods should be called
|
||||
assert calls == [
|
||||
"MixedMiddleware.before_model",
|
||||
"MixedMiddleware.modify_model_request",
|
||||
]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Consolidated tests for middleware decorators: before_model, after_model, and modify_model_request."""
|
||||
|
||||
import pytest
|
||||
from typing import Any
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
@@ -150,3 +151,161 @@ def test_decorators_use_function_names_as_default() -> None:
|
||||
assert my_before_hook.__class__.__name__ == "my_before_hook"
|
||||
assert my_modify_hook.__class__.__name__ == "my_modify_hook"
|
||||
assert my_after_hook.__class__.__name__ == "my_after_hook"
|
||||
|
||||
|
||||
# Async Decorator Tests
|
||||
|
||||
|
||||
def test_async_before_model_decorator() -> None:
|
||||
"""Test before_model decorator with async function."""
|
||||
|
||||
@before_model(state_schema=CustomState, tools=[test_tool], name="AsyncBeforeModel")
|
||||
async def async_before_model(state: CustomState) -> dict[str, Any]:
|
||||
return {"custom_field": "async_value"}
|
||||
|
||||
assert isinstance(async_before_model, AgentMiddleware)
|
||||
assert async_before_model.state_schema == CustomState
|
||||
assert async_before_model.tools == [test_tool]
|
||||
assert async_before_model.__class__.__name__ == "AsyncBeforeModel"
|
||||
|
||||
|
||||
def test_async_after_model_decorator() -> None:
|
||||
"""Test after_model decorator with async function."""
|
||||
|
||||
@after_model(state_schema=CustomState, tools=[test_tool], name="AsyncAfterModel")
|
||||
async def async_after_model(state: CustomState) -> dict[str, Any]:
|
||||
return {"custom_field": "async_value"}
|
||||
|
||||
assert isinstance(async_after_model, AgentMiddleware)
|
||||
assert async_after_model.state_schema == CustomState
|
||||
assert async_after_model.tools == [test_tool]
|
||||
assert async_after_model.__class__.__name__ == "AsyncAfterModel"
|
||||
|
||||
|
||||
def test_async_modify_model_request_decorator() -> None:
|
||||
"""Test modify_model_request decorator with async function."""
|
||||
|
||||
@modify_model_request(state_schema=CustomState, tools=[test_tool], name="AsyncModifyRequest")
|
||||
async def async_modify_request(request: ModelRequest, state: CustomState) -> ModelRequest:
|
||||
request.system_prompt = "Modified async"
|
||||
return request
|
||||
|
||||
assert isinstance(async_modify_request, AgentMiddleware)
|
||||
assert async_modify_request.state_schema == CustomState
|
||||
assert async_modify_request.tools == [test_tool]
|
||||
assert async_modify_request.__class__.__name__ == "AsyncModifyRequest"
|
||||
|
||||
|
||||
def test_mixed_sync_async_decorators() -> None:
|
||||
"""Test decorators with both sync and async functions."""
|
||||
|
||||
@before_model(name="MixedBeforeModel")
|
||||
def sync_before(state: AgentState) -> None:
|
||||
return None
|
||||
|
||||
@before_model(name="MixedBeforeModel")
|
||||
async def async_before(state: AgentState) -> None:
|
||||
return None
|
||||
|
||||
@modify_model_request(name="MixedModifyRequest")
|
||||
def sync_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||
return request
|
||||
|
||||
@modify_model_request(name="MixedModifyRequest")
|
||||
async def async_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||
return request
|
||||
|
||||
# Both should create valid middleware instances
|
||||
assert isinstance(sync_before, AgentMiddleware)
|
||||
assert isinstance(async_before, AgentMiddleware)
|
||||
assert isinstance(sync_modify, AgentMiddleware)
|
||||
assert isinstance(async_modify, AgentMiddleware)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_decorators_integration() -> None:
|
||||
"""Test async decorators working together in an agent."""
|
||||
call_order = []
|
||||
|
||||
@before_model
|
||||
async def track_async_before(state: AgentState) -> None:
|
||||
call_order.append("async_before")
|
||||
return None
|
||||
|
||||
@modify_model_request
|
||||
async def track_async_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||
call_order.append("async_modify")
|
||||
return request
|
||||
|
||||
@after_model
|
||||
async def track_async_after(state: AgentState) -> None:
|
||||
call_order.append("async_after")
|
||||
return None
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
middleware=[track_async_before, track_async_modify, track_async_after],
|
||||
)
|
||||
agent = agent.compile()
|
||||
await agent.ainvoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
assert call_order == ["async_before", "async_modify", "async_after"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_sync_async_decorators_integration() -> None:
|
||||
"""Test mixed sync/async decorators working together in an agent."""
|
||||
call_order = []
|
||||
|
||||
@before_model
|
||||
def track_sync_before(state: AgentState) -> None:
|
||||
call_order.append("sync_before")
|
||||
return None
|
||||
|
||||
@before_model
|
||||
async def track_async_before(state: AgentState) -> None:
|
||||
call_order.append("async_before")
|
||||
return None
|
||||
|
||||
@modify_model_request
|
||||
def track_sync_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||
call_order.append("sync_modify")
|
||||
return request
|
||||
|
||||
@modify_model_request
|
||||
async def track_async_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||
call_order.append("async_modify")
|
||||
return request
|
||||
|
||||
@after_model
|
||||
async def track_async_after(state: AgentState) -> None:
|
||||
call_order.append("async_after")
|
||||
return None
|
||||
|
||||
@after_model
|
||||
def track_sync_after(state: AgentState) -> None:
|
||||
call_order.append("sync_after")
|
||||
return None
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
middleware=[
|
||||
track_sync_before,
|
||||
track_async_before,
|
||||
track_sync_modify,
|
||||
track_async_modify,
|
||||
track_async_after,
|
||||
track_sync_after,
|
||||
],
|
||||
)
|
||||
agent = agent.compile()
|
||||
await agent.ainvoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
assert call_order == [
|
||||
"sync_before",
|
||||
"async_before",
|
||||
"sync_modify",
|
||||
"async_modify",
|
||||
"sync_after",
|
||||
"async_after",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user