diff --git a/libs/langchain_v1/langchain/agents/middleware/tool_emulator.py b/libs/langchain_v1/langchain/agents/middleware/tool_emulator.py index 3ba141badf4..1b6ee202167 100644 --- a/libs/langchain_v1/langchain/agents/middleware/tool_emulator.py +++ b/libs/langchain_v1/langchain/agents/middleware/tool_emulator.py @@ -123,7 +123,7 @@ class LLMToolEmulator(AgentMiddleware): # Extract tool information for emulation tool_args = request.tool_call["args"] - tool_description = request.tool.description + tool_description = request.tool.description if request.tool else "No description available" # Build prompt for emulator LLM prompt = ( @@ -175,7 +175,7 @@ class LLMToolEmulator(AgentMiddleware): # Extract tool information for emulation tool_args = request.tool_call["args"] - tool_description = request.tool.description + tool_description = request.tool.description if request.tool else "No description available" # Build prompt for emulator LLM prompt = ( diff --git a/libs/langchain_v1/langchain/agents/middleware/tool_retry.py b/libs/langchain_v1/langchain/agents/middleware/tool_retry.py index 146cf75625f..7fb319a46d0 100644 --- a/libs/langchain_v1/langchain/agents/middleware/tool_retry.py +++ b/libs/langchain_v1/langchain/agents/middleware/tool_retry.py @@ -299,7 +299,7 @@ class ToolRetryMiddleware(AgentMiddleware): Returns: ToolMessage or Command (the final result). """ - tool_name = request.tool.name + tool_name = request.tool.name if request.tool else request.tool_call["name"] # Check if retry should apply to this tool if not self._should_retry_tool(tool_name): @@ -348,7 +348,7 @@ class ToolRetryMiddleware(AgentMiddleware): Returns: ToolMessage or Command (the final result). """ - tool_name = request.tool.name + tool_name = request.tool.name if request.tool else request.tool_call["name"] # Check if retry should apply to this tool if not self._should_retry_tool(tool_name): diff --git a/libs/langchain_v1/langchain/tools/tool_node.py b/libs/langchain_v1/langchain/tools/tool_node.py index fd326c8c8e6..636dce084e0 100644 --- a/libs/langchain_v1/langchain/tools/tool_node.py +++ b/libs/langchain_v1/langchain/tools/tool_node.py @@ -121,13 +121,16 @@ class ToolCallRequest: Attributes: tool_call: Tool call dict with name, args, and id from model output. - tool: BaseTool instance to be invoked. + tool: BaseTool instance to be invoked, or None if tool is not + registered with the ToolNode. When tool is None, interceptors can + handle the request without validation. If the interceptor calls execute(), + validation will occur and raise an error for unregistered tools. state: Agent state (dict, list, or BaseModel). runtime: LangGraph runtime context (optional, None if outside graph). """ tool_call: ToolCall - tool: BaseTool + tool: BaseTool | None state: Any runtime: ToolRuntime @@ -728,6 +731,14 @@ class _ToolNode(RunnableCallable): call = request.tool_call tool = request.tool + # Validate tool exists when we actually need to execute it + if tool is None: + if invalid_tool_message := self._validate_tool_call(call): + return invalid_tool_message + # This should never happen if validation works correctly + msg = f"Tool {call['name']} is not registered with ToolNode" + raise TypeError(msg) + call_args = {**call, "type": "tool_call"} try: @@ -804,10 +815,9 @@ class _ToolNode(RunnableCallable): Returns: ToolMessage or Command. """ - if invalid_tool_message := self._validate_tool_call(call): - return invalid_tool_message - - tool = self.tools_by_name[call["name"]] + # Validation is deferred to _execute_tool_sync to allow interceptors + # to short-circuit requests for unregistered tools + tool = self.tools_by_name.get(call["name"]) # Create the tool request with state and runtime tool_request = ToolCallRequest( @@ -866,6 +876,14 @@ class _ToolNode(RunnableCallable): call = request.tool_call tool = request.tool + # Validate tool exists when we actually need to execute it + if tool is None: + if invalid_tool_message := self._validate_tool_call(call): + return invalid_tool_message + # This should never happen if validation works correctly + msg = f"Tool {call['name']} is not registered with ToolNode" + raise TypeError(msg) + call_args = {**call, "type": "tool_call"} try: @@ -942,10 +960,9 @@ class _ToolNode(RunnableCallable): Returns: ToolMessage or Command. """ - if invalid_tool_message := self._validate_tool_call(call): - return invalid_tool_message - - tool = self.tools_by_name[call["name"]] + # Validation is deferred to _execute_tool_async to allow interceptors + # to short-circuit requests for unregistered tools + tool = self.tools_by_name.get(call["name"]) # Create the tool request with state and runtime tool_request = ToolCallRequest( diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_tool_node_interceptor_unregistered.py b/libs/langchain_v1/tests/unit_tests/agents/test_tool_node_interceptor_unregistered.py new file mode 100644 index 00000000000..bbf3060b3cb --- /dev/null +++ b/libs/langchain_v1/tests/unit_tests/agents/test_tool_node_interceptor_unregistered.py @@ -0,0 +1,571 @@ +"""Test tool node interceptor handling of unregistered tools.""" + +from collections.abc import Awaitable, Callable +from unittest.mock import Mock + +import pytest +from langchain_core.messages import AIMessage, ToolMessage +from langchain_core.runnables.config import RunnableConfig +from langchain_core.tools import tool as dec_tool +from langgraph.store.base import BaseStore +from langgraph.types import Command + +from langchain.tools.tool_node import ToolCallRequest, _ToolNode + +pytestmark = pytest.mark.anyio + + +def _create_mock_runtime(store: BaseStore | None = None) -> Mock: + """Create a mock Runtime object for testing ToolNode outside of graph context. + + This helper is needed because ToolNode._func expects a Runtime parameter + which is injected by RunnableCallable from config["configurable"]["__pregel_runtime"]. + When testing ToolNode directly (outside a graph), we need to provide this manually. + """ + mock_runtime = Mock() + mock_runtime.store = store + mock_runtime.context = None + mock_runtime.stream_writer = lambda *args, **kwargs: None + return mock_runtime + + +def _create_config_with_runtime(store: BaseStore | None = None) -> RunnableConfig: + """Create a RunnableConfig with mock Runtime for testing ToolNode. + + Returns: + RunnableConfig with __pregel_runtime in configurable dict. + """ + return {"configurable": {"__pregel_runtime": _create_mock_runtime(store)}} + + +@dec_tool +def registered_tool(x: int) -> str: + """A registered tool.""" + return f"Result: {x}" + + +def test_interceptor_can_handle_unregistered_tool_sync() -> None: + """Test that interceptor can handle requests for unregistered tools (sync).""" + + def interceptor( + request: ToolCallRequest, + execute: Callable[[ToolCallRequest], ToolMessage | Command], + ) -> ToolMessage | Command: + """Intercept and handle unregistered tools.""" + if request.tool_call["name"] == "unregistered_tool": + # Short-circuit without calling execute for unregistered tool + return ToolMessage( + content="Handled by interceptor", + tool_call_id=request.tool_call["id"], + name="unregistered_tool", + ) + # Pass through for registered tools + return execute(request) + + node = _ToolNode([registered_tool], wrap_tool_call=interceptor) + + # Test registered tool works normally + result = node.invoke( + [ + AIMessage( + "", + tool_calls=[ + { + "name": "registered_tool", + "args": {"x": 42}, + "id": "1", + "type": "tool_call", + } + ], + ) + ], + config=_create_config_with_runtime(), + ) + assert result[0].content == "Result: 42" + assert result[0].tool_call_id == "1" + + # Test unregistered tool is intercepted and handled + result = node.invoke( + [ + AIMessage( + "", + tool_calls=[ + { + "name": "unregistered_tool", + "args": {"x": 99}, + "id": "2", + "type": "tool_call", + } + ], + ) + ], + config=_create_config_with_runtime(), + ) + assert result[0].content == "Handled by interceptor" + assert result[0].tool_call_id == "2" + assert result[0].name == "unregistered_tool" + + +async def test_interceptor_can_handle_unregistered_tool_async() -> None: + """Test that interceptor can handle requests for unregistered tools (async).""" + + async def async_interceptor( + request: ToolCallRequest, + execute: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]], + ) -> ToolMessage | Command: + """Intercept and handle unregistered tools.""" + if request.tool_call["name"] == "unregistered_tool": + # Short-circuit without calling execute for unregistered tool + return ToolMessage( + content="Handled by async interceptor", + tool_call_id=request.tool_call["id"], + name="unregistered_tool", + ) + # Pass through for registered tools + return await execute(request) + + node = _ToolNode([registered_tool], awrap_tool_call=async_interceptor) + + # Test registered tool works normally + result = await node.ainvoke( + [ + AIMessage( + "", + tool_calls=[ + { + "name": "registered_tool", + "args": {"x": 42}, + "id": "1", + "type": "tool_call", + } + ], + ) + ], + config=_create_config_with_runtime(), + ) + assert result[0].content == "Result: 42" + assert result[0].tool_call_id == "1" + + # Test unregistered tool is intercepted and handled + result = await node.ainvoke( + [ + AIMessage( + "", + tool_calls=[ + { + "name": "unregistered_tool", + "args": {"x": 99}, + "id": "2", + "type": "tool_call", + } + ], + ) + ], + config=_create_config_with_runtime(), + ) + assert result[0].content == "Handled by async interceptor" + assert result[0].tool_call_id == "2" + assert result[0].name == "unregistered_tool" + + +def test_unregistered_tool_error_when_interceptor_calls_execute() -> None: + """Test that unregistered tools error if interceptor tries to execute them.""" + + def bad_interceptor( + request: ToolCallRequest, + execute: Callable[[ToolCallRequest], ToolMessage | Command], + ) -> ToolMessage | Command: + """Interceptor that tries to execute unregistered tool.""" + # This should fail validation when execute is called + return execute(request) + + node = _ToolNode([registered_tool], wrap_tool_call=bad_interceptor) + + # Registered tool should still work + result = node.invoke( + [ + AIMessage( + "", + tool_calls=[ + { + "name": "registered_tool", + "args": {"x": 42}, + "id": "1", + "type": "tool_call", + } + ], + ) + ], + config=_create_config_with_runtime(), + ) + assert result[0].content == "Result: 42" + + # Unregistered tool should error when interceptor calls execute + result = node.invoke( + [ + AIMessage( + "", + tool_calls=[ + { + "name": "unregistered_tool", + "args": {"x": 99}, + "id": "2", + "type": "tool_call", + } + ], + ) + ], + config=_create_config_with_runtime(), + ) + # Should get validation error message + assert result[0].status == "error" + assert "is not a valid tool" in result[0].content + assert result[0].tool_call_id == "2" + + +def test_interceptor_handles_mix_of_registered_and_unregistered() -> None: + """Test interceptor handling mix of registered and unregistered tools.""" + + def selective_interceptor( + request: ToolCallRequest, + execute: Callable[[ToolCallRequest], ToolMessage | Command], + ) -> ToolMessage | Command: + """Handle unregistered tools, pass through registered ones.""" + if request.tool_call["name"] == "magic_tool": + return ToolMessage( + content=f"Magic result: {request.tool_call['args'].get('value', 0) * 2}", + tool_call_id=request.tool_call["id"], + name="magic_tool", + ) + return execute(request) + + node = _ToolNode([registered_tool], wrap_tool_call=selective_interceptor) + + # Test multiple tool calls - mix of registered and unregistered + result = node.invoke( + [ + AIMessage( + "", + tool_calls=[ + { + "name": "registered_tool", + "args": {"x": 10}, + "id": "1", + "type": "tool_call", + }, + { + "name": "magic_tool", + "args": {"value": 5}, + "id": "2", + "type": "tool_call", + }, + { + "name": "registered_tool", + "args": {"x": 20}, + "id": "3", + "type": "tool_call", + }, + ], + ) + ], + config=_create_config_with_runtime(), + ) + + # All tools should execute successfully + assert len(result) == 3 + assert result[0].content == "Result: 10" + assert result[0].tool_call_id == "1" + assert result[1].content == "Magic result: 10" + assert result[1].tool_call_id == "2" + assert result[2].content == "Result: 20" + assert result[2].tool_call_id == "3" + + +def test_interceptor_command_for_unregistered_tool() -> None: + """Test interceptor returning Command for unregistered tool.""" + + def command_interceptor( + request: ToolCallRequest, + execute: Callable[[ToolCallRequest], ToolMessage | Command], + ) -> ToolMessage | Command: + """Return Command for unregistered tools.""" + if request.tool_call["name"] == "routing_tool": + return Command( + update=[ + ToolMessage( + content="Routing to special handler", + tool_call_id=request.tool_call["id"], + name="routing_tool", + ) + ], + goto="special_node", + ) + return execute(request) + + node = _ToolNode([registered_tool], wrap_tool_call=command_interceptor) + + result = node.invoke( + [ + AIMessage( + "", + tool_calls=[ + { + "name": "routing_tool", + "args": {}, + "id": "1", + "type": "tool_call", + } + ], + ) + ], + config=_create_config_with_runtime(), + ) + + # Should get Command back + assert len(result) == 1 + assert isinstance(result[0], Command) + assert result[0].goto == "special_node" + assert result[0].update is not None + assert len(result[0].update) == 1 + assert result[0].update[0].content == "Routing to special handler" + + +def test_interceptor_exception_with_unregistered_tool() -> None: + """Test that interceptor exceptions are caught by error handling.""" + + def failing_interceptor( + request: ToolCallRequest, + execute: Callable[[ToolCallRequest], ToolMessage | Command], + ) -> ToolMessage | Command: + """Interceptor that throws exception for unregistered tools.""" + if request.tool_call["name"] == "bad_tool": + msg = "Interceptor failed" + raise ValueError(msg) + return execute(request) + + node = _ToolNode([registered_tool], wrap_tool_call=failing_interceptor, handle_tool_errors=True) + + # Interceptor exception should be caught and converted to error message + result = node.invoke( + [ + AIMessage( + "", + tool_calls=[ + { + "name": "bad_tool", + "args": {}, + "id": "1", + "type": "tool_call", + } + ], + ) + ], + config=_create_config_with_runtime(), + ) + + assert len(result) == 1 + assert result[0].status == "error" + assert "Interceptor failed" in result[0].content + assert result[0].tool_call_id == "1" + + # Test that exception is raised when handle_tool_errors is False + node_no_handling = _ToolNode( + [registered_tool], wrap_tool_call=failing_interceptor, handle_tool_errors=False + ) + + with pytest.raises(ValueError, match="Interceptor failed"): + node_no_handling.invoke( + [ + AIMessage( + "", + tool_calls=[ + { + "name": "bad_tool", + "args": {}, + "id": "2", + "type": "tool_call", + } + ], + ) + ], + config=_create_config_with_runtime(), + ) + + +async def test_async_interceptor_exception_with_unregistered_tool() -> None: + """Test that async interceptor exceptions are caught by error handling.""" + + async def failing_async_interceptor( + request: ToolCallRequest, + execute: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]], + ) -> ToolMessage | Command: + """Async interceptor that throws exception for unregistered tools.""" + if request.tool_call["name"] == "bad_async_tool": + msg = "Async interceptor failed" + raise RuntimeError(msg) + return await execute(request) + + node = _ToolNode( + [registered_tool], awrap_tool_call=failing_async_interceptor, handle_tool_errors=True + ) + + # Interceptor exception should be caught and converted to error message + result = await node.ainvoke( + [ + AIMessage( + "", + tool_calls=[ + { + "name": "bad_async_tool", + "args": {}, + "id": "1", + "type": "tool_call", + } + ], + ) + ], + config=_create_config_with_runtime(), + ) + + assert len(result) == 1 + assert result[0].status == "error" + assert "Async interceptor failed" in result[0].content + assert result[0].tool_call_id == "1" + + # Test that exception is raised when handle_tool_errors is False + node_no_handling = _ToolNode( + [registered_tool], awrap_tool_call=failing_async_interceptor, handle_tool_errors=False + ) + + with pytest.raises(RuntimeError, match="Async interceptor failed"): + await node_no_handling.ainvoke( + [ + AIMessage( + "", + tool_calls=[ + { + "name": "bad_async_tool", + "args": {}, + "id": "2", + "type": "tool_call", + } + ], + ) + ], + config=_create_config_with_runtime(), + ) + + +def test_interceptor_with_dict_input_format() -> None: + """Test that interceptor works with dict input format.""" + + def interceptor( + request: ToolCallRequest, + execute: Callable[[ToolCallRequest], ToolMessage | Command], + ) -> ToolMessage | Command: + """Intercept unregistered tools with dict input.""" + if request.tool_call["name"] == "dict_tool": + return ToolMessage( + content="Handled dict input", + tool_call_id=request.tool_call["id"], + name="dict_tool", + ) + return execute(request) + + node = _ToolNode([registered_tool], wrap_tool_call=interceptor) + + # Test with dict input format + result = node.invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[ + { + "name": "dict_tool", + "args": {"value": 5}, + "id": "1", + "type": "tool_call", + } + ], + ) + ] + }, + config=_create_config_with_runtime(), + ) + + # Should return dict format output + assert isinstance(result, dict) + assert "messages" in result + assert len(result["messages"]) == 1 + assert result["messages"][0].content == "Handled dict input" + assert result["messages"][0].tool_call_id == "1" + + +def test_interceptor_verifies_tool_is_none_for_unregistered() -> None: + """Test that request.tool is None for unregistered tools.""" + + captured_requests: list[ToolCallRequest] = [] + + def capturing_interceptor( + request: ToolCallRequest, + execute: Callable[[ToolCallRequest], ToolMessage | Command], + ) -> ToolMessage | Command: + """Capture request to verify tool field.""" + captured_requests.append(request) + if request.tool is None: + # Tool is unregistered + return ToolMessage( + content=f"Unregistered: {request.tool_call['name']}", + tool_call_id=request.tool_call["id"], + name=request.tool_call["name"], + ) + # Tool is registered + return execute(request) + + node = _ToolNode([registered_tool], wrap_tool_call=capturing_interceptor) + + # Test unregistered tool + node.invoke( + [ + AIMessage( + "", + tool_calls=[ + { + "name": "unknown_tool", + "args": {}, + "id": "1", + "type": "tool_call", + } + ], + ) + ], + config=_create_config_with_runtime(), + ) + + assert len(captured_requests) == 1 + assert captured_requests[0].tool is None + assert captured_requests[0].tool_call["name"] == "unknown_tool" + + # Clear and test registered tool + captured_requests.clear() + node.invoke( + [ + AIMessage( + "", + tool_calls=[ + { + "name": "registered_tool", + "args": {"x": 10}, + "id": "2", + "type": "tool_call", + } + ], + ) + ], + config=_create_config_with_runtime(), + ) + + assert len(captured_requests) == 1 + assert captured_requests[0].tool is not None + assert captured_requests[0].tool.name == "registered_tool"