mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-19 21:35:33 +00:00
Compare commits
12 Commits
bot/refres
...
sr/do-not-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
838c383074 | ||
|
|
c534c365a0 | ||
|
|
ef1f031801 | ||
|
|
0380e50d31 | ||
|
|
6829386f50 | ||
|
|
5067be1760 | ||
|
|
d2d8588ccc | ||
|
|
7341885289 | ||
|
|
cbd75b04d8 | ||
|
|
e0372c14c0 | ||
|
|
35081b1eac | ||
|
|
3f64f5faf5 |
@@ -641,44 +641,55 @@ class ChildTool(BaseTool):
|
||||
if input_args is not None:
|
||||
if isinstance(input_args, dict):
|
||||
return tool_input
|
||||
|
||||
# Get the function signature to identify injected args
|
||||
sig = inspect.signature(self._run)
|
||||
injected_args = []
|
||||
injected_tool_call_id_param = None
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param.annotation != inspect.Parameter.empty:
|
||||
if _is_injected_arg_type(param.annotation, injected_type=InjectedToolCallId):
|
||||
injected_tool_call_id_param = param_name
|
||||
elif _is_injected_arg_type(param.annotation):
|
||||
injected_args.append(param_name)
|
||||
|
||||
# Check if InjectedToolCallId is required
|
||||
if injected_tool_call_id_param and tool_call_id is None:
|
||||
msg = (
|
||||
"When tool includes an InjectedToolCallId "
|
||||
"argument, tool must always be invoked with a full "
|
||||
"model ToolCall of the form: {'args': {...}, "
|
||||
"'name': '...', 'type': 'tool_call', "
|
||||
"'tool_call_id': '...'}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
# Filter out injected args from tool_input for validation
|
||||
# Keep a reference to the full tool_input for return value filtering
|
||||
validation_input = {
|
||||
k: v for k, v in tool_input.items() if k not in injected_args
|
||||
}
|
||||
|
||||
if issubclass(input_args, BaseModel):
|
||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
|
||||
if tool_call_id is None:
|
||||
msg = (
|
||||
"When tool includes an InjectedToolCallId "
|
||||
"argument, tool must always be invoked with a full "
|
||||
"model ToolCall of the form: {'args': {...}, "
|
||||
"'name': '...', 'type': 'tool_call', "
|
||||
"'tool_call_id': '...'}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
tool_input[k] = tool_call_id
|
||||
result = input_args.model_validate(tool_input)
|
||||
result = input_args.model_validate(validation_input)
|
||||
result_dict = result.model_dump()
|
||||
elif issubclass(input_args, BaseModelV1):
|
||||
for k, v in get_all_basemodel_annotations(input_args).items():
|
||||
if _is_injected_arg_type(v, injected_type=InjectedToolCallId):
|
||||
if tool_call_id is None:
|
||||
msg = (
|
||||
"When tool includes an InjectedToolCallId "
|
||||
"argument, tool must always be invoked with a full "
|
||||
"model ToolCall of the form: {'args': {...}, "
|
||||
"'name': '...', 'type': 'tool_call', "
|
||||
"'tool_call_id': '...'}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
tool_input[k] = tool_call_id
|
||||
result = input_args.parse_obj(tool_input)
|
||||
result = input_args.parse_obj(validation_input)
|
||||
result_dict = result.dict()
|
||||
else:
|
||||
msg = (
|
||||
f"args_schema must be a Pydantic BaseModel, got {self.args_schema}"
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
return {
|
||||
|
||||
# Build the final result, including validated args and injected tool_call_id
|
||||
final_result = {
|
||||
k: getattr(result, k) for k, v in result_dict.items() if k in tool_input
|
||||
}
|
||||
if injected_tool_call_id_param:
|
||||
final_result[injected_tool_call_id_param] = tool_call_id
|
||||
|
||||
return final_result
|
||||
return tool_input
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -196,6 +196,7 @@ class StructuredTool(BaseTool):
|
||||
parse_docstring=parse_docstring,
|
||||
error_on_invalid_docstring=error_on_invalid_docstring,
|
||||
filter_args=_filter_schema_args(source_function),
|
||||
include_injected=False,
|
||||
)
|
||||
description_ = description
|
||||
if description is None and not parse_docstring:
|
||||
|
||||
@@ -312,8 +312,26 @@ class ToolInvocationError(ToolException):
|
||||
source: The exception that occurred.
|
||||
tool_kwargs: The keyword arguments that were passed to the tool.
|
||||
"""
|
||||
# Format the ValidationError without input values to avoid exposing
|
||||
# injected arguments in error messages
|
||||
error_count = source.error_count()
|
||||
errors_list = source.errors(include_input=False)
|
||||
|
||||
# Store the filtered errors for programmatic access
|
||||
self.filtered_errors = errors_list
|
||||
|
||||
error_str = f"{error_count} validation error{'s' if error_count > 1 else ''} for {tool_name}\n"
|
||||
for error in errors_list:
|
||||
loc = " -> ".join(str(loc_part) for loc_part in error["loc"])
|
||||
error_str += f"{loc}\n {error['msg']}"
|
||||
if error.get("type"):
|
||||
error_str += f" [type={error['type']}]"
|
||||
if error.get("url"):
|
||||
error_str += f"\n For further information visit {error['url']}"
|
||||
error_str += "\n"
|
||||
|
||||
self.message = TOOL_INVOCATION_ERROR_TEMPLATE.format(
|
||||
tool_name=tool_name, tool_kwargs=tool_kwargs, error=source
|
||||
tool_name=tool_name, tool_kwargs=tool_kwargs, error=error_str
|
||||
)
|
||||
self.tool_name = tool_name
|
||||
self.tool_kwargs = tool_kwargs
|
||||
@@ -623,17 +641,9 @@ class _ToolNode(RunnableCallable):
|
||||
)
|
||||
tool_runtimes.append(tool_runtime)
|
||||
|
||||
# Inject tool arguments (including runtime)
|
||||
|
||||
injected_tool_calls = []
|
||||
input_types = [input_type] * len(tool_calls)
|
||||
for call, tool_runtime in zip(tool_calls, tool_runtimes, strict=False):
|
||||
injected_call = self._inject_tool_args(call, tool_runtime) # type: ignore[arg-type]
|
||||
injected_tool_calls.append(injected_call)
|
||||
with get_executor_for_config(config) as executor:
|
||||
outputs = list(
|
||||
executor.map(self._run_one, injected_tool_calls, input_types, tool_runtimes)
|
||||
)
|
||||
outputs = list(executor.map(self._run_one, tool_calls, input_types, tool_runtimes))
|
||||
|
||||
return self._combine_tool_outputs(outputs, input_type)
|
||||
|
||||
@@ -660,12 +670,9 @@ class _ToolNode(RunnableCallable):
|
||||
)
|
||||
tool_runtimes.append(tool_runtime)
|
||||
|
||||
injected_tool_calls = []
|
||||
coros = []
|
||||
for call, tool_runtime in zip(tool_calls, tool_runtimes, strict=False):
|
||||
injected_call = self._inject_tool_args(call, tool_runtime) # type: ignore[arg-type]
|
||||
injected_tool_calls.append(injected_call)
|
||||
coros.append(self._arun_one(injected_call, input_type, tool_runtime)) # type: ignore[arg-type]
|
||||
coros.append(self._arun_one(call, input_type, tool_runtime)) # type: ignore[arg-type]
|
||||
outputs = await asyncio.gather(*coros)
|
||||
|
||||
return self._combine_tool_outputs(outputs, input_type)
|
||||
@@ -742,12 +749,15 @@ class _ToolNode(RunnableCallable):
|
||||
msg = f"Tool {call['name']} is not registered with ToolNode"
|
||||
raise TypeError(msg)
|
||||
|
||||
call_args = {**call, "type": "tool_call"}
|
||||
# Inject state, store, and runtime right before invocation
|
||||
injected_call = self._inject_tool_args(call, request.runtime)
|
||||
call_args = {**injected_call, "type": "tool_call"}
|
||||
|
||||
try:
|
||||
try:
|
||||
response = tool.invoke(call_args, config)
|
||||
except ValidationError as exc:
|
||||
# Use original call["args"] without injected values for error reporting
|
||||
raise ToolInvocationError(call["name"], exc, call["args"]) from exc
|
||||
|
||||
# GraphInterrupt is a special exception that will always be raised.
|
||||
@@ -887,12 +897,15 @@ class _ToolNode(RunnableCallable):
|
||||
msg = f"Tool {call['name']} is not registered with ToolNode"
|
||||
raise TypeError(msg)
|
||||
|
||||
call_args = {**call, "type": "tool_call"}
|
||||
# Inject state, store, and runtime right before invocation
|
||||
injected_call = self._inject_tool_args(call, request.runtime)
|
||||
call_args = {**injected_call, "type": "tool_call"}
|
||||
|
||||
try:
|
||||
try:
|
||||
response = await tool.ainvoke(call_args, config)
|
||||
except ValidationError as exc:
|
||||
# Use original call["args"] without injected values for error reporting
|
||||
raise ToolInvocationError(call["name"], exc, call["args"]) from exc
|
||||
|
||||
# GraphInterrupt is a special exception that will always be raised.
|
||||
|
||||
@@ -0,0 +1,671 @@
|
||||
"""Unit tests for ValidationError filtering in ToolNode.
|
||||
|
||||
This module tests that validation errors are filtered to only include arguments
|
||||
that the LLM controls. Injected arguments (InjectedState, InjectedStore,
|
||||
ToolRuntime) are automatically provided by the system and should not appear in
|
||||
validation error messages. This ensures the LLM receives focused, actionable
|
||||
feedback about the parameters it can actually control, improving error correction
|
||||
and reducing confusion from irrelevant system implementation details.
|
||||
"""
|
||||
|
||||
from typing import Annotated
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langchain_core.tools import tool as dec_tool
|
||||
from langgraph.store.base import BaseStore
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware.types import AgentState
|
||||
from langchain.tools import InjectedState, InjectedStore
|
||||
from langchain.tools.tool_node import ToolInvocationError, ToolRuntime, _ToolNode
|
||||
|
||||
from .model import FakeToolCallingModel
|
||||
|
||||
pytestmark = pytest.mark.anyio
|
||||
|
||||
|
||||
def _create_mock_runtime(store: BaseStore | None = None) -> Mock:
|
||||
"""Create a mock Runtime object for testing ToolNode outside of graph context."""
|
||||
mock_runtime = Mock()
|
||||
mock_runtime.store = store
|
||||
mock_runtime.context = None
|
||||
mock_runtime.stream_writer = lambda *args, **kwargs: None
|
||||
return mock_runtime
|
||||
|
||||
|
||||
def _create_config_with_runtime(store: BaseStore | None = None) -> RunnableConfig:
|
||||
"""Create a RunnableConfig with mock Runtime for testing ToolNode."""
|
||||
return {"configurable": {"__pregel_runtime": _create_mock_runtime(store)}}
|
||||
|
||||
|
||||
async def test_filter_injected_state_validation_errors() -> None:
|
||||
"""Test that validation errors for InjectedState arguments are filtered out.
|
||||
|
||||
InjectedState parameters are not controlled by the LLM, so any validation
|
||||
errors related to them should not appear in error messages. This ensures
|
||||
the LLM receives only actionable feedback about its own tool call arguments.
|
||||
"""
|
||||
|
||||
@dec_tool
|
||||
def my_tool(
|
||||
value: int,
|
||||
state: Annotated[dict, InjectedState],
|
||||
) -> str:
|
||||
"""Tool that uses injected state.
|
||||
|
||||
Args:
|
||||
value: An integer value.
|
||||
state: The graph state (injected).
|
||||
"""
|
||||
return f"value={value}, messages={len(state.get('messages', []))}"
|
||||
|
||||
tool_node = _ToolNode([my_tool])
|
||||
|
||||
# Call with invalid 'value' argument (should be int, not str)
|
||||
result = await tool_node.ainvoke(
|
||||
{
|
||||
"messages": [
|
||||
AIMessage(
|
||||
"hi?",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "my_tool",
|
||||
"args": {"value": "not_an_int"}, # Invalid type
|
||||
"id": "call_1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
)
|
||||
]
|
||||
},
|
||||
config=_create_config_with_runtime(),
|
||||
)
|
||||
|
||||
# Should get a ToolMessage with error
|
||||
assert len(result["messages"]) == 1
|
||||
tool_message = result["messages"][0]
|
||||
assert tool_message.status == "error"
|
||||
assert tool_message.tool_call_id == "call_1"
|
||||
|
||||
# Error should mention 'value' but NOT 'state' (which is injected)
|
||||
assert "value" in tool_message.content
|
||||
assert "state" not in tool_message.content.lower()
|
||||
|
||||
|
||||
async def test_filter_injected_store_validation_errors() -> None:
|
||||
"""Test that validation errors for InjectedStore arguments are filtered out.
|
||||
|
||||
InjectedStore parameters are not controlled by the LLM, so any validation
|
||||
errors related to them should not appear in error messages. This keeps
|
||||
error feedback focused on LLM-controllable parameters.
|
||||
"""
|
||||
|
||||
@dec_tool
|
||||
def my_tool(
|
||||
key: str,
|
||||
store: Annotated[BaseStore, InjectedStore()],
|
||||
) -> str:
|
||||
"""Tool that uses injected store.
|
||||
|
||||
Args:
|
||||
key: A key to look up.
|
||||
store: The persistent store (injected).
|
||||
"""
|
||||
return f"key={key}"
|
||||
|
||||
tool_node = _ToolNode([my_tool])
|
||||
|
||||
# Call with invalid 'key' argument (missing required argument)
|
||||
result = await tool_node.ainvoke(
|
||||
{
|
||||
"messages": [
|
||||
AIMessage(
|
||||
"hi?",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "my_tool",
|
||||
"args": {}, # Missing 'key'
|
||||
"id": "call_1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
)
|
||||
]
|
||||
},
|
||||
config=_create_config_with_runtime(store=InMemoryStore()),
|
||||
)
|
||||
|
||||
# Should get a ToolMessage with error
|
||||
assert len(result["messages"]) == 1
|
||||
tool_message = result["messages"][0]
|
||||
assert tool_message.status == "error"
|
||||
|
||||
# Error should mention 'key' is required
|
||||
assert "key" in tool_message.content.lower()
|
||||
# The error should be about 'key' field specifically (not about store field)
|
||||
# Note: 'store' might appear in input_value representation, but the validation
|
||||
# error itself should only be for 'key'
|
||||
assert (
|
||||
"field required" in tool_message.content.lower()
|
||||
or "missing" in tool_message.content.lower()
|
||||
)
|
||||
|
||||
|
||||
async def test_filter_tool_runtime_validation_errors() -> None:
|
||||
"""Test that validation errors for ToolRuntime arguments are filtered out.
|
||||
|
||||
ToolRuntime parameters are not controlled by the LLM, so any validation
|
||||
errors related to them should not appear in error messages. This ensures
|
||||
the LLM only sees errors for parameters it can fix.
|
||||
"""
|
||||
|
||||
@dec_tool
|
||||
def my_tool(
|
||||
query: str,
|
||||
runtime: ToolRuntime,
|
||||
) -> str:
|
||||
"""Tool that uses ToolRuntime.
|
||||
|
||||
Args:
|
||||
query: A query string.
|
||||
runtime: The tool runtime context (injected).
|
||||
"""
|
||||
return f"query={query}"
|
||||
|
||||
tool_node = _ToolNode([my_tool])
|
||||
|
||||
# Call with invalid 'query' argument (wrong type)
|
||||
result = await tool_node.ainvoke(
|
||||
{
|
||||
"messages": [
|
||||
AIMessage(
|
||||
"hi?",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "my_tool",
|
||||
"args": {"query": 123}, # Should be str, not int
|
||||
"id": "call_1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
)
|
||||
]
|
||||
},
|
||||
config=_create_config_with_runtime(),
|
||||
)
|
||||
|
||||
# Should get a ToolMessage with error
|
||||
assert len(result["messages"]) == 1
|
||||
tool_message = result["messages"][0]
|
||||
assert tool_message.status == "error"
|
||||
|
||||
# Error should mention 'query' but NOT 'runtime' (which is injected)
|
||||
assert "query" in tool_message.content.lower()
|
||||
assert "runtime" not in tool_message.content.lower()
|
||||
|
||||
|
||||
async def test_filter_multiple_injected_args() -> None:
|
||||
"""Test filtering when a tool has multiple injected arguments.
|
||||
|
||||
When a tool uses multiple injected parameters (state, store, runtime), none of
|
||||
them should appear in validation error messages since they're all system-provided
|
||||
and not controlled by the LLM. Only LLM-controllable parameter errors should appear.
|
||||
"""
|
||||
|
||||
@dec_tool
|
||||
def my_tool(
|
||||
value: int,
|
||||
state: Annotated[dict, InjectedState],
|
||||
store: Annotated[BaseStore, InjectedStore()],
|
||||
runtime: ToolRuntime,
|
||||
) -> str:
|
||||
"""Tool with multiple injected arguments.
|
||||
|
||||
Args:
|
||||
value: An integer value.
|
||||
state: The graph state (injected).
|
||||
store: The persistent store (injected).
|
||||
runtime: The tool runtime context (injected).
|
||||
"""
|
||||
return f"value={value}"
|
||||
|
||||
tool_node = _ToolNode([my_tool])
|
||||
|
||||
# Call with invalid 'value' - injected args should be filtered from error
|
||||
result = await tool_node.ainvoke(
|
||||
{
|
||||
"messages": [
|
||||
AIMessage(
|
||||
"hi?",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "my_tool",
|
||||
"args": {"value": "not_an_int"},
|
||||
"id": "call_1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
)
|
||||
]
|
||||
},
|
||||
config=_create_config_with_runtime(store=InMemoryStore()),
|
||||
)
|
||||
|
||||
tool_message = result["messages"][0]
|
||||
assert tool_message.status == "error"
|
||||
|
||||
# Only 'value' error should be reported
|
||||
assert "value" in tool_message.content
|
||||
# None of the injected args should appear in error
|
||||
assert "state" not in tool_message.content.lower()
|
||||
assert "store" not in tool_message.content.lower()
|
||||
assert "runtime" not in tool_message.content.lower()
|
||||
|
||||
|
||||
async def test_no_filtering_when_all_errors_are_model_args() -> None:
|
||||
"""Test that validation errors for LLM-controlled arguments are preserved.
|
||||
|
||||
When validation fails for arguments the LLM controls, those errors should
|
||||
be fully reported to help the LLM correct its tool calls. This ensures
|
||||
the LLM receives complete feedback about all issues it can fix.
|
||||
"""
|
||||
|
||||
@dec_tool
|
||||
def my_tool(
|
||||
value1: int,
|
||||
value2: str,
|
||||
state: Annotated[dict, InjectedState],
|
||||
) -> str:
|
||||
"""Tool with both regular and injected arguments.
|
||||
|
||||
Args:
|
||||
value1: First value.
|
||||
value2: Second value.
|
||||
state: The graph state (injected).
|
||||
"""
|
||||
return f"value1={value1}, value2={value2}"
|
||||
|
||||
tool_node = _ToolNode([my_tool])
|
||||
|
||||
# Call with invalid arguments for BOTH non-injected parameters
|
||||
result = await tool_node.ainvoke(
|
||||
{
|
||||
"messages": [
|
||||
AIMessage(
|
||||
"hi?",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "my_tool",
|
||||
"args": {
|
||||
"value1": "not_an_int", # Invalid
|
||||
"value2": 456, # Invalid (should be str)
|
||||
},
|
||||
"id": "call_1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
)
|
||||
]
|
||||
},
|
||||
config=_create_config_with_runtime(),
|
||||
)
|
||||
|
||||
tool_message = result["messages"][0]
|
||||
assert tool_message.status == "error"
|
||||
|
||||
# Both errors should be present
|
||||
assert "value1" in tool_message.content
|
||||
assert "value2" in tool_message.content
|
||||
# Injected state should not appear
|
||||
assert "state" not in tool_message.content.lower()
|
||||
|
||||
|
||||
async def test_validation_error_with_no_injected_args() -> None:
|
||||
"""Test that tools without injected arguments show all validation errors.
|
||||
|
||||
For tools that only have LLM-controlled parameters, all validation errors
|
||||
should be reported since everything is under the LLM's control and can be
|
||||
corrected by the LLM in subsequent tool calls.
|
||||
"""
|
||||
|
||||
@dec_tool
|
||||
def my_tool(value1: int, value2: str) -> str:
|
||||
"""Regular tool without injected arguments.
|
||||
|
||||
Args:
|
||||
value1: First value.
|
||||
value2: Second value.
|
||||
"""
|
||||
return f"{value1} {value2}"
|
||||
|
||||
tool_node = _ToolNode([my_tool])
|
||||
|
||||
result = await tool_node.ainvoke(
|
||||
{
|
||||
"messages": [
|
||||
AIMessage(
|
||||
"hi?",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "my_tool",
|
||||
"args": {"value1": "invalid", "value2": 123},
|
||||
"id": "call_1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
)
|
||||
]
|
||||
},
|
||||
config=_create_config_with_runtime(),
|
||||
)
|
||||
|
||||
tool_message = result["messages"][0]
|
||||
assert tool_message.status == "error"
|
||||
|
||||
# Both errors should be present since there are no injected args to filter
|
||||
assert "value1" in tool_message.content
|
||||
assert "value2" in tool_message.content
|
||||
|
||||
|
||||
async def test_tool_invocation_error_without_handle_errors() -> None:
|
||||
"""Test that ToolInvocationError contains only LLM-controlled parameter errors.
|
||||
|
||||
When handle_tool_errors is False, the raised ToolInvocationError should still
|
||||
filter out system-injected arguments from the error details, ensuring that
|
||||
error messages focus on what the LLM can control.
|
||||
"""
|
||||
|
||||
@dec_tool
|
||||
def my_tool(
|
||||
value: int,
|
||||
state: Annotated[dict, InjectedState],
|
||||
) -> str:
|
||||
"""Tool with injected state.
|
||||
|
||||
Args:
|
||||
value: An integer value.
|
||||
state: The graph state (injected).
|
||||
"""
|
||||
return f"value={value}"
|
||||
|
||||
tool_node = _ToolNode([my_tool], handle_tool_errors=False)
|
||||
|
||||
# Should raise ToolInvocationError with filtered errors
|
||||
with pytest.raises(ToolInvocationError) as exc_info:
|
||||
await tool_node.ainvoke(
|
||||
{
|
||||
"messages": [
|
||||
AIMessage(
|
||||
"hi?",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "my_tool",
|
||||
"args": {"value": "not_an_int"},
|
||||
"id": "call_1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
)
|
||||
]
|
||||
},
|
||||
config=_create_config_with_runtime(),
|
||||
)
|
||||
|
||||
error = exc_info.value
|
||||
assert error.tool_name == "my_tool"
|
||||
assert error.filtered_errors is not None
|
||||
assert len(error.filtered_errors) > 0
|
||||
|
||||
# Filtered errors should only contain 'value' error, not 'state'
|
||||
error_locs = [err["loc"] for err in error.filtered_errors]
|
||||
assert any("value" in str(loc) for loc in error_locs)
|
||||
assert not any("state" in str(loc) for loc in error_locs)
|
||||
|
||||
|
||||
async def test_sync_tool_validation_error_filtering() -> None:
|
||||
"""Test that error filtering works for sync tools.
|
||||
|
||||
Error filtering should work identically for both sync and async tool execution,
|
||||
excluding injected arguments from validation error messages.
|
||||
"""
|
||||
|
||||
@dec_tool
|
||||
def my_tool(
|
||||
value: int,
|
||||
state: Annotated[dict, InjectedState],
|
||||
) -> str:
|
||||
"""Sync tool with injected state.
|
||||
|
||||
Args:
|
||||
value: An integer value.
|
||||
state: The graph state (injected).
|
||||
"""
|
||||
return f"value={value}"
|
||||
|
||||
tool_node = _ToolNode([my_tool])
|
||||
|
||||
# Test sync invocation
|
||||
result = tool_node.invoke(
|
||||
{
|
||||
"messages": [
|
||||
AIMessage(
|
||||
"hi?",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "my_tool",
|
||||
"args": {"value": "not_an_int"},
|
||||
"id": "call_1",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
)
|
||||
]
|
||||
},
|
||||
config=_create_config_with_runtime(),
|
||||
)
|
||||
|
||||
tool_message = result["messages"][0]
|
||||
assert tool_message.status == "error"
|
||||
assert "value" in tool_message.content
|
||||
assert "state" not in tool_message.content.lower()
|
||||
|
||||
|
||||
async def test_create_agent_error_content_with_multiple_params() -> None:
|
||||
"""Test that error messages only include LLM-controlled parameter errors.
|
||||
|
||||
Uses create_agent to verify that when a tool with both LLM-controlled
|
||||
and system-injected parameters receives invalid arguments, the error message:
|
||||
1. Contains details about LLM-controlled parameter errors (query, limit)
|
||||
2. Does NOT contain system-injected parameter names (state, store, runtime)
|
||||
3. Does NOT contain values from system-injected parameters
|
||||
4. Properly formats the validation errors for LLM correction
|
||||
|
||||
This ensures the LLM receives focused, actionable feedback.
|
||||
"""
|
||||
|
||||
class TestState(AgentState):
|
||||
user_id: str
|
||||
api_key: str
|
||||
session_data: dict
|
||||
|
||||
@dec_tool
|
||||
def complex_tool(
|
||||
query: str,
|
||||
limit: int,
|
||||
state: Annotated[TestState, InjectedState],
|
||||
store: Annotated[BaseStore, InjectedStore()],
|
||||
runtime: ToolRuntime,
|
||||
) -> str:
|
||||
"""A complex tool with multiple injected and non-injected parameters.
|
||||
|
||||
Args:
|
||||
query: The search query string.
|
||||
limit: Maximum number of results to return.
|
||||
state: The graph state (injected).
|
||||
store: The persistent store (injected).
|
||||
runtime: The tool runtime context (injected).
|
||||
"""
|
||||
# Access injected params to verify they work in normal execution
|
||||
user = state.get("user_id", "unknown")
|
||||
return f"Results for '{query}' (limit={limit}, user={user})"
|
||||
|
||||
# Create a model that makes an incorrect tool call with multiple errors:
|
||||
# - query is wrong type (int instead of str)
|
||||
# - limit is missing
|
||||
# Then returns no tool calls to end the loop
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
{
|
||||
"name": "complex_tool",
|
||||
"args": {
|
||||
"query": 12345, # Wrong type - should be str
|
||||
# "limit" is missing - required field
|
||||
},
|
||||
"id": "call_complex_1",
|
||||
}
|
||||
],
|
||||
[], # No tool calls on second iteration to end the loop
|
||||
]
|
||||
)
|
||||
|
||||
# Create an agent with the complex tool and custom state
|
||||
# Need to provide a store since the tool uses InjectedStore
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[complex_tool],
|
||||
state_schema=TestState,
|
||||
store=InMemoryStore(),
|
||||
)
|
||||
|
||||
# Invoke with sensitive data in state
|
||||
result = agent.invoke(
|
||||
{
|
||||
"messages": [HumanMessage("Search for something")],
|
||||
"user_id": "user_12345",
|
||||
"api_key": "sk-secret-key-abc123xyz",
|
||||
"session_data": {"token": "secret_session_token"},
|
||||
}
|
||||
)
|
||||
|
||||
# Find the tool error message
|
||||
tool_messages = [m for m in result["messages"] if m.type == "tool"]
|
||||
assert len(tool_messages) == 1
|
||||
tool_message = tool_messages[0]
|
||||
assert tool_message.status == "error"
|
||||
assert tool_message.tool_call_id == "call_complex_1"
|
||||
|
||||
content = tool_message.content
|
||||
|
||||
# Verify error mentions LLM-controlled parameter issues
|
||||
assert "query" in content.lower(), "Error should mention 'query' (LLM-controlled)"
|
||||
assert "limit" in content.lower(), "Error should mention 'limit' (LLM-controlled)"
|
||||
|
||||
# Should indicate validation errors occurred
|
||||
assert "validation error" in content.lower() or "error" in content.lower(), (
|
||||
"Error should indicate validation occurred"
|
||||
)
|
||||
|
||||
# Verify NO system-injected parameter names appear in error
|
||||
# These are not controlled by the LLM and should be excluded
|
||||
assert "state" not in content.lower(), "Error should NOT mention 'state' (system-injected)"
|
||||
assert "store" not in content.lower(), "Error should NOT mention 'store' (system-injected)"
|
||||
assert "runtime" not in content.lower(), "Error should NOT mention 'runtime' (system-injected)"
|
||||
|
||||
# Verify NO values from system-injected parameters appear in error
|
||||
# The LLM doesn't control these, so they shouldn't distract from the actual issues
|
||||
assert "user_12345" not in content, "Error should NOT contain user_id value (from state)"
|
||||
assert "sk-secret-key" not in content, "Error should NOT contain api_key value (from state)"
|
||||
assert "secret_session_token" not in content, (
|
||||
"Error should NOT contain session_data value (from state)"
|
||||
)
|
||||
|
||||
# Verify the LLM's original tool call args are present
|
||||
# The error should show what the LLM actually provided to help it correct the mistake
|
||||
assert "12345" in content, "Error should show the invalid query value provided by LLM (12345)"
|
||||
|
||||
# Check error is well-formatted
|
||||
assert "complex_tool" in content, "Error should mention the tool name"
|
||||
|
||||
|
||||
async def test_create_agent_error_only_model_controllable_params() -> None:
|
||||
"""Test that errors only include LLM-controllable parameter issues.
|
||||
|
||||
Focused test ensuring that validation errors for LLM-controlled parameters
|
||||
are clearly reported, while system-injected parameters remain completely
|
||||
absent from error messages. This provides focused feedback to the LLM.
|
||||
"""
|
||||
|
||||
class StateWithSecrets(AgentState):
|
||||
password: str # Example of data not controlled by LLM
|
||||
|
||||
@dec_tool
|
||||
def secure_tool(
|
||||
username: str,
|
||||
email: str,
|
||||
state: Annotated[StateWithSecrets, InjectedState],
|
||||
) -> str:
|
||||
"""Tool that validates user credentials.
|
||||
|
||||
Args:
|
||||
username: The username (3-20 chars).
|
||||
email: The email address.
|
||||
state: State with password (system-injected).
|
||||
"""
|
||||
return f"Validated {username} with email {email}"
|
||||
|
||||
# LLM provides invalid username (too short) and invalid email
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
{
|
||||
"name": "secure_tool",
|
||||
"args": {
|
||||
"username": "ab", # Too short (needs 3-20)
|
||||
"email": "not-an-email", # Invalid format
|
||||
},
|
||||
"id": "call_secure_1",
|
||||
}
|
||||
],
|
||||
[],
|
||||
]
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[secure_tool],
|
||||
state_schema=StateWithSecrets,
|
||||
)
|
||||
|
||||
result = agent.invoke(
|
||||
{
|
||||
"messages": [HumanMessage("Create account")],
|
||||
"password": "super_secret_password_12345",
|
||||
}
|
||||
)
|
||||
|
||||
tool_messages = [m for m in result["messages"] if m.type == "tool"]
|
||||
assert len(tool_messages) == 1
|
||||
content = tool_messages[0].content
|
||||
|
||||
# The error should mention LLM-controlled parameters
|
||||
# Note: Pydantic's default validation may or may not catch format issues,
|
||||
# but the parameters themselves should be present in error messages
|
||||
assert "username" in content.lower() or "email" in content.lower(), (
|
||||
"Error should mention at least one LLM-controlled parameter"
|
||||
)
|
||||
|
||||
# Password is system-injected and should not appear
|
||||
# The LLM doesn't control it, so it shouldn't distract from the actual errors
|
||||
assert "password" not in content.lower(), (
|
||||
"Error should NOT mention 'password' (system-injected parameter)"
|
||||
)
|
||||
assert "super_secret_password" not in content, (
|
||||
"Error should NOT contain password value (from system-injected state)"
|
||||
)
|
||||
Reference in New Issue
Block a user