Compare commits

...

3 Commits

Author SHA1 Message Date
Sydney Runkle
286039788e linting 2025-09-29 16:12:34 -07:00
Sydney Runkle
41e9d479f5 cleaning up types theoretically 2025-09-29 16:10:06 -07:00
Sydney Runkle
1526c419d8 initial pass at async impl 2025-09-29 14:15:23 -07:00
3 changed files with 894 additions and 94 deletions

View File

@@ -13,7 +13,6 @@ from typing import (
Literal,
Protocol,
TypeAlias,
TypeGuard,
cast,
overload,
)
@@ -23,16 +22,16 @@ from langchain_core.messages import AnyMessage # noqa: TC002
from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.graph.message import add_messages
from langgraph.runtime import Runtime
from langgraph.types import Command
from langgraph.typing import ContextT
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Awaitable, Callable
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.runtime import Runtime
from langgraph.types import Command
from langchain.agents.structured_output import ResponseFormat
@@ -42,7 +41,12 @@ __all__ = [
"ContextT",
"ModelRequest",
"OmitFromSchema",
"PublicAgentState",
"aafter_model",
"abefore_model",
"after_model",
"amodify_model_request",
"before_model",
"modify_model_request",
]
JumpTo = Literal["tools", "model", "end"]
@@ -50,6 +54,9 @@ JumpTo = Literal["tools", "model", "end"]
ResponseT = TypeVar("ResponseT")
NodeReturn: TypeAlias = dict[str, Any] | Command | None
"""Return type for middleware node hooks."""
@dataclass
class ModelRequest:
@@ -126,9 +133,12 @@ class AgentMiddleware(Generic[StateT, ContextT]):
after_model_jump_to: ClassVar[list[JumpTo]] = []
"""Valid jump destinations for after_model hook. Used to establish conditional edges."""
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> NodeReturn:
"""Logic to run before the model is called."""
async def abefore_model(self, state: StateT, runtime: Runtime[ContextT]) -> NodeReturn:
"""Async logic to run before the model is called."""
def modify_model_request(
self,
request: ModelRequest,
@@ -138,64 +148,80 @@ class AgentMiddleware(Generic[StateT, ContextT]):
"""Logic to modify request kwargs before the model is called."""
return request
def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
async def amodify_model_request(
self,
request: ModelRequest,
state: StateT, # noqa: ARG002
runtime: Runtime[ContextT], # noqa: ARG002
) -> ModelRequest:
"""Async logic to modify request kwargs before the model is called."""
return request
def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> NodeReturn:
"""Logic to run after the model is called."""
async def aafter_model(self, state: StateT, runtime: Runtime[ContextT]) -> NodeReturn:
"""Async logic to run after the model is called."""
class _CallableWithState(Protocol[StateT_contra]):
"""Callable with AgentState as argument."""
def __call__(self, state: StateT_contra) -> dict[str, Any] | Command | None:
"""Perform some logic with the state."""
class _NodeCallableWithState(Protocol[StateT_contra]):
"""Callable for before/after model hooks with just state (sync or async)."""
def __call__(self, state: StateT_contra) -> NodeReturn | Awaitable[NodeReturn]:
"""Perform logic with the state."""
...
class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
"""Callable with AgentState and Runtime as arguments."""
class _NodeCallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
"""Callable for before/after model hooks with state and runtime (sync or async)."""
def __call__(
self, state: StateT_contra, runtime: Runtime[ContextT]
) -> dict[str, Any] | Command | None:
"""Perform some logic with the state and runtime."""
) -> NodeReturn | Awaitable[NodeReturn]:
"""Perform logic with the state and runtime."""
...
class _CallableWithModelRequestAndState(Protocol[StateT_contra]):
"""Callable with ModelRequest and AgentState as arguments."""
class _ModelRequestCallableWithState(Protocol[StateT_contra]):
"""Callable for modify_model_request hook with state (sync or async)."""
def __call__(self, request: ModelRequest, state: StateT_contra) -> ModelRequest:
"""Perform some logic with the model request and state."""
def __call__(
self, request: ModelRequest, state: StateT_contra
) -> ModelRequest | Awaitable[ModelRequest]:
"""Perform logic with the model request and state."""
...
class _CallableWithModelRequestAndStateAndRuntime(Protocol[StateT_contra, ContextT]):
"""Callable with ModelRequest, AgentState, and Runtime as arguments."""
class _ModelRequestCallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
"""Callable for modify_model_request hook with state and runtime (sync or async)."""
def __call__(
self, request: ModelRequest, state: StateT_contra, runtime: Runtime[ContextT]
) -> ModelRequest:
"""Perform some logic with the model request, state, and runtime."""
) -> ModelRequest | Awaitable[ModelRequest]:
"""Perform logic with the model request, state, and runtime."""
...
_NodeSignature: TypeAlias = (
_CallableWithState[StateT] | _CallableWithStateAndRuntime[StateT, ContextT]
_NodeCallableWithState[StateT] | _NodeCallableWithStateAndRuntime[StateT, ContextT]
)
_AsyncNodeSignature: TypeAlias = (
_NodeCallableWithState[StateT] | _NodeCallableWithStateAndRuntime[StateT, ContextT]
)
_ModelRequestSignature: TypeAlias = (
_CallableWithModelRequestAndState[StateT]
| _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]
_ModelRequestCallableWithState[StateT]
| _ModelRequestCallableWithStateAndRuntime[StateT, ContextT]
)
_AsyncModelRequestSignature: TypeAlias = (
_ModelRequestCallableWithState[StateT]
| _ModelRequestCallableWithStateAndRuntime[StateT, ContextT]
)
def is_callable_with_runtime(
func: _NodeSignature[StateT, ContextT],
) -> TypeGuard[_CallableWithStateAndRuntime[StateT, ContextT]]:
return "runtime" in signature(func).parameters
def is_callable_with_runtime_and_request(
func: _ModelRequestSignature[StateT, ContextT],
) -> TypeGuard[_CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]]:
func: _NodeSignature[StateT, ContextT] | _ModelRequestSignature[StateT, ContextT],
) -> bool:
"""Check if callable accepts runtime parameter."""
return "runtime" in signature(func).parameters
@@ -284,8 +310,8 @@ def before_model(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command | None:
return func(state, runtime)
) -> NodeReturn:
return func(state, runtime) # type: ignore[call-arg,return-value]
wrapped = wrapped_with_runtime
else:
@@ -293,8 +319,8 @@ def before_model(
def wrapped_without_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
) -> dict[str, Any] | Command | None:
return func(state) # type: ignore[call-arg]
) -> NodeReturn:
return func(state) # type: ignore[call-arg,return-value]
wrapped = wrapped_without_runtime # type: ignore[assignment]
@@ -394,7 +420,7 @@ def modify_model_request(
def decorator(
func: _ModelRequestSignature[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]:
if is_callable_with_runtime_and_request(func):
if is_callable_with_runtime(func):
def wrapped_with_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
@@ -402,7 +428,7 @@ def modify_model_request(
state: StateT,
runtime: Runtime[ContextT],
) -> ModelRequest:
return func(request, state, runtime)
return func(request, state, runtime) # type: ignore[call-arg,return-value]
wrapped = wrapped_with_runtime
else:
@@ -412,7 +438,7 @@ def modify_model_request(
request: ModelRequest,
state: StateT,
) -> ModelRequest:
return func(request, state) # type: ignore[call-arg]
return func(request, state) # type: ignore[call-arg,return-value]
wrapped = wrapped_without_runtime # type: ignore[assignment]
@@ -510,8 +536,8 @@ def after_model(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command | None:
return func(state, runtime)
) -> NodeReturn:
return func(state, runtime) # type: ignore[call-arg,return-value]
wrapped = wrapped_with_runtime
else:
@@ -519,8 +545,8 @@ def after_model(
def wrapped_without_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
) -> dict[str, Any] | Command | None:
return func(state) # type: ignore[call-arg]
) -> NodeReturn:
return func(state) # type: ignore[call-arg,return-value]
wrapped = wrapped_without_runtime # type: ignore[assignment]
@@ -541,3 +567,339 @@ def after_model(
if func is not None:
return decorator(func)
return decorator
@overload
def abefore_model(
func: _AsyncNodeSignature[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]: ...
@overload
def abefore_model(
func: None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
jump_to: list[JumpTo] | None = None,
name: str | None = None,
) -> Callable[[_AsyncNodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]: ...
def abefore_model(
func: _AsyncNodeSignature[StateT, ContextT] | None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
jump_to: list[JumpTo] | None = None,
name: str | None = None,
) -> (
Callable[[_AsyncNodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
| AgentMiddleware[StateT, ContextT]
):
"""Decorator used to dynamically create a middleware with the async before_model hook.
Args:
func: The async function to be decorated. Can accept either:
- `state: StateT` - Just the agent state
- `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
state_schema: Optional custom state schema type. If not provided, uses the default
AgentState schema.
tools: Optional list of additional tools to register with this middleware.
jump_to: Optional list of valid jump destinations for conditional edges.
Valid values are: "tools", "model", "end"
name: Optional name for the generated middleware class. If not provided,
uses the decorated function's name.
Returns:
Either an AgentMiddleware instance (if func is provided directly) or a decorator function
that can be applied to a function its wrapping.
The decorated function should return:
- `dict[str, Any]` - State updates to merge into the agent state
- `Command` - A command to control flow (e.g., jump to different node)
- `None` - No state updates or flow control
Examples:
Basic usage with state only:
```python
@abefore_model
async def log_before_model(state: AgentState) -> None:
print(f"About to call model with {len(state['messages'])} messages")
```
Advanced usage with runtime and conditional jumping:
```python
@abefore_model(jump_to=["end"])
async def conditional_before_model(
state: AgentState, runtime: Runtime
) -> dict[str, Any] | None:
if await some_async_condition(state):
return {"jump_to": "end"}
return None
```
"""
def decorator(func: _AsyncNodeSignature[StateT, ContextT]) -> AgentMiddleware[StateT, ContextT]:
if is_callable_with_runtime(func):
async def wrapped_with_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
runtime: Runtime[ContextT],
) -> NodeReturn:
result = func(state, runtime) # type: ignore[call-arg]
return await result # type: ignore[misc]
wrapped = wrapped_with_runtime
else:
async def wrapped_without_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
) -> NodeReturn:
result = func(state) # type: ignore[call-arg]
return await result # type: ignore[misc]
wrapped = wrapped_without_runtime # type: ignore[assignment]
# Use function name as default if no name provided
middleware_name = name or cast(
"str", getattr(func, "__name__", "AsyncBeforeModelMiddleware")
)
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"before_model_jump_to": jump_to or [],
"abefore_model": wrapped,
},
)()
if func is not None:
return decorator(func)
return decorator
@overload
def amodify_model_request(
func: _AsyncModelRequestSignature[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]: ...
@overload
def amodify_model_request(
func: None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
name: str | None = None,
) -> Callable[
[_AsyncModelRequestSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]
]: ...
def amodify_model_request(
func: _AsyncModelRequestSignature[StateT, ContextT] | None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
name: str | None = None,
) -> (
Callable[[_AsyncModelRequestSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
| AgentMiddleware[StateT, ContextT]
):
r"""Decorator used to dynamically create a middleware with the async modify_model_request hook.
Args:
func: The async function to be decorated. Can accept either:
- `request: ModelRequest, state: StateT` - Model request and agent state
- `request: ModelRequest, state: StateT, runtime: Runtime[ContextT]` -
Model request, state, and runtime context
state_schema: Optional custom state schema type. If not provided, uses the default
AgentState schema.
tools: Optional list of additional tools to register with this middleware.
name: Optional name for the generated middleware class. If not provided,
uses the decorated function's name.
Returns:
Either an AgentMiddleware instance (if func is provided) or a decorator function
that can be applied to a function.
The decorated function should return:
- `ModelRequest` - The modified model request to be sent to the language model
Examples:
Basic usage to modify system prompt:
```python
@amodify_model_request
async def add_context_to_prompt(request: ModelRequest, state: AgentState) -> ModelRequest:
context = await fetch_async_context(state)
if request.system_prompt:
request.system_prompt += f"\n\nAdditional context: {context}"
else:
request.system_prompt = f"Additional context: {context}"
return request
```
"""
def decorator(
func: _AsyncModelRequestSignature[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]:
if is_callable_with_runtime(func):
async def wrapped_with_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
state: StateT,
runtime: Runtime[ContextT],
) -> ModelRequest:
result = func(request, state, runtime) # type: ignore[call-arg]
return await result # type: ignore[misc]
wrapped = wrapped_with_runtime
else:
async def wrapped_without_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
request: ModelRequest,
state: StateT,
) -> ModelRequest:
result = func(request, state) # type: ignore[call-arg]
return await result # type: ignore[misc]
wrapped = wrapped_without_runtime # type: ignore[assignment]
# Use function name as default if no name provided
middleware_name = name or cast(
"str", getattr(func, "__name__", "AsyncModifyModelRequestMiddleware")
)
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"amodify_model_request": wrapped,
},
)()
if func is not None:
return decorator(func)
return decorator
@overload
def aafter_model(
func: _AsyncNodeSignature[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]: ...
@overload
def aafter_model(
func: None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
jump_to: list[JumpTo] | None = None,
name: str | None = None,
) -> Callable[[_AsyncNodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]: ...
def aafter_model(
func: _AsyncNodeSignature[StateT, ContextT] | None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
jump_to: list[JumpTo] | None = None,
name: str | None = None,
) -> (
Callable[[_AsyncNodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
| AgentMiddleware[StateT, ContextT]
):
"""Decorator used to dynamically create a middleware with the async after_model hook.
Args:
func: The async function to be decorated. Can accept either:
- `state: StateT` - Just the agent state (includes model response)
- `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
state_schema: Optional custom state schema type. If not provided, uses the default
AgentState schema.
tools: Optional list of additional tools to register with this middleware.
jump_to: Optional list of valid jump destinations for conditional edges.
Valid values are: "tools", "model", "end"
name: Optional name for the generated middleware class. If not provided,
uses the decorated function's name.
Returns:
Either an AgentMiddleware instance (if func is provided) or a decorator function
that can be applied to a function.
The decorated function should return:
- `dict[str, Any]` - State updates to merge into the agent state
- `Command` - A command to control flow (e.g., jump to different node)
- `None` - No state updates or flow control
Examples:
Basic usage for async logging model responses:
```python
@aafter_model
async def log_latest_message(state: AgentState) -> None:
await async_log(state["messages"][-1].content)
```
With custom state schema:
```python
@aafter_model(state_schema=MyCustomState, name="MyAsyncAfterModelMiddleware")
async def custom_after_model(state: MyCustomState) -> dict[str, Any]:
result = await some_async_operation(state)
return {"custom_field": result}
```
"""
def decorator(func: _AsyncNodeSignature[StateT, ContextT]) -> AgentMiddleware[StateT, ContextT]:
if is_callable_with_runtime(func):
async def wrapped_with_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
runtime: Runtime[ContextT],
) -> NodeReturn:
result = func(state, runtime) # type: ignore[call-arg]
return await result # type: ignore[misc]
wrapped = wrapped_with_runtime
else:
async def wrapped_without_runtime(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
) -> NodeReturn:
result = func(state) # type: ignore[call-arg]
return await result # type: ignore[misc]
wrapped = wrapped_without_runtime # type: ignore[assignment]
# Use function name as default if no name provided
middleware_name = name or cast(
"str", getattr(func, "__name__", "AsyncAfterModelMiddleware")
)
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"after_model_jump_to": jump_to or [],
"aafter_model": wrapped,
},
)()
if func is not None:
return decorator(func)
return decorator

View File

@@ -212,15 +212,28 @@ def create_agent( # noqa: PLR0915
"Please remove duplicate middleware instances."
)
middleware_w_before = [
m for m in middleware if m.__class__.before_model is not AgentMiddleware.before_model
m
for m in middleware
if (
m.__class__.before_model is not AgentMiddleware.before_model
or m.__class__.abefore_model is not AgentMiddleware.abefore_model
)
]
middleware_w_modify_model_request = [
m
for m in middleware
if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request
if (
m.__class__.modify_model_request is not AgentMiddleware.modify_model_request
or m.__class__.amodify_model_request is not AgentMiddleware.amodify_model_request
)
]
middleware_w_after = [
m for m in middleware if m.__class__.after_model is not AgentMiddleware.after_model
m
for m in middleware
if (
m.__class__.after_model is not AgentMiddleware.after_model
or m.__class__.aafter_model is not AgentMiddleware.aafter_model
)
]
state_schemas = {m.state_schema for m in middleware}
@@ -346,16 +359,32 @@ def create_agent( # noqa: PLR0915
)
return request.model.bind(**request.model_settings)
model_request_signatures: list[
tuple[bool, AgentMiddleware[AgentState[ResponseT], ContextT]]
] = [
("runtime" in signature(m.modify_model_request).parameters, m)
for m in middleware_w_modify_model_request
]
# Helper functions for middleware processing
def _has_sync_hook(
middleware: AgentMiddleware[AgentState[ResponseT], ContextT], hook_name: str
) -> bool:
"""Check if middleware has a sync implementation of the given hook."""
return getattr(middleware.__class__, hook_name) is not getattr(AgentMiddleware, hook_name)
def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Sync model request handler with sequential middleware processing."""
request = ModelRequest(
def _has_async_hook(
middleware: AgentMiddleware[AgentState[ResponseT], ContextT], hook_name: str
) -> bool:
"""Check if middleware has an async implementation of the given hook."""
async_hook_name = f"a{hook_name}"
return getattr(middleware.__class__, async_hook_name) is not getattr(
AgentMiddleware, async_hook_name
)
def _uses_runtime(
middleware: AgentMiddleware[AgentState[ResponseT], ContextT], hook_name: str
) -> bool:
"""Check if the hook uses runtime parameter."""
hook_method = getattr(middleware, hook_name)
return "runtime" in signature(hook_method).parameters
def _create_base_model_request(state: AgentState) -> ModelRequest:
"""Create the base model request from state."""
return ModelRequest(
model=model,
tools=default_tools,
system_prompt=system_prompt,
@@ -364,47 +393,71 @@ def create_agent( # noqa: PLR0915
tool_choice=None,
)
# Apply modify_model_request middleware in sequence
for use_runtime, m in model_request_signatures:
if use_runtime:
m.modify_model_request(request, state, runtime)
else:
m.modify_model_request(request, state) # type: ignore[call-arg]
def _apply_sync_modify_middleware(
request: ModelRequest, state: AgentState, runtime: Runtime[ContextT]
) -> ModelRequest:
"""Apply sync modify_model_request middleware in sequence."""
for m in middleware_w_modify_model_request:
if not _has_sync_hook(m, "modify_model_request"):
msg = (
f"Middleware {m.__class__.__name__} only has async modify_model_request hook. "
"Cannot use in sync context. Either provide a sync version "
"or use async execution."
)
raise ValueError(msg)
# Get the final model and messages
if _uses_runtime(m, "modify_model_request"):
request = m.modify_model_request(request, state, runtime)
else:
request = m.modify_model_request(request, state) # type: ignore[call-arg]
return request
async def _apply_async_modify_middleware(
request: ModelRequest, state: AgentState, runtime: Runtime[ContextT]
) -> ModelRequest:
"""Apply async modify_model_request middleware in sequence."""
for m in middleware_w_modify_model_request:
# Prefer async version if available
if _has_async_hook(m, "modify_model_request"):
if _uses_runtime(m, "amodify_model_request"):
request = await m.amodify_model_request(request, state, runtime)
else:
request = await m.amodify_model_request(request, state) # type: ignore[call-arg]
elif _has_sync_hook(m, "modify_model_request"):
# Fall back to sync version
if _uses_runtime(m, "modify_model_request"):
request = m.modify_model_request(request, state, runtime)
else:
request = m.modify_model_request(request, state) # type: ignore[call-arg]
else:
msg = (
f"Middleware {m.__class__.__name__} has no modify_model_request "
"implementation. This should not happen - middleware validation failed."
)
raise ValueError(msg)
return request
def _finalize_model_request(request: ModelRequest) -> tuple[Runnable, list[AnyMessage]]:
"""Finalize the model request and return the bound model and messages."""
model_ = _get_bound_model(request)
messages = request.messages
if request.system_prompt:
messages = [SystemMessage(request.system_prompt), *messages]
return model_, messages
def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Sync model request handler with sequential middleware processing."""
request = _create_base_model_request(state)
request = _apply_sync_modify_middleware(request, state, runtime)
model_, messages = _finalize_model_request(request)
output = model_.invoke(messages)
return _handle_model_output(output)
async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Async model request handler with sequential middleware processing."""
# Start with the base model request
request = ModelRequest(
model=model,
tools=default_tools,
system_prompt=system_prompt,
response_format=response_format,
messages=state["messages"],
tool_choice=None,
)
# Apply modify_model_request middleware in sequence
for use_runtime, m in model_request_signatures:
if use_runtime:
m.modify_model_request(request, state, runtime)
else:
m.modify_model_request(request, state) # type: ignore[call-arg]
# Get the final model and messages
model_ = _get_bound_model(request)
messages = request.messages
if request.system_prompt:
messages = [SystemMessage(request.system_prompt), *messages]
request = _create_base_model_request(state)
request = await _apply_async_modify_middleware(request, state, runtime)
model_, messages = _finalize_model_request(request)
output = await model_.ainvoke(messages)
return _handle_model_output(output)
@@ -417,17 +470,36 @@ def create_agent( # noqa: PLR0915
if tool_node is not None:
graph.add_node("tools", tool_node)
def _add_middleware_node(
middleware: AgentMiddleware[AgentState[ResponseT], ContextT], hook_name: str, node_name: str
) -> None:
"""Add a middleware node to the graph with proper sync/async handling."""
has_sync = _has_sync_hook(middleware, hook_name)
has_async = _has_async_hook(middleware, hook_name)
if has_sync and has_async:
# Both sync and async available - use RunnableCallable
sync_method = getattr(middleware, hook_name)
async_method = getattr(middleware, f"a{hook_name}")
graph.add_node(
node_name, RunnableCallable(sync_method, async_method), input_schema=state_schema
)
elif has_async:
# Only async available
async_method = getattr(middleware, f"a{hook_name}")
graph.add_node(node_name, async_method, input_schema=state_schema)
elif has_sync:
# Only sync available
sync_method = getattr(middleware, hook_name)
graph.add_node(node_name, sync_method, input_schema=state_schema)
# Add middleware nodes
for m in middleware:
if m.__class__.before_model is not AgentMiddleware.before_model:
graph.add_node(
f"{m.__class__.__name__}.before_model", m.before_model, input_schema=state_schema
)
if _has_sync_hook(m, "before_model") or _has_async_hook(m, "before_model"):
_add_middleware_node(m, "before_model", f"{m.__class__.__name__}.before_model")
if m.__class__.after_model is not AgentMiddleware.after_model:
graph.add_node(
f"{m.__class__.__name__}.after_model", m.after_model, input_schema=state_schema
)
if _has_sync_hook(m, "after_model") or _has_async_hook(m, "after_model"):
_add_middleware_node(m, "after_model", f"{m.__class__.__name__}.after_model")
# add start edge
first_node = (

View File

@@ -0,0 +1,366 @@
"""Tests for async middleware decorators: abefore_model, aafter_model, and amodify_model_request."""
import asyncio
from typing import Any
from typing_extensions import NotRequired
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.tools import tool
from langgraph.types import Command
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ModelRequest,
abefore_model,
aafter_model,
amodify_model_request,
)
from langchain.agents.middleware_agent import create_agent
from .model import FakeToolCallingModel
class CustomState(AgentState):
"""Custom state schema for testing."""
custom_field: NotRequired[str]
@tool
def test_tool(input: str) -> str:
"""A test tool for middleware testing."""
return f"Tool result: {input}"
def test_abefore_model_decorator() -> None:
"""Test abefore_model decorator with all configuration options."""
@abefore_model(
state_schema=CustomState, tools=[test_tool], jump_to=["end"], name="CustomAsyncBeforeModel"
)
async def custom_async_before_model(state: CustomState) -> dict[str, Any]:
await asyncio.sleep(0.001) # Simulate async work
return {"jump_to": "end"}
assert isinstance(custom_async_before_model, AgentMiddleware)
assert custom_async_before_model.state_schema == CustomState
assert custom_async_before_model.tools == [test_tool]
assert custom_async_before_model.before_model_jump_to == ["end"]
assert custom_async_before_model.__class__.__name__ == "CustomAsyncBeforeModel"
# Test that the async method was set
assert hasattr(custom_async_before_model, "abefore_model")
result = asyncio.run(
custom_async_before_model.abefore_model({"messages": [HumanMessage("Hello")]})
)
assert result == {"jump_to": "end"}
def test_aafter_model_decorator() -> None:
"""Test aafter_model decorator with all configuration options."""
@aafter_model(
state_schema=CustomState,
tools=[test_tool],
jump_to=["model", "end"],
name="CustomAsyncAfterModel",
)
async def custom_async_after_model(state: CustomState) -> dict[str, Any]:
await asyncio.sleep(0.001) # Simulate async work
return {"jump_to": "model"}
# Verify all options were applied
assert isinstance(custom_async_after_model, AgentMiddleware)
assert custom_async_after_model.state_schema == CustomState
assert custom_async_after_model.tools == [test_tool]
assert custom_async_after_model.after_model_jump_to == ["model", "end"]
assert custom_async_after_model.__class__.__name__ == "CustomAsyncAfterModel"
# Verify it works
result = asyncio.run(
custom_async_after_model.aafter_model(
{"messages": [HumanMessage("Hello"), AIMessage("Hi!")]}
)
)
assert result == {"jump_to": "model"}
def test_amodify_model_request_decorator() -> None:
"""Test amodify_model_request decorator with all configuration options."""
@amodify_model_request(
state_schema=CustomState, tools=[test_tool], name="CustomAsyncModifyRequest"
)
async def custom_async_modify_request(
request: ModelRequest, state: CustomState
) -> ModelRequest:
await asyncio.sleep(0.001) # Simulate async work
request.system_prompt = "Async Modified"
return request
# Verify all options were applied
assert isinstance(custom_async_modify_request, AgentMiddleware)
assert custom_async_modify_request.state_schema == CustomState
assert custom_async_modify_request.tools == [test_tool]
assert custom_async_modify_request.__class__.__name__ == "CustomAsyncModifyRequest"
# Verify it works
original_request = ModelRequest(
model="test-model",
system_prompt="Original",
messages=[HumanMessage("Hello")],
tool_choice=None,
tools=[],
response_format=None,
)
result = asyncio.run(
custom_async_modify_request.amodify_model_request(
original_request, {"messages": [HumanMessage("Hello")]}
)
)
assert result.system_prompt == "Async Modified"
def test_all_async_decorators_integration() -> None:
"""Test all three async decorators working together in an agent."""
call_order = []
@abefore_model
async def track_async_before(state: AgentState) -> None:
await asyncio.sleep(0.001) # Simulate async work
call_order.append("async_before")
return None
@amodify_model_request
async def track_async_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
await asyncio.sleep(0.001) # Simulate async work
call_order.append("async_modify")
return request
@aafter_model
async def track_async_after(state: AgentState) -> None:
await asyncio.sleep(0.001) # Simulate async work
call_order.append("async_after")
return None
agent = create_agent(
model=FakeToolCallingModel(),
middleware=[track_async_before, track_async_modify, track_async_after],
)
agent = agent.compile()
async def run_test():
result = await agent.ainvoke({"messages": [HumanMessage("Hello")]})
return result
asyncio.run(run_test())
assert call_order == ["async_before", "async_modify", "async_after"]
def test_mixed_sync_async_middleware() -> None:
"""Test mixing sync and async middleware in the same agent."""
call_order = []
# Import sync decorators
from langchain.agents.middleware.types import before_model, modify_model_request, after_model
@before_model
def sync_before(state: AgentState) -> None:
call_order.append("sync_before")
return None
@abefore_model
async def async_before(state: AgentState) -> None:
await asyncio.sleep(0.001) # Simulate async work
call_order.append("async_before")
return None
@modify_model_request
def sync_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
call_order.append("sync_modify")
return request
@amodify_model_request
async def async_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
await asyncio.sleep(0.001) # Simulate async work
call_order.append("async_modify")
return request
@after_model
def sync_after(state: AgentState) -> None:
call_order.append("sync_after")
return None
@aafter_model
async def async_after(state: AgentState) -> None:
await asyncio.sleep(0.001) # Simulate async work
call_order.append("async_after")
return None
agent = create_agent(
model=FakeToolCallingModel(),
middleware=[sync_before, async_before, sync_modify, async_modify, sync_after, async_after],
)
agent = agent.compile()
async def run_test():
result = await agent.ainvoke({"messages": [HumanMessage("Hello")]})
return result
asyncio.run(run_test())
# Both sync and async middleware should have run
# Order: before middlewares (sync_before, async_before), modify middlewares (sync_modify, async_modify),
# after middlewares (async_after, sync_after) - note: after hooks run in reverse order
expected_calls = [
"sync_before",
"async_before",
"sync_modify",
"async_modify",
"async_after",
"sync_after",
]
assert call_order == expected_calls
def test_async_decorators_use_function_names_as_default() -> None:
"""Test that async decorators use function names as default middleware names."""
@abefore_model
async def my_async_before_hook(state: AgentState) -> None:
return None
@amodify_model_request
async def my_async_modify_hook(request: ModelRequest, state: AgentState) -> ModelRequest:
return request
@aafter_model
async def my_async_after_hook(state: AgentState) -> None:
return None
# Verify that function names are used as middleware class names
assert my_async_before_hook.__class__.__name__ == "my_async_before_hook"
assert my_async_modify_hook.__class__.__name__ == "my_async_modify_hook"
assert my_async_after_hook.__class__.__name__ == "my_async_after_hook"
def test_async_with_runtime_context() -> None:
"""Test async decorators that use runtime context."""
@abefore_model
async def async_before_with_runtime(state: AgentState, runtime) -> dict[str, Any]:
await asyncio.sleep(0.001) # Simulate async work
# Use runtime context in some way
context_info = getattr(runtime, "context", {})
return {"custom_field": f"processed_with_runtime_{len(context_info)}"}
@amodify_model_request
async def async_modify_with_runtime(
request: ModelRequest, state: AgentState, runtime
) -> ModelRequest:
await asyncio.sleep(0.001) # Simulate async work
# Modify request based on runtime context
request.system_prompt = f"Runtime context available: {runtime is not None}"
return request
@aafter_model
async def async_after_with_runtime(state: AgentState, runtime) -> None:
await asyncio.sleep(0.001) # Simulate async work
# Process state with runtime
return None
# Test that these can be instantiated (runtime context validation will happen at execution)
assert isinstance(async_before_with_runtime, AgentMiddleware)
assert isinstance(async_modify_with_runtime, AgentMiddleware)
assert isinstance(async_after_with_runtime, AgentMiddleware)
def test_sync_execution_with_async_only_middleware_error() -> None:
"""Test that sync execution properly errors when encountering async-only middleware."""
import pytest
@amodify_model_request
async def async_only_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
await asyncio.sleep(0.001) # Simulate async work
request.system_prompt = "Modified by async-only middleware"
return request
# Create an agent with async-only middleware
agent = create_agent(model=FakeToolCallingModel(), middleware=[async_only_modify])
agent = agent.compile()
# Sync execution should raise an error
with pytest.raises(ValueError, match="only has async modify_model_request hook"):
agent.invoke({"messages": [HumanMessage("Hello")]})
# Async execution should work fine
async def run_async_test():
result = await agent.ainvoke({"messages": [HumanMessage("Hello")]})
return result
# This should not raise an error
result = asyncio.run(run_async_test())
assert result is not None
def test_mixed_middleware_execution_contexts() -> None:
"""Test that mixed sync/async middleware works in both execution contexts."""
call_order = []
# Import sync decorators
from langchain.agents.middleware.types import modify_model_request
@modify_model_request
def sync_only_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
call_order.append("sync_only_modify")
return request
@amodify_model_request
async def async_only_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
await asyncio.sleep(0.001)
call_order.append("async_only_modify")
return request
# Create agent with both sync-only and async-only middleware
agent = create_agent(
model=FakeToolCallingModel(), middleware=[sync_only_modify, async_only_modify]
)
agent = agent.compile()
# Test async execution - should work and call both middleware
async def run_async_test():
call_order.clear()
result = await agent.ainvoke({"messages": [HumanMessage("Hello")]})
return result
asyncio.run(run_async_test())
assert "sync_only_modify" in call_order
assert "async_only_modify" in call_order
# Test sync execution - should fail due to async-only middleware
import pytest
call_order.clear()
with pytest.raises(ValueError, match="only has async modify_model_request hook"):
agent.invoke({"messages": [HumanMessage("Hello")]})
def test_error_handling_for_missing_implementations() -> None:
"""Test error handling when middleware validation fails."""
import pytest
# Create a custom middleware with no implementations
class BrokenMiddleware(AgentMiddleware):
def __init__(self):
self.tools = []
broken_middleware = BrokenMiddleware()
# This should not cause immediate error during agent creation
# The validation logic should properly filter out middleware without implementations
agent = create_agent(model=FakeToolCallingModel(), middleware=[broken_middleware])
agent = agent.compile()
# Execution should work normally since the broken middleware has no hooks
result = agent.invoke({"messages": [HumanMessage("Hello")]})
assert result is not None