chore(langchain): fix types in test_injected_runtime_create_agent, test_create_agent_tool_validation (#34568)

This commit is contained in:
Christophe Bornet
2026-01-10 03:50:18 +01:00
committed by GitHub
parent f1ab8c5c80
commit 5dc8ba3c99
2 changed files with 23 additions and 22 deletions

View File

@@ -1,5 +1,5 @@
import sys
from typing import Annotated
from typing import Annotated, Any
import pytest
from langchain_core.messages import HumanMessage
@@ -28,7 +28,7 @@ def test_tool_invocation_error_excludes_injected_state() -> None:
"""
# Define a custom state schema with injected data
class TestState(AgentState):
class TestState(AgentState[Any]):
secret_data: str # Example of state data not controlled by LLM
@dec_tool
@@ -94,7 +94,7 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None:
"""
# Define a custom state schema
class TestState(AgentState):
class TestState(AgentState[Any]):
internal_data: str
@dec_tool
@@ -193,10 +193,10 @@ def test_create_agent_error_content_with_multiple_params() -> None:
This ensures the LLM receives focused, actionable feedback.
"""
class TestState(AgentState):
class TestState(AgentState[Any]):
user_id: str
api_key: str
session_data: dict
session_data: dict[str, Any]
@dec_tool
def complex_tool(
@@ -309,7 +309,7 @@ def test_create_agent_error_only_model_controllable_params() -> None:
absent from error messages. This provides focused feedback to the LLM.
"""
class StateWithSecrets(AgentState):
class StateWithSecrets(AgentState[Any]):
password: str # Example of data not controlled by LLM
@dec_tool

View File

@@ -16,7 +16,7 @@ configurations.
from __future__ import annotations
from typing import Annotated, Any
from typing import TYPE_CHECKING, Annotated, Any
from langchain_core.messages import HumanMessage, ToolMessage
from langchain_core.tools import tool
@@ -28,11 +28,14 @@ from langchain.agents.middleware.types import AgentMiddleware, AgentState
from langchain.tools import InjectedState, ToolRuntime
from tests.unit_tests.agents.model import FakeToolCallingModel
if TYPE_CHECKING:
from langgraph.runtime import Runtime
def test_tool_runtime_basic_injection() -> None:
"""Test basic ToolRuntime injection in tools with create_agent."""
# Track what was injected
injected_data = {}
injected_data: dict[str, Any] = {}
@tool
def runtime_tool(x: int, runtime: ToolRuntime) -> str:
@@ -80,7 +83,7 @@ def test_tool_runtime_basic_injection() -> None:
async def test_tool_runtime_async_injection() -> None:
"""Test ToolRuntime injection works with async tools."""
injected_data = {}
injected_data: dict[str, Any] = {}
@tool
async def async_runtime_tool(x: int, runtime: ToolRuntime) -> str:
@@ -195,7 +198,7 @@ def test_tool_runtime_with_store() -> None:
def test_tool_runtime_with_multiple_tools() -> None:
"""Test multiple tools can all access ToolRuntime."""
call_log = []
call_log: list[tuple[str, str | None, int | str]] = []
@tool
def tool_a(x: int, runtime: ToolRuntime) -> str:
@@ -242,7 +245,7 @@ def test_tool_runtime_with_multiple_tools() -> None:
def test_tool_runtime_config_access() -> None:
"""Test tools can access config through ToolRuntime."""
config_data = {}
config_data: dict[str, Any] = {}
@tool
def config_tool(x: int, runtime: ToolRuntime) -> str:
@@ -282,7 +285,7 @@ def test_tool_runtime_config_access() -> None:
def test_tool_runtime_with_custom_state() -> None:
"""Test ToolRuntime works with custom state schemas."""
class CustomState(AgentState):
class CustomState(AgentState[Any]):
custom_field: str
runtime_state = {}
@@ -464,11 +467,11 @@ def test_tool_runtime_with_middleware() -> None:
runtime_calls = []
class TestMiddleware(AgentMiddleware):
def before_model(self, state, runtime) -> dict[str, Any]:
def before_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any]:
middleware_calls.append("before_model")
return {}
def after_model(self, state, runtime) -> dict[str, Any]:
def after_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any]:
middleware_calls.append("after_model")
return {}
@@ -515,11 +518,7 @@ def test_tool_runtime_type_hints() -> None:
def typed_runtime_tool(x: int, runtime: ToolRuntime) -> str:
"""Tool with runtime access."""
# Access state dict - verify we can access standard state fields
if isinstance(runtime.state, dict):
# Count messages in state
typed_runtime["message_count"] = len(runtime.state.get("messages", []))
else:
typed_runtime["message_count"] = len(getattr(runtime.state, "messages", []))
typed_runtime["message_count"] = len(runtime.state.get("messages", []))
return f"Typed: {x}"
agent = create_agent(
@@ -546,7 +545,7 @@ def test_tool_runtime_type_hints() -> None:
def test_tool_runtime_name_based_injection() -> None:
"""Test that parameter named 'runtime' gets injected without type annotation."""
injected_data = {}
injected_data: dict[str, Any] = {}
@tool
def name_based_tool(x: int, runtime: Any) -> str:
@@ -601,7 +600,7 @@ def test_combined_injected_state_runtime_store() -> None:
injected_data = {}
# Custom state schema with additional fields
class CustomState(AgentState):
class CustomState(AgentState[Any]):
user_id: str
session_id: str
@@ -667,6 +666,7 @@ def test_combined_injected_state_runtime_store() -> None:
# Verify the tool's args schema only includes LLM-controlled parameters
tool_args_schema = multi_injection_tool.args_schema
assert isinstance(tool_args_schema, dict)
assert "location" in tool_args_schema["properties"]
assert "state" not in tool_args_schema["properties"]
assert "runtime" not in tool_args_schema["properties"]
@@ -718,7 +718,7 @@ async def test_combined_injected_state_runtime_store_async() -> None:
injected_data = {}
# Custom state schema
class CustomState(AgentState):
class CustomState(AgentState[Any]):
api_key: str
request_id: str
@@ -792,6 +792,7 @@ async def test_combined_injected_state_runtime_store_async() -> None:
# Verify the tool's args schema only includes LLM-controlled parameters
tool_args_schema = async_multi_injection_tool.args_schema
assert isinstance(tool_args_schema, dict)
assert "query" in tool_args_schema["properties"]
assert "max_results" in tool_args_schema["properties"]
assert "state" not in tool_args_schema["properties"]