diff --git a/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py b/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py index 37441d9be59..f165d499385 100644 --- a/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py +++ b/libs/langchain_v1/langchain/agents/middleware/tool_error_handling.py @@ -9,7 +9,7 @@ from __future__ import annotations import inspect import logging import time -from collections.abc import Callable +from types import UnionType from typing import TYPE_CHECKING, Union, get_args, get_origin, get_type_hints from langchain_core.messages import ToolMessage @@ -17,13 +17,10 @@ from langchain_core.messages import ToolMessage from langchain.agents.middleware.types import AgentMiddleware # Import ToolCallResponse locally to avoid circular import -from langchain.tools.tool_node import ToolCallResponse +from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse if TYPE_CHECKING: - from collections.abc import Generator - from types import UnionType - - from langchain.tools.tool_node import ToolCallRequest, ToolCallResponse + from collections.abc import Callable, Generator logger = logging.getLogger(__name__) @@ -72,13 +69,14 @@ def _infer_retriable_types( if first_param.name in type_hints: origin = get_origin(first_param.annotation) # Handle Union types - if origin in [Union, UnionType]: # type: ignore[has-type] + if origin in [Union, UnionType]: args = get_args(first_param.annotation) if all(isinstance(arg, type) and issubclass(arg, Exception) for arg in args): return tuple(args) msg = ( "All types in retry predicate annotation must be Exception types. " - "For example, `def should_retry(e: Union[TimeoutError, ConnectionError]) -> bool`. " + "For example, `def should_retry(e: Union[TimeoutError, " + "ConnectionError]) -> bool`. " f"Got '{first_param.annotation}' instead." ) raise ValueError(msg) @@ -260,14 +258,13 @@ class RetryMiddleware(AgentMiddleware): return response # If predicate is provided, check if we should retry - if self._retry_predicate is not None: - if not self._retry_predicate(exception): - logger.debug( - "Retry predicate returned False for %s in tool %s", - type(exception).__name__, - request.tool_call["name"], - ) - return response + if self._retry_predicate is not None and not self._retry_predicate(exception): + logger.debug( + "Retry predicate returned False for %s in tool %s", + type(exception).__name__, + request.tool_call["name"], + ) + return response # Last attempt - return error if attempt > self.max_retries: diff --git a/libs/langchain_v1/langchain/tools/tool_node.py b/libs/langchain_v1/langchain/tools/tool_node.py index 56d29fc95ce..6be59457ecc 100644 --- a/libs/langchain_v1/langchain/tools/tool_node.py +++ b/libs/langchain_v1/langchain/tools/tool_node.py @@ -152,7 +152,8 @@ ToolCallHandler = Callable[ """Generator-based handler that intercepts tool execution. Receives a ToolCallRequest, state, and runtime; yields modified ToolCallRequests; -receives ToolCallResponses; and returns a final ToolCallResponse. Supports multiple yields for retry logic. +receives ToolCallResponses; and returns a final ToolCallResponse. Supports multiple +yields for retry logic. """ @@ -511,7 +512,7 @@ class ToolNode(RunnableCallable): input: list[AnyMessage] | dict[str, Any] | BaseModel, config: RunnableConfig, *, - store: Optional[BaseStore], + store: BaseStore | None, ) -> Any: try: runtime = get_runtime() @@ -535,7 +536,7 @@ class ToolNode(RunnableCallable): input: list[AnyMessage] | dict[str, Any] | BaseModel, config: RunnableConfig, *, - store: Optional[BaseStore], + store: BaseStore | None, ) -> Any: try: runtime = get_runtime() diff --git a/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py b/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py index f844aff7352..5a21df58b40 100644 --- a/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py +++ b/libs/langchain_v1/tests/unit_tests/tools/test_on_tool_call.py @@ -1,6 +1,7 @@ """Tests for on_tool_call handler functionality.""" from collections.abc import Generator +from typing import Any import pytest from langchain_core.messages import AIMessage, ToolMessage @@ -41,7 +42,7 @@ def test_on_tool_call_passthrough() -> None: """Test that a simple passthrough handler works.""" def passthrough_handler( - request: ToolCallRequest, state, runtime + request: ToolCallRequest, _state: Any, _runtime: Any ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Simply pass through without modification.""" response = yield request @@ -65,14 +66,14 @@ def test_on_tool_call_passthrough() -> None: assert tool_message.status != "error" -def test_on_tool_call_retry_success(): +def test_on_tool_call_retry_success() -> None: """Test that retry handler can recover from transient errors.""" # Reset counter if hasattr(rate_limit_tool, "_call_count"): rate_limit_tool._call_count = 0 def retry_handler( - request: ToolCallRequest, state, runtime + request: ToolCallRequest, _state: Any, _runtime: Any ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Retry up to 3 times.""" max_retries = 3 @@ -97,7 +98,8 @@ def test_on_tool_call_retry_success(): status="error", ), ) - raise AssertionError("Unreachable code") + msg = "Unreachable code" + raise AssertionError(msg) tool_node = ToolNode([rate_limit_tool], on_tool_call=retry_handler, handle_tool_errors=False) result = tool_node.invoke( @@ -117,11 +119,11 @@ def test_on_tool_call_retry_success(): assert tool_message.status != "error" -def test_on_tool_call_convert_error_to_message(): +def test_on_tool_call_convert_error_to_message() -> None: """Test that handler can convert raised errors to error messages.""" def error_to_message_handler( - request: ToolCallRequest, state, runtime + request: ToolCallRequest, _state: Any, _runtime: Any ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Convert any error to a user-friendly message.""" response = yield request @@ -161,11 +163,11 @@ def test_on_tool_call_convert_error_to_message(): assert tool_message.status == "error" -def test_on_tool_call_let_error_raise(): +def test_on_tool_call_let_error_raise() -> None: """Test that handler can let errors propagate.""" def let_raise_handler( - request: ToolCallRequest, state, runtime + request: ToolCallRequest, _state: Any, _runtime: Any ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Just return the response as-is, letting errors raise.""" response = yield request @@ -173,7 +175,7 @@ def test_on_tool_call_let_error_raise(): tool_node = ToolNode([error_tool], on_tool_call=let_raise_handler, handle_tool_errors=False) - with pytest.raises(ValueError) as exc_info: + with pytest.raises(ValueError, match=r"Error with value: 5"): tool_node.invoke( { "messages": [ @@ -185,15 +187,13 @@ def test_on_tool_call_let_error_raise(): } ) - assert "Error with value: 5" in str(exc_info.value) - -def test_on_tool_call_with_handled_errors(): +def test_on_tool_call_with_handled_errors() -> None: """Test interaction between on_tool_call and handle_tool_errors.""" call_count = {"count": 0} def counting_handler( - request: ToolCallRequest, state, runtime + request: ToolCallRequest, _state: Any, _runtime: Any ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Count how many times we're called.""" call_count["count"] += 1 @@ -221,19 +221,19 @@ def test_on_tool_call_with_handled_errors(): assert "Please fix your mistakes" in tool_message.content -def test_on_tool_call_must_return_value(): +def test_on_tool_call_must_return_value() -> None: """Test that handler must return a ToolCallResponse.""" def no_return_handler( - request: ToolCallRequest, state, runtime + request: ToolCallRequest, _state: Any, _runtime: Any ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Handler that doesn't return anything.""" - response = yield request + _ = yield request # Implicit return None tool_node = ToolNode([success_tool], on_tool_call=no_return_handler) - with pytest.raises(ValueError) as exc_info: + with pytest.raises(ValueError, match=r"must explicitly return a ToolCallResponse"): tool_node.invoke( { "messages": [ @@ -245,14 +245,12 @@ def test_on_tool_call_must_return_value(): } ) - assert "must explicitly return a ToolCallResponse" in str(exc_info.value) - -def test_on_tool_call_request_modification(): +def test_on_tool_call_request_modification() -> None: """Test that handler can modify the request before execution.""" def double_input_handler( - request: ToolCallRequest, state, runtime + request: ToolCallRequest, _state: Any, _runtime: Any ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Double the input value.""" # Modify the tool call args @@ -286,17 +284,15 @@ def test_on_tool_call_request_modification(): assert tool_message.content == "20" -def test_on_tool_call_response_validation(): +def test_on_tool_call_response_validation() -> None: """Test that ToolCallResponse validates action and required fields.""" # Test action="return" requires result - with pytest.raises(ValueError) as exc_info: + with pytest.raises(ValueError, match=r"action='return' requires a result"): ToolCallResponse(action="return") - assert "action='return' requires a result" in str(exc_info.value) # Test action="raise" requires exception - with pytest.raises(ValueError) as exc_info: + with pytest.raises(ValueError, match=r"action='raise' requires an exception"): ToolCallResponse(action="raise") - assert "action='raise' requires an exception" in str(exc_info.value) # Valid responses should work ToolCallResponse( @@ -306,7 +302,7 @@ def test_on_tool_call_response_validation(): ToolCallResponse(action="raise", exception=ValueError("test")) -def test_on_tool_call_without_handler_backward_compat(): +def test_on_tool_call_without_handler_backward_compat() -> None: """Test that tools work without on_tool_call handler (backward compatibility).""" # Success case tool_node = ToolNode([success_tool]) @@ -324,7 +320,7 @@ def test_on_tool_call_without_handler_backward_compat(): # Error case with handle_tool_errors=False tool_node_error = ToolNode([error_tool], handle_tool_errors=False) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=r"Error with value: 5"): tool_node_error.invoke( { "messages": [ @@ -351,12 +347,12 @@ def test_on_tool_call_without_handler_backward_compat(): assert result["messages"][0].status == "error" -def test_on_tool_call_multiple_yields(): +def test_on_tool_call_multiple_yields() -> None: """Test that handler can yield multiple times for retries.""" attempts = {"count": 0} def multi_yield_handler( - request: ToolCallRequest, state, runtime + request: ToolCallRequest, _state: Any, _runtime: Any ) -> Generator[ToolCallRequest, ToolCallResponse, ToolCallResponse]: """Yield multiple times to track attempts.""" max_attempts = 3 @@ -373,7 +369,7 @@ def test_on_tool_call_multiple_yields(): tool_node = ToolNode([error_tool], on_tool_call=multi_yield_handler, handle_tool_errors=False) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=r"Error with value: 5"): tool_node.invoke( { "messages": [