From 3eee4002d924fe47a9ebcba63cef3c03f307b388 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Wed, 10 Jun 2026 20:42:46 +0200 Subject: [PATCH] refactor(langchain): refactor `test_create_agent_tool_validation` (#34443) Simplify test for `create_agent` errors. * Remove duplicate tests * Test sync and async with common logic --------- Co-authored-by: Mason Daugherty --- .../test_create_agent_tool_validation.py | 182 ++++++++++-------- 1 file changed, 102 insertions(+), 80 deletions(-) diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_create_agent_tool_validation.py b/libs/langchain_v1/tests/unit_tests/agents/test_create_agent_tool_validation.py index 571d280714f..10047e4dd30 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_create_agent_tool_validation.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_create_agent_tool_validation.py @@ -3,6 +3,7 @@ from typing import Annotated, Any import pytest from langchain_core.messages import HumanMessage +from langchain_core.tools import BaseTool from langgraph.prebuilt import InjectedStore, ToolRuntime from langgraph.store.base import BaseStore from langgraph.store.memory import InMemoryStore @@ -13,6 +14,12 @@ from langchain.tools import tool as dec_tool from tests.unit_tests.agents.model import FakeToolCallingModel +class UserState(AgentState[Any]): + user_id: str + api_key: str + session_data: dict[str, Any] + + @pytest.mark.skipif( sys.version_info >= (3, 14), reason="Pydantic model rebuild issue in Python 3.14" ) @@ -27,9 +34,8 @@ def test_tool_invocation_error_excludes_injected_state() -> None: This test uses create_agent to ensure the behavior works in a full agent context. """ - # Define a custom state schema with injected data class TestState(AgentState[Any]): - secret_data: str # Example of state data not controlled by LLM + secret_data: str @dec_tool def tool_with_injected_state( @@ -40,29 +46,25 @@ def test_tool_invocation_error_excludes_injected_state() -> None: _ = (some_val, state) return "ok" - # Create a fake model that makes an incorrect tool call (missing 'some_val') - # Then returns no tool calls on the second iteration to end the loop model = FakeToolCallingModel( tool_calls=[ [ { "name": "tool_with_injected_state", - "args": {"wrong_arg": "value"}, # Missing required 'some_val' + "args": {"wrong_arg": "value"}, "id": "call_1", } ], - [], # No tool calls on second iteration to end the loop + [], ] ) - # Create an agent with the tool and custom state schema agent = create_agent( model=model, tools=[tool_with_injected_state], state_schema=TestState, ) - # Invoke the agent with injected state data result = agent.invoke( { "messages": [HumanMessage("Test message")], @@ -70,14 +72,11 @@ def test_tool_invocation_error_excludes_injected_state() -> None: } ) - # Find the tool error message tool_messages = [m for m in result["messages"] if m.type == "tool"] assert len(tool_messages) == 1 tool_message = tool_messages[0] assert tool_message.status == "error" - # The error message should contain only the LLM-provided args (wrong_arg) - # and NOT the system-injected state (secret_data) assert "{'wrong_arg': 'value'}" in tool_message.content assert "secret_data" not in tool_message.content assert "sensitive_secret_123" not in tool_message.content @@ -94,7 +93,6 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None: the LLM receives only relevant context for correction. """ - # Define a custom state schema class TestState(AgentState[Any]): internal_data: str @@ -108,30 +106,25 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None: _ = (query, max_results, state) return "ok" - # Create a fake model that makes an incorrect tool call - # - query has wrong type (int instead of str) - # - max_results is missing model = FakeToolCallingModel( tool_calls=[ [ { "name": "async_tool_with_injected_state", - "args": {"query": 999}, # Wrong type, missing max_results + "args": {"query": 999}, "id": "call_async_1", } ], - [], # End the loop + [], ] ) - # Create an agent with the async tool agent = create_agent( model=model, tools=[async_tool_with_injected_state], state_schema=TestState, ) - # Invoke with state data result = await agent.ainvoke( { "messages": [HumanMessage("Test async")], @@ -139,19 +132,14 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None: } ) - # Find the tool error message tool_messages = [m for m in result["messages"] if m.type == "tool"] assert len(tool_messages) == 1 tool_message = tool_messages[0] assert tool_message.status == "error" - # Verify error mentions LLM-controlled parameters only content = tool_message.content assert "query" in content.lower(), "Error should mention 'query' (LLM-controlled)" assert "max_results" in content.lower(), "Error should mention 'max_results' (LLM-controlled)" - - # Verify system-injected state does not appear in the validation errors - # This keeps the error focused on what the LLM can actually fix assert "internal_data" not in content, ( "Error should NOT mention 'internal_data' (system-injected field)" ) @@ -159,11 +147,8 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None: "Error should NOT contain system-injected state values" ) - # Verify only LLM-controlled parameters are in the error list - # Should see "query" and "max_results" errors, but not "state" lines = content.split("\n") error_lines = [line.strip() for line in lines if line.strip()] - # Find lines that look like field names (single words at start of line) field_errors = [ line for line in error_lines @@ -174,7 +159,6 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None: and not line.startswith("please") and len(line.split()) <= 2 ] - # Verify system-injected 'state' is not in the field error list assert not any(field.lower() == "state" for field in field_errors), ( "The field 'state' (system-injected) should not appear in validation errors" ) @@ -183,7 +167,7 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None: @pytest.mark.skipif( sys.version_info >= (3, 14), reason="Pydantic model rebuild issue in Python 3.14" ) -def test_create_agent_error_content_with_multiple_params() -> None: +def test_create_agent_error() -> None: """Test that error messages only include LLM-controlled parameter errors. Uses create_agent to verify that when a tool with both LLM-controlled @@ -195,16 +179,11 @@ def test_create_agent_error_content_with_multiple_params() -> None: This ensures the LLM receives focused, actionable feedback. """ - class TestState(AgentState[Any]): - user_id: str - api_key: str - session_data: dict[str, Any] - @dec_tool def complex_tool( query: str, limit: int, - state: Annotated[TestState, InjectedState], + state: Annotated[UserState, InjectedState], store: Annotated[BaseStore, InjectedStore()], runtime: ToolRuntime, ) -> str: @@ -220,10 +199,64 @@ def test_create_agent_error_content_with_multiple_params() -> None: _ = (query, limit, state, store, runtime) return "ok" - # Create a model that makes an incorrect tool call with multiple errors: - # - query is wrong type (int instead of str) - # - limit is missing - # Then returns no tool calls to end the loop + agent, payload = _build_complex_agent(complex_tool) + _assert_agent_error(agent.invoke(payload)) + + +@pytest.mark.skipif( + sys.version_info >= (3, 14), reason="Pydantic model rebuild issue in Python 3.14" +) +async def test_create_agent_error_async() -> None: + """Test that error messages only include LLM-controlled parameter errors. + + Uses create_agent to verify that when a tool with both LLM-controlled + and system-injected parameters receives invalid arguments, the error message: + 1. Contains details about LLM-controlled parameter errors (query, limit) + 2. Does NOT contain system-injected parameter names (state, store, runtime) + 3. Does NOT contain values from system-injected parameters + 4. Properly formats the validation errors for LLM correction + This ensures the LLM receives focused, actionable feedback. + """ + + @dec_tool + async def complex_tool( + query: str, + limit: int, + state: Annotated[UserState, InjectedState], + store: Annotated[BaseStore, InjectedStore()], + runtime: ToolRuntime, + ) -> str: + """A complex tool with multiple injected and non-injected parameters. + + Args: + query: The search query string. + limit: Maximum number of results to return. + state: The graph state (injected). + store: The persistent store (injected). + runtime: The tool runtime context (injected). + """ + _ = (query, limit, state, store, runtime) + return "ok" + + agent, payload = _build_complex_agent(complex_tool) + _assert_agent_error(await agent.ainvoke(payload)) + + +def _build_complex_agent(tool: BaseTool) -> tuple[Any, dict[str, Any]]: + """Build an agent and invocation payload shared by the complex-tool error tests. + + The model issues a single tool call with multiple errors: + - `query` has the wrong type (int instead of str) + - `wrong_arg` is not an expected parameter + - `limit` is missing (required) + Then returns no tool calls so the agent loop terminates. + + Args: + tool: The complex tool (sync or async) under test. + + Returns: + The compiled agent and the invocation payload (with sensitive state values). + """ model = FakeToolCallingModel( tool_calls=[ [ @@ -231,6 +264,7 @@ def test_create_agent_error_content_with_multiple_params() -> None: "name": "complex_tool", "args": { "query": 12345, # Wrong type - should be str + "wrong_arg": "value", # Not an expected parameter # "limit" is missing - required field }, "id": "call_complex_1", @@ -240,25 +274,25 @@ def test_create_agent_error_content_with_multiple_params() -> None: ] ) - # Create an agent with the complex tool and custom state - # Need to provide a store since the tool uses InjectedStore + # A store is required because the tool uses InjectedStore. agent = create_agent( model=model, - tools=[complex_tool], - state_schema=TestState, + tools=[tool], + state_schema=UserState, store=InMemoryStore(), ) - # Invoke with sensitive data in state - result = agent.invoke( - { - "messages": [HumanMessage("Search for something")], - "user_id": "user_12345", - "api_key": "sk-secret-key-abc123xyz", - "session_data": {"token": "secret_session_token"}, - } - ) + payload = { + "messages": [HumanMessage("Search for something")], + "user_id": "user_12345", + "api_key": "sk-secret-key-abc123xyz", + "session_data": {"token": "secret_session_token"}, + } + return agent, payload + + +def _assert_agent_error(result: dict[str, Any]) -> None: # Find the tool error message tool_messages = [m for m in result["messages"] if m.type == "tool"] assert len(tool_messages) == 1 @@ -266,22 +300,24 @@ def test_create_agent_error_content_with_multiple_params() -> None: assert tool_message.status == "error" assert tool_message.tool_call_id == "call_complex_1" - content = tool_message.content + content = tool_message.content.lower() + assert "with error:" in content, "Error should be formatted with an error section" + _, _, error = content.partition("with error:") - # Verify error mentions LLM-controlled parameter issues - assert "query" in content.lower(), "Error should mention 'query' (LLM-controlled)" - assert "limit" in content.lower(), "Error should mention 'limit' (LLM-controlled)" + # Verify the error section mentions LLM-controlled parameter issues + assert "query" in error, "Error should mention 'query' (LLM-controlled)" + assert "limit" in error, "Error should mention 'limit' (LLM-controlled)" - # Should indicate validation errors occurred - assert "validation error" in content.lower() or "error" in content.lower(), ( - "Error should indicate validation occurred" - ) + # The LLM's original (invalid) args are echoed back so it can self-correct + assert "wrong_arg" in content, "Error should echo the LLM-provided 'wrong_arg'" + assert "12345" in content, "Error should show the invalid query value provided by LLM (12345)" + assert "complex_tool" in content, "Error should mention the tool name" # Verify NO system-injected parameter names appear in error # These are not controlled by the LLM and should be excluded - assert "state" not in content.lower(), "Error should NOT mention 'state' (system-injected)" - assert "store" not in content.lower(), "Error should NOT mention 'store' (system-injected)" - assert "runtime" not in content.lower(), "Error should NOT mention 'runtime' (system-injected)" + assert "state" not in content, "Error should NOT mention 'state' (system-injected)" + assert "store" not in content, "Error should NOT mention 'store' (system-injected)" + assert "runtime" not in content, "Error should NOT mention 'runtime' (system-injected)" # Verify NO values from system-injected parameters appear in error # The LLM doesn't control these, so they shouldn't distract from the actual issues @@ -291,13 +327,6 @@ def test_create_agent_error_content_with_multiple_params() -> None: "Error should NOT contain session_data value (from state)" ) - # Verify the LLM's original tool call args are present - # The error should show what the LLM actually provided to help it correct the mistake - assert "12345" in content, "Error should show the invalid query value provided by LLM (12345)" - - # Check error is well-formatted - assert "complex_tool" in content, "Error should mention the tool name" - @pytest.mark.skipif( sys.version_info >= (3, 14), reason="Pydantic model rebuild issue in Python 3.14" @@ -311,7 +340,7 @@ def test_create_agent_error_only_model_controllable_params() -> None: """ class StateWithSecrets(AgentState[Any]): - password: str # Example of data not controlled by LLM + password: str @dec_tool def secure_tool( @@ -329,15 +358,14 @@ def test_create_agent_error_only_model_controllable_params() -> None: _ = state return f"Validated {username} with email {email}" - # LLM provides invalid username (too short) and invalid email model = FakeToolCallingModel( tool_calls=[ [ { "name": "secure_tool", "args": { - "username": "ab", # Too short (needs 3-20) - "email": "not-an-email", # Invalid format + "username": "ab", + "email": "not-an-email", }, "id": "call_secure_1", } @@ -363,15 +391,9 @@ def test_create_agent_error_only_model_controllable_params() -> None: assert len(tool_messages) == 1 content = tool_messages[0].content - # The error should mention LLM-controlled parameters - # Note: Pydantic's default validation may or may not catch format issues, - # but the parameters themselves should be present in error messages assert "username" in content.lower() or "email" in content.lower(), ( "Error should mention at least one LLM-controlled parameter" ) - - # Password is system-injected and should not appear - # The LLM doesn't control it, so it shouldn't distract from the actual errors assert "password" not in content.lower(), ( "Error should NOT mention 'password' (system-injected parameter)" )