mirror of
https://github.com/hwchase17/langchain.git
synced 2026-07-01 14:47:02 +00:00
chore(langchain): improve typing in tests (#38163)
Co-authored-by: Mason Daugherty <github@mdrxy.com> Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
committed by
GitHub
parent
5f0abc1152
commit
5d044cd326
@@ -111,10 +111,7 @@ enable_error_code = "deprecated"
|
||||
warn_unreachable = true
|
||||
|
||||
exclude = [
|
||||
# Exclude agents tests except middleware_typing/ which has type-checked tests
|
||||
"tests/unit_tests/agents/middleware/",
|
||||
"tests/unit_tests/agents/specifications/",
|
||||
"tests/unit_tests/agents/test_.*\\.py",
|
||||
]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
@@ -147,6 +144,7 @@ unfixable = [
|
||||
]
|
||||
|
||||
flake8-annotations.allow-star-arg-any = true
|
||||
flake8-annotations.mypy-init-return = true
|
||||
allowed-confusables = ["–"]
|
||||
|
||||
[tool.ruff.lint.flake8-tidy-imports]
|
||||
@@ -157,13 +155,7 @@ convention = "google"
|
||||
ignore-var-parameters = true # ignore missing documentation for *args and **kwargs parameters
|
||||
|
||||
[tool.ruff.lint.extend-per-file-ignores]
|
||||
"tests/unit_tests/agents/*" = [
|
||||
"ANN", # Annotations, needs to fix
|
||||
]
|
||||
"tests/unit_tests/agents/test_responses_spec.py" = ["F821"]
|
||||
"tests/unit_tests/agents/test_return_direct_spec.py" = ["F821"]
|
||||
"tests/unit_tests/agents/test_react_agent.py" = ["ALL"]
|
||||
|
||||
"tests/*" = [
|
||||
"D1", # Documentation rules
|
||||
"S101", # Tests need assertions
|
||||
|
||||
@@ -5,6 +5,8 @@ This module tests that the name parameter correctly sets .name on AIMessage outp
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
HumanMessage,
|
||||
@@ -119,7 +121,9 @@ def test_lc_agent_name_in_stream_metadata() -> None:
|
||||
stream_mode="messages",
|
||||
):
|
||||
if "lc_agent_name" in metadata:
|
||||
metadata_with_agent_name.append(metadata["lc_agent_name"])
|
||||
# `stream()` is typed `Iterator[dict[str, Any] | Any]`, so unpacking the
|
||||
# tuple leaves `metadata` as `str | Any`; cast to index it as the dict it is.
|
||||
metadata_with_agent_name.append(cast("dict[str, Any]", metadata)["lc_agent_name"])
|
||||
|
||||
assert len(metadata_with_agent_name) > 0
|
||||
assert all(name == "streaming_agent" for name in metadata_with_agent_name)
|
||||
@@ -155,7 +159,9 @@ def test_lc_agent_name_in_stream_metadata_multiple_iterations() -> None:
|
||||
stream_mode="messages",
|
||||
):
|
||||
if "lc_agent_name" in metadata:
|
||||
metadata_with_agent_name.append(metadata["lc_agent_name"])
|
||||
# `stream()` is typed `Iterator[dict[str, Any] | Any]`, so unpacking the
|
||||
# tuple leaves `metadata` as `str | Any`; cast to index it as the dict it is.
|
||||
metadata_with_agent_name.append(cast("dict[str, Any]", metadata)["lc_agent_name"])
|
||||
|
||||
# Should have metadata entries for messages from both iterations
|
||||
assert len(metadata_with_agent_name) > 0
|
||||
@@ -176,7 +182,9 @@ async def test_lc_agent_name_in_astream_metadata() -> None:
|
||||
stream_mode="messages",
|
||||
):
|
||||
if "lc_agent_name" in metadata:
|
||||
metadata_with_agent_name.append(metadata["lc_agent_name"])
|
||||
# `stream()` is typed `Iterator[dict[str, Any] | Any]`, so unpacking the
|
||||
# tuple leaves `metadata` as `str | Any`; cast to index it as the dict it is.
|
||||
metadata_with_agent_name.append(cast("dict[str, Any]", metadata)["lc_agent_name"])
|
||||
|
||||
assert len(metadata_with_agent_name) > 0
|
||||
assert all(name == "async_streaming_agent" for name in metadata_with_agent_name)
|
||||
@@ -212,7 +220,9 @@ async def test_lc_agent_name_in_astream_metadata_multiple_iterations() -> None:
|
||||
stream_mode="messages",
|
||||
):
|
||||
if "lc_agent_name" in metadata:
|
||||
metadata_with_agent_name.append(metadata["lc_agent_name"])
|
||||
# `stream()` is typed `Iterator[dict[str, Any] | Any]`, so unpacking the
|
||||
# tuple leaves `metadata` as `str | Any`; cast to index it as the dict it is.
|
||||
metadata_with_agent_name.append(cast("dict[str, Any]", metadata)["lc_agent_name"])
|
||||
|
||||
# Should have metadata entries for messages from both iterations
|
||||
assert len(metadata_with_agent_name) > 0
|
||||
|
||||
@@ -113,7 +113,7 @@ class TestAgentStreamV3Sync:
|
||||
# by `BaseRunStream.__init__` whenever `MessagesTransformer` is
|
||||
# registered. Content population is covered by langgraph tests —
|
||||
# here we only assert the agent streamer inherits the built-in.
|
||||
assert "messages" in run._mux.extensions # type: ignore[attr-defined]
|
||||
assert "messages" in run._mux.extensions
|
||||
assert hasattr(run, "messages")
|
||||
# Drain so the run closes cleanly.
|
||||
for tc in run.tool_calls:
|
||||
@@ -145,14 +145,14 @@ class TestAgentStreamV3Sync:
|
||||
transformers=[_Marker],
|
||||
)
|
||||
# Both the agent default and the user transformer are registered.
|
||||
assert "tool_calls" in run._mux.extensions # type: ignore[attr-defined]
|
||||
assert "marker" in run._mux.extensions # type: ignore[attr-defined]
|
||||
assert "tool_calls" in run._mux.extensions
|
||||
assert "marker" in run._mux.extensions
|
||||
|
||||
# `ToolCallTransformer` must come BEFORE the user's transformer in
|
||||
# the registration order so it processes `tools` events first. The
|
||||
# docstring on `create_agent` promises caller transformers are
|
||||
# appended after the built-in.
|
||||
transformers = run._mux._transformers # type: ignore[attr-defined]
|
||||
transformers = run._mux._transformers
|
||||
tool_call_idx = next(
|
||||
i for i, t in enumerate(transformers) if isinstance(t, ToolCallTransformer)
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ from langgraph.store.base import BaseStore
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
from langchain.agents import AgentState, create_agent
|
||||
from langchain.agents.middleware import InputAgentState
|
||||
from langchain.tools import InjectedState
|
||||
from langchain.tools import tool as dec_tool
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
@@ -37,6 +38,9 @@ def test_tool_invocation_error_excludes_injected_state() -> None:
|
||||
class TestState(AgentState[Any]):
|
||||
secret_data: str
|
||||
|
||||
class TestInputState(InputAgentState):
|
||||
secret_data: str
|
||||
|
||||
@dec_tool
|
||||
def tool_with_injected_state(
|
||||
some_val: int,
|
||||
@@ -66,10 +70,10 @@ def test_tool_invocation_error_excludes_injected_state() -> None:
|
||||
)
|
||||
|
||||
result = agent.invoke(
|
||||
{
|
||||
"messages": [HumanMessage("Test message")],
|
||||
"secret_data": "sensitive_secret_123",
|
||||
}
|
||||
TestInputState(
|
||||
messages=[HumanMessage("Test message")],
|
||||
secret_data="sensitive_secret_123", # noqa: S106
|
||||
)
|
||||
)
|
||||
|
||||
tool_messages = [m for m in result["messages"] if m.type == "tool"]
|
||||
@@ -96,6 +100,9 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None:
|
||||
class TestState(AgentState[Any]):
|
||||
internal_data: str
|
||||
|
||||
class TestInputState(InputAgentState):
|
||||
internal_data: str
|
||||
|
||||
@dec_tool
|
||||
async def async_tool_with_injected_state(
|
||||
query: str,
|
||||
@@ -126,10 +133,9 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None:
|
||||
)
|
||||
|
||||
result = await agent.ainvoke(
|
||||
{
|
||||
"messages": [HumanMessage("Test async")],
|
||||
"internal_data": "secret_internal_value_xyz",
|
||||
}
|
||||
TestInputState(
|
||||
messages=[HumanMessage("Test async")], internal_data="secret_internal_value_xyz"
|
||||
)
|
||||
)
|
||||
|
||||
tool_messages = [m for m in result["messages"] if m.type == "tool"]
|
||||
@@ -342,6 +348,9 @@ def test_create_agent_error_only_model_controllable_params() -> None:
|
||||
class StateWithSecrets(AgentState[Any]):
|
||||
password: str
|
||||
|
||||
class InputStateWithSecrets(InputAgentState):
|
||||
password: str
|
||||
|
||||
@dec_tool
|
||||
def secure_tool(
|
||||
username: str,
|
||||
@@ -381,10 +390,10 @@ def test_create_agent_error_only_model_controllable_params() -> None:
|
||||
)
|
||||
|
||||
result = agent.invoke(
|
||||
{
|
||||
"messages": [HumanMessage("Create account")],
|
||||
"password": "super_secret_password_12345",
|
||||
}
|
||||
InputStateWithSecrets(
|
||||
messages=[HumanMessage("Create account")],
|
||||
password="super_secret_password_12345", # noqa: S106
|
||||
)
|
||||
)
|
||||
|
||||
tool_messages = [m for m in result["messages"] if m.type == "tool"]
|
||||
|
||||
@@ -5,14 +5,14 @@ including the scenario where no AIMessage exists in the message list
|
||||
(fixes issue #34792).
|
||||
"""
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
|
||||
from langchain.agents.factory import _fetch_last_ai_and_tool_messages
|
||||
|
||||
|
||||
def test_fetch_last_ai_and_tool_messages_normal() -> None:
|
||||
"""Test normal case with AIMessage and subsequent ToolMessages."""
|
||||
messages = [
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Hi there!", tool_calls=[{"name": "test", "id": "1", "args": {}}]),
|
||||
ToolMessage(content="Tool result", tool_call_id="1"),
|
||||
@@ -29,7 +29,7 @@ def test_fetch_last_ai_and_tool_messages_normal() -> None:
|
||||
|
||||
def test_fetch_last_ai_and_tool_messages_multiple_ai() -> None:
|
||||
"""Test that the last AIMessage is returned when multiple exist."""
|
||||
messages = [
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="First question"),
|
||||
AIMessage(content="First answer", id="ai1"),
|
||||
HumanMessage(content="Second question"),
|
||||
@@ -53,7 +53,7 @@ def test_fetch_last_ai_and_tool_messages_no_ai_message() -> None:
|
||||
The function now returns None for the AIMessage, allowing callers to
|
||||
handle this edge case explicitly.
|
||||
"""
|
||||
messages = [
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="Hello"),
|
||||
SystemMessage(content="You are a helpful assistant"),
|
||||
]
|
||||
@@ -70,7 +70,7 @@ def test_fetch_last_ai_and_tool_messages_empty_list() -> None:
|
||||
|
||||
This can occur after RemoveMessage(id=REMOVE_ALL_MESSAGES) clears all messages.
|
||||
"""
|
||||
messages: list = []
|
||||
messages: list[AnyMessage] = []
|
||||
|
||||
ai_msg, tool_msgs = _fetch_last_ai_and_tool_messages(messages)
|
||||
|
||||
@@ -81,7 +81,7 @@ def test_fetch_last_ai_and_tool_messages_empty_list() -> None:
|
||||
|
||||
def test_fetch_last_ai_and_tool_messages_only_human_messages() -> None:
|
||||
"""Test handling when only HumanMessages exist."""
|
||||
messages = [
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="Hello"),
|
||||
HumanMessage(content="Are you there?"),
|
||||
]
|
||||
@@ -94,7 +94,7 @@ def test_fetch_last_ai_and_tool_messages_only_human_messages() -> None:
|
||||
|
||||
def test_fetch_last_ai_and_tool_messages_ai_without_tool_calls() -> None:
|
||||
"""Test AIMessage without tool_calls returns empty tool messages list."""
|
||||
messages = [
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Hi! How can I help you today?"),
|
||||
]
|
||||
|
||||
@@ -25,7 +25,7 @@ from langgraph.store.memory import InMemoryStore
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState, InputAgentState
|
||||
from langchain.tools import InjectedState, ToolRuntime
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
|
||||
@@ -303,6 +303,9 @@ def test_tool_runtime_with_custom_state() -> None:
|
||||
class CustomState(AgentState[Any]):
|
||||
custom_field: str
|
||||
|
||||
class CustomInputState(InputAgentState):
|
||||
custom_field: str
|
||||
|
||||
runtime_state = {}
|
||||
|
||||
@tool
|
||||
@@ -327,10 +330,10 @@ def test_tool_runtime_with_custom_state() -> None:
|
||||
)
|
||||
|
||||
result = agent.invoke(
|
||||
{
|
||||
"messages": [HumanMessage("Test custom state")],
|
||||
"custom_field": "custom_value",
|
||||
}
|
||||
CustomInputState(
|
||||
messages=[HumanMessage("Test custom state")],
|
||||
custom_field="custom_value",
|
||||
)
|
||||
)
|
||||
|
||||
# Verify custom field was accessible
|
||||
@@ -622,6 +625,10 @@ def test_combined_injected_state_runtime_store() -> None:
|
||||
user_id: str
|
||||
session_id: str
|
||||
|
||||
class CustomInputState(InputAgentState):
|
||||
user_id: str
|
||||
session_id: str
|
||||
|
||||
# Define explicit args schema that only includes LLM-controlled parameters
|
||||
weather_schema = {
|
||||
"type": "object",
|
||||
@@ -692,11 +699,11 @@ def test_combined_injected_state_runtime_store() -> None:
|
||||
|
||||
# Invoke with custom state fields
|
||||
result = agent.invoke(
|
||||
{
|
||||
"messages": [HumanMessage("What's the weather like?")],
|
||||
"user_id": "user_42",
|
||||
"session_id": "session_abc123",
|
||||
}
|
||||
CustomInputState(
|
||||
messages=[HumanMessage("What's the weather like?")],
|
||||
user_id="user_42",
|
||||
session_id="session_abc123",
|
||||
)
|
||||
)
|
||||
|
||||
# Verify tool executed successfully
|
||||
@@ -740,6 +747,10 @@ async def test_combined_injected_state_runtime_store_async() -> None:
|
||||
api_key: str
|
||||
request_id: str
|
||||
|
||||
class CustomInputState(InputAgentState):
|
||||
api_key: str
|
||||
request_id: str
|
||||
|
||||
# Define explicit args schema that only includes LLM-controlled parameters
|
||||
# Note: state, runtime, and store are NOT in this schema
|
||||
search_schema = {
|
||||
@@ -819,11 +830,11 @@ async def test_combined_injected_state_runtime_store_async() -> None:
|
||||
|
||||
# Invoke async
|
||||
result = await agent.ainvoke(
|
||||
{
|
||||
"messages": [HumanMessage("Search for something")],
|
||||
"api_key": "sk-test-key-xyz",
|
||||
"request_id": "req_999",
|
||||
}
|
||||
CustomInputState(
|
||||
messages=[HumanMessage("Search for something")],
|
||||
api_key="sk-test-key-xyz",
|
||||
request_id="req_999",
|
||||
)
|
||||
)
|
||||
|
||||
# Verify tool executed successfully
|
||||
|
||||
@@ -16,8 +16,8 @@ class ArgsSchema(BaseModel):
|
||||
"""Args schema with config and runtime fields."""
|
||||
|
||||
query: str = Field(description="The query")
|
||||
config: dict | None = Field(default=None)
|
||||
runtime: dict | None = Field(default=None)
|
||||
config: dict[str, Any] | None = Field(default=None)
|
||||
runtime: dict[str, Any] | None = Field(default=None)
|
||||
|
||||
|
||||
def test_config_and_runtime_not_injected_to_kwargs() -> None:
|
||||
|
||||
@@ -905,7 +905,7 @@ class TestSupportsProviderStrategy:
|
||||
"""Unit tests for `_supports_provider_strategy`."""
|
||||
|
||||
@staticmethod
|
||||
def _make_structured_model(model_name: str):
|
||||
def _make_structured_model(model_name: str) -> GenericFakeChatModel:
|
||||
class GeminiTestChatModel(GenericFakeChatModel):
|
||||
model_name: str
|
||||
|
||||
|
||||
@@ -118,7 +118,7 @@ def test_responses_integration_matrix(case: TestCase) -> None:
|
||||
# Unwrap nested schema objects
|
||||
response_format_spec = [item.get("schema", item) for item in response_format_spec]
|
||||
if len(response_format_spec) == 1:
|
||||
tool_output = ToolStrategy(response_format_spec[0])
|
||||
tool_output = ToolStrategy[Any](response_format_spec[0])
|
||||
else:
|
||||
tool_output = ToolStrategy({"oneOf": response_format_spec})
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from langchain.agents import create_agent, factory
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
InputAgentState,
|
||||
PrivateStateAttr,
|
||||
)
|
||||
|
||||
@@ -40,6 +41,9 @@ def test_state_schema_single_custom_field() -> None:
|
||||
class CustomState(AgentState[Any]):
|
||||
custom_field: str
|
||||
|
||||
class CustomInputState(InputAgentState):
|
||||
custom_field: str
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(
|
||||
tool_calls=[[{"args": {"x": 1}, "id": "call_1", "name": "simple_tool"}], []]
|
||||
@@ -48,7 +52,9 @@ def test_state_schema_single_custom_field() -> None:
|
||||
state_schema=CustomState,
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Test")], "custom_field": "test_value"})
|
||||
result = agent.invoke(
|
||||
CustomInputState(messages=[HumanMessage("Test")], custom_field="test_value")
|
||||
)
|
||||
|
||||
assert result["custom_field"] == "test_value"
|
||||
assert len(result["messages"]) == 4
|
||||
@@ -62,6 +68,11 @@ def test_state_schema_multiple_custom_fields() -> None:
|
||||
session_id: str
|
||||
context: str
|
||||
|
||||
class CustomInputState(InputAgentState):
|
||||
user_id: str
|
||||
session_id: str
|
||||
context: str
|
||||
|
||||
agent = create_agent(
|
||||
model=FakeToolCallingModel(
|
||||
tool_calls=[[{"args": {"x": 1}, "id": "call_1", "name": "simple_tool"}], []]
|
||||
@@ -71,12 +82,12 @@ def test_state_schema_multiple_custom_fields() -> None:
|
||||
)
|
||||
|
||||
result = agent.invoke(
|
||||
{
|
||||
"messages": [HumanMessage("Test")],
|
||||
"user_id": "user_123",
|
||||
"session_id": "session_456",
|
||||
"context": "test_ctx",
|
||||
}
|
||||
CustomInputState(
|
||||
messages=[HumanMessage("Test")],
|
||||
user_id="user_123",
|
||||
session_id="session_456",
|
||||
context="test_ctx",
|
||||
)
|
||||
)
|
||||
|
||||
assert result["user_id"] == "user_123"
|
||||
@@ -91,6 +102,9 @@ def test_state_schema_with_tool_runtime() -> None:
|
||||
class ExtendedState(AgentState[Any]):
|
||||
counter: int
|
||||
|
||||
class ExtendedInputState(InputAgentState):
|
||||
counter: int
|
||||
|
||||
runtime_data = {}
|
||||
|
||||
@tool
|
||||
@@ -107,7 +121,7 @@ def test_state_schema_with_tool_runtime() -> None:
|
||||
state_schema=ExtendedState,
|
||||
)
|
||||
|
||||
result = agent.invoke({"messages": [HumanMessage("Test")], "counter": 5})
|
||||
result = agent.invoke(ExtendedInputState(messages=[HumanMessage("Test")], counter=5))
|
||||
|
||||
assert runtime_data["counter"] == 5
|
||||
assert "Counter is 5" in result["messages"][2].content
|
||||
@@ -122,6 +136,10 @@ def test_state_schema_with_middleware() -> None:
|
||||
class MiddlewareState(AgentState[Any]):
|
||||
middleware_data: str
|
||||
|
||||
class UserAndMiddlewareInputState(InputAgentState):
|
||||
user_name: str
|
||||
middleware_data: str
|
||||
|
||||
middleware_calls = []
|
||||
|
||||
class TestMiddleware(AgentMiddleware[MiddlewareState, None]):
|
||||
@@ -142,11 +160,11 @@ def test_state_schema_with_middleware() -> None:
|
||||
)
|
||||
|
||||
result = agent.invoke(
|
||||
{
|
||||
"messages": [HumanMessage("Test")],
|
||||
"user_name": "Alice",
|
||||
"middleware_data": "test_data",
|
||||
}
|
||||
UserAndMiddlewareInputState(
|
||||
messages=[HumanMessage("Test")],
|
||||
user_name="Alice",
|
||||
middleware_data="test_data",
|
||||
)
|
||||
)
|
||||
|
||||
assert result["user_name"] == "Alice"
|
||||
@@ -176,6 +194,9 @@ async def test_state_schema_async() -> None:
|
||||
class AsyncState(AgentState[Any]):
|
||||
async_field: str
|
||||
|
||||
class AsyncInputState(InputAgentState):
|
||||
async_field: str
|
||||
|
||||
@tool
|
||||
async def async_tool(x: int) -> str:
|
||||
"""Async tool."""
|
||||
@@ -190,10 +211,10 @@ async def test_state_schema_async() -> None:
|
||||
)
|
||||
|
||||
result = await agent.ainvoke(
|
||||
{
|
||||
"messages": [HumanMessage("Test async")],
|
||||
"async_field": "async_value",
|
||||
}
|
||||
AsyncInputState(
|
||||
messages=[HumanMessage("Test async")],
|
||||
async_field="async_value",
|
||||
)
|
||||
)
|
||||
|
||||
assert result["async_field"] == "async_value"
|
||||
@@ -213,6 +234,10 @@ def test_state_schema_with_private_state_field() -> None:
|
||||
public_field: str
|
||||
private_field: Annotated[str, PrivateStateAttr]
|
||||
|
||||
class InputStateWithPrivateField(InputAgentState):
|
||||
public_field: str
|
||||
private_field: Annotated[str, PrivateStateAttr]
|
||||
|
||||
captured_state = {}
|
||||
|
||||
@tool
|
||||
@@ -234,11 +259,11 @@ def test_state_schema_with_private_state_field() -> None:
|
||||
|
||||
# Invoke the agent with BOTH public and private fields
|
||||
result = agent.invoke(
|
||||
{
|
||||
"messages": [HumanMessage("Test private state")],
|
||||
"public_field": "public_value",
|
||||
"private_field": "private_value", # This should be filtered out
|
||||
}
|
||||
InputStateWithPrivateField(
|
||||
messages=[HumanMessage("Test private state")],
|
||||
public_field="public_value",
|
||||
private_field="private_value", # This should be filtered out
|
||||
)
|
||||
)
|
||||
|
||||
# Assert that public_field is preserved in the result
|
||||
|
||||
@@ -8,14 +8,21 @@ tools were dropped during `stream(..., subgraphs=True)`.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.messages import HumanMessage, ToolCall
|
||||
from langchain_core.tools import tool
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import HumanMessage, ToolCall
|
||||
from langchain_core.tools import BaseTool, tool
|
||||
|
||||
from langchain.agents import AgentState, create_agent
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
def _make_subagent_caller_tool():
|
||||
from langchain.agents.middleware import InputAgentState, OutputAgentState
|
||||
|
||||
|
||||
def _make_subagent_caller_tool() -> BaseTool:
|
||||
"""Build a subagent and a tool that invokes it."""
|
||||
subagent = create_agent(
|
||||
model=FakeToolCallingModel(tool_calls=[[]]),
|
||||
@@ -26,12 +33,18 @@ def _make_subagent_caller_tool():
|
||||
def call_subagent(query: str) -> str:
|
||||
"""Delegate the query to a sub-agent."""
|
||||
result = subagent.invoke({"messages": [HumanMessage(query)]})
|
||||
return result["messages"][-1].text
|
||||
# `invoke()` returns an untyped state, so `.text` is `Any`; it is really a
|
||||
# `str` (`TextAccessor`), so narrow it to satisfy the `-> str` return type.
|
||||
return cast("str", result["messages"][-1].text)
|
||||
|
||||
return call_subagent
|
||||
|
||||
|
||||
def _make_parent_agent(call_subagent_tool) -> object:
|
||||
# Return type mirrors `create_agent`'s overload; the context slot is unparameterized
|
||||
# here, so it resolves to `None` (the `ContextT` default).
|
||||
def _make_parent_agent(
|
||||
call_subagent_tool: BaseTool,
|
||||
) -> CompiledStateGraph[AgentState[Any], None, InputAgentState, OutputAgentState[Any]]:
|
||||
parent_tool_calls: list[list[ToolCall]] = [
|
||||
[{"args": {"query": "hi"}, "id": "call_1", "name": "call_subagent"}],
|
||||
[],
|
||||
|
||||
@@ -9,6 +9,7 @@ dispatches a nested `create_agent` from a tool, giving true end-to-end coverage.
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
@@ -38,7 +39,9 @@ def test_subagents_surfaces_named_subagent() -> None:
|
||||
def call_weather(city: str) -> str:
|
||||
"""Call the weather agent."""
|
||||
result = weather_agent.invoke({"messages": [HumanMessage(f"weather in {city}")]})
|
||||
return result["messages"][-1].text
|
||||
# `invoke()` returns an untyped state, so `.text` is `Any`; it is really a
|
||||
# `str` (`TextAccessor`), so narrow it to satisfy the `-> str` return type.
|
||||
return cast("str", result["messages"][-1].text)
|
||||
|
||||
supervisor = create_agent(
|
||||
model=_supervisor_model(),
|
||||
@@ -85,7 +88,9 @@ async def test_subagents_surfaces_named_subagent_async() -> None:
|
||||
async def call_weather(city: str) -> str:
|
||||
"""Call the weather agent."""
|
||||
result = await weather_agent.ainvoke({"messages": [HumanMessage(f"weather in {city}")]})
|
||||
return result["messages"][-1].text
|
||||
# `invoke()` returns an untyped state, so `.text` is `Any`; it is really a
|
||||
# `str` (`TextAccessor`), so narrow it to satisfy the `-> str` return type.
|
||||
return cast("str", result["messages"][-1].text)
|
||||
|
||||
supervisor = create_agent(
|
||||
model=_supervisor_model(),
|
||||
@@ -148,7 +153,9 @@ def test_unnamed_inner_agent_surfaces_with_inherited_name() -> None:
|
||||
def call_weather(city: str) -> str:
|
||||
"""Call an unnamed inner agent."""
|
||||
result = inner_agent.invoke({"messages": [HumanMessage(f"weather in {city}")]})
|
||||
return result["messages"][-1].text
|
||||
# `invoke()` returns an untyped state, so `.text` is `Any`; it is really a
|
||||
# `str` (`TextAccessor`), so narrow it to satisfy the `-> str` return type.
|
||||
return cast("str", result["messages"][-1].text)
|
||||
|
||||
supervisor = create_agent(
|
||||
model=_supervisor_model(),
|
||||
@@ -183,7 +190,9 @@ def test_same_name_nested_agent_surfaced() -> None:
|
||||
def call_weather(city: str) -> str:
|
||||
"""Call a same-named inner agent."""
|
||||
result = inner_agent.invoke({"messages": [HumanMessage(f"weather in {city}")]})
|
||||
return result["messages"][-1].text
|
||||
# `invoke()` returns an untyped state, so `.text` is `Any`; it is really a
|
||||
# `str` (`TextAccessor`), so narrow it to satisfy the `-> str` return type.
|
||||
return cast("str", result["messages"][-1].text)
|
||||
|
||||
# The parent agent shares the inner agent's name.
|
||||
supervisor = create_agent(
|
||||
|
||||
@@ -765,7 +765,7 @@ class TestDynamicSystemPromptMiddleware:
|
||||
def test_middleware_can_return_system_message(self) -> None:
|
||||
"""Test that middleware can return a SystemMessage with dynamic content."""
|
||||
|
||||
def dynamic_system_prompt_middleware(request: ModelRequest) -> SystemMessage:
|
||||
def dynamic_system_prompt_middleware(request: ModelRequest[Any]) -> SystemMessage:
|
||||
"""Return a SystemMessage with dynamic content."""
|
||||
region = getattr(request.runtime.context, "region", "n/a")
|
||||
return SystemMessage(content=f"You are a helpful assistant. Region: {region}")
|
||||
@@ -786,7 +786,6 @@ class TestDynamicSystemPromptMiddleware:
|
||||
runtime=runtime,
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
new_system_message = dynamic_system_prompt_middleware(request)
|
||||
|
||||
assert isinstance(new_system_message, SystemMessage)
|
||||
|
||||
Reference in New Issue
Block a user