Compare commits

...

12 Commits

Author SHA1 Message Date
Eugene Yurtsev
3c92e986f6 x 2025-10-06 21:52:18 -04:00
Eugene Yurtsev
40b4c69a5a x 2025-10-06 17:38:01 -04:00
Eugene Yurtsev
076c6f6b41 x 2025-10-06 17:35:27 -04:00
Eugene Yurtsev
db58bfa543 rename to continue 2025-10-06 17:32:04 -04:00
Eugene Yurtsev
ba9ec6d895 x 2025-10-06 16:56:16 -04:00
Eugene Yurtsev
fa533c44b7 x 2025-10-06 16:54:00 -04:00
Eugene Yurtsev
4f53ed3e9a x 2025-10-06 16:51:24 -04:00
Eugene Yurtsev
def2f147ae x 2025-10-06 16:16:19 -04:00
Eugene Yurtsev
65e073e85c x 2025-10-06 16:08:09 -04:00
Eugene Yurtsev
a9ff8e0b67 x 2025-10-06 15:31:16 -04:00
Eugene Yurtsev
0927ae4be1 x 2025-10-06 15:26:43 -04:00
Eugene Yurtsev
06ce94ca06 x 2025-10-06 00:00:02 -04:00
7 changed files with 1619 additions and 45 deletions

View File

@@ -7,6 +7,10 @@ from .planning import PlanningMiddleware
from .prompt_caching import AnthropicPromptCachingMiddleware
from .summarization import SummarizationMiddleware
from .tool_call_limit import ToolCallLimitMiddleware
from .tool_error_handling import (
ErrorToMessageMiddleware,
ToolRetryMiddleware,
)
from .tool_selection import LLMToolSelectorMiddleware
from .types import (
AgentMiddleware,
@@ -24,6 +28,7 @@ __all__ = [
"AgentState",
# should move to langchain-anthropic if we decide to keep it
"AnthropicPromptCachingMiddleware",
"ErrorToMessageMiddleware",
"HumanInTheLoopMiddleware",
"LLMToolSelectorMiddleware",
"ModelFallbackMiddleware",
@@ -31,6 +36,7 @@ __all__ = [
"PIIDetectionError",
"PIIMiddleware",
"PlanningMiddleware",
"ToolRetryMiddleware",
"SummarizationMiddleware",
"ToolCallLimitMiddleware",
"after_model",

View File

@@ -0,0 +1,406 @@
"""Middleware for handling tool execution errors in agents.
This module provides composable middleware for error handling, retries,
and error-to-message conversion in tool execution workflows.
"""
from __future__ import annotations
import inspect
import logging
import time
from types import UnionType
from typing import TYPE_CHECKING, Union, get_args, get_origin, get_type_hints
from langchain_core.messages import ToolMessage
from langchain.agents.middleware.types import AgentMiddleware
# Import ToolCallResponse locally to avoid circular import
from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse
if TYPE_CHECKING:
from collections.abc import Callable, Generator
logger = logging.getLogger(__name__)
# Default retriable exception types - transient errors that may succeed on retry
DEFAULT_RETRIABLE_EXCEPTIONS = (
# Network and connection errors
ConnectionError,
TimeoutError,
# HTTP client errors are typically not retriable, but these are exceptions:
# - 429: Rate limit (temporary)
# - 503: Service unavailable (temporary)
# Note: Specific HTTP libraries may define their own exception types
)
def _infer_retriable_types(
predicate: Callable[[Exception], bool],
) -> tuple[type[Exception], ...]:
"""Infer exception types from a retry predicate function's type annotations.
Analyzes the type annotations of a predicate function to determine which
exception types it's designed to handle for retry decisions.
Args:
predicate: A callable that takes an exception and returns whether to retry.
The first parameter should be type-annotated with exception type(s).
Returns:
Tuple of exception types that the predicate handles. Returns (Exception,)
if no specific type information is available.
Raises:
ValueError: If the predicate's annotation contains non-Exception types.
"""
sig = inspect.signature(predicate)
params = list(sig.parameters.values())
if params:
# Skip self/cls if it's a method
if params[0].name in ["self", "cls"] and len(params) == 2:
first_param = params[1]
else:
first_param = params[0]
type_hints = get_type_hints(predicate)
if first_param.name in type_hints:
origin = get_origin(first_param.annotation)
# Handle Union types
if origin in [Union, UnionType]:
args = get_args(first_param.annotation)
if all(isinstance(arg, type) and issubclass(arg, Exception) for arg in args):
return tuple(args)
msg = (
"All types in retry predicate annotation must be Exception types. "
"For example, `def should_retry(e: Union[TimeoutError, "
"ConnectionError]) -> bool`. "
f"Got '{first_param.annotation}' instead."
)
raise ValueError(msg)
# Handle single exception type
exception_type = type_hints[first_param.name]
if isinstance(exception_type, type) and issubclass(exception_type, Exception):
return (exception_type,)
msg = (
"Retry predicate must be annotated with Exception type(s). "
"For example, `def should_retry(e: TimeoutError) -> bool` or "
"`def should_retry(e: Union[TimeoutError, ConnectionError]) -> bool`. "
f"Got '{exception_type}' instead."
)
raise ValueError(msg)
# No type information - return Exception for backward compatibility
return (Exception,)
class ToolRetryMiddleware(AgentMiddleware):
"""Retry failed tool calls with constant delay.
This middleware catches tool execution errors and retries them up to a maximum
number of attempts with a constant delay between retries. It operates at the
outermost layer of middleware composition to catch all errors.
Examples:
Retry only network errors:
```python
from langchain.agents.middleware import ToolRetryMiddleware
middleware = ToolRetryMiddleware(
max_retries=3,
delay=2.0,
retry_on=(TimeoutError, ConnectionError),
)
agent = create_agent(
model="openai:gpt-4o",
tools=[my_tool],
middleware=[middleware],
)
```
Use predicate function for custom retry logic:
```python
from langchain.tools.tool_node import ToolInvocationError
def should_retry(e: Exception) -> bool:
# Don't retry validation errors from LLM
if isinstance(e, ToolInvocationError):
return False
# Retry network errors
if isinstance(e, (TimeoutError, ConnectionError)):
return True
return False
middleware = ToolRetryMiddleware(
max_retries=3,
retry_on=should_retry,
)
```
Compose with error conversion:
```python
from langchain.agents.middleware import (
ToolRetryMiddleware,
ErrorToMessageMiddleware,
)
agent = create_agent(
model="openai:gpt-4o",
tools=[my_tool],
middleware=[
# Outer: retry network errors
ToolRetryMiddleware(
max_retries=3,
delay=2.0,
retry_on=(TimeoutError, ConnectionError),
),
# Inner: convert validation errors to messages
ErrorToMessageMiddleware(
exception_types=(ValidationError,),
),
],
)
```
"""
def __init__(
self,
*,
max_retries: int = 3,
delay: float = 1.0,
retry_on: type[Exception]
| tuple[type[Exception], ...]
| Callable[[Exception], bool] = DEFAULT_RETRIABLE_EXCEPTIONS,
) -> None:
"""Initialize retry middleware.
Args:
max_retries: Maximum number of retry attempts. Total attempts will be
max_retries + 1 (initial attempt plus retries).
delay: Constant delay in seconds between retry attempts.
retry_on: Specifies which exceptions should be retried. Can be:
- **type[Exception]**: Retry only this exception type
- **tuple[type[Exception], ...]**: Retry these exception types
- **Callable[[Exception], bool]**: Predicate function that returns
True if the exception should be retried. Type annotations on the
callable are used to filter which exceptions are passed to it.
Defaults to ``DEFAULT_RETRIABLE_EXCEPTIONS`` (ConnectionError, TimeoutError).
"""
super().__init__()
if max_retries < 0:
msg = "max_retries must be non-negative"
raise ValueError(msg)
if delay < 0:
msg = "delay must be non-negative"
raise ValueError(msg)
self.max_retries = max_retries
self.delay = delay
self._retry_on = retry_on
# Determine which exception types to check
if isinstance(retry_on, type) and issubclass(retry_on, Exception):
self._retriable_types = (retry_on,)
self._retry_predicate = None
elif isinstance(retry_on, tuple):
if not retry_on:
msg = "retry_on tuple must not be empty"
raise ValueError(msg)
if not all(isinstance(t, type) and issubclass(t, Exception) for t in retry_on):
msg = "All elements in retry_on tuple must be Exception types"
raise ValueError(msg)
self._retriable_types = retry_on
self._retry_predicate = None
elif callable(retry_on):
self._retriable_types = _infer_retriable_types(retry_on)
self._retry_predicate = retry_on
else:
msg = (
"retry_on must be an Exception type, tuple of Exception types, "
f"or callable. Got {type(retry_on)}"
)
raise ValueError(msg)
def on_tool_call(
self, request: ToolCallRequest, state, runtime
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
"""Retry tool execution on failures."""
for attempt in range(1, self.max_retries + 2): # +1 for initial, +1 for inclusive
response = yield request
# Success - return immediately
if response.action == "continue":
return response
# Error - check if we should retry
if response.action == "raise":
exception = response.exception
if exception is None:
msg = "ToolCallResponse with action='raise' must have an exception"
raise ValueError(msg)
# Check if this exception type is retriable
if not isinstance(exception, self._retriable_types):
logger.debug(
"Exception %s is not retriable for tool %s",
type(exception).__name__,
request.tool_call["name"],
)
return response
# If predicate is provided, check if we should retry
if self._retry_predicate is not None and not self._retry_predicate(exception):
logger.debug(
"Retry predicate returned False for %s in tool %s",
type(exception).__name__,
request.tool_call["name"],
)
return response
# Last attempt - return error
if attempt > self.max_retries:
logger.debug(
"Max retries (%d) reached for tool %s",
self.max_retries,
request.tool_call["name"],
)
return response
# Retry - log and delay
logger.debug(
"Retrying tool %s (attempt %d/%d) after error: %s",
request.tool_call["name"],
attempt,
self.max_retries + 1,
type(exception).__name__,
)
time.sleep(self.delay)
continue
# Should never reach here
msg = f"Unexpected control flow in ToolRetryMiddleware for tool {request.tool_call['name']}"
raise RuntimeError(msg)
class ErrorToMessageMiddleware(AgentMiddleware):
"""Convert specific exception types to ToolMessages.
This middleware intercepts errors and converts them into ToolMessages that
can be sent back to the model as feedback. This is useful for errors caused
by invalid model inputs where the model needs feedback to correct its behavior.
Examples:
Convert validation errors to messages:
```python
from langchain.agents.middleware import ErrorToMessageMiddleware
from langchain.tools.tool_node import ToolInvocationError
middleware = ErrorToMessageMiddleware(
exception_types=(ToolInvocationError,),
message_template="Invalid arguments: {error}. Please fix and try again.",
)
agent = create_agent(
model="openai:gpt-4o",
tools=[my_tool],
middleware=[middleware],
)
```
Compose with retry for network errors:
```python
from langchain.agents.middleware import (
ToolRetryMiddleware,
ErrorToMessageMiddleware,
)
agent = create_agent(
model="openai:gpt-4o",
tools=[my_tool],
middleware=[
# Outer: retry all errors
ToolRetryMiddleware(max_retries=3),
# Inner: convert validation errors to messages
ErrorToMessageMiddleware(
exception_types=(ValidationError,),
),
],
)
```
"""
def __init__(
self,
*,
exception_types: tuple[type[Exception], ...],
message_template: str = "Error: {error}",
) -> None:
"""Initialize error conversion middleware.
Args:
exception_types: Tuple of exception types to convert to messages.
message_template: Template string for error messages. Can use ``{error}``
placeholder for the exception string representation.
"""
super().__init__()
if not exception_types:
msg = "exception_types must not be empty"
raise ValueError(msg)
self.exception_types = exception_types
self.message_template = message_template
def on_tool_call(
self, request: ToolCallRequest, state, runtime
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
"""Convert matching errors to ToolMessages."""
response = yield request
# Success - pass through
if response.action == "continue":
return response
# Error - check if we should convert
if response.action == "raise":
exception = response.exception
if exception is None:
msg = "ToolCallResponse with action='raise' must have an exception"
raise ValueError(msg)
# Check if exception type matches
if not isinstance(exception, self.exception_types):
return response
# Convert to ToolMessage
logger.debug(
"Converting %s to ToolMessage for tool %s",
type(exception).__name__,
request.tool_call["name"],
)
error_message = self.message_template.format(error=str(exception))
tool_message = ToolMessage(
content=error_message,
name=request.tool_call["name"],
tool_call_id=request.tool_call["id"],
status="error",
)
return ToolCallResponse(
action="continue",
result=tool_message,
exception=exception, # Preserve for logging/debugging
)
return response

View File

@@ -19,7 +19,7 @@ from typing import (
from langchain_core.runnables import run_in_executor
if TYPE_CHECKING:
from collections.abc import Awaitable
from collections.abc import Awaitable, Generator
# needed as top level import for pydantic schema generation on AgentState
from langchain_core.messages import AnyMessage # noqa: TC002
@@ -36,6 +36,7 @@ if TYPE_CHECKING:
from langgraph.types import Command
from langchain.agents.structured_output import ResponseFormat
from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse
__all__ = [
"AgentMiddleware",
@@ -215,6 +216,48 @@ class AgentMiddleware(Generic[StateT, ContextT]):
None, self.retry_model_request, error, request, state, runtime, attempt
)
def on_tool_call(
self,
request: ToolCallRequest,
state: StateT,
runtime: Runtime[ContextT],
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
"""Intercept tool execution to implement retry logic, monitoring, or request modification.
Provides generator-based control over the complete tool execution lifecycle.
Multiple middleware can define this hook; they compose automatically with
outer middleware wrapping inner middleware (first defined = outermost layer).
Generator Protocol:
1. Yield a ToolCallRequest (potentially modified from the input)
2. Receive a ToolCallResponse via .send()
3. Optionally yield again to retry
4. Return the final ToolCallResponse to propagate
Args:
request: Tool invocation details including tool_call, tool instance, and config.
state: Current agent state (readonly context).
runtime: LangGraph runtime for accessing user context (readonly context).
Returns:
Generator for request/response interception.
Example:
Retry on rate limit with exponential backoff:
```python
def on_tool_call(self, request, state, runtime):
for attempt in range(3):
response = yield request
if response.action == "continue":
return response
if "rate limit" in str(response.exception):
time.sleep(2**attempt)
continue
return response
```
"""
class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
"""Callable with AgentState and Runtime as arguments."""

View File

@@ -1,7 +1,7 @@
"""Middleware agent implementation."""
import itertools
from collections.abc import Callable, Sequence
from collections.abc import Callable, Generator, Sequence
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
from langchain_core.language_models.chat_models import BaseChatModel
@@ -35,12 +35,99 @@ from langchain.agents.structured_output import (
)
from langchain.chat_models import init_chat_model
from langchain.tools import ToolNode
from langchain.tools.tool_node import ToolCallHandler, ToolCallRequest, ToolCallResponse
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
ResponseT = TypeVar("ResponseT")
def _chain_tool_call_handlers(
handlers: list[ToolCallHandler],
) -> ToolCallHandler | None:
"""Compose multiple tool call handlers into a single middleware stack.
Args:
handlers: Handlers in middleware order (first = outermost layer).
Returns:
Single composed handler, or None if handlers is empty.
"""
if not handlers:
return None
if len(handlers) == 1:
return handlers[0]
def _extract_return_value(stop_iteration: StopIteration) -> ToolCallResponse:
"""Extract ToolCallResponse from StopIteration, validating protocol compliance."""
if stop_iteration.value is None:
msg = "on_tool_call handler must explicitly return a ToolCallResponse"
raise ValueError(msg)
return stop_iteration.value
def compose_two(outer: ToolCallHandler, inner: ToolCallHandler) -> ToolCallHandler:
"""Compose two handlers where outer wraps inner."""
def composed(
request: ToolCallRequest,
state: Any,
runtime: Any,
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
outer_gen = outer(request, state, runtime)
# Initialize outer generator
try:
outer_request = next(outer_gen)
except StopIteration as e:
return _extract_return_value(e)
# Outer retry loop
while True:
inner_gen = inner(outer_request, state, runtime)
# Initialize inner generator
try:
inner_request = next(inner_gen)
except StopIteration as e:
# Inner returned immediately - send to outer
inner_response = _extract_return_value(e)
try:
outer_request = outer_gen.send(inner_response)
continue # Outer retrying
except StopIteration as e:
return _extract_return_value(e)
# Inner retry loop - yield to next layer (or tool)
while True:
tool_response = yield inner_request
try:
inner_request = inner_gen.send(tool_response)
# Inner retrying - continue inner loop
except StopIteration as e:
# Inner done - send response to outer
inner_response = _extract_return_value(e)
break
# Send inner's final response to outer
try:
outer_request = outer_gen.send(inner_response)
# Outer retrying - continue outer loop
except StopIteration as e:
# Outer done - return final response
return _extract_return_value(e)
return composed
# Compose right-to-left: handlers[0](handlers[1](...(handlers[-1](tool))))
result = handlers[-1]
for handler in reversed(handlers[:-1]):
result = compose_two(handler, result)
return result
def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
"""Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations.
@@ -226,6 +313,20 @@ def create_agent( # noqa: PLR0915
structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
# Validate middleware and collect handlers
assert len({m.name for m in middleware}) == len(middleware), ( # noqa: S101
"Please remove duplicate middleware instances."
)
middleware_w_on_tool_call = [
m for m in middleware if m.__class__.on_tool_call is not AgentMiddleware.on_tool_call
]
# Chain all on_tool_call handlers into a single composed handler
on_tool_call_handler = None
if middleware_w_on_tool_call:
handlers = [m.on_tool_call for m in middleware_w_on_tool_call]
on_tool_call_handler = _chain_tool_call_handlers(handlers)
# Setup tools
tool_node: ToolNode | None = None
if isinstance(tools, list):
@@ -237,7 +338,11 @@ def create_agent( # noqa: PLR0915
available_tools = middleware_tools + regular_tools
# Only create ToolNode if we have client-side tools
tool_node = ToolNode(tools=available_tools) if available_tools else None
tool_node = (
ToolNode(tools=available_tools, on_tool_call=on_tool_call_handler)
if available_tools
else None
)
# Default tools for ModelRequest initialization
# Include built-ins and regular tools (can be changed dynamically by middleware)
@@ -248,7 +353,7 @@ def create_agent( # noqa: PLR0915
if tool_node:
# Add middleware tools to existing ToolNode
available_tools = list(tool_node.tools_by_name.values()) + middleware_tools
tool_node = ToolNode(available_tools)
tool_node = ToolNode(available_tools, on_tool_call=on_tool_call_handler)
# default_tools includes all client-side tools (no built-ins or structured tools)
default_tools = available_tools
@@ -256,10 +361,6 @@ def create_agent( # noqa: PLR0915
# No tools provided, only middleware_tools available
default_tools = middleware_tools
# validate middleware
assert len({m.name for m in middleware}) == len(middleware), ( # noqa: S101
"Please remove duplicate middleware instances."
)
middleware_w_before = [
m
for m in middleware

View File

@@ -38,21 +38,22 @@ from __future__ import annotations
import asyncio
import inspect
import json
from collections.abc import Callable, Generator, Sequence
from copy import copy, deepcopy
from dataclasses import replace
from dataclasses import dataclass, replace
from types import UnionType
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Literal,
Optional,
Union,
cast,
get_args,
get_origin,
get_type_hints,
)
from typing import Optional as Optional
from langchain_core.messages import (
AIMessage,
@@ -75,12 +76,11 @@ from langchain_core.tools.base import (
from langgraph._internal._runnable import RunnableCallable
from langgraph.errors import GraphBubbleUp
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import get_runtime
from langgraph.types import Command, Send
from pydantic import BaseModel, ValidationError
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from langchain_core.runnables import RunnableConfig
from langgraph.store.base import BaseStore
@@ -100,6 +100,62 @@ TOOL_INVOCATION_ERROR_TEMPLATE = (
)
@dataclass()
class ToolCallRequest:
"""Request passed to on_tool_call handler before tool execution.
Attributes:
tool_call: The tool call dict containing name, args, and id.
tool: The BaseTool instance that will be invoked.
Note:
tool_call["args"] can be mutated directly to modify arguments.
"""
tool_call: ToolCall
tool: BaseTool
@dataclass()
class ToolCallResponse:
"""Response returned from on_tool_call handler after tool execution.
The action field determines control flow:
- "continue": Handler completed successfully, use result
- "raise": Handler wants to propagate the exception
Attributes:
action: Control flow directive ("continue" or "raise").
result: ToolMessage or Command when action="continue".
exception: The exception when action="raise", or for logging when
action="continue" with an error ToolMessage.
"""
action: Literal["continue", "raise"]
result: ToolMessage | Command | None = None
exception: Exception | None = None
def __post_init__(self) -> None:
"""Validate that required fields are present based on action."""
if self.action == "continue" and self.result is None:
msg = "action='continue' requires a result"
raise ValueError(msg)
if self.action == "raise" and self.exception is None:
msg = "action='raise' requires an exception"
raise ValueError(msg)
ToolCallHandler = Callable[
[ToolCallRequest, Any, Any], Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]
]
"""Generator-based handler that intercepts tool execution.
Receives a ToolCallRequest, state, and runtime; yields modified ToolCallRequests;
receives ToolCallResponses; and returns a final ToolCallResponse. Supports multiple
yields for retry logic.
"""
def msg_content_output(output: Any) -> str | list[dict]:
"""Convert tool output to valid message content format.
@@ -156,7 +212,7 @@ class ToolInvocationError(Exception):
super().__init__(self.message)
def _default_handle_tool_errors(e: Exception) -> str:
def _default_handle_tool_errors(e: ToolInvocationError) -> str:
"""Default error handler for tool errors.
If the tool is a tool invocation error, return its message.
@@ -300,8 +356,8 @@ class ToolNode(RunnableCallable):
Output format depends on input type and tool behavior:
**For Regular tools**:
- Dict input ``{"messages": [ToolMessage(...)]}``
- List input ``[ToolMessage(...)]``
- Dict input -> ``{"messages": [ToolMessage(...)]}``
- List input -> ``[ToolMessage(...)]``
**For Command tools**:
- Returns ``[Command(...)]`` or mixed list with regular tool outputs
@@ -335,6 +391,12 @@ class ToolNode(RunnableCallable):
- catches tool invocation errors (due to invalid arguments provided by the model) and returns a descriptive error message
- ignores tool execution errors (they will be re-raised)
on_tool_call: Optional handler to intercept tool execution. Receives
``ToolCallRequest``, yields potentially modified requests, receives
``ToolCallResponse`` via ``.send()``, and returns final ``ToolCallResponse``.
Enables retries, argument modification, and custom error handling.
Defaults to ``None``.
messages_key: The key in the state dictionary that contains the message list.
This same key will be used for the output ToolMessages.
Defaults to "messages".
@@ -378,6 +440,23 @@ class ToolNode(RunnableCallable):
tool_node = ToolNode([my_tool], handle_tool_errors=handle_errors)
```
Intercepting tool calls:
```python
from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse
def retry_handler(request):
\"\"\"Retry failed tool calls up to 3 times.\"\"\"
for attempt in range(3):
response = yield request
if response.action == "continue":
return response
# Retry on error
return response # Final attempt
tool_node = ToolNode([my_tool], on_tool_call=retry_handler)
```
""" # noqa: E501
name: str = "tools"
@@ -393,6 +472,7 @@ class ToolNode(RunnableCallable):
| Callable[..., str]
| type[Exception]
| tuple[type[Exception], ...] = _default_handle_tool_errors,
on_tool_call: ToolCallHandler | None = None,
messages_key: str = "messages",
) -> None:
"""Initialize the ToolNode with the provided tools and configuration.
@@ -402,6 +482,7 @@ class ToolNode(RunnableCallable):
name: Node name for graph identification.
tags: Optional metadata tags.
handle_tool_errors: Error handling configuration.
on_tool_call: Optional handler to intercept tool execution.
messages_key: State key containing messages.
"""
super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False)
@@ -409,6 +490,7 @@ class ToolNode(RunnableCallable):
self._tool_to_state_args: dict[str, dict[str, str | None]] = {}
self._tool_to_store_arg: dict[str, str | None] = {}
self._handle_tool_errors = handle_tool_errors
self._on_tool_call = on_tool_call
self._messages_key = messages_key
for tool in tools:
if not isinstance(tool, BaseTool):
@@ -429,13 +511,24 @@ class ToolNode(RunnableCallable):
input: list[AnyMessage] | dict[str, Any] | BaseModel,
config: RunnableConfig,
*,
# Optional[BaseStore] should not change to BaseStore | None
# until we support injection of store using `BaseStore | None` annotation
store: Optional[BaseStore], # noqa: UP045
) -> Any:
try:
runtime = get_runtime()
except RuntimeError:
# Running outside of the LangGrah runtime context (e.g., unit-tests)
runtime = None
tool_calls, input_type = self._parse_input(input, store)
config_list = get_config_list(config, len(tool_calls))
input_types = [input_type] * len(tool_calls)
inputs = [input] * len(tool_calls)
runtimes = [runtime] * len(tool_calls)
with get_executor_for_config(config) as executor:
outputs = [*executor.map(self._run_one, tool_calls, input_types, config_list)]
outputs = [
*executor.map(self._run_one, tool_calls, input_types, config_list, inputs, runtimes)
]
return self._combine_tool_outputs(outputs, input_type)
@@ -444,11 +537,18 @@ class ToolNode(RunnableCallable):
input: list[AnyMessage] | dict[str, Any] | BaseModel,
config: RunnableConfig,
*,
# Optional[BaseStore] should not change to BaseStore | None
# until we support injection of store using `BaseStore | None` annotation
store: Optional[BaseStore], # noqa: UP045
) -> Any:
try:
runtime = get_runtime()
except RuntimeError:
# Running outside of the LangGrah runtime context (e.g., unit-tests)
runtime = None
tool_calls, input_type = self._parse_input(input, store)
outputs = await asyncio.gather(
*(self._arun_one(call, input_type, config) for call in tool_calls)
*(self._arun_one(call, input_type, config, input, runtime) for call in tool_calls)
)
return self._combine_tool_outputs(outputs, input_type)
@@ -495,20 +595,19 @@ class ToolNode(RunnableCallable):
combined_outputs.append(parent_command)
return combined_outputs
def _run_one(
self,
call: ToolCall,
input_type: Literal["list", "dict", "tool_calls"],
config: RunnableConfig,
) -> ToolMessage | Command:
"""Run a single tool call synchronously."""
if invalid_tool_message := self._validate_tool_call(call):
return invalid_tool_message
def _execute_tool_sync(
self, request: ToolCallRequest, input_type: Literal["list", "dict", "tool_calls"], config: RunnableConfig
) -> ToolCallResponse:
"""Execute tool and return response.
Applies handle_tool_errors configuration. When on_tool_call is configured,
unhandled errors return action="raise" instead of raising immediately.
"""
call = request.tool_call
tool = request.tool
call_args = {**call, "type": "tool_call"}
try:
call_args = {**call, "type": "tool_call"}
tool = self.tools_by_name[call["name"]]
try:
response = tool.invoke(call_args, config)
except ValidationError as exc:
@@ -541,40 +640,110 @@ class ToolNode(RunnableCallable):
# default behavior is catching all exceptions
handled_types = (Exception,)
# Unhandled
# Check if error is handled
if not self._handle_tool_errors or not isinstance(e, handled_types):
# Error is not handled
if self._on_tool_call is not None:
# If handler exists, return action="raise" so handler can decide
return ToolCallResponse(action="raise", exception=e)
# No handler - maintain backward compatibility by raising immediately
raise
# Handled
# Error is handled - create error ToolMessage
content = _handle_tool_error(e, flag=self._handle_tool_errors)
return ToolMessage(
error_message = ToolMessage(
content=content,
name=call["name"],
tool_call_id=call["id"],
status="error",
)
return ToolCallResponse(action="continue", result=error_message, exception=e)
# Process successful response
if isinstance(response, Command):
return self._validate_tool_command(response, call, input_type)
# Validate Command before returning to handler
validated_command = self._validate_tool_command(response, request.tool_call, input_type)
return ToolCallResponse(action="continue", result=validated_command)
if isinstance(response, ToolMessage):
response.content = cast("str | list", msg_content_output(response.content))
return response
return ToolCallResponse(action="continue", result=response)
msg = f"Tool {call['name']} returned unexpected type: {type(response)}"
raise TypeError(msg)
async def _arun_one(
def _run_one(
self,
call: ToolCall,
input_type: Literal["list", "dict", "tool_calls"],
config: RunnableConfig,
input: list[AnyMessage] | dict[str, Any] | BaseModel,
runtime: Any,
) -> ToolMessage | Command:
"""Run a single tool call asynchronously."""
"""Run a single tool call synchronously."""
if invalid_tool_message := self._validate_tool_call(call):
return invalid_tool_message
try:
call_args = {**call, "type": "tool_call"}
tool = self.tools_by_name[call["name"]]
tool = self.tools_by_name[call["name"]]
# Create the tool request
tool_request = ToolCallRequest(
tool_call=call,
tool=tool,
)
if self._on_tool_call is None:
tool_response = self._execute_tool_sync(tool_request, input_type, config)
else:
# Generator protocol: start generator, send responses, receive requests
gen = self._on_tool_call(tool_request, input, runtime)
try:
request = next(gen)
except StopIteration:
msg = "on_tool_call handler must yield at least once before returning"
raise ValueError(msg)
while True:
tool_response = self._execute_tool_sync(request, input_type, config)
try:
request = gen.send(tool_response)
except StopIteration as e:
if e.value is None:
msg = (
"on_tool_call handler must explicitly return a ToolCallResponse. "
"Ensure your handler ends with 'return response'."
)
raise ValueError(msg)
tool_response = e.value
break
# Apply action directive
if tool_response.action == "raise":
if tool_response.exception is None:
msg = "ToolCallResponse with action='raise' must have an exception"
raise ValueError(msg)
raise tool_response.exception
result = tool_response.result
if result is None:
msg = "ToolCallResponse with action='continue' must have a result"
raise ValueError(msg)
return result
async def _execute_tool_async(
self, request: ToolCallRequest, input_type: Literal["list", "dict", "tool_calls"], config: RunnableConfig
) -> ToolCallResponse:
"""Execute tool asynchronously and return response.
Applies handle_tool_errors configuration. When on_tool_call is configured,
unhandled errors return action="raise" instead of raising immediately.
"""
call = request.tool_call
tool = request.tool
call_args = {**call, "type": "tool_call"}
try:
try:
response = await tool.ainvoke(call_args, config)
except ValidationError as exc:
@@ -607,27 +776,97 @@ class ToolNode(RunnableCallable):
# default behavior is catching all exceptions
handled_types = (Exception,)
# Unhandled
# Check if error is handled
if not self._handle_tool_errors or not isinstance(e, handled_types):
# Error is not handled
if self._on_tool_call is not None:
# If handler exists, return action="raise" so handler can decide
return ToolCallResponse(action="raise", exception=e)
# No handler - maintain backward compatibility by raising immediately
raise
# Handled
content = _handle_tool_error(e, flag=self._handle_tool_errors)
return ToolMessage(
# Error is handled - create error ToolMessage
content = _handle_tool_error(e, flag=self._handle_tool_errors)
error_message = ToolMessage(
content=content,
name=call["name"],
tool_call_id=call["id"],
status="error",
)
return ToolCallResponse(action="continue", result=error_message, exception=e)
# Process successful response
if isinstance(response, Command):
return self._validate_tool_command(response, call, input_type)
# Validate Command before returning to handler
validated_command = self._validate_tool_command(response, request.tool_call, input_type)
return ToolCallResponse(action="continue", result=validated_command)
if isinstance(response, ToolMessage):
response.content = cast("str | list", msg_content_output(response.content))
return response
return ToolCallResponse(action="continue", result=response)
msg = f"Tool {call['name']} returned unexpected type: {type(response)}"
raise TypeError(msg)
async def _arun_one(
self,
call: ToolCall,
input_type: Literal["list", "dict", "tool_calls"],
config: RunnableConfig,
input: list[AnyMessage] | dict[str, Any] | BaseModel,
runtime: Any,
) -> ToolMessage | Command:
"""Run a single tool call asynchronously."""
if invalid_tool_message := self._validate_tool_call(call):
return invalid_tool_message
tool = self.tools_by_name[call["name"]]
# Create the tool request
tool_request = ToolCallRequest(
tool_call=call,
tool=tool,
)
if self._on_tool_call is None:
tool_response = await self._execute_tool_async(tool_request, input_type, config)
else:
# Generator protocol: handler is sync generator, tool execution is async
gen = self._on_tool_call(tool_request, input, runtime)
try:
request = next(gen)
except StopIteration:
msg = "on_tool_call handler must yield at least once before returning"
raise ValueError(msg)
while True:
tool_response = await self._execute_tool_async(request, input_type, config)
try:
request = gen.send(tool_response)
except StopIteration as e:
if e.value is None:
msg = (
"on_tool_call handler must explicitly return a ToolCallResponse. "
"Ensure your handler ends with 'return response'."
)
raise ValueError(msg)
tool_response = e.value
break
# Apply action directive
if tool_response.action == "raise":
if tool_response.exception is None:
msg = "ToolCallResponse with action='raise' must have an exception"
raise ValueError(msg)
raise tool_response.exception
result = tool_response.result
if result is None:
msg = "ToolCallResponse with action='continue' must have a result"
raise ValueError(msg)
return result
def _parse_input(
self,
input: list[AnyMessage] | dict[str, Any] | BaseModel,

View File

@@ -0,0 +1,396 @@
"""Unit tests for on_tool_call middleware hook."""
from collections.abc import Generator
from typing import Any, Literal, Union
import typing
from pydantic import BaseModel
import pytest
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool, tool
from langchain.agents.middleware.types import AgentMiddleware
from langchain.agents.middleware_agent import create_agent
from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse
class FakeModel(GenericFakeChatModel):
"""Fake chat model for testing."""
tool_style: Literal["openai", "anthropic"] = "openai"
def bind_tools(
self,
tools: typing.Sequence[Union[dict[str, Any], type[BaseModel], typing.Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
if len(tools) == 0:
msg = "Must provide at least one tool"
raise ValueError(msg)
tool_dicts = []
for tool in tools:
if isinstance(tool, dict):
tool_dicts.append(tool)
continue
if not isinstance(tool, BaseTool):
msg = "Only BaseTool and dict is supported by FakeModel.bind_tools"
raise TypeError(msg)
# NOTE: this is a simplified tool spec for testing purposes only
if self.tool_style == "openai":
tool_dicts.append(
{
"type": "function",
"function": {
"name": tool.name,
},
}
)
elif self.tool_style == "anthropic":
tool_dicts.append(
{
"name": tool.name,
}
)
return self.bind(tools=tool_dicts)
@tool
def add_tool(x: int, y: int) -> int:
"""Add two numbers."""
return x + y
@tool
def failing_tool(x: int) -> int:
"""Tool that raises an error."""
msg = "Intentional failure"
raise ValueError(msg)
def test_single_middleware_on_tool_call():
"""Test that a single middleware can intercept tool calls."""
call_log = []
class LoggingMiddleware(AgentMiddleware):
"""Middleware that logs tool calls."""
def on_tool_call(
self, request: ToolCallRequest, state, runtime
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
call_log.append(f"before_{request.tool.name}")
response = yield request
call_log.append(f"after_{request.tool.name}")
return response
model = FakeModel(
messages=iter(
[
AIMessage(
content="",
tool_calls=[{"name": "add_tool", "args": {"x": 2, "y": 3}, "id": "1"}],
),
AIMessage(content="Done"),
]
)
)
agent = create_agent(
model=model,
tools=[add_tool],
middleware=[LoggingMiddleware()],
)
result = agent.compile().invoke({"messages": [HumanMessage("Add 2 and 3")]})
assert "before_add_tool" in call_log
assert "after_add_tool" in call_log
assert call_log.index("before_add_tool") < call_log.index("after_add_tool")
# Check that tool executed successfully
tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)]
assert len(tool_messages) == 1
assert tool_messages[0].content == "5"
def test_multiple_middleware_chaining():
"""Test that multiple middleware chain correctly (outer wraps inner)."""
call_order = []
class OuterMiddleware(AgentMiddleware):
"""Outer middleware in the chain."""
def on_tool_call(
self, request: ToolCallRequest, state, runtime
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
call_order.append("outer_start")
response = yield request
call_order.append("outer_end")
return response
class InnerMiddleware(AgentMiddleware):
"""Inner middleware in the chain."""
def on_tool_call(
self, request: ToolCallRequest, state, runtime
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
call_order.append("inner_start")
response = yield request
call_order.append("inner_end")
return response
model = FakeModel(
messages=iter(
[
AIMessage(
content="",
tool_calls=[{"name": "add_tool", "args": {"x": 1, "y": 1}, "id": "1"}],
),
AIMessage(content="Done"),
]
)
)
agent = create_agent(
model=model,
tools=[add_tool],
middleware=[OuterMiddleware(), InnerMiddleware()],
)
agent.compile().invoke({"messages": [HumanMessage("Add 1 and 1")]})
# Verify order: outer_start -> inner_start -> tool -> inner_end -> outer_end
assert call_order == ["outer_start", "inner_start", "inner_end", "outer_end"]
def test_middleware_retry_logic():
"""Test that middleware can retry tool calls."""
attempt_count = 0
class RetryMiddleware(AgentMiddleware):
"""Middleware that retries on failure."""
def on_tool_call(
self, request: ToolCallRequest, state, runtime
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
nonlocal attempt_count
max_retries = 2
for attempt in range(max_retries):
attempt_count += 1
response = yield request
if response.action == "continue":
return response
if response.action == "raise" and attempt < max_retries - 1:
# Retry
continue
# Convert error to success message
return ToolCallResponse(
action="continue",
result=ToolMessage(
content=f"Failed after {max_retries} attempts",
name=request.tool_call["name"],
tool_call_id=request.tool_call["id"],
status="error",
),
)
raise AssertionError("Unreachable")
model = FakeModel(
messages=iter(
[
AIMessage(
content="",
tool_calls=[{"name": "failing_tool", "args": {"x": 1}, "id": "1"}],
),
AIMessage(content="Done"),
]
)
)
agent = create_agent(
model=model,
tools=[failing_tool],
middleware=[RetryMiddleware()],
)
result = agent.compile().invoke({"messages": [HumanMessage("Test retry")]})
# Should have attempted twice
assert attempt_count == 2
# Check that we got an error message
tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)]
assert len(tool_messages) == 1
assert "Failed after 2 attempts" in tool_messages[0].content
def test_middleware_request_modification():
"""Test that middleware can modify tool requests."""
class RequestModifierMiddleware(AgentMiddleware):
"""Middleware that doubles the input."""
def on_tool_call(
self, request: ToolCallRequest, state, runtime
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
# Modify the arguments
modified_tool_call = {
**request.tool_call,
"args": {
"x": request.tool_call["args"]["x"] * 2,
"y": request.tool_call["args"]["y"] * 2,
},
}
modified_request = ToolCallRequest(
tool_call=modified_tool_call,
tool=request.tool,
)
response = yield modified_request
return response
model = FakeModel(
messages=iter(
[
AIMessage(
content="",
tool_calls=[{"name": "add_tool", "args": {"x": 1, "y": 2}, "id": "1"}],
),
AIMessage(content="Done"),
]
)
)
agent = create_agent(
model=model,
tools=[add_tool],
middleware=[RequestModifierMiddleware()],
)
result = agent.compile().invoke({"messages": [HumanMessage("Add 1 and 2")]})
# Original: 1 + 2 = 3, Modified: 2 + 4 = 6
tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)]
assert len(tool_messages) == 1
assert tool_messages[0].content == "6"
def test_multiple_middleware_with_retry():
"""Test complex scenario with multiple middleware and retry logic."""
call_log = []
class MonitoringMiddleware(AgentMiddleware):
"""Outer middleware for monitoring."""
def on_tool_call(
self, request: ToolCallRequest, state, runtime
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
call_log.append("monitoring_start")
response = yield request
call_log.append("monitoring_end")
return response
class RetryMiddleware(AgentMiddleware):
"""Inner middleware for retries."""
def on_tool_call(
self, request: ToolCallRequest, state, runtime
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
call_log.append("retry_start")
for attempt in range(2):
call_log.append(f"retry_attempt_{attempt + 1}")
response = yield request
if response.action == "continue":
call_log.append("retry_success")
return response
if attempt == 0: # Retry once
call_log.append("retry_retry")
continue
call_log.append("retry_failed")
return response
model = FakeModel(
messages=iter(
[
AIMessage(
content="",
tool_calls=[{"name": "add_tool", "args": {"x": 5, "y": 7}, "id": "1"}],
),
AIMessage(content="Done"),
]
)
)
agent = create_agent(
model=model,
tools=[add_tool],
middleware=[MonitoringMiddleware(), RetryMiddleware()],
)
agent.compile().invoke({"messages": [HumanMessage("Add 5 and 7")]})
# Verify the call sequence
assert call_log[0] == "monitoring_start"
assert call_log[1] == "retry_start"
assert "retry_attempt_1" in call_log
assert "retry_success" in call_log
assert call_log[-1] == "monitoring_end"
def test_mixed_middleware():
"""Test middleware with both before_model and on_tool_call hooks."""
call_log = []
class MixedMiddleware(AgentMiddleware):
"""Middleware with multiple hooks."""
def before_model(self, state, runtime):
call_log.append("before_model")
return None
def on_tool_call(
self, request: ToolCallRequest, state, runtime
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
call_log.append("on_tool_call_start")
response = yield request
call_log.append("on_tool_call_end")
return response
model = FakeModel(
messages=iter(
[
AIMessage(
content="",
tool_calls=[{"name": "add_tool", "args": {"x": 10, "y": 20}, "id": "1"}],
),
AIMessage(content="Done"),
]
)
)
agent = create_agent(
model=model,
tools=[add_tool],
middleware=[MixedMiddleware()],
)
agent.compile().invoke({"messages": [HumanMessage("Add 10 and 20")]})
# Both hooks should have been called
assert "before_model" in call_log
assert "on_tool_call_start" in call_log
assert "on_tool_call_end" in call_log
# before_model runs before on_tool_call
assert call_log.index("before_model") < call_log.index("on_tool_call_start")

View File

@@ -0,0 +1,383 @@
"""Tests for on_tool_call handler functionality."""
from collections.abc import Generator
from typing import Any
import pytest
from langchain_core.messages import AIMessage, ToolMessage
from langchain_core.tools import tool
from langchain.tools import ToolNode
from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse
# Test tools
@tool
def success_tool(x: int) -> int:
"""A tool that always succeeds."""
return x * 2
@tool
def error_tool(x: int) -> int:
"""A tool that always raises ValueError."""
msg = f"Error with value: {x}"
raise ValueError(msg)
@tool
def rate_limit_tool(x: int) -> int:
"""A tool that simulates rate limit errors."""
if not hasattr(rate_limit_tool, "_call_count"):
rate_limit_tool._call_count = 0
rate_limit_tool._call_count += 1
if rate_limit_tool._call_count < 3: # Fail first 2 times
msg = "Rate limit exceeded"
raise ValueError(msg)
return x * 2
def test_on_tool_call_passthrough() -> None:
"""Test that a simple passthrough handler works."""
def passthrough_handler(
request: ToolCallRequest, _state: Any, _runtime: Any
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
"""Simply pass through without modification."""
response = yield request
return response
tool_node = ToolNode([success_tool], on_tool_call=passthrough_handler)
result = tool_node.invoke(
{
"messages": [
AIMessage(
"",
tool_calls=[{"name": "success_tool", "args": {"x": 5}, "id": "1"}],
)
]
}
)
assert len(result["messages"]) == 1
tool_message: ToolMessage = result["messages"][0]
assert tool_message.content == "10"
assert tool_message.status != "error"
def test_on_tool_call_retry_success() -> None:
"""Test that retry handler can recover from transient errors."""
# Reset counter
if hasattr(rate_limit_tool, "_call_count"):
rate_limit_tool._call_count = 0
def retry_handler(
request: ToolCallRequest, _state: Any, _runtime: Any
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
"""Retry up to 3 times."""
max_retries = 3
for attempt in range(max_retries):
response = yield request
if response.action == "continue":
return response
# Retry on error
if attempt < max_retries - 1:
continue
# Final attempt failed - convert to error message
return ToolCallResponse(
action="continue",
result=ToolMessage(
content=f"Failed after {max_retries} attempts",
name=request.tool_call["name"],
tool_call_id=request.tool_call["id"],
status="error",
),
)
msg = "Unreachable code"
raise AssertionError(msg)
tool_node = ToolNode([rate_limit_tool], on_tool_call=retry_handler, handle_tool_errors=False)
result = tool_node.invoke(
{
"messages": [
AIMessage(
"",
tool_calls=[{"name": "rate_limit_tool", "args": {"x": 5}, "id": "1"}],
)
]
}
)
assert len(result["messages"]) == 1
tool_message: ToolMessage = result["messages"][0]
assert tool_message.content == "10" # Should succeed on 3rd attempt
assert tool_message.status != "error"
def test_on_tool_call_convert_error_to_message() -> None:
"""Test that handler can convert raised errors to error messages."""
def error_to_message_handler(
request: ToolCallRequest, _state: Any, _runtime: Any
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
"""Convert any error to a user-friendly message."""
response = yield request
if response.action == "raise":
return ToolCallResponse(
action="continue",
result=ToolMessage(
content=f"Tool failed: {response.exception}",
name=request.tool_call["name"],
tool_call_id=request.tool_call["id"],
status="error",
),
exception=response.exception,
)
return response
tool_node = ToolNode(
[error_tool], on_tool_call=error_to_message_handler, handle_tool_errors=False
)
result = tool_node.invoke(
{
"messages": [
AIMessage(
"",
tool_calls=[{"name": "error_tool", "args": {"x": 5}, "id": "1"}],
)
]
}
)
assert len(result["messages"]) == 1
tool_message: ToolMessage = result["messages"][0]
assert "Tool failed" in tool_message.content
assert "Error with value: 5" in tool_message.content
assert tool_message.status == "error"
def test_on_tool_call_let_error_raise() -> None:
"""Test that handler can let errors propagate."""
def let_raise_handler(
request: ToolCallRequest, _state: Any, _runtime: Any
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
"""Just return the response as-is, letting errors raise."""
response = yield request
return response
tool_node = ToolNode([error_tool], on_tool_call=let_raise_handler, handle_tool_errors=False)
with pytest.raises(ValueError, match=r"Error with value: 5"):
tool_node.invoke(
{
"messages": [
AIMessage(
"",
tool_calls=[{"name": "error_tool", "args": {"x": 5}, "id": "1"}],
)
]
}
)
def test_on_tool_call_with_handled_errors() -> None:
"""Test interaction between on_tool_call and handle_tool_errors."""
call_count = {"count": 0}
def counting_handler(
request: ToolCallRequest, _state: Any, _runtime: Any
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
"""Count how many times we're called."""
call_count["count"] += 1
response = yield request
return response
# When handle_tool_errors=True, errors are converted to ToolMessages
# so handler sees action="continue"
tool_node = ToolNode([error_tool], on_tool_call=counting_handler, handle_tool_errors=True)
result = tool_node.invoke(
{
"messages": [
AIMessage(
"",
tool_calls=[{"name": "error_tool", "args": {"x": 5}, "id": "1"}],
)
]
}
)
assert call_count["count"] == 1
assert len(result["messages"]) == 1
tool_message: ToolMessage = result["messages"][0]
assert tool_message.status == "error"
assert "Please fix your mistakes" in tool_message.content
def test_on_tool_call_must_return_value() -> None:
"""Test that handler must return a ToolCallResponse."""
def no_return_handler(
request: ToolCallRequest, _state: Any, _runtime: Any
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
"""Handler that doesn't return anything."""
_ = yield request
# Implicit return None
tool_node = ToolNode([success_tool], on_tool_call=no_return_handler)
with pytest.raises(ValueError, match=r"must explicitly return a ToolCallResponse"):
tool_node.invoke(
{
"messages": [
AIMessage(
"",
tool_calls=[{"name": "success_tool", "args": {"x": 5}, "id": "1"}],
)
]
}
)
def test_on_tool_call_request_modification() -> None:
"""Test that handler can modify the request before execution."""
def double_input_handler(
request: ToolCallRequest, _state: Any, _runtime: Any
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
"""Double the input value."""
# Modify the tool call args
modified_tool_call = {
**request.tool_call,
"args": {**request.tool_call["args"], "x": request.tool_call["args"]["x"] * 2},
}
modified_request = ToolCallRequest(
tool_call=modified_tool_call,
tool=request.tool,
)
response = yield modified_request
return response
tool_node = ToolNode([success_tool], on_tool_call=double_input_handler)
result = tool_node.invoke(
{
"messages": [
AIMessage(
"",
tool_calls=[{"name": "success_tool", "args": {"x": 5}, "id": "1"}],
)
]
}
)
assert len(result["messages"]) == 1
tool_message: ToolMessage = result["messages"][0]
# Input was 5, doubled to 10, then tool multiplies by 2 = 20
assert tool_message.content == "20"
def test_on_tool_call_response_validation() -> None:
"""Test that ToolCallResponse validates action and required fields."""
# Test action="continue" requires result
with pytest.raises(ValueError, match=r"action='continue' requires a result"):
ToolCallResponse(action="continue")
# Test action="raise" requires exception
with pytest.raises(ValueError, match=r"action='raise' requires an exception"):
ToolCallResponse(action="raise")
# Valid responses should work
ToolCallResponse(
action="continue",
result=ToolMessage(content="test", tool_call_id="1", name="test"),
)
ToolCallResponse(action="raise", exception=ValueError("test"))
def test_on_tool_call_without_handler_backward_compat() -> None:
"""Test that tools work without on_tool_call handler (backward compatibility)."""
# Success case
tool_node = ToolNode([success_tool])
result = tool_node.invoke(
{
"messages": [
AIMessage(
"",
tool_calls=[{"name": "success_tool", "args": {"x": 5}, "id": "1"}],
)
]
}
)
assert result["messages"][0].content == "10"
# Error case with handle_tool_errors=False
tool_node_error = ToolNode([error_tool], handle_tool_errors=False)
with pytest.raises(ValueError, match=r"Error with value: 5"):
tool_node_error.invoke(
{
"messages": [
AIMessage(
"",
tool_calls=[{"name": "error_tool", "args": {"x": 5}, "id": "1"}],
)
]
}
)
# Error case with handle_tool_errors=True
tool_node_handled = ToolNode([error_tool], handle_tool_errors=True)
result = tool_node_handled.invoke(
{
"messages": [
AIMessage(
"",
tool_calls=[{"name": "error_tool", "args": {"x": 5}, "id": "1"}],
)
]
}
)
assert result["messages"][0].status == "error"
def test_on_tool_call_multiple_yields() -> None:
"""Test that handler can yield multiple times for retries."""
attempts = {"count": 0}
def multi_yield_handler(
request: ToolCallRequest, _state: Any, _runtime: Any
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
"""Yield multiple times to track attempts."""
max_attempts = 3
for _ in range(max_attempts):
attempts["count"] += 1
response = yield request
if response.action == "continue":
return response
# All attempts failed
return response
tool_node = ToolNode([error_tool], on_tool_call=multi_yield_handler, handle_tool_errors=False)
with pytest.raises(ValueError, match=r"Error with value: 5"):
tool_node.invoke(
{
"messages": [
AIMessage(
"",
tool_calls=[{"name": "error_tool", "args": {"x": 5}, "id": "1"}],
)
]
}
)
assert attempts["count"] == 3