chore(langchain): improve typing in tests (#38163)

Co-authored-by: Mason Daugherty <github@mdrxy.com>
Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Christophe Bornet
2026-06-16 01:02:22 +02:00
committed by GitHub
parent 5f0abc1152
commit 5d044cd326
13 changed files with 157 additions and 89 deletions

View File

@@ -111,10 +111,7 @@ enable_error_code = "deprecated"
warn_unreachable = true
exclude = [
# Exclude agents tests except middleware_typing/ which has type-checked tests
"tests/unit_tests/agents/middleware/",
"tests/unit_tests/agents/specifications/",
"tests/unit_tests/agents/test_.*\\.py",
]
[[tool.mypy.overrides]]
@@ -147,6 +144,7 @@ unfixable = [
]
flake8-annotations.allow-star-arg-any = true
flake8-annotations.mypy-init-return = true
allowed-confusables = [""]
[tool.ruff.lint.flake8-tidy-imports]
@@ -157,13 +155,7 @@ convention = "google"
ignore-var-parameters = true # ignore missing documentation for *args and **kwargs parameters
[tool.ruff.lint.extend-per-file-ignores]
"tests/unit_tests/agents/*" = [
"ANN", # Annotations, needs to fix
]
"tests/unit_tests/agents/test_responses_spec.py" = ["F821"]
"tests/unit_tests/agents/test_return_direct_spec.py" = ["F821"]
"tests/unit_tests/agents/test_react_agent.py" = ["ALL"]
"tests/*" = [
"D1", # Documentation rules
"S101", # Tests need assertions

View File

@@ -5,6 +5,8 @@ This module tests that the name parameter correctly sets .name on AIMessage outp
from __future__ import annotations
from typing import Any, cast
from langchain_core.messages import (
AIMessage,
HumanMessage,
@@ -119,7 +121,9 @@ def test_lc_agent_name_in_stream_metadata() -> None:
stream_mode="messages",
):
if "lc_agent_name" in metadata:
metadata_with_agent_name.append(metadata["lc_agent_name"])
# `stream()` is typed `Iterator[dict[str, Any] | Any]`, so unpacking the
# tuple leaves `metadata` as `str | Any`; cast to index it as the dict it is.
metadata_with_agent_name.append(cast("dict[str, Any]", metadata)["lc_agent_name"])
assert len(metadata_with_agent_name) > 0
assert all(name == "streaming_agent" for name in metadata_with_agent_name)
@@ -155,7 +159,9 @@ def test_lc_agent_name_in_stream_metadata_multiple_iterations() -> None:
stream_mode="messages",
):
if "lc_agent_name" in metadata:
metadata_with_agent_name.append(metadata["lc_agent_name"])
# `stream()` is typed `Iterator[dict[str, Any] | Any]`, so unpacking the
# tuple leaves `metadata` as `str | Any`; cast to index it as the dict it is.
metadata_with_agent_name.append(cast("dict[str, Any]", metadata)["lc_agent_name"])
# Should have metadata entries for messages from both iterations
assert len(metadata_with_agent_name) > 0
@@ -176,7 +182,9 @@ async def test_lc_agent_name_in_astream_metadata() -> None:
stream_mode="messages",
):
if "lc_agent_name" in metadata:
metadata_with_agent_name.append(metadata["lc_agent_name"])
# `stream()` is typed `Iterator[dict[str, Any] | Any]`, so unpacking the
# tuple leaves `metadata` as `str | Any`; cast to index it as the dict it is.
metadata_with_agent_name.append(cast("dict[str, Any]", metadata)["lc_agent_name"])
assert len(metadata_with_agent_name) > 0
assert all(name == "async_streaming_agent" for name in metadata_with_agent_name)
@@ -212,7 +220,9 @@ async def test_lc_agent_name_in_astream_metadata_multiple_iterations() -> None:
stream_mode="messages",
):
if "lc_agent_name" in metadata:
metadata_with_agent_name.append(metadata["lc_agent_name"])
# `stream()` is typed `Iterator[dict[str, Any] | Any]`, so unpacking the
# tuple leaves `metadata` as `str | Any`; cast to index it as the dict it is.
metadata_with_agent_name.append(cast("dict[str, Any]", metadata)["lc_agent_name"])
# Should have metadata entries for messages from both iterations
assert len(metadata_with_agent_name) > 0

View File

@@ -113,7 +113,7 @@ class TestAgentStreamV3Sync:
# by `BaseRunStream.__init__` whenever `MessagesTransformer` is
# registered. Content population is covered by langgraph tests —
# here we only assert the agent streamer inherits the built-in.
assert "messages" in run._mux.extensions # type: ignore[attr-defined]
assert "messages" in run._mux.extensions
assert hasattr(run, "messages")
# Drain so the run closes cleanly.
for tc in run.tool_calls:
@@ -145,14 +145,14 @@ class TestAgentStreamV3Sync:
transformers=[_Marker],
)
# Both the agent default and the user transformer are registered.
assert "tool_calls" in run._mux.extensions # type: ignore[attr-defined]
assert "marker" in run._mux.extensions # type: ignore[attr-defined]
assert "tool_calls" in run._mux.extensions
assert "marker" in run._mux.extensions
# `ToolCallTransformer` must come BEFORE the user's transformer in
# the registration order so it processes `tools` events first. The
# docstring on `create_agent` promises caller transformers are
# appended after the built-in.
transformers = run._mux._transformers # type: ignore[attr-defined]
transformers = run._mux._transformers
tool_call_idx = next(
i for i, t in enumerate(transformers) if isinstance(t, ToolCallTransformer)
)

View File

@@ -9,6 +9,7 @@ from langgraph.store.base import BaseStore
from langgraph.store.memory import InMemoryStore
from langchain.agents import AgentState, create_agent
from langchain.agents.middleware import InputAgentState
from langchain.tools import InjectedState
from langchain.tools import tool as dec_tool
from tests.unit_tests.agents.model import FakeToolCallingModel
@@ -37,6 +38,9 @@ def test_tool_invocation_error_excludes_injected_state() -> None:
class TestState(AgentState[Any]):
secret_data: str
class TestInputState(InputAgentState):
secret_data: str
@dec_tool
def tool_with_injected_state(
some_val: int,
@@ -66,10 +70,10 @@ def test_tool_invocation_error_excludes_injected_state() -> None:
)
result = agent.invoke(
{
"messages": [HumanMessage("Test message")],
"secret_data": "sensitive_secret_123",
}
TestInputState(
messages=[HumanMessage("Test message")],
secret_data="sensitive_secret_123", # noqa: S106
)
)
tool_messages = [m for m in result["messages"] if m.type == "tool"]
@@ -96,6 +100,9 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None:
class TestState(AgentState[Any]):
internal_data: str
class TestInputState(InputAgentState):
internal_data: str
@dec_tool
async def async_tool_with_injected_state(
query: str,
@@ -126,10 +133,9 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None:
)
result = await agent.ainvoke(
{
"messages": [HumanMessage("Test async")],
"internal_data": "secret_internal_value_xyz",
}
TestInputState(
messages=[HumanMessage("Test async")], internal_data="secret_internal_value_xyz"
)
)
tool_messages = [m for m in result["messages"] if m.type == "tool"]
@@ -342,6 +348,9 @@ def test_create_agent_error_only_model_controllable_params() -> None:
class StateWithSecrets(AgentState[Any]):
password: str
class InputStateWithSecrets(InputAgentState):
password: str
@dec_tool
def secure_tool(
username: str,
@@ -381,10 +390,10 @@ def test_create_agent_error_only_model_controllable_params() -> None:
)
result = agent.invoke(
{
"messages": [HumanMessage("Create account")],
"password": "super_secret_password_12345",
}
InputStateWithSecrets(
messages=[HumanMessage("Create account")],
password="super_secret_password_12345", # noqa: S106
)
)
tool_messages = [m for m in result["messages"] if m.type == "tool"]

View File

@@ -5,14 +5,14 @@ including the scenario where no AIMessage exists in the message list
(fixes issue #34792).
"""
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, SystemMessage, ToolMessage
from langchain.agents.factory import _fetch_last_ai_and_tool_messages
def test_fetch_last_ai_and_tool_messages_normal() -> None:
"""Test normal case with AIMessage and subsequent ToolMessages."""
messages = [
messages: list[AnyMessage] = [
HumanMessage(content="Hello"),
AIMessage(content="Hi there!", tool_calls=[{"name": "test", "id": "1", "args": {}}]),
ToolMessage(content="Tool result", tool_call_id="1"),
@@ -29,7 +29,7 @@ def test_fetch_last_ai_and_tool_messages_normal() -> None:
def test_fetch_last_ai_and_tool_messages_multiple_ai() -> None:
"""Test that the last AIMessage is returned when multiple exist."""
messages = [
messages: list[AnyMessage] = [
HumanMessage(content="First question"),
AIMessage(content="First answer", id="ai1"),
HumanMessage(content="Second question"),
@@ -53,7 +53,7 @@ def test_fetch_last_ai_and_tool_messages_no_ai_message() -> None:
The function now returns None for the AIMessage, allowing callers to
handle this edge case explicitly.
"""
messages = [
messages: list[AnyMessage] = [
HumanMessage(content="Hello"),
SystemMessage(content="You are a helpful assistant"),
]
@@ -70,7 +70,7 @@ def test_fetch_last_ai_and_tool_messages_empty_list() -> None:
This can occur after RemoveMessage(id=REMOVE_ALL_MESSAGES) clears all messages.
"""
messages: list = []
messages: list[AnyMessage] = []
ai_msg, tool_msgs = _fetch_last_ai_and_tool_messages(messages)
@@ -81,7 +81,7 @@ def test_fetch_last_ai_and_tool_messages_empty_list() -> None:
def test_fetch_last_ai_and_tool_messages_only_human_messages() -> None:
"""Test handling when only HumanMessages exist."""
messages = [
messages: list[AnyMessage] = [
HumanMessage(content="Hello"),
HumanMessage(content="Are you there?"),
]
@@ -94,7 +94,7 @@ def test_fetch_last_ai_and_tool_messages_only_human_messages() -> None:
def test_fetch_last_ai_and_tool_messages_ai_without_tool_calls() -> None:
"""Test AIMessage without tool_calls returns empty tool messages list."""
messages = [
messages: list[AnyMessage] = [
HumanMessage(content="Hello"),
AIMessage(content="Hi! How can I help you today?"),
]

View File

@@ -25,7 +25,7 @@ from langgraph.store.memory import InMemoryStore
from typing_extensions import override
from langchain.agents import create_agent
from langchain.agents.middleware.types import AgentMiddleware, AgentState
from langchain.agents.middleware.types import AgentMiddleware, AgentState, InputAgentState
from langchain.tools import InjectedState, ToolRuntime
from tests.unit_tests.agents.model import FakeToolCallingModel
@@ -303,6 +303,9 @@ def test_tool_runtime_with_custom_state() -> None:
class CustomState(AgentState[Any]):
custom_field: str
class CustomInputState(InputAgentState):
custom_field: str
runtime_state = {}
@tool
@@ -327,10 +330,10 @@ def test_tool_runtime_with_custom_state() -> None:
)
result = agent.invoke(
{
"messages": [HumanMessage("Test custom state")],
"custom_field": "custom_value",
}
CustomInputState(
messages=[HumanMessage("Test custom state")],
custom_field="custom_value",
)
)
# Verify custom field was accessible
@@ -622,6 +625,10 @@ def test_combined_injected_state_runtime_store() -> None:
user_id: str
session_id: str
class CustomInputState(InputAgentState):
user_id: str
session_id: str
# Define explicit args schema that only includes LLM-controlled parameters
weather_schema = {
"type": "object",
@@ -692,11 +699,11 @@ def test_combined_injected_state_runtime_store() -> None:
# Invoke with custom state fields
result = agent.invoke(
{
"messages": [HumanMessage("What's the weather like?")],
"user_id": "user_42",
"session_id": "session_abc123",
}
CustomInputState(
messages=[HumanMessage("What's the weather like?")],
user_id="user_42",
session_id="session_abc123",
)
)
# Verify tool executed successfully
@@ -740,6 +747,10 @@ async def test_combined_injected_state_runtime_store_async() -> None:
api_key: str
request_id: str
class CustomInputState(InputAgentState):
api_key: str
request_id: str
# Define explicit args schema that only includes LLM-controlled parameters
# Note: state, runtime, and store are NOT in this schema
search_schema = {
@@ -819,11 +830,11 @@ async def test_combined_injected_state_runtime_store_async() -> None:
# Invoke async
result = await agent.ainvoke(
{
"messages": [HumanMessage("Search for something")],
"api_key": "sk-test-key-xyz",
"request_id": "req_999",
}
CustomInputState(
messages=[HumanMessage("Search for something")],
api_key="sk-test-key-xyz",
request_id="req_999",
)
)
# Verify tool executed successfully

View File

@@ -16,8 +16,8 @@ class ArgsSchema(BaseModel):
"""Args schema with config and runtime fields."""
query: str = Field(description="The query")
config: dict | None = Field(default=None)
runtime: dict | None = Field(default=None)
config: dict[str, Any] | None = Field(default=None)
runtime: dict[str, Any] | None = Field(default=None)
def test_config_and_runtime_not_injected_to_kwargs() -> None:

View File

@@ -905,7 +905,7 @@ class TestSupportsProviderStrategy:
"""Unit tests for `_supports_provider_strategy`."""
@staticmethod
def _make_structured_model(model_name: str):
def _make_structured_model(model_name: str) -> GenericFakeChatModel:
class GeminiTestChatModel(GenericFakeChatModel):
model_name: str

View File

@@ -118,7 +118,7 @@ def test_responses_integration_matrix(case: TestCase) -> None:
# Unwrap nested schema objects
response_format_spec = [item.get("schema", item) for item in response_format_spec]
if len(response_format_spec) == 1:
tool_output = ToolStrategy(response_format_spec[0])
tool_output = ToolStrategy[Any](response_format_spec[0])
else:
tool_output = ToolStrategy({"oneOf": response_format_spec})

View File

@@ -16,6 +16,7 @@ from langchain.agents import create_agent, factory
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
InputAgentState,
PrivateStateAttr,
)
@@ -40,6 +41,9 @@ def test_state_schema_single_custom_field() -> None:
class CustomState(AgentState[Any]):
custom_field: str
class CustomInputState(InputAgentState):
custom_field: str
agent = create_agent(
model=FakeToolCallingModel(
tool_calls=[[{"args": {"x": 1}, "id": "call_1", "name": "simple_tool"}], []]
@@ -48,7 +52,9 @@ def test_state_schema_single_custom_field() -> None:
state_schema=CustomState,
)
result = agent.invoke({"messages": [HumanMessage("Test")], "custom_field": "test_value"})
result = agent.invoke(
CustomInputState(messages=[HumanMessage("Test")], custom_field="test_value")
)
assert result["custom_field"] == "test_value"
assert len(result["messages"]) == 4
@@ -62,6 +68,11 @@ def test_state_schema_multiple_custom_fields() -> None:
session_id: str
context: str
class CustomInputState(InputAgentState):
user_id: str
session_id: str
context: str
agent = create_agent(
model=FakeToolCallingModel(
tool_calls=[[{"args": {"x": 1}, "id": "call_1", "name": "simple_tool"}], []]
@@ -71,12 +82,12 @@ def test_state_schema_multiple_custom_fields() -> None:
)
result = agent.invoke(
{
"messages": [HumanMessage("Test")],
"user_id": "user_123",
"session_id": "session_456",
"context": "test_ctx",
}
CustomInputState(
messages=[HumanMessage("Test")],
user_id="user_123",
session_id="session_456",
context="test_ctx",
)
)
assert result["user_id"] == "user_123"
@@ -91,6 +102,9 @@ def test_state_schema_with_tool_runtime() -> None:
class ExtendedState(AgentState[Any]):
counter: int
class ExtendedInputState(InputAgentState):
counter: int
runtime_data = {}
@tool
@@ -107,7 +121,7 @@ def test_state_schema_with_tool_runtime() -> None:
state_schema=ExtendedState,
)
result = agent.invoke({"messages": [HumanMessage("Test")], "counter": 5})
result = agent.invoke(ExtendedInputState(messages=[HumanMessage("Test")], counter=5))
assert runtime_data["counter"] == 5
assert "Counter is 5" in result["messages"][2].content
@@ -122,6 +136,10 @@ def test_state_schema_with_middleware() -> None:
class MiddlewareState(AgentState[Any]):
middleware_data: str
class UserAndMiddlewareInputState(InputAgentState):
user_name: str
middleware_data: str
middleware_calls = []
class TestMiddleware(AgentMiddleware[MiddlewareState, None]):
@@ -142,11 +160,11 @@ def test_state_schema_with_middleware() -> None:
)
result = agent.invoke(
{
"messages": [HumanMessage("Test")],
"user_name": "Alice",
"middleware_data": "test_data",
}
UserAndMiddlewareInputState(
messages=[HumanMessage("Test")],
user_name="Alice",
middleware_data="test_data",
)
)
assert result["user_name"] == "Alice"
@@ -176,6 +194,9 @@ async def test_state_schema_async() -> None:
class AsyncState(AgentState[Any]):
async_field: str
class AsyncInputState(InputAgentState):
async_field: str
@tool
async def async_tool(x: int) -> str:
"""Async tool."""
@@ -190,10 +211,10 @@ async def test_state_schema_async() -> None:
)
result = await agent.ainvoke(
{
"messages": [HumanMessage("Test async")],
"async_field": "async_value",
}
AsyncInputState(
messages=[HumanMessage("Test async")],
async_field="async_value",
)
)
assert result["async_field"] == "async_value"
@@ -213,6 +234,10 @@ def test_state_schema_with_private_state_field() -> None:
public_field: str
private_field: Annotated[str, PrivateStateAttr]
class InputStateWithPrivateField(InputAgentState):
public_field: str
private_field: Annotated[str, PrivateStateAttr]
captured_state = {}
@tool
@@ -234,11 +259,11 @@ def test_state_schema_with_private_state_field() -> None:
# Invoke the agent with BOTH public and private fields
result = agent.invoke(
{
"messages": [HumanMessage("Test private state")],
"public_field": "public_value",
"private_field": "private_value", # This should be filtered out
}
InputStateWithPrivateField(
messages=[HumanMessage("Test private state")],
public_field="public_value",
private_field="private_value", # This should be filtered out
)
)
# Assert that public_field is preserved in the result

View File

@@ -8,14 +8,21 @@ tools were dropped during `stream(..., subgraphs=True)`.
from __future__ import annotations
from langchain_core.messages import HumanMessage, ToolCall
from langchain_core.tools import tool
from typing import TYPE_CHECKING, Any, cast
from langchain.agents import create_agent
from langchain_core.messages import HumanMessage, ToolCall
from langchain_core.tools import BaseTool, tool
from langchain.agents import AgentState, create_agent
from tests.unit_tests.agents.model import FakeToolCallingModel
if TYPE_CHECKING:
from langgraph.graph.state import CompiledStateGraph
def _make_subagent_caller_tool():
from langchain.agents.middleware import InputAgentState, OutputAgentState
def _make_subagent_caller_tool() -> BaseTool:
"""Build a subagent and a tool that invokes it."""
subagent = create_agent(
model=FakeToolCallingModel(tool_calls=[[]]),
@@ -26,12 +33,18 @@ def _make_subagent_caller_tool():
def call_subagent(query: str) -> str:
"""Delegate the query to a sub-agent."""
result = subagent.invoke({"messages": [HumanMessage(query)]})
return result["messages"][-1].text
# `invoke()` returns an untyped state, so `.text` is `Any`; it is really a
# `str` (`TextAccessor`), so narrow it to satisfy the `-> str` return type.
return cast("str", result["messages"][-1].text)
return call_subagent
def _make_parent_agent(call_subagent_tool) -> object:
# Return type mirrors `create_agent`'s overload; the context slot is unparameterized
# here, so it resolves to `None` (the `ContextT` default).
def _make_parent_agent(
call_subagent_tool: BaseTool,
) -> CompiledStateGraph[AgentState[Any], None, InputAgentState, OutputAgentState[Any]]:
parent_tool_calls: list[list[ToolCall]] = [
[{"args": {"query": "hi"}, "id": "call_1", "name": "call_subagent"}],
[],

View File

@@ -9,6 +9,7 @@ dispatches a nested `create_agent` from a tool, giving true end-to-end coverage.
from __future__ import annotations
import sys
from typing import cast
import pytest
from langchain_core.messages import HumanMessage
@@ -38,7 +39,9 @@ def test_subagents_surfaces_named_subagent() -> None:
def call_weather(city: str) -> str:
"""Call the weather agent."""
result = weather_agent.invoke({"messages": [HumanMessage(f"weather in {city}")]})
return result["messages"][-1].text
# `invoke()` returns an untyped state, so `.text` is `Any`; it is really a
# `str` (`TextAccessor`), so narrow it to satisfy the `-> str` return type.
return cast("str", result["messages"][-1].text)
supervisor = create_agent(
model=_supervisor_model(),
@@ -85,7 +88,9 @@ async def test_subagents_surfaces_named_subagent_async() -> None:
async def call_weather(city: str) -> str:
"""Call the weather agent."""
result = await weather_agent.ainvoke({"messages": [HumanMessage(f"weather in {city}")]})
return result["messages"][-1].text
# `invoke()` returns an untyped state, so `.text` is `Any`; it is really a
# `str` (`TextAccessor`), so narrow it to satisfy the `-> str` return type.
return cast("str", result["messages"][-1].text)
supervisor = create_agent(
model=_supervisor_model(),
@@ -148,7 +153,9 @@ def test_unnamed_inner_agent_surfaces_with_inherited_name() -> None:
def call_weather(city: str) -> str:
"""Call an unnamed inner agent."""
result = inner_agent.invoke({"messages": [HumanMessage(f"weather in {city}")]})
return result["messages"][-1].text
# `invoke()` returns an untyped state, so `.text` is `Any`; it is really a
# `str` (`TextAccessor`), so narrow it to satisfy the `-> str` return type.
return cast("str", result["messages"][-1].text)
supervisor = create_agent(
model=_supervisor_model(),
@@ -183,7 +190,9 @@ def test_same_name_nested_agent_surfaced() -> None:
def call_weather(city: str) -> str:
"""Call a same-named inner agent."""
result = inner_agent.invoke({"messages": [HumanMessage(f"weather in {city}")]})
return result["messages"][-1].text
# `invoke()` returns an untyped state, so `.text` is `Any`; it is really a
# `str` (`TextAccessor`), so narrow it to satisfy the `-> str` return type.
return cast("str", result["messages"][-1].text)
# The parent agent shares the inner agent's name.
supervisor = create_agent(

View File

@@ -765,7 +765,7 @@ class TestDynamicSystemPromptMiddleware:
def test_middleware_can_return_system_message(self) -> None:
"""Test that middleware can return a SystemMessage with dynamic content."""
def dynamic_system_prompt_middleware(request: ModelRequest) -> SystemMessage:
def dynamic_system_prompt_middleware(request: ModelRequest[Any]) -> SystemMessage:
"""Return a SystemMessage with dynamic content."""
region = getattr(request.runtime.context, "region", "n/a")
return SystemMessage(content=f"You are a helpful assistant. Region: {region}")
@@ -786,7 +786,6 @@ class TestDynamicSystemPromptMiddleware:
runtime=runtime,
model_settings={},
)
new_system_message = dynamic_system_prompt_middleware(request)
assert isinstance(new_system_message, SystemMessage)