refactor(langchain): refactor test_create_agent_tool_validation (#34443)

Simplify test for `create_agent` errors.
* Remove duplicate tests
* Test sync and async with common logic

---------

Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
Christophe Bornet
2026-06-10 20:42:46 +02:00
committed by GitHub
parent 3d3a4c27cc
commit 3eee4002d9

View File

@@ -3,6 +3,7 @@ from typing import Annotated, Any
import pytest
from langchain_core.messages import HumanMessage
from langchain_core.tools import BaseTool
from langgraph.prebuilt import InjectedStore, ToolRuntime
from langgraph.store.base import BaseStore
from langgraph.store.memory import InMemoryStore
@@ -13,6 +14,12 @@ from langchain.tools import tool as dec_tool
from tests.unit_tests.agents.model import FakeToolCallingModel
class UserState(AgentState[Any]):
user_id: str
api_key: str
session_data: dict[str, Any]
@pytest.mark.skipif(
sys.version_info >= (3, 14), reason="Pydantic model rebuild issue in Python 3.14"
)
@@ -27,9 +34,8 @@ def test_tool_invocation_error_excludes_injected_state() -> None:
This test uses create_agent to ensure the behavior works in a full agent context.
"""
# Define a custom state schema with injected data
class TestState(AgentState[Any]):
secret_data: str # Example of state data not controlled by LLM
secret_data: str
@dec_tool
def tool_with_injected_state(
@@ -40,29 +46,25 @@ def test_tool_invocation_error_excludes_injected_state() -> None:
_ = (some_val, state)
return "ok"
# Create a fake model that makes an incorrect tool call (missing 'some_val')
# Then returns no tool calls on the second iteration to end the loop
model = FakeToolCallingModel(
tool_calls=[
[
{
"name": "tool_with_injected_state",
"args": {"wrong_arg": "value"}, # Missing required 'some_val'
"args": {"wrong_arg": "value"},
"id": "call_1",
}
],
[], # No tool calls on second iteration to end the loop
[],
]
)
# Create an agent with the tool and custom state schema
agent = create_agent(
model=model,
tools=[tool_with_injected_state],
state_schema=TestState,
)
# Invoke the agent with injected state data
result = agent.invoke(
{
"messages": [HumanMessage("Test message")],
@@ -70,14 +72,11 @@ def test_tool_invocation_error_excludes_injected_state() -> None:
}
)
# Find the tool error message
tool_messages = [m for m in result["messages"] if m.type == "tool"]
assert len(tool_messages) == 1
tool_message = tool_messages[0]
assert tool_message.status == "error"
# The error message should contain only the LLM-provided args (wrong_arg)
# and NOT the system-injected state (secret_data)
assert "{'wrong_arg': 'value'}" in tool_message.content
assert "secret_data" not in tool_message.content
assert "sensitive_secret_123" not in tool_message.content
@@ -94,7 +93,6 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None:
the LLM receives only relevant context for correction.
"""
# Define a custom state schema
class TestState(AgentState[Any]):
internal_data: str
@@ -108,30 +106,25 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None:
_ = (query, max_results, state)
return "ok"
# Create a fake model that makes an incorrect tool call
# - query has wrong type (int instead of str)
# - max_results is missing
model = FakeToolCallingModel(
tool_calls=[
[
{
"name": "async_tool_with_injected_state",
"args": {"query": 999}, # Wrong type, missing max_results
"args": {"query": 999},
"id": "call_async_1",
}
],
[], # End the loop
[],
]
)
# Create an agent with the async tool
agent = create_agent(
model=model,
tools=[async_tool_with_injected_state],
state_schema=TestState,
)
# Invoke with state data
result = await agent.ainvoke(
{
"messages": [HumanMessage("Test async")],
@@ -139,19 +132,14 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None:
}
)
# Find the tool error message
tool_messages = [m for m in result["messages"] if m.type == "tool"]
assert len(tool_messages) == 1
tool_message = tool_messages[0]
assert tool_message.status == "error"
# Verify error mentions LLM-controlled parameters only
content = tool_message.content
assert "query" in content.lower(), "Error should mention 'query' (LLM-controlled)"
assert "max_results" in content.lower(), "Error should mention 'max_results' (LLM-controlled)"
# Verify system-injected state does not appear in the validation errors
# This keeps the error focused on what the LLM can actually fix
assert "internal_data" not in content, (
"Error should NOT mention 'internal_data' (system-injected field)"
)
@@ -159,11 +147,8 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None:
"Error should NOT contain system-injected state values"
)
# Verify only LLM-controlled parameters are in the error list
# Should see "query" and "max_results" errors, but not "state"
lines = content.split("\n")
error_lines = [line.strip() for line in lines if line.strip()]
# Find lines that look like field names (single words at start of line)
field_errors = [
line
for line in error_lines
@@ -174,7 +159,6 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None:
and not line.startswith("please")
and len(line.split()) <= 2
]
# Verify system-injected 'state' is not in the field error list
assert not any(field.lower() == "state" for field in field_errors), (
"The field 'state' (system-injected) should not appear in validation errors"
)
@@ -183,7 +167,7 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None:
@pytest.mark.skipif(
sys.version_info >= (3, 14), reason="Pydantic model rebuild issue in Python 3.14"
)
def test_create_agent_error_content_with_multiple_params() -> None:
def test_create_agent_error() -> None:
"""Test that error messages only include LLM-controlled parameter errors.
Uses create_agent to verify that when a tool with both LLM-controlled
@@ -195,16 +179,11 @@ def test_create_agent_error_content_with_multiple_params() -> None:
This ensures the LLM receives focused, actionable feedback.
"""
class TestState(AgentState[Any]):
user_id: str
api_key: str
session_data: dict[str, Any]
@dec_tool
def complex_tool(
query: str,
limit: int,
state: Annotated[TestState, InjectedState],
state: Annotated[UserState, InjectedState],
store: Annotated[BaseStore, InjectedStore()],
runtime: ToolRuntime,
) -> str:
@@ -220,10 +199,64 @@ def test_create_agent_error_content_with_multiple_params() -> None:
_ = (query, limit, state, store, runtime)
return "ok"
# Create a model that makes an incorrect tool call with multiple errors:
# - query is wrong type (int instead of str)
# - limit is missing
# Then returns no tool calls to end the loop
agent, payload = _build_complex_agent(complex_tool)
_assert_agent_error(agent.invoke(payload))
@pytest.mark.skipif(
sys.version_info >= (3, 14), reason="Pydantic model rebuild issue in Python 3.14"
)
async def test_create_agent_error_async() -> None:
"""Test that error messages only include LLM-controlled parameter errors.
Uses create_agent to verify that when a tool with both LLM-controlled
and system-injected parameters receives invalid arguments, the error message:
1. Contains details about LLM-controlled parameter errors (query, limit)
2. Does NOT contain system-injected parameter names (state, store, runtime)
3. Does NOT contain values from system-injected parameters
4. Properly formats the validation errors for LLM correction
This ensures the LLM receives focused, actionable feedback.
"""
@dec_tool
async def complex_tool(
query: str,
limit: int,
state: Annotated[UserState, InjectedState],
store: Annotated[BaseStore, InjectedStore()],
runtime: ToolRuntime,
) -> str:
"""A complex tool with multiple injected and non-injected parameters.
Args:
query: The search query string.
limit: Maximum number of results to return.
state: The graph state (injected).
store: The persistent store (injected).
runtime: The tool runtime context (injected).
"""
_ = (query, limit, state, store, runtime)
return "ok"
agent, payload = _build_complex_agent(complex_tool)
_assert_agent_error(await agent.ainvoke(payload))
def _build_complex_agent(tool: BaseTool) -> tuple[Any, dict[str, Any]]:
"""Build an agent and invocation payload shared by the complex-tool error tests.
The model issues a single tool call with multiple errors:
- `query` has the wrong type (int instead of str)
- `wrong_arg` is not an expected parameter
- `limit` is missing (required)
Then returns no tool calls so the agent loop terminates.
Args:
tool: The complex tool (sync or async) under test.
Returns:
The compiled agent and the invocation payload (with sensitive state values).
"""
model = FakeToolCallingModel(
tool_calls=[
[
@@ -231,6 +264,7 @@ def test_create_agent_error_content_with_multiple_params() -> None:
"name": "complex_tool",
"args": {
"query": 12345, # Wrong type - should be str
"wrong_arg": "value", # Not an expected parameter
# "limit" is missing - required field
},
"id": "call_complex_1",
@@ -240,25 +274,25 @@ def test_create_agent_error_content_with_multiple_params() -> None:
]
)
# Create an agent with the complex tool and custom state
# Need to provide a store since the tool uses InjectedStore
# A store is required because the tool uses InjectedStore.
agent = create_agent(
model=model,
tools=[complex_tool],
state_schema=TestState,
tools=[tool],
state_schema=UserState,
store=InMemoryStore(),
)
# Invoke with sensitive data in state
result = agent.invoke(
{
"messages": [HumanMessage("Search for something")],
"user_id": "user_12345",
"api_key": "sk-secret-key-abc123xyz",
"session_data": {"token": "secret_session_token"},
}
)
payload = {
"messages": [HumanMessage("Search for something")],
"user_id": "user_12345",
"api_key": "sk-secret-key-abc123xyz",
"session_data": {"token": "secret_session_token"},
}
return agent, payload
def _assert_agent_error(result: dict[str, Any]) -> None:
# Find the tool error message
tool_messages = [m for m in result["messages"] if m.type == "tool"]
assert len(tool_messages) == 1
@@ -266,22 +300,24 @@ def test_create_agent_error_content_with_multiple_params() -> None:
assert tool_message.status == "error"
assert tool_message.tool_call_id == "call_complex_1"
content = tool_message.content
content = tool_message.content.lower()
assert "with error:" in content, "Error should be formatted with an error section"
_, _, error = content.partition("with error:")
# Verify error mentions LLM-controlled parameter issues
assert "query" in content.lower(), "Error should mention 'query' (LLM-controlled)"
assert "limit" in content.lower(), "Error should mention 'limit' (LLM-controlled)"
# Verify the error section mentions LLM-controlled parameter issues
assert "query" in error, "Error should mention 'query' (LLM-controlled)"
assert "limit" in error, "Error should mention 'limit' (LLM-controlled)"
# Should indicate validation errors occurred
assert "validation error" in content.lower() or "error" in content.lower(), (
"Error should indicate validation occurred"
)
# The LLM's original (invalid) args are echoed back so it can self-correct
assert "wrong_arg" in content, "Error should echo the LLM-provided 'wrong_arg'"
assert "12345" in content, "Error should show the invalid query value provided by LLM (12345)"
assert "complex_tool" in content, "Error should mention the tool name"
# Verify NO system-injected parameter names appear in error
# These are not controlled by the LLM and should be excluded
assert "state" not in content.lower(), "Error should NOT mention 'state' (system-injected)"
assert "store" not in content.lower(), "Error should NOT mention 'store' (system-injected)"
assert "runtime" not in content.lower(), "Error should NOT mention 'runtime' (system-injected)"
assert "state" not in content, "Error should NOT mention 'state' (system-injected)"
assert "store" not in content, "Error should NOT mention 'store' (system-injected)"
assert "runtime" not in content, "Error should NOT mention 'runtime' (system-injected)"
# Verify NO values from system-injected parameters appear in error
# The LLM doesn't control these, so they shouldn't distract from the actual issues
@@ -291,13 +327,6 @@ def test_create_agent_error_content_with_multiple_params() -> None:
"Error should NOT contain session_data value (from state)"
)
# Verify the LLM's original tool call args are present
# The error should show what the LLM actually provided to help it correct the mistake
assert "12345" in content, "Error should show the invalid query value provided by LLM (12345)"
# Check error is well-formatted
assert "complex_tool" in content, "Error should mention the tool name"
@pytest.mark.skipif(
sys.version_info >= (3, 14), reason="Pydantic model rebuild issue in Python 3.14"
@@ -311,7 +340,7 @@ def test_create_agent_error_only_model_controllable_params() -> None:
"""
class StateWithSecrets(AgentState[Any]):
password: str # Example of data not controlled by LLM
password: str
@dec_tool
def secure_tool(
@@ -329,15 +358,14 @@ def test_create_agent_error_only_model_controllable_params() -> None:
_ = state
return f"Validated {username} with email {email}"
# LLM provides invalid username (too short) and invalid email
model = FakeToolCallingModel(
tool_calls=[
[
{
"name": "secure_tool",
"args": {
"username": "ab", # Too short (needs 3-20)
"email": "not-an-email", # Invalid format
"username": "ab",
"email": "not-an-email",
},
"id": "call_secure_1",
}
@@ -363,15 +391,9 @@ def test_create_agent_error_only_model_controllable_params() -> None:
assert len(tool_messages) == 1
content = tool_messages[0].content
# The error should mention LLM-controlled parameters
# Note: Pydantic's default validation may or may not catch format issues,
# but the parameters themselves should be present in error messages
assert "username" in content.lower() or "email" in content.lower(), (
"Error should mention at least one LLM-controlled parameter"
)
# Password is system-injected and should not appear
# The LLM doesn't control it, so it shouldn't distract from the actual errors
assert "password" not in content.lower(), (
"Error should NOT mention 'password' (system-injected parameter)"
)