Compare commits

...

6 Commits

Author SHA1 Message Date
Sydney Runkle
5b97dada4d removal 2025-10-05 22:13:05 -04:00
Sydney Runkle
b276d34a0b fix tests 2025-10-05 22:09:35 -04:00
Sydney Runkle
5441475ac9 semi supervised refactor 2025-10-05 22:04:54 -04:00
Sydney Runkle
22fb405b45 feat(langchain_v1): improving tracing (#33283)
Forgo tracing `RunnableCallable` outer calls + ensure we're not showing
structured response on input fields
2025-10-05 21:16:03 -04:00
Sydney Runkle
6e07fc7982 renaming 2025-10-05 17:38:14 -04:00
Sydney Runkle
983d84ade8 first pass at before and after agent 2025-10-05 17:02:28 -04:00
7 changed files with 1395 additions and 322 deletions

View File

@@ -12,7 +12,9 @@ from .types import (
AgentMiddleware,
AgentState,
ModelRequest,
after_agent,
after_model,
before_agent,
before_model,
dynamic_prompt,
hook_config,
@@ -33,7 +35,9 @@ __all__ = [
"PlanningMiddleware",
"SummarizationMiddleware",
"ToolCallLimitMiddleware",
"after_agent",
"after_model",
"before_agent",
"before_model",
"dynamic_prompt",
"hook_config",

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import Enum
from inspect import iscoroutinefunction
from typing import (
TYPE_CHECKING,
@@ -41,11 +42,18 @@ __all__ = [
"AgentMiddleware",
"AgentState",
"ContextT",
"HookImplementation",
"MiddlewareHookInfo",
"ModelRequest",
"OmitFromSchema",
"PublicAgentState",
"after_agent",
"after_model",
"before_agent",
"before_model",
"dynamic_prompt",
"hook_config",
"modify_model_request",
]
JumpTo = Literal["tools", "model", "end"]
@@ -54,6 +62,15 @@ JumpTo = Literal["tools", "model", "end"]
ResponseT = TypeVar("ResponseT")
class HookImplementation(str, Enum):
"""Tracks which implementation variants exist for a middleware hook."""
NONE = "none"
SYNC_ONLY = "sync"
ASYNC_ONLY = "async"
BOTH = "both"
@dataclass
class ModelRequest:
"""Model request information for the agent."""
@@ -93,7 +110,7 @@ class AgentState(TypedDict, Generic[ResponseT]):
messages: Required[Annotated[list[AnyMessage], add_messages]]
jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]]
structured_response: NotRequired[ResponseT]
structured_response: NotRequired[Annotated[ResponseT, OmitFromInput]]
thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]]
run_model_call_count: NotRequired[Annotated[int, UntrackedValue, PrivateStateAttr]]
@@ -112,6 +129,56 @@ StateT = TypeVar("StateT", bound=AgentState, default=AgentState)
StateT_contra = TypeVar("StateT_contra", bound=AgentState, contravariant=True)
@dataclass
class MiddlewareHookInfo:
"""Information about a specific middleware hook implementation.
This class encapsulates metadata about how a middleware implements a particular hook,
including the actual hook functions and jump configuration.
"""
middleware_name: str
"""The name of the middleware that implements this hook."""
hook_name: str
"""The name of the hook (e.g., 'before_model', 'after_agent')."""
sync_fn: Callable[..., Any] | None
"""The synchronous hook function, or None if not implemented."""
async_fn: Callable[..., Any] | None
"""The asynchronous hook function, or None if not implemented."""
can_jump_to: list[JumpTo]
"""Valid jump destinations for this hook."""
@property
def node_name(self) -> str:
"""The graph node name for this hook."""
return f"{self.middleware_name}.{self.hook_name}"
@property
def has_sync(self) -> bool:
"""Whether this hook has a sync implementation."""
return self.sync_fn is not None
@property
def has_async(self) -> bool:
"""Whether this hook has an async implementation."""
return self.async_fn is not None
@property
def implementation(self) -> HookImplementation:
"""Which variants (sync/async/both) are implemented."""
if self.has_sync and self.has_async:
return HookImplementation.BOTH
if self.has_sync:
return HookImplementation.SYNC_ONLY
if self.has_async:
return HookImplementation.ASYNC_ONLY
return HookImplementation.NONE
class AgentMiddleware(Generic[StateT, ContextT]):
"""Base middleware class for an agent.
@@ -133,6 +200,14 @@ class AgentMiddleware(Generic[StateT, ContextT]):
"""
return self.__class__.__name__
def before_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
"""Logic to run before the agent execution starts."""
async def abefore_agent(
self, state: StateT, runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
"""Async logic to run before the agent execution starts."""
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
"""Logic to run before the model is called."""
@@ -215,6 +290,102 @@ class AgentMiddleware(Generic[StateT, ContextT]):
None, self.retry_model_request, error, request, state, runtime, attempt
)
def after_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
"""Logic to run after the agent execution completes."""
async def aafter_agent(
self, state: StateT, runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
"""Async logic to run after the agent execution completes."""
def hook_info(self, hook_name: str) -> MiddlewareHookInfo | None:
"""Get information about this middleware's implementation of a specific hook.
Args:
hook_name: The name of the hook to inspect (e.g., 'before_model', 'after_agent').
Returns:
MiddlewareHookInfo if the hook is implemented, None otherwise.
Example:
>>> middleware = MyMiddleware()
>>> info = middleware.hook_info("before_model")
>>> if info:
... print(f"Has sync: {info.has_sync}, Has async: {info.has_async}")
"""
base_class = AgentMiddleware
middleware_class = self.__class__
# Check sync and async variants
sync_name = hook_name
async_name = f"a{hook_name}"
base_sync_method = getattr(base_class, sync_name, None)
base_async_method = getattr(base_class, async_name, None)
middleware_sync_method = getattr(middleware_class, sync_name, None)
middleware_async_method = getattr(middleware_class, async_name, None)
has_custom_sync = middleware_sync_method is not base_sync_method
has_custom_async = middleware_async_method is not base_async_method
if not has_custom_sync and not has_custom_async:
return None
# Get the actual bound methods - only include customized implementations
sync_fn = getattr(self, sync_name) if has_custom_sync else None
async_fn = getattr(self, async_name) if has_custom_async else None
# Get can_jump_to from either sync or async variant
can_jump_to: list[JumpTo] = []
if has_custom_sync:
can_jump_to = getattr(middleware_sync_method, "__can_jump_to__", [])
elif has_custom_async:
can_jump_to = getattr(middleware_async_method, "__can_jump_to__", [])
return MiddlewareHookInfo(
middleware_name=self.name,
hook_name=hook_name,
sync_fn=sync_fn,
async_fn=async_fn,
can_jump_to=can_jump_to,
)
def all_hook_info(self) -> dict[str, MiddlewareHookInfo]:
"""Get information about all hooks implemented by this middleware.
Returns:
Dictionary mapping hook names to their MiddlewareHookInfo.
Example:
>>> middleware = MyMiddleware()
>>> for hook_name, info in middleware.all_hook_info().items():
... print(f"{hook_name}: sync={info.has_sync}, async={info.has_async}")
"""
hook_names = [
"before_agent",
"before_model",
"modify_model_request",
"after_model",
"after_agent",
"retry_model_request",
]
return {name: info for name in hook_names if (info := self.hook_info(name)) is not None}
@property
def implemented_hooks(self) -> list[str]:
"""List of hook names this middleware implements.
Returns:
List of hook names that are overridden from the base class.
Example:
>>> middleware = MyMiddleware()
>>> print(middleware.implemented_hooks)
['before_model', 'after_model']
"""
return list(self.all_hook_info().keys())
class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
"""Callable with AgentState and Runtime as arguments."""
@@ -707,6 +878,279 @@ def after_model(
return decorator
@overload
def before_agent(
func: _CallableWithStateAndRuntime[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]: ...
@overload
def before_agent(
func: None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
can_jump_to: list[JumpTo] | None = None,
name: str | None = None,
) -> Callable[
[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]
]: ...
def before_agent(
func: _CallableWithStateAndRuntime[StateT, ContextT] | None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
can_jump_to: list[JumpTo] | None = None,
name: str | None = None,
) -> (
Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
| AgentMiddleware[StateT, ContextT]
):
"""Decorator used to dynamically create a middleware with the before_agent hook.
Args:
func: The function to be decorated. Must accept:
`state: StateT, runtime: Runtime[ContextT]` - State and runtime context
state_schema: Optional custom state schema type. If not provided, uses the default
AgentState schema.
tools: Optional list of additional tools to register with this middleware.
can_jump_to: Optional list of valid jump destinations for conditional edges.
Valid values are: "tools", "model", "end"
name: Optional name for the generated middleware class. If not provided,
uses the decorated function's name.
Returns:
Either an AgentMiddleware instance (if func is provided directly) or a decorator function
that can be applied to a function its wrapping.
The decorated function should return:
- `dict[str, Any]` - State updates to merge into the agent state
- `Command` - A command to control flow (e.g., jump to different node)
- `None` - No state updates or flow control
Examples:
Basic usage:
```python
@before_agent
def log_before_agent(state: AgentState, runtime: Runtime) -> None:
print(f"Starting agent with {len(state['messages'])} messages")
```
With conditional jumping:
```python
@before_agent(can_jump_to=["end"])
def conditional_before_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
if some_condition(state):
return {"jump_to": "end"}
return None
```
With custom state schema:
```python
@before_agent(state_schema=MyCustomState)
def custom_before_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
return {"custom_field": "initialized_value"}
```
"""
def decorator(
func: _CallableWithStateAndRuntime[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]:
is_async = iscoroutinefunction(func)
func_can_jump_to = (
can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", [])
)
if is_async:
async def async_wrapped(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command | None:
return await func(state, runtime) # type: ignore[misc]
# Preserve can_jump_to metadata on the wrapped function
if func_can_jump_to:
async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
middleware_name = name or cast(
"str", getattr(func, "__name__", "BeforeAgentMiddleware")
)
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"abefore_agent": async_wrapped,
},
)()
def wrapped(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command | None:
return func(state, runtime) # type: ignore[return-value]
# Preserve can_jump_to metadata on the wrapped function
if func_can_jump_to:
wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
# Use function name as default if no name provided
middleware_name = name or cast("str", getattr(func, "__name__", "BeforeAgentMiddleware"))
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"before_agent": wrapped,
},
)()
if func is not None:
return decorator(func)
return decorator
@overload
def after_agent(
func: _CallableWithStateAndRuntime[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]: ...
@overload
def after_agent(
func: None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
can_jump_to: list[JumpTo] | None = None,
name: str | None = None,
) -> Callable[
[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]
]: ...
def after_agent(
func: _CallableWithStateAndRuntime[StateT, ContextT] | None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
can_jump_to: list[JumpTo] | None = None,
name: str | None = None,
) -> (
Callable[[_CallableWithStateAndRuntime[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
| AgentMiddleware[StateT, ContextT]
):
"""Decorator used to dynamically create a middleware with the after_agent hook.
Args:
func: The function to be decorated. Must accept:
`state: StateT, runtime: Runtime[ContextT]` - State and runtime context
state_schema: Optional custom state schema type. If not provided, uses the default
AgentState schema.
tools: Optional list of additional tools to register with this middleware.
can_jump_to: Optional list of valid jump destinations for conditional edges.
Valid values are: "tools", "model", "end"
name: Optional name for the generated middleware class. If not provided,
uses the decorated function's name.
Returns:
Either an AgentMiddleware instance (if func is provided) or a decorator function
that can be applied to a function.
The decorated function should return:
- `dict[str, Any]` - State updates to merge into the agent state
- `Command` - A command to control flow (e.g., jump to different node)
- `None` - No state updates or flow control
Examples:
Basic usage for logging agent completion:
```python
@after_agent
def log_completion(state: AgentState, runtime: Runtime) -> None:
print(f"Agent completed with {len(state['messages'])} messages")
```
With custom state schema:
```python
@after_agent(state_schema=MyCustomState, name="MyAfterAgentMiddleware")
def custom_after_agent(state: MyCustomState, runtime: Runtime) -> dict[str, Any]:
return {"custom_field": "finalized_value"}
```
"""
def decorator(
func: _CallableWithStateAndRuntime[StateT, ContextT],
) -> AgentMiddleware[StateT, ContextT]:
is_async = iscoroutinefunction(func)
# Extract can_jump_to from decorator parameter or from function metadata
func_can_jump_to = (
can_jump_to if can_jump_to is not None else getattr(func, "__can_jump_to__", [])
)
if is_async:
async def async_wrapped(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command | None:
return await func(state, runtime) # type: ignore[misc]
# Preserve can_jump_to metadata on the wrapped function
if func_can_jump_to:
async_wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware"))
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"aafter_agent": async_wrapped,
},
)()
def wrapped(
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command | None:
return func(state, runtime) # type: ignore[return-value]
# Preserve can_jump_to metadata on the wrapped function
if func_can_jump_to:
wrapped.__can_jump_to__ = func_can_jump_to # type: ignore[attr-defined]
# Use function name as default if no name provided
middleware_name = name or cast("str", getattr(func, "__name__", "AfterAgentMiddleware"))
return type(
middleware_name,
(AgentMiddleware,),
{
"state_schema": state_schema or AgentState,
"tools": tools or [],
"after_agent": wrapped,
},
)()
if func is not None:
return decorator(func)
return decorator
@overload
def dynamic_prompt(
func: _CallableReturningPromptString[StateT, ContextT],

View File

@@ -1,13 +1,14 @@
"""Middleware agent implementation."""
import itertools
from collections.abc import Callable, Sequence
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
from dataclasses import dataclass
from typing import Annotated, Any, Generic, cast, get_args, get_origin, get_type_hints
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
from langchain_core.runnables import Runnable
from langchain_core.runnables import Runnable, run_in_executor
from langchain_core.tools import BaseTool
from langgraph._internal._runnable import RunnableCallable
from langgraph.constants import END, START
from langgraph.graph.state import StateGraph
from langgraph.runtime import Runtime
@@ -19,6 +20,7 @@ from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
JumpTo,
MiddlewareHookInfo,
ModelRequest,
OmitFromSchema,
PublicAgentState,
@@ -41,6 +43,167 @@ STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
ResponseT = TypeVar("ResponseT")
# ============================================================================
# Data Structures for Agent Graph Construction
# ============================================================================
@dataclass
class MiddlewareHooks:
"""Middleware hooks categorized by type for graph construction."""
before_agent: list[MiddlewareHookInfo]
"""Hooks that run once before the agent starts."""
before_model: list[MiddlewareHookInfo]
"""Hooks that run before each model call in the agent loop."""
modify_model_request: list[MiddlewareHookInfo]
"""Hooks that modify the model request before calling the model."""
after_model: list[MiddlewareHookInfo]
"""Hooks that run after each model call in the agent loop."""
after_agent: list[MiddlewareHookInfo]
"""Hooks that run once after the agent completes."""
retry: list[MiddlewareHookInfo]
"""Hooks that handle model invocation errors and optionally retry."""
@classmethod
def from_middleware_list(
cls,
middleware: Sequence[AgentMiddleware[AgentState[ResponseT], ContextT]],
) -> "MiddlewareHooks":
"""Extract and categorize all hooks from middleware instances.
Args:
middleware: Sequence of middleware instances to analyze.
Returns:
MiddlewareHooks with all hooks organized by type.
"""
hooks_by_type: dict[str, list[MiddlewareHookInfo]] = {
"before_agent": [],
"before_model": [],
"modify_model_request": [],
"after_model": [],
"after_agent": [],
"retry": [],
}
# Map hook names to their category
hook_name_mapping = {
"before_agent": "before_agent",
"before_model": "before_model",
"modify_model_request": "modify_model_request",
"after_model": "after_model",
"after_agent": "after_agent",
"retry_model_request": "retry",
}
for m in middleware:
for hook_name, category in hook_name_mapping.items():
if hook_info := m.hook_info(hook_name):
hooks_by_type[category].append(hook_info)
return cls(
before_agent=hooks_by_type["before_agent"],
before_model=hooks_by_type["before_model"],
modify_model_request=hooks_by_type["modify_model_request"],
after_model=hooks_by_type["after_model"],
after_agent=hooks_by_type["after_agent"],
retry=hooks_by_type["retry"],
)
@dataclass
class AgentComponents:
"""Core components and configuration for agent construction."""
model: BaseChatModel
"""The language model to use for the agent."""
tool_node: ToolNode | None
"""The tool execution node, or None if no tools are available."""
middleware_hooks: MiddlewareHooks
"""Middleware hooks organized by type."""
structured_output_tools: dict[str, OutputToolBinding]
"""Tools used for structured output parsing."""
default_tools: list[BaseTool | dict]
"""Default tools available to the agent (regular tools + middleware tools + built-ins)."""
initial_response_format: ResponseFormat | None
"""The initial response format configuration."""
system_prompt: str | None
"""The system prompt for the agent."""
@dataclass
class GraphTopology:
"""Key nodes in the graph topology defining the execution flow.
The agent graph has the following structure:
START -> entry_node -> [loop: loop_entry_node -> model -> loop_exit_node -> tools]
-> exit_node -> END
- entry_node: Runs once at the start (before_agent hooks)
- loop_entry_node: Beginning of agent loop (before_model hooks)
- loop_exit_node: End of each loop iteration (after_model hooks)
- exit_node: Runs once at the end (after_agent hooks) or END
"""
entry_node: str
"""The first node executed (START -> entry_node)."""
loop_entry_node: str
"""Where the agent loop begins (where tools loop back to)."""
loop_exit_node: str
"""The last node in each loop iteration."""
exit_node: str
"""The final node before END (or END itself)."""
@classmethod
def compute(cls, hooks: MiddlewareHooks) -> "GraphTopology":
"""Compute graph topology from middleware hook configuration.
Args:
hooks: The categorized middleware hooks.
Returns:
GraphTopology describing the flow through the graph.
"""
# Entry node (runs once at start): before_agent -> before_model -> model_request
if hooks.before_agent:
entry_node = hooks.before_agent[0].node_name
elif hooks.before_model:
entry_node = hooks.before_model[0].node_name
else:
entry_node = "model_request"
# Loop entry node (beginning of agent loop, excludes before_agent)
loop_entry_node = hooks.before_model[0].node_name if hooks.before_model else "model_request"
# Loop exit node (end of each iteration, excludes after_agent)
loop_exit_node = hooks.after_model[0].node_name if hooks.after_model else "model_request"
# Exit node (runs once at end): after_agent or END
exit_node = hooks.after_agent[-1].node_name if hooks.after_agent else END
return cls(
entry_node=entry_node,
loop_entry_node=loop_entry_node,
loop_exit_node=loop_exit_node,
exit_node=exit_node,
)
def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
"""Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations.
@@ -86,39 +249,112 @@ def _extract_metadata(type_: type) -> list:
return []
def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> list[JumpTo]:
"""Get the can_jump_to list from either sync or async hook methods.
# ============================================================================
# Setup and Initialization Functions
# ============================================================================
def _setup_components(
model: str | BaseChatModel,
tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode | None,
middleware: Sequence[AgentMiddleware[AgentState[ResponseT], ContextT]],
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None,
system_prompt: str | None,
) -> AgentComponents:
"""Setup and validate agent components.
Args:
middleware: The middleware instance to inspect.
hook_name: The name of the hook ('before_model' or 'after_model').
model: Model name or instance.
tools: Tools for the agent.
middleware: Middleware instances.
response_format: Response format configuration.
system_prompt: System prompt for the agent.
Returns:
List of jump destinations, or empty list if not configured.
AgentComponents with all components configured and validated.
"""
# Get the base class method for comparison
base_sync_method = getattr(AgentMiddleware, hook_name, None)
base_async_method = getattr(AgentMiddleware, f"a{hook_name}", None)
# Initialize chat model
if isinstance(model, str):
model = init_chat_model(model)
# Try sync method first - only if it's overridden from base class
sync_method = getattr(middleware.__class__, hook_name, None)
if (
sync_method
and sync_method is not base_sync_method
and hasattr(sync_method, "__can_jump_to__")
):
return sync_method.__can_jump_to__
# Handle tools being None or empty
if tools is None:
tools = []
# Try async method - only if it's overridden from base class
async_method = getattr(middleware.__class__, f"a{hook_name}", None)
if (
async_method
and async_method is not base_async_method
and hasattr(async_method, "__can_jump_to__")
):
return async_method.__can_jump_to__
# Convert response format and setup structured output tools
initial_response_format: ToolStrategy | ProviderStrategy | AutoStrategy | None
if response_format is None:
initial_response_format = None
elif isinstance(response_format, (ToolStrategy, ProviderStrategy, AutoStrategy)):
initial_response_format = response_format
else:
# Raw schema - wrap in AutoStrategy to enable auto-detection
initial_response_format = AutoStrategy(schema=response_format)
return []
# For AutoStrategy, convert to ToolStrategy to setup tools upfront
tool_strategy_for_setup: ToolStrategy | None = None
if isinstance(initial_response_format, AutoStrategy):
tool_strategy_for_setup = ToolStrategy(schema=initial_response_format.schema)
elif isinstance(initial_response_format, ToolStrategy):
tool_strategy_for_setup = initial_response_format
structured_output_tools: dict[str, OutputToolBinding] = {}
if tool_strategy_for_setup:
for response_schema in tool_strategy_for_setup.schema_specs:
structured_tool_info = OutputToolBinding.from_schema_spec(response_schema)
structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
# Setup tools
tool_node: ToolNode | None = None
default_tools: list[BaseTool | dict[str, Any]]
if isinstance(tools, list):
# Extract built-in provider tools (dict format) and regular tools (BaseTool)
built_in_tools = [t for t in tools if isinstance(t, dict)]
regular_tools = [t for t in tools if not isinstance(t, dict)]
# Tools that require client-side execution
available_tools = middleware_tools + regular_tools
# Only create ToolNode if we have client-side tools
tool_node = ToolNode(tools=available_tools) if available_tools else None
# Default tools for ModelRequest initialization
default_tools = regular_tools + middleware_tools + built_in_tools
elif isinstance(tools, ToolNode):
tool_node = tools
if tool_node:
# Add middleware tools to existing ToolNode
available_tools = list(tool_node.tools_by_name.values()) + middleware_tools
tool_node = ToolNode(available_tools)
# default_tools includes all client-side tools
default_tools = available_tools
else:
default_tools = middleware_tools
else:
# No tools provided, only middleware_tools available
default_tools = middleware_tools
# Validate middleware
assert len({m.name for m in middleware}) == len(middleware), ( # noqa: S101
"Please remove duplicate middleware instances."
)
# Categorize middleware by hooks
middleware_hooks = MiddlewareHooks.from_middleware_list(middleware)
return AgentComponents(
model=model,
tool_node=tool_node,
middleware_hooks=middleware_hooks,
structured_output_tools=structured_output_tools,
default_tools=default_tools,
initial_response_format=initial_response_format,
system_prompt=system_prompt,
)
def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
@@ -144,6 +380,58 @@ def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
)
# ============================================================================
# Node Building Functions
# ============================================================================
def _create_hook_node(hook_info: MiddlewareHookInfo) -> RunnableCallable:
"""Create a graph node for a middleware hook.
Args:
hook_info: Information about the hook to create a node for.
Returns:
RunnableCallable that supports both sync and async execution.
"""
return RunnableCallable(hook_info.sync_fn, hook_info.async_fn, trace=False)
def _add_middleware_nodes(
graph: StateGraph[AgentState, ContextT, PublicAgentState, PublicAgentState],
components: AgentComponents,
state_schema: type,
) -> None:
"""Add all middleware hook nodes to the graph.
Args:
graph: The state graph to add nodes to.
components: Agent components with middleware hooks.
state_schema: The state schema for input validation.
"""
hooks = components.middleware_hooks
# Add before_agent nodes
for hook_info in hooks.before_agent:
node = _create_hook_node(hook_info)
graph.add_node(hook_info.node_name, node, input_schema=state_schema)
# Add before_model nodes
for hook_info in hooks.before_model:
node = _create_hook_node(hook_info)
graph.add_node(hook_info.node_name, node, input_schema=state_schema)
# Add after_model nodes
for hook_info in hooks.after_model:
node = _create_hook_node(hook_info)
graph.add_node(hook_info.node_name, node, input_schema=state_schema)
# Add after_agent nodes
for hook_info in hooks.after_agent:
node = _create_hook_node(hook_info)
graph.add_node(hook_info.node_name, node, input_schema=state_schema)
def _handle_structured_output_error(
exception: Exception,
response_format: ResponseFormat,
@@ -174,6 +462,157 @@ def _handle_structured_output_error(
return False, ""
# ============================================================================
# Edge Building Functions
# ============================================================================
def _connect_entry_edges(
graph: StateGraph[AgentState, ContextT, PublicAgentState, PublicAgentState],
topology: GraphTopology,
) -> None:
"""Connect the entry edge from START to the entry node.
Args:
graph: The state graph to add edges to.
topology: Graph topology configuration.
"""
graph.add_edge(START, topology.entry_node)
def _connect_loop_edges(
graph: StateGraph[AgentState, ContextT, PublicAgentState, PublicAgentState],
topology: GraphTopology,
components: AgentComponents,
) -> None:
"""Connect conditional edges for the agent loop (tools <-> model).
Args:
graph: The state graph to add edges to.
topology: Graph topology configuration.
components: Agent components with tool configuration.
"""
tool_node = components.tool_node
structured_output_tools = components.structured_output_tools
if tool_node is None:
# No tools - connect loop_exit directly to exit_node
if topology.loop_exit_node == "model_request":
graph.add_edge(topology.loop_exit_node, topology.exit_node)
else:
# We have after_model but no tools
_add_middleware_edge(
graph,
topology.loop_exit_node,
topology.exit_node,
topology.loop_entry_node,
can_jump_to=components.middleware_hooks.after_model[0].can_jump_to,
)
return
# Add conditional edge from tools back to model or exit
graph.add_conditional_edges(
"tools",
_make_tools_to_model_edge(
tool_node, topology.loop_entry_node, structured_output_tools, topology.exit_node
),
[topology.loop_entry_node, topology.exit_node],
)
# Add conditional edge from model to tools or exit
graph.add_conditional_edges(
topology.loop_exit_node,
_make_model_to_tools_edge(
topology.loop_entry_node, structured_output_tools, tool_node, topology.exit_node
),
[topology.loop_entry_node, "tools", topology.exit_node],
)
def _connect_middleware_chains(
graph: StateGraph[AgentState, ContextT, PublicAgentState, PublicAgentState],
components: AgentComponents,
topology: GraphTopology,
) -> None:
"""Connect middleware hooks in chains.
Args:
graph: The state graph to add edges to.
components: Agent components with middleware hooks.
topology: Graph topology configuration.
"""
hooks = components.middleware_hooks
# Connect before_agent chain
if hooks.before_agent:
for i in range(len(hooks.before_agent) - 1):
_add_middleware_edge(
graph,
hooks.before_agent[i].node_name,
hooks.before_agent[i + 1].node_name,
topology.loop_entry_node,
can_jump_to=hooks.before_agent[i].can_jump_to,
)
# Connect last before_agent to loop_entry_node
_add_middleware_edge(
graph,
hooks.before_agent[-1].node_name,
topology.loop_entry_node,
topology.loop_entry_node,
can_jump_to=hooks.before_agent[-1].can_jump_to,
)
# Connect before_model chain
if hooks.before_model:
for i in range(len(hooks.before_model) - 1):
_add_middleware_edge(
graph,
hooks.before_model[i].node_name,
hooks.before_model[i + 1].node_name,
topology.loop_entry_node,
can_jump_to=hooks.before_model[i].can_jump_to,
)
# Connect last before_model to model_request
_add_middleware_edge(
graph,
hooks.before_model[-1].node_name,
"model_request",
topology.loop_entry_node,
can_jump_to=hooks.before_model[-1].can_jump_to,
)
# Connect after_model chain (reverse order)
if hooks.after_model:
graph.add_edge("model_request", hooks.after_model[-1].node_name)
for i in range(len(hooks.after_model) - 1, 0, -1):
_add_middleware_edge(
graph,
hooks.after_model[i].node_name,
hooks.after_model[i - 1].node_name,
topology.loop_entry_node,
can_jump_to=hooks.after_model[i].can_jump_to,
)
# Connect after_agent chain (reverse order)
if hooks.after_agent:
for i in range(len(hooks.after_agent) - 1, 0, -1):
_add_middleware_edge(
graph,
hooks.after_agent[i].node_name,
hooks.after_agent[i - 1].node_name,
topology.loop_entry_node,
can_jump_to=hooks.after_agent[i].can_jump_to,
)
# Connect first after_agent to END
_add_middleware_edge(
graph,
hooks.after_agent[0].node_name,
END,
topology.loop_entry_node,
can_jump_to=hooks.after_agent[0].can_jump_to,
)
def create_agent( # noqa: PLR0915
*,
model: str | BaseChatModel,
@@ -185,106 +624,23 @@ def create_agent( # noqa: PLR0915
) -> StateGraph[
AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
]:
"""Create a middleware agent graph."""
# init chat model
if isinstance(model, str):
model = init_chat_model(model)
"""Create a middleware agent graph.
# Handle tools being None or empty
if tools is None:
tools = []
Args:
model: Model name or BaseChatModel instance.
tools: Tools for the agent to use.
system_prompt: System prompt for the agent.
middleware: Middleware instances to customize agent behavior.
response_format: Response format configuration for structured outputs.
context_schema: Context schema for the graph runtime.
# Convert response format and setup structured output tools
# Raw schemas are wrapped in AutoStrategy to preserve auto-detection intent.
# AutoStrategy is converted to ToolStrategy upfront to calculate tools during agent creation,
# but may be replaced with ProviderStrategy later based on model capabilities.
initial_response_format: ToolStrategy | ProviderStrategy | AutoStrategy | None
if response_format is None:
initial_response_format = None
elif isinstance(response_format, (ToolStrategy, ProviderStrategy)):
# Preserve explicitly requested strategies
initial_response_format = response_format
elif isinstance(response_format, AutoStrategy):
# AutoStrategy provided - preserve it for later auto-detection
initial_response_format = response_format
else:
# Raw schema - wrap in AutoStrategy to enable auto-detection
initial_response_format = AutoStrategy(schema=response_format)
# For AutoStrategy, convert to ToolStrategy to setup tools upfront
# (may be replaced with ProviderStrategy later based on model)
tool_strategy_for_setup: ToolStrategy | None = None
if isinstance(initial_response_format, AutoStrategy):
tool_strategy_for_setup = ToolStrategy(schema=initial_response_format.schema)
elif isinstance(initial_response_format, ToolStrategy):
tool_strategy_for_setup = initial_response_format
structured_output_tools: dict[str, OutputToolBinding] = {}
if tool_strategy_for_setup:
for response_schema in tool_strategy_for_setup.schema_specs:
structured_tool_info = OutputToolBinding.from_schema_spec(response_schema)
structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
# Setup tools
tool_node: ToolNode | None = None
if isinstance(tools, list):
# Extract built-in provider tools (dict format) and regular tools (BaseTool)
built_in_tools = [t for t in tools if isinstance(t, dict)]
regular_tools = [t for t in tools if not isinstance(t, dict)]
# Tools that require client-side execution (must be in ToolNode)
available_tools = middleware_tools + regular_tools
# Only create ToolNode if we have client-side tools
tool_node = ToolNode(tools=available_tools) if available_tools else None
# Default tools for ModelRequest initialization
# Include built-ins and regular tools (can be changed dynamically by middleware)
# Structured tools are NOT included - they're added dynamically based on response_format
default_tools = regular_tools + middleware_tools + built_in_tools
elif isinstance(tools, ToolNode):
tool_node = tools
if tool_node:
# Add middleware tools to existing ToolNode
available_tools = list(tool_node.tools_by_name.values()) + middleware_tools
tool_node = ToolNode(available_tools)
# default_tools includes all client-side tools (no built-ins or structured tools)
default_tools = available_tools
else:
# No tools provided, only middleware_tools available
default_tools = middleware_tools
# validate middleware
assert len({m.name for m in middleware}) == len(middleware), ( # noqa: S101
"Please remove duplicate middleware instances."
)
middleware_w_before = [
m
for m in middleware
if m.__class__.before_model is not AgentMiddleware.before_model
or m.__class__.abefore_model is not AgentMiddleware.abefore_model
]
middleware_w_modify_model_request = [
m
for m in middleware
if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request
or m.__class__.amodify_model_request is not AgentMiddleware.amodify_model_request
]
middleware_w_after = [
m
for m in middleware
if m.__class__.after_model is not AgentMiddleware.after_model
or m.__class__.aafter_model is not AgentMiddleware.aafter_model
]
middleware_w_retry = [
m
for m in middleware
if m.__class__.retry_model_request is not AgentMiddleware.retry_model_request
or m.__class__.aretry_model_request is not AgentMiddleware.aretry_model_request
]
Returns:
StateGraph configured with all nodes and edges.
"""
# Phase 1: Setup and validate components
components = _setup_components(model, tools, middleware, response_format, system_prompt)
# Phase 2: Create schemas
state_schemas = {m.state_schema for m in middleware}
state_schemas.add(AgentState)
@@ -292,7 +648,7 @@ def create_agent( # noqa: PLR0915
input_schema = _resolve_schema(state_schemas, "InputSchema", "input")
output_schema = _resolve_schema(state_schemas, "OutputSchema", "output")
# create graph, add nodes
# Phase 3: Create graph
graph: StateGraph[
AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
] = StateGraph(
@@ -302,6 +658,15 @@ def create_agent( # noqa: PLR0915
context_schema=context_schema,
)
# Phase 4: Define model request handlers (need access to components via closure)
# These are inner functions because they need access to components
structured_output_tools = components.structured_output_tools
default_tools = components.default_tools
initial_response_format = components.initial_response_format
model_instance = components.model
middleware_w_modify_model_request = components.middleware_hooks.modify_model_request
middleware_w_retry = components.middleware_hooks.retry
def _handle_model_output(
output: AIMessage, effective_response_format: ResponseFormat | None
) -> dict[str, Any]:
@@ -511,29 +876,28 @@ def create_agent( # noqa: PLR0915
def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Sync model request handler with sequential middleware processing."""
request = ModelRequest(
model=model,
model=model_instance,
tools=default_tools,
system_prompt=system_prompt,
system_prompt=components.system_prompt,
response_format=initial_response_format,
messages=state["messages"],
tool_choice=None,
)
# Apply modify_model_request middleware in sequence
for m in middleware_w_modify_model_request:
if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request:
m.modify_model_request(request, state, runtime)
for hook_info in middleware_w_modify_model_request:
if hook_info.sync_fn:
hook_info.sync_fn(request, state, runtime)
else:
msg = (
f"No synchronous function provided for "
f'{m.__class__.__name__}.amodify_model_request".'
f"{hook_info.middleware_name}.amodify_model_request"
"\nEither initialize with a synchronous function or invoke"
" via the async API (ainvoke, astream, etc.)"
)
raise TypeError(msg)
# Retry loop for model invocation with error handling
# Hard limit of 100 attempts to prevent infinite loops from buggy middleware
max_attempts = 100
for attempt in range(1, max_attempts + 1):
try:
@@ -551,18 +915,17 @@ def create_agent( # noqa: PLR0915
}
except Exception as error:
# Try retry_model_request on each middleware
for m in middleware_w_retry:
if m.__class__.retry_model_request is not AgentMiddleware.retry_model_request:
if retry_request := m.retry_model_request(
for hook_info in middleware_w_retry:
if hook_info.sync_fn:
if retry_request := hook_info.sync_fn(
error, request, state, runtime, attempt
):
# Break on first middleware that wants to retry
request = retry_request
break
else:
msg = (
f"No synchronous function provided for "
f'{m.__class__.__name__}.aretry_model_request".'
f"{hook_info.middleware_name}.aretry_model_request"
"\nEither initialize with a synchronous function or invoke"
" via the async API (ainvoke, astream, etc.)"
)
@@ -577,20 +940,26 @@ def create_agent( # noqa: PLR0915
async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Async model request handler with sequential middleware processing."""
request = ModelRequest(
model=model,
model=model_instance,
tools=default_tools,
system_prompt=system_prompt,
system_prompt=components.system_prompt,
response_format=initial_response_format,
messages=state["messages"],
tool_choice=None,
)
# Apply modify_model_request middleware in sequence
for m in middleware_w_modify_model_request:
await m.amodify_model_request(request, state, runtime)
for hook_info in middleware_w_modify_model_request:
if hook_info.async_fn:
await hook_info.async_fn(request, state, runtime)
elif hook_info.sync_fn:
# Fallback to sync if only sync is implemented
await run_in_executor(None, hook_info.sync_fn, request, state, runtime)
else:
msg = f"No function provided for {hook_info.middleware_name}.modify_model_request"
raise RuntimeError(msg)
# Retry loop for model invocation with error handling
# Hard limit of 100 attempts to prevent infinite loops from buggy middleware
max_attempts = 100
for attempt in range(1, max_attempts + 1):
try:
@@ -608,135 +977,41 @@ def create_agent( # noqa: PLR0915
}
except Exception as error:
# Try retry_model_request on each middleware
for m in middleware_w_retry:
if retry_request := await m.aretry_model_request(
error, request, state, runtime, attempt
):
# Break on first middleware that wants to retry
for hook_info in middleware_w_retry:
retry_request = None
if hook_info.async_fn:
retry_request = await hook_info.async_fn(
error, request, state, runtime, attempt
)
elif hook_info.sync_fn:
# Fallback to sync if only sync is implemented
retry_request = await run_in_executor(
None, hook_info.sync_fn, error, request, state, runtime, attempt
)
if retry_request:
request = retry_request
break
else:
# If no middleware wants to retry, re-raise the error
raise
# If we exit the loop, max attempts exceeded
msg = f"Maximum retry attempts ({max_attempts}) exceeded"
raise RuntimeError(msg)
# Use sync or async based on model capabilities
from langgraph._internal._runnable import RunnableCallable
# Phase 5: Add nodes to graph
graph.add_node("model_request", RunnableCallable(model_request, amodel_request, trace=False))
if components.tool_node is not None:
graph.add_node("tools", components.tool_node)
_add_middleware_nodes(graph, components, state_schema)
graph.add_node("model_request", RunnableCallable(model_request, amodel_request))
# Phase 6: Compute graph topology
topology = GraphTopology.compute(components.middleware_hooks)
# Only add tools node if we have tools
if tool_node is not None:
graph.add_node("tools", tool_node)
# Add middleware nodes
for m in middleware:
if (
m.__class__.before_model is not AgentMiddleware.before_model
or m.__class__.abefore_model is not AgentMiddleware.abefore_model
):
# Use RunnableCallable to support both sync and async
# Pass None for sync if not overridden to avoid signature conflicts
sync_before = (
m.before_model
if m.__class__.before_model is not AgentMiddleware.before_model
else None
)
async_before = (
m.abefore_model
if m.__class__.abefore_model is not AgentMiddleware.abefore_model
else None
)
before_node = RunnableCallable(sync_before, async_before)
graph.add_node(f"{m.name}.before_model", before_node, input_schema=state_schema)
if (
m.__class__.after_model is not AgentMiddleware.after_model
or m.__class__.aafter_model is not AgentMiddleware.aafter_model
):
# Use RunnableCallable to support both sync and async
# Pass None for sync if not overridden to avoid signature conflicts
sync_after = (
m.after_model
if m.__class__.after_model is not AgentMiddleware.after_model
else None
)
async_after = (
m.aafter_model
if m.__class__.aafter_model is not AgentMiddleware.aafter_model
else None
)
after_node = RunnableCallable(sync_after, async_after)
graph.add_node(f"{m.name}.after_model", after_node, input_schema=state_schema)
# add start edge
first_node = (
f"{middleware_w_before[0].name}.before_model" if middleware_w_before else "model_request"
)
last_node = (
f"{middleware_w_after[0].name}.after_model" if middleware_w_after else "model_request"
)
graph.add_edge(START, first_node)
# add conditional edges only if tools exist
if tool_node is not None:
graph.add_conditional_edges(
"tools",
_make_tools_to_model_edge(tool_node, first_node, structured_output_tools),
[first_node, END],
)
graph.add_conditional_edges(
last_node,
_make_model_to_tools_edge(first_node, structured_output_tools, tool_node),
[first_node, "tools", END],
)
elif last_node == "model_request":
# If no tools, just go to END from model
graph.add_edge(last_node, END)
else:
# If after_model, then need to check for can_jump_to
_add_middleware_edge(
graph,
f"{middleware_w_after[0].name}.after_model",
END,
first_node,
can_jump_to=_get_can_jump_to(middleware_w_after[0], "after_model"),
)
# Add middleware edges (same as before)
if middleware_w_before:
for m1, m2 in itertools.pairwise(middleware_w_before):
_add_middleware_edge(
graph,
f"{m1.name}.before_model",
f"{m2.name}.before_model",
first_node,
can_jump_to=_get_can_jump_to(m1, "before_model"),
)
# Go directly to model_request after the last before_model
_add_middleware_edge(
graph,
f"{middleware_w_before[-1].name}.before_model",
"model_request",
first_node,
can_jump_to=_get_can_jump_to(middleware_w_before[-1], "before_model"),
)
if middleware_w_after:
graph.add_edge("model_request", f"{middleware_w_after[-1].name}.after_model")
for idx in range(len(middleware_w_after) - 1, 0, -1):
m1 = middleware_w_after[idx]
m2 = middleware_w_after[idx - 1]
_add_middleware_edge(
graph,
f"{m1.name}.after_model",
f"{m2.name}.after_model",
first_node,
can_jump_to=_get_can_jump_to(m1, "after_model"),
)
# Phase 7: Connect edges
_connect_entry_edges(graph, topology)
_connect_loop_edges(graph, topology, components)
_connect_middleware_chains(graph, components, topology)
return graph
@@ -768,7 +1043,10 @@ def _fetch_last_ai_and_tool_messages(
def _make_model_to_tools_edge(
first_node: str, structured_output_tools: dict[str, OutputToolBinding], tool_node: ToolNode
first_node: str,
structured_output_tools: dict[str, OutputToolBinding],
tool_node: ToolNode,
exit_node: str,
) -> Callable[[dict[str, Any]], str | list[Send] | None]:
def model_to_tools(state: dict[str, Any]) -> str | list[Send] | None:
# 1. if there's an explicit jump_to in the state, use it
@@ -778,10 +1056,10 @@ def _make_model_to_tools_edge(
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
tool_message_ids = [m.tool_call_id for m in tool_messages]
# 2. if the model hasn't called any tools, jump to END
# 2. if the model hasn't called any tools, exit the loop
# this is the classic exit condition for an agent loop
if len(last_ai_message.tool_calls) == 0:
return END
return exit_node
pending_tool_calls = [
c
@@ -804,7 +1082,10 @@ def _make_model_to_tools_edge(
def _make_tools_to_model_edge(
tool_node: ToolNode, next_node: str, structured_output_tools: dict[str, OutputToolBinding]
tool_node: ToolNode,
next_node: str,
structured_output_tools: dict[str, OutputToolBinding],
exit_node: str,
) -> Callable[[dict[str, Any]], str | None]:
def tools_to_model(state: dict[str, Any]) -> str | None:
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
@@ -814,10 +1095,10 @@ def _make_tools_to_model_edge(
for c in last_ai_message.tool_calls
if c["name"] in tool_node.tools_by_name
):
return END
return exit_node
if any(t.name in structured_output_tools for t in tool_messages):
return END
return exit_node
return next_node

View File

@@ -73,8 +73,8 @@
__start__([<p>__start__</p>]):::first
model_request(model_request)
NoopTen\2ebefore_model(NoopTen.before_model)
NoopTen\2eafter_model(NoopTen.after_model)
NoopEleven\2ebefore_model(NoopEleven.before_model)
NoopTen\2eafter_model(NoopTen.after_model)
NoopEleven\2eafter_model(NoopEleven.after_model)
__end__([<p>__end__</p>]):::last
NoopEleven\2eafter_model --> NoopTen\2eafter_model;
@@ -240,8 +240,8 @@
__start__([<p>__start__</p>]):::first
model_request(model_request)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
@@ -267,10 +267,10 @@
__start__([<p>__start__</p>]):::first
model_request(model_request)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopEight\2eafter_model(NoopEight.after_model)
NoopNine\2ebefore_model(NoopNine.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2eafter_model(NoopEight.after_model)
NoopNine\2eafter_model(NoopNine.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
@@ -299,8 +299,8 @@
model_request(model_request)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
@@ -332,8 +332,8 @@
model_request(model_request)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
@@ -365,8 +365,8 @@
model_request(model_request)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
@@ -398,8 +398,8 @@
model_request(model_request)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;
@@ -431,8 +431,8 @@
model_request(model_request)
tools(tools)
NoopSeven\2ebefore_model(NoopSeven.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2ebefore_model(NoopEight.before_model)
NoopSeven\2eafter_model(NoopSeven.after_model)
NoopEight\2eafter_model(NoopEight.after_model)
__end__([<p>__end__</p>]):::last
NoopEight\2eafter_model --> NoopSeven\2eafter_model;

View File

@@ -0,0 +1,388 @@
"""Unit tests for before_agent and after_agent middleware hooks."""
from typing import Any
from langchain.agents import create_agent
from langchain.agents.middleware import (
AgentMiddleware,
AgentState,
after_agent,
before_agent,
)
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.tools import tool
from langgraph.runtime import Runtime
from ..model import FakeToolCallingModel
@tool
def sample_tool(query: str) -> str:
"""A sample tool for testing."""
return f"Result for: {query}"
class TestBeforeAgentBasic:
"""Test basic before_agent functionality."""
def test_sync_before_agent_execution(self) -> None:
"""Test that before_agent hook is called synchronously."""
execution_log = []
@before_agent
def log_before_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
execution_log.append("before_agent_called")
execution_log.append(f"message_count: {len(state['messages'])}")
return None
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")]))
agent = create_agent(model=model, tools=[], middleware=[log_before_agent])
agent.invoke({"messages": [HumanMessage("Hi")]})
assert "before_agent_called" in execution_log
assert "message_count: 1" in execution_log
async def test_async_before_agent_execution(self) -> None:
"""Test that before_agent hook is called asynchronously."""
execution_log = []
@before_agent
async def async_log_before_agent(
state: AgentState, runtime: Runtime
) -> dict[str, Any] | None:
execution_log.append("async_before_agent_called")
execution_log.append(f"message_count: {len(state['messages'])}")
return None
model = GenericFakeChatModel(messages=iter([AIMessage(content="Hello!")]))
agent = create_agent(model=model, tools=[], middleware=[async_log_before_agent])
await agent.ainvoke({"messages": [HumanMessage("Hi")]})
assert "async_before_agent_called" in execution_log
assert "message_count: 1" in execution_log
def test_before_agent_state_modification(self) -> None:
"""Test that before_agent can modify state."""
@before_agent
def add_metadata(state: AgentState, runtime: Runtime) -> dict[str, Any]:
return {"messages": [HumanMessage("Injected by middleware")]}
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(model=model, tools=[], middleware=[add_metadata])
result = agent.invoke({"messages": [HumanMessage("Original")]})
# Should have original + injected + AI response
assert len(result["messages"]) >= 2
message_contents = [msg.content for msg in result["messages"]]
assert "Injected by middleware" in message_contents
def test_before_agent_with_class_inheritance(self) -> None:
"""Test before_agent using class inheritance."""
execution_log = []
class CustomBeforeAgentMiddleware(AgentMiddleware):
def before_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
execution_log.append("class_before_agent_called")
return None
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(model=model, tools=[], middleware=[CustomBeforeAgentMiddleware()])
agent.invoke({"messages": [HumanMessage("Test")]})
assert "class_before_agent_called" in execution_log
async def test_before_agent_with_async_class_inheritance(self) -> None:
"""Test async before_agent using class inheritance."""
execution_log = []
class CustomAsyncBeforeAgentMiddleware(AgentMiddleware):
async def abefore_agent(
self, state: AgentState, runtime: Runtime
) -> dict[str, Any] | None:
execution_log.append("async_class_before_agent_called")
return None
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(model=model, tools=[], middleware=[CustomAsyncBeforeAgentMiddleware()])
await agent.ainvoke({"messages": [HumanMessage("Test")]})
assert "async_class_before_agent_called" in execution_log
class TestAfterAgentBasic:
"""Test basic after_agent functionality."""
def test_sync_after_agent_execution(self) -> None:
"""Test that after_agent hook is called synchronously."""
execution_log = []
@after_agent
def log_after_agent(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
execution_log.append("after_agent_called")
execution_log.append(f"final_message_count: {len(state['messages'])}")
return None
model = GenericFakeChatModel(messages=iter([AIMessage(content="Final response")]))
agent = create_agent(model=model, tools=[], middleware=[log_after_agent])
agent.invoke({"messages": [HumanMessage("Hi")]})
assert "after_agent_called" in execution_log
assert any("final_message_count:" in log for log in execution_log)
async def test_async_after_agent_execution(self) -> None:
"""Test that after_agent hook is called asynchronously."""
execution_log = []
@after_agent
async def async_log_after_agent(
state: AgentState, runtime: Runtime
) -> dict[str, Any] | None:
execution_log.append("async_after_agent_called")
return None
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(model=model, tools=[], middleware=[async_log_after_agent])
await agent.ainvoke({"messages": [HumanMessage("Hi")]})
assert "async_after_agent_called" in execution_log
def test_after_agent_state_modification(self) -> None:
"""Test that after_agent can modify state."""
@after_agent
def add_final_message(state: AgentState, runtime: Runtime) -> dict[str, Any]:
return {"messages": [AIMessage("Added by after_agent")]}
model = GenericFakeChatModel(messages=iter([AIMessage(content="Model response")]))
agent = create_agent(model=model, tools=[], middleware=[add_final_message])
result = agent.invoke({"messages": [HumanMessage("Test")]})
message_contents = [msg.content for msg in result["messages"]]
assert "Added by after_agent" in message_contents
def test_after_agent_with_class_inheritance(self) -> None:
"""Test after_agent using class inheritance."""
execution_log = []
class CustomAfterAgentMiddleware(AgentMiddleware):
def after_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
execution_log.append("class_after_agent_called")
return None
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(model=model, tools=[], middleware=[CustomAfterAgentMiddleware()])
agent.invoke({"messages": [HumanMessage("Test")]})
assert "class_after_agent_called" in execution_log
async def test_after_agent_with_async_class_inheritance(self) -> None:
"""Test async after_agent using class inheritance."""
execution_log = []
class CustomAsyncAfterAgentMiddleware(AgentMiddleware):
async def aafter_agent(
self, state: AgentState, runtime: Runtime
) -> dict[str, Any] | None:
execution_log.append("async_class_after_agent_called")
return None
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(model=model, tools=[], middleware=[CustomAsyncAfterAgentMiddleware()])
await agent.ainvoke({"messages": [HumanMessage("Test")]})
assert "async_class_after_agent_called" in execution_log
class TestBeforeAndAfterAgentCombined:
"""Test before_agent and after_agent hooks working together."""
def test_execution_order(self) -> None:
"""Test that before_agent executes before after_agent."""
execution_log = []
@before_agent
def log_before(state: AgentState, runtime: Runtime) -> None:
execution_log.append("before")
@after_agent
def log_after(state: AgentState, runtime: Runtime) -> None:
execution_log.append("after")
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(model=model, tools=[], middleware=[log_before, log_after])
agent.invoke({"messages": [HumanMessage("Test")]})
assert execution_log == ["before", "after"]
async def test_async_execution_order(self) -> None:
"""Test async execution order of before_agent and after_agent."""
execution_log = []
@before_agent
async def async_log_before(state: AgentState, runtime: Runtime) -> None:
execution_log.append("async_before")
@after_agent
async def async_log_after(state: AgentState, runtime: Runtime) -> None:
execution_log.append("async_after")
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(model=model, tools=[], middleware=[async_log_before, async_log_after])
await agent.ainvoke({"messages": [HumanMessage("Test")]})
assert execution_log == ["async_before", "async_after"]
def test_state_passthrough(self) -> None:
"""Test that state modifications in before_agent are visible to after_agent."""
collected_states = {}
@before_agent
def modify_in_before(state: AgentState, runtime: Runtime) -> dict[str, Any]:
return {"messages": [HumanMessage("Modified by before_agent")]}
@after_agent
def capture_in_after(state: AgentState, runtime: Runtime) -> None:
collected_states["messages"] = state["messages"]
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(model=model, tools=[], middleware=[modify_in_before, capture_in_after])
agent.invoke({"messages": [HumanMessage("Original")]})
message_contents = [msg.content for msg in collected_states["messages"]]
assert "Modified by before_agent" in message_contents
def test_multiple_middleware_instances(self) -> None:
"""Test multiple before_agent and after_agent middleware instances."""
execution_log = []
@before_agent
def before_one(state: AgentState, runtime: Runtime) -> None:
execution_log.append("before_1")
@before_agent
def before_two(state: AgentState, runtime: Runtime) -> None:
execution_log.append("before_2")
@after_agent
def after_one(state: AgentState, runtime: Runtime) -> None:
execution_log.append("after_1")
@after_agent
def after_two(state: AgentState, runtime: Runtime) -> None:
execution_log.append("after_2")
model = GenericFakeChatModel(messages=iter([AIMessage(content="Response")]))
agent = create_agent(
model=model, tools=[], middleware=[before_one, before_two, after_one, after_two]
)
agent.invoke({"messages": [HumanMessage("Test")]})
assert "before_1" in execution_log
assert "before_2" in execution_log
assert "after_1" in execution_log
assert "after_2" in execution_log
def test_agent_hooks_run_once_with_multiple_model_calls(self) -> None:
"""Test that before_agent and after_agent run only once even with tool calls."""
execution_log = []
@before_agent
def log_before_agent(state: AgentState, runtime: Runtime) -> None:
execution_log.append("before_agent")
@after_agent
def log_after_agent(state: AgentState, runtime: Runtime) -> None:
execution_log.append("after_agent")
# Model will call a tool once, then respond with final answer
model = FakeToolCallingModel(
tool_calls=[
[{"name": "sample_tool", "args": {"query": "test"}, "id": "1"}],
[], # Second call returns no tool calls (final answer)
]
)
agent = create_agent(
model=model,
tools=[sample_tool],
middleware=[log_before_agent, log_after_agent],
)
agent.invoke({"messages": [HumanMessage("Test")]})
# before_agent and after_agent should run exactly once
assert execution_log.count("before_agent") == 1
assert execution_log.count("after_agent") == 1
# before_agent should run first, after_agent should run last
assert execution_log[0] == "before_agent"
assert execution_log[-1] == "after_agent"
class TestDecoratorParameters:
"""Test decorator parameters for before_agent and after_agent."""
def test_before_agent_with_custom_name(self) -> None:
"""Test before_agent with custom middleware name."""
@before_agent(name="CustomBeforeAgentMiddleware")
def custom_named_before(state: AgentState, runtime: Runtime) -> None:
pass
assert custom_named_before.name == "CustomBeforeAgentMiddleware"
def test_after_agent_with_custom_name(self) -> None:
"""Test after_agent with custom middleware name."""
@after_agent(name="CustomAfterAgentMiddleware")
def custom_named_after(state: AgentState, runtime: Runtime) -> None:
pass
assert custom_named_after.name == "CustomAfterAgentMiddleware"
def test_before_agent_default_name(self) -> None:
"""Test that before_agent uses function name by default."""
@before_agent
def my_before_agent_function(state: AgentState, runtime: Runtime) -> None:
pass
assert my_before_agent_function.name == "my_before_agent_function"
def test_after_agent_default_name(self) -> None:
"""Test that after_agent uses function name by default."""
@after_agent
def my_after_agent_function(state: AgentState, runtime: Runtime) -> None:
pass
assert my_after_agent_function.name == "my_after_agent_function"

View File

@@ -8,8 +8,8 @@ from syrupy.assertion import SnapshotAssertion
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.tools import tool
from langgraph.runtime import Runtime
from langgraph.types import Command
from langchain.agents.middleware_agent import create_agent
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
@@ -20,7 +20,6 @@ from langchain.agents.middleware.types import (
modify_model_request,
hook_config,
)
from langchain.agents.middleware_agent import create_agent, _get_can_jump_to
from .model import FakeToolCallingModel
@@ -470,49 +469,6 @@ async def test_async_can_jump_to_integration() -> None:
assert len(result["messages"]) > 1
def test_get_can_jump_to_no_false_positives() -> None:
"""Test that _get_can_jump_to doesn't return false positives for base class methods."""
# Middleware with no overridden methods should return empty list
class EmptyMiddleware(AgentMiddleware):
pass
empty_middleware = EmptyMiddleware()
empty_middleware.tools = []
# Should not return any jump destinations for base class methods
assert _get_can_jump_to(empty_middleware, "before_model") == []
assert _get_can_jump_to(empty_middleware, "after_model") == []
def test_get_can_jump_to_only_overridden_methods() -> None:
"""Test that _get_can_jump_to only checks overridden methods."""
# Middleware with only sync method overridden
class SyncOnlyMiddleware(AgentMiddleware):
@hook_config(can_jump_to=["end"])
def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
return None
sync_middleware = SyncOnlyMiddleware()
sync_middleware.tools = []
# Should return can_jump_to from overridden sync method
assert _get_can_jump_to(sync_middleware, "before_model") == ["end"]
# Middleware with only async method overridden
class AsyncOnlyMiddleware(AgentMiddleware):
@hook_config(can_jump_to=["model"])
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
return None
async_middleware = AsyncOnlyMiddleware()
async_middleware.tools = []
# Should return can_jump_to from overridden async method
assert _get_can_jump_to(async_middleware, "after_model") == ["model"]
def test_async_middleware_with_can_jump_to_graph_snapshot(snapshot: SnapshotAssertion) -> None:
"""Test that async middleware with can_jump_to creates correct graph structure with conditional edges."""