This commit is contained in:
Eugene Yurtsev
2025-10-06 16:16:19 -04:00
parent 65e073e85c
commit def2f147ae
3 changed files with 44 additions and 50 deletions

View File

@@ -9,7 +9,7 @@ from __future__ import annotations
import inspect
import logging
import time
from collections.abc import Callable
from types import UnionType
from typing import TYPE_CHECKING, Union, get_args, get_origin, get_type_hints
from langchain_core.messages import ToolMessage
@@ -17,13 +17,10 @@ from langchain_core.messages import ToolMessage
from langchain.agents.middleware.types import AgentMiddleware
# Import ToolCallResponse locally to avoid circular import
from langchain.tools.tool_node import ToolCallResponse
from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse
if TYPE_CHECKING:
from collections.abc import Generator
from types import UnionType
from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse
from collections.abc import Callable, Generator
logger = logging.getLogger(__name__)
@@ -72,13 +69,14 @@ def _infer_retriable_types(
if first_param.name in type_hints:
origin = get_origin(first_param.annotation)
# Handle Union types
if origin in [Union, UnionType]: # type: ignore[has-type]
if origin in [Union, UnionType]:
args = get_args(first_param.annotation)
if all(isinstance(arg, type) and issubclass(arg, Exception) for arg in args):
return tuple(args)
msg = (
"All types in retry predicate annotation must be Exception types. "
"For example, `def should_retry(e: Union[TimeoutError, ConnectionError]) -> bool`. "
"For example, `def should_retry(e: Union[TimeoutError, "
"ConnectionError]) -> bool`. "
f"Got '{first_param.annotation}' instead."
)
raise ValueError(msg)
@@ -260,14 +258,13 @@ class RetryMiddleware(AgentMiddleware):
return response
# If predicate is provided, check if we should retry
if self._retry_predicate is not None:
if not self._retry_predicate(exception):
logger.debug(
"Retry predicate returned False for %s in tool %s",
type(exception).__name__,
request.tool_call["name"],
)
return response
if self._retry_predicate is not None and not self._retry_predicate(exception):
logger.debug(
"Retry predicate returned False for %s in tool %s",
type(exception).__name__,
request.tool_call["name"],
)
return response
# Last attempt - return error
if attempt > self.max_retries:

View File

@@ -152,7 +152,8 @@ ToolCallHandler = Callable[
"""Generator-based handler that intercepts tool execution.
Receives a ToolCallRequest, state, and runtime; yields modified ToolCallRequests;
receives ToolCallResponses; and returns a final ToolCallResponse. Supports multiple yields for retry logic.
receives ToolCallResponses; and returns a final ToolCallResponse. Supports multiple
yields for retry logic.
"""
@@ -511,7 +512,7 @@ class ToolNode(RunnableCallable):
input: list[AnyMessage] | dict[str, Any] | BaseModel,
config: RunnableConfig,
*,
store: Optional[BaseStore],
store: BaseStore | None,
) -> Any:
try:
runtime = get_runtime()
@@ -535,7 +536,7 @@ class ToolNode(RunnableCallable):
input: list[AnyMessage] | dict[str, Any] | BaseModel,
config: RunnableConfig,
*,
store: Optional[BaseStore],
store: BaseStore | None,
) -> Any:
try:
runtime = get_runtime()

View File

@@ -1,6 +1,7 @@
"""Tests for on_tool_call handler functionality."""
from collections.abc import Generator
from typing import Any
import pytest
from langchain_core.messages import AIMessage, ToolMessage
@@ -41,7 +42,7 @@ def test_on_tool_call_passthrough() -> None:
"""Test that a simple passthrough handler works."""
def passthrough_handler(
request: ToolCallRequest, state, runtime
request: ToolCallRequest, _state: Any, _runtime: Any
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
"""Simply pass through without modification."""
response = yield request
@@ -65,14 +66,14 @@ def test_on_tool_call_passthrough() -> None:
assert tool_message.status != "error"
def test_on_tool_call_retry_success():
def test_on_tool_call_retry_success() -> None:
"""Test that retry handler can recover from transient errors."""
# Reset counter
if hasattr(rate_limit_tool, "_call_count"):
rate_limit_tool._call_count = 0
def retry_handler(
request: ToolCallRequest, state, runtime
request: ToolCallRequest, _state: Any, _runtime: Any
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
"""Retry up to 3 times."""
max_retries = 3
@@ -97,7 +98,8 @@ def test_on_tool_call_retry_success():
status="error",
),
)
raise AssertionError("Unreachable code")
msg = "Unreachable code"
raise AssertionError(msg)
tool_node = ToolNode([rate_limit_tool], on_tool_call=retry_handler, handle_tool_errors=False)
result = tool_node.invoke(
@@ -117,11 +119,11 @@ def test_on_tool_call_retry_success():
assert tool_message.status != "error"
def test_on_tool_call_convert_error_to_message():
def test_on_tool_call_convert_error_to_message() -> None:
"""Test that handler can convert raised errors to error messages."""
def error_to_message_handler(
request: ToolCallRequest, state, runtime
request: ToolCallRequest, _state: Any, _runtime: Any
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
"""Convert any error to a user-friendly message."""
response = yield request
@@ -161,11 +163,11 @@ def test_on_tool_call_convert_error_to_message():
assert tool_message.status == "error"
def test_on_tool_call_let_error_raise():
def test_on_tool_call_let_error_raise() -> None:
"""Test that handler can let errors propagate."""
def let_raise_handler(
request: ToolCallRequest, state, runtime
request: ToolCallRequest, _state: Any, _runtime: Any
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
"""Just return the response as-is, letting errors raise."""
response = yield request
@@ -173,7 +175,7 @@ def test_on_tool_call_let_error_raise():
tool_node = ToolNode([error_tool], on_tool_call=let_raise_handler, handle_tool_errors=False)
with pytest.raises(ValueError) as exc_info:
with pytest.raises(ValueError, match=r"Error with value: 5"):
tool_node.invoke(
{
"messages": [
@@ -185,15 +187,13 @@ def test_on_tool_call_let_error_raise():
}
)
assert "Error with value: 5" in str(exc_info.value)
def test_on_tool_call_with_handled_errors():
def test_on_tool_call_with_handled_errors() -> None:
"""Test interaction between on_tool_call and handle_tool_errors."""
call_count = {"count": 0}
def counting_handler(
request: ToolCallRequest, state, runtime
request: ToolCallRequest, _state: Any, _runtime: Any
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
"""Count how many times we're called."""
call_count["count"] += 1
@@ -221,19 +221,19 @@ def test_on_tool_call_with_handled_errors():
assert "Please fix your mistakes" in tool_message.content
def test_on_tool_call_must_return_value():
def test_on_tool_call_must_return_value() -> None:
"""Test that handler must return a ToolCallResponse."""
def no_return_handler(
request: ToolCallRequest, state, runtime
request: ToolCallRequest, _state: Any, _runtime: Any
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
"""Handler that doesn't return anything."""
response = yield request
_ = yield request
# Implicit return None
tool_node = ToolNode([success_tool], on_tool_call=no_return_handler)
with pytest.raises(ValueError) as exc_info:
with pytest.raises(ValueError, match=r"must explicitly return a ToolCallResponse"):
tool_node.invoke(
{
"messages": [
@@ -245,14 +245,12 @@ def test_on_tool_call_must_return_value():
}
)
assert "must explicitly return a ToolCallResponse" in str(exc_info.value)
def test_on_tool_call_request_modification():
def test_on_tool_call_request_modification() -> None:
"""Test that handler can modify the request before execution."""
def double_input_handler(
request: ToolCallRequest, state, runtime
request: ToolCallRequest, _state: Any, _runtime: Any
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
"""Double the input value."""
# Modify the tool call args
@@ -286,17 +284,15 @@ def test_on_tool_call_request_modification():
assert tool_message.content == "20"
def test_on_tool_call_response_validation():
def test_on_tool_call_response_validation() -> None:
"""Test that ToolCallResponse validates action and required fields."""
# Test action="return" requires result
with pytest.raises(ValueError) as exc_info:
with pytest.raises(ValueError, match=r"action='return' requires a result"):
ToolCallResponse(action="return")
assert "action='return' requires a result" in str(exc_info.value)
# Test action="raise" requires exception
with pytest.raises(ValueError) as exc_info:
with pytest.raises(ValueError, match=r"action='raise' requires an exception"):
ToolCallResponse(action="raise")
assert "action='raise' requires an exception" in str(exc_info.value)
# Valid responses should work
ToolCallResponse(
@@ -306,7 +302,7 @@ def test_on_tool_call_response_validation():
ToolCallResponse(action="raise", exception=ValueError("test"))
def test_on_tool_call_without_handler_backward_compat():
def test_on_tool_call_without_handler_backward_compat() -> None:
"""Test that tools work without on_tool_call handler (backward compatibility)."""
# Success case
tool_node = ToolNode([success_tool])
@@ -324,7 +320,7 @@ def test_on_tool_call_without_handler_backward_compat():
# Error case with handle_tool_errors=False
tool_node_error = ToolNode([error_tool], handle_tool_errors=False)
with pytest.raises(ValueError):
with pytest.raises(ValueError, match=r"Error with value: 5"):
tool_node_error.invoke(
{
"messages": [
@@ -351,12 +347,12 @@ def test_on_tool_call_without_handler_backward_compat():
assert result["messages"][0].status == "error"
def test_on_tool_call_multiple_yields():
def test_on_tool_call_multiple_yields() -> None:
"""Test that handler can yield multiple times for retries."""
attempts = {"count": 0}
def multi_yield_handler(
request: ToolCallRequest, state, runtime
request: ToolCallRequest, _state: Any, _runtime: Any
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
"""Yield multiple times to track attempts."""
max_attempts = 3
@@ -373,7 +369,7 @@ def test_on_tool_call_multiple_yields():
tool_node = ToolNode([error_tool], on_tool_call=multi_yield_handler, handle_tool_errors=False)
with pytest.raises(ValueError):
with pytest.raises(ValueError, match=r"Error with value: 5"):
tool_node.invoke(
{
"messages": [