Compare commits

...

1 Commits

Author SHA1 Message Date
Sydney Runkle
7230be12bb cc pass at fixing tool message bug 2025-10-20 13:14:45 -04:00
2 changed files with 180 additions and 2 deletions

View File

@@ -748,7 +748,9 @@ class _ToolNode(RunnableCallable):
try:
response = tool.invoke(call_args, config)
except ValidationError as exc:
raise ToolInvocationError(call["name"], exc, call["args"]) from exc
# Filter out injected arguments before including in error message
filtered_args = self._filter_injected_args(call["name"], call["args"])
raise ToolInvocationError(call["name"], exc, filtered_args) from exc
# GraphInterrupt is a special exception that will always be raised.
# It can be triggered in the following scenarios,
@@ -893,7 +895,9 @@ class _ToolNode(RunnableCallable):
try:
response = await tool.ainvoke(call_args, config)
except ValidationError as exc:
raise ToolInvocationError(call["name"], exc, call["args"]) from exc
# Filter out injected arguments before including in error message
filtered_args = self._filter_injected_args(call["name"], call["args"])
raise ToolInvocationError(call["name"], exc, filtered_args) from exc
# GraphInterrupt is a special exception that will always be raised.
# It can be triggered in the following scenarios,
@@ -1197,6 +1201,42 @@ class _ToolNode(RunnableCallable):
tool_call_with_store = self._inject_store(tool_call_with_state, tool_runtime.store)
return self._inject_runtime(tool_call_with_store, tool_runtime)
def _filter_injected_args(self, tool_name: str, tool_args: dict[str, Any]) -> dict[str, Any]:
"""Filter out injected arguments from tool arguments.
When tool invocation fails, we want to show only the model-provided arguments
in error messages, not the runtime-injected arguments (state, store, runtime).
Args:
tool_name: The name of the tool.
tool_args: The tool arguments dictionary (may include injected args).
Returns:
A new dictionary with injected arguments removed.
"""
if tool_name not in self.tools_by_name:
return tool_args
# Collect all injected argument names
injected_arg_names: set[str] = set()
# Add state argument names
state_args = self._tool_to_state_args.get(tool_name, {})
injected_arg_names.update(state_args.keys())
# Add store argument name
store_arg = self._tool_to_store_arg.get(tool_name)
if store_arg:
injected_arg_names.add(store_arg)
# Add runtime argument name
runtime_arg = self._tool_to_runtime_arg.get(tool_name)
if runtime_arg:
injected_arg_names.add(runtime_arg)
# Filter out injected arguments
return {k: v for k, v in tool_args.items() if k not in injected_arg_names}
def _validate_tool_command(
self,
command: Command,

View File

@@ -37,6 +37,7 @@ from typing_extensions import TypedDict
from langchain.tools import (
InjectedState,
InjectedStore,
ToolRuntime,
)
from langchain.tools.tool_node import _ToolNode
from langchain.tools.tool_node import TOOL_CALL_ERROR_TEMPLATE, ToolInvocationError, tools_condition
@@ -1548,3 +1549,140 @@ def test_tool_node_stream_writer() -> None:
},
),
]
def test_tool_invocation_error_excludes_injected_args() -> None:
"""Test that ToolInvocationError excludes injected arguments from error messages.
When a tool with injected arguments (state, store, runtime) fails validation,
the error message should only show the model-provided arguments, not the
injected runtime context.
"""
store = InMemoryStore()
# Create a tool with multiple injected arguments
@dec_tool
def tool_with_injections(
x: int, # Model-provided arg
state: Annotated[dict[str, Any], InjectedState], # Injected state
store_arg: Annotated[BaseStore, InjectedStore], # Injected store
runtime: ToolRuntime, # Injected runtime
) -> str:
"""Tool that requires injected arguments."""
# This validation error will be raised if x is wrong type
return f"x={x}"
node = _ToolNode([tool_with_injections], handle_tool_errors=True)
# Create a tool call with invalid argument (string instead of int)
# This will cause a ValidationError
tool_call = ToolCall(
name="tool_with_injections",
args={"x": "not_an_int"}, # Invalid: should be int
id="test_call_123",
)
msg = AIMessage("test", tool_calls=[tool_call])
# Invoke the node - should handle the validation error
result = node.invoke({"messages": [msg]}, config=_create_config_with_runtime(store=store))
# Get the error message from the tool message
tool_message = result["messages"][-1]
assert isinstance(tool_message, ToolMessage)
assert tool_message.status == "error"
error_content = tool_message.content
# Verify the error message contains the model-provided argument
assert "x" in error_content
assert "not_an_int" in error_content
# Verify the error message does NOT contain injected arguments or their values
# These should be filtered out
assert "state" not in error_content.lower() or "state=" not in error_content.lower()
assert "store_arg" not in error_content.lower()
assert "runtime" not in error_content.lower()
assert "ToolRuntime" not in error_content
assert "InMemoryStore" not in error_content
assert "BaseStore" not in error_content
# The error should mention the tool name
assert "tool_with_injections" in error_content
async def test_tool_invocation_error_excludes_injected_args_async() -> None:
"""Test async version: ToolInvocationError excludes injected arguments.
Similar to the sync test, but verifies the async execution path also
properly filters injected arguments from error messages.
"""
@dec_tool
async def async_tool_with_runtime(
value: str, # Model-provided arg
runtime: ToolRuntime, # Injected runtime
) -> str:
"""Async tool with runtime injection."""
return f"value={value}"
node = _ToolNode([async_tool_with_runtime], handle_tool_errors=True)
# Create a tool call with invalid argument type
tool_call = ToolCall(
name="async_tool_with_runtime",
args={"value": 123}, # Invalid: should be string, not int
id="async_call_456",
)
msg = AIMessage("test", tool_calls=[tool_call])
# Invoke async
result = await node.ainvoke({"messages": [msg]}, config=_create_config_with_runtime())
# Get the error message
tool_message = result["messages"][-1]
assert isinstance(tool_message, ToolMessage)
assert tool_message.status == "error"
error_content = tool_message.content
# Verify error contains the model-provided argument
assert "value" in error_content
# Verify error does NOT contain injected runtime
assert "runtime" not in error_content.lower() or "runtime=" not in error_content.lower()
assert "ToolRuntime" not in error_content
def test_tool_invocation_error_with_no_injections() -> None:
"""Test that tools without injections still show all args in error messages."""
@dec_tool
def simple_tool(a: int, b: int) -> int:
"""Simple tool with no injections."""
return a + b
node = _ToolNode([simple_tool], handle_tool_errors=True)
# Create invalid call
tool_call = ToolCall(
name="simple_tool",
args={"a": "invalid", "b": 2}, # 'a' should be int
id="simple_call",
)
msg = AIMessage("test", tool_calls=[tool_call])
result = node.invoke({"messages": [msg]}, config=_create_config_with_runtime())
tool_message = result["messages"][-1]
assert isinstance(tool_message, ToolMessage)
assert tool_message.status == "error"
error_content = tool_message.content
# Both args should be in the error since neither is injected
assert "a" in error_content
assert "invalid" in error_content
assert "b" in error_content