diff --git a/libs/langchain_v1/langchain/agents/middleware/shell_tool.py b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py index e2326640932..1b5029538f5 100644 --- a/libs/langchain_v1/langchain/agents/middleware/shell_tool.py +++ b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py @@ -15,7 +15,7 @@ import uuid import weakref from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Any, Literal, cast +from typing import TYPE_CHECKING, Annotated, Any, Literal, cast, overload from langchain_core.messages import ToolMessage from langchain_core.runnables import run_in_executor @@ -758,6 +758,24 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState[ResponseT], ContextT, R matches_by_type.setdefault(rule.pii_type, []).extend(matches) return updated, matches_by_type + @overload + def _run_shell_tool( + self, + resources: _SessionResources, + payload: dict[str, Any], + *, + tool_call_id: str, + ) -> ToolMessage: ... + + @overload + def _run_shell_tool( + self, + resources: _SessionResources, + payload: dict[str, Any], + *, + tool_call_id: None, + ) -> str: ... + def _run_shell_tool( self, resources: _SessionResources, @@ -853,6 +871,26 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState[ResponseT], ContextT, R artifact=artifact, ) + @overload + def _format_tool_message( + self, + content: str, + tool_call_id: str, + *, + status: Literal["success", "error"], + artifact: dict[str, Any] | None = None, + ) -> ToolMessage: ... + + @overload + def _format_tool_message( + self, + content: str, + tool_call_id: None, + *, + status: Literal["success", "error"], + artifact: dict[str, Any] | None = None, + ) -> str: ... + def _format_tool_message( self, content: str, diff --git a/libs/langchain_v1/pyproject.toml b/libs/langchain_v1/pyproject.toml index 98c642177a0..11be8694607 100644 --- a/libs/langchain_v1/pyproject.toml +++ b/libs/langchain_v1/pyproject.toml @@ -110,10 +110,6 @@ strict = true enable_error_code = "deprecated" warn_unreachable = true -exclude = [ - "tests/unit_tests/agents/middleware/", -] - [[tool.mypy.overrides]] module = ["pytest_socket.*", "vcr.*"] ignore_missing_imports = true diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_composition.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_composition.py index 61e609cabf4..009098b9033 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_composition.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_composition.py @@ -277,7 +277,7 @@ class TestChainModelCallHandlers: test: str state_values = [] - runtime_values = [] + runtime_values: list[tuple[str, Runtime[Any]]] = [] def outer( request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse] diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_dynamic_tools.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_dynamic_tools.py index 682eb27aa94..1190099e367 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_dynamic_tools.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_dynamic_tools.py @@ -6,24 +6,31 @@ that are not declared upfront when creating the agent. import asyncio from collections.abc import Awaitable, Callable -from typing import Any +from typing import TYPE_CHECKING, Any import pytest from langchain_core.messages import HumanMessage, ToolCall, ToolMessage from langchain_core.tools import tool from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph.state import CompiledStateGraph from langgraph.types import Command from langchain.agents.factory import create_agent from langchain.agents.middleware.types import ( AgentMiddleware, + AgentState, + InputAgentState, ModelCallResult, ModelRequest, ModelResponse, + OutputAgentState, ToolCallRequest, ) from tests.unit_tests.agents.model import FakeToolCallingModel +if TYPE_CHECKING: + from langchain_core.runnables import RunnableConfig + @tool def static_tool(value: str) -> str: @@ -156,7 +163,7 @@ class ConditionalDynamicToolMiddleware(AgentMiddleware): def _should_add_tool(self, request: ModelRequest) -> bool: messages = request.state.get("messages", []) - return messages and "calculator" in str(messages[-1].content).lower() + return bool(messages) and "calculator" in str(messages[-1].content).lower() def wrap_model_call( self, @@ -205,10 +212,15 @@ def get_tool_messages(result: dict[str, Any]) -> list[ToolMessage]: return [m for m in result["messages"] if isinstance(m, ToolMessage)] -async def invoke_agent(agent: Any, message: str, *, use_async: bool) -> dict[str, Any]: +async def invoke_agent( + agent: CompiledStateGraph[AgentState[Any], None, InputAgentState, OutputAgentState[Any]], + message: str, + *, + use_async: bool, +) -> dict[str, Any]: """Invoke agent synchronously or asynchronously based on flag.""" - input_data = {"messages": [HumanMessage(message)]} - config = {"configurable": {"thread_id": "test"}} + input_data: InputAgentState = {"messages": [HumanMessage(message)]} + config: RunnableConfig = {"configurable": {"thread_id": "test"}} if use_async: return await agent.ainvoke(input_data, config) # Run sync invoke in thread pool to avoid blocking the event loop @@ -240,7 +252,7 @@ async def test_dynamic_tool_basic(*, use_async: bool, tools: list[Any] | None) - agent = create_agent( model=model, - tools=tools, # type: ignore[arg-type] + tools=tools, middleware=[DynamicToolMiddleware()], checkpointer=InMemorySaver(), ) diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_framework.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_framework.py index 0da495ce320..9e170caf482 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_framework.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_framework.py @@ -1,6 +1,6 @@ import sys from collections.abc import Awaitable, Callable -from typing import Annotated, Any, Generic +from typing import TYPE_CHECKING, Annotated, Any import pytest from langchain_core.language_models import GenericFakeChatModel @@ -14,6 +14,7 @@ from syrupy.assertion import SnapshotAssertion from typing_extensions import override from langchain.agents.factory import create_agent +from langchain.agents.middleware import InputAgentState from langchain.agents.middleware.types import ( AgentMiddleware, AgentState, @@ -23,7 +24,6 @@ from langchain.agents.middleware.types import ( OmitFromInput, OmitFromOutput, PrivateStateAttr, - ResponseT, after_agent, after_model, before_agent, @@ -35,6 +35,9 @@ from langchain.tools import InjectedState from tests.unit_tests.agents.messages import _AnyIdHumanMessage, _AnyIdToolMessage from tests.unit_tests.agents.model import FakeToolCallingModel +if TYPE_CHECKING: + from langchain_core.runnables import RunnableConfig + def test_create_agent_invoke( sync_checkpointer: BaseCheckpointSaver[str], @@ -96,8 +99,8 @@ def test_create_agent_invoke( checkpointer=sync_checkpointer, ) - thread1 = {"configurable": {"thread_id": "1"}} - assert agent_one.invoke({"messages": ["hello"]}, thread1) == { + thread1: RunnableConfig = {"configurable": {"thread_id": "1"}} + assert agent_one.invoke({"messages": [HumanMessage("hello")]}, thread1) == { "messages": [ _AnyIdHumanMessage(content="hello"), AIMessage( @@ -201,8 +204,8 @@ def test_create_agent_jump( if isinstance(sync_checkpointer, InMemorySaver): assert agent_one.get_graph().draw_mermaid() == snapshot - thread1 = {"configurable": {"thread_id": "1"}} - assert agent_one.invoke({"messages": []}, thread1) == {"messages": []} + thread1: RunnableConfig = {"configurable": {"thread_id": "1"}} + assert agent_one.invoke(InputAgentState(messages=[]), thread1) == {"messages": []} assert calls == ["NoopSeven.before_model", "NoopEight.before_model"] @@ -368,6 +371,11 @@ def test_public_private_state_for_custom_middleware() -> None: omit_output: Annotated[str, OmitFromOutput] private_state: Annotated[str, PrivateStateAttr] + class CustomInputState(InputAgentState): + omit_input: str + omit_output: str + private_state: str + class CustomMiddleware(AgentMiddleware[CustomState]): state_schema: type[CustomState] = CustomState @@ -380,12 +388,12 @@ def test_public_private_state_for_custom_middleware() -> None: agent = create_agent(model=FakeToolCallingModel(), middleware=[CustomMiddleware()]) result = agent.invoke( - { - "messages": [HumanMessage("Hello")], - "omit_input": "test in", - "private_state": "test in", - "omit_output": "test in", - } + CustomInputState( + messages=[HumanMessage("Hello")], + omit_input="test in", + private_state="test in", + omit_output="test in", + ) ) assert "omit_input" in result assert "omit_output" not in result @@ -420,7 +428,11 @@ def test_runtime_injected_into_middleware() -> None: # custom state w/in a function -class CustomState(AgentState[ResponseT], Generic[ResponseT]): +class CustomState(AgentState[Any]): + custom_state: str + + +class _CustomInputState(InputAgentState): custom_state: str @@ -457,10 +469,10 @@ agent = create_agent( def test_injected_state_in_middleware_agent() -> None: """Test that custom state is properly injected into tools when using middleware.""" result = agent.invoke( - { - "custom_state": "I love pizza", - "messages": [HumanMessage("Call the test state tool")], - } + _CustomInputState( + messages=[HumanMessage("Call the test state tool")], + custom_state="I love pizza", + ) ) messages = result["messages"] diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_tools.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_tools.py index 7e008e720be..587345b1f15 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_tools.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_tools.py @@ -13,6 +13,7 @@ from langchain.agents.factory import create_agent from langchain.agents.middleware.types import ( AgentMiddleware, AgentState, + InputAgentState, ModelCallResult, ModelRequest, ModelResponse, @@ -180,6 +181,9 @@ def test_middleware_can_add_and_remove_tools() -> None: class AdminState(AgentState[Any]): is_admin: bool + class AdminInputState(InputAgentState): + is_admin: bool + class ConditionalToolMiddleware(AgentMiddleware[AdminState]): state_schema = AdminState @@ -209,11 +213,11 @@ def test_middleware_can_add_and_remove_tools() -> None: # Test non-admin user - should not have access to admin_tool # We can't directly inspect the bound model, but we can verify the agent runs - result = agent.invoke({"messages": [HumanMessage("Hello")], "is_admin": False}) + result = agent.invoke(AdminInputState(messages=[HumanMessage("Hello")], is_admin=False)) assert "messages" in result # Test admin user - should have access to all tools - result = agent.invoke({"messages": [HumanMessage("Hello")], "is_admin": True}) + result = agent.invoke(AdminInputState(messages=[HumanMessage("Hello")], is_admin=True)) assert "messages" in result @@ -383,8 +387,8 @@ def test_tool_node_not_accepted() -> None: tool_node = ToolNode([some_tool]) with pytest.raises(TypeError, match="'ToolNode' object is not iterable"): - create_agent( + create_agent( # type: ignore[call-overload] model=FakeToolCallingModel(), - tools=tool_node, # type: ignore[arg-type] + tools=tool_node, system_prompt="You are a helpful assistant.", ) diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_transformers.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_transformers.py index c660b8bb13c..a4c2f0902c7 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_transformers.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_transformers.py @@ -60,7 +60,7 @@ def test_middleware_transformer_registered_on_compiled_graph() -> None: run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3") - assert "middleware_marker" in run._mux.extensions # type: ignore[attr-defined] + assert "middleware_marker" in run._mux.extensions # Drain to close the run cleanly. list(run.tool_calls) @@ -80,7 +80,7 @@ def test_middleware_and_user_transformers_compose_in_order() -> None: run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3") - transformers = run._mux._transformers # type: ignore[attr-defined] + transformers = run._mux._transformers tool_call_idx = next( i for i, t in enumerate(transformers) if isinstance(t, ToolCallTransformer) ) @@ -119,7 +119,7 @@ def test_transformers_from_multiple_middleware_preserve_middleware_order() -> No run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3") - transformers = run._mux._transformers # type: ignore[attr-defined] + transformers = run._mux._transformers idx_a = next(i for i, t in enumerate(transformers) if isinstance(t, _MarkerA)) idx_b = next(i for i, t in enumerate(transformers) if isinstance(t, _MarkerB)) assert idx_a < idx_b @@ -136,7 +136,7 @@ def test_middleware_without_transformers_does_not_affect_registry() -> None: agent = create_agent(model=FakeToolCallingModel(), tools=[], middleware=[_Middleware()]) run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3") - transformers = run._mux._transformers # type: ignore[attr-defined] + transformers = run._mux._transformers assert any(isinstance(t, ToolCallTransformer) for t in transformers) assert not any(isinstance(t, _MiddlewareMarker) for t in transformers) diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_wrap_model_call.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_wrap_model_call.py index 65fb57a90fd..1890038bdc1 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_wrap_model_call.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_wrap_model_call.py @@ -169,7 +169,7 @@ class TestRetryLogic: def __init__(self, max_retries: int = 3): super().__init__() self.max_retries = max_retries - self.attempts = [] + self.attempts: list[int] = [] def wrap_model_call( self, @@ -1195,7 +1195,7 @@ class TestWrapModelCallDecorator: messages: list[Any] custom_field: str - @wrap_model_call(state_schema=CustomState) + @wrap_model_call(state_schema=CustomState) # type: ignore[type-var] def middleware_with_schema( request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse], diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_wrap_model_call_state_update.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_wrap_model_call_state_update.py index 4257f31e990..e24493bf821 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_wrap_model_call_state_update.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/core/test_wrap_model_call_state_update.py @@ -153,7 +153,7 @@ class TestCustomStateField: summary: str class SummaryMiddleware(AgentMiddleware): - state_schema = MyState # type: ignore[assignment] + state_schema = MyState def wrap_model_call( self, @@ -437,7 +437,7 @@ class TestComposition: custom_key: str class OuterMiddleware(AgentMiddleware): - state_schema = MyState # type: ignore[assignment] + state_schema = MyState def wrap_model_call( self, @@ -456,7 +456,7 @@ class TestComposition: ) class InnerMiddleware(AgentMiddleware): - state_schema = MyState # type: ignore[assignment] + state_schema = MyState def wrap_model_call( self, @@ -493,7 +493,7 @@ class TestComposition: outer_key: str class OuterMiddleware(AgentMiddleware): - state_schema = MyState # type: ignore[assignment] + state_schema = MyState def wrap_model_call( self, @@ -507,7 +507,7 @@ class TestComposition: ) class InnerMiddleware(AgentMiddleware): - state_schema = MyState # type: ignore[assignment] + state_schema = MyState def wrap_model_call( self, @@ -550,7 +550,7 @@ class TestComposition: return handler(request) class InnerMiddleware(AgentMiddleware): - state_schema = MyState # type: ignore[assignment] + state_schema = MyState def wrap_model_call( self, @@ -752,7 +752,7 @@ class TestAsyncComposition: return await handler(request) class InnerMiddleware(AgentMiddleware): - state_schema = MyState # type: ignore[assignment] + state_schema = MyState async def awrap_model_call( self, diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_context_editing.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_context_editing.py index 39f79ac7452..0220e078ce5 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_context_editing.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_context_editing.py @@ -37,14 +37,14 @@ class _TokenCountingChatModel(FakeChatModel): def get_num_tokens_from_messages( self, messages: list[BaseMessage], - tools: Sequence | None = None, + tools: Sequence[Any] | None = None, ) -> int: return sum(_count_message_tokens(message) for message in messages) def _count_message_tokens(message: MessageLikeRepresentation) -> int: if isinstance(message, (AIMessage, ToolMessage)): - return _count_content(message.content) + return _count_content(cast("MessageLikeRepresentation", message.content)) if isinstance(message, str): return len(message) return len(str(message)) diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_model_call_limit.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_model_call_limit.py index 05ac31d6abe..dc444f8a208 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_model_call_limit.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_model_call_limit.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + import pytest from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.tools import tool @@ -5,6 +7,7 @@ from langgraph.checkpoint.memory import InMemorySaver from langgraph.runtime import Runtime from langchain.agents.factory import create_agent +from langchain.agents.middleware import InputAgentState from langchain.agents.middleware.model_call_limit import ( ModelCallLimitExceededError, ModelCallLimitMiddleware, @@ -12,6 +15,9 @@ from langchain.agents.middleware.model_call_limit import ( ) from tests.unit_tests.agents.model import FakeToolCallingModel +if TYPE_CHECKING: + from langchain_core.runnables import RunnableConfig + @tool def simple_tool(value: str) -> str: @@ -215,13 +221,13 @@ def test_run_limit_resets_between_invocations() -> None: agent = create_agent(model=model, middleware=[middleware], checkpointer=InMemorySaver()) - thread_config = {"configurable": {"thread_id": "test_thread"}} - agent.invoke({"messages": [HumanMessage("Hello")]}, thread_config) - agent.invoke({"messages": [HumanMessage("Hello again")]}, thread_config) - agent.invoke({"messages": [HumanMessage("Hello third")]}, thread_config) + thread_config: RunnableConfig = {"configurable": {"thread_id": "test_thread"}} + agent.invoke(InputAgentState(messages=[HumanMessage("Hello")]), thread_config) + agent.invoke(InputAgentState(messages=[HumanMessage("Hello again")]), thread_config) + agent.invoke(InputAgentState(messages=[HumanMessage("Hello third")]), thread_config) # Fourth run: should raise, thread_model_call_count == 3 (limit) with pytest.raises(ModelCallLimitExceededError) as exc_info: - agent.invoke({"messages": [HumanMessage("Hello fourth")]}, thread_config) + agent.invoke(InputAgentState(messages=[HumanMessage("Hello fourth")]), thread_config) error_msg = str(exc_info.value) assert "thread limit (3/3)" in error_msg diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_pii.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_pii.py index acac3cfc444..1f7f2162f74 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_pii.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_pii.py @@ -15,6 +15,7 @@ from langchain_core.messages import ( ) from langchain_core.tools import tool from langgraph.runtime import Runtime +from langgraph.stream._types import ProtocolEvent from langgraph.stream.transformers import MessagesTransformer from langchain.agents import AgentState @@ -578,7 +579,7 @@ class TestCustomDetector: KeyError: 'value' when used with hash or mask strategies. """ - def detect_phone(content: str) -> list[dict]: # type: ignore[type-arg] + def detect_phone(content: str) -> list[dict[str, Any]]: return [ {"text": m.group(), "start": m.start(), "end": m.end()} for m in re.finditer(r"\+91[\s.-]?\d{10}", content) @@ -586,7 +587,7 @@ class TestCustomDetector: middleware = PIIMiddleware( "indian_phone", - detector=detect_phone, + detector=detect_phone, # type: ignore[arg-type] strategy="hash", apply_to_input=True, ) @@ -601,7 +602,7 @@ class TestCustomDetector: def test_custom_callable_detector_with_text_key_mask(self) -> None: """Custom detectors returning 'text' instead of 'value' must work with mask strategy.""" - def detect_phone(content: str) -> list[dict]: # type: ignore[type-arg] + def detect_phone(content: str) -> list[dict[str, Any]]: return [ {"text": m.group(), "start": m.start(), "end": m.end()} for m in re.finditer(r"\+91[\s.-]?\d{10}", content) @@ -609,7 +610,7 @@ class TestCustomDetector: middleware = PIIMiddleware( "indian_phone", - detector=detect_phone, + detector=detect_phone, # type: ignore[arg-type] strategy="mask", apply_to_input=True, ) @@ -715,7 +716,7 @@ class TestMultipleMiddleware: # ============================================================================ -def _make_delta_event(text: str, *, index: int = 0, run_id: str = "r1") -> dict[str, Any]: +def _make_delta_event(text: str, *, index: int = 0, run_id: str = "r1") -> ProtocolEvent: """Build a `messages` protocol event for a text content-block delta.""" return { "type": "event", @@ -735,7 +736,7 @@ def _make_delta_event(text: str, *, index: int = 0, run_id: str = "r1") -> dict[ } -def _make_finish_event(text: str, *, index: int = 0, run_id: str = "r1") -> dict[str, Any]: +def _make_finish_event(text: str, *, index: int = 0, run_id: str = "r1") -> ProtocolEvent: """Build a `messages` protocol event for content-block-finish on a text block.""" return { "type": "event", @@ -755,7 +756,7 @@ def _make_finish_event(text: str, *, index: int = 0, run_id: str = "r1") -> dict } -def _emitted_text(events: list[dict[str, Any]]) -> str: +def _emitted_text(events: list[ProtocolEvent]) -> tuple[str, dict[int, str]]: """Concatenate delta + finalized text the way a streaming consumer would.""" parts = [] final_by_index: dict[int, str] = {} @@ -772,10 +773,10 @@ def _emitted_text(events: list[dict[str, Any]]) -> str: final_by_index[payload["index"]] = content["text"] # Concatenated delta stream is what the consumer sees in real time; # finalized text is the snapshot. Return both via a tuple-like dict. - return "".join(parts), final_by_index # type: ignore[return-value] + return "".join(parts), final_by_index -def _run_transformer(transformer: Any, events: list[dict[str, Any]]) -> list[dict[str, Any]]: +def _run_transformer(transformer: Any, events: list[ProtocolEvent]) -> list[ProtocolEvent]: """Feed events through the transformer (mutates in place) and return them.""" for event in events: transformer.process(event) @@ -840,6 +841,7 @@ class TestPIIStreamTransformer: redacted = transformer._redact_value(msg) # Original untouched. + assert isinstance(msg.content[0], dict) assert msg.content[0]["text"] == "Reach me at alice@example.com" # Redacted copy walked every block. assert "alice@example.com" not in redacted.content[0]["text"] @@ -898,7 +900,7 @@ class TestPIIStreamTransformer: rule = RedactionRule(pii_type="email").resolve() transformer = _PIIStreamTransformer(rule=rule, lookback=32) - events = [ + events: list[ProtocolEvent] = [ { "type": "event", "method": "messages", @@ -953,7 +955,7 @@ class TestPIIStreamTransformer: rule = RedactionRule(pii_type="email", strategy="block").resolve() transformer = _PIIStreamTransformer(rule=rule) - event = { + event: ProtocolEvent = { "type": "event", "method": "messages", "params": { @@ -986,7 +988,7 @@ class TestPIIStreamTransformer: rule = RedactionRule(pii_type="email").resolve() transformer = _PIIStreamTransformer(rule=rule) # default lookback=128 - events = [ + events: list[ProtocolEvent] = [ { "type": "event", "method": "messages", @@ -1031,7 +1033,7 @@ class TestPIIStreamTransformer: # 50-char args with PII near the start; emit_end = 50 - 8 = 42. args = '{"to": "alice@example.com", "subject": "hi"}' - events = [ + events: list[ProtocolEvent] = [ { "type": "event", "method": "messages", @@ -1082,7 +1084,7 @@ class TestPIIStreamTransformer: '{"to": "alice@example', '{"to": "alice@example.com"}', ] - events = [ + events: list[ProtocolEvent] = [ { "type": "event", "method": "messages", @@ -1124,7 +1126,7 @@ class TestPIIStreamTransformer: rule = RedactionRule(pii_type="email").resolve() transformer = _PIIStreamTransformer(rule=rule) - events = [ + events: list[ProtocolEvent] = [ { "type": "event", "method": "messages", @@ -1156,7 +1158,7 @@ class TestPIIStreamTransformer: rule = RedactionRule(pii_type="email", strategy="block").resolve() transformer = _PIIStreamTransformer(rule=rule) - event = { + event: ProtocolEvent = { "type": "event", "method": "messages", "params": { @@ -1192,7 +1194,7 @@ class TestPIIStreamTransformer: ], id="m1", ) - event: dict[str, Any] = { + event: ProtocolEvent = { "type": "event", "method": "messages", "params": { @@ -1217,7 +1219,7 @@ class TestPIIStreamTransformer: ], id="m1", ) - event: dict[str, Any] = { + event: ProtocolEvent = { "type": "event", "method": "messages", "params": { @@ -1235,7 +1237,7 @@ class TestPIIStreamTransformer: transformer = _PIIStreamTransformer(rule=rule) msg = ToolMessage(content="Result: alice@example.com", tool_call_id="c1", id="m1") - event: dict[str, Any] = { + event: ProtocolEvent = { "type": "event", "method": "messages", "params": { @@ -1258,7 +1260,7 @@ class TestPIIStreamTransformer: """Tools events route to the new handler without error.""" rule = RedactionRule(pii_type="email").resolve() transformer = _PIIStreamTransformer(rule=rule) - event = { + event: ProtocolEvent = { "type": "event", "method": "tools", "params": { @@ -1279,7 +1281,7 @@ class TestPIIStreamTransformer: rule = RedactionRule(pii_type="email").resolve() transformer = _PIIStreamTransformer(rule=rule) - event = { + event: ProtocolEvent = { "type": "event", "method": "tools", "params": { @@ -1303,7 +1305,7 @@ class TestPIIStreamTransformer: rule = RedactionRule(pii_type="email").resolve() transformer = _PIIStreamTransformer(rule=rule) - event = { + event: ProtocolEvent = { "type": "event", "method": "tools", "params": { @@ -1365,7 +1367,7 @@ class TestPIIStreamTransformer: rule = RedactionRule(pii_type="email").resolve() transformer = _PIIStreamTransformer(rule=rule) - events = [ + events: list[ProtocolEvent] = [ { "type": "event", "method": "messages", @@ -1403,6 +1405,7 @@ class TestPIIStreamTransformer: content="", invalid_tool_calls=[ InvalidToolCall( + type="invalid_tool_call", name="send_email", args='{"to": "alice@example.com"} BROKEN', id="c1", @@ -1411,16 +1414,18 @@ class TestPIIStreamTransformer: ], id="m1", ) - redacted = transformer._redact_value(msg) + redacted: AIMessage = transformer._redact_value(msg) + assert redacted.invalid_tool_calls[0]["args"] is not None assert "alice@example.com" not in redacted.invalid_tool_calls[0]["args"] assert "[REDACTED_EMAIL]" in redacted.invalid_tool_calls[0]["args"] + assert msg.invalid_tool_calls[0]["args"] is not None assert "alice@example.com" in msg.invalid_tool_calls[0]["args"] def test_tool_started_input_is_redacted(self) -> None: rule = RedactionRule(pii_type="email").resolve() transformer = _PIIStreamTransformer(rule=rule) - event = { + event: ProtocolEvent = { "type": "event", "method": "tools", "params": { @@ -1445,7 +1450,7 @@ class TestPIIStreamTransformer: rule = RedactionRule(pii_type="email").resolve() transformer = _PIIStreamTransformer(rule=rule, lookback=64) - events = [ + events: list[ProtocolEvent] = [ { "type": "event", "method": "tools", @@ -1483,7 +1488,7 @@ class TestPIIStreamTransformer: rule = RedactionRule(pii_type="email").resolve() transformer = _PIIStreamTransformer(rule=rule) - event = { + event: ProtocolEvent = { "type": "event", "method": "tools", "params": { @@ -1506,7 +1511,7 @@ class TestPIIStreamTransformer: rule = RedactionRule(pii_type="email").resolve() transformer = _PIIStreamTransformer(rule=rule) - event = { + event: ProtocolEvent = { "type": "event", "method": "tools", "params": { @@ -1529,7 +1534,7 @@ class TestPIIStreamTransformer: rule = RedactionRule(pii_type="email").resolve() transformer = _PIIStreamTransformer(rule=rule) - event = { + event: ProtocolEvent = { "type": "event", "method": "tools", "params": { @@ -1551,7 +1556,7 @@ class TestPIIStreamTransformer: rule = RedactionRule(pii_type="email").resolve() transformer = _PIIStreamTransformer(rule=rule, lookback=64) - delta_event = { + delta_event: ProtocolEvent = { "type": "event", "method": "tools", "params": { @@ -1567,7 +1572,7 @@ class TestPIIStreamTransformer: transformer.process(delta_event) assert "c1" in transformer._tool_buffers - finish_event = { + finish_event: ProtocolEvent = { "type": "event", "method": "tools", "params": { @@ -1588,7 +1593,7 @@ class TestPIIStreamTransformer: rule = RedactionRule(pii_type="email", strategy="block").resolve() transformer = _PIIStreamTransformer(rule=rule) - event = { + event: ProtocolEvent = { "type": "event", "method": "tools", "params": { @@ -1609,7 +1614,7 @@ class TestPIIStreamTransformer: rule = RedactionRule(pii_type="email", strategy="block").resolve() transformer = _PIIStreamTransformer(rule=rule) - event = { + event: ProtocolEvent = { "type": "event", "method": "messages", "params": { @@ -1646,7 +1651,7 @@ class TestPIIStreamTransformer: _make_finish_event("Reach me at alice@example.com tomorrow."), ] _run_transformer(transformer, events) - streamed, finals = _emitted_text(events) # type: ignore[misc] + streamed, finals = _emitted_text(events) # The raw email never reaches the wire — the delta is held in the # lookback buffer and the finalize snapshot is the redacted text. @@ -1665,7 +1670,7 @@ class TestPIIStreamTransformer: _make_finish_event("Hi, contact alice@example.com when ready"), ] _run_transformer(transformer, events) - streamed, finals = _emitted_text(events) # type: ignore[misc] + streamed, finals = _emitted_text(events) # The held-buffer should have prevented the raw email from being # released until detection ran over the concatenation. @@ -1695,7 +1700,7 @@ class TestPIIStreamTransformer: _make_finish_event(text), ] _run_transformer(transformer, events) - streamed, finals = _emitted_text(events) # type: ignore[misc] + streamed, finals = _emitted_text(events) # No prefix of the email reaches the wire. assert email not in streamed @@ -1716,7 +1721,7 @@ class TestPIIStreamTransformer: _make_finish_event("Card: 5425 2334 3010 9903 next"), ] _run_transformer(transformer, events) - streamed, finals = _emitted_text(events) # type: ignore[misc] + streamed, finals = _emitted_text(events) # No prefix of the card may reach the wire — the lookback buffer # holds whitespace-separated groups until detection runs over the @@ -1762,7 +1767,7 @@ class TestPIIStreamTransformer: transformer.process(clean_event) # no raise yet # Second delta completes the email — detection fires, raises. - completing_event = _make_delta_event("@example.com soon") + completing_event: ProtocolEvent = _make_delta_event("@example.com soon") with pytest.raises(PIIDetectionError): transformer.process(completing_event) @@ -1777,7 +1782,7 @@ class TestPIIStreamTransformer: _make_finish_event("Hello there, how are you?"), ] _run_transformer(transformer, events) - streamed, finals = _emitted_text(events) # type: ignore[misc] + streamed, finals = _emitted_text(events) # Deltas hold everything back; the finalize event carries the # whole block at once. `ChatModelStream._resolve_block_text` @@ -1801,7 +1806,7 @@ class TestPIIStreamTransformer: _make_finish_event("alice@example.com"), ] _run_transformer(transformer, events) - _, finals = _emitted_text(events) # type: ignore[misc] + _, finals = _emitted_text(events) assert "alice@example.com" not in finals[0] assert "[REDACTED_EMAIL]" in finals[0] @@ -1851,7 +1856,7 @@ class TestPIIStreamTransformer: assert ("r1", 0) in transformer._buffers # message-finish for the run wipes any (run-id, *) entries. - message_finish_event = { + message_finish_event: ProtocolEvent = { "type": "event", "method": "messages", "params": { @@ -1886,7 +1891,7 @@ class TestPIIStreamTransformer: _make_finish_event("hello alice@example.com goodbye"), ] _run_transformer(transformer, events) - _, finals = _emitted_text(events) # type: ignore[misc] + _, finals = _emitted_text(events) # The finalized snapshot always re-runs detection over the full text. assert "alice@example.com" not in finals[0] @@ -1901,7 +1906,7 @@ class TestPIIStreamTransformer: rule = RedactionRule(pii_type="email").resolve() transformer = _PIIStreamTransformer(rule=rule) - data_event = { + data_event: ProtocolEvent = { "type": "event", "method": "messages", "params": { @@ -1934,7 +1939,7 @@ class TestPIIStreamTransformer: agent = create_agent(model, [], middleware=[PIIMiddleware("email", apply_to_output=True)]) run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3") - transformers = run._mux._transformers # type: ignore[attr-defined] + transformers = run._mux._transformers pii_idx = next( i for i, t in enumerate(transformers) if isinstance(t, _PIIStreamTransformer) diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py index 460bf513f8b..21d62fba241 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py @@ -30,7 +30,9 @@ from typing_extensions import override from langchain.agents import AgentState from langchain.agents.middleware.summarization import ( + ContextSize, SummarizationMiddleware, + TriggerClause, _provider_matches, ) from langchain.chat_models import init_chat_model @@ -1153,8 +1155,8 @@ def test_summarization_middleware_cutoff_at_start_of_tool_sequence() -> None: def test_trigger_copies_mutable_inputs() -> None: """Test caller mutations do not change stored trigger configuration.""" model = FakeToolCallingModel() - clause = {"tokens": 1000} - trigger = [clause] + clause: TriggerClause = {"tokens": 1000} + trigger: list[ContextSize | TriggerClause] = [clause] middleware = SummarizationMiddleware( model=model, @@ -1171,7 +1173,7 @@ def test_trigger_copies_mutable_inputs() -> None: return 500 middleware.token_counter = token_counter_low - state = {"messages": [HumanMessage(content="1"), HumanMessage(content="2")]} + state = AgentState(messages=[HumanMessage(content="1"), HumanMessage(content="2")]) result = middleware.before_model(state, Runtime()) assert result is None @@ -1231,13 +1233,13 @@ def test_and_trigger_conditions() -> None: return 1500 # Above token threshold middleware.token_counter = token_counter_high - state = { - "messages": [ + state = AgentState( + messages=[ HumanMessage(content="1"), AIMessage(content="2"), HumanMessage(content="3"), ] - } + ) result = middleware.before_model(state, Runtime()) assert result is None, "Should not summarize when only tokens condition is met" @@ -1247,8 +1249,8 @@ def test_and_trigger_conditions() -> None: return 500 # Below token threshold middleware.token_counter = token_counter_low - state = { - "messages": [ + state = AgentState( + messages=[ HumanMessage(content="1"), AIMessage(content="2"), HumanMessage(content="3"), @@ -1256,7 +1258,7 @@ def test_and_trigger_conditions() -> None: HumanMessage(content="5"), AIMessage(content="6"), ] - } + ) result = middleware.before_model(state, Runtime()) assert result is None, "Should not summarize when only messages condition is met" @@ -1289,14 +1291,14 @@ def test_or_trigger_conditions_with_and_clauses() -> None: return 5500 middleware.token_counter = token_counter_5500 - state = { - "messages": [ + state = AgentState( + messages=[ HumanMessage(content="1"), AIMessage(content="2"), HumanMessage(content="3"), AIMessage(content="4"), ] - } + ) result = middleware.before_model(state, Runtime()) assert result is not None, "Should summarize when first OR clause is met" @@ -1306,7 +1308,7 @@ def test_or_trigger_conditions_with_and_clauses() -> None: return 3500 middleware.token_counter = token_counter_3500 - state = {"messages": [HumanMessage(content=str(i)) for i in range(7)]} + state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(7)]) result = middleware.before_model(state, Runtime()) assert result is not None, "Should summarize when second OR clause is met" @@ -1318,14 +1320,14 @@ def test_or_trigger_conditions_with_and_clauses() -> None: return 4500 middleware.token_counter = token_counter_4500 - state = { - "messages": [ + state = AgentState( + messages=[ HumanMessage(content="1"), AIMessage(content="2"), HumanMessage(content="3"), AIMessage(content="4"), ] - } + ) result = middleware.before_model(state, Runtime()) assert result is None, "Should not summarize when no complete clause is met" @@ -1337,7 +1339,7 @@ async def test_and_trigger_conditions_async() -> None: trigger={"tokens": 1000, "messages": 5}, keep=("messages", 2), ) - state = {"messages": [HumanMessage(content=str(i)) for i in range(6)]} + state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(6)]) # Only the messages threshold met (tokens below) -> should not summarize. def token_counter_low(_messages: Iterable[MessageLikeRepresentation]) -> int: @@ -1367,7 +1369,7 @@ async def test_or_trigger_conditions_with_and_clauses_async() -> None: ], keep=("messages", 2), ) - state = {"messages": [HumanMessage(content=str(i)) for i in range(4)]} + state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(4)]) # First clause met (tokens = 5500, messages = 4) -> should summarize. def token_counter_5500(_messages: Iterable[MessageLikeRepresentation]) -> int: @@ -1401,7 +1403,7 @@ def test_backward_compatibility_tuple_trigger() -> None: return 1500 middleware_single.token_counter = token_counter_high - state = {"messages": [HumanMessage(content=str(i)) for i in range(3)]} + state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(3)]) result = middleware_single.before_model(state, Runtime()) assert result is not None, "Single tuple trigger should work" @@ -1414,7 +1416,7 @@ def test_backward_compatibility_tuple_trigger() -> None: # Should trigger with high tokens (first condition met) middleware_list.token_counter = token_counter_high - state = {"messages": [HumanMessage(content=str(i)) for i in range(3)]} + state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(3)]) result = middleware_list.before_model(state, Runtime()) assert result is not None, "List of tuples should trigger when any condition met" @@ -1423,7 +1425,7 @@ def test_backward_compatibility_tuple_trigger() -> None: return 100 middleware_list.token_counter = token_counter_low - state = {"messages": [HumanMessage(content=str(i)) for i in range(6)]} + state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(6)]) result = middleware_list.before_model(state, Runtime()) assert result is not None, "List of tuples should trigger when second condition met" @@ -1447,7 +1449,7 @@ def test_mixed_and_or_conditions() -> None: return 4500 middleware.token_counter = token_counter_high - state = {"messages": [HumanMessage(content=str(i)) for i in range(12)]} + state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(12)]) result = middleware.before_model(state, Runtime()) assert result is not None, "Should trigger when AND clause is met" @@ -1456,13 +1458,13 @@ def test_mixed_and_or_conditions() -> None: return 1000 middleware.token_counter = token_counter_low - state = {"messages": [HumanMessage(content=str(i)) for i in range(55)]} + state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(55)]) result = middleware.before_model(state, Runtime()) assert result is not None, "Should trigger when simple messages condition is met" # Test case 3: Neither condition met middleware.token_counter = token_counter_low - state = {"messages": [HumanMessage(content=str(i)) for i in range(8)]} + state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(8)]) result = middleware.before_model(state, Runtime()) assert result is None, "Should not trigger when no condition is met" @@ -1484,21 +1486,21 @@ def test_fraction_in_and_trigger() -> None: # Test case 1: Both conditions met # 5 messages * 200 = 1000 tokens (profile max is 1000) # 1000 / 1000 = 1.0 >= 0.8 AND messages = 5 >= 5 - state = {"messages": [HumanMessage(content=str(i)) for i in range(5)]} + state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(5)]) result = middleware.before_model(state, Runtime()) assert result is not None, "Should trigger when both fraction and messages conditions met" # Test case 2: Only messages condition met # 3 messages * 200 = 600 tokens # 600 / 1000 = 0.6 < 0.8 and messages = 3 < 5 - state = {"messages": [HumanMessage(content=str(i)) for i in range(3)]} + state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(3)]) result = middleware.before_model(state, Runtime()) assert result is None, "Should not trigger when neither condition is fully met" # Test case 3: High fraction but not enough messages # 4 messages * 200 = 800 tokens # 800 / 1000 = 0.8 >= 0.8 but messages = 4 < 5 - state = {"messages": [HumanMessage(content=str(i)) for i in range(4)]} + state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(4)]) result = middleware.before_model(state, Runtime()) assert result is None, "Should not trigger when only fraction condition is met" @@ -1511,7 +1513,7 @@ def test_trigger_validation_errors() -> None: with pytest.raises(ValueError, match="Unsupported trigger metric"): SummarizationMiddleware( model=model, - trigger={"invalid_metric": 100}, + trigger={"invalid_metric": 100}, # type: ignore[arg-type] ) # Invalid fraction value (> 1) — shares the tuple path's message via @@ -1547,21 +1549,21 @@ def test_trigger_validation_errors() -> None: with pytest.raises(ValueError, match="Fraction trigger values must be numeric"): SummarizationMiddleware( model=model, - trigger={"fraction": "invalid"}, + trigger={"fraction": "invalid"}, # type: ignore[arg-type] ) # Float value for an integer metric is rejected (no silent truncation) with pytest.raises(ValueError, match="tokens trigger values must be integers"): SummarizationMiddleware( model=model, - trigger={"tokens": 1000.5}, + trigger={"tokens": 1000.5}, # type: ignore[arg-type] ) # Numeric string for an integer metric is rejected (no silent coercion) with pytest.raises(ValueError, match="messages trigger values must be integers"): SummarizationMiddleware( model=model, - trigger={"messages": "10"}, + trigger={"messages": "10"}, # type: ignore[arg-type] ) # Boolean is rejected (bool is an int subclass) @@ -1575,7 +1577,7 @@ def test_trigger_validation_errors() -> None: with pytest.raises(TypeError, match="Unsupported trigger item type"): SummarizationMiddleware( model=model, - trigger=["invalid"], + trigger=["invalid"], # type: ignore[list-item] ) # Unsupported top-level trigger type (not a tuple, dict, or list) @@ -1616,7 +1618,7 @@ def test_empty_list_trigger_never_summarizes() -> None: token_counter=lambda _: 10_000, ) assert middleware._trigger_conditions == [] - state = {"messages": [HumanMessage(content=str(i)) for i in range(50)]} + state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(50)]) assert middleware.before_model(state, Runtime()) is None diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_tool_call_limit.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_tool_call_limit.py index ade644a3788..9050eb6751e 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_tool_call_limit.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_tool_call_limit.py @@ -7,6 +7,7 @@ from langgraph.checkpoint.memory import InMemorySaver from langchain.agents.factory import create_agent from langchain.agents.middleware.tool_call_limit import ( + ExitBehavior, ToolCallLimitExceededError, ToolCallLimitMiddleware, ToolCallLimitState, @@ -34,7 +35,8 @@ def test_middleware_initialization_validation() -> None: assert middleware.run_limit is None # Test exit behaviors - for behavior in ["error", "end", "continue"]: + behaviors: tuple[ExitBehavior, ...] = ("error", "end", "continue") + for behavior in behaviors: middleware = ToolCallLimitMiddleware(thread_limit=5, exit_behavior=behavior) assert middleware.exit_behavior == behavior