mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 14:43:07 +00:00
x
This commit is contained in:
@@ -9,7 +9,7 @@ from __future__ import annotations
|
||||
import inspect
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from types import UnionType
|
||||
from typing import TYPE_CHECKING, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
from langchain_core.messages import ToolMessage
|
||||
@@ -17,13 +17,10 @@ from langchain_core.messages import ToolMessage
|
||||
from langchain.agents.middleware.types import AgentMiddleware
|
||||
|
||||
# Import ToolCallResponse locally to avoid circular import
|
||||
from langchain.tools.tool_node import ToolCallResponse
|
||||
from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
from types import UnionType
|
||||
|
||||
from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse
|
||||
from collections.abc import Callable, Generator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -72,13 +69,14 @@ def _infer_retriable_types(
|
||||
if first_param.name in type_hints:
|
||||
origin = get_origin(first_param.annotation)
|
||||
# Handle Union types
|
||||
if origin in [Union, UnionType]: # type: ignore[has-type]
|
||||
if origin in [Union, UnionType]:
|
||||
args = get_args(first_param.annotation)
|
||||
if all(isinstance(arg, type) and issubclass(arg, Exception) for arg in args):
|
||||
return tuple(args)
|
||||
msg = (
|
||||
"All types in retry predicate annotation must be Exception types. "
|
||||
"For example, `def should_retry(e: Union[TimeoutError, ConnectionError]) -> bool`. "
|
||||
"For example, `def should_retry(e: Union[TimeoutError, "
|
||||
"ConnectionError]) -> bool`. "
|
||||
f"Got '{first_param.annotation}' instead."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
@@ -260,14 +258,13 @@ class RetryMiddleware(AgentMiddleware):
|
||||
return response
|
||||
|
||||
# If predicate is provided, check if we should retry
|
||||
if self._retry_predicate is not None:
|
||||
if not self._retry_predicate(exception):
|
||||
logger.debug(
|
||||
"Retry predicate returned False for %s in tool %s",
|
||||
type(exception).__name__,
|
||||
request.tool_call["name"],
|
||||
)
|
||||
return response
|
||||
if self._retry_predicate is not None and not self._retry_predicate(exception):
|
||||
logger.debug(
|
||||
"Retry predicate returned False for %s in tool %s",
|
||||
type(exception).__name__,
|
||||
request.tool_call["name"],
|
||||
)
|
||||
return response
|
||||
|
||||
# Last attempt - return error
|
||||
if attempt > self.max_retries:
|
||||
|
||||
@@ -152,7 +152,8 @@ ToolCallHandler = Callable[
|
||||
"""Generator-based handler that intercepts tool execution.
|
||||
|
||||
Receives a ToolCallRequest, state, and runtime; yields modified ToolCallRequests;
|
||||
receives ToolCallResponses; and returns a final ToolCallResponse. Supports multiple yields for retry logic.
|
||||
receives ToolCallResponses; and returns a final ToolCallResponse. Supports multiple
|
||||
yields for retry logic.
|
||||
"""
|
||||
|
||||
|
||||
@@ -511,7 +512,7 @@ class ToolNode(RunnableCallable):
|
||||
input: list[AnyMessage] | dict[str, Any] | BaseModel,
|
||||
config: RunnableConfig,
|
||||
*,
|
||||
store: Optional[BaseStore],
|
||||
store: BaseStore | None,
|
||||
) -> Any:
|
||||
try:
|
||||
runtime = get_runtime()
|
||||
@@ -535,7 +536,7 @@ class ToolNode(RunnableCallable):
|
||||
input: list[AnyMessage] | dict[str, Any] | BaseModel,
|
||||
config: RunnableConfig,
|
||||
*,
|
||||
store: Optional[BaseStore],
|
||||
store: BaseStore | None,
|
||||
) -> Any:
|
||||
try:
|
||||
runtime = get_runtime()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Tests for on_tool_call handler functionality."""
|
||||
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
@@ -41,7 +42,7 @@ def test_on_tool_call_passthrough() -> None:
|
||||
"""Test that a simple passthrough handler works."""
|
||||
|
||||
def passthrough_handler(
|
||||
request: ToolCallRequest, state, runtime
|
||||
request: ToolCallRequest, _state: Any, _runtime: Any
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
"""Simply pass through without modification."""
|
||||
response = yield request
|
||||
@@ -65,14 +66,14 @@ def test_on_tool_call_passthrough() -> None:
|
||||
assert tool_message.status != "error"
|
||||
|
||||
|
||||
def test_on_tool_call_retry_success():
|
||||
def test_on_tool_call_retry_success() -> None:
|
||||
"""Test that retry handler can recover from transient errors."""
|
||||
# Reset counter
|
||||
if hasattr(rate_limit_tool, "_call_count"):
|
||||
rate_limit_tool._call_count = 0
|
||||
|
||||
def retry_handler(
|
||||
request: ToolCallRequest, state, runtime
|
||||
request: ToolCallRequest, _state: Any, _runtime: Any
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
"""Retry up to 3 times."""
|
||||
max_retries = 3
|
||||
@@ -97,7 +98,8 @@ def test_on_tool_call_retry_success():
|
||||
status="error",
|
||||
),
|
||||
)
|
||||
raise AssertionError("Unreachable code")
|
||||
msg = "Unreachable code"
|
||||
raise AssertionError(msg)
|
||||
|
||||
tool_node = ToolNode([rate_limit_tool], on_tool_call=retry_handler, handle_tool_errors=False)
|
||||
result = tool_node.invoke(
|
||||
@@ -117,11 +119,11 @@ def test_on_tool_call_retry_success():
|
||||
assert tool_message.status != "error"
|
||||
|
||||
|
||||
def test_on_tool_call_convert_error_to_message():
|
||||
def test_on_tool_call_convert_error_to_message() -> None:
|
||||
"""Test that handler can convert raised errors to error messages."""
|
||||
|
||||
def error_to_message_handler(
|
||||
request: ToolCallRequest, state, runtime
|
||||
request: ToolCallRequest, _state: Any, _runtime: Any
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
"""Convert any error to a user-friendly message."""
|
||||
response = yield request
|
||||
@@ -161,11 +163,11 @@ def test_on_tool_call_convert_error_to_message():
|
||||
assert tool_message.status == "error"
|
||||
|
||||
|
||||
def test_on_tool_call_let_error_raise():
|
||||
def test_on_tool_call_let_error_raise() -> None:
|
||||
"""Test that handler can let errors propagate."""
|
||||
|
||||
def let_raise_handler(
|
||||
request: ToolCallRequest, state, runtime
|
||||
request: ToolCallRequest, _state: Any, _runtime: Any
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
"""Just return the response as-is, letting errors raise."""
|
||||
response = yield request
|
||||
@@ -173,7 +175,7 @@ def test_on_tool_call_let_error_raise():
|
||||
|
||||
tool_node = ToolNode([error_tool], on_tool_call=let_raise_handler, handle_tool_errors=False)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
with pytest.raises(ValueError, match=r"Error with value: 5"):
|
||||
tool_node.invoke(
|
||||
{
|
||||
"messages": [
|
||||
@@ -185,15 +187,13 @@ def test_on_tool_call_let_error_raise():
|
||||
}
|
||||
)
|
||||
|
||||
assert "Error with value: 5" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_on_tool_call_with_handled_errors():
|
||||
def test_on_tool_call_with_handled_errors() -> None:
|
||||
"""Test interaction between on_tool_call and handle_tool_errors."""
|
||||
call_count = {"count": 0}
|
||||
|
||||
def counting_handler(
|
||||
request: ToolCallRequest, state, runtime
|
||||
request: ToolCallRequest, _state: Any, _runtime: Any
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
"""Count how many times we're called."""
|
||||
call_count["count"] += 1
|
||||
@@ -221,19 +221,19 @@ def test_on_tool_call_with_handled_errors():
|
||||
assert "Please fix your mistakes" in tool_message.content
|
||||
|
||||
|
||||
def test_on_tool_call_must_return_value():
|
||||
def test_on_tool_call_must_return_value() -> None:
|
||||
"""Test that handler must return a ToolCallResponse."""
|
||||
|
||||
def no_return_handler(
|
||||
request: ToolCallRequest, state, runtime
|
||||
request: ToolCallRequest, _state: Any, _runtime: Any
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
"""Handler that doesn't return anything."""
|
||||
response = yield request
|
||||
_ = yield request
|
||||
# Implicit return None
|
||||
|
||||
tool_node = ToolNode([success_tool], on_tool_call=no_return_handler)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
with pytest.raises(ValueError, match=r"must explicitly return a ToolCallResponse"):
|
||||
tool_node.invoke(
|
||||
{
|
||||
"messages": [
|
||||
@@ -245,14 +245,12 @@ def test_on_tool_call_must_return_value():
|
||||
}
|
||||
)
|
||||
|
||||
assert "must explicitly return a ToolCallResponse" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_on_tool_call_request_modification():
|
||||
def test_on_tool_call_request_modification() -> None:
|
||||
"""Test that handler can modify the request before execution."""
|
||||
|
||||
def double_input_handler(
|
||||
request: ToolCallRequest, state, runtime
|
||||
request: ToolCallRequest, _state: Any, _runtime: Any
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
"""Double the input value."""
|
||||
# Modify the tool call args
|
||||
@@ -286,17 +284,15 @@ def test_on_tool_call_request_modification():
|
||||
assert tool_message.content == "20"
|
||||
|
||||
|
||||
def test_on_tool_call_response_validation():
|
||||
def test_on_tool_call_response_validation() -> None:
|
||||
"""Test that ToolCallResponse validates action and required fields."""
|
||||
# Test action="return" requires result
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
with pytest.raises(ValueError, match=r"action='return' requires a result"):
|
||||
ToolCallResponse(action="return")
|
||||
assert "action='return' requires a result" in str(exc_info.value)
|
||||
|
||||
# Test action="raise" requires exception
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
with pytest.raises(ValueError, match=r"action='raise' requires an exception"):
|
||||
ToolCallResponse(action="raise")
|
||||
assert "action='raise' requires an exception" in str(exc_info.value)
|
||||
|
||||
# Valid responses should work
|
||||
ToolCallResponse(
|
||||
@@ -306,7 +302,7 @@ def test_on_tool_call_response_validation():
|
||||
ToolCallResponse(action="raise", exception=ValueError("test"))
|
||||
|
||||
|
||||
def test_on_tool_call_without_handler_backward_compat():
|
||||
def test_on_tool_call_without_handler_backward_compat() -> None:
|
||||
"""Test that tools work without on_tool_call handler (backward compatibility)."""
|
||||
# Success case
|
||||
tool_node = ToolNode([success_tool])
|
||||
@@ -324,7 +320,7 @@ def test_on_tool_call_without_handler_backward_compat():
|
||||
|
||||
# Error case with handle_tool_errors=False
|
||||
tool_node_error = ToolNode([error_tool], handle_tool_errors=False)
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match=r"Error with value: 5"):
|
||||
tool_node_error.invoke(
|
||||
{
|
||||
"messages": [
|
||||
@@ -351,12 +347,12 @@ def test_on_tool_call_without_handler_backward_compat():
|
||||
assert result["messages"][0].status == "error"
|
||||
|
||||
|
||||
def test_on_tool_call_multiple_yields():
|
||||
def test_on_tool_call_multiple_yields() -> None:
|
||||
"""Test that handler can yield multiple times for retries."""
|
||||
attempts = {"count": 0}
|
||||
|
||||
def multi_yield_handler(
|
||||
request: ToolCallRequest, state, runtime
|
||||
request: ToolCallRequest, _state: Any, _runtime: Any
|
||||
) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]:
|
||||
"""Yield multiple times to track attempts."""
|
||||
max_attempts = 3
|
||||
@@ -373,7 +369,7 @@ def test_on_tool_call_multiple_yields():
|
||||
|
||||
tool_node = ToolNode([error_tool], on_tool_call=multi_yield_handler, handle_tool_errors=False)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(ValueError, match=r"Error with value: 5"):
|
||||
tool_node.invoke(
|
||||
{
|
||||
"messages": [
|
||||
|
||||
Reference in New Issue
Block a user