mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-11 03:30:09 +00:00
Compare commits
12 Commits
sr/dont-du
...
eugene/on_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3c92e986f6 | ||
|
|
40b4c69a5a | ||
|
|
076c6f6b41 | ||
|
|
db58bfa543 | ||
|
|
ba9ec6d895 | ||
|
|
fa533c44b7 | ||
|
|
4f53ed3e9a | ||
|
|
def2f147ae | ||
|
|
65e073e85c | ||
|
|
a9ff8e0b67 | ||
|
|
0927ae4be1 | ||
|
|
06ce94ca06 |
@@ -7,6 +7,10 @@ from .planning import PlanningMiddleware
|
||||
from .prompt_caching import AnthropicPromptCachingMiddleware
|
||||
from .summarization import SummarizationMiddleware
|
||||
from .tool_call_limit import ToolCallLimitMiddleware
|
||||
from .tool_error_handling import (
|
||||
ErrorToMessageMiddleware,
|
||||
ToolRetryMiddleware,
|
||||
)
|
||||
from .tool_selection import LLMToolSelectorMiddleware
|
||||
from .types import (
|
||||
AgentMiddleware,
|
||||
@@ -24,6 +28,7 @@ __all__ = [
|
||||
"AgentState",
|
||||
# should move to langchain-anthropic if we decide to keep it
|
||||
"AnthropicPromptCachingMiddleware",
|
||||
"ErrorToMessageMiddleware",
|
||||
"HumanInTheLoopMiddleware",
|
||||
"LLMToolSelectorMiddleware",
|
||||
"ModelFallbackMiddleware",
|
||||
@@ -31,6 +36,7 @@ __all__ = [
|
||||
"PIIDetectionError",
|
||||
"PIIMiddleware",
|
||||
"PlanningMiddleware",
|
||||
"ToolRetryMiddleware",
|
||||
"SummarizationMiddleware",
|
||||
"ToolCallLimitMiddleware",
|
||||
"after_model",
|
||||
|
||||
@@ -0,0 +1,406 @@
|
||||
"""Middleware for handling tool execution errors in agents.
|
||||
|
||||
This module provides composable middleware for error handling, retries,
|
||||
and error-to-message conversion in tool execution workflows.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import time
|
||||
from types import UnionType
|
||||
from typing import TYPE_CHECKING, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware
|
||||
|
||||
# Import ToolCallResponse locally to avoid circular import
|
||||
from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Generator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Default retriable exception types - transient errors that may succeed on retry
|
||||
DEFAULT_RETRIABLE_EXCEPTIONS = (
|
||||
# Network and connection errors
|
||||
ConnectionError,
|
||||
TimeoutError,
|
||||
# HTTP client errors are typically not retriable, but these are exceptions:
|
||||
# - 429: Rate limit (temporary)
|
||||
# - 503: Service unavailable (temporary)
|
||||
# Note: Specific HTTP libraries may define their own exception types
|
||||
)
|
||||
|
||||
|
||||
def _infer_retriable_types(
|
||||
predicate: Callable[[Exception], bool],
|
||||
) -> tuple[type[Exception], ...]:
|
||||
"""Infer exception types from a retry predicate function's type annotations.
|
||||
|
||||
Analyzes the type annotations of a predicate function to determine which
|
||||
exception types it's designed to handle for retry decisions.
|
||||
|
||||
Args:
|
||||
predicate: A callable that takes an exception and returns whether to retry.
|
||||
The first parameter should be type-annotated with exception type(s).
|
||||
|
||||
Returns:
|
||||
Tuple of exception types that the predicate handles. Returns (Exception,)
|
||||
if no specific type information is available.
|
||||
|
||||
Raises:
|
||||
ValueError: If the predicate's annotation contains non-Exception types.
|
||||
"""
|
||||
sig = inspect.signature(predicate)
|
||||
params = list(sig.parameters.values())
|
||||
if params:
|
||||
# Skip self/cls if it's a method
|
||||
if params[0].name in ["self", "cls"] and len(params) == 2:
|
||||
first_param = params[1]
|
||||
else:
|
||||
first_param = params[0]
|
||||
|
||||
type_hints = get_type_hints(predicate)
|
||||
if first_param.name in type_hints:
|
||||
origin = get_origin(first_param.annotation)
|
||||
# Handle Union types
|
||||
if origin in [Union, UnionType]:
|
||||
args = get_args(first_param.annotation)
|
||||
if all(isinstance(arg, type) and issubclass(arg, Exception) for arg in args):
|
||||
return tuple(args)
|
||||
msg = (
|
||||
"All types in retry predicate annotation must be Exception types. "
|
||||
"For example, `def should_retry(e: Union[TimeoutError, "
|
||||
"ConnectionError]) -> bool`. "
|
||||
f"Got '{first_param.annotation}' instead."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
# Handle single exception type
|
||||
exception_type = type_hints[first_param.name]
|
||||
if isinstance(exception_type, type) and issubclass(exception_type, Exception):
|
||||
return (exception_type,)
|
||||
msg = (
|
||||
"Retry predicate must be annotated with Exception type(s). "
|
||||
"For example, `def should_retry(e: TimeoutError) -> bool` or "
|
||||
"`def should_retry(e: Union[TimeoutError, ConnectionError]) -> bool`. "
|
||||
f"Got '{exception_type}' instead."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
# No type information - return Exception for backward compatibility
|
||||
return (Exception,)
|
||||
|
||||
|
||||
class ToolRetryMiddleware(AgentMiddleware):
|
||||
"""Retry failed tool calls with constant delay.
|
||||
|
||||
This middleware catches tool execution errors and retries them up to a maximum
|
||||
number of attempts with a constant delay between retries. It operates at the
|
||||
outermost layer of middleware composition to catch all errors.
|
||||
|
||||
Examples:
|
||||
Retry only network errors:
|
||||
|
||||
```python
|
||||
from langchain.agents.middleware import ToolRetryMiddleware
|
||||
|
||||
middleware = ToolRetryMiddleware(
|
||||
max_retries=3,
|
||||
delay=2.0,
|
||||
retry_on=(TimeoutError, ConnectionError),
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model="openai:gpt-4o",
|
||||
tools=[my_tool],
|
||||
middleware=[middleware],
|
||||
)
|
||||
```
|
||||
|
||||
Use predicate function for custom retry logic:
|
||||
|
||||
```python
|
||||
from langchain.tools.tool_node import ToolInvocationError
|
||||
|
||||
|
||||
def should_retry(e: Exception) -> bool:
|
||||
# Don't retry validation errors from LLM
|
||||
if isinstance(e, ToolInvocationError):
|
||||
return False
|
||||
# Retry network errors
|
||||
if isinstance(e, (TimeoutError, ConnectionError)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
middleware = ToolRetryMiddleware(
|
||||
max_retries=3,
|
||||
retry_on=should_retry,
|
||||
)
|
||||
```
|
||||
|
||||
Compose with error conversion:
|
||||
|
||||
```python
|
||||
from langchain.agents.middleware import (
|
||||
ToolRetryMiddleware,
|
||||
ErrorToMessageMiddleware,
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model="openai:gpt-4o",
|
||||
tools=[my_tool],
|
||||
middleware=[
|
||||
# Outer: retry network errors
|
||||
ToolRetryMiddleware(
|
||||
max_retries=3,
|
||||
delay=2.0,
|
||||
retry_on=(TimeoutError, ConnectionError),
|
||||
),
|
||||
# Inner: convert validation errors to messages
|
||||
ErrorToMessageMiddleware(
|
||||
exception_types=(ValidationError,),
|
||||
),
|
||||
],
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_retries: int = 3,
|
||||
delay: float = 1.0,
|
||||
retry_on: type[Exception]
|
||||
| tuple[type[Exception], ...]
|
||||
| Callable[[Exception], bool] = DEFAULT_RETRIABLE_EXCEPTIONS,
|
||||
) -> None:
|
||||
"""Initialize retry middleware.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retry attempts. Total attempts will be
|
||||
max_retries + 1 (initial attempt plus retries).
|
||||
delay: Constant delay in seconds between retry attempts.
|
||||
retry_on: Specifies which exceptions should be retried. Can be:
|
||||
- **type[Exception]**: Retry only this exception type
|
||||
- **tuple[type[Exception], ...]**: Retry these exception types
|
||||
- **Callable[[Exception], bool]**: Predicate function that returns
|
||||
True if the exception should be retried. Type annotations on the
|
||||
callable are used to filter which exceptions are passed to it.
|
||||
Defaults to ``DEFAULT_RETRIABLE_EXCEPTIONS`` (ConnectionError, TimeoutError).
|
||||
"""
|
||||
super().__init__()
|
||||
if max_retries < 0:
|
||||
msg = "max_retries must be non-negative"
|
||||
raise ValueError(msg)
|
||||
if delay < 0:
|
||||
msg = "delay must be non-negative"
|
||||
raise ValueError(msg)
|
||||
|
||||
self.max_retries = max_retries
|
||||
self.delay = delay
|
||||
self._retry_on = retry_on
|
||||
|
||||
# Determine which exception types to check
|
||||
if isinstance(retry_on, type) and issubclass(retry_on, Exception):
|
||||
self._retriable_types = (retry_on,)
|
||||
self._retry_predicate = None
|
||||
elif isinstance(retry_on, tuple):
|
||||
if not retry_on:
|
||||
msg = "retry_on tuple must not be empty"
|
||||
raise ValueError(msg)
|
||||
if not all(isinstance(t, type) and issubclass(t, Exception) for t in retry_on):
|
||||
msg = "All elements in retry_on tuple must be Exception types"
|
||||
raise ValueError(msg)
|
||||
self._retriable_types = retry_on
|
||||
self._retry_predicate = None
|
||||
elif callable(retry_on):
|
||||
self._retriable_types = _infer_retriable_types(retry_on)
|
||||
self._retry_predicate = retry_on
|
||||
else:
|
||||
msg = (
|
||||
"retry_on must be an Exception type, tuple of Exception types, "
|
||||
f"or callable. Got {type(retry_on)}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def on_tool_call(
|
||||
self, request: ToolCallRequest, state, runtime
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
"""Retry tool execution on failures."""
|
||||
for attempt in range(1, self.max_retries + 2): # +1 for initial, +1 for inclusive
|
||||
response = yield request
|
||||
|
||||
# Success - return immediately
|
||||
if response.action == "continue":
|
||||
return response
|
||||
|
||||
# Error - check if we should retry
|
||||
if response.action == "raise":
|
||||
exception = response.exception
|
||||
if exception is None:
|
||||
msg = "ToolCallResponse with action='raise' must have an exception"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Check if this exception type is retriable
|
||||
if not isinstance(exception, self._retriable_types):
|
||||
logger.debug(
|
||||
"Exception %s is not retriable for tool %s",
|
||||
type(exception).__name__,
|
||||
request.tool_call["name"],
|
||||
)
|
||||
return response
|
||||
|
||||
# If predicate is provided, check if we should retry
|
||||
if self._retry_predicate is not None and not self._retry_predicate(exception):
|
||||
logger.debug(
|
||||
"Retry predicate returned False for %s in tool %s",
|
||||
type(exception).__name__,
|
||||
request.tool_call["name"],
|
||||
)
|
||||
return response
|
||||
|
||||
# Last attempt - return error
|
||||
if attempt > self.max_retries:
|
||||
logger.debug(
|
||||
"Max retries (%d) reached for tool %s",
|
||||
self.max_retries,
|
||||
request.tool_call["name"],
|
||||
)
|
||||
return response
|
||||
|
||||
# Retry - log and delay
|
||||
logger.debug(
|
||||
"Retrying tool %s (attempt %d/%d) after error: %s",
|
||||
request.tool_call["name"],
|
||||
attempt,
|
||||
self.max_retries + 1,
|
||||
type(exception).__name__,
|
||||
)
|
||||
time.sleep(self.delay)
|
||||
continue
|
||||
|
||||
# Should never reach here
|
||||
msg = f"Unexpected control flow in ToolRetryMiddleware for tool {request.tool_call['name']}"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
|
||||
class ErrorToMessageMiddleware(AgentMiddleware):
|
||||
"""Convert specific exception types to ToolMessages.
|
||||
|
||||
This middleware intercepts errors and converts them into ToolMessages that
|
||||
can be sent back to the model as feedback. This is useful for errors caused
|
||||
by invalid model inputs where the model needs feedback to correct its behavior.
|
||||
|
||||
Examples:
|
||||
Convert validation errors to messages:
|
||||
|
||||
```python
|
||||
from langchain.agents.middleware import ErrorToMessageMiddleware
|
||||
from langchain.tools.tool_node import ToolInvocationError
|
||||
|
||||
middleware = ErrorToMessageMiddleware(
|
||||
exception_types=(ToolInvocationError,),
|
||||
message_template="Invalid arguments: {error}. Please fix and try again.",
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model="openai:gpt-4o",
|
||||
tools=[my_tool],
|
||||
middleware=[middleware],
|
||||
)
|
||||
```
|
||||
|
||||
Compose with retry for network errors:
|
||||
|
||||
```python
|
||||
from langchain.agents.middleware import (
|
||||
ToolRetryMiddleware,
|
||||
ErrorToMessageMiddleware,
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model="openai:gpt-4o",
|
||||
tools=[my_tool],
|
||||
middleware=[
|
||||
# Outer: retry all errors
|
||||
ToolRetryMiddleware(max_retries=3),
|
||||
# Inner: convert validation errors to messages
|
||||
ErrorToMessageMiddleware(
|
||||
exception_types=(ValidationError,),
|
||||
),
|
||||
],
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
exception_types: tuple[type[Exception], ...],
|
||||
message_template: str = "Error: {error}",
|
||||
) -> None:
|
||||
"""Initialize error conversion middleware.
|
||||
|
||||
Args:
|
||||
exception_types: Tuple of exception types to convert to messages.
|
||||
message_template: Template string for error messages. Can use ``{error}``
|
||||
placeholder for the exception string representation.
|
||||
"""
|
||||
super().__init__()
|
||||
if not exception_types:
|
||||
msg = "exception_types must not be empty"
|
||||
raise ValueError(msg)
|
||||
|
||||
self.exception_types = exception_types
|
||||
self.message_template = message_template
|
||||
|
||||
def on_tool_call(
|
||||
self, request: ToolCallRequest, state, runtime
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
"""Convert matching errors to ToolMessages."""
|
||||
response = yield request
|
||||
|
||||
# Success - pass through
|
||||
if response.action == "continue":
|
||||
return response
|
||||
|
||||
# Error - check if we should convert
|
||||
if response.action == "raise":
|
||||
exception = response.exception
|
||||
if exception is None:
|
||||
msg = "ToolCallResponse with action='raise' must have an exception"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Check if exception type matches
|
||||
if not isinstance(exception, self.exception_types):
|
||||
return response
|
||||
|
||||
# Convert to ToolMessage
|
||||
logger.debug(
|
||||
"Converting %s to ToolMessage for tool %s",
|
||||
type(exception).__name__,
|
||||
request.tool_call["name"],
|
||||
)
|
||||
|
||||
error_message = self.message_template.format(error=str(exception))
|
||||
tool_message = ToolMessage(
|
||||
content=error_message,
|
||||
name=request.tool_call["name"],
|
||||
tool_call_id=request.tool_call["id"],
|
||||
status="error",
|
||||
)
|
||||
|
||||
return ToolCallResponse(
|
||||
action="continue",
|
||||
result=tool_message,
|
||||
exception=exception, # Preserve for logging/debugging
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -19,7 +19,7 @@ from typing import (
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import Awaitable, Generator
|
||||
|
||||
# needed as top level import for pydantic schema generation on AgentState
|
||||
from langchain_core.messages import AnyMessage # noqa: TC002
|
||||
@@ -36,6 +36,7 @@ if TYPE_CHECKING:
|
||||
from langgraph.types import Command
|
||||
|
||||
from langchain.agents.structured_output import ResponseFormat
|
||||
from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse
|
||||
|
||||
__all__ = [
|
||||
"AgentMiddleware",
|
||||
@@ -215,6 +216,48 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
||||
None, self.retry_model_request, error, request, state, runtime, attempt
|
||||
)
|
||||
|
||||
def on_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
state: StateT,
|
||||
runtime: Runtime[ContextT],
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
"""Intercept tool execution to implement retry logic, monitoring, or request modification.
|
||||
|
||||
Provides generator-based control over the complete tool execution lifecycle.
|
||||
Multiple middleware can define this hook; they compose automatically with
|
||||
outer middleware wrapping inner middleware (first defined = outermost layer).
|
||||
|
||||
Generator Protocol:
|
||||
1. Yield a ToolCallRequest (potentially modified from the input)
|
||||
2. Receive a ToolCallResponse via .send()
|
||||
3. Optionally yield again to retry
|
||||
4. Return the final ToolCallResponse to propagate
|
||||
|
||||
Args:
|
||||
request: Tool invocation details including tool_call, tool instance, and config.
|
||||
state: Current agent state (readonly context).
|
||||
runtime: LangGraph runtime for accessing user context (readonly context).
|
||||
|
||||
Returns:
|
||||
Generator for request/response interception.
|
||||
|
||||
Example:
|
||||
Retry on rate limit with exponential backoff:
|
||||
|
||||
```python
|
||||
def on_tool_call(self, request, state, runtime):
|
||||
for attempt in range(3):
|
||||
response = yield request
|
||||
if response.action == "continue":
|
||||
return response
|
||||
if "rate limit" in str(response.exception):
|
||||
time.sleep(2**attempt)
|
||||
continue
|
||||
return response
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
||||
"""Callable with AgentState and Runtime as arguments."""
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Middleware agent implementation."""
|
||||
|
||||
import itertools
|
||||
from collections.abc import Callable, Sequence
|
||||
from collections.abc import Callable, Generator, Sequence
|
||||
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
|
||||
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
@@ -35,12 +35,99 @@ from langchain.agents.structured_output import (
|
||||
)
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain.tools import ToolNode
|
||||
from langchain.tools.tool_node import ToolCallHandler, ToolCallRequest, ToolCallResponse
|
||||
|
||||
STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
|
||||
|
||||
ResponseT = TypeVar("ResponseT")
|
||||
|
||||
|
||||
def _chain_tool_call_handlers(
|
||||
handlers: list[ToolCallHandler],
|
||||
) -> ToolCallHandler | None:
|
||||
"""Compose multiple tool call handlers into a single middleware stack.
|
||||
|
||||
Args:
|
||||
handlers: Handlers in middleware order (first = outermost layer).
|
||||
|
||||
Returns:
|
||||
Single composed handler, or None if handlers is empty.
|
||||
"""
|
||||
if not handlers:
|
||||
return None
|
||||
|
||||
if len(handlers) == 1:
|
||||
return handlers[0]
|
||||
|
||||
def _extract_return_value(stop_iteration: StopIteration) -> ToolCallResponse:
|
||||
"""Extract ToolCallResponse from StopIteration, validating protocol compliance."""
|
||||
if stop_iteration.value is None:
|
||||
msg = "on_tool_call handler must explicitly return a ToolCallResponse"
|
||||
raise ValueError(msg)
|
||||
return stop_iteration.value
|
||||
|
||||
def compose_two(outer: ToolCallHandler, inner: ToolCallHandler) -> ToolCallHandler:
|
||||
"""Compose two handlers where outer wraps inner."""
|
||||
|
||||
def composed(
|
||||
request: ToolCallRequest,
|
||||
state: Any,
|
||||
runtime: Any,
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
outer_gen = outer(request, state, runtime)
|
||||
|
||||
# Initialize outer generator
|
||||
try:
|
||||
outer_request = next(outer_gen)
|
||||
except StopIteration as e:
|
||||
return _extract_return_value(e)
|
||||
|
||||
# Outer retry loop
|
||||
while True:
|
||||
inner_gen = inner(outer_request, state, runtime)
|
||||
|
||||
# Initialize inner generator
|
||||
try:
|
||||
inner_request = next(inner_gen)
|
||||
except StopIteration as e:
|
||||
# Inner returned immediately - send to outer
|
||||
inner_response = _extract_return_value(e)
|
||||
try:
|
||||
outer_request = outer_gen.send(inner_response)
|
||||
continue # Outer retrying
|
||||
except StopIteration as e:
|
||||
return _extract_return_value(e)
|
||||
|
||||
# Inner retry loop - yield to next layer (or tool)
|
||||
while True:
|
||||
tool_response = yield inner_request
|
||||
|
||||
try:
|
||||
inner_request = inner_gen.send(tool_response)
|
||||
# Inner retrying - continue inner loop
|
||||
except StopIteration as e:
|
||||
# Inner done - send response to outer
|
||||
inner_response = _extract_return_value(e)
|
||||
break
|
||||
|
||||
# Send inner's final response to outer
|
||||
try:
|
||||
outer_request = outer_gen.send(inner_response)
|
||||
# Outer retrying - continue outer loop
|
||||
except StopIteration as e:
|
||||
# Outer done - return final response
|
||||
return _extract_return_value(e)
|
||||
|
||||
return composed
|
||||
|
||||
# Compose right-to-left: handlers[0](handlers[1](...(handlers[-1](tool))))
|
||||
result = handlers[-1]
|
||||
for handler in reversed(handlers[:-1]):
|
||||
result = compose_two(handler, result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
|
||||
"""Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations.
|
||||
|
||||
@@ -226,6 +313,20 @@ def create_agent( # noqa: PLR0915
|
||||
structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
|
||||
middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
|
||||
|
||||
# Validate middleware and collect handlers
|
||||
assert len({m.name for m in middleware}) == len(middleware), ( # noqa: S101
|
||||
"Please remove duplicate middleware instances."
|
||||
)
|
||||
middleware_w_on_tool_call = [
|
||||
m for m in middleware if m.__class__.on_tool_call is not AgentMiddleware.on_tool_call
|
||||
]
|
||||
|
||||
# Chain all on_tool_call handlers into a single composed handler
|
||||
on_tool_call_handler = None
|
||||
if middleware_w_on_tool_call:
|
||||
handlers = [m.on_tool_call for m in middleware_w_on_tool_call]
|
||||
on_tool_call_handler = _chain_tool_call_handlers(handlers)
|
||||
|
||||
# Setup tools
|
||||
tool_node: ToolNode | None = None
|
||||
if isinstance(tools, list):
|
||||
@@ -237,7 +338,11 @@ def create_agent( # noqa: PLR0915
|
||||
available_tools = middleware_tools + regular_tools
|
||||
|
||||
# Only create ToolNode if we have client-side tools
|
||||
tool_node = ToolNode(tools=available_tools) if available_tools else None
|
||||
tool_node = (
|
||||
ToolNode(tools=available_tools, on_tool_call=on_tool_call_handler)
|
||||
if available_tools
|
||||
else None
|
||||
)
|
||||
|
||||
# Default tools for ModelRequest initialization
|
||||
# Include built-ins and regular tools (can be changed dynamically by middleware)
|
||||
@@ -248,7 +353,7 @@ def create_agent( # noqa: PLR0915
|
||||
if tool_node:
|
||||
# Add middleware tools to existing ToolNode
|
||||
available_tools = list(tool_node.tools_by_name.values()) + middleware_tools
|
||||
tool_node = ToolNode(available_tools)
|
||||
tool_node = ToolNode(available_tools, on_tool_call=on_tool_call_handler)
|
||||
|
||||
# default_tools includes all client-side tools (no built-ins or structured tools)
|
||||
default_tools = available_tools
|
||||
@@ -256,10 +361,6 @@ def create_agent( # noqa: PLR0915
|
||||
# No tools provided, only middleware_tools available
|
||||
default_tools = middleware_tools
|
||||
|
||||
# validate middleware
|
||||
assert len({m.name for m in middleware}) == len(middleware), ( # noqa: S101
|
||||
"Please remove duplicate middleware instances."
|
||||
)
|
||||
middleware_w_before = [
|
||||
m
|
||||
for m in middleware
|
||||
|
||||
@@ -38,21 +38,22 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
from collections.abc import Callable, Generator, Sequence
|
||||
from copy import copy, deepcopy
|
||||
from dataclasses import replace
|
||||
from dataclasses import dataclass, replace
|
||||
from types import UnionType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
)
|
||||
from typing import Optional as Optional
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@@ -75,12 +76,11 @@ from langchain_core.tools.base import (
|
||||
from langgraph._internal._runnable import RunnableCallable
|
||||
from langgraph.errors import GraphBubbleUp
|
||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
from langgraph.runtime import get_runtime
|
||||
from langgraph.types import Command, Send
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.store.base import BaseStore
|
||||
|
||||
@@ -100,6 +100,62 @@ TOOL_INVOCATION_ERROR_TEMPLATE = (
|
||||
)
|
||||
|
||||
|
||||
@dataclass()
|
||||
class ToolCallRequest:
|
||||
"""Request passed to on_tool_call handler before tool execution.
|
||||
|
||||
Attributes:
|
||||
tool_call: The tool call dict containing name, args, and id.
|
||||
tool: The BaseTool instance that will be invoked.
|
||||
|
||||
Note:
|
||||
tool_call["args"] can be mutated directly to modify arguments.
|
||||
"""
|
||||
|
||||
tool_call: ToolCall
|
||||
tool: BaseTool
|
||||
|
||||
|
||||
@dataclass()
|
||||
class ToolCallResponse:
|
||||
"""Response returned from on_tool_call handler after tool execution.
|
||||
|
||||
The action field determines control flow:
|
||||
- "continue": Handler completed successfully, use result
|
||||
- "raise": Handler wants to propagate the exception
|
||||
|
||||
Attributes:
|
||||
action: Control flow directive ("continue" or "raise").
|
||||
result: ToolMessage or Command when action="continue".
|
||||
exception: The exception when action="raise", or for logging when
|
||||
action="continue" with an error ToolMessage.
|
||||
"""
|
||||
|
||||
action: Literal["continue", "raise"]
|
||||
result: ToolMessage | Command | None = None
|
||||
exception: Exception | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate that required fields are present based on action."""
|
||||
if self.action == "continue" and self.result is None:
|
||||
msg = "action='continue' requires a result"
|
||||
raise ValueError(msg)
|
||||
if self.action == "raise" and self.exception is None:
|
||||
msg = "action='raise' requires an exception"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
ToolCallHandler = Callable[
|
||||
[ToolCallRequest, Any, Any], Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]
|
||||
]
|
||||
"""Generator-based handler that intercepts tool execution.
|
||||
|
||||
Receives a ToolCallRequest, state, and runtime; yields modified ToolCallRequests;
|
||||
receives ToolCallResponses; and returns a final ToolCallResponse. Supports multiple
|
||||
yields for retry logic.
|
||||
"""
|
||||
|
||||
|
||||
def msg_content_output(output: Any) -> str | list[dict]:
|
||||
"""Convert tool output to valid message content format.
|
||||
|
||||
@@ -156,7 +212,7 @@ class ToolInvocationError(Exception):
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
def _default_handle_tool_errors(e: Exception) -> str:
|
||||
def _default_handle_tool_errors(e: ToolInvocationError) -> str:
|
||||
"""Default error handler for tool errors.
|
||||
|
||||
If the tool is a tool invocation error, return its message.
|
||||
@@ -300,8 +356,8 @@ class ToolNode(RunnableCallable):
|
||||
Output format depends on input type and tool behavior:
|
||||
|
||||
**For Regular tools**:
|
||||
- Dict input → ``{"messages": [ToolMessage(...)]}``
|
||||
- List input → ``[ToolMessage(...)]``
|
||||
- Dict input -> ``{"messages": [ToolMessage(...)]}``
|
||||
- List input -> ``[ToolMessage(...)]``
|
||||
|
||||
**For Command tools**:
|
||||
- Returns ``[Command(...)]`` or mixed list with regular tool outputs
|
||||
@@ -335,6 +391,12 @@ class ToolNode(RunnableCallable):
|
||||
- catches tool invocation errors (due to invalid arguments provided by the model) and returns a descriptive error message
|
||||
- ignores tool execution errors (they will be re-raised)
|
||||
|
||||
on_tool_call: Optional handler to intercept tool execution. Receives
|
||||
``ToolCallRequest``, yields potentially modified requests, receives
|
||||
``ToolCallResponse`` via ``.send()``, and returns final ``ToolCallResponse``.
|
||||
Enables retries, argument modification, and custom error handling.
|
||||
Defaults to ``None``.
|
||||
|
||||
messages_key: The key in the state dictionary that contains the message list.
|
||||
This same key will be used for the output ToolMessages.
|
||||
Defaults to "messages".
|
||||
@@ -378,6 +440,23 @@ class ToolNode(RunnableCallable):
|
||||
|
||||
tool_node = ToolNode([my_tool], handle_tool_errors=handle_errors)
|
||||
```
|
||||
|
||||
Intercepting tool calls:
|
||||
|
||||
```python
|
||||
from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse
|
||||
|
||||
def retry_handler(request):
|
||||
\"\"\"Retry failed tool calls up to 3 times.\"\"\"
|
||||
for attempt in range(3):
|
||||
response = yield request
|
||||
if response.action == "continue":
|
||||
return response
|
||||
# Retry on error
|
||||
return response # Final attempt
|
||||
|
||||
tool_node = ToolNode([my_tool], on_tool_call=retry_handler)
|
||||
```
|
||||
""" # noqa: E501
|
||||
|
||||
name: str = "tools"
|
||||
@@ -393,6 +472,7 @@ class ToolNode(RunnableCallable):
|
||||
| Callable[..., str]
|
||||
| type[Exception]
|
||||
| tuple[type[Exception], ...] = _default_handle_tool_errors,
|
||||
on_tool_call: ToolCallHandler | None = None,
|
||||
messages_key: str = "messages",
|
||||
) -> None:
|
||||
"""Initialize the ToolNode with the provided tools and configuration.
|
||||
@@ -402,6 +482,7 @@ class ToolNode(RunnableCallable):
|
||||
name: Node name for graph identification.
|
||||
tags: Optional metadata tags.
|
||||
handle_tool_errors: Error handling configuration.
|
||||
on_tool_call: Optional handler to intercept tool execution.
|
||||
messages_key: State key containing messages.
|
||||
"""
|
||||
super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False)
|
||||
@@ -409,6 +490,7 @@ class ToolNode(RunnableCallable):
|
||||
self._tool_to_state_args: dict[str, dict[str, str | None]] = {}
|
||||
self._tool_to_store_arg: dict[str, str | None] = {}
|
||||
self._handle_tool_errors = handle_tool_errors
|
||||
self._on_tool_call = on_tool_call
|
||||
self._messages_key = messages_key
|
||||
for tool in tools:
|
||||
if not isinstance(tool, BaseTool):
|
||||
@@ -429,13 +511,24 @@ class ToolNode(RunnableCallable):
|
||||
input: list[AnyMessage] | dict[str, Any] | BaseModel,
|
||||
config: RunnableConfig,
|
||||
*,
|
||||
# Optional[BaseStore] should not change to BaseStore | None
|
||||
# until we support injection of store using `BaseStore | None` annotation
|
||||
store: Optional[BaseStore], # noqa: UP045
|
||||
) -> Any:
|
||||
try:
|
||||
runtime = get_runtime()
|
||||
except RuntimeError:
|
||||
# Running outside of the LangGrah runtime context (e.g., unit-tests)
|
||||
runtime = None
|
||||
tool_calls, input_type = self._parse_input(input, store)
|
||||
config_list = get_config_list(config, len(tool_calls))
|
||||
input_types = [input_type] * len(tool_calls)
|
||||
inputs = [input] * len(tool_calls)
|
||||
runtimes = [runtime] * len(tool_calls)
|
||||
with get_executor_for_config(config) as executor:
|
||||
outputs = [*executor.map(self._run_one, tool_calls, input_types, config_list)]
|
||||
outputs = [
|
||||
*executor.map(self._run_one, tool_calls, input_types, config_list, inputs, runtimes)
|
||||
]
|
||||
|
||||
return self._combine_tool_outputs(outputs, input_type)
|
||||
|
||||
@@ -444,11 +537,18 @@ class ToolNode(RunnableCallable):
|
||||
input: list[AnyMessage] | dict[str, Any] | BaseModel,
|
||||
config: RunnableConfig,
|
||||
*,
|
||||
# Optional[BaseStore] should not change to BaseStore | None
|
||||
# until we support injection of store using `BaseStore | None` annotation
|
||||
store: Optional[BaseStore], # noqa: UP045
|
||||
) -> Any:
|
||||
try:
|
||||
runtime = get_runtime()
|
||||
except RuntimeError:
|
||||
# Running outside of the LangGrah runtime context (e.g., unit-tests)
|
||||
runtime = None
|
||||
tool_calls, input_type = self._parse_input(input, store)
|
||||
outputs = await asyncio.gather(
|
||||
*(self._arun_one(call, input_type, config) for call in tool_calls)
|
||||
*(self._arun_one(call, input_type, config, input, runtime) for call in tool_calls)
|
||||
)
|
||||
|
||||
return self._combine_tool_outputs(outputs, input_type)
|
||||
@@ -495,20 +595,19 @@ class ToolNode(RunnableCallable):
|
||||
combined_outputs.append(parent_command)
|
||||
return combined_outputs
|
||||
|
||||
def _run_one(
|
||||
self,
|
||||
call: ToolCall,
|
||||
input_type: Literal["list", "dict", "tool_calls"],
|
||||
config: RunnableConfig,
|
||||
) -> ToolMessage | Command:
|
||||
"""Run a single tool call synchronously."""
|
||||
if invalid_tool_message := self._validate_tool_call(call):
|
||||
return invalid_tool_message
|
||||
def _execute_tool_sync(
|
||||
self, request: ToolCallRequest, input_type: Literal["list", "dict", "tool_calls"], config: RunnableConfig
|
||||
) -> ToolCallResponse:
|
||||
"""Execute tool and return response.
|
||||
|
||||
Applies handle_tool_errors configuration. When on_tool_call is configured,
|
||||
unhandled errors return action="raise" instead of raising immediately.
|
||||
"""
|
||||
call = request.tool_call
|
||||
tool = request.tool
|
||||
call_args = {**call, "type": "tool_call"}
|
||||
|
||||
try:
|
||||
call_args = {**call, "type": "tool_call"}
|
||||
tool = self.tools_by_name[call["name"]]
|
||||
|
||||
try:
|
||||
response = tool.invoke(call_args, config)
|
||||
except ValidationError as exc:
|
||||
@@ -541,40 +640,110 @@ class ToolNode(RunnableCallable):
|
||||
# default behavior is catching all exceptions
|
||||
handled_types = (Exception,)
|
||||
|
||||
# Unhandled
|
||||
# Check if error is handled
|
||||
if not self._handle_tool_errors or not isinstance(e, handled_types):
|
||||
# Error is not handled
|
||||
if self._on_tool_call is not None:
|
||||
# If handler exists, return action="raise" so handler can decide
|
||||
return ToolCallResponse(action="raise", exception=e)
|
||||
# No handler - maintain backward compatibility by raising immediately
|
||||
raise
|
||||
# Handled
|
||||
|
||||
# Error is handled - create error ToolMessage
|
||||
content = _handle_tool_error(e, flag=self._handle_tool_errors)
|
||||
return ToolMessage(
|
||||
error_message = ToolMessage(
|
||||
content=content,
|
||||
name=call["name"],
|
||||
tool_call_id=call["id"],
|
||||
status="error",
|
||||
)
|
||||
return ToolCallResponse(action="continue", result=error_message, exception=e)
|
||||
|
||||
# Process successful response
|
||||
if isinstance(response, Command):
|
||||
return self._validate_tool_command(response, call, input_type)
|
||||
# Validate Command before returning to handler
|
||||
validated_command = self._validate_tool_command(response, request.tool_call, input_type)
|
||||
return ToolCallResponse(action="continue", result=validated_command)
|
||||
if isinstance(response, ToolMessage):
|
||||
response.content = cast("str | list", msg_content_output(response.content))
|
||||
return response
|
||||
return ToolCallResponse(action="continue", result=response)
|
||||
|
||||
msg = f"Tool {call['name']} returned unexpected type: {type(response)}"
|
||||
raise TypeError(msg)
|
||||
|
||||
async def _arun_one(
|
||||
def _run_one(
|
||||
self,
|
||||
call: ToolCall,
|
||||
input_type: Literal["list", "dict", "tool_calls"],
|
||||
config: RunnableConfig,
|
||||
input: list[AnyMessage] | dict[str, Any] | BaseModel,
|
||||
runtime: Any,
|
||||
) -> ToolMessage | Command:
|
||||
"""Run a single tool call asynchronously."""
|
||||
"""Run a single tool call synchronously."""
|
||||
if invalid_tool_message := self._validate_tool_call(call):
|
||||
return invalid_tool_message
|
||||
|
||||
try:
|
||||
call_args = {**call, "type": "tool_call"}
|
||||
tool = self.tools_by_name[call["name"]]
|
||||
tool = self.tools_by_name[call["name"]]
|
||||
|
||||
# Create the tool request
|
||||
tool_request = ToolCallRequest(
|
||||
tool_call=call,
|
||||
tool=tool,
|
||||
)
|
||||
|
||||
if self._on_tool_call is None:
|
||||
tool_response = self._execute_tool_sync(tool_request, input_type, config)
|
||||
else:
|
||||
# Generator protocol: start generator, send responses, receive requests
|
||||
gen = self._on_tool_call(tool_request, input, runtime)
|
||||
|
||||
try:
|
||||
request = next(gen)
|
||||
except StopIteration:
|
||||
msg = "on_tool_call handler must yield at least once before returning"
|
||||
raise ValueError(msg)
|
||||
|
||||
while True:
|
||||
tool_response = self._execute_tool_sync(request, input_type, config)
|
||||
try:
|
||||
request = gen.send(tool_response)
|
||||
except StopIteration as e:
|
||||
if e.value is None:
|
||||
msg = (
|
||||
"on_tool_call handler must explicitly return a ToolCallResponse. "
|
||||
"Ensure your handler ends with 'return response'."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
tool_response = e.value
|
||||
break
|
||||
|
||||
# Apply action directive
|
||||
if tool_response.action == "raise":
|
||||
if tool_response.exception is None:
|
||||
msg = "ToolCallResponse with action='raise' must have an exception"
|
||||
raise ValueError(msg)
|
||||
raise tool_response.exception
|
||||
|
||||
result = tool_response.result
|
||||
if result is None:
|
||||
msg = "ToolCallResponse with action='continue' must have a result"
|
||||
raise ValueError(msg)
|
||||
|
||||
return result
|
||||
|
||||
async def _execute_tool_async(
|
||||
self, request: ToolCallRequest, input_type: Literal["list", "dict", "tool_calls"], config: RunnableConfig
|
||||
) -> ToolCallResponse:
|
||||
"""Execute tool asynchronously and return response.
|
||||
|
||||
Applies handle_tool_errors configuration. When on_tool_call is configured,
|
||||
unhandled errors return action="raise" instead of raising immediately.
|
||||
"""
|
||||
call = request.tool_call
|
||||
tool = request.tool
|
||||
call_args = {**call, "type": "tool_call"}
|
||||
|
||||
try:
|
||||
try:
|
||||
response = await tool.ainvoke(call_args, config)
|
||||
except ValidationError as exc:
|
||||
@@ -607,27 +776,97 @@ class ToolNode(RunnableCallable):
|
||||
# default behavior is catching all exceptions
|
||||
handled_types = (Exception,)
|
||||
|
||||
# Unhandled
|
||||
# Check if error is handled
|
||||
if not self._handle_tool_errors or not isinstance(e, handled_types):
|
||||
# Error is not handled
|
||||
if self._on_tool_call is not None:
|
||||
# If handler exists, return action="raise" so handler can decide
|
||||
return ToolCallResponse(action="raise", exception=e)
|
||||
# No handler - maintain backward compatibility by raising immediately
|
||||
raise
|
||||
# Handled
|
||||
content = _handle_tool_error(e, flag=self._handle_tool_errors)
|
||||
|
||||
return ToolMessage(
|
||||
# Error is handled - create error ToolMessage
|
||||
content = _handle_tool_error(e, flag=self._handle_tool_errors)
|
||||
error_message = ToolMessage(
|
||||
content=content,
|
||||
name=call["name"],
|
||||
tool_call_id=call["id"],
|
||||
status="error",
|
||||
)
|
||||
return ToolCallResponse(action="continue", result=error_message, exception=e)
|
||||
|
||||
# Process successful response
|
||||
if isinstance(response, Command):
|
||||
return self._validate_tool_command(response, call, input_type)
|
||||
# Validate Command before returning to handler
|
||||
validated_command = self._validate_tool_command(response, request.tool_call, input_type)
|
||||
return ToolCallResponse(action="continue", result=validated_command)
|
||||
if isinstance(response, ToolMessage):
|
||||
response.content = cast("str | list", msg_content_output(response.content))
|
||||
return response
|
||||
return ToolCallResponse(action="continue", result=response)
|
||||
|
||||
msg = f"Tool {call['name']} returned unexpected type: {type(response)}"
|
||||
raise TypeError(msg)
|
||||
|
||||
async def _arun_one(
|
||||
self,
|
||||
call: ToolCall,
|
||||
input_type: Literal["list", "dict", "tool_calls"],
|
||||
config: RunnableConfig,
|
||||
input: list[AnyMessage] | dict[str, Any] | BaseModel,
|
||||
runtime: Any,
|
||||
) -> ToolMessage | Command:
|
||||
"""Run a single tool call asynchronously."""
|
||||
if invalid_tool_message := self._validate_tool_call(call):
|
||||
return invalid_tool_message
|
||||
|
||||
tool = self.tools_by_name[call["name"]]
|
||||
|
||||
# Create the tool request
|
||||
tool_request = ToolCallRequest(
|
||||
tool_call=call,
|
||||
tool=tool,
|
||||
)
|
||||
|
||||
if self._on_tool_call is None:
|
||||
tool_response = await self._execute_tool_async(tool_request, input_type, config)
|
||||
else:
|
||||
# Generator protocol: handler is sync generator, tool execution is async
|
||||
gen = self._on_tool_call(tool_request, input, runtime)
|
||||
|
||||
try:
|
||||
request = next(gen)
|
||||
except StopIteration:
|
||||
msg = "on_tool_call handler must yield at least once before returning"
|
||||
raise ValueError(msg)
|
||||
|
||||
while True:
|
||||
tool_response = await self._execute_tool_async(request, input_type, config)
|
||||
try:
|
||||
request = gen.send(tool_response)
|
||||
except StopIteration as e:
|
||||
if e.value is None:
|
||||
msg = (
|
||||
"on_tool_call handler must explicitly return a ToolCallResponse. "
|
||||
"Ensure your handler ends with 'return response'."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
tool_response = e.value
|
||||
break
|
||||
|
||||
# Apply action directive
|
||||
if tool_response.action == "raise":
|
||||
if tool_response.exception is None:
|
||||
msg = "ToolCallResponse with action='raise' must have an exception"
|
||||
raise ValueError(msg)
|
||||
raise tool_response.exception
|
||||
|
||||
result = tool_response.result
|
||||
if result is None:
|
||||
msg = "ToolCallResponse with action='continue' must have a result"
|
||||
raise ValueError(msg)
|
||||
|
||||
return result
|
||||
|
||||
def _parse_input(
|
||||
self,
|
||||
input: list[AnyMessage] | dict[str, Any] | BaseModel,
|
||||
|
||||
@@ -0,0 +1,396 @@
|
||||
"""Unit tests for on_tool_call middleware hook."""
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal, Union
|
||||
import typing
|
||||
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool, tool
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware
|
||||
from langchain.agents.middleware_agent import create_agent
|
||||
from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse
|
||||
|
||||
|
||||
class FakeModel(GenericFakeChatModel):
|
||||
"""Fake chat model for testing."""
|
||||
|
||||
tool_style: Literal["openai", "anthropic"] = "openai"
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: typing.Sequence[Union[dict[str, Any], type[BaseModel], typing.Callable, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
if len(tools) == 0:
|
||||
msg = "Must provide at least one tool"
|
||||
raise ValueError(msg)
|
||||
|
||||
tool_dicts = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, dict):
|
||||
tool_dicts.append(tool)
|
||||
continue
|
||||
if not isinstance(tool, BaseTool):
|
||||
msg = "Only BaseTool and dict is supported by FakeModel.bind_tools"
|
||||
raise TypeError(msg)
|
||||
|
||||
# NOTE: this is a simplified tool spec for testing purposes only
|
||||
if self.tool_style == "openai":
|
||||
tool_dicts.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
},
|
||||
}
|
||||
)
|
||||
elif self.tool_style == "anthropic":
|
||||
tool_dicts.append(
|
||||
{
|
||||
"name": tool.name,
|
||||
}
|
||||
)
|
||||
|
||||
return self.bind(tools=tool_dicts)
|
||||
|
||||
|
||||
@tool
|
||||
def add_tool(x: int, y: int) -> int:
|
||||
"""Add two numbers."""
|
||||
return x + y
|
||||
|
||||
|
||||
@tool
|
||||
def failing_tool(x: int) -> int:
|
||||
"""Tool that raises an error."""
|
||||
msg = "Intentional failure"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def test_single_middleware_on_tool_call():
|
||||
"""Test that a single middleware can intercept tool calls."""
|
||||
call_log = []
|
||||
|
||||
class LoggingMiddleware(AgentMiddleware):
|
||||
"""Middleware that logs tool calls."""
|
||||
|
||||
def on_tool_call(
|
||||
self, request: ToolCallRequest, state, runtime
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
call_log.append(f"before_{request.tool.name}")
|
||||
response = yield request
|
||||
call_log.append(f"after_{request.tool.name}")
|
||||
return response
|
||||
|
||||
model = FakeModel(
|
||||
messages=iter(
|
||||
[
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"name": "add_tool", "args": {"x": 2, "y": 3}, "id": "1"}],
|
||||
),
|
||||
AIMessage(content="Done"),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[add_tool],
|
||||
middleware=[LoggingMiddleware()],
|
||||
)
|
||||
|
||||
result = agent.compile().invoke({"messages": [HumanMessage("Add 2 and 3")]})
|
||||
|
||||
assert "before_add_tool" in call_log
|
||||
assert "after_add_tool" in call_log
|
||||
assert call_log.index("before_add_tool") < call_log.index("after_add_tool")
|
||||
|
||||
# Check that tool executed successfully
|
||||
tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)]
|
||||
assert len(tool_messages) == 1
|
||||
assert tool_messages[0].content == "5"
|
||||
|
||||
|
||||
def test_multiple_middleware_chaining():
|
||||
"""Test that multiple middleware chain correctly (outer wraps inner)."""
|
||||
call_order = []
|
||||
|
||||
class OuterMiddleware(AgentMiddleware):
|
||||
"""Outer middleware in the chain."""
|
||||
|
||||
def on_tool_call(
|
||||
self, request: ToolCallRequest, state, runtime
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
call_order.append("outer_start")
|
||||
response = yield request
|
||||
call_order.append("outer_end")
|
||||
return response
|
||||
|
||||
class InnerMiddleware(AgentMiddleware):
|
||||
"""Inner middleware in the chain."""
|
||||
|
||||
def on_tool_call(
|
||||
self, request: ToolCallRequest, state, runtime
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
call_order.append("inner_start")
|
||||
response = yield request
|
||||
call_order.append("inner_end")
|
||||
return response
|
||||
|
||||
model = FakeModel(
|
||||
messages=iter(
|
||||
[
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"name": "add_tool", "args": {"x": 1, "y": 1}, "id": "1"}],
|
||||
),
|
||||
AIMessage(content="Done"),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[add_tool],
|
||||
middleware=[OuterMiddleware(), InnerMiddleware()],
|
||||
)
|
||||
|
||||
agent.compile().invoke({"messages": [HumanMessage("Add 1 and 1")]})
|
||||
|
||||
# Verify order: outer_start -> inner_start -> tool -> inner_end -> outer_end
|
||||
assert call_order == ["outer_start", "inner_start", "inner_end", "outer_end"]
|
||||
|
||||
|
||||
def test_middleware_retry_logic():
|
||||
"""Test that middleware can retry tool calls."""
|
||||
attempt_count = 0
|
||||
|
||||
class RetryMiddleware(AgentMiddleware):
|
||||
"""Middleware that retries on failure."""
|
||||
|
||||
def on_tool_call(
|
||||
self, request: ToolCallRequest, state, runtime
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
nonlocal attempt_count
|
||||
max_retries = 2
|
||||
|
||||
for attempt in range(max_retries):
|
||||
attempt_count += 1
|
||||
response = yield request
|
||||
|
||||
if response.action == "continue":
|
||||
return response
|
||||
|
||||
if response.action == "raise" and attempt < max_retries - 1:
|
||||
# Retry
|
||||
continue
|
||||
|
||||
# Convert error to success message
|
||||
return ToolCallResponse(
|
||||
action="continue",
|
||||
result=ToolMessage(
|
||||
content=f"Failed after {max_retries} attempts",
|
||||
name=request.tool_call["name"],
|
||||
tool_call_id=request.tool_call["id"],
|
||||
status="error",
|
||||
),
|
||||
)
|
||||
|
||||
raise AssertionError("Unreachable")
|
||||
|
||||
model = FakeModel(
|
||||
messages=iter(
|
||||
[
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"name": "failing_tool", "args": {"x": 1}, "id": "1"}],
|
||||
),
|
||||
AIMessage(content="Done"),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[failing_tool],
|
||||
middleware=[RetryMiddleware()],
|
||||
)
|
||||
|
||||
result = agent.compile().invoke({"messages": [HumanMessage("Test retry")]})
|
||||
|
||||
# Should have attempted twice
|
||||
assert attempt_count == 2
|
||||
|
||||
# Check that we got an error message
|
||||
tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)]
|
||||
assert len(tool_messages) == 1
|
||||
assert "Failed after 2 attempts" in tool_messages[0].content
|
||||
|
||||
|
||||
def test_middleware_request_modification():
|
||||
"""Test that middleware can modify tool requests."""
|
||||
|
||||
class RequestModifierMiddleware(AgentMiddleware):
|
||||
"""Middleware that doubles the input."""
|
||||
|
||||
def on_tool_call(
|
||||
self, request: ToolCallRequest, state, runtime
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
# Modify the arguments
|
||||
modified_tool_call = {
|
||||
**request.tool_call,
|
||||
"args": {
|
||||
"x": request.tool_call["args"]["x"] * 2,
|
||||
"y": request.tool_call["args"]["y"] * 2,
|
||||
},
|
||||
}
|
||||
modified_request = ToolCallRequest(
|
||||
tool_call=modified_tool_call,
|
||||
tool=request.tool,
|
||||
)
|
||||
response = yield modified_request
|
||||
return response
|
||||
|
||||
model = FakeModel(
|
||||
messages=iter(
|
||||
[
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"name": "add_tool", "args": {"x": 1, "y": 2}, "id": "1"}],
|
||||
),
|
||||
AIMessage(content="Done"),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[add_tool],
|
||||
middleware=[RequestModifierMiddleware()],
|
||||
)
|
||||
|
||||
result = agent.compile().invoke({"messages": [HumanMessage("Add 1 and 2")]})
|
||||
|
||||
# Original: 1 + 2 = 3, Modified: 2 + 4 = 6
|
||||
tool_messages = [m for m in result["messages"] if isinstance(m, ToolMessage)]
|
||||
assert len(tool_messages) == 1
|
||||
assert tool_messages[0].content == "6"
|
||||
|
||||
|
||||
def test_multiple_middleware_with_retry():
|
||||
"""Test complex scenario with multiple middleware and retry logic."""
|
||||
call_log = []
|
||||
|
||||
class MonitoringMiddleware(AgentMiddleware):
|
||||
"""Outer middleware for monitoring."""
|
||||
|
||||
def on_tool_call(
|
||||
self, request: ToolCallRequest, state, runtime
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
call_log.append("monitoring_start")
|
||||
response = yield request
|
||||
call_log.append("monitoring_end")
|
||||
return response
|
||||
|
||||
class RetryMiddleware(AgentMiddleware):
|
||||
"""Inner middleware for retries."""
|
||||
|
||||
def on_tool_call(
|
||||
self, request: ToolCallRequest, state, runtime
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
call_log.append("retry_start")
|
||||
for attempt in range(2):
|
||||
call_log.append(f"retry_attempt_{attempt + 1}")
|
||||
response = yield request
|
||||
|
||||
if response.action == "continue":
|
||||
call_log.append("retry_success")
|
||||
return response
|
||||
|
||||
if attempt == 0: # Retry once
|
||||
call_log.append("retry_retry")
|
||||
continue
|
||||
|
||||
call_log.append("retry_failed")
|
||||
return response
|
||||
|
||||
model = FakeModel(
|
||||
messages=iter(
|
||||
[
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"name": "add_tool", "args": {"x": 5, "y": 7}, "id": "1"}],
|
||||
),
|
||||
AIMessage(content="Done"),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[add_tool],
|
||||
middleware=[MonitoringMiddleware(), RetryMiddleware()],
|
||||
)
|
||||
|
||||
agent.compile().invoke({"messages": [HumanMessage("Add 5 and 7")]})
|
||||
|
||||
# Verify the call sequence
|
||||
assert call_log[0] == "monitoring_start"
|
||||
assert call_log[1] == "retry_start"
|
||||
assert "retry_attempt_1" in call_log
|
||||
assert "retry_success" in call_log
|
||||
assert call_log[-1] == "monitoring_end"
|
||||
|
||||
|
||||
def test_mixed_middleware():
|
||||
"""Test middleware with both before_model and on_tool_call hooks."""
|
||||
call_log = []
|
||||
|
||||
class MixedMiddleware(AgentMiddleware):
|
||||
"""Middleware with multiple hooks."""
|
||||
|
||||
def before_model(self, state, runtime):
|
||||
call_log.append("before_model")
|
||||
return None
|
||||
|
||||
def on_tool_call(
|
||||
self, request: ToolCallRequest, state, runtime
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
call_log.append("on_tool_call_start")
|
||||
response = yield request
|
||||
call_log.append("on_tool_call_end")
|
||||
return response
|
||||
|
||||
model = FakeModel(
|
||||
messages=iter(
|
||||
[
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"name": "add_tool", "args": {"x": 10, "y": 20}, "id": "1"}],
|
||||
),
|
||||
AIMessage(content="Done"),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[add_tool],
|
||||
middleware=[MixedMiddleware()],
|
||||
)
|
||||
|
||||
agent.compile().invoke({"messages": [HumanMessage("Add 10 and 20")]})
|
||||
|
||||
# Both hooks should have been called
|
||||
assert "before_model" in call_log
|
||||
assert "on_tool_call_start" in call_log
|
||||
assert "on_tool_call_end" in call_log
|
||||
# before_model runs before on_tool_call
|
||||
assert call_log.index("before_model") < call_log.index("on_tool_call_start")
|
||||
383
libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py
Normal file
383
libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py
Normal file
@@ -0,0 +1,383 @@
|
||||
"""Tests for on_tool_call handler functionality."""
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from langchain.tools import ToolNode
|
||||
from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse
|
||||
|
||||
|
||||
# Test tools
|
||||
@tool
|
||||
def success_tool(x: int) -> int:
|
||||
"""A tool that always succeeds."""
|
||||
return x * 2
|
||||
|
||||
|
||||
@tool
|
||||
def error_tool(x: int) -> int:
|
||||
"""A tool that always raises ValueError."""
|
||||
msg = f"Error with value: {x}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
@tool
|
||||
def rate_limit_tool(x: int) -> int:
|
||||
"""A tool that simulates rate limit errors."""
|
||||
if not hasattr(rate_limit_tool, "_call_count"):
|
||||
rate_limit_tool._call_count = 0
|
||||
rate_limit_tool._call_count += 1
|
||||
|
||||
if rate_limit_tool._call_count < 3: # Fail first 2 times
|
||||
msg = "Rate limit exceeded"
|
||||
raise ValueError(msg)
|
||||
return x * 2
|
||||
|
||||
|
||||
def test_on_tool_call_passthrough() -> None:
|
||||
"""Test that a simple passthrough handler works."""
|
||||
|
||||
def passthrough_handler(
|
||||
request: ToolCallRequest, _state: Any, _runtime: Any
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
"""Simply pass through without modification."""
|
||||
response = yield request
|
||||
return response
|
||||
|
||||
tool_node = ToolNode([success_tool], on_tool_call=passthrough_handler)
|
||||
result = tool_node.invoke(
|
||||
{
|
||||
"messages": [
|
||||
AIMessage(
|
||||
"",
|
||||
tool_calls=[{"name": "success_tool", "args": {"x": 5}, "id": "1"}],
|
||||
)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
assert len(result["messages"]) == 1
|
||||
tool_message: ToolMessage = result["messages"][0]
|
||||
assert tool_message.content == "10"
|
||||
assert tool_message.status != "error"
|
||||
|
||||
|
||||
def test_on_tool_call_retry_success() -> None:
|
||||
"""Test that retry handler can recover from transient errors."""
|
||||
# Reset counter
|
||||
if hasattr(rate_limit_tool, "_call_count"):
|
||||
rate_limit_tool._call_count = 0
|
||||
|
||||
def retry_handler(
|
||||
request: ToolCallRequest, _state: Any, _runtime: Any
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
"""Retry up to 3 times."""
|
||||
max_retries = 3
|
||||
|
||||
for attempt in range(max_retries):
|
||||
response = yield request
|
||||
|
||||
if response.action == "continue":
|
||||
return response
|
||||
|
||||
# Retry on error
|
||||
if attempt < max_retries - 1:
|
||||
continue
|
||||
|
||||
# Final attempt failed - convert to error message
|
||||
return ToolCallResponse(
|
||||
action="continue",
|
||||
result=ToolMessage(
|
||||
content=f"Failed after {max_retries} attempts",
|
||||
name=request.tool_call["name"],
|
||||
tool_call_id=request.tool_call["id"],
|
||||
status="error",
|
||||
),
|
||||
)
|
||||
msg = "Unreachable code"
|
||||
raise AssertionError(msg)
|
||||
|
||||
tool_node = ToolNode([rate_limit_tool], on_tool_call=retry_handler, handle_tool_errors=False)
|
||||
result = tool_node.invoke(
|
||||
{
|
||||
"messages": [
|
||||
AIMessage(
|
||||
"",
|
||||
tool_calls=[{"name": "rate_limit_tool", "args": {"x": 5}, "id": "1"}],
|
||||
)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
assert len(result["messages"]) == 1
|
||||
tool_message: ToolMessage = result["messages"][0]
|
||||
assert tool_message.content == "10" # Should succeed on 3rd attempt
|
||||
assert tool_message.status != "error"
|
||||
|
||||
|
||||
def test_on_tool_call_convert_error_to_message() -> None:
|
||||
"""Test that handler can convert raised errors to error messages."""
|
||||
|
||||
def error_to_message_handler(
|
||||
request: ToolCallRequest, _state: Any, _runtime: Any
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
"""Convert any error to a user-friendly message."""
|
||||
response = yield request
|
||||
|
||||
if response.action == "raise":
|
||||
return ToolCallResponse(
|
||||
action="continue",
|
||||
result=ToolMessage(
|
||||
content=f"Tool failed: {response.exception}",
|
||||
name=request.tool_call["name"],
|
||||
tool_call_id=request.tool_call["id"],
|
||||
status="error",
|
||||
),
|
||||
exception=response.exception,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
tool_node = ToolNode(
|
||||
[error_tool], on_tool_call=error_to_message_handler, handle_tool_errors=False
|
||||
)
|
||||
result = tool_node.invoke(
|
||||
{
|
||||
"messages": [
|
||||
AIMessage(
|
||||
"",
|
||||
tool_calls=[{"name": "error_tool", "args": {"x": 5}, "id": "1"}],
|
||||
)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
assert len(result["messages"]) == 1
|
||||
tool_message: ToolMessage = result["messages"][0]
|
||||
assert "Tool failed" in tool_message.content
|
||||
assert "Error with value: 5" in tool_message.content
|
||||
assert tool_message.status == "error"
|
||||
|
||||
|
||||
def test_on_tool_call_let_error_raise() -> None:
|
||||
"""Test that handler can let errors propagate."""
|
||||
|
||||
def let_raise_handler(
|
||||
request: ToolCallRequest, _state: Any, _runtime: Any
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
"""Just return the response as-is, letting errors raise."""
|
||||
response = yield request
|
||||
return response
|
||||
|
||||
tool_node = ToolNode([error_tool], on_tool_call=let_raise_handler, handle_tool_errors=False)
|
||||
|
||||
with pytest.raises(ValueError, match=r"Error with value: 5"):
|
||||
tool_node.invoke(
|
||||
{
|
||||
"messages": [
|
||||
AIMessage(
|
||||
"",
|
||||
tool_calls=[{"name": "error_tool", "args": {"x": 5}, "id": "1"}],
|
||||
)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_on_tool_call_with_handled_errors() -> None:
|
||||
"""Test interaction between on_tool_call and handle_tool_errors."""
|
||||
call_count = {"count": 0}
|
||||
|
||||
def counting_handler(
|
||||
request: ToolCallRequest, _state: Any, _runtime: Any
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
"""Count how many times we're called."""
|
||||
call_count["count"] += 1
|
||||
response = yield request
|
||||
return response
|
||||
|
||||
# When handle_tool_errors=True, errors are converted to ToolMessages
|
||||
# so handler sees action="continue"
|
||||
tool_node = ToolNode([error_tool], on_tool_call=counting_handler, handle_tool_errors=True)
|
||||
result = tool_node.invoke(
|
||||
{
|
||||
"messages": [
|
||||
AIMessage(
|
||||
"",
|
||||
tool_calls=[{"name": "error_tool", "args": {"x": 5}, "id": "1"}],
|
||||
)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
assert call_count["count"] == 1
|
||||
assert len(result["messages"]) == 1
|
||||
tool_message: ToolMessage = result["messages"][0]
|
||||
assert tool_message.status == "error"
|
||||
assert "Please fix your mistakes" in tool_message.content
|
||||
|
||||
|
||||
def test_on_tool_call_must_return_value() -> None:
|
||||
"""Test that handler must return a ToolCallResponse."""
|
||||
|
||||
def no_return_handler(
|
||||
request: ToolCallRequest, _state: Any, _runtime: Any
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
"""Handler that doesn't return anything."""
|
||||
_ = yield request
|
||||
# Implicit return None
|
||||
|
||||
tool_node = ToolNode([success_tool], on_tool_call=no_return_handler)
|
||||
|
||||
with pytest.raises(ValueError, match=r"must explicitly return a ToolCallResponse"):
|
||||
tool_node.invoke(
|
||||
{
|
||||
"messages": [
|
||||
AIMessage(
|
||||
"",
|
||||
tool_calls=[{"name": "success_tool", "args": {"x": 5}, "id": "1"}],
|
||||
)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_on_tool_call_request_modification() -> None:
|
||||
"""Test that handler can modify the request before execution."""
|
||||
|
||||
def double_input_handler(
|
||||
request: ToolCallRequest, _state: Any, _runtime: Any
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
"""Double the input value."""
|
||||
# Modify the tool call args
|
||||
modified_tool_call = {
|
||||
**request.tool_call,
|
||||
"args": {**request.tool_call["args"], "x": request.tool_call["args"]["x"] * 2},
|
||||
}
|
||||
modified_request = ToolCallRequest(
|
||||
tool_call=modified_tool_call,
|
||||
tool=request.tool,
|
||||
)
|
||||
response = yield modified_request
|
||||
return response
|
||||
|
||||
tool_node = ToolNode([success_tool], on_tool_call=double_input_handler)
|
||||
result = tool_node.invoke(
|
||||
{
|
||||
"messages": [
|
||||
AIMessage(
|
||||
"",
|
||||
tool_calls=[{"name": "success_tool", "args": {"x": 5}, "id": "1"}],
|
||||
)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
assert len(result["messages"]) == 1
|
||||
tool_message: ToolMessage = result["messages"][0]
|
||||
# Input was 5, doubled to 10, then tool multiplies by 2 = 20
|
||||
assert tool_message.content == "20"
|
||||
|
||||
|
||||
def test_on_tool_call_response_validation() -> None:
|
||||
"""Test that ToolCallResponse validates action and required fields."""
|
||||
# Test action="continue" requires result
|
||||
with pytest.raises(ValueError, match=r"action='continue' requires a result"):
|
||||
ToolCallResponse(action="continue")
|
||||
|
||||
# Test action="raise" requires exception
|
||||
with pytest.raises(ValueError, match=r"action='raise' requires an exception"):
|
||||
ToolCallResponse(action="raise")
|
||||
|
||||
# Valid responses should work
|
||||
ToolCallResponse(
|
||||
action="continue",
|
||||
result=ToolMessage(content="test", tool_call_id="1", name="test"),
|
||||
)
|
||||
ToolCallResponse(action="raise", exception=ValueError("test"))
|
||||
|
||||
|
||||
def test_on_tool_call_without_handler_backward_compat() -> None:
|
||||
"""Test that tools work without on_tool_call handler (backward compatibility)."""
|
||||
# Success case
|
||||
tool_node = ToolNode([success_tool])
|
||||
result = tool_node.invoke(
|
||||
{
|
||||
"messages": [
|
||||
AIMessage(
|
||||
"",
|
||||
tool_calls=[{"name": "success_tool", "args": {"x": 5}, "id": "1"}],
|
||||
)
|
||||
]
|
||||
}
|
||||
)
|
||||
assert result["messages"][0].content == "10"
|
||||
|
||||
# Error case with handle_tool_errors=False
|
||||
tool_node_error = ToolNode([error_tool], handle_tool_errors=False)
|
||||
with pytest.raises(ValueError, match=r"Error with value: 5"):
|
||||
tool_node_error.invoke(
|
||||
{
|
||||
"messages": [
|
||||
AIMessage(
|
||||
"",
|
||||
tool_calls=[{"name": "error_tool", "args": {"x": 5}, "id": "1"}],
|
||||
)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
# Error case with handle_tool_errors=True
|
||||
tool_node_handled = ToolNode([error_tool], handle_tool_errors=True)
|
||||
result = tool_node_handled.invoke(
|
||||
{
|
||||
"messages": [
|
||||
AIMessage(
|
||||
"",
|
||||
tool_calls=[{"name": "error_tool", "args": {"x": 5}, "id": "1"}],
|
||||
)
|
||||
]
|
||||
}
|
||||
)
|
||||
assert result["messages"][0].status == "error"
|
||||
|
||||
|
||||
def test_on_tool_call_multiple_yields() -> None:
|
||||
"""Test that handler can yield multiple times for retries."""
|
||||
attempts = {"count": 0}
|
||||
|
||||
def multi_yield_handler(
|
||||
request: ToolCallRequest, _state: Any, _runtime: Any
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
"""Yield multiple times to track attempts."""
|
||||
max_attempts = 3
|
||||
|
||||
for _ in range(max_attempts):
|
||||
attempts["count"] += 1
|
||||
response = yield request
|
||||
|
||||
if response.action == "continue":
|
||||
return response
|
||||
|
||||
# All attempts failed
|
||||
return response
|
||||
|
||||
tool_node = ToolNode([error_tool], on_tool_call=multi_yield_handler, handle_tool_errors=False)
|
||||
|
||||
with pytest.raises(ValueError, match=r"Error with value: 5"):
|
||||
tool_node.invoke(
|
||||
{
|
||||
"messages": [
|
||||
AIMessage(
|
||||
"",
|
||||
tool_calls=[{"name": "error_tool", "args": {"x": 5}, "id": "1"}],
|
||||
)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
assert attempts["count"] == 3
|
||||
Reference in New Issue
Block a user