mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-09 10:41:52 +00:00
Compare commits
3 Commits
langchain=
...
sr/async-i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
286039788e | ||
|
|
41e9d479f5 | ||
|
|
1526c419d8 |
@@ -13,7 +13,6 @@ from typing import (
|
||||
Literal,
|
||||
Protocol,
|
||||
TypeAlias,
|
||||
TypeGuard,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
@@ -23,16 +22,16 @@ from langchain_core.messages import AnyMessage # noqa: TC002
|
||||
from langgraph.channels.ephemeral_value import EphemeralValue
|
||||
from langgraph.graph.message import add_messages
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import Command
|
||||
from langgraph.typing import ContextT
|
||||
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import Command
|
||||
|
||||
from langchain.agents.structured_output import ResponseFormat
|
||||
|
||||
@@ -42,7 +41,12 @@ __all__ = [
|
||||
"ContextT",
|
||||
"ModelRequest",
|
||||
"OmitFromSchema",
|
||||
"PublicAgentState",
|
||||
"aafter_model",
|
||||
"abefore_model",
|
||||
"after_model",
|
||||
"amodify_model_request",
|
||||
"before_model",
|
||||
"modify_model_request",
|
||||
]
|
||||
|
||||
JumpTo = Literal["tools", "model", "end"]
|
||||
@@ -50,6 +54,9 @@ JumpTo = Literal["tools", "model", "end"]
|
||||
|
||||
ResponseT = TypeVar("ResponseT")
|
||||
|
||||
NodeReturn: TypeAlias = dict[str, Any] | Command | None
|
||||
"""Return type for middleware node hooks."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelRequest:
|
||||
@@ -126,9 +133,12 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
after_model_jump_to: ClassVar[list[JumpTo]] = []
|
||||
"""Valid jump destinations for after_model hook. Used to establish conditional edges."""
|
||||
|
||||
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
||||
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> NodeReturn:
|
||||
"""Logic to run before the model is called."""
|
||||
|
||||
async def abefore_model(self, state: StateT, runtime: Runtime[ContextT]) -> NodeReturn:
|
||||
"""Async logic to run before the model is called."""
|
||||
|
||||
def modify_model_request(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
@@ -138,64 +148,80 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
"""Logic to modify request kwargs before the model is called."""
|
||||
return request
|
||||
|
||||
def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
||||
async def amodify_model_request(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
state: StateT, # noqa: ARG002
|
||||
runtime: Runtime[ContextT], # noqa: ARG002
|
||||
) -> ModelRequest:
|
||||
"""Async logic to modify request kwargs before the model is called."""
|
||||
return request
|
||||
|
||||
def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> NodeReturn:
|
||||
"""Logic to run after the model is called."""
|
||||
|
||||
async def aafter_model(self, state: StateT, runtime: Runtime[ContextT]) -> NodeReturn:
|
||||
"""Async logic to run after the model is called."""
|
||||
|
||||
class _CallableWithState(Protocol[StateT_contra]):
|
||||
"""Callable with AgentState as argument."""
|
||||
|
||||
def __call__(self, state: StateT_contra) -> dict[str, Any] | Command | None:
|
||||
"""Perform some logic with the state."""
|
||||
class _NodeCallableWithState(Protocol[StateT_contra]):
|
||||
"""Callable for before/after model hooks with just state (sync or async)."""
|
||||
|
||||
def __call__(self, state: StateT_contra) -> NodeReturn | Awaitable[NodeReturn]:
|
||||
"""Perform logic with the state."""
|
||||
...
|
||||
|
||||
|
||||
class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
||||
"""Callable with AgentState and Runtime as arguments."""
|
||||
class _NodeCallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
||||
"""Callable for before/after model hooks with state and runtime (sync or async)."""
|
||||
|
||||
def __call__(
|
||||
self, state: StateT_contra, runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | Command | None:
|
||||
"""Perform some logic with the state and runtime."""
|
||||
) -> NodeReturn | Awaitable[NodeReturn]:
|
||||
"""Perform logic with the state and runtime."""
|
||||
...
|
||||
|
||||
|
||||
class _CallableWithModelRequestAndState(Protocol[StateT_contra]):
|
||||
"""Callable with ModelRequest and AgentState as arguments."""
|
||||
class _ModelRequestCallableWithState(Protocol[StateT_contra]):
|
||||
"""Callable for modify_model_request hook with state (sync or async)."""
|
||||
|
||||
def __call__(self, request: ModelRequest, state: StateT_contra) -> ModelRequest:
|
||||
"""Perform some logic with the model request and state."""
|
||||
def __call__(
|
||||
self, request: ModelRequest, state: StateT_contra
|
||||
) -> ModelRequest | Awaitable[ModelRequest]:
|
||||
"""Perform logic with the model request and state."""
|
||||
...
|
||||
|
||||
|
||||
class _CallableWithModelRequestAndStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
||||
"""Callable with ModelRequest, AgentState, and Runtime as arguments."""
|
||||
class _ModelRequestCallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
||||
"""Callable for modify_model_request hook with state and runtime (sync or async)."""
|
||||
|
||||
def __call__(
|
||||
self, request: ModelRequest, state: StateT_contra, runtime: Runtime[ContextT]
|
||||
) -> ModelRequest:
|
||||
"""Perform some logic with the model request, state, and runtime."""
|
||||
) -> ModelRequest | Awaitable[ModelRequest]:
|
||||
"""Perform logic with the model request, state, and runtime."""
|
||||
...
|
||||
|
||||
|
||||
_NodeSignature: TypeAlias = (
|
||||
_CallableWithState[StateT] | _CallableWithStateAndRuntime[StateT, ContextT]
|
||||
_NodeCallableWithState[StateT] | _NodeCallableWithStateAndRuntime[StateT, ContextT]
|
||||
)
|
||||
_AsyncNodeSignature: TypeAlias = (
|
||||
_NodeCallableWithState[StateT] | _NodeCallableWithStateAndRuntime[StateT, ContextT]
|
||||
)
|
||||
_ModelRequestSignature: TypeAlias = (
|
||||
_CallableWithModelRequestAndState[StateT]
|
||||
| _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]
|
||||
_ModelRequestCallableWithState[StateT]
|
||||
| _ModelRequestCallableWithStateAndRuntime[StateT, ContextT]
|
||||
)
|
||||
_AsyncModelRequestSignature: TypeAlias = (
|
||||
_ModelRequestCallableWithState[StateT]
|
||||
| _ModelRequestCallableWithStateAndRuntime[StateT, ContextT]
|
||||
)
|
||||
|
||||
|
||||
def is_callable_with_runtime(
|
||||
func: _NodeSignature[StateT, ContextT],
|
||||
) -> TypeGuard[_CallableWithStateAndRuntime[StateT, ContextT]]:
|
||||
return "runtime" in signature(func).parameters
|
||||
|
||||
|
||||
def is_callable_with_runtime_and_request(
|
||||
func: _ModelRequestSignature[StateT, ContextT],
|
||||
) -> TypeGuard[_CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]]:
|
||||
func: _NodeSignature[StateT, ContextT] | _ModelRequestSignature[StateT, ContextT],
|
||||
) -> bool:
|
||||
"""Check if callable accepts runtime parameter."""
|
||||
return "runtime" in signature(func).parameters
|
||||
|
||||
|
||||
@@ -284,8 +310,8 @@ def before_model(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | Command | None:
|
||||
return func(state, runtime)
|
||||
) -> NodeReturn:
|
||||
return func(state, runtime) # type: ignore[call-arg,return-value]
|
||||
|
||||
wrapped = wrapped_with_runtime
|
||||
else:
|
||||
@@ -293,8 +319,8 @@ def before_model(
|
||||
def wrapped_without_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
state: StateT,
|
||||
) -> dict[str, Any] | Command | None:
|
||||
return func(state) # type: ignore[call-arg]
|
||||
) -> NodeReturn:
|
||||
return func(state) # type: ignore[call-arg,return-value]
|
||||
|
||||
wrapped = wrapped_without_runtime # type: ignore[assignment]
|
||||
|
||||
@@ -394,7 +420,7 @@ def modify_model_request(
|
||||
def decorator(
|
||||
func: _ModelRequestSignature[StateT, ContextT],
|
||||
) -> AgentMiddleware[StateT, ContextT]:
|
||||
if is_callable_with_runtime_and_request(func):
|
||||
if is_callable_with_runtime(func):
|
||||
|
||||
def wrapped_with_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
@@ -402,7 +428,7 @@ def modify_model_request(
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> ModelRequest:
|
||||
return func(request, state, runtime)
|
||||
return func(request, state, runtime) # type: ignore[call-arg,return-value]
|
||||
|
||||
wrapped = wrapped_with_runtime
|
||||
else:
|
||||
@@ -412,7 +438,7 @@ def modify_model_request(
|
||||
request: ModelRequest,
|
||||
state: StateT,
|
||||
) -> ModelRequest:
|
||||
return func(request, state) # type: ignore[call-arg]
|
||||
return func(request, state) # type: ignore[call-arg,return-value]
|
||||
|
||||
wrapped = wrapped_without_runtime # type: ignore[assignment]
|
||||
|
||||
@@ -510,8 +536,8 @@ def after_model(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | Command | None:
|
||||
return func(state, runtime)
|
||||
) -> NodeReturn:
|
||||
return func(state, runtime) # type: ignore[call-arg,return-value]
|
||||
|
||||
wrapped = wrapped_with_runtime
|
||||
else:
|
||||
@@ -519,8 +545,8 @@ def after_model(
|
||||
def wrapped_without_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
state: StateT,
|
||||
) -> dict[str, Any] | Command | None:
|
||||
return func(state) # type: ignore[call-arg]
|
||||
) -> NodeReturn:
|
||||
return func(state) # type: ignore[call-arg,return-value]
|
||||
|
||||
wrapped = wrapped_without_runtime # type: ignore[assignment]
|
||||
|
||||
@@ -541,3 +567,339 @@ def after_model(
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
return decorator
|
||||
|
||||
|
||||
@overload
|
||||
def abefore_model(
|
||||
func: _AsyncNodeSignature[StateT, ContextT],
|
||||
) -> AgentMiddleware[StateT, ContextT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def abefore_model(
|
||||
func: None = None,
|
||||
*,
|
||||
state_schema: type[StateT] | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
jump_to: list[JumpTo] | None = None,
|
||||
name: str | None = None,
|
||||
) -> Callable[[_AsyncNodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]: ...
|
||||
|
||||
|
||||
def abefore_model(
|
||||
func: _AsyncNodeSignature[StateT, ContextT] | None = None,
|
||||
*,
|
||||
state_schema: type[StateT] | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
jump_to: list[JumpTo] | None = None,
|
||||
name: str | None = None,
|
||||
) -> (
|
||||
Callable[[_AsyncNodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
|
||||
| AgentMiddleware[StateT, ContextT]
|
||||
):
|
||||
"""Decorator used to dynamically create a middleware with the async before_model hook.
|
||||
|
||||
Args:
|
||||
func: The async function to be decorated. Can accept either:
|
||||
- `state: StateT` - Just the agent state
|
||||
- `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
|
||||
state_schema: Optional custom state schema type. If not provided, uses the default
|
||||
AgentState schema.
|
||||
tools: Optional list of additional tools to register with this middleware.
|
||||
jump_to: Optional list of valid jump destinations for conditional edges.
|
||||
Valid values are: "tools", "model", "end"
|
||||
name: Optional name for the generated middleware class. If not provided,
|
||||
uses the decorated function's name.
|
||||
|
||||
Returns:
|
||||
Either an AgentMiddleware instance (if func is provided directly) or a decorator function
|
||||
that can be applied to a function its wrapping.
|
||||
|
||||
The decorated function should return:
|
||||
- `dict[str, Any]` - State updates to merge into the agent state
|
||||
- `Command` - A command to control flow (e.g., jump to different node)
|
||||
- `None` - No state updates or flow control
|
||||
|
||||
Examples:
|
||||
Basic usage with state only:
|
||||
```python
|
||||
@abefore_model
|
||||
async def log_before_model(state: AgentState) -> None:
|
||||
print(f"About to call model with {len(state['messages'])} messages")
|
||||
```
|
||||
|
||||
Advanced usage with runtime and conditional jumping:
|
||||
```python
|
||||
@abefore_model(jump_to=["end"])
|
||||
async def conditional_before_model(
|
||||
state: AgentState, runtime: Runtime
|
||||
) -> dict[str, Any] | None:
|
||||
if await some_async_condition(state):
|
||||
return {"jump_to": "end"}
|
||||
return None
|
||||
```
|
||||
"""
|
||||
|
||||
def decorator(func: _AsyncNodeSignature[StateT, ContextT]) -> AgentMiddleware[StateT, ContextT]:
|
||||
if is_callable_with_runtime(func):
|
||||
|
||||
async def wrapped_with_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> NodeReturn:
|
||||
result = func(state, runtime) # type: ignore[call-arg]
|
||||
return await result # type: ignore[misc]
|
||||
|
||||
wrapped = wrapped_with_runtime
|
||||
else:
|
||||
|
||||
async def wrapped_without_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
state: StateT,
|
||||
) -> NodeReturn:
|
||||
result = func(state) # type: ignore[call-arg]
|
||||
return await result # type: ignore[misc]
|
||||
|
||||
wrapped = wrapped_without_runtime # type: ignore[assignment]
|
||||
|
||||
# Use function name as default if no name provided
|
||||
middleware_name = name or cast(
|
||||
"str", getattr(func, "__name__", "AsyncBeforeModelMiddleware")
|
||||
)
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"before_model_jump_to": jump_to or [],
|
||||
"abefore_model": wrapped,
|
||||
},
|
||||
)()
|
||||
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
return decorator
|
||||
|
||||
|
||||
@overload
|
||||
def amodify_model_request(
|
||||
func: _AsyncModelRequestSignature[StateT, ContextT],
|
||||
) -> AgentMiddleware[StateT, ContextT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def amodify_model_request(
|
||||
func: None = None,
|
||||
*,
|
||||
state_schema: type[StateT] | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
name: str | None = None,
|
||||
) -> Callable[
|
||||
[_AsyncModelRequestSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]
|
||||
]: ...
|
||||
|
||||
|
||||
def amodify_model_request(
|
||||
func: _AsyncModelRequestSignature[StateT, ContextT] | None = None,
|
||||
*,
|
||||
state_schema: type[StateT] | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
name: str | None = None,
|
||||
) -> (
|
||||
Callable[[_AsyncModelRequestSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
|
||||
| AgentMiddleware[StateT, ContextT]
|
||||
):
|
||||
r"""Decorator used to dynamically create a middleware with the async modify_model_request hook.
|
||||
|
||||
Args:
|
||||
func: The async function to be decorated. Can accept either:
|
||||
- `request: ModelRequest, state: StateT` - Model request and agent state
|
||||
- `request: ModelRequest, state: StateT, runtime: Runtime[ContextT]` -
|
||||
Model request, state, and runtime context
|
||||
state_schema: Optional custom state schema type. If not provided, uses the default
|
||||
AgentState schema.
|
||||
tools: Optional list of additional tools to register with this middleware.
|
||||
name: Optional name for the generated middleware class. If not provided,
|
||||
uses the decorated function's name.
|
||||
|
||||
Returns:
|
||||
Either an AgentMiddleware instance (if func is provided) or a decorator function
|
||||
that can be applied to a function.
|
||||
|
||||
The decorated function should return:
|
||||
- `ModelRequest` - The modified model request to be sent to the language model
|
||||
|
||||
Examples:
|
||||
Basic usage to modify system prompt:
|
||||
```python
|
||||
@amodify_model_request
|
||||
async def add_context_to_prompt(request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||
context = await fetch_async_context(state)
|
||||
if request.system_prompt:
|
||||
request.system_prompt += f"\n\nAdditional context: {context}"
|
||||
else:
|
||||
request.system_prompt = f"Additional context: {context}"
|
||||
return request
|
||||
```
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: _AsyncModelRequestSignature[StateT, ContextT],
|
||||
) -> AgentMiddleware[StateT, ContextT]:
|
||||
if is_callable_with_runtime(func):
|
||||
|
||||
async def wrapped_with_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
request: ModelRequest,
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> ModelRequest:
|
||||
result = func(request, state, runtime) # type: ignore[call-arg]
|
||||
return await result # type: ignore[misc]
|
||||
|
||||
wrapped = wrapped_with_runtime
|
||||
else:
|
||||
|
||||
async def wrapped_without_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
request: ModelRequest,
|
||||
state: StateT,
|
||||
) -> ModelRequest:
|
||||
result = func(request, state) # type: ignore[call-arg]
|
||||
return await result # type: ignore[misc]
|
||||
|
||||
wrapped = wrapped_without_runtime # type: ignore[assignment]
|
||||
|
||||
# Use function name as default if no name provided
|
||||
middleware_name = name or cast(
|
||||
"str", getattr(func, "__name__", "AsyncModifyModelRequestMiddleware")
|
||||
)
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"amodify_model_request": wrapped,
|
||||
},
|
||||
)()
|
||||
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
return decorator
|
||||
|
||||
|
||||
@overload
|
||||
def aafter_model(
|
||||
func: _AsyncNodeSignature[StateT, ContextT],
|
||||
) -> AgentMiddleware[StateT, ContextT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def aafter_model(
|
||||
func: None = None,
|
||||
*,
|
||||
state_schema: type[StateT] | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
jump_to: list[JumpTo] | None = None,
|
||||
name: str | None = None,
|
||||
) -> Callable[[_AsyncNodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]: ...
|
||||
|
||||
|
||||
def aafter_model(
|
||||
func: _AsyncNodeSignature[StateT, ContextT] | None = None,
|
||||
*,
|
||||
state_schema: type[StateT] | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
jump_to: list[JumpTo] | None = None,
|
||||
name: str | None = None,
|
||||
) -> (
|
||||
Callable[[_AsyncNodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
|
||||
| AgentMiddleware[StateT, ContextT]
|
||||
):
|
||||
"""Decorator used to dynamically create a middleware with the async after_model hook.
|
||||
|
||||
Args:
|
||||
func: The async function to be decorated. Can accept either:
|
||||
- `state: StateT` - Just the agent state (includes model response)
|
||||
- `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
|
||||
state_schema: Optional custom state schema type. If not provided, uses the default
|
||||
AgentState schema.
|
||||
tools: Optional list of additional tools to register with this middleware.
|
||||
jump_to: Optional list of valid jump destinations for conditional edges.
|
||||
Valid values are: "tools", "model", "end"
|
||||
name: Optional name for the generated middleware class. If not provided,
|
||||
uses the decorated function's name.
|
||||
|
||||
Returns:
|
||||
Either an AgentMiddleware instance (if func is provided) or a decorator function
|
||||
that can be applied to a function.
|
||||
|
||||
The decorated function should return:
|
||||
- `dict[str, Any]` - State updates to merge into the agent state
|
||||
- `Command` - A command to control flow (e.g., jump to different node)
|
||||
- `None` - No state updates or flow control
|
||||
|
||||
Examples:
|
||||
Basic usage for async logging model responses:
|
||||
```python
|
||||
@aafter_model
|
||||
async def log_latest_message(state: AgentState) -> None:
|
||||
await async_log(state["messages"][-1].content)
|
||||
```
|
||||
|
||||
With custom state schema:
|
||||
```python
|
||||
@aafter_model(state_schema=MyCustomState, name="MyAsyncAfterModelMiddleware")
|
||||
async def custom_after_model(state: MyCustomState) -> dict[str, Any]:
|
||||
result = await some_async_operation(state)
|
||||
return {"custom_field": result}
|
||||
```
|
||||
"""
|
||||
|
||||
def decorator(func: _AsyncNodeSignature[StateT, ContextT]) -> AgentMiddleware[StateT, ContextT]:
|
||||
if is_callable_with_runtime(func):
|
||||
|
||||
async def wrapped_with_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> NodeReturn:
|
||||
result = func(state, runtime) # type: ignore[call-arg]
|
||||
return await result # type: ignore[misc]
|
||||
|
||||
wrapped = wrapped_with_runtime
|
||||
else:
|
||||
|
||||
async def wrapped_without_runtime(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
state: StateT,
|
||||
) -> NodeReturn:
|
||||
result = func(state) # type: ignore[call-arg]
|
||||
return await result # type: ignore[misc]
|
||||
|
||||
wrapped = wrapped_without_runtime # type: ignore[assignment]
|
||||
|
||||
# Use function name as default if no name provided
|
||||
middleware_name = name or cast(
|
||||
"str", getattr(func, "__name__", "AsyncAfterModelMiddleware")
|
||||
)
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"after_model_jump_to": jump_to or [],
|
||||
"aafter_model": wrapped,
|
||||
},
|
||||
)()
|
||||
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
return decorator
|
||||
|
||||
@@ -212,15 +212,28 @@ def create_agent( # noqa: PLR0915
|
||||
"Please remove duplicate middleware instances."
|
||||
)
|
||||
middleware_w_before = [
|
||||
m for m in middleware if m.__class__.before_model is not AgentMiddleware.before_model
|
||||
m
|
||||
for m in middleware
|
||||
if (
|
||||
m.__class__.before_model is not AgentMiddleware.before_model
|
||||
or m.__class__.abefore_model is not AgentMiddleware.abefore_model
|
||||
)
|
||||
]
|
||||
middleware_w_modify_model_request = [
|
||||
m
|
||||
for m in middleware
|
||||
if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request
|
||||
if (
|
||||
m.__class__.modify_model_request is not AgentMiddleware.modify_model_request
|
||||
or m.__class__.amodify_model_request is not AgentMiddleware.amodify_model_request
|
||||
)
|
||||
]
|
||||
middleware_w_after = [
|
||||
m for m in middleware if m.__class__.after_model is not AgentMiddleware.after_model
|
||||
m
|
||||
for m in middleware
|
||||
if (
|
||||
m.__class__.after_model is not AgentMiddleware.after_model
|
||||
or m.__class__.aafter_model is not AgentMiddleware.aafter_model
|
||||
)
|
||||
]
|
||||
|
||||
state_schemas = {m.state_schema for m in middleware}
|
||||
@@ -346,16 +359,32 @@ def create_agent( # noqa: PLR0915
|
||||
)
|
||||
return request.model.bind(**request.model_settings)
|
||||
|
||||
model_request_signatures: list[
|
||||
tuple[bool, AgentMiddleware[AgentState[ResponseT], ContextT]]
|
||||
] = [
|
||||
("runtime" in signature(m.modify_model_request).parameters, m)
|
||||
for m in middleware_w_modify_model_request
|
||||
]
|
||||
# Helper functions for middleware processing
|
||||
def _has_sync_hook(
|
||||
middleware: AgentMiddleware[AgentState[ResponseT], ContextT], hook_name: str
|
||||
) -> bool:
|
||||
"""Check if middleware has a sync implementation of the given hook."""
|
||||
return getattr(middleware.__class__, hook_name) is not getattr(AgentMiddleware, hook_name)
|
||||
|
||||
def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
"""Sync model request handler with sequential middleware processing."""
|
||||
request = ModelRequest(
|
||||
def _has_async_hook(
|
||||
middleware: AgentMiddleware[AgentState[ResponseT], ContextT], hook_name: str
|
||||
) -> bool:
|
||||
"""Check if middleware has an async implementation of the given hook."""
|
||||
async_hook_name = f"a{hook_name}"
|
||||
return getattr(middleware.__class__, async_hook_name) is not getattr(
|
||||
AgentMiddleware, async_hook_name
|
||||
)
|
||||
|
||||
def _uses_runtime(
|
||||
middleware: AgentMiddleware[AgentState[ResponseT], ContextT], hook_name: str
|
||||
) -> bool:
|
||||
"""Check if the hook uses runtime parameter."""
|
||||
hook_method = getattr(middleware, hook_name)
|
||||
return "runtime" in signature(hook_method).parameters
|
||||
|
||||
def _create_base_model_request(state: AgentState) -> ModelRequest:
|
||||
"""Create the base model request from state."""
|
||||
return ModelRequest(
|
||||
model=model,
|
||||
tools=default_tools,
|
||||
system_prompt=system_prompt,
|
||||
@@ -364,47 +393,71 @@ def create_agent( # noqa: PLR0915
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
# Apply modify_model_request middleware in sequence
|
||||
for use_runtime, m in model_request_signatures:
|
||||
if use_runtime:
|
||||
m.modify_model_request(request, state, runtime)
|
||||
else:
|
||||
m.modify_model_request(request, state) # type: ignore[call-arg]
|
||||
def _apply_sync_modify_middleware(
|
||||
request: ModelRequest, state: AgentState, runtime: Runtime[ContextT]
|
||||
) -> ModelRequest:
|
||||
"""Apply sync modify_model_request middleware in sequence."""
|
||||
for m in middleware_w_modify_model_request:
|
||||
if not _has_sync_hook(m, "modify_model_request"):
|
||||
msg = (
|
||||
f"Middleware {m.__class__.__name__} only has async modify_model_request hook. "
|
||||
"Cannot use in sync context. Either provide a sync version "
|
||||
"or use async execution."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
# Get the final model and messages
|
||||
if _uses_runtime(m, "modify_model_request"):
|
||||
request = m.modify_model_request(request, state, runtime)
|
||||
else:
|
||||
request = m.modify_model_request(request, state) # type: ignore[call-arg]
|
||||
return request
|
||||
|
||||
async def _apply_async_modify_middleware(
|
||||
request: ModelRequest, state: AgentState, runtime: Runtime[ContextT]
|
||||
) -> ModelRequest:
|
||||
"""Apply async modify_model_request middleware in sequence."""
|
||||
for m in middleware_w_modify_model_request:
|
||||
# Prefer async version if available
|
||||
if _has_async_hook(m, "modify_model_request"):
|
||||
if _uses_runtime(m, "amodify_model_request"):
|
||||
request = await m.amodify_model_request(request, state, runtime)
|
||||
else:
|
||||
request = await m.amodify_model_request(request, state) # type: ignore[call-arg]
|
||||
elif _has_sync_hook(m, "modify_model_request"):
|
||||
# Fall back to sync version
|
||||
if _uses_runtime(m, "modify_model_request"):
|
||||
request = m.modify_model_request(request, state, runtime)
|
||||
else:
|
||||
request = m.modify_model_request(request, state) # type: ignore[call-arg]
|
||||
else:
|
||||
msg = (
|
||||
f"Middleware {m.__class__.__name__} has no modify_model_request "
|
||||
"implementation. This should not happen - middleware validation failed."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return request
|
||||
|
||||
def _finalize_model_request(request: ModelRequest) -> tuple[Runnable, list[AnyMessage]]:
|
||||
"""Finalize the model request and return the bound model and messages."""
|
||||
model_ = _get_bound_model(request)
|
||||
messages = request.messages
|
||||
if request.system_prompt:
|
||||
messages = [SystemMessage(request.system_prompt), *messages]
|
||||
return model_, messages
|
||||
|
||||
def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
"""Sync model request handler with sequential middleware processing."""
|
||||
request = _create_base_model_request(state)
|
||||
request = _apply_sync_modify_middleware(request, state, runtime)
|
||||
model_, messages = _finalize_model_request(request)
|
||||
output = model_.invoke(messages)
|
||||
return _handle_model_output(output)
|
||||
|
||||
async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
"""Async model request handler with sequential middleware processing."""
|
||||
# Start with the base model request
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
tools=default_tools,
|
||||
system_prompt=system_prompt,
|
||||
response_format=response_format,
|
||||
messages=state["messages"],
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
# Apply modify_model_request middleware in sequence
|
||||
for use_runtime, m in model_request_signatures:
|
||||
if use_runtime:
|
||||
m.modify_model_request(request, state, runtime)
|
||||
else:
|
||||
m.modify_model_request(request, state) # type: ignore[call-arg]
|
||||
|
||||
# Get the final model and messages
|
||||
model_ = _get_bound_model(request)
|
||||
messages = request.messages
|
||||
if request.system_prompt:
|
||||
messages = [SystemMessage(request.system_prompt), *messages]
|
||||
|
||||
request = _create_base_model_request(state)
|
||||
request = await _apply_async_modify_middleware(request, state, runtime)
|
||||
model_, messages = _finalize_model_request(request)
|
||||
output = await model_.ainvoke(messages)
|
||||
return _handle_model_output(output)
|
||||
|
||||
@@ -417,17 +470,36 @@ def create_agent( # noqa: PLR0915
|
||||
if tool_node is not None:
|
||||
graph.add_node("tools", tool_node)
|
||||
|
||||
def _add_middleware_node(
|
||||
middleware: AgentMiddleware[AgentState[ResponseT], ContextT], hook_name: str, node_name: str
|
||||
) -> None:
|
||||
"""Add a middleware node to the graph with proper sync/async handling."""
|
||||
has_sync = _has_sync_hook(middleware, hook_name)
|
||||
has_async = _has_async_hook(middleware, hook_name)
|
||||
|
||||
if has_sync and has_async:
|
||||
# Both sync and async available - use RunnableCallable
|
||||
sync_method = getattr(middleware, hook_name)
|
||||
async_method = getattr(middleware, f"a{hook_name}")
|
||||
graph.add_node(
|
||||
node_name, RunnableCallable(sync_method, async_method), input_schema=state_schema
|
||||
)
|
||||
elif has_async:
|
||||
# Only async available
|
||||
async_method = getattr(middleware, f"a{hook_name}")
|
||||
graph.add_node(node_name, async_method, input_schema=state_schema)
|
||||
elif has_sync:
|
||||
# Only sync available
|
||||
sync_method = getattr(middleware, hook_name)
|
||||
graph.add_node(node_name, sync_method, input_schema=state_schema)
|
||||
|
||||
# Add middleware nodes
|
||||
for m in middleware:
|
||||
if m.__class__.before_model is not AgentMiddleware.before_model:
|
||||
graph.add_node(
|
||||
f"{m.__class__.__name__}.before_model", m.before_model, input_schema=state_schema
|
||||
)
|
||||
if _has_sync_hook(m, "before_model") or _has_async_hook(m, "before_model"):
|
||||
_add_middleware_node(m, "before_model", f"{m.__class__.__name__}.before_model")
|
||||
|
||||
if m.__class__.after_model is not AgentMiddleware.after_model:
|
||||
graph.add_node(
|
||||
f"{m.__class__.__name__}.after_model", m.after_model, input_schema=state_schema
|
||||
)
|
||||
if _has_sync_hook(m, "after_model") or _has_async_hook(m, "after_model"):
|
||||
_add_middleware_node(m, "after_model", f"{m.__class__.__name__}.after_model")
|
||||
|
||||
# add start edge
|
||||
first_node = (
|
||||
|
||||
@@ -0,0 +1,366 @@
|
||||
"""Tests for async middleware decorators: abefore_model, aafter_model, and amodify_model_request."""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.types import Command
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ModelRequest,
|
||||
abefore_model,
|
||||
aafter_model,
|
||||
amodify_model_request,
|
||||
)
|
||||
from langchain.agents.middleware_agent import create_agent
|
||||
from .model import FakeToolCallingModel
|
||||
|
||||
|
||||
class CustomState(AgentState):
|
||||
"""Custom state schema for testing."""
|
||||
|
||||
custom_field: NotRequired[str]
|
||||
|
||||
|
||||
@tool
|
||||
def test_tool(input: str) -> str:
|
||||
"""A test tool for middleware testing."""
|
||||
return f"Tool result: {input}"
|
||||
|
||||
|
||||
def test_abefore_model_decorator() -> None:
|
||||
"""Test abefore_model decorator with all configuration options."""
|
||||
|
||||
@abefore_model(
|
||||
state_schema=CustomState, tools=[test_tool], jump_to=["end"], name="CustomAsyncBeforeModel"
|
||||
)
|
||||
async def custom_async_before_model(state: CustomState) -> dict[str, Any]:
|
||||
await asyncio.sleep(0.001) # Simulate async work
|
||||
return {"jump_to": "end"}
|
||||
|
||||
assert isinstance(custom_async_before_model, AgentMiddleware)
|
||||
assert custom_async_before_model.state_schema == CustomState
|
||||
assert custom_async_before_model.tools == [test_tool]
|
||||
assert custom_async_before_model.before_model_jump_to == ["end"]
|
||||
assert custom_async_before_model.__class__.__name__ == "CustomAsyncBeforeModel"
|
||||
|
||||
# Test that the async method was set
|
||||
assert hasattr(custom_async_before_model, "abefore_model")
|
||||
result = asyncio.run(
|
||||
custom_async_before_model.abefore_model({"messages": [HumanMessage("Hello")]})
|
||||
)
|
||||
assert result == {"jump_to": "end"}
|
||||
|
||||
|
||||
def test_aafter_model_decorator() -> None:
|
||||
"""Test aafter_model decorator with all configuration options."""
|
||||
|
||||
@aafter_model(
|
||||
state_schema=CustomState,
|
||||
tools=[test_tool],
|
||||
jump_to=["model", "end"],
|
||||
name="CustomAsyncAfterModel",
|
||||
)
|
||||
async def custom_async_after_model(state: CustomState) -> dict[str, Any]:
|
||||
await asyncio.sleep(0.001) # Simulate async work
|
||||
return {"jump_to": "model"}
|
||||
|
||||
# Verify all options were applied
|
||||
assert isinstance(custom_async_after_model, AgentMiddleware)
|
||||
assert custom_async_after_model.state_schema == CustomState
|
||||
assert custom_async_after_model.tools == [test_tool]
|
||||
assert custom_async_after_model.after_model_jump_to == ["model", "end"]
|
||||
assert custom_async_after_model.__class__.__name__ == "CustomAsyncAfterModel"
|
||||
|
||||
# Verify it works
|
||||
result = asyncio.run(
|
||||
custom_async_after_model.aafter_model(
|
||||
{"messages": [HumanMessage("Hello"), AIMessage("Hi!")]}
|
||||
)
|
||||
)
|
||||
assert result == {"jump_to": "model"}
|
||||
|
||||
|
||||
def test_amodify_model_request_decorator() -> None:
|
||||
"""Test amodify_model_request decorator with all configuration options."""
|
||||
|
||||
@amodify_model_request(
|
||||
state_schema=CustomState, tools=[test_tool], name="CustomAsyncModifyRequest"
|
||||
)
|
||||
async def custom_async_modify_request(
|
||||
request: ModelRequest, state: CustomState
|
||||
) -> ModelRequest:
|
||||
await asyncio.sleep(0.001) # Simulate async work
|
||||
request.system_prompt = "Async Modified"
|
||||
return request
|
||||
|
||||
# Verify all options were applied
|
||||
assert isinstance(custom_async_modify_request, AgentMiddleware)
|
||||
assert custom_async_modify_request.state_schema == CustomState
|
||||
assert custom_async_modify_request.tools == [test_tool]
|
||||
assert custom_async_modify_request.__class__.__name__ == "CustomAsyncModifyRequest"
|
||||
|
||||
# Verify it works
|
||||
original_request = ModelRequest(
|
||||
model="test-model",
|
||||
system_prompt="Original",
|
||||
messages=[HumanMessage("Hello")],
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
)
|
||||
result = asyncio.run(
|
||||
custom_async_modify_request.amodify_model_request(
|
||||
original_request, {"messages": [HumanMessage("Hello")]}
|
||||
)
|
||||
)
|
||||
assert result.system_prompt == "Async Modified"
|
||||
|
||||
|
||||
def test_all_async_decorators_integration() -> None:
|
||||
"""Test all three async decorators working together in an agent."""
|
||||
call_order = []
|
||||
|
||||
@abefore_model
|
||||
async def track_async_before(state: AgentState) -> None:
|
||||
await asyncio.sleep(0.001) # Simulate async work
|
||||
call_order.append("async_before")
|
||||
return None
|
||||
|
||||
@amodify_model_request
|
||||
async def track_async_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||
await asyncio.sleep(0.001) # Simulate async work
|
||||
call_order.append("async_modify")
|
||||
return request
|
||||
|
||||
@aafter_model
|
||||
async def track_async_after(state: AgentState) -> None:
|
||||
await asyncio.sleep(0.001) # Simulate async work
|
||||
call_order.append("async_after")
|
||||
return None
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
middleware=[track_async_before, track_async_modify, track_async_after],
|
||||
)
|
||||
agent = agent.compile()
|
||||
|
||||
async def run_test():
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("Hello")]})
|
||||
return result
|
||||
|
||||
asyncio.run(run_test())
|
||||
assert call_order == ["async_before", "async_modify", "async_after"]
|
||||
|
||||
|
||||
def test_mixed_sync_async_middleware() -> None:
|
||||
"""Test mixing sync and async middleware in the same agent."""
|
||||
call_order = []
|
||||
|
||||
# Import sync decorators
|
||||
from langchain.agents.middleware.types import before_model, modify_model_request, after_model
|
||||
|
||||
@before_model
|
||||
def sync_before(state: AgentState) -> None:
|
||||
call_order.append("sync_before")
|
||||
return None
|
||||
|
||||
@abefore_model
|
||||
async def async_before(state: AgentState) -> None:
|
||||
await asyncio.sleep(0.001) # Simulate async work
|
||||
call_order.append("async_before")
|
||||
return None
|
||||
|
||||
@modify_model_request
|
||||
def sync_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||
call_order.append("sync_modify")
|
||||
return request
|
||||
|
||||
@amodify_model_request
|
||||
async def async_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||
await asyncio.sleep(0.001) # Simulate async work
|
||||
call_order.append("async_modify")
|
||||
return request
|
||||
|
||||
@after_model
|
||||
def sync_after(state: AgentState) -> None:
|
||||
call_order.append("sync_after")
|
||||
return None
|
||||
|
||||
@aafter_model
|
||||
async def async_after(state: AgentState) -> None:
|
||||
await asyncio.sleep(0.001) # Simulate async work
|
||||
call_order.append("async_after")
|
||||
return None
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(),
|
||||
middleware=[sync_before, async_before, sync_modify, async_modify, sync_after, async_after],
|
||||
)
|
||||
agent = agent.compile()
|
||||
|
||||
async def run_test():
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("Hello")]})
|
||||
return result
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
# Both sync and async middleware should have run
|
||||
# Order: before middlewares (sync_before, async_before), modify middlewares (sync_modify, async_modify),
|
||||
# after middlewares (async_after, sync_after) - note: after hooks run in reverse order
|
||||
expected_calls = [
|
||||
"sync_before",
|
||||
"async_before",
|
||||
"sync_modify",
|
||||
"async_modify",
|
||||
"async_after",
|
||||
"sync_after",
|
||||
]
|
||||
assert call_order == expected_calls
|
||||
|
||||
|
||||
def test_async_decorators_use_function_names_as_default() -> None:
|
||||
"""Test that async decorators use function names as default middleware names."""
|
||||
|
||||
@abefore_model
|
||||
async def my_async_before_hook(state: AgentState) -> None:
|
||||
return None
|
||||
|
||||
@amodify_model_request
|
||||
async def my_async_modify_hook(request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||
return request
|
||||
|
||||
@aafter_model
|
||||
async def my_async_after_hook(state: AgentState) -> None:
|
||||
return None
|
||||
|
||||
# Verify that function names are used as middleware class names
|
||||
assert my_async_before_hook.__class__.__name__ == "my_async_before_hook"
|
||||
assert my_async_modify_hook.__class__.__name__ == "my_async_modify_hook"
|
||||
assert my_async_after_hook.__class__.__name__ == "my_async_after_hook"
|
||||
|
||||
|
||||
def test_async_with_runtime_context() -> None:
|
||||
"""Test async decorators that use runtime context."""
|
||||
|
||||
@abefore_model
|
||||
async def async_before_with_runtime(state: AgentState, runtime) -> dict[str, Any]:
|
||||
await asyncio.sleep(0.001) # Simulate async work
|
||||
# Use runtime context in some way
|
||||
context_info = getattr(runtime, "context", {})
|
||||
return {"custom_field": f"processed_with_runtime_{len(context_info)}"}
|
||||
|
||||
@amodify_model_request
|
||||
async def async_modify_with_runtime(
|
||||
request: ModelRequest, state: AgentState, runtime
|
||||
) -> ModelRequest:
|
||||
await asyncio.sleep(0.001) # Simulate async work
|
||||
# Modify request based on runtime context
|
||||
request.system_prompt = f"Runtime context available: {runtime is not None}"
|
||||
return request
|
||||
|
||||
@aafter_model
|
||||
async def async_after_with_runtime(state: AgentState, runtime) -> None:
|
||||
await asyncio.sleep(0.001) # Simulate async work
|
||||
# Process state with runtime
|
||||
return None
|
||||
|
||||
# Test that these can be instantiated (runtime context validation will happen at execution)
|
||||
assert isinstance(async_before_with_runtime, AgentMiddleware)
|
||||
assert isinstance(async_modify_with_runtime, AgentMiddleware)
|
||||
assert isinstance(async_after_with_runtime, AgentMiddleware)
|
||||
|
||||
|
||||
def test_sync_execution_with_async_only_middleware_error() -> None:
|
||||
"""Test that sync execution properly errors when encountering async-only middleware."""
|
||||
import pytest
|
||||
|
||||
@amodify_model_request
|
||||
async def async_only_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||
await asyncio.sleep(0.001) # Simulate async work
|
||||
request.system_prompt = "Modified by async-only middleware"
|
||||
return request
|
||||
|
||||
# Create an agent with async-only middleware
|
||||
agent = create_agent(model=FakeToolCallingModel(), middleware=[async_only_modify])
|
||||
agent = agent.compile()
|
||||
|
||||
# Sync execution should raise an error
|
||||
with pytest.raises(ValueError, match="only has async modify_model_request hook"):
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
# Async execution should work fine
|
||||
async def run_async_test():
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("Hello")]})
|
||||
return result
|
||||
|
||||
# This should not raise an error
|
||||
result = asyncio.run(run_async_test())
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_mixed_middleware_execution_contexts() -> None:
|
||||
"""Test that mixed sync/async middleware works in both execution contexts."""
|
||||
call_order = []
|
||||
|
||||
# Import sync decorators
|
||||
from langchain.agents.middleware.types import modify_model_request
|
||||
|
||||
@modify_model_request
|
||||
def sync_only_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||
call_order.append("sync_only_modify")
|
||||
return request
|
||||
|
||||
@amodify_model_request
|
||||
async def async_only_modify(request: ModelRequest, state: AgentState) -> ModelRequest:
|
||||
await asyncio.sleep(0.001)
|
||||
call_order.append("async_only_modify")
|
||||
return request
|
||||
|
||||
# Create agent with both sync-only and async-only middleware
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(), middleware=[sync_only_modify, async_only_modify]
|
||||
)
|
||||
agent = agent.compile()
|
||||
|
||||
# Test async execution - should work and call both middleware
|
||||
async def run_async_test():
|
||||
call_order.clear()
|
||||
result = await agent.ainvoke({"messages": [HumanMessage("Hello")]})
|
||||
return result
|
||||
|
||||
asyncio.run(run_async_test())
|
||||
assert "sync_only_modify" in call_order
|
||||
assert "async_only_modify" in call_order
|
||||
|
||||
# Test sync execution - should fail due to async-only middleware
|
||||
import pytest
|
||||
|
||||
call_order.clear()
|
||||
with pytest.raises(ValueError, match="only has async modify_model_request hook"):
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
|
||||
|
||||
def test_error_handling_for_missing_implementations() -> None:
|
||||
"""Test error handling when middleware validation fails."""
|
||||
import pytest
|
||||
|
||||
# Create a custom middleware with no implementations
|
||||
class BrokenMiddleware(AgentMiddleware):
|
||||
def __init__(self):
|
||||
self.tools = []
|
||||
|
||||
broken_middleware = BrokenMiddleware()
|
||||
|
||||
# This should not cause immediate error during agent creation
|
||||
# The validation logic should properly filter out middleware without implementations
|
||||
agent = create_agent(model=FakeToolCallingModel(), middleware=[broken_middleware])
|
||||
agent = agent.compile()
|
||||
|
||||
# Execution should work normally since the broken middleware has no hooks
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")]})
|
||||
assert result is not None
|
||||
Reference in New Issue
Block a user