diff --git a/libs/langchain_v1/pyproject.toml b/libs/langchain_v1/pyproject.toml index 736d7f23a12..3c92ddffea8 100644 --- a/libs/langchain_v1/pyproject.toml +++ b/libs/langchain_v1/pyproject.toml @@ -111,10 +111,7 @@ enable_error_code = "deprecated" warn_unreachable = true exclude = [ - # Exclude agents tests except middleware_typing/ which has type-checked tests "tests/unit_tests/agents/middleware/", - "tests/unit_tests/agents/specifications/", - "tests/unit_tests/agents/test_.*\\.py", ] [[tool.mypy.overrides]] @@ -147,6 +144,7 @@ unfixable = [ ] flake8-annotations.allow-star-arg-any = true +flake8-annotations.mypy-init-return = true allowed-confusables = ["–"] [tool.ruff.lint.flake8-tidy-imports] @@ -157,13 +155,7 @@ convention = "google" ignore-var-parameters = true # ignore missing documentation for *args and **kwargs parameters [tool.ruff.lint.extend-per-file-ignores] -"tests/unit_tests/agents/*" = [ - "ANN", # Annotations, needs to fix -] -"tests/unit_tests/agents/test_responses_spec.py" = ["F821"] -"tests/unit_tests/agents/test_return_direct_spec.py" = ["F821"] "tests/unit_tests/agents/test_react_agent.py" = ["ALL"] - "tests/*" = [ "D1", # Documentation rules "S101", # Tests need assertions diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_agent_name.py b/libs/langchain_v1/tests/unit_tests/agents/test_agent_name.py index bdb4e5b080a..1ff2914a75b 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_agent_name.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_agent_name.py @@ -5,6 +5,8 @@ This module tests that the name parameter correctly sets .name on AIMessage outp from __future__ import annotations +from typing import Any, cast + from langchain_core.messages import ( AIMessage, HumanMessage, @@ -119,7 +121,9 @@ def test_lc_agent_name_in_stream_metadata() -> None: stream_mode="messages", ): if "lc_agent_name" in metadata: - metadata_with_agent_name.append(metadata["lc_agent_name"]) + # `stream()` is typed `Iterator[dict[str, Any] | Any]`, so unpacking the + # tuple leaves `metadata` as `str | Any`; cast to index it as the dict it is. + metadata_with_agent_name.append(cast("dict[str, Any]", metadata)["lc_agent_name"]) assert len(metadata_with_agent_name) > 0 assert all(name == "streaming_agent" for name in metadata_with_agent_name) @@ -155,7 +159,9 @@ def test_lc_agent_name_in_stream_metadata_multiple_iterations() -> None: stream_mode="messages", ): if "lc_agent_name" in metadata: - metadata_with_agent_name.append(metadata["lc_agent_name"]) + # `stream()` is typed `Iterator[dict[str, Any] | Any]`, so unpacking the + # tuple leaves `metadata` as `str | Any`; cast to index it as the dict it is. + metadata_with_agent_name.append(cast("dict[str, Any]", metadata)["lc_agent_name"]) # Should have metadata entries for messages from both iterations assert len(metadata_with_agent_name) > 0 @@ -176,7 +182,9 @@ async def test_lc_agent_name_in_astream_metadata() -> None: stream_mode="messages", ): if "lc_agent_name" in metadata: - metadata_with_agent_name.append(metadata["lc_agent_name"]) + # `stream()` is typed `Iterator[dict[str, Any] | Any]`, so unpacking the + # tuple leaves `metadata` as `str | Any`; cast to index it as the dict it is. + metadata_with_agent_name.append(cast("dict[str, Any]", metadata)["lc_agent_name"]) assert len(metadata_with_agent_name) > 0 assert all(name == "async_streaming_agent" for name in metadata_with_agent_name) @@ -212,7 +220,9 @@ async def test_lc_agent_name_in_astream_metadata_multiple_iterations() -> None: stream_mode="messages", ): if "lc_agent_name" in metadata: - metadata_with_agent_name.append(metadata["lc_agent_name"]) + # `stream()` is typed `Iterator[dict[str, Any] | Any]`, so unpacking the + # tuple leaves `metadata` as `str | Any`; cast to index it as the dict it is. + metadata_with_agent_name.append(cast("dict[str, Any]", metadata)["lc_agent_name"]) # Should have metadata entries for messages from both iterations assert len(metadata_with_agent_name) > 0 diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_agent_streaming.py b/libs/langchain_v1/tests/unit_tests/agents/test_agent_streaming.py index 8b5275137ec..4aa58a3f1ae 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_agent_streaming.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_agent_streaming.py @@ -113,7 +113,7 @@ class TestAgentStreamV3Sync: # by `BaseRunStream.__init__` whenever `MessagesTransformer` is # registered. Content population is covered by langgraph tests — # here we only assert the agent streamer inherits the built-in. - assert "messages" in run._mux.extensions # type: ignore[attr-defined] + assert "messages" in run._mux.extensions assert hasattr(run, "messages") # Drain so the run closes cleanly. for tc in run.tool_calls: @@ -145,14 +145,14 @@ class TestAgentStreamV3Sync: transformers=[_Marker], ) # Both the agent default and the user transformer are registered. - assert "tool_calls" in run._mux.extensions # type: ignore[attr-defined] - assert "marker" in run._mux.extensions # type: ignore[attr-defined] + assert "tool_calls" in run._mux.extensions + assert "marker" in run._mux.extensions # `ToolCallTransformer` must come BEFORE the user's transformer in # the registration order so it processes `tools` events first. The # docstring on `create_agent` promises caller transformers are # appended after the built-in. - transformers = run._mux._transformers # type: ignore[attr-defined] + transformers = run._mux._transformers tool_call_idx = next( i for i, t in enumerate(transformers) if isinstance(t, ToolCallTransformer) ) diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_create_agent_tool_validation.py b/libs/langchain_v1/tests/unit_tests/agents/test_create_agent_tool_validation.py index 10047e4dd30..eab76855871 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_create_agent_tool_validation.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_create_agent_tool_validation.py @@ -9,6 +9,7 @@ from langgraph.store.base import BaseStore from langgraph.store.memory import InMemoryStore from langchain.agents import AgentState, create_agent +from langchain.agents.middleware import InputAgentState from langchain.tools import InjectedState from langchain.tools import tool as dec_tool from tests.unit_tests.agents.model import FakeToolCallingModel @@ -37,6 +38,9 @@ def test_tool_invocation_error_excludes_injected_state() -> None: class TestState(AgentState[Any]): secret_data: str + class TestInputState(InputAgentState): + secret_data: str + @dec_tool def tool_with_injected_state( some_val: int, @@ -66,10 +70,10 @@ def test_tool_invocation_error_excludes_injected_state() -> None: ) result = agent.invoke( - { - "messages": [HumanMessage("Test message")], - "secret_data": "sensitive_secret_123", - } + TestInputState( + messages=[HumanMessage("Test message")], + secret_data="sensitive_secret_123", # noqa: S106 + ) ) tool_messages = [m for m in result["messages"] if m.type == "tool"] @@ -96,6 +100,9 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None: class TestState(AgentState[Any]): internal_data: str + class TestInputState(InputAgentState): + internal_data: str + @dec_tool async def async_tool_with_injected_state( query: str, @@ -126,10 +133,9 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None: ) result = await agent.ainvoke( - { - "messages": [HumanMessage("Test async")], - "internal_data": "secret_internal_value_xyz", - } + TestInputState( + messages=[HumanMessage("Test async")], internal_data="secret_internal_value_xyz" + ) ) tool_messages = [m for m in result["messages"] if m.type == "tool"] @@ -342,6 +348,9 @@ def test_create_agent_error_only_model_controllable_params() -> None: class StateWithSecrets(AgentState[Any]): password: str + class InputStateWithSecrets(InputAgentState): + password: str + @dec_tool def secure_tool( username: str, @@ -381,10 +390,10 @@ def test_create_agent_error_only_model_controllable_params() -> None: ) result = agent.invoke( - { - "messages": [HumanMessage("Create account")], - "password": "super_secret_password_12345", - } + InputStateWithSecrets( + messages=[HumanMessage("Create account")], + password="super_secret_password_12345", # noqa: S106 + ) ) tool_messages = [m for m in result["messages"] if m.type == "tool"] diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_fetch_last_ai_and_tool_messages.py b/libs/langchain_v1/tests/unit_tests/agents/test_fetch_last_ai_and_tool_messages.py index bec7adaa93b..a1c9204db32 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_fetch_last_ai_and_tool_messages.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_fetch_last_ai_and_tool_messages.py @@ -5,14 +5,14 @@ including the scenario where no AIMessage exists in the message list (fixes issue #34792). """ -from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage +from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, SystemMessage, ToolMessage from langchain.agents.factory import _fetch_last_ai_and_tool_messages def test_fetch_last_ai_and_tool_messages_normal() -> None: """Test normal case with AIMessage and subsequent ToolMessages.""" - messages = [ + messages: list[AnyMessage] = [ HumanMessage(content="Hello"), AIMessage(content="Hi there!", tool_calls=[{"name": "test", "id": "1", "args": {}}]), ToolMessage(content="Tool result", tool_call_id="1"), @@ -29,7 +29,7 @@ def test_fetch_last_ai_and_tool_messages_normal() -> None: def test_fetch_last_ai_and_tool_messages_multiple_ai() -> None: """Test that the last AIMessage is returned when multiple exist.""" - messages = [ + messages: list[AnyMessage] = [ HumanMessage(content="First question"), AIMessage(content="First answer", id="ai1"), HumanMessage(content="Second question"), @@ -53,7 +53,7 @@ def test_fetch_last_ai_and_tool_messages_no_ai_message() -> None: The function now returns None for the AIMessage, allowing callers to handle this edge case explicitly. """ - messages = [ + messages: list[AnyMessage] = [ HumanMessage(content="Hello"), SystemMessage(content="You are a helpful assistant"), ] @@ -70,7 +70,7 @@ def test_fetch_last_ai_and_tool_messages_empty_list() -> None: This can occur after RemoveMessage(id=REMOVE_ALL_MESSAGES) clears all messages. """ - messages: list = [] + messages: list[AnyMessage] = [] ai_msg, tool_msgs = _fetch_last_ai_and_tool_messages(messages) @@ -81,7 +81,7 @@ def test_fetch_last_ai_and_tool_messages_empty_list() -> None: def test_fetch_last_ai_and_tool_messages_only_human_messages() -> None: """Test handling when only HumanMessages exist.""" - messages = [ + messages: list[AnyMessage] = [ HumanMessage(content="Hello"), HumanMessage(content="Are you there?"), ] @@ -94,7 +94,7 @@ def test_fetch_last_ai_and_tool_messages_only_human_messages() -> None: def test_fetch_last_ai_and_tool_messages_ai_without_tool_calls() -> None: """Test AIMessage without tool_calls returns empty tool messages list.""" - messages = [ + messages: list[AnyMessage] = [ HumanMessage(content="Hello"), AIMessage(content="Hi! How can I help you today?"), ] diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_injected_runtime_create_agent.py b/libs/langchain_v1/tests/unit_tests/agents/test_injected_runtime_create_agent.py index 0f8b70657c3..f9d13afc322 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_injected_runtime_create_agent.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_injected_runtime_create_agent.py @@ -25,7 +25,7 @@ from langgraph.store.memory import InMemoryStore from typing_extensions import override from langchain.agents import create_agent -from langchain.agents.middleware.types import AgentMiddleware, AgentState +from langchain.agents.middleware.types import AgentMiddleware, AgentState, InputAgentState from langchain.tools import InjectedState, ToolRuntime from tests.unit_tests.agents.model import FakeToolCallingModel @@ -303,6 +303,9 @@ def test_tool_runtime_with_custom_state() -> None: class CustomState(AgentState[Any]): custom_field: str + class CustomInputState(InputAgentState): + custom_field: str + runtime_state = {} @tool @@ -327,10 +330,10 @@ def test_tool_runtime_with_custom_state() -> None: ) result = agent.invoke( - { - "messages": [HumanMessage("Test custom state")], - "custom_field": "custom_value", - } + CustomInputState( + messages=[HumanMessage("Test custom state")], + custom_field="custom_value", + ) ) # Verify custom field was accessible @@ -622,6 +625,10 @@ def test_combined_injected_state_runtime_store() -> None: user_id: str session_id: str + class CustomInputState(InputAgentState): + user_id: str + session_id: str + # Define explicit args schema that only includes LLM-controlled parameters weather_schema = { "type": "object", @@ -692,11 +699,11 @@ def test_combined_injected_state_runtime_store() -> None: # Invoke with custom state fields result = agent.invoke( - { - "messages": [HumanMessage("What's the weather like?")], - "user_id": "user_42", - "session_id": "session_abc123", - } + CustomInputState( + messages=[HumanMessage("What's the weather like?")], + user_id="user_42", + session_id="session_abc123", + ) ) # Verify tool executed successfully @@ -740,6 +747,10 @@ async def test_combined_injected_state_runtime_store_async() -> None: api_key: str request_id: str + class CustomInputState(InputAgentState): + api_key: str + request_id: str + # Define explicit args schema that only includes LLM-controlled parameters # Note: state, runtime, and store are NOT in this schema search_schema = { @@ -819,11 +830,11 @@ async def test_combined_injected_state_runtime_store_async() -> None: # Invoke async result = await agent.ainvoke( - { - "messages": [HumanMessage("Search for something")], - "api_key": "sk-test-key-xyz", - "request_id": "req_999", - } + CustomInputState( + messages=[HumanMessage("Search for something")], + api_key="sk-test-key-xyz", + request_id="req_999", + ) ) # Verify tool executed successfully diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_kwargs_tool_runtime_injection.py b/libs/langchain_v1/tests/unit_tests/agents/test_kwargs_tool_runtime_injection.py index 1e6637fb6c3..5c817847508 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_kwargs_tool_runtime_injection.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_kwargs_tool_runtime_injection.py @@ -16,8 +16,8 @@ class ArgsSchema(BaseModel): """Args schema with config and runtime fields.""" query: str = Field(description="The query") - config: dict | None = Field(default=None) - runtime: dict | None = Field(default=None) + config: dict[str, Any] | None = Field(default=None) + runtime: dict[str, Any] | None = Field(default=None) def test_config_and_runtime_not_injected_to_kwargs() -> None: diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py b/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py index 55ff31a0b40..f2c237f0cf6 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py @@ -905,7 +905,7 @@ class TestSupportsProviderStrategy: """Unit tests for `_supports_provider_strategy`.""" @staticmethod - def _make_structured_model(model_name: str): + def _make_structured_model(model_name: str) -> GenericFakeChatModel: class GeminiTestChatModel(GenericFakeChatModel): model_name: str diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_responses_spec.py b/libs/langchain_v1/tests/unit_tests/agents/test_responses_spec.py index 5c86d90f7c5..6c93b753b6b 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_responses_spec.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_responses_spec.py @@ -118,7 +118,7 @@ def test_responses_integration_matrix(case: TestCase) -> None: # Unwrap nested schema objects response_format_spec = [item.get("schema", item) for item in response_format_spec] if len(response_format_spec) == 1: - tool_output = ToolStrategy(response_format_spec[0]) + tool_output = ToolStrategy[Any](response_format_spec[0]) else: tool_output = ToolStrategy({"oneOf": response_format_spec}) diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_state_schema.py b/libs/langchain_v1/tests/unit_tests/agents/test_state_schema.py index bd0bffeb35f..65f10da9d2a 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_state_schema.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_state_schema.py @@ -16,6 +16,7 @@ from langchain.agents import create_agent, factory from langchain.agents.middleware.types import ( AgentMiddleware, AgentState, + InputAgentState, PrivateStateAttr, ) @@ -40,6 +41,9 @@ def test_state_schema_single_custom_field() -> None: class CustomState(AgentState[Any]): custom_field: str + class CustomInputState(InputAgentState): + custom_field: str + agent = create_agent( model=FakeToolCallingModel( tool_calls=[[{"args": {"x": 1}, "id": "call_1", "name": "simple_tool"}], []] @@ -48,7 +52,9 @@ def test_state_schema_single_custom_field() -> None: state_schema=CustomState, ) - result = agent.invoke({"messages": [HumanMessage("Test")], "custom_field": "test_value"}) + result = agent.invoke( + CustomInputState(messages=[HumanMessage("Test")], custom_field="test_value") + ) assert result["custom_field"] == "test_value" assert len(result["messages"]) == 4 @@ -62,6 +68,11 @@ def test_state_schema_multiple_custom_fields() -> None: session_id: str context: str + class CustomInputState(InputAgentState): + user_id: str + session_id: str + context: str + agent = create_agent( model=FakeToolCallingModel( tool_calls=[[{"args": {"x": 1}, "id": "call_1", "name": "simple_tool"}], []] @@ -71,12 +82,12 @@ def test_state_schema_multiple_custom_fields() -> None: ) result = agent.invoke( - { - "messages": [HumanMessage("Test")], - "user_id": "user_123", - "session_id": "session_456", - "context": "test_ctx", - } + CustomInputState( + messages=[HumanMessage("Test")], + user_id="user_123", + session_id="session_456", + context="test_ctx", + ) ) assert result["user_id"] == "user_123" @@ -91,6 +102,9 @@ def test_state_schema_with_tool_runtime() -> None: class ExtendedState(AgentState[Any]): counter: int + class ExtendedInputState(InputAgentState): + counter: int + runtime_data = {} @tool @@ -107,7 +121,7 @@ def test_state_schema_with_tool_runtime() -> None: state_schema=ExtendedState, ) - result = agent.invoke({"messages": [HumanMessage("Test")], "counter": 5}) + result = agent.invoke(ExtendedInputState(messages=[HumanMessage("Test")], counter=5)) assert runtime_data["counter"] == 5 assert "Counter is 5" in result["messages"][2].content @@ -122,6 +136,10 @@ def test_state_schema_with_middleware() -> None: class MiddlewareState(AgentState[Any]): middleware_data: str + class UserAndMiddlewareInputState(InputAgentState): + user_name: str + middleware_data: str + middleware_calls = [] class TestMiddleware(AgentMiddleware[MiddlewareState, None]): @@ -142,11 +160,11 @@ def test_state_schema_with_middleware() -> None: ) result = agent.invoke( - { - "messages": [HumanMessage("Test")], - "user_name": "Alice", - "middleware_data": "test_data", - } + UserAndMiddlewareInputState( + messages=[HumanMessage("Test")], + user_name="Alice", + middleware_data="test_data", + ) ) assert result["user_name"] == "Alice" @@ -176,6 +194,9 @@ async def test_state_schema_async() -> None: class AsyncState(AgentState[Any]): async_field: str + class AsyncInputState(InputAgentState): + async_field: str + @tool async def async_tool(x: int) -> str: """Async tool.""" @@ -190,10 +211,10 @@ async def test_state_schema_async() -> None: ) result = await agent.ainvoke( - { - "messages": [HumanMessage("Test async")], - "async_field": "async_value", - } + AsyncInputState( + messages=[HumanMessage("Test async")], + async_field="async_value", + ) ) assert result["async_field"] == "async_value" @@ -213,6 +234,10 @@ def test_state_schema_with_private_state_field() -> None: public_field: str private_field: Annotated[str, PrivateStateAttr] + class InputStateWithPrivateField(InputAgentState): + public_field: str + private_field: Annotated[str, PrivateStateAttr] + captured_state = {} @tool @@ -234,11 +259,11 @@ def test_state_schema_with_private_state_field() -> None: # Invoke the agent with BOTH public and private fields result = agent.invoke( - { - "messages": [HumanMessage("Test private state")], - "public_field": "public_value", - "private_field": "private_value", # This should be filtered out - } + InputStateWithPrivateField( + messages=[HumanMessage("Test private state")], + public_field="public_value", + private_field="private_value", # This should be filtered out + ) ) # Assert that public_field is preserved in the result diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_subagent_streaming.py b/libs/langchain_v1/tests/unit_tests/agents/test_subagent_streaming.py index 8505dec3d11..220011108e6 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_subagent_streaming.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_subagent_streaming.py @@ -8,14 +8,21 @@ tools were dropped during `stream(..., subgraphs=True)`. from __future__ import annotations -from langchain_core.messages import HumanMessage, ToolCall -from langchain_core.tools import tool +from typing import TYPE_CHECKING, Any, cast -from langchain.agents import create_agent +from langchain_core.messages import HumanMessage, ToolCall +from langchain_core.tools import BaseTool, tool + +from langchain.agents import AgentState, create_agent from tests.unit_tests.agents.model import FakeToolCallingModel +if TYPE_CHECKING: + from langgraph.graph.state import CompiledStateGraph -def _make_subagent_caller_tool(): + from langchain.agents.middleware import InputAgentState, OutputAgentState + + +def _make_subagent_caller_tool() -> BaseTool: """Build a subagent and a tool that invokes it.""" subagent = create_agent( model=FakeToolCallingModel(tool_calls=[[]]), @@ -26,12 +33,18 @@ def _make_subagent_caller_tool(): def call_subagent(query: str) -> str: """Delegate the query to a sub-agent.""" result = subagent.invoke({"messages": [HumanMessage(query)]}) - return result["messages"][-1].text + # `invoke()` returns an untyped state, so `.text` is `Any`; it is really a + # `str` (`TextAccessor`), so narrow it to satisfy the `-> str` return type. + return cast("str", result["messages"][-1].text) return call_subagent -def _make_parent_agent(call_subagent_tool) -> object: +# Return type mirrors `create_agent`'s overload; the context slot is unparameterized +# here, so it resolves to `None` (the `ContextT` default). +def _make_parent_agent( + call_subagent_tool: BaseTool, +) -> CompiledStateGraph[AgentState[Any], None, InputAgentState, OutputAgentState[Any]]: parent_tool_calls: list[list[ToolCall]] = [ [{"args": {"query": "hi"}, "id": "call_1", "name": "call_subagent"}], [], diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_subagent_transformer.py b/libs/langchain_v1/tests/unit_tests/agents/test_subagent_transformer.py index d5c3202e350..dc6f1bd99cd 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_subagent_transformer.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_subagent_transformer.py @@ -9,6 +9,7 @@ dispatches a nested `create_agent` from a tool, giving true end-to-end coverage. from __future__ import annotations import sys +from typing import cast import pytest from langchain_core.messages import HumanMessage @@ -38,7 +39,9 @@ def test_subagents_surfaces_named_subagent() -> None: def call_weather(city: str) -> str: """Call the weather agent.""" result = weather_agent.invoke({"messages": [HumanMessage(f"weather in {city}")]}) - return result["messages"][-1].text + # `invoke()` returns an untyped state, so `.text` is `Any`; it is really a + # `str` (`TextAccessor`), so narrow it to satisfy the `-> str` return type. + return cast("str", result["messages"][-1].text) supervisor = create_agent( model=_supervisor_model(), @@ -85,7 +88,9 @@ async def test_subagents_surfaces_named_subagent_async() -> None: async def call_weather(city: str) -> str: """Call the weather agent.""" result = await weather_agent.ainvoke({"messages": [HumanMessage(f"weather in {city}")]}) - return result["messages"][-1].text + # `invoke()` returns an untyped state, so `.text` is `Any`; it is really a + # `str` (`TextAccessor`), so narrow it to satisfy the `-> str` return type. + return cast("str", result["messages"][-1].text) supervisor = create_agent( model=_supervisor_model(), @@ -148,7 +153,9 @@ def test_unnamed_inner_agent_surfaces_with_inherited_name() -> None: def call_weather(city: str) -> str: """Call an unnamed inner agent.""" result = inner_agent.invoke({"messages": [HumanMessage(f"weather in {city}")]}) - return result["messages"][-1].text + # `invoke()` returns an untyped state, so `.text` is `Any`; it is really a + # `str` (`TextAccessor`), so narrow it to satisfy the `-> str` return type. + return cast("str", result["messages"][-1].text) supervisor = create_agent( model=_supervisor_model(), @@ -183,7 +190,9 @@ def test_same_name_nested_agent_surfaced() -> None: def call_weather(city: str) -> str: """Call a same-named inner agent.""" result = inner_agent.invoke({"messages": [HumanMessage(f"weather in {city}")]}) - return result["messages"][-1].text + # `invoke()` returns an untyped state, so `.text` is `Any`; it is really a + # `str` (`TextAccessor`), so narrow it to satisfy the `-> str` return type. + return cast("str", result["messages"][-1].text) # The parent agent shares the inner agent's name. supervisor = create_agent( diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_system_message.py b/libs/langchain_v1/tests/unit_tests/agents/test_system_message.py index 9c141feaf31..d7d4e24cabf 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_system_message.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_system_message.py @@ -765,7 +765,7 @@ class TestDynamicSystemPromptMiddleware: def test_middleware_can_return_system_message(self) -> None: """Test that middleware can return a SystemMessage with dynamic content.""" - def dynamic_system_prompt_middleware(request: ModelRequest) -> SystemMessage: + def dynamic_system_prompt_middleware(request: ModelRequest[Any]) -> SystemMessage: """Return a SystemMessage with dynamic content.""" region = getattr(request.runtime.context, "region", "n/a") return SystemMessage(content=f"You are a helpful assistant. Region: {region}") @@ -786,7 +786,6 @@ class TestDynamicSystemPromptMiddleware: runtime=runtime, model_settings={}, ) - new_system_message = dynamic_system_prompt_middleware(request) assert isinstance(new_system_message, SystemMessage)