mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-09 02:33:34 +00:00
Compare commits
6 Commits
langchain-
...
sr/refacto
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5b97dada4d | ||
|
|
b276d34a0b | ||
|
|
5441475ac9 | ||
|
|
22fb405b45 | ||
|
|
6e07fc7982 | ||
|
|
983d84ade8 |
@@ -12,7 +12,9 @@ from .types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ModelRequest,
|
||||
after_agent,
|
||||
after_model,
|
||||
before_agent,
|
||||
before_model,
|
||||
dynamic_prompt,
|
||||
hook_config,
|
||||
@@ -33,7 +35,9 @@ __all__ = [
|
||||
"PlanningMiddleware",
|
||||
"SummarizationMiddleware",
|
||||
"ToolCallLimitMiddleware",
|
||||
"after_agent",
|
||||
"after_model",
|
||||
"before_agent",
|
||||
"before_model",
|
||||
"dynamic_prompt",
|
||||
"hook_config",
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from inspect import iscoroutinefunction
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -41,11 +42,18 @@ __all__ = [
|
||||
"AgentMiddleware",
|
||||
"AgentState",
|
||||
"ContextT",
|
||||
"HookImplementation",
|
||||
"MiddlewareHookInfo",
|
||||
"ModelRequest",
|
||||
"OmitFromSchema",
|
||||
"PublicAgentState",
|
||||
"after_agent",
|
||||
"after_model",
|
||||
"before_agent",
|
||||
"before_model",
|
||||
"dynamic_prompt",
|
||||
"hook_config",
|
||||
"modify_model_request",
|
||||
]
|
||||
|
||||
JumpTo = Literal["tools", "model", "end"]
|
||||
@@ -54,6 +62,15 @@ JumpTo = Literal["tools", "model", "end"]
|
||||
ResponseT = TypeVar("ResponseT")
|
||||
|
||||
|
||||
class HookImplementation(str, Enum):
|
||||
"""Tracks which implementation variants exist for a middleware hook."""
|
||||
|
||||
NONE = "none"
|
||||
SYNC_ONLY = "sync"
|
||||
ASYNC_ONLY = "async"
|
||||
BOTH = "both"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelRequest:
|
||||
"""Model request information for the agent."""
|
||||
@@ -93,7 +110,7 @@ class AgentState(TypedDict, Generic[ResponseT]):
|
||||
|
||||
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
||||
jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]]
|
||||
structured_response: NotRequired[ResponseT]
|
||||
structured_response: NotRequired[Annotated[ResponseT, OmitFromInput]]
|
||||
thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]]
|
||||
run_model_call_count: NotRequired[Annotated[int, UntrackedValue, PrivateStateAttr]]
|
||||
|
||||
@@ -112,6 +129,56 @@ StateT = TypeVar("StateT", bound=AgentState, default=AgentState)
|
||||
StateT_contra = TypeVar("StateT_contra", bound=AgentState, contravariant=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MiddlewareHookInfo:
|
||||
"""Information about a specific middleware hook implementation.
|
||||
|
||||
This class encapsulates metadata about how a middleware implements a particular hook,
|
||||
including the actual hook functions and jump configuration.
|
||||
"""
|
||||
|
||||
middleware_name: str
|
||||
"""The name of the middleware that implements this hook."""
|
||||
|
||||
hook_name: str
|
||||
"""The name of the hook (e.g., 'before_model', 'after_agent')."""
|
||||
|
||||
sync_fn: Callable[..., Any] | None
|
||||
"""The synchronous hook function, or None if not implemented."""
|
||||
|
||||
async_fn: Callable[..., Any] | None
|
||||
"""The asynchronous hook function, or None if not implemented."""
|
||||
|
||||
can_jump_to: list[JumpTo]
|
||||
"""Valid jump destinations for this hook."""
|
||||
|
||||
@property
|
||||
def node_name(self) -> str:
|
||||
"""The graph node name for this hook."""
|
||||
return f"{self.middleware_name}.{self.hook_name}"
|
||||
|
||||
@property
|
||||
def has_sync(self) -> bool:
|
||||
"""Whether this hook has a sync implementation."""
|
||||
return self.sync_fn is not None
|
||||
|
||||
@property
|
||||
def has_async(self) -> bool:
|
||||
"""Whether this hook has an async implementation."""
|
||||
return self.async_fn is not None
|
||||
|
||||
@property
|
||||
def implementation(self) -> HookImplementation:
|
||||
"""Which variants (sync/async/both) are implemented."""
|
||||
if self.has_sync and self.has_async:
|
||||
return HookImplementation.BOTH
|
||||
if self.has_sync:
|
||||
return HookImplementation.SYNC_ONLY
|
||||
if self.has_async:
|
||||
return HookImplementation.ASYNC_ONLY
|
||||
return HookImplementation.NONE
|
||||
|
||||
|
||||
class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
"""Base middleware class for an agent.
|
||||
|
||||
@@ -133,6 +200,14 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
"""
|
||||
return self.__class__.__name__
|
||||
|
||||
def before_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
||||
"""Logic to run before the agent execution starts."""
|
||||
|
||||
async def abefore_agent(
|
||||
self, state: StateT, runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async logic to run before the agent execution starts."""
|
||||
|
||||
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
||||
"""Logic to run before the model is called."""
|
||||
|
||||
@@ -215,6 +290,102 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
None, self.retry_model_request, error, request, state, runtime, attempt
|
||||
)
|
||||
|
||||
def after_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
||||
"""Logic to run after the agent execution completes."""
|
||||
|
||||
async def aafter_agent(
|
||||
self, state: StateT, runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async logic to run after the agent execution completes."""
|
||||
|
||||
def hook_info(self, hook_name: str) -> MiddlewareHookInfo | None:
|
||||
"""Get information about this middleware's implementation of a specific hook.
|
||||
|
||||
Args:
|
||||
hook_name: The name of the hook to inspect (e.g., 'before_model', 'after_agent').
|
||||
|
||||
Returns:
|
||||
MiddlewareHookInfo if the hook is implemented, None otherwise.
|
||||
|
||||
Example:
|
||||
>>> middleware = MyMiddleware()
|
||||
>>> info = middleware.hook_info("before_model")
|
||||
>>> if info:
|
||||
... print(f"Has sync: {info.has_sync}, Has async: {info.has_async}")
|
||||
"""
|
||||
base_class = AgentMiddleware
|
||||
middleware_class = self.__class__
|
||||
|
||||
# Check sync and async variants
|
||||
sync_name = hook_name
|
||||
async_name = f"a{hook_name}"
|
||||
|
||||
base_sync_method = getattr(base_class, sync_name, None)
|
||||
base_async_method = getattr(base_class, async_name, None)
|
||||
|
||||
middleware_sync_method = getattr(middleware_class, sync_name, None)
|
||||
middleware_async_method = getattr(middleware_class, async_name, None)
|
||||
|
||||
has_custom_sync = middleware_sync_method is not base_sync_method
|
||||
has_custom_async = middleware_async_method is not base_async_method
|
||||
|
||||
if not has_custom_sync and not has_custom_async:
|
||||
return None
|
||||
|
||||
# Get the actual bound methods - only include customized implementations
|
||||
sync_fn = getattr(self, sync_name) if has_custom_sync else None
|
||||
async_fn = getattr(self, async_name) if has_custom_async else None
|
||||
|
||||
# Get can_jump_to from either sync or async variant
|
||||
can_jump_to: list[JumpTo] = []
|
||||
if has_custom_sync:
|
||||
can_jump_to = getattr(middleware_sync_method, "__can_jump_to__", [])
|
||||
elif has_custom_async:
|
||||
can_jump_to = getattr(middleware_async_method, "__can_jump_to__", [])
|
||||
|
||||
return MiddlewareHookInfo(
|
||||
middleware_name=self.name,
|
||||
hook_name=hook_name,
|
||||
sync_fn=sync_fn,
|
||||
async_fn=async_fn,
|
||||
can_jump_to=can_jump_to,
|
||||
)
|
||||
|
||||
def all_hook_info(self) -> dict[str, MiddlewareHookInfo]:
|
||||
"""Get information about all hooks implemented by this middleware.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping hook names to their MiddlewareHookInfo.
|
||||
|
||||
Example:
|
||||
>>> middleware = MyMiddleware()
|
||||
>>> for hook_name, info in middleware.all_hook_info().items():
|
||||
... print(f"{hook_name}: sync={info.has_sync}, async={info.has_async}")
|
||||
"""
|
||||
hook_names = [
|
||||
"before_agent",
|
||||
"before_model",
|
||||
"modify_model_request",
|
||||
"after_model",
|
||||
"after_agent",
|
||||
"retry_model_request",
|
||||
]
|
||||
return {name: info for name in hook_names if (info := self.hook_info(name)) is not None}
|
||||
|
||||
@property
|
||||
def implemented_hooks(self) -> list[str]:
|
||||
"""List of hook names this middleware implements.
|
||||
|
||||
Returns:
|
||||
List of hook names that are overridden from the base class.
|
||||
|
||||
Example:
|
||||
>>> middleware = MyMiddleware()
|
||||
>>> print(middleware.implemented_hooks)
|
||||
['before_model', 'after_model']
|
||||
"""
|
||||
return list(self.all_hook_info().keys())
|
||||
|
||||
|
||||
class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
||||
"""Callable with AgentState and Runtime as arguments."""
|
||||
@@ -707,6 +878,279 @@ def after_model(
|
||||
return decorator
|
||||
|
||||
|
||||
@overload
|
||||
def before_agent(
|
||||
func: _CallableWithStateAndRuntime[StateT, ContextT],
|
||||
) -> AgentMiddleware[StateT, ContextT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def before_agent(
|
||||
func: None = None,
|
||||
*,
|
||||
state_schema: type[StateT] | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
can_jump_to: list[JumpTo] | None = None,
|
||||
name: str | None = None,
|
||||
) -> Callable[
|
||||
[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]
|
||||
]: ...
|
||||
|
||||
|
||||
def before_agent(
|
||||
func: _CallableWithStateAndRuntime[StateT, ContextT] | None = None,
|
||||
*,
|
||||
state_schema: type[StateT] | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
can_jump_to: list[JumpTo] | None = None,
|
||||
name: str | None = None,
|
||||
) -> (
|
||||
Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
|
||||
| AgentMiddleware[StateT, ContextT]
|
||||
):
|
||||
"""Decorator used to dynamically create a middleware with the before_agent hook.
|
||||
|
||||
Args:
|
||||
func: The function to be decorated. Must accept:
|
||||
`state: StateT, runtime: Runtime[ContextT]` - State and runtime context
|
||||
state_schema: Optional custom state schema type. If not provided, uses the default
|
||||
AgentState schema.
|
||||
tools: Optional list of additional tools to register with this middleware.
|
||||
can_jump_to: Optional list of valid jump destinations for conditional edges.
|
||||
Valid values are: "tools", "model", "end"
|
||||
name: Optional name for the generated middleware class. If not provided,
|
||||
uses the decorated function's name.
|
||||
|
||||
Returns:
|
||||
Either an AgentMiddleware instance (if func is provided directly) or a decorator function
|
||||
that can be applied to a function its wrapping.
|
||||
|
||||
The decorated function should return:
|
||||
- `dict[str, Any]` - State updates to merge into the agent state
|
||||
- `Command` - A command to control flow (e.g., jump to different node)
|
||||
- `None` - No state updates or flow control
|
||||
|
||||
Examples:
|
||||
Basic usage:
|
||||
```python
|
||||
@before_agent
|
||||
def log_before_agent(state: AgentState, runtime: Runtime) -> None:
|
||||
print(f"Starting agent with {len(state['messages'])} messages")
|
||||
```
|
||||
|
||||
With conditional jumping:
|
||||
```python
|
||||
@before_agent(can_jump_to=["end"])
|
||||
def conditional_before_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
if some_condition(state):
|
||||
return {"jump_to": "end"}
|
||||
return None
|
||||
```
|
||||
|
||||
With custom state schema:
|
||||
```python
|
||||
@before_agent(state_schema=MyCustomState)
|
||||
def custom_before_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {"custom_field": "initialized_value"}
|
||||
```
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: _CallableWithStateAndRuntime[StateT, ContextT],
|
||||
) -> AgentMiddleware[StateT, ContextT]:
|
||||
is_async = iscoroutinefunction(func)
|
||||
|
||||
func_can_jump_to = (
|
||||
can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", [])
|
||||
)
|
||||
|
||||
if is_async:
|
||||
|
||||
async def async_wrapped(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | Command | None:
|
||||
return await func(state, runtime) # type: ignore[misc]
|
||||
|
||||
# Preserve can_jump_to metadata on the wrapped function
|
||||
if func_can_jump_to:
|
||||
async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
|
||||
|
||||
middleware_name = name or cast(
|
||||
"str", getattr(func, "__name__", "BeforeAgentMiddleware")
|
||||
)
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"abefore_agent": async_wrapped,
|
||||
},
|
||||
)()
|
||||
|
||||
def wrapped(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | Command | None:
|
||||
return func(state, runtime) # type: ignore[return-value]
|
||||
|
||||
# Preserve can_jump_to metadata on the wrapped function
|
||||
if func_can_jump_to:
|
||||
wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
|
||||
|
||||
# Use function name as default if no name provided
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "BeforeAgentMiddleware"))
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"before_agent": wrapped,
|
||||
},
|
||||
)()
|
||||
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
return decorator
|
||||
|
||||
|
||||
@overload
|
||||
def after_agent(
|
||||
func: _CallableWithStateAndRuntime[StateT, ContextT],
|
||||
) -> AgentMiddleware[StateT, ContextT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def after_agent(
|
||||
func: None = None,
|
||||
*,
|
||||
state_schema: type[StateT] | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
can_jump_to: list[JumpTo] | None = None,
|
||||
name: str | None = None,
|
||||
) -> Callable[
|
||||
[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]
|
||||
]: ...
|
||||
|
||||
|
||||
def after_agent(
|
||||
func: _CallableWithStateAndRuntime[StateT, ContextT] | None = None,
|
||||
*,
|
||||
state_schema: type[StateT] | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
can_jump_to: list[JumpTo] | None = None,
|
||||
name: str | None = None,
|
||||
) -> (
|
||||
Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
|
||||
| AgentMiddleware[StateT, ContextT]
|
||||
):
|
||||
"""Decorator used to dynamically create a middleware with the after_agent hook.
|
||||
|
||||
Args:
|
||||
func: The function to be decorated. Must accept:
|
||||
`state: StateT, runtime: Runtime[ContextT]` - State and runtime context
|
||||
state_schema: Optional custom state schema type. If not provided, uses the default
|
||||
AgentState schema.
|
||||
tools: Optional list of additional tools to register with this middleware.
|
||||
can_jump_to: Optional list of valid jump destinations for conditional edges.
|
||||
Valid values are: "tools", "model", "end"
|
||||
name: Optional name for the generated middleware class. If not provided,
|
||||
uses the decorated function's name.
|
||||
|
||||
Returns:
|
||||
Either an AgentMiddleware instance (if func is provided) or a decorator function
|
||||
that can be applied to a function.
|
||||
|
||||
The decorated function should return:
|
||||
- `dict[str, Any]` - State updates to merge into the agent state
|
||||
- `Command` - A command to control flow (e.g., jump to different node)
|
||||
- `None` - No state updates or flow control
|
||||
|
||||
Examples:
|
||||
Basic usage for logging agent completion:
|
||||
```python
|
||||
@after_agent
|
||||
def log_completion(state: AgentState, runtime: Runtime) -> None:
|
||||
print(f"Agent completed with {len(state['messages'])} messages")
|
||||
```
|
||||
|
||||
With custom state schema:
|
||||
```python
|
||||
@after_agent(state_schema=MyCustomState, name="MyAfterAgentMiddleware")
|
||||
def custom_after_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {"custom_field": "finalized_value"}
|
||||
```
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: _CallableWithStateAndRuntime[StateT, ContextT],
|
||||
) -> AgentMiddleware[StateT, ContextT]:
|
||||
is_async = iscoroutinefunction(func)
|
||||
# Extract can_jump_to from decorator parameter or from function metadata
|
||||
func_can_jump_to = (
|
||||
can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", [])
|
||||
)
|
||||
|
||||
if is_async:
|
||||
|
||||
async def async_wrapped(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | Command | None:
|
||||
return await func(state, runtime) # type: ignore[misc]
|
||||
|
||||
# Preserve can_jump_to metadata on the wrapped function
|
||||
if func_can_jump_to:
|
||||
async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
|
||||
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware"))
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"aafter_agent": async_wrapped,
|
||||
},
|
||||
)()
|
||||
|
||||
def wrapped(
|
||||
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | Command | None:
|
||||
return func(state, runtime) # type: ignore[return-value]
|
||||
|
||||
# Preserve can_jump_to metadata on the wrapped function
|
||||
if func_can_jump_to:
|
||||
wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
|
||||
|
||||
# Use function name as default if no name provided
|
||||
middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware"))
|
||||
|
||||
return type(
|
||||
middleware_name,
|
||||
(AgentMiddleware,),
|
||||
{
|
||||
"state_schema": state_schema or AgentState,
|
||||
"tools": tools or [],
|
||||
"after_agent": wrapped,
|
||||
},
|
||||
)()
|
||||
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
return decorator
|
||||
|
||||
|
||||
@overload
|
||||
def dynamic_prompt(
|
||||
func: _CallableReturningPromptString[StateT, ContextT],
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
"""Middleware agent implementation."""
|
||||
|
||||
import itertools
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Any, Generic, cast, get_args, get_origin, get_type_hints
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.runnables import Runnable, run_in_executor
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph._internal._runnable import RunnableCallable
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph.state import StateGraph
|
||||
from langgraph.runtime import Runtime
|
||||
@@ -19,6 +20,7 @@ from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
JumpTo,
|
||||
MiddlewareHookInfo,
|
||||
ModelRequest,
|
||||
OmitFromSchema,
|
||||
PublicAgentState,
|
||||
@@ -41,6 +43,167 @@ STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
|
||||
ResponseT = TypeVar("ResponseT")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Data Structures for Agent Graph Construction
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class MiddlewareHooks:
|
||||
"""Middleware hooks categorized by type for graph construction."""
|
||||
|
||||
before_agent: list[MiddlewareHookInfo]
|
||||
"""Hooks that run once before the agent starts."""
|
||||
|
||||
before_model: list[MiddlewareHookInfo]
|
||||
"""Hooks that run before each model call in the agent loop."""
|
||||
|
||||
modify_model_request: list[MiddlewareHookInfo]
|
||||
"""Hooks that modify the model request before calling the model."""
|
||||
|
||||
after_model: list[MiddlewareHookInfo]
|
||||
"""Hooks that run after each model call in the agent loop."""
|
||||
|
||||
after_agent: list[MiddlewareHookInfo]
|
||||
"""Hooks that run once after the agent completes."""
|
||||
|
||||
retry: list[MiddlewareHookInfo]
|
||||
"""Hooks that handle model invocation errors and optionally retry."""
|
||||
|
||||
@classmethod
|
||||
def from_middleware_list(
|
||||
cls,
|
||||
middleware: Sequence[AgentMiddleware[AgentState[ResponseT], ContextT]],
|
||||
) -> "MiddlewareHooks":
|
||||
"""Extract and categorize all hooks from middleware instances.
|
||||
|
||||
Args:
|
||||
middleware: Sequence of middleware instances to analyze.
|
||||
|
||||
Returns:
|
||||
MiddlewareHooks with all hooks organized by type.
|
||||
"""
|
||||
hooks_by_type: dict[str, list[MiddlewareHookInfo]] = {
|
||||
"before_agent": [],
|
||||
"before_model": [],
|
||||
"modify_model_request": [],
|
||||
"after_model": [],
|
||||
"after_agent": [],
|
||||
"retry": [],
|
||||
}
|
||||
|
||||
# Map hook names to their category
|
||||
hook_name_mapping = {
|
||||
"before_agent": "before_agent",
|
||||
"before_model": "before_model",
|
||||
"modify_model_request": "modify_model_request",
|
||||
"after_model": "after_model",
|
||||
"after_agent": "after_agent",
|
||||
"retry_model_request": "retry",
|
||||
}
|
||||
|
||||
for m in middleware:
|
||||
for hook_name, category in hook_name_mapping.items():
|
||||
if hook_info := m.hook_info(hook_name):
|
||||
hooks_by_type[category].append(hook_info)
|
||||
|
||||
return cls(
|
||||
before_agent=hooks_by_type["before_agent"],
|
||||
before_model=hooks_by_type["before_model"],
|
||||
modify_model_request=hooks_by_type["modify_model_request"],
|
||||
after_model=hooks_by_type["after_model"],
|
||||
after_agent=hooks_by_type["after_agent"],
|
||||
retry=hooks_by_type["retry"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentComponents:
|
||||
"""Core components and configuration for agent construction."""
|
||||
|
||||
model: BaseChatModel
|
||||
"""The language model to use for the agent."""
|
||||
|
||||
tool_node: ToolNode | None
|
||||
"""The tool execution node, or None if no tools are available."""
|
||||
|
||||
middleware_hooks: MiddlewareHooks
|
||||
"""Middleware hooks organized by type."""
|
||||
|
||||
structured_output_tools: dict[str, OutputToolBinding]
|
||||
"""Tools used for structured output parsing."""
|
||||
|
||||
default_tools: list[BaseTool | dict]
|
||||
"""Default tools available to the agent (regular tools + middleware tools + built-ins)."""
|
||||
|
||||
initial_response_format: ResponseFormat | None
|
||||
"""The initial response format configuration."""
|
||||
|
||||
system_prompt: str | None
|
||||
"""The system prompt for the agent."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphTopology:
|
||||
"""Key nodes in the graph topology defining the execution flow.
|
||||
|
||||
The agent graph has the following structure:
|
||||
START -> entry_node -> [loop: loop_entry_node -> model -> loop_exit_node -> tools]
|
||||
-> exit_node -> END
|
||||
|
||||
- entry_node: Runs once at the start (before_agent hooks)
|
||||
- loop_entry_node: Beginning of agent loop (before_model hooks)
|
||||
- loop_exit_node: End of each loop iteration (after_model hooks)
|
||||
- exit_node: Runs once at the end (after_agent hooks) or END
|
||||
"""
|
||||
|
||||
entry_node: str
|
||||
"""The first node executed (START -> entry_node)."""
|
||||
|
||||
loop_entry_node: str
|
||||
"""Where the agent loop begins (where tools loop back to)."""
|
||||
|
||||
loop_exit_node: str
|
||||
"""The last node in each loop iteration."""
|
||||
|
||||
exit_node: str
|
||||
"""The final node before END (or END itself)."""
|
||||
|
||||
@classmethod
|
||||
def compute(cls, hooks: MiddlewareHooks) -> "GraphTopology":
|
||||
"""Compute graph topology from middleware hook configuration.
|
||||
|
||||
Args:
|
||||
hooks: The categorized middleware hooks.
|
||||
|
||||
Returns:
|
||||
GraphTopology describing the flow through the graph.
|
||||
"""
|
||||
# Entry node (runs once at start): before_agent -> before_model -> model_request
|
||||
if hooks.before_agent:
|
||||
entry_node = hooks.before_agent[0].node_name
|
||||
elif hooks.before_model:
|
||||
entry_node = hooks.before_model[0].node_name
|
||||
else:
|
||||
entry_node = "model_request"
|
||||
|
||||
# Loop entry node (beginning of agent loop, excludes before_agent)
|
||||
loop_entry_node = hooks.before_model[0].node_name if hooks.before_model else "model_request"
|
||||
|
||||
# Loop exit node (end of each iteration, excludes after_agent)
|
||||
loop_exit_node = hooks.after_model[0].node_name if hooks.after_model else "model_request"
|
||||
|
||||
# Exit node (runs once at end): after_agent or END
|
||||
exit_node = hooks.after_agent[-1].node_name if hooks.after_agent else END
|
||||
|
||||
return cls(
|
||||
entry_node=entry_node,
|
||||
loop_entry_node=loop_entry_node,
|
||||
loop_exit_node=loop_exit_node,
|
||||
exit_node=exit_node,
|
||||
)
|
||||
|
||||
|
||||
def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
|
||||
"""Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations.
|
||||
|
||||
@@ -86,39 +249,112 @@ def _extract_metadata(type_: type) -> list:
|
||||
return []
|
||||
|
||||
|
||||
def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> list[JumpTo]:
|
||||
"""Get the can_jump_to list from either sync or async hook methods.
|
||||
# ============================================================================
|
||||
# Setup and Initialization Functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _setup_components(
|
||||
model: str | BaseChatModel,
|
||||
tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode | None,
|
||||
middleware: Sequence[AgentMiddleware[AgentState[ResponseT], ContextT]],
|
||||
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None,
|
||||
system_prompt: str | None,
|
||||
) -> AgentComponents:
|
||||
"""Setup and validate agent components.
|
||||
|
||||
Args:
|
||||
middleware: The middleware instance to inspect.
|
||||
hook_name: The name of the hook ('before_model' or 'after_model').
|
||||
model: Model name or instance.
|
||||
tools: Tools for the agent.
|
||||
middleware: Middleware instances.
|
||||
response_format: Response format configuration.
|
||||
system_prompt: System prompt for the agent.
|
||||
|
||||
Returns:
|
||||
List of jump destinations, or empty list if not configured.
|
||||
AgentComponents with all components configured and validated.
|
||||
"""
|
||||
# Get the base class method for comparison
|
||||
base_sync_method = getattr(AgentMiddleware, hook_name, None)
|
||||
base_async_method = getattr(AgentMiddleware, f"a{hook_name}", None)
|
||||
# Initialize chat model
|
||||
if isinstance(model, str):
|
||||
model = init_chat_model(model)
|
||||
|
||||
# Try sync method first - only if it's overridden from base class
|
||||
sync_method = getattr(middleware.__class__, hook_name, None)
|
||||
if (
|
||||
sync_method
|
||||
and sync_method is not base_sync_method
|
||||
and hasattr(sync_method, "__can_jump_to__")
|
||||
):
|
||||
return sync_method.__can_jump_to__
|
||||
# Handle tools being None or empty
|
||||
if tools is None:
|
||||
tools = []
|
||||
|
||||
# Try async method - only if it's overridden from base class
|
||||
async_method = getattr(middleware.__class__, f"a{hook_name}", None)
|
||||
if (
|
||||
async_method
|
||||
and async_method is not base_async_method
|
||||
and hasattr(async_method, "__can_jump_to__")
|
||||
):
|
||||
return async_method.__can_jump_to__
|
||||
# Convert response format and setup structured output tools
|
||||
initial_response_format: ToolStrategy | ProviderStrategy | AutoStrategy | None
|
||||
if response_format is None:
|
||||
initial_response_format = None
|
||||
elif isinstance(response_format, (ToolStrategy, ProviderStrategy, AutoStrategy)):
|
||||
initial_response_format = response_format
|
||||
else:
|
||||
# Raw schema - wrap in AutoStrategy to enable auto-detection
|
||||
initial_response_format = AutoStrategy(schema=response_format)
|
||||
|
||||
return []
|
||||
# For AutoStrategy, convert to ToolStrategy to setup tools upfront
|
||||
tool_strategy_for_setup: ToolStrategy | None = None
|
||||
if isinstance(initial_response_format, AutoStrategy):
|
||||
tool_strategy_for_setup = ToolStrategy(schema=initial_response_format.schema)
|
||||
elif isinstance(initial_response_format, ToolStrategy):
|
||||
tool_strategy_for_setup = initial_response_format
|
||||
|
||||
structured_output_tools: dict[str, OutputToolBinding] = {}
|
||||
if tool_strategy_for_setup:
|
||||
for response_schema in tool_strategy_for_setup.schema_specs:
|
||||
structured_tool_info = OutputToolBinding.from_schema_spec(response_schema)
|
||||
structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
|
||||
|
||||
middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
|
||||
|
||||
# Setup tools
|
||||
tool_node: ToolNode | None = None
|
||||
default_tools: list[BaseTool | dict[str, Any]]
|
||||
|
||||
if isinstance(tools, list):
|
||||
# Extract built-in provider tools (dict format) and regular tools (BaseTool)
|
||||
built_in_tools = [t for t in tools if isinstance(t, dict)]
|
||||
regular_tools = [t for t in tools if not isinstance(t, dict)]
|
||||
|
||||
# Tools that require client-side execution
|
||||
available_tools = middleware_tools + regular_tools
|
||||
|
||||
# Only create ToolNode if we have client-side tools
|
||||
tool_node = ToolNode(tools=available_tools) if available_tools else None
|
||||
|
||||
# Default tools for ModelRequest initialization
|
||||
default_tools = regular_tools + middleware_tools + built_in_tools
|
||||
elif isinstance(tools, ToolNode):
|
||||
tool_node = tools
|
||||
if tool_node:
|
||||
# Add middleware tools to existing ToolNode
|
||||
available_tools = list(tool_node.tools_by_name.values()) + middleware_tools
|
||||
tool_node = ToolNode(available_tools)
|
||||
|
||||
# default_tools includes all client-side tools
|
||||
default_tools = available_tools
|
||||
else:
|
||||
default_tools = middleware_tools
|
||||
else:
|
||||
# No tools provided, only middleware_tools available
|
||||
default_tools = middleware_tools
|
||||
|
||||
# Validate middleware
|
||||
assert len({m.name for m in middleware}) == len(middleware), ( # noqa: S101
|
||||
"Please remove duplicate middleware instances."
|
||||
)
|
||||
|
||||
# Categorize middleware by hooks
|
||||
middleware_hooks = MiddlewareHooks.from_middleware_list(middleware)
|
||||
|
||||
return AgentComponents(
|
||||
model=model,
|
||||
tool_node=tool_node,
|
||||
middleware_hooks=middleware_hooks,
|
||||
structured_output_tools=structured_output_tools,
|
||||
default_tools=default_tools,
|
||||
initial_response_format=initial_response_format,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
|
||||
|
||||
def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
|
||||
@@ -144,6 +380,58 @@ def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Node Building Functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _create_hook_node(hook_info: MiddlewareHookInfo) -> RunnableCallable:
|
||||
"""Create a graph node for a middleware hook.
|
||||
|
||||
Args:
|
||||
hook_info: Information about the hook to create a node for.
|
||||
|
||||
Returns:
|
||||
RunnableCallable that supports both sync and async execution.
|
||||
"""
|
||||
return RunnableCallable(hook_info.sync_fn, hook_info.async_fn, trace=False)
|
||||
|
||||
|
||||
def _add_middleware_nodes(
|
||||
graph: StateGraph[AgentState, ContextT, PublicAgentState, PublicAgentState],
|
||||
components: AgentComponents,
|
||||
state_schema: type,
|
||||
) -> None:
|
||||
"""Add all middleware hook nodes to the graph.
|
||||
|
||||
Args:
|
||||
graph: The state graph to add nodes to.
|
||||
components: Agent components with middleware hooks.
|
||||
state_schema: The state schema for input validation.
|
||||
"""
|
||||
hooks = components.middleware_hooks
|
||||
|
||||
# Add before_agent nodes
|
||||
for hook_info in hooks.before_agent:
|
||||
node = _create_hook_node(hook_info)
|
||||
graph.add_node(hook_info.node_name, node, input_schema=state_schema)
|
||||
|
||||
# Add before_model nodes
|
||||
for hook_info in hooks.before_model:
|
||||
node = _create_hook_node(hook_info)
|
||||
graph.add_node(hook_info.node_name, node, input_schema=state_schema)
|
||||
|
||||
# Add after_model nodes
|
||||
for hook_info in hooks.after_model:
|
||||
node = _create_hook_node(hook_info)
|
||||
graph.add_node(hook_info.node_name, node, input_schema=state_schema)
|
||||
|
||||
# Add after_agent nodes
|
||||
for hook_info in hooks.after_agent:
|
||||
node = _create_hook_node(hook_info)
|
||||
graph.add_node(hook_info.node_name, node, input_schema=state_schema)
|
||||
|
||||
|
||||
def _handle_structured_output_error(
|
||||
exception: Exception,
|
||||
response_format: ResponseFormat,
|
||||
@@ -174,6 +462,157 @@ def _handle_structured_output_error(
|
||||
return False, ""
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Edge Building Functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _connect_entry_edges(
|
||||
graph: StateGraph[AgentState, ContextT, PublicAgentState, PublicAgentState],
|
||||
topology: GraphTopology,
|
||||
) -> None:
|
||||
"""Connect the entry edge from START to the entry node.
|
||||
|
||||
Args:
|
||||
graph: The state graph to add edges to.
|
||||
topology: Graph topology configuration.
|
||||
"""
|
||||
graph.add_edge(START, topology.entry_node)
|
||||
|
||||
|
||||
def _connect_loop_edges(
|
||||
graph: StateGraph[AgentState, ContextT, PublicAgentState, PublicAgentState],
|
||||
topology: GraphTopology,
|
||||
components: AgentComponents,
|
||||
) -> None:
|
||||
"""Connect conditional edges for the agent loop (tools <-> model).
|
||||
|
||||
Args:
|
||||
graph: The state graph to add edges to.
|
||||
topology: Graph topology configuration.
|
||||
components: Agent components with tool configuration.
|
||||
"""
|
||||
tool_node = components.tool_node
|
||||
structured_output_tools = components.structured_output_tools
|
||||
|
||||
if tool_node is None:
|
||||
# No tools - connect loop_exit directly to exit_node
|
||||
if topology.loop_exit_node == "model_request":
|
||||
graph.add_edge(topology.loop_exit_node, topology.exit_node)
|
||||
else:
|
||||
# We have after_model but no tools
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
topology.loop_exit_node,
|
||||
topology.exit_node,
|
||||
topology.loop_entry_node,
|
||||
can_jump_to=components.middleware_hooks.after_model[0].can_jump_to,
|
||||
)
|
||||
return
|
||||
|
||||
# Add conditional edge from tools back to model or exit
|
||||
graph.add_conditional_edges(
|
||||
"tools",
|
||||
_make_tools_to_model_edge(
|
||||
tool_node, topology.loop_entry_node, structured_output_tools, topology.exit_node
|
||||
),
|
||||
[topology.loop_entry_node, topology.exit_node],
|
||||
)
|
||||
|
||||
# Add conditional edge from model to tools or exit
|
||||
graph.add_conditional_edges(
|
||||
topology.loop_exit_node,
|
||||
_make_model_to_tools_edge(
|
||||
topology.loop_entry_node, structured_output_tools, tool_node, topology.exit_node
|
||||
),
|
||||
[topology.loop_entry_node, "tools", topology.exit_node],
|
||||
)
|
||||
|
||||
|
||||
def _connect_middleware_chains(
|
||||
graph: StateGraph[AgentState, ContextT, PublicAgentState, PublicAgentState],
|
||||
components: AgentComponents,
|
||||
topology: GraphTopology,
|
||||
) -> None:
|
||||
"""Connect middleware hooks in chains.
|
||||
|
||||
Args:
|
||||
graph: The state graph to add edges to.
|
||||
components: Agent components with middleware hooks.
|
||||
topology: Graph topology configuration.
|
||||
"""
|
||||
hooks = components.middleware_hooks
|
||||
|
||||
# Connect before_agent chain
|
||||
if hooks.before_agent:
|
||||
for i in range(len(hooks.before_agent) - 1):
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
hooks.before_agent[i].node_name,
|
||||
hooks.before_agent[i + 1].node_name,
|
||||
topology.loop_entry_node,
|
||||
can_jump_to=hooks.before_agent[i].can_jump_to,
|
||||
)
|
||||
# Connect last before_agent to loop_entry_node
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
hooks.before_agent[-1].node_name,
|
||||
topology.loop_entry_node,
|
||||
topology.loop_entry_node,
|
||||
can_jump_to=hooks.before_agent[-1].can_jump_to,
|
||||
)
|
||||
|
||||
# Connect before_model chain
|
||||
if hooks.before_model:
|
||||
for i in range(len(hooks.before_model) - 1):
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
hooks.before_model[i].node_name,
|
||||
hooks.before_model[i + 1].node_name,
|
||||
topology.loop_entry_node,
|
||||
can_jump_to=hooks.before_model[i].can_jump_to,
|
||||
)
|
||||
# Connect last before_model to model_request
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
hooks.before_model[-1].node_name,
|
||||
"model_request",
|
||||
topology.loop_entry_node,
|
||||
can_jump_to=hooks.before_model[-1].can_jump_to,
|
||||
)
|
||||
|
||||
# Connect after_model chain (reverse order)
|
||||
if hooks.after_model:
|
||||
graph.add_edge("model_request", hooks.after_model[-1].node_name)
|
||||
for i in range(len(hooks.after_model) - 1, 0, -1):
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
hooks.after_model[i].node_name,
|
||||
hooks.after_model[i - 1].node_name,
|
||||
topology.loop_entry_node,
|
||||
can_jump_to=hooks.after_model[i].can_jump_to,
|
||||
)
|
||||
|
||||
# Connect after_agent chain (reverse order)
|
||||
if hooks.after_agent:
|
||||
for i in range(len(hooks.after_agent) - 1, 0, -1):
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
hooks.after_agent[i].node_name,
|
||||
hooks.after_agent[i - 1].node_name,
|
||||
topology.loop_entry_node,
|
||||
can_jump_to=hooks.after_agent[i].can_jump_to,
|
||||
)
|
||||
# Connect first after_agent to END
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
hooks.after_agent[0].node_name,
|
||||
END,
|
||||
topology.loop_entry_node,
|
||||
can_jump_to=hooks.after_agent[0].can_jump_to,
|
||||
)
|
||||
|
||||
|
||||
def create_agent( # noqa: PLR0915
|
||||
*,
|
||||
model: str | BaseChatModel,
|
||||
@@ -185,106 +624,23 @@ def create_agent( # noqa: PLR0915
|
||||
) -> StateGraph[
|
||||
AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
|
||||
]:
|
||||
"""Create a middleware agent graph."""
|
||||
# init chat model
|
||||
if isinstance(model, str):
|
||||
model = init_chat_model(model)
|
||||
"""Create a middleware agent graph.
|
||||
|
||||
# Handle tools being None or empty
|
||||
if tools is None:
|
||||
tools = []
|
||||
Args:
|
||||
model: Model name or BaseChatModel instance.
|
||||
tools: Tools for the agent to use.
|
||||
system_prompt: System prompt for the agent.
|
||||
middleware: Middleware instances to customize agent behavior.
|
||||
response_format: Response format configuration for structured outputs.
|
||||
context_schema: Context schema for the graph runtime.
|
||||
|
||||
# Convert response format and setup structured output tools
|
||||
# Raw schemas are wrapped in AutoStrategy to preserve auto-detection intent.
|
||||
# AutoStrategy is converted to ToolStrategy upfront to calculate tools during agent creation,
|
||||
# but may be replaced with ProviderStrategy later based on model capabilities.
|
||||
initial_response_format: ToolStrategy | ProviderStrategy | AutoStrategy | None
|
||||
if response_format is None:
|
||||
initial_response_format = None
|
||||
elif isinstance(response_format, (ToolStrategy, ProviderStrategy)):
|
||||
# Preserve explicitly requested strategies
|
||||
initial_response_format = response_format
|
||||
elif isinstance(response_format, AutoStrategy):
|
||||
# AutoStrategy provided - preserve it for later auto-detection
|
||||
initial_response_format = response_format
|
||||
else:
|
||||
# Raw schema - wrap in AutoStrategy to enable auto-detection
|
||||
initial_response_format = AutoStrategy(schema=response_format)
|
||||
|
||||
# For AutoStrategy, convert to ToolStrategy to setup tools upfront
|
||||
# (may be replaced with ProviderStrategy later based on model)
|
||||
tool_strategy_for_setup: ToolStrategy | None = None
|
||||
if isinstance(initial_response_format, AutoStrategy):
|
||||
tool_strategy_for_setup = ToolStrategy(schema=initial_response_format.schema)
|
||||
elif isinstance(initial_response_format, ToolStrategy):
|
||||
tool_strategy_for_setup = initial_response_format
|
||||
|
||||
structured_output_tools: dict[str, OutputToolBinding] = {}
|
||||
if tool_strategy_for_setup:
|
||||
for response_schema in tool_strategy_for_setup.schema_specs:
|
||||
structured_tool_info = OutputToolBinding.from_schema_spec(response_schema)
|
||||
structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
|
||||
middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
|
||||
|
||||
# Setup tools
|
||||
tool_node: ToolNode | None = None
|
||||
if isinstance(tools, list):
|
||||
# Extract built-in provider tools (dict format) and regular tools (BaseTool)
|
||||
built_in_tools = [t for t in tools if isinstance(t, dict)]
|
||||
regular_tools = [t for t in tools if not isinstance(t, dict)]
|
||||
|
||||
# Tools that require client-side execution (must be in ToolNode)
|
||||
available_tools = middleware_tools + regular_tools
|
||||
|
||||
# Only create ToolNode if we have client-side tools
|
||||
tool_node = ToolNode(tools=available_tools) if available_tools else None
|
||||
|
||||
# Default tools for ModelRequest initialization
|
||||
# Include built-ins and regular tools (can be changed dynamically by middleware)
|
||||
# Structured tools are NOT included - they're added dynamically based on response_format
|
||||
default_tools = regular_tools + middleware_tools + built_in_tools
|
||||
elif isinstance(tools, ToolNode):
|
||||
tool_node = tools
|
||||
if tool_node:
|
||||
# Add middleware tools to existing ToolNode
|
||||
available_tools = list(tool_node.tools_by_name.values()) + middleware_tools
|
||||
tool_node = ToolNode(available_tools)
|
||||
|
||||
# default_tools includes all client-side tools (no built-ins or structured tools)
|
||||
default_tools = available_tools
|
||||
else:
|
||||
# No tools provided, only middleware_tools available
|
||||
default_tools = middleware_tools
|
||||
|
||||
# validate middleware
|
||||
assert len({m.name for m in middleware}) == len(middleware), ( # noqa: S101
|
||||
"Please remove duplicate middleware instances."
|
||||
)
|
||||
middleware_w_before = [
|
||||
m
|
||||
for m in middleware
|
||||
if m.__class__.before_model is not AgentMiddleware.before_model
|
||||
or m.__class__.abefore_model is not AgentMiddleware.abefore_model
|
||||
]
|
||||
middleware_w_modify_model_request = [
|
||||
m
|
||||
for m in middleware
|
||||
if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request
|
||||
or m.__class__.amodify_model_request is not AgentMiddleware.amodify_model_request
|
||||
]
|
||||
middleware_w_after = [
|
||||
m
|
||||
for m in middleware
|
||||
if m.__class__.after_model is not AgentMiddleware.after_model
|
||||
or m.__class__.aafter_model is not AgentMiddleware.aafter_model
|
||||
]
|
||||
middleware_w_retry = [
|
||||
m
|
||||
for m in middleware
|
||||
if m.__class__.retry_model_request is not AgentMiddleware.retry_model_request
|
||||
or m.__class__.aretry_model_request is not AgentMiddleware.aretry_model_request
|
||||
]
|
||||
Returns:
|
||||
StateGraph configured with all nodes and edges.
|
||||
"""
|
||||
# Phase 1: Setup and validate components
|
||||
components = _setup_components(model, tools, middleware, response_format, system_prompt)
|
||||
|
||||
# Phase 2: Create schemas
|
||||
state_schemas = {m.state_schema for m in middleware}
|
||||
state_schemas.add(AgentState)
|
||||
|
||||
@@ -292,7 +648,7 @@ def create_agent( # noqa: PLR0915
|
||||
input_schema = _resolve_schema(state_schemas, "InputSchema", "input")
|
||||
output_schema = _resolve_schema(state_schemas, "OutputSchema", "output")
|
||||
|
||||
# create graph, add nodes
|
||||
# Phase 3: Create graph
|
||||
graph: StateGraph[
|
||||
AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
|
||||
] = StateGraph(
|
||||
@@ -302,6 +658,15 @@ def create_agent( # noqa: PLR0915
|
||||
context_schema=context_schema,
|
||||
)
|
||||
|
||||
# Phase 4: Define model request handlers (need access to components via closure)
|
||||
# These are inner functions because they need access to components
|
||||
structured_output_tools = components.structured_output_tools
|
||||
default_tools = components.default_tools
|
||||
initial_response_format = components.initial_response_format
|
||||
model_instance = components.model
|
||||
middleware_w_modify_model_request = components.middleware_hooks.modify_model_request
|
||||
middleware_w_retry = components.middleware_hooks.retry
|
||||
|
||||
def _handle_model_output(
|
||||
output: AIMessage, effective_response_format: ResponseFormat | None
|
||||
) -> dict[str, Any]:
|
||||
@@ -511,29 +876,28 @@ def create_agent( # noqa: PLR0915
|
||||
def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
"""Sync model request handler with sequential middleware processing."""
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
model=model_instance,
|
||||
tools=default_tools,
|
||||
system_prompt=system_prompt,
|
||||
system_prompt=components.system_prompt,
|
||||
response_format=initial_response_format,
|
||||
messages=state["messages"],
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
# Apply modify_model_request middleware in sequence
|
||||
for m in middleware_w_modify_model_request:
|
||||
if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request:
|
||||
m.modify_model_request(request, state, runtime)
|
||||
for hook_info in middleware_w_modify_model_request:
|
||||
if hook_info.sync_fn:
|
||||
hook_info.sync_fn(request, state, runtime)
|
||||
else:
|
||||
msg = (
|
||||
f"No synchronous function provided for "
|
||||
f'{m.__class__.__name__}.amodify_model_request".'
|
||||
f"{hook_info.middleware_name}.amodify_model_request"
|
||||
"\nEither initialize with a synchronous function or invoke"
|
||||
" via the async API (ainvoke, astream, etc.)"
|
||||
)
|
||||
raise TypeError(msg)
|
||||
|
||||
# Retry loop for model invocation with error handling
|
||||
# Hard limit of 100 attempts to prevent infinite loops from buggy middleware
|
||||
max_attempts = 100
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
@@ -551,18 +915,17 @@ def create_agent( # noqa: PLR0915
|
||||
}
|
||||
except Exception as error:
|
||||
# Try retry_model_request on each middleware
|
||||
for m in middleware_w_retry:
|
||||
if m.__class__.retry_model_request is not AgentMiddleware.retry_model_request:
|
||||
if retry_request := m.retry_model_request(
|
||||
for hook_info in middleware_w_retry:
|
||||
if hook_info.sync_fn:
|
||||
if retry_request := hook_info.sync_fn(
|
||||
error, request, state, runtime, attempt
|
||||
):
|
||||
# Break on first middleware that wants to retry
|
||||
request = retry_request
|
||||
break
|
||||
else:
|
||||
msg = (
|
||||
f"No synchronous function provided for "
|
||||
f'{m.__class__.__name__}.aretry_model_request".'
|
||||
f"{hook_info.middleware_name}.aretry_model_request"
|
||||
"\nEither initialize with a synchronous function or invoke"
|
||||
" via the async API (ainvoke, astream, etc.)"
|
||||
)
|
||||
@@ -577,20 +940,26 @@ def create_agent( # noqa: PLR0915
|
||||
async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
|
||||
"""Async model request handler with sequential middleware processing."""
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
model=model_instance,
|
||||
tools=default_tools,
|
||||
system_prompt=system_prompt,
|
||||
system_prompt=components.system_prompt,
|
||||
response_format=initial_response_format,
|
||||
messages=state["messages"],
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
# Apply modify_model_request middleware in sequence
|
||||
for m in middleware_w_modify_model_request:
|
||||
await m.amodify_model_request(request, state, runtime)
|
||||
for hook_info in middleware_w_modify_model_request:
|
||||
if hook_info.async_fn:
|
||||
await hook_info.async_fn(request, state, runtime)
|
||||
elif hook_info.sync_fn:
|
||||
# Fallback to sync if only sync is implemented
|
||||
await run_in_executor(None, hook_info.sync_fn, request, state, runtime)
|
||||
else:
|
||||
msg = f"No function provided for {hook_info.middleware_name}.modify_model_request"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
# Retry loop for model invocation with error handling
|
||||
# Hard limit of 100 attempts to prevent infinite loops from buggy middleware
|
||||
max_attempts = 100
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
@@ -608,135 +977,41 @@ def create_agent( # noqa: PLR0915
|
||||
}
|
||||
except Exception as error:
|
||||
# Try retry_model_request on each middleware
|
||||
for m in middleware_w_retry:
|
||||
if retry_request := await m.aretry_model_request(
|
||||
error, request, state, runtime, attempt
|
||||
):
|
||||
# Break on first middleware that wants to retry
|
||||
for hook_info in middleware_w_retry:
|
||||
retry_request = None
|
||||
if hook_info.async_fn:
|
||||
retry_request = await hook_info.async_fn(
|
||||
error, request, state, runtime, attempt
|
||||
)
|
||||
elif hook_info.sync_fn:
|
||||
# Fallback to sync if only sync is implemented
|
||||
retry_request = await run_in_executor(
|
||||
None, hook_info.sync_fn, error, request, state, runtime, attempt
|
||||
)
|
||||
|
||||
if retry_request:
|
||||
request = retry_request
|
||||
break
|
||||
else:
|
||||
# If no middleware wants to retry, re-raise the error
|
||||
raise
|
||||
|
||||
# If we exit the loop, max attempts exceeded
|
||||
msg = f"Maximum retry attempts ({max_attempts}) exceeded"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
# Use sync or async based on model capabilities
|
||||
from langgraph._internal._runnable import RunnableCallable
|
||||
# Phase 5: Add nodes to graph
|
||||
graph.add_node("model_request", RunnableCallable(model_request, amodel_request, trace=False))
|
||||
if components.tool_node is not None:
|
||||
graph.add_node("tools", components.tool_node)
|
||||
_add_middleware_nodes(graph, components, state_schema)
|
||||
|
||||
graph.add_node("model_request", RunnableCallable(model_request, amodel_request))
|
||||
# Phase 6: Compute graph topology
|
||||
topology = GraphTopology.compute(components.middleware_hooks)
|
||||
|
||||
# Only add tools node if we have tools
|
||||
if tool_node is not None:
|
||||
graph.add_node("tools", tool_node)
|
||||
|
||||
# Add middleware nodes
|
||||
for m in middleware:
|
||||
if (
|
||||
m.__class__.before_model is not AgentMiddleware.before_model
|
||||
or m.__class__.abefore_model is not AgentMiddleware.abefore_model
|
||||
):
|
||||
# Use RunnableCallable to support both sync and async
|
||||
# Pass None for sync if not overridden to avoid signature conflicts
|
||||
sync_before = (
|
||||
m.before_model
|
||||
if m.__class__.before_model is not AgentMiddleware.before_model
|
||||
else None
|
||||
)
|
||||
async_before = (
|
||||
m.abefore_model
|
||||
if m.__class__.abefore_model is not AgentMiddleware.abefore_model
|
||||
else None
|
||||
)
|
||||
before_node = RunnableCallable(sync_before, async_before)
|
||||
graph.add_node(f"{m.name}.before_model", before_node, input_schema=state_schema)
|
||||
|
||||
if (
|
||||
m.__class__.after_model is not AgentMiddleware.after_model
|
||||
or m.__class__.aafter_model is not AgentMiddleware.aafter_model
|
||||
):
|
||||
# Use RunnableCallable to support both sync and async
|
||||
# Pass None for sync if not overridden to avoid signature conflicts
|
||||
sync_after = (
|
||||
m.after_model
|
||||
if m.__class__.after_model is not AgentMiddleware.after_model
|
||||
else None
|
||||
)
|
||||
async_after = (
|
||||
m.aafter_model
|
||||
if m.__class__.aafter_model is not AgentMiddleware.aafter_model
|
||||
else None
|
||||
)
|
||||
after_node = RunnableCallable(sync_after, async_after)
|
||||
graph.add_node(f"{m.name}.after_model", after_node, input_schema=state_schema)
|
||||
|
||||
# add start edge
|
||||
first_node = (
|
||||
f"{middleware_w_before[0].name}.before_model" if middleware_w_before else "model_request"
|
||||
)
|
||||
last_node = (
|
||||
f"{middleware_w_after[0].name}.after_model" if middleware_w_after else "model_request"
|
||||
)
|
||||
graph.add_edge(START, first_node)
|
||||
|
||||
# add conditional edges only if tools exist
|
||||
if tool_node is not None:
|
||||
graph.add_conditional_edges(
|
||||
"tools",
|
||||
_make_tools_to_model_edge(tool_node, first_node, structured_output_tools),
|
||||
[first_node, END],
|
||||
)
|
||||
graph.add_conditional_edges(
|
||||
last_node,
|
||||
_make_model_to_tools_edge(first_node, structured_output_tools, tool_node),
|
||||
[first_node, "tools", END],
|
||||
)
|
||||
elif last_node == "model_request":
|
||||
# If no tools, just go to END from model
|
||||
graph.add_edge(last_node, END)
|
||||
else:
|
||||
# If after_model, then need to check for can_jump_to
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
f"{middleware_w_after[0].name}.after_model",
|
||||
END,
|
||||
first_node,
|
||||
can_jump_to=_get_can_jump_to(middleware_w_after[0], "after_model"),
|
||||
)
|
||||
|
||||
# Add middleware edges (same as before)
|
||||
if middleware_w_before:
|
||||
for m1, m2 in itertools.pairwise(middleware_w_before):
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
f"{m1.name}.before_model",
|
||||
f"{m2.name}.before_model",
|
||||
first_node,
|
||||
can_jump_to=_get_can_jump_to(m1, "before_model"),
|
||||
)
|
||||
# Go directly to model_request after the last before_model
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
f"{middleware_w_before[-1].name}.before_model",
|
||||
"model_request",
|
||||
first_node,
|
||||
can_jump_to=_get_can_jump_to(middleware_w_before[-1], "before_model"),
|
||||
)
|
||||
|
||||
if middleware_w_after:
|
||||
graph.add_edge("model_request", f"{middleware_w_after[-1].name}.after_model")
|
||||
for idx in range(len(middleware_w_after) - 1, 0, -1):
|
||||
m1 = middleware_w_after[idx]
|
||||
m2 = middleware_w_after[idx - 1]
|
||||
_add_middleware_edge(
|
||||
graph,
|
||||
f"{m1.name}.after_model",
|
||||
f"{m2.name}.after_model",
|
||||
first_node,
|
||||
can_jump_to=_get_can_jump_to(m1, "after_model"),
|
||||
)
|
||||
# Phase 7: Connect edges
|
||||
_connect_entry_edges(graph, topology)
|
||||
_connect_loop_edges(graph, topology, components)
|
||||
_connect_middleware_chains(graph, components, topology)
|
||||
|
||||
return graph
|
||||
|
||||
@@ -768,7 +1043,10 @@ def _fetch_last_ai_and_tool_messages(
|
||||
|
||||
|
||||
def _make_model_to_tools_edge(
|
||||
first_node: str, structured_output_tools: dict[str, OutputToolBinding], tool_node: ToolNode
|
||||
first_node: str,
|
||||
structured_output_tools: dict[str, OutputToolBinding],
|
||||
tool_node: ToolNode,
|
||||
exit_node: str,
|
||||
) -> Callable[[dict[str, Any]], str | list[Send] | None]:
|
||||
def model_to_tools(state: dict[str, Any]) -> str | list[Send] | None:
|
||||
# 1. if there's an explicit jump_to in the state, use it
|
||||
@@ -778,10 +1056,10 @@ def _make_model_to_tools_edge(
|
||||
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
|
||||
tool_message_ids = [m.tool_call_id for m in tool_messages]
|
||||
|
||||
# 2. if the model hasn't called any tools, jump to END
|
||||
# 2. if the model hasn't called any tools, exit the loop
|
||||
# this is the classic exit condition for an agent loop
|
||||
if len(last_ai_message.tool_calls) == 0:
|
||||
return END
|
||||
return exit_node
|
||||
|
||||
pending_tool_calls = [
|
||||
c
|
||||
@@ -804,7 +1082,10 @@ def _make_model_to_tools_edge(
|
||||
|
||||
|
||||
def _make_tools_to_model_edge(
|
||||
tool_node: ToolNode, next_node: str, structured_output_tools: dict[str, OutputToolBinding]
|
||||
tool_node: ToolNode,
|
||||
next_node: str,
|
||||
structured_output_tools: dict[str, OutputToolBinding],
|
||||
exit_node: str,
|
||||
) -> Callable[[dict[str, Any]], str | None]:
|
||||
def tools_to_model(state: dict[str, Any]) -> str | None:
|
||||
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
|
||||
@@ -814,10 +1095,10 @@ def _make_tools_to_model_edge(
|
||||
for c in last_ai_message.tool_calls
|
||||
if c["name"] in tool_node.tools_by_name
|
||||
):
|
||||
return END
|
||||
return exit_node
|
||||
|
||||
if any(t.name in structured_output_tools for t in tool_messages):
|
||||
return END
|
||||
return exit_node
|
||||
|
||||
return next_node
|
||||
|
||||
|
||||
@@ -73,8 +73,8 @@
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model_request(model_request)
|
||||
NoopTen\2ebefore_model(NoopTen.before_model)
|
||||
NoopTen\2eafter_model(NoopTen.after_model)
|
||||
NoopEleven\2ebefore_model(NoopEleven.before_model)
|
||||
NoopTen\2eafter_model(NoopTen.after_model)
|
||||
NoopEleven\2eafter_model(NoopEleven.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEleven\2eafter_model --> NoopTen\2eafter_model;
|
||||
@@ -240,8 +240,8 @@
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model_request(model_request)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
@@ -267,10 +267,10 @@
|
||||
__start__([<p>__start__</p>]):::first
|
||||
model_request(model_request)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
NoopNine\2ebefore_model(NoopNine.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
NoopNine\2eafter_model(NoopNine.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
@@ -299,8 +299,8 @@
|
||||
model_request(model_request)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
@@ -332,8 +332,8 @@
|
||||
model_request(model_request)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
@@ -365,8 +365,8 @@
|
||||
model_request(model_request)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
@@ -398,8 +398,8 @@
|
||||
model_request(model_request)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
@@ -431,8 +431,8 @@
|
||||
model_request(model_request)
|
||||
tools(tools)
|
||||
NoopSeven\2ebefore_model(NoopSeven.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2ebefore_model(NoopEight.before_model)
|
||||
NoopSeven\2eafter_model(NoopSeven.after_model)
|
||||
NoopEight\2eafter_model(NoopEight.after_model)
|
||||
__end__([<p>__end__</p>]):::last
|
||||
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
|
||||
|
||||
@@ -0,0 +1,388 @@
|
||||
"""Unit tests for before_agent and after_agent middleware hooks."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
after_agent,
|
||||
before_agent,
|
||||
)
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from ..model import FakeToolCallingModel
|
||||
|
||||
|
||||
@tool
|
||||
def sample_tool(query: str) -> str:
|
||||
"""A sample tool for testing."""
|
||||
return f"Result for: {query}"
|
||||
|
||||
|
||||
class TestBeforeAgentBasic:
|
||||
"""Test basic before_agent functionality."""
|
||||
|
||||
def test_sync_before_agent_execution(self) -> None:
|
||||
"""Test that before_agent hook is called synchronously."""
|
||||
execution_log = []
|
||||
|
||||
@before_agent
|
||||
def log_before_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
execution_log.append("before_agent_called")
|
||||
execution_log.append(f"message_count: {len(state['messages'])}")
|
||||
return None
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")]))
|
||||
|
||||
agent = create_agent(model=model, tools=[], middleware=[log_before_agent])
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
assert "before_agent_called" in execution_log
|
||||
assert "message_count: 1" in execution_log
|
||||
|
||||
async def test_async_before_agent_execution(self) -> None:
|
||||
"""Test that before_agent hook is called asynchronously."""
|
||||
execution_log = []
|
||||
|
||||
@before_agent
|
||||
async def async_log_before_agent(
|
||||
state: AgentState, runtime: Runtime
|
||||
) -> dict[str, Any] | None:
|
||||
execution_log.append("async_before_agent_called")
|
||||
execution_log.append(f"message_count: {len(state['messages'])}")
|
||||
return None
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")]))
|
||||
|
||||
agent = create_agent(model=model, tools=[], middleware=[async_log_before_agent])
|
||||
|
||||
await agent.ainvoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
assert "async_before_agent_called" in execution_log
|
||||
assert "message_count: 1" in execution_log
|
||||
|
||||
def test_before_agent_state_modification(self) -> None:
|
||||
"""Test that before_agent can modify state."""
|
||||
|
||||
@before_agent
|
||||
def add_metadata(state: AgentState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {"messages": [HumanMessage("Injected by middleware")]}
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
|
||||
agent = create_agent(model=model, tools=[], middleware=[add_metadata])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Original")]})
|
||||
|
||||
# Should have original + injected + AI response
|
||||
assert len(result["messages"]) >= 2
|
||||
message_contents = [msg.content for msg in result["messages"]]
|
||||
assert "Injected by middleware" in message_contents
|
||||
|
||||
def test_before_agent_with_class_inheritance(self) -> None:
|
||||
"""Test before_agent using class inheritance."""
|
||||
execution_log = []
|
||||
|
||||
class CustomBeforeAgentMiddleware(AgentMiddleware):
|
||||
def before_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
execution_log.append("class_before_agent_called")
|
||||
return None
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
|
||||
agent = create_agent(model=model, tools=[], middleware=[CustomBeforeAgentMiddleware()])
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert "class_before_agent_called" in execution_log
|
||||
|
||||
async def test_before_agent_with_async_class_inheritance(self) -> None:
|
||||
"""Test async before_agent using class inheritance."""
|
||||
execution_log = []
|
||||
|
||||
class CustomAsyncBeforeAgentMiddleware(AgentMiddleware):
|
||||
async def abefore_agent(
|
||||
self, state: AgentState, runtime: Runtime
|
||||
) -> dict[str, Any] | None:
|
||||
execution_log.append("async_class_before_agent_called")
|
||||
return None
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
|
||||
agent = create_agent(model=model, tools=[], middleware=[CustomAsyncBeforeAgentMiddleware()])
|
||||
|
||||
await agent.ainvoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert "async_class_before_agent_called" in execution_log
|
||||
|
||||
|
||||
class TestAfterAgentBasic:
|
||||
"""Test basic after_agent functionality."""
|
||||
|
||||
def test_sync_after_agent_execution(self) -> None:
|
||||
"""Test that after_agent hook is called synchronously."""
|
||||
execution_log = []
|
||||
|
||||
@after_agent
|
||||
def log_after_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
execution_log.append("after_agent_called")
|
||||
execution_log.append(f"final_message_count: {len(state['messages'])}")
|
||||
return None
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Final response")]))
|
||||
|
||||
agent = create_agent(model=model, tools=[], middleware=[log_after_agent])
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
assert "after_agent_called" in execution_log
|
||||
assert any("final_message_count:" in log for log in execution_log)
|
||||
|
||||
async def test_async_after_agent_execution(self) -> None:
|
||||
"""Test that after_agent hook is called asynchronously."""
|
||||
execution_log = []
|
||||
|
||||
@after_agent
|
||||
async def async_log_after_agent(
|
||||
state: AgentState, runtime: Runtime
|
||||
) -> dict[str, Any] | None:
|
||||
execution_log.append("async_after_agent_called")
|
||||
return None
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
|
||||
agent = create_agent(model=model, tools=[], middleware=[async_log_after_agent])
|
||||
|
||||
await agent.ainvoke({"messages": [HumanMessage("Hi")]})
|
||||
|
||||
assert "async_after_agent_called" in execution_log
|
||||
|
||||
def test_after_agent_state_modification(self) -> None:
|
||||
"""Test that after_agent can modify state."""
|
||||
|
||||
@after_agent
|
||||
def add_final_message(state: AgentState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {"messages": [AIMessage("Added by after_agent")]}
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Model response")]))
|
||||
|
||||
agent = create_agent(model=model, tools=[], middleware=[add_final_message])
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
message_contents = [msg.content for msg in result["messages"]]
|
||||
assert "Added by after_agent" in message_contents
|
||||
|
||||
def test_after_agent_with_class_inheritance(self) -> None:
|
||||
"""Test after_agent using class inheritance."""
|
||||
execution_log = []
|
||||
|
||||
class CustomAfterAgentMiddleware(AgentMiddleware):
|
||||
def after_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
execution_log.append("class_after_agent_called")
|
||||
return None
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
|
||||
agent = create_agent(model=model, tools=[], middleware=[CustomAfterAgentMiddleware()])
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert "class_after_agent_called" in execution_log
|
||||
|
||||
async def test_after_agent_with_async_class_inheritance(self) -> None:
|
||||
"""Test async after_agent using class inheritance."""
|
||||
execution_log = []
|
||||
|
||||
class CustomAsyncAfterAgentMiddleware(AgentMiddleware):
|
||||
async def aafter_agent(
|
||||
self, state: AgentState, runtime: Runtime
|
||||
) -> dict[str, Any] | None:
|
||||
execution_log.append("async_class_after_agent_called")
|
||||
return None
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
|
||||
agent = create_agent(model=model, tools=[], middleware=[CustomAsyncAfterAgentMiddleware()])
|
||||
|
||||
await agent.ainvoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert "async_class_after_agent_called" in execution_log
|
||||
|
||||
|
||||
class TestBeforeAndAfterAgentCombined:
|
||||
"""Test before_agent and after_agent hooks working together."""
|
||||
|
||||
def test_execution_order(self) -> None:
|
||||
"""Test that before_agent executes before after_agent."""
|
||||
execution_log = []
|
||||
|
||||
@before_agent
|
||||
def log_before(state: AgentState, runtime: Runtime) -> None:
|
||||
execution_log.append("before")
|
||||
|
||||
@after_agent
|
||||
def log_after(state: AgentState, runtime: Runtime) -> None:
|
||||
execution_log.append("after")
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
|
||||
agent = create_agent(model=model, tools=[], middleware=[log_before, log_after])
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert execution_log == ["before", "after"]
|
||||
|
||||
async def test_async_execution_order(self) -> None:
|
||||
"""Test async execution order of before_agent and after_agent."""
|
||||
execution_log = []
|
||||
|
||||
@before_agent
|
||||
async def async_log_before(state: AgentState, runtime: Runtime) -> None:
|
||||
execution_log.append("async_before")
|
||||
|
||||
@after_agent
|
||||
async def async_log_after(state: AgentState, runtime: Runtime) -> None:
|
||||
execution_log.append("async_after")
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
|
||||
agent = create_agent(model=model, tools=[], middleware=[async_log_before, async_log_after])
|
||||
|
||||
await agent.ainvoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert execution_log == ["async_before", "async_after"]
|
||||
|
||||
def test_state_passthrough(self) -> None:
|
||||
"""Test that state modifications in before_agent are visible to after_agent."""
|
||||
collected_states = {}
|
||||
|
||||
@before_agent
|
||||
def modify_in_before(state: AgentState, runtime: Runtime) -> dict[str, Any]:
|
||||
return {"messages": [HumanMessage("Modified by before_agent")]}
|
||||
|
||||
@after_agent
|
||||
def capture_in_after(state: AgentState, runtime: Runtime) -> None:
|
||||
collected_states["messages"] = state["messages"]
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
|
||||
agent = create_agent(model=model, tools=[], middleware=[modify_in_before, capture_in_after])
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Original")]})
|
||||
|
||||
message_contents = [msg.content for msg in collected_states["messages"]]
|
||||
assert "Modified by before_agent" in message_contents
|
||||
|
||||
def test_multiple_middleware_instances(self) -> None:
|
||||
"""Test multiple before_agent and after_agent middleware instances."""
|
||||
execution_log = []
|
||||
|
||||
@before_agent
|
||||
def before_one(state: AgentState, runtime: Runtime) -> None:
|
||||
execution_log.append("before_1")
|
||||
|
||||
@before_agent
|
||||
def before_two(state: AgentState, runtime: Runtime) -> None:
|
||||
execution_log.append("before_2")
|
||||
|
||||
@after_agent
|
||||
def after_one(state: AgentState, runtime: Runtime) -> None:
|
||||
execution_log.append("after_1")
|
||||
|
||||
@after_agent
|
||||
def after_two(state: AgentState, runtime: Runtime) -> None:
|
||||
execution_log.append("after_2")
|
||||
|
||||
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
|
||||
|
||||
agent = create_agent(
|
||||
model=model, tools=[], middleware=[before_one, before_two, after_one, after_two]
|
||||
)
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
assert "before_1" in execution_log
|
||||
assert "before_2" in execution_log
|
||||
assert "after_1" in execution_log
|
||||
assert "after_2" in execution_log
|
||||
|
||||
def test_agent_hooks_run_once_with_multiple_model_calls(self) -> None:
|
||||
"""Test that before_agent and after_agent run only once even with tool calls."""
|
||||
execution_log = []
|
||||
|
||||
@before_agent
|
||||
def log_before_agent(state: AgentState, runtime: Runtime) -> None:
|
||||
execution_log.append("before_agent")
|
||||
|
||||
@after_agent
|
||||
def log_after_agent(state: AgentState, runtime: Runtime) -> None:
|
||||
execution_log.append("after_agent")
|
||||
|
||||
# Model will call a tool once, then respond with final answer
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[{"name": "sample_tool", "args": {"query": "test"}, "id": "1"}],
|
||||
[], # Second call returns no tool calls (final answer)
|
||||
]
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[sample_tool],
|
||||
middleware=[log_before_agent, log_after_agent],
|
||||
)
|
||||
|
||||
agent.invoke({"messages": [HumanMessage("Test")]})
|
||||
|
||||
# before_agent and after_agent should run exactly once
|
||||
assert execution_log.count("before_agent") == 1
|
||||
assert execution_log.count("after_agent") == 1
|
||||
# before_agent should run first, after_agent should run last
|
||||
assert execution_log[0] == "before_agent"
|
||||
assert execution_log[-1] == "after_agent"
|
||||
|
||||
|
||||
class TestDecoratorParameters:
|
||||
"""Test decorator parameters for before_agent and after_agent."""
|
||||
|
||||
def test_before_agent_with_custom_name(self) -> None:
|
||||
"""Test before_agent with custom middleware name."""
|
||||
|
||||
@before_agent(name="CustomBeforeAgentMiddleware")
|
||||
def custom_named_before(state: AgentState, runtime: Runtime) -> None:
|
||||
pass
|
||||
|
||||
assert custom_named_before.name == "CustomBeforeAgentMiddleware"
|
||||
|
||||
def test_after_agent_with_custom_name(self) -> None:
|
||||
"""Test after_agent with custom middleware name."""
|
||||
|
||||
@after_agent(name="CustomAfterAgentMiddleware")
|
||||
def custom_named_after(state: AgentState, runtime: Runtime) -> None:
|
||||
pass
|
||||
|
||||
assert custom_named_after.name == "CustomAfterAgentMiddleware"
|
||||
|
||||
def test_before_agent_default_name(self) -> None:
|
||||
"""Test that before_agent uses function name by default."""
|
||||
|
||||
@before_agent
|
||||
def my_before_agent_function(state: AgentState, runtime: Runtime) -> None:
|
||||
pass
|
||||
|
||||
assert my_before_agent_function.name == "my_before_agent_function"
|
||||
|
||||
def test_after_agent_default_name(self) -> None:
|
||||
"""Test that after_agent uses function name by default."""
|
||||
|
||||
@after_agent
|
||||
def my_after_agent_function(state: AgentState, runtime: Runtime) -> None:
|
||||
pass
|
||||
|
||||
assert my_after_agent_function.name == "my_after_agent_function"
|
||||
@@ -8,8 +8,8 @@ from syrupy.assertion import SnapshotAssertion
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import Command
|
||||
|
||||
from langchain.agents.middleware_agent import create_agent
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
@@ -20,7 +20,6 @@ from langchain.agents.middleware.types import (
|
||||
modify_model_request,
|
||||
hook_config,
|
||||
)
|
||||
from langchain.agents.middleware_agent import create_agent, _get_can_jump_to
|
||||
from .model import FakeToolCallingModel
|
||||
|
||||
|
||||
@@ -470,49 +469,6 @@ async def test_async_can_jump_to_integration() -> None:
|
||||
assert len(result["messages"]) > 1
|
||||
|
||||
|
||||
def test_get_can_jump_to_no_false_positives() -> None:
|
||||
"""Test that _get_can_jump_to doesn't return false positives for base class methods."""
|
||||
|
||||
# Middleware with no overridden methods should return empty list
|
||||
class EmptyMiddleware(AgentMiddleware):
|
||||
pass
|
||||
|
||||
empty_middleware = EmptyMiddleware()
|
||||
empty_middleware.tools = []
|
||||
|
||||
# Should not return any jump destinations for base class methods
|
||||
assert _get_can_jump_to(empty_middleware, "before_model") == []
|
||||
assert _get_can_jump_to(empty_middleware, "after_model") == []
|
||||
|
||||
|
||||
def test_get_can_jump_to_only_overridden_methods() -> None:
|
||||
"""Test that _get_can_jump_to only checks overridden methods."""
|
||||
|
||||
# Middleware with only sync method overridden
|
||||
class SyncOnlyMiddleware(AgentMiddleware):
|
||||
@hook_config(can_jump_to=["end"])
|
||||
def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
sync_middleware = SyncOnlyMiddleware()
|
||||
sync_middleware.tools = []
|
||||
|
||||
# Should return can_jump_to from overridden sync method
|
||||
assert _get_can_jump_to(sync_middleware, "before_model") == ["end"]
|
||||
|
||||
# Middleware with only async method overridden
|
||||
class AsyncOnlyMiddleware(AgentMiddleware):
|
||||
@hook_config(can_jump_to=["model"])
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
async_middleware = AsyncOnlyMiddleware()
|
||||
async_middleware.tools = []
|
||||
|
||||
# Should return can_jump_to from overridden async method
|
||||
assert _get_can_jump_to(async_middleware, "after_model") == ["model"]
|
||||
|
||||
|
||||
def test_async_middleware_with_can_jump_to_graph_snapshot(snapshot: SnapshotAssertion) -> None:
|
||||
"""Test that async middleware with can_jump_to creates correct graph structure with conditional edges."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user