Compare commits

...

5 Commits

Author SHA1 Message Date
Sydney Runkle
fd3acabe9d run in executor and middleware signatures 2025-09-30 16:41:36 -07:00
Sydney Runkle
348075987f adding tests 2025-09-30 13:48:57 -07:00
Sydney Runkle
ea5d6f2cfa correct handling for sync / async table 2025-09-30 13:10:25 -07:00
Sydney Runkle
cd9a12cc9b conditions for finding middleware 2025-09-30 12:50:01 -07:00
Sydney Runkle
33b11630fe another pass at async 2025-09-30 12:19:14 -07:00
4 changed files with 666 additions and 39 deletions

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
from dataclasses import dataclass, field
from inspect import signature
from inspect import iscoroutinefunction, signature
from typing import (
TYPE_CHECKING,
Annotated,
@@ -18,6 +18,11 @@ from typing import (
overload,
)
from langchain_core.runnables import run_in_executor
if TYPE_CHECKING:
from collections.abc import Awaitable
# needed as top level import for pydantic schema generation on AgentState
from langchain_core.messages import AnyMessage # noqa: TC002
from langgraph.channels.ephemeral_value import EphemeralValue
@@ -129,6 +134,11 @@ class AgentMiddleware(Generic[StateT, ContextT]):
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
"""Logic to run before the model is called."""
async def abefore_model(
self, state: StateT, runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
"""Async logic to run before the model is called."""
def modify_model_request(
self,
request: ModelRequest,
@@ -138,14 +148,35 @@ class AgentMiddleware(Generic[StateT, ContextT]):
"""Logic to modify request kwargs before the model is called."""
return request
async def amodify_model_request(
self,
request: ModelRequest,
state: StateT,
runtime: Runtime[ContextT],
) -> ModelRequest:
"""Async logic to modify request kwargs before the model is called."""
# Try calling sync version with runtime first, fall back to without runtime
try:
return await run_in_executor(None, self.modify_model_request, request, state, runtime)
except TypeError:
# Sync version doesn't accept runtime, call without it
return await run_in_executor(None, self.modify_model_request, request, state)
def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
"""Logic to run after the model is called."""
async def aafter_model(
self, state: StateT, runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
"""Async logic to run after the model is called."""
class _CallableWithState(Protocol[StateT_contra]):
"""Callable with AgentState as argument."""
def __call__(self, state: StateT_contra) -> dict[str, Any] | Command | None:
def __call__(
self, state: StateT_contra
) -> dict[str, Any] | Command | None | Awaitable[dict[str, Any] | Command | None]:
"""Perform some logic with the state."""
...
@@ -155,7 +186,7 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
def __call__(
self, state: StateT_contra, runtime: Runtime[ContextT]
) -> dict[str, Any] | Command | None:
) -> dict[str, Any] | Command | None | Awaitable[dict[str, Any] | Command | None]:
"""Perform some logic with the state and runtime."""
...
@@ -163,7 +194,9 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
class _CallableWithModelRequestAndState(Protocol[StateT_contra]):
"""Callable with ModelRequest and AgentState as arguments."""
def __call__(self, request: ModelRequest, state: StateT_contra) -> ModelRequest:
def __call__(
self, request: ModelRequest, state: StateT_contra
) -> ModelRequest | Awaitable[ModelRequest]:
"""Perform some logic with the model request and state."""
...
@@ -173,7 +206,7 @@ class _CallableWithModelRequestAndStateAndRuntime(Protocol[StateT_contra, Contex
def __call__(
self, request: ModelRequest, state: StateT_contra, runtime: Runtime[ContextT]
) -> ModelRequest:
) -> ModelRequest | Awaitable[ModelRequest]:
"""Perform some logic with the model request, state, and runtime."""
...
@@ -278,14 +311,53 @@ def before_model(
"""
def decorator(func: _NodeSignature[StateT, ContextT]) -> AgentMiddleware[StateT, ContextT]:
if is_callable_with_runtime(func):
is_async = iscoroutinefunction(func)
uses_runtime = is_callable_with_runtime(func)
if is_async:
if uses_runtime:
async def async_wrapped_with_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command | None:
return await func(state, runtime) # type: ignore[misc]
async_wrapped = async_wrapped_with_runtime
else:
async def async_wrapped_without_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
) -> dict[str, Any] | Command | None:
return await func(state) # type: ignore[call-arg,misc]
async_wrapped = async_wrapped_without_runtime # type: ignore[assignment]
middleware_name = name or cast(
"str", getattr(func, "__name__", "BeforeModelMiddleware")
)
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"before_model_jump_to": jump_to or [],
"abefore_model": async_wrapped,
},
)()
if uses_runtime:
def wrapped_with_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command | None:
return func(state, runtime)
return func(state, runtime) # type: ignore[return-value]
wrapped = wrapped_with_runtime
else:
@@ -298,7 +370,6 @@ def before_model(
wrapped = wrapped_without_runtime # type: ignore[assignment]
# Use function name as default if no name provided
middleware_name = name or cast("str", getattr(func, "__name__", "BeforeModelMiddleware"))
return type(
@@ -394,7 +465,47 @@ def modify_model_request(
def decorator(
func: _ModelRequestSignature[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]:
if is_callable_with_runtime_and_request(func):
is_async = iscoroutinefunction(func)
uses_runtime = is_callable_with_runtime_and_request(func)
if is_async:
if uses_runtime:
async def async_wrapped_with_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
state: StateT,
runtime: Runtime[ContextT],
) -> ModelRequest:
return await func(request, state, runtime) # type: ignore[misc]
async_wrapped = async_wrapped_with_runtime
else:
async def async_wrapped_without_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
state: StateT,
) -> ModelRequest:
return await func(request, state) # type: ignore[call-arg,misc]
async_wrapped = async_wrapped_without_runtime # type: ignore[assignment]
middleware_name = name or cast(
"str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
)
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"amodify_model_request": async_wrapped,
},
)()
if uses_runtime:
def wrapped_with_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
@@ -402,7 +513,7 @@ def modify_model_request(
state: StateT,
runtime: Runtime[ContextT],
) -> ModelRequest:
return func(request, state, runtime)
return func(request, state, runtime) # type: ignore[return-value]
wrapped = wrapped_with_runtime
else:
@@ -416,7 +527,6 @@ def modify_model_request(
wrapped = wrapped_without_runtime # type: ignore[assignment]
# Use function name as default if no name provided
middleware_name = name or cast(
"str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
)
@@ -504,14 +614,51 @@ def after_model(
"""
def decorator(func: _NodeSignature[StateT, ContextT]) -> AgentMiddleware[StateT, ContextT]:
if is_callable_with_runtime(func):
is_async = iscoroutinefunction(func)
uses_runtime = is_callable_with_runtime(func)
if is_async:
if uses_runtime:
async def async_wrapped_with_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command | None:
return await func(state, runtime) # type: ignore[misc]
async_wrapped = async_wrapped_with_runtime
else:
async def async_wrapped_without_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
) -> dict[str, Any] | Command | None:
return await func(state) # type: ignore[call-arg,misc]
async_wrapped = async_wrapped_without_runtime # type: ignore[assignment]
middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware"))
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"after_model_jump_to": jump_to or [],
"aafter_model": async_wrapped,
},
)()
if uses_runtime:
def wrapped_with_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command | None:
return func(state, runtime)
return func(state, runtime) # type: ignore[return-value]
wrapped = wrapped_with_runtime
else:
@@ -524,7 +671,6 @@ def after_model(
wrapped = wrapped_without_runtime # type: ignore[assignment]
# Use function name as default if no name provided
middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware"))
return type(

View File

@@ -2,8 +2,9 @@
import itertools
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from inspect import signature
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
from typing import Annotated, Any, Generic, cast, get_args, get_origin, get_type_hints
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
@@ -38,6 +39,27 @@ from langchain.tools import ToolNode
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
ResponseT = TypeVar("ResponseT")
@dataclass
class MiddlewareSignature(Generic[ResponseT, ContextT]):
"""Structured metadata for a middleware's hook implementations.
Attributes:
middleware: The middleware instance.
has_sync: Whether the middleware implements a sync version of the hook.
has_async: Whether the middleware implements an async version of the hook.
sync_uses_runtime: Whether the sync hook accepts a runtime argument.
async_uses_runtime: Whether the async hook accepts a runtime argument.
"""
middleware: AgentMiddleware[AgentState[ResponseT], ContextT]
has_sync: bool
has_async: bool
sync_uses_runtime: bool
async_uses_runtime: bool
def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
"""Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations.
@@ -130,9 +152,6 @@ def _handle_structured_output_error(
return False, ""
ResponseT = TypeVar("ResponseT")
def create_agent( # noqa: PLR0915
*,
model: str | BaseChatModel,
@@ -212,15 +231,22 @@ def create_agent( # noqa: PLR0915
"Please remove duplicate middleware instances."
)
middleware_w_before = [
m for m in middleware if m.__class__.before_model is not AgentMiddleware.before_model
m
for m in middleware
if m.__class__.before_model is not AgentMiddleware.before_model
or m.__class__.abefore_model is not AgentMiddleware.abefore_model
]
middleware_w_modify_model_request = [
m
for m in middleware
if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request
or m.__class__.amodify_model_request is not AgentMiddleware.amodify_model_request
]
middleware_w_after = [
m for m in middleware if m.__class__.after_model is not AgentMiddleware.after_model
m
for m in middleware
if m.__class__.after_model is not AgentMiddleware.after_model
or m.__class__.aafter_model is not AgentMiddleware.aafter_model
]
state_schemas = {m.state_schema for m in middleware}
@@ -346,12 +372,32 @@ def create_agent( # noqa: PLR0915
)
return request.model.bind(**request.model_settings)
model_request_signatures: list[
tuple[bool, AgentMiddleware[AgentState[ResponseT], ContextT]]
] = [
("runtime" in signature(m.modify_model_request).parameters, m)
for m in middleware_w_modify_model_request
]
# Build signatures for modify_model_request middleware with async support
model_request_signatures: list[MiddlewareSignature[ResponseT, ContextT]] = []
for m in middleware_w_modify_model_request:
# Check if async version is implemented (not the default)
has_sync = m.__class__.modify_model_request is not AgentMiddleware.modify_model_request
has_async = m.__class__.amodify_model_request is not AgentMiddleware.amodify_model_request
# Check runtime usage for each implementation
sync_uses_runtime = (
"runtime" in signature(m.modify_model_request).parameters if has_sync else False
)
# If async is implemented, check its signature; otherwise default is True
# (the default async implementation always expects runtime)
async_uses_runtime = (
"runtime" in signature(m.amodify_model_request).parameters if has_async else True
)
model_request_signatures.append(
MiddlewareSignature(
middleware=m,
has_sync=has_sync,
has_async=has_async,
sync_uses_runtime=sync_uses_runtime,
async_uses_runtime=async_uses_runtime,
)
)
def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Sync model request handler with sequential middleware processing."""
@@ -365,11 +411,20 @@ def create_agent( # noqa: PLR0915
)
# Apply modify_model_request middleware in sequence
for use_runtime, m in model_request_signatures:
if use_runtime:
m.modify_model_request(request, state, runtime)
else:
m.modify_model_request(request, state) # type: ignore[call-arg]
for sig in model_request_signatures:
if sig.has_sync:
if sig.sync_uses_runtime:
sig.middleware.modify_model_request(request, state, runtime)
else:
sig.middleware.modify_model_request(request, state) # type: ignore[call-arg]
elif sig.has_async:
msg = (
f"No synchronous function provided for "
f'{sig.middleware.__class__.__name__}.amodify_model_request".'
"\nEither initialize with a synchronous function or invoke"
" via the async API (ainvoke, astream, etc.)"
)
raise TypeError(msg)
# Get the final model and messages
model_ = _get_bound_model(request)
@@ -393,11 +448,13 @@ def create_agent( # noqa: PLR0915
)
# Apply modify_model_request middleware in sequence
for use_runtime, m in model_request_signatures:
if use_runtime:
m.modify_model_request(request, state, runtime)
for sig in model_request_signatures:
# If async is overridden and doesn't use runtime, call without it
if sig.has_async and not sig.async_uses_runtime:
await sig.middleware.amodify_model_request(request, state) # type: ignore[call-arg]
# Otherwise call async with runtime (default implementation handles sync delegation)
else:
m.modify_model_request(request, state) # type: ignore[call-arg]
await sig.middleware.amodify_model_request(request, state, runtime)
# Get the final model and messages
model_ = _get_bound_model(request)
@@ -419,14 +476,46 @@ def create_agent( # noqa: PLR0915
# Add middleware nodes
for m in middleware:
if m.__class__.before_model is not AgentMiddleware.before_model:
if (
m.__class__.before_model is not AgentMiddleware.before_model
or m.__class__.abefore_model is not AgentMiddleware.abefore_model
):
# Use RunnableCallable to support both sync and async
# Pass None for sync if not overridden to avoid signature conflicts
sync_before = (
m.before_model
if m.__class__.before_model is not AgentMiddleware.before_model
else None
)
async_before = (
m.abefore_model
if m.__class__.abefore_model is not AgentMiddleware.abefore_model
else None
)
before_node = RunnableCallable(sync_before, async_before)
graph.add_node(
f"{m.__class__.__name__}.before_model", m.before_model, input_schema=state_schema
f"{m.__class__.__name__}.before_model", before_node, input_schema=state_schema
)
if m.__class__.after_model is not AgentMiddleware.after_model:
if (
m.__class__.after_model is not AgentMiddleware.after_model
or m.__class__.aafter_model is not AgentMiddleware.aafter_model
):
# Use RunnableCallable to support both sync and async
# Pass None for sync if not overridden to avoid signature conflicts
sync_after = (
m.after_model
if m.__class__.after_model is not AgentMiddleware.after_model
else None
)
async_after = (
m.aafter_model
if m.__class__.aafter_model is not AgentMiddleware.aafter_model
else None
)
after_node = RunnableCallable(sync_after, async_after)
graph.add_node(
f"{m.__class__.__name__}.after_model", m.after_model, input_schema=state_schema
f"{m.__class__.__name__}.after_model", after_node, input_schema=state_schema
)
# add start edge

View File

@@ -1661,3 +1661,236 @@ def test_planning_middleware_custom_system_prompt() -> None:
assert result["todos"] == [{"content": "Custom task", "status": "pending"}]
# assert custom system prompt is in the first AI message
assert "call the write_todos tool" in result["messages"][1].content
# Async Middleware Tests
async def test_create_agent_async_invoke() -> None:
"""Test async invoke with async middleware hooks."""
calls = []
class AsyncMiddleware(AgentMiddleware):
async def abefore_model(self, state) -> None:
calls.append("AsyncMiddleware.abefore_model")
async def amodify_model_request(self, request, state) -> ModelRequest:
calls.append("AsyncMiddleware.amodify_model_request")
request.messages.append(HumanMessage("async middleware message"))
return request
async def aafter_model(self, state) -> None:
calls.append("AsyncMiddleware.aafter_model")
@tool
def my_tool(input: str) -> str:
"""A great tool"""
calls.append("my_tool")
return input.upper()
agent = create_agent(
model=FakeToolCallingModel(
tool_calls=[
[{"args": {"input": "yo"}, "id": "1", "name": "my_tool"}],
[],
]
),
tools=[my_tool],
system_prompt="You are a helpful assistant.",
middleware=[AsyncMiddleware()],
).compile()
result = await agent.ainvoke({"messages": [HumanMessage("hello")]})
# Should have:
# 1. Original hello message
# 2. Async middleware message (first invoke)
# 3. AI message with tool call
# 4. Tool message
# 5. Async middleware message (second invoke)
# 6. Final AI message
assert len(result["messages"]) == 6
assert result["messages"][0].content == "hello"
assert result["messages"][1].content == "async middleware message"
assert calls == [
"AsyncMiddleware.abefore_model",
"AsyncMiddleware.amodify_model_request",
"AsyncMiddleware.aafter_model",
"my_tool",
"AsyncMiddleware.abefore_model",
"AsyncMiddleware.amodify_model_request",
"AsyncMiddleware.aafter_model",
]
async def test_create_agent_async_invoke_multiple_middleware() -> None:
"""Test async invoke with multiple async middleware hooks."""
calls = []
class AsyncMiddlewareOne(AgentMiddleware):
async def abefore_model(self, state) -> None:
calls.append("AsyncMiddlewareOne.abefore_model")
async def amodify_model_request(self, request, state) -> ModelRequest:
calls.append("AsyncMiddlewareOne.amodify_model_request")
return request
async def aafter_model(self, state) -> None:
calls.append("AsyncMiddlewareOne.aafter_model")
class AsyncMiddlewareTwo(AgentMiddleware):
async def abefore_model(self, state) -> None:
calls.append("AsyncMiddlewareTwo.abefore_model")
async def amodify_model_request(self, request, state) -> ModelRequest:
calls.append("AsyncMiddlewareTwo.amodify_model_request")
return request
async def aafter_model(self, state) -> None:
calls.append("AsyncMiddlewareTwo.aafter_model")
agent = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[AsyncMiddlewareOne(), AsyncMiddlewareTwo()],
).compile()
result = await agent.ainvoke({"messages": [HumanMessage("hello")]})
assert calls == [
"AsyncMiddlewareOne.abefore_model",
"AsyncMiddlewareTwo.abefore_model",
"AsyncMiddlewareOne.amodify_model_request",
"AsyncMiddlewareTwo.amodify_model_request",
"AsyncMiddlewareTwo.aafter_model",
"AsyncMiddlewareOne.aafter_model",
]
async def test_create_agent_async_jump() -> None:
"""Test async invoke with async middleware using jump_to."""
calls = []
class AsyncMiddlewareOne(AgentMiddleware):
async def abefore_model(self, state) -> None:
calls.append("AsyncMiddlewareOne.abefore_model")
class AsyncMiddlewareTwo(AgentMiddleware):
before_model_jump_to = ["end"]
async def abefore_model(self, state) -> dict[str, Any]:
calls.append("AsyncMiddlewareTwo.abefore_model")
return {"jump_to": "end"}
agent = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[AsyncMiddlewareOne(), AsyncMiddlewareTwo()],
).compile()
result = await agent.ainvoke({"messages": []})
assert result == {"messages": []}
assert calls == ["AsyncMiddlewareOne.abefore_model", "AsyncMiddlewareTwo.abefore_model"]
async def test_create_agent_mixed_sync_async_middleware() -> None:
"""Test async invoke with mixed sync and async middleware."""
calls = []
class SyncMiddleware(AgentMiddleware):
def before_model(self, state) -> None:
calls.append("SyncMiddleware.before_model")
def modify_model_request(self, request, state) -> ModelRequest:
calls.append("SyncMiddleware.modify_model_request")
return request
def after_model(self, state) -> None:
calls.append("SyncMiddleware.after_model")
class AsyncMiddleware(AgentMiddleware):
async def abefore_model(self, state) -> None:
calls.append("AsyncMiddleware.abefore_model")
async def amodify_model_request(self, request, state) -> ModelRequest:
calls.append("AsyncMiddleware.amodify_model_request")
return request
async def aafter_model(self, state) -> None:
calls.append("AsyncMiddleware.aafter_model")
agent = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[SyncMiddleware(), AsyncMiddleware()],
).compile()
result = await agent.ainvoke({"messages": [HumanMessage("hello")]})
# In async mode, both sync and async middleware should work
assert calls == [
"SyncMiddleware.before_model",
"AsyncMiddleware.abefore_model",
"SyncMiddleware.modify_model_request",
"AsyncMiddleware.amodify_model_request",
"AsyncMiddleware.aafter_model",
"SyncMiddleware.after_model",
]
def test_create_agent_sync_invoke_with_only_async_middleware_raises_error() -> None:
"""Test that sync invoke with only async middleware raises TypeError."""
class AsyncOnlyMiddleware(AgentMiddleware):
async def amodify_model_request(self, request, state) -> ModelRequest:
return request
agent = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[AsyncOnlyMiddleware()],
).compile()
with pytest.raises(
TypeError,
match=r"No synchronous function provided for AsyncOnlyMiddleware\.amodify_model_request",
):
agent.invoke({"messages": [HumanMessage("hello")]})
def test_create_agent_sync_invoke_with_mixed_middleware() -> None:
"""Test that sync invoke works with mixed sync/async middleware when sync versions exist."""
calls = []
class MixedMiddleware(AgentMiddleware):
def before_model(self, state) -> None:
calls.append("MixedMiddleware.before_model")
async def abefore_model(self, state) -> None:
calls.append("MixedMiddleware.abefore_model")
def modify_model_request(self, request, state) -> ModelRequest:
calls.append("MixedMiddleware.modify_model_request")
return request
async def amodify_model_request(self, request, state) -> ModelRequest:
calls.append("MixedMiddleware.amodify_model_request")
return request
agent = create_agent(
model=FakeToolCallingModel(),
tools=[],
system_prompt="You are a helpful assistant.",
middleware=[MixedMiddleware()],
).compile()
result = agent.invoke({"messages": [HumanMessage("hello")]})
# In sync mode, only sync methods should be called
assert calls == [
"MixedMiddleware.before_model",
"MixedMiddleware.modify_model_request",
]

View File

@@ -1,5 +1,6 @@
"""Consolidated tests for middleware decorators: before_model, after_model, and modify_model_request."""
import pytest
from typing import Any
from typing_extensions import NotRequired
@@ -150,3 +151,161 @@ def test_decorators_use_function_names_as_default() -> None:
assert my_before_hook.__class__.__name__ == "my_before_hook"
assert my_modify_hook.__class__.__name__ == "my_modify_hook"
assert my_after_hook.__class__.__name__ == "my_after_hook"
# Async Decorator Tests
def test_async_before_model_decorator() -> None:
"""Test before_model decorator with async function."""
@before_model(state_schema=CustomState, tools=[test_tool], name="AsyncBeforeModel")
async def async_before_model(state: CustomState) -> dict[str, Any]:
return {"custom_field": "async_value"}
assert isinstance(async_before_model, AgentMiddleware)
assert async_before_model.state_schema == CustomState
assert async_before_model.tools == [test_tool]
assert async_before_model.__class__.__name__ == "AsyncBeforeModel"
def test_async_after_model_decorator() -> None:
"""Test after_model decorator with async function."""
@after_model(state_schema=CustomState, tools=[test_tool], name="AsyncAfterModel")
async def async_after_model(state: CustomState) -> dict[str, Any]:
return {"custom_field": "async_value"}
assert isinstance(async_after_model, AgentMiddleware)
assert async_after_model.state_schema == CustomState
assert async_after_model.tools == [test_tool]
assert async_after_model.__class__.__name__ == "AsyncAfterModel"
def test_async_modify_model_request_decorator() -> None:
"""Test modify_model_request decorator with async function."""
@modify_model_request(state_schema=CustomState, tools=[test_tool], name="AsyncModifyRequest")
async def async_modify_request(request: ModelRequest, state: CustomState) -> ModelRequest:
request.system_prompt = "Modified async"
return request
assert isinstance(async_modify_request, AgentMiddleware)
assert async_modify_request.state_schema == CustomState
assert async_modify_request.tools == [test_tool]
assert async_modify_request.__class__.__name__ == "AsyncModifyRequest"
def test_mixed_sync_async_decorators() -> None:
"""Test decorators with both sync and async functions."""
@before_model(name="MixedBeforeModel")
def sync_before(state: AgentState) -> None:
return None
@before_model(name="MixedBeforeModel")
async def async_before(state: AgentState) -> None:
return None
@modify_model_request(name="MixedModifyRequest")
def sync_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
return request
@modify_model_request(name="MixedModifyRequest")
async def async_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
return request
# Both should create valid middleware instances
assert isinstance(sync_before, AgentMiddleware)
assert isinstance(async_before, AgentMiddleware)
assert isinstance(sync_modify, AgentMiddleware)
assert isinstance(async_modify, AgentMiddleware)
@pytest.mark.asyncio
async def test_async_decorators_integration() -> None:
"""Test async decorators working together in an agent."""
call_order = []
@before_model
async def track_async_before(state: AgentState) -> None:
call_order.append("async_before")
return None
@modify_model_request
async def track_async_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
call_order.append("async_modify")
return request
@after_model
async def track_async_after(state: AgentState) -> None:
call_order.append("async_after")
return None
agent = create_agent(
model=FakeToolCallingModel(),
middleware=[track_async_before, track_async_modify, track_async_after],
)
agent = agent.compile()
await agent.ainvoke({"messages": [HumanMessage("Hello")]})
assert call_order == ["async_before", "async_modify", "async_after"]
@pytest.mark.asyncio
async def test_mixed_sync_async_decorators_integration() -> None:
"""Test mixed sync/async decorators working together in an agent."""
call_order = []
@before_model
def track_sync_before(state: AgentState) -> None:
call_order.append("sync_before")
return None
@before_model
async def track_async_before(state: AgentState) -> None:
call_order.append("async_before")
return None
@modify_model_request
def track_sync_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
call_order.append("sync_modify")
return request
@modify_model_request
async def track_async_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
call_order.append("async_modify")
return request
@after_model
async def track_async_after(state: AgentState) -> None:
call_order.append("async_after")
return None
@after_model
def track_sync_after(state: AgentState) -> None:
call_order.append("sync_after")
return None
agent = create_agent(
model=FakeToolCallingModel(),
middleware=[
track_sync_before,
track_async_before,
track_sync_modify,
track_async_modify,
track_async_after,
track_sync_after,
],
)
agent = agent.compile()
await agent.ainvoke({"messages": [HumanMessage("Hello")]})
assert call_order == [
"sync_before",
"async_before",
"sync_modify",
"async_modify",
"sync_after",
"async_after",
]