mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 18:50:33 +00:00
chore: add more robust test for runtime injection w/ explicit args_schema (#34051)
This commit is contained in:
@@ -21,11 +21,14 @@ from typing import Any
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.prebuilt import InjectedStore
|
||||
from langgraph.store.base import BaseStore
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware.types import AgentState
|
||||
from langchain.tools import ToolRuntime
|
||||
from langchain.tools import InjectedState, ToolRuntime
|
||||
|
||||
from .model import FakeToolCallingModel
|
||||
|
||||
@@ -589,3 +592,243 @@ def test_tool_runtime_name_based_injection() -> None:
|
||||
assert injected_data["tool_call_id"] == "name_call_123"
|
||||
assert injected_data["state"] is not None
|
||||
assert "messages" in injected_data["state"]
|
||||
|
||||
|
||||
def test_combined_injected_state_runtime_store() -> None:
|
||||
"""Test that all injection mechanisms work together in create_agent.
|
||||
|
||||
This test verifies that a tool can receive injected state, tool runtime,
|
||||
and injected store simultaneously when specified in the function signature
|
||||
but not in the explicit args schema. This is modeled after the pattern
|
||||
from mre.py where multiple injection types are combined.
|
||||
"""
|
||||
# Track what was injected
|
||||
injected_data = {}
|
||||
|
||||
# Custom state schema with additional fields
|
||||
class CustomState(AgentState):
|
||||
user_id: str
|
||||
session_id: str
|
||||
|
||||
# Define explicit args schema that only includes LLM-controlled parameters
|
||||
weather_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "The location to get weather for"},
|
||||
},
|
||||
"required": ["location"],
|
||||
}
|
||||
|
||||
@tool(args_schema=weather_schema)
|
||||
def multi_injection_tool(
|
||||
location: str,
|
||||
state: Annotated[Any, InjectedState],
|
||||
runtime: ToolRuntime,
|
||||
store: Annotated[Any, InjectedStore()],
|
||||
) -> str:
|
||||
"""Tool that uses injected state, runtime, and store together.
|
||||
|
||||
Args:
|
||||
location: The location to get weather for (LLM-controlled).
|
||||
state: The graph state (injected).
|
||||
runtime: The tool runtime context (injected).
|
||||
store: The persistent store (injected).
|
||||
"""
|
||||
# Capture all injected parameters
|
||||
injected_data["state"] = state
|
||||
injected_data["user_id"] = state.get("user_id", "unknown")
|
||||
injected_data["session_id"] = state.get("session_id", "unknown")
|
||||
injected_data["runtime"] = runtime
|
||||
injected_data["tool_call_id"] = runtime.tool_call_id
|
||||
injected_data["store"] = store
|
||||
injected_data["store_is_none"] = store is None
|
||||
|
||||
# Verify runtime.state matches the state parameter
|
||||
injected_data["runtime_state_matches"] = runtime.state == state
|
||||
|
||||
return f"Weather info for {location}"
|
||||
|
||||
# Create model that calls the tool
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
{
|
||||
"name": "multi_injection_tool",
|
||||
"args": {"location": "San Francisco"}, # Only LLM-controlled arg
|
||||
"id": "call_weather_123",
|
||||
}
|
||||
],
|
||||
[], # End the loop
|
||||
]
|
||||
)
|
||||
|
||||
# Create agent with custom state and store
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[multi_injection_tool],
|
||||
state_schema=CustomState,
|
||||
store=InMemoryStore(),
|
||||
)
|
||||
|
||||
# Verify the tool's args schema only includes LLM-controlled parameters
|
||||
tool_args_schema = multi_injection_tool.args_schema
|
||||
assert "location" in tool_args_schema["properties"]
|
||||
assert "state" not in tool_args_schema["properties"]
|
||||
assert "runtime" not in tool_args_schema["properties"]
|
||||
assert "store" not in tool_args_schema["properties"]
|
||||
|
||||
# Invoke with custom state fields
|
||||
result = agent.invoke(
|
||||
{
|
||||
"messages": [HumanMessage("What's the weather like?")],
|
||||
"user_id": "user_42",
|
||||
"session_id": "session_abc123",
|
||||
}
|
||||
)
|
||||
|
||||
# Verify tool executed successfully
|
||||
tool_messages = [msg for msg in result["messages"] if isinstance(msg, ToolMessage)]
|
||||
assert len(tool_messages) == 1
|
||||
tool_message = tool_messages[0]
|
||||
assert tool_message.content == "Weather info for San Francisco"
|
||||
assert tool_message.tool_call_id == "call_weather_123"
|
||||
|
||||
# Verify all injections worked correctly
|
||||
assert injected_data["state"] is not None
|
||||
assert "messages" in injected_data["state"]
|
||||
|
||||
# Verify custom state fields were accessible
|
||||
assert injected_data["user_id"] == "user_42"
|
||||
assert injected_data["session_id"] == "session_abc123"
|
||||
|
||||
# Verify runtime was injected
|
||||
assert injected_data["runtime"] is not None
|
||||
assert injected_data["tool_call_id"] == "call_weather_123"
|
||||
|
||||
# Verify store was injected
|
||||
assert injected_data["store_is_none"] is False
|
||||
assert injected_data["store"] is not None
|
||||
|
||||
# Verify runtime.state matches the injected state
|
||||
assert injected_data["runtime_state_matches"] is True
|
||||
|
||||
|
||||
async def test_combined_injected_state_runtime_store_async() -> None:
|
||||
"""Test that all injection mechanisms work together in async execution.
|
||||
|
||||
This async version verifies that injected state, tool runtime, and injected
|
||||
store all work correctly with async tools in create_agent.
|
||||
"""
|
||||
# Track what was injected
|
||||
injected_data = {}
|
||||
|
||||
# Custom state schema
|
||||
class CustomState(AgentState):
|
||||
api_key: str
|
||||
request_id: str
|
||||
|
||||
# Define explicit args schema that only includes LLM-controlled parameters
|
||||
# Note: state, runtime, and store are NOT in this schema
|
||||
search_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "The search query"},
|
||||
"max_results": {"type": "integer", "description": "Maximum number of results"},
|
||||
},
|
||||
"required": ["query", "max_results"],
|
||||
}
|
||||
|
||||
@tool(args_schema=search_schema)
|
||||
async def async_multi_injection_tool(
|
||||
query: str,
|
||||
max_results: int,
|
||||
state: Annotated[Any, InjectedState],
|
||||
runtime: ToolRuntime,
|
||||
store: Annotated[Any, InjectedStore()],
|
||||
) -> str:
|
||||
"""Async tool with multiple injection types.
|
||||
|
||||
Args:
|
||||
query: The search query (LLM-controlled).
|
||||
max_results: Maximum number of results (LLM-controlled).
|
||||
state: The graph state (injected).
|
||||
runtime: The tool runtime context (injected).
|
||||
store: The persistent store (injected).
|
||||
"""
|
||||
# Capture all injected parameters
|
||||
injected_data["state"] = state
|
||||
injected_data["api_key"] = state.get("api_key", "unknown")
|
||||
injected_data["request_id"] = state.get("request_id", "unknown")
|
||||
injected_data["runtime"] = runtime
|
||||
injected_data["tool_call_id"] = runtime.tool_call_id
|
||||
injected_data["config"] = runtime.config
|
||||
injected_data["store"] = store
|
||||
|
||||
# Verify we can write to the store
|
||||
if store is not None:
|
||||
await store.aput(("test", "namespace"), "test_key", {"query": query})
|
||||
# Read back to verify it worked
|
||||
item = await store.aget(("test", "namespace"), "test_key")
|
||||
injected_data["store_write_success"] = item is not None
|
||||
|
||||
return f"Found {max_results} results for '{query}'"
|
||||
|
||||
# Create model that calls the async tool
|
||||
model = FakeToolCallingModel(
|
||||
tool_calls=[
|
||||
[
|
||||
{
|
||||
"name": "async_multi_injection_tool",
|
||||
"args": {"query": "test search", "max_results": 10},
|
||||
"id": "call_search_456",
|
||||
}
|
||||
],
|
||||
[],
|
||||
]
|
||||
)
|
||||
|
||||
# Create agent with custom state and store
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=[async_multi_injection_tool],
|
||||
state_schema=CustomState,
|
||||
store=InMemoryStore(),
|
||||
)
|
||||
|
||||
# Verify the tool's args schema only includes LLM-controlled parameters
|
||||
tool_args_schema = async_multi_injection_tool.args_schema
|
||||
assert "query" in tool_args_schema["properties"]
|
||||
assert "max_results" in tool_args_schema["properties"]
|
||||
assert "state" not in tool_args_schema["properties"]
|
||||
assert "runtime" not in tool_args_schema["properties"]
|
||||
assert "store" not in tool_args_schema["properties"]
|
||||
|
||||
# Invoke async
|
||||
result = await agent.ainvoke(
|
||||
{
|
||||
"messages": [HumanMessage("Search for something")],
|
||||
"api_key": "sk-test-key-xyz",
|
||||
"request_id": "req_999",
|
||||
}
|
||||
)
|
||||
|
||||
# Verify tool executed successfully
|
||||
tool_messages = [msg for msg in result["messages"] if isinstance(msg, ToolMessage)]
|
||||
assert len(tool_messages) == 1
|
||||
tool_message = tool_messages[0]
|
||||
assert tool_message.content == "Found 10 results for 'test search'"
|
||||
assert tool_message.tool_call_id == "call_search_456"
|
||||
|
||||
# Verify all injections worked correctly
|
||||
assert injected_data["state"] is not None
|
||||
assert injected_data["api_key"] == "sk-test-key-xyz"
|
||||
assert injected_data["request_id"] == "req_999"
|
||||
|
||||
# Verify runtime was injected
|
||||
assert injected_data["runtime"] is not None
|
||||
assert injected_data["tool_call_id"] == "call_search_456"
|
||||
assert injected_data["config"] is not None
|
||||
|
||||
# Verify store was injected and writable
|
||||
assert injected_data["store"] is not None
|
||||
assert injected_data["store_write_success"] is True
|
||||
|
||||
8
libs/langchain_v1/uv.lock
generated
8
libs/langchain_v1/uv.lock
generated
@@ -2174,7 +2174,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "1.0.6"
|
||||
version = "1.0.7"
|
||||
source = { editable = "../core" }
|
||||
dependencies = [
|
||||
{ name = "jsonpatch" },
|
||||
@@ -2626,15 +2626,15 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langgraph-prebuilt"
|
||||
version = "1.0.4"
|
||||
version = "1.0.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "langchain-core" },
|
||||
{ name = "langgraph-checkpoint" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/84/08/45857c7c65f696307834af13946a72293e6cc49141de887f0957c2eb2c46/langgraph_prebuilt-1.0.4.tar.gz", hash = "sha256:7b4f9e97a146d2d625695c3549bdb432974b80817165139ec2ec869721e72c0f", size = 142470, upload-time = "2025-11-13T19:02:14.807Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/46/f9/54f8891b32159e4542236817aea2ee83de0de18bce28e9bdba08c7f93001/langgraph_prebuilt-1.0.5.tar.gz", hash = "sha256:85802675ad778cc7240fd02d47db1e0b59c0c86d8369447d77ce47623845db2d", size = 144453, upload-time = "2025-11-20T16:47:39.23Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/69/14/a83e50129f66df783a68acb89e7b3e9c39b5c128a8748e961bc2b187f003/langgraph_prebuilt-1.0.4-py3-none-any.whl", hash = "sha256:50b1aa2b434783b6da30785568cf7155136b484750cc2ec695c0d4255db08262", size = 34414, upload-time = "2025-11-13T19:02:13.416Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/87/5e/aeba4a5b39fe6e874e0dd003a82da71c7153e671312671a8dacc5cb7c1af/langgraph_prebuilt-1.0.5-py3-none-any.whl", hash = "sha256:22369563e1848862ace53fbc11b027c28dd04a9ac39314633bb95f2a7e258496", size = 35072, upload-time = "2025-11-20T16:47:38.187Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user