mirror of
https://github.com/hwchase17/langchain.git
synced 2026-07-01 14:47:02 +00:00
refactor(langchain): refactor test_create_agent_tool_validation (#34443)
Simplify test for `create_agent` errors. * Remove duplicate tests * Test sync and async with common logic --------- Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
committed by
GitHub
parent
3d3a4c27cc
commit
3eee4002d9
@@ -3,6 +3,7 @@ from typing import Annotated, Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.prebuilt import InjectedStore, ToolRuntime
|
||||
from langgraph.store.base import BaseStore
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
@@ -13,6 +14,12 @@ from langchain.tools import tool as dec_tool
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
|
||||
|
||||
class UserState(AgentState[Any]):
|
||||
user_id: str
|
||||
api_key: str
|
||||
session_data: dict[str, Any]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info >= (3, 14), reason="Pydantic model rebuild issue in Python 3.14"
|
||||
)
|
||||
@@ -27,9 +34,8 @@ def test_tool_invocation_error_excludes_injected_state() -> None:
|
||||
This test uses create_agent to ensure the behavior works in a full agent context.
|
||||
"""
|
||||
|
||||
# Define a custom state schema with injected data
|
||||
class TestState(AgentState[Any]):
|
||||
secret_data: str # Example of state data not controlled by LLM
|
||||
secret_data: str
|
||||
|
||||
@dec_tool
|
||||
def tool_with_injected_state(
|
||||
@@ -40,29 +46,25 @@ def test_tool_invocation_error_excludes_injected_state() -> None:
|
||||
_ = (some_val, state)
|
||||
return "ok"
|
||||
|
||||
# Create a fake model that makes an incorrect tool call (missing 'some_val')
|
||||
# Then returns no tool calls on the second iteration to end the loop
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
{
|
||||
"name": "tool_with_injected_state",
|
||||
"args": {"wrong_arg": "value"}, # Missing required 'some_val'
|
||||
"args": {"wrong_arg": "value"},
|
||||
"id": "call_1",
|
||||
}
|
||||
],
|
||||
[], # No tool calls on second iteration to end the loop
|
||||
[],
|
||||
]
|
||||
)
|
||||
|
||||
# Create an agent with the tool and custom state schema
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[tool_with_injected_state],
|
||||
state_schema=TestState,
|
||||
)
|
||||
|
||||
# Invoke the agent with injected state data
|
||||
result = agent.invoke(
|
||||
{
|
||||
"messages": [HumanMessage("Test message")],
|
||||
@@ -70,14 +72,11 @@ def test_tool_invocation_error_excludes_injected_state() -> None:
|
||||
}
|
||||
)
|
||||
|
||||
# Find the tool error message
|
||||
tool_messages = [m for m in result["messages"] if m.type == "tool"]
|
||||
assert len(tool_messages) == 1
|
||||
tool_message = tool_messages[0]
|
||||
assert tool_message.status == "error"
|
||||
|
||||
# The error message should contain only the LLM-provided args (wrong_arg)
|
||||
# and NOT the system-injected state (secret_data)
|
||||
assert "{'wrong_arg': 'value'}" in tool_message.content
|
||||
assert "secret_data" not in tool_message.content
|
||||
assert "sensitive_secret_123" not in tool_message.content
|
||||
@@ -94,7 +93,6 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None:
|
||||
the LLM receives only relevant context for correction.
|
||||
"""
|
||||
|
||||
# Define a custom state schema
|
||||
class TestState(AgentState[Any]):
|
||||
internal_data: str
|
||||
|
||||
@@ -108,30 +106,25 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None:
|
||||
_ = (query, max_results, state)
|
||||
return "ok"
|
||||
|
||||
# Create a fake model that makes an incorrect tool call
|
||||
# - query has wrong type (int instead of str)
|
||||
# - max_results is missing
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
{
|
||||
"name": "async_tool_with_injected_state",
|
||||
"args": {"query": 999}, # Wrong type, missing max_results
|
||||
"args": {"query": 999},
|
||||
"id": "call_async_1",
|
||||
}
|
||||
],
|
||||
[], # End the loop
|
||||
[],
|
||||
]
|
||||
)
|
||||
|
||||
# Create an agent with the async tool
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[async_tool_with_injected_state],
|
||||
state_schema=TestState,
|
||||
)
|
||||
|
||||
# Invoke with state data
|
||||
result = await agent.ainvoke(
|
||||
{
|
||||
"messages": [HumanMessage("Test async")],
|
||||
@@ -139,19 +132,14 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None:
|
||||
}
|
||||
)
|
||||
|
||||
# Find the tool error message
|
||||
tool_messages = [m for m in result["messages"] if m.type == "tool"]
|
||||
assert len(tool_messages) == 1
|
||||
tool_message = tool_messages[0]
|
||||
assert tool_message.status == "error"
|
||||
|
||||
# Verify error mentions LLM-controlled parameters only
|
||||
content = tool_message.content
|
||||
assert "query" in content.lower(), "Error should mention 'query' (LLM-controlled)"
|
||||
assert "max_results" in content.lower(), "Error should mention 'max_results' (LLM-controlled)"
|
||||
|
||||
# Verify system-injected state does not appear in the validation errors
|
||||
# This keeps the error focused on what the LLM can actually fix
|
||||
assert "internal_data" not in content, (
|
||||
"Error should NOT mention 'internal_data' (system-injected field)"
|
||||
)
|
||||
@@ -159,11 +147,8 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None:
|
||||
"Error should NOT contain system-injected state values"
|
||||
)
|
||||
|
||||
# Verify only LLM-controlled parameters are in the error list
|
||||
# Should see "query" and "max_results" errors, but not "state"
|
||||
lines = content.split("\n")
|
||||
error_lines = [line.strip() for line in lines if line.strip()]
|
||||
# Find lines that look like field names (single words at start of line)
|
||||
field_errors = [
|
||||
line
|
||||
for line in error_lines
|
||||
@@ -174,7 +159,6 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None:
|
||||
and not line.startswith("please")
|
||||
and len(line.split()) <= 2
|
||||
]
|
||||
# Verify system-injected 'state' is not in the field error list
|
||||
assert not any(field.lower() == "state" for field in field_errors), (
|
||||
"The field 'state' (system-injected) should not appear in validation errors"
|
||||
)
|
||||
@@ -183,7 +167,7 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None:
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info >= (3, 14), reason="Pydantic model rebuild issue in Python 3.14"
|
||||
)
|
||||
def test_create_agent_error_content_with_multiple_params() -> None:
|
||||
def test_create_agent_error() -> None:
|
||||
"""Test that error messages only include LLM-controlled parameter errors.
|
||||
|
||||
Uses create_agent to verify that when a tool with both LLM-controlled
|
||||
@@ -195,16 +179,11 @@ def test_create_agent_error_content_with_multiple_params() -> None:
|
||||
This ensures the LLM receives focused, actionable feedback.
|
||||
"""
|
||||
|
||||
class TestState(AgentState[Any]):
|
||||
user_id: str
|
||||
api_key: str
|
||||
session_data: dict[str, Any]
|
||||
|
||||
@dec_tool
|
||||
def complex_tool(
|
||||
query: str,
|
||||
limit: int,
|
||||
state: Annotated[TestState, InjectedState],
|
||||
state: Annotated[UserState, InjectedState],
|
||||
store: Annotated[BaseStore, InjectedStore()],
|
||||
runtime: ToolRuntime,
|
||||
) -> str:
|
||||
@@ -220,10 +199,64 @@ def test_create_agent_error_content_with_multiple_params() -> None:
|
||||
_ = (query, limit, state, store, runtime)
|
||||
return "ok"
|
||||
|
||||
# Create a model that makes an incorrect tool call with multiple errors:
|
||||
# - query is wrong type (int instead of str)
|
||||
# - limit is missing
|
||||
# Then returns no tool calls to end the loop
|
||||
agent, payload = _build_complex_agent(complex_tool)
|
||||
_assert_agent_error(agent.invoke(payload))
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info >= (3, 14), reason="Pydantic model rebuild issue in Python 3.14"
|
||||
)
|
||||
async def test_create_agent_error_async() -> None:
|
||||
"""Test that error messages only include LLM-controlled parameter errors.
|
||||
|
||||
Uses create_agent to verify that when a tool with both LLM-controlled
|
||||
and system-injected parameters receives invalid arguments, the error message:
|
||||
1. Contains details about LLM-controlled parameter errors (query, limit)
|
||||
2. Does NOT contain system-injected parameter names (state, store, runtime)
|
||||
3. Does NOT contain values from system-injected parameters
|
||||
4. Properly formats the validation errors for LLM correction
|
||||
This ensures the LLM receives focused, actionable feedback.
|
||||
"""
|
||||
|
||||
@dec_tool
|
||||
async def complex_tool(
|
||||
query: str,
|
||||
limit: int,
|
||||
state: Annotated[UserState, InjectedState],
|
||||
store: Annotated[BaseStore, InjectedStore()],
|
||||
runtime: ToolRuntime,
|
||||
) -> str:
|
||||
"""A complex tool with multiple injected and non-injected parameters.
|
||||
|
||||
Args:
|
||||
query: The search query string.
|
||||
limit: Maximum number of results to return.
|
||||
state: The graph state (injected).
|
||||
store: The persistent store (injected).
|
||||
runtime: The tool runtime context (injected).
|
||||
"""
|
||||
_ = (query, limit, state, store, runtime)
|
||||
return "ok"
|
||||
|
||||
agent, payload = _build_complex_agent(complex_tool)
|
||||
_assert_agent_error(await agent.ainvoke(payload))
|
||||
|
||||
|
||||
def _build_complex_agent(tool: BaseTool) -> tuple[Any, dict[str, Any]]:
|
||||
"""Build an agent and invocation payload shared by the complex-tool error tests.
|
||||
|
||||
The model issues a single tool call with multiple errors:
|
||||
- `query` has the wrong type (int instead of str)
|
||||
- `wrong_arg` is not an expected parameter
|
||||
- `limit` is missing (required)
|
||||
Then returns no tool calls so the agent loop terminates.
|
||||
|
||||
Args:
|
||||
tool: The complex tool (sync or async) under test.
|
||||
|
||||
Returns:
|
||||
The compiled agent and the invocation payload (with sensitive state values).
|
||||
"""
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
@@ -231,6 +264,7 @@ def test_create_agent_error_content_with_multiple_params() -> None:
|
||||
"name": "complex_tool",
|
||||
"args": {
|
||||
"query": 12345, # Wrong type - should be str
|
||||
"wrong_arg": "value", # Not an expected parameter
|
||||
# "limit" is missing - required field
|
||||
},
|
||||
"id": "call_complex_1",
|
||||
@@ -240,25 +274,25 @@ def test_create_agent_error_content_with_multiple_params() -> None:
|
||||
]
|
||||
)
|
||||
|
||||
# Create an agent with the complex tool and custom state
|
||||
# Need to provide a store since the tool uses InjectedStore
|
||||
# A store is required because the tool uses InjectedStore.
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[complex_tool],
|
||||
state_schema=TestState,
|
||||
tools=[tool],
|
||||
state_schema=UserState,
|
||||
store=InMemoryStore(),
|
||||
)
|
||||
|
||||
# Invoke with sensitive data in state
|
||||
result = agent.invoke(
|
||||
{
|
||||
"messages": [HumanMessage("Search for something")],
|
||||
"user_id": "user_12345",
|
||||
"api_key": "sk-secret-key-abc123xyz",
|
||||
"session_data": {"token": "secret_session_token"},
|
||||
}
|
||||
)
|
||||
payload = {
|
||||
"messages": [HumanMessage("Search for something")],
|
||||
"user_id": "user_12345",
|
||||
"api_key": "sk-secret-key-abc123xyz",
|
||||
"session_data": {"token": "secret_session_token"},
|
||||
}
|
||||
|
||||
return agent, payload
|
||||
|
||||
|
||||
def _assert_agent_error(result: dict[str, Any]) -> None:
|
||||
# Find the tool error message
|
||||
tool_messages = [m for m in result["messages"] if m.type == "tool"]
|
||||
assert len(tool_messages) == 1
|
||||
@@ -266,22 +300,24 @@ def test_create_agent_error_content_with_multiple_params() -> None:
|
||||
assert tool_message.status == "error"
|
||||
assert tool_message.tool_call_id == "call_complex_1"
|
||||
|
||||
content = tool_message.content
|
||||
content = tool_message.content.lower()
|
||||
assert "with error:" in content, "Error should be formatted with an error section"
|
||||
_, _, error = content.partition("with error:")
|
||||
|
||||
# Verify error mentions LLM-controlled parameter issues
|
||||
assert "query" in content.lower(), "Error should mention 'query' (LLM-controlled)"
|
||||
assert "limit" in content.lower(), "Error should mention 'limit' (LLM-controlled)"
|
||||
# Verify the error section mentions LLM-controlled parameter issues
|
||||
assert "query" in error, "Error should mention 'query' (LLM-controlled)"
|
||||
assert "limit" in error, "Error should mention 'limit' (LLM-controlled)"
|
||||
|
||||
# Should indicate validation errors occurred
|
||||
assert "validation error" in content.lower() or "error" in content.lower(), (
|
||||
"Error should indicate validation occurred"
|
||||
)
|
||||
# The LLM's original (invalid) args are echoed back so it can self-correct
|
||||
assert "wrong_arg" in content, "Error should echo the LLM-provided 'wrong_arg'"
|
||||
assert "12345" in content, "Error should show the invalid query value provided by LLM (12345)"
|
||||
assert "complex_tool" in content, "Error should mention the tool name"
|
||||
|
||||
# Verify NO system-injected parameter names appear in error
|
||||
# These are not controlled by the LLM and should be excluded
|
||||
assert "state" not in content.lower(), "Error should NOT mention 'state' (system-injected)"
|
||||
assert "store" not in content.lower(), "Error should NOT mention 'store' (system-injected)"
|
||||
assert "runtime" not in content.lower(), "Error should NOT mention 'runtime' (system-injected)"
|
||||
assert "state" not in content, "Error should NOT mention 'state' (system-injected)"
|
||||
assert "store" not in content, "Error should NOT mention 'store' (system-injected)"
|
||||
assert "runtime" not in content, "Error should NOT mention 'runtime' (system-injected)"
|
||||
|
||||
# Verify NO values from system-injected parameters appear in error
|
||||
# The LLM doesn't control these, so they shouldn't distract from the actual issues
|
||||
@@ -291,13 +327,6 @@ def test_create_agent_error_content_with_multiple_params() -> None:
|
||||
"Error should NOT contain session_data value (from state)"
|
||||
)
|
||||
|
||||
# Verify the LLM's original tool call args are present
|
||||
# The error should show what the LLM actually provided to help it correct the mistake
|
||||
assert "12345" in content, "Error should show the invalid query value provided by LLM (12345)"
|
||||
|
||||
# Check error is well-formatted
|
||||
assert "complex_tool" in content, "Error should mention the tool name"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info >= (3, 14), reason="Pydantic model rebuild issue in Python 3.14"
|
||||
@@ -311,7 +340,7 @@ def test_create_agent_error_only_model_controllable_params() -> None:
|
||||
"""
|
||||
|
||||
class StateWithSecrets(AgentState[Any]):
|
||||
password: str # Example of data not controlled by LLM
|
||||
password: str
|
||||
|
||||
@dec_tool
|
||||
def secure_tool(
|
||||
@@ -329,15 +358,14 @@ def test_create_agent_error_only_model_controllable_params() -> None:
|
||||
_ = state
|
||||
return f"Validated {username} with email {email}"
|
||||
|
||||
# LLM provides invalid username (too short) and invalid email
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
{
|
||||
"name": "secure_tool",
|
||||
"args": {
|
||||
"username": "ab", # Too short (needs 3-20)
|
||||
"email": "not-an-email", # Invalid format
|
||||
"username": "ab",
|
||||
"email": "not-an-email",
|
||||
},
|
||||
"id": "call_secure_1",
|
||||
}
|
||||
@@ -363,15 +391,9 @@ def test_create_agent_error_only_model_controllable_params() -> None:
|
||||
assert len(tool_messages) == 1
|
||||
content = tool_messages[0].content
|
||||
|
||||
# The error should mention LLM-controlled parameters
|
||||
# Note: Pydantic's default validation may or may not catch format issues,
|
||||
# but the parameters themselves should be present in error messages
|
||||
assert "username" in content.lower() or "email" in content.lower(), (
|
||||
"Error should mention at least one LLM-controlled parameter"
|
||||
)
|
||||
|
||||
# Password is system-injected and should not appear
|
||||
# The LLM doesn't control it, so it shouldn't distract from the actual errors
|
||||
assert "password" not in content.lower(), (
|
||||
"Error should NOT mention 'password' (system-injected parameter)"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user