chore(langchain): add types in agent middleware tests (#38188)

Co-authored-by: Mason Daugherty <github@mdrxy.com>
Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Christophe Bornet
2026-06-23 06:18:11 +02:00
committed by GitHub
parent 18fe3473a5
commit 64177b6fc5
14 changed files with 208 additions and 131 deletions

View File

@@ -15,7 +15,7 @@ import uuid
import weakref
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast, overload
from langchain_core.messages import ToolMessage
from langchain_core.runnables import run_in_executor
@@ -758,6 +758,24 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState[ResponseT], ContextT, R
matches_by_type.setdefault(rule.pii_type, []).extend(matches)
return updated, matches_by_type
@overload
def _run_shell_tool(
self,
resources: _SessionResources,
payload: dict[str, Any],
*,
tool_call_id: str,
) -> ToolMessage: ...
@overload
def _run_shell_tool(
self,
resources: _SessionResources,
payload: dict[str, Any],
*,
tool_call_id: None,
) -> str: ...
def _run_shell_tool(
self,
resources: _SessionResources,
@@ -853,6 +871,26 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState[ResponseT], ContextT, R
artifact=artifact,
)
@overload
def _format_tool_message(
self,
content: str,
tool_call_id: str,
*,
status: Literal["success", "error"],
artifact: dict[str, Any] | None = None,
) -> ToolMessage: ...
@overload
def _format_tool_message(
self,
content: str,
tool_call_id: None,
*,
status: Literal["success", "error"],
artifact: dict[str, Any] | None = None,
) -> str: ...
def _format_tool_message(
self,
content: str,

View File

@@ -110,10 +110,6 @@ strict = true
enable_error_code = "deprecated"
warn_unreachable = true
exclude = [
"tests/unit_tests/agents/middleware/",
]
[[tool.mypy.overrides]]
module = ["pytest_socket.*", "vcr.*"]
ignore_missing_imports = true

View File

@@ -277,7 +277,7 @@ class TestChainModelCallHandlers:
test: str
state_values = []
runtime_values = []
runtime_values: list[tuple[str, Runtime[Any]]] = []
def outer(
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]

View File

@@ -6,24 +6,31 @@ that are not declared upfront when creating the agent.
import asyncio
from collections.abc import Awaitable, Callable
from typing import Any
from typing import TYPE_CHECKING, Any
import pytest
from langchain_core.messages import HumanMessage, ToolCall, ToolMessage
from langchain_core.tools import tool
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph.state import CompiledStateGraph
from langgraph.types import Command
from langchain.agents.factory import create_agent
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
InputAgentState,
ModelCallResult,
ModelRequest,
ModelResponse,
OutputAgentState,
ToolCallRequest,
)
from tests.unit_tests.agents.model import FakeToolCallingModel
if TYPE_CHECKING:
from langchain_core.runnables import RunnableConfig
@tool
def static_tool(value: str) -> str:
@@ -156,7 +163,7 @@ class ConditionalDynamicToolMiddleware(AgentMiddleware):
def _should_add_tool(self, request: ModelRequest) -> bool:
messages = request.state.get("messages", [])
return messages and "calculator" in str(messages[-1].content).lower()
return bool(messages) and "calculator" in str(messages[-1].content).lower()
def wrap_model_call(
self,
@@ -205,10 +212,15 @@ def get_tool_messages(result: dict[str, Any]) -> list[ToolMessage]:
return [m for m in result["messages"] if isinstance(m, ToolMessage)]
async def invoke_agent(agent: Any, message: str, *, use_async: bool) -> dict[str, Any]:
async def invoke_agent(
agent: CompiledStateGraph[AgentState[Any], None, InputAgentState, OutputAgentState[Any]],
message: str,
*,
use_async: bool,
) -> dict[str, Any]:
"""Invoke agent synchronously or asynchronously based on flag."""
input_data = {"messages": [HumanMessage(message)]}
config = {"configurable": {"thread_id": "test"}}
input_data: InputAgentState = {"messages": [HumanMessage(message)]}
config: RunnableConfig = {"configurable": {"thread_id": "test"}}
if use_async:
return await agent.ainvoke(input_data, config)
# Run sync invoke in thread pool to avoid blocking the event loop
@@ -240,7 +252,7 @@ async def test_dynamic_tool_basic(*, use_async: bool, tools: list[Any] | None) -
agent = create_agent(
model=model,
tools=tools, # type: ignore[arg-type]
tools=tools,
middleware=[DynamicToolMiddleware()],
checkpointer=InMemorySaver(),
)

View File

@@ -1,6 +1,6 @@
import sys
from collections.abc import Awaitable, Callable
from typing import Annotated, Any, Generic
from typing import TYPE_CHECKING, Annotated, Any
import pytest
from langchain_core.language_models import GenericFakeChatModel
@@ -14,6 +14,7 @@ from syrupy.assertion import SnapshotAssertion
from typing_extensions import override
from langchain.agents.factory import create_agent
from langchain.agents.middleware import InputAgentState
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
@@ -23,7 +24,6 @@ from langchain.agents.middleware.types import (
OmitFromInput,
OmitFromOutput,
PrivateStateAttr,
ResponseT,
after_agent,
after_model,
before_agent,
@@ -35,6 +35,9 @@ from langchain.tools import InjectedState
from tests.unit_tests.agents.messages import _AnyIdHumanMessage, _AnyIdToolMessage
from tests.unit_tests.agents.model import FakeToolCallingModel
if TYPE_CHECKING:
from langchain_core.runnables import RunnableConfig
def test_create_agent_invoke(
sync_checkpointer: BaseCheckpointSaver[str],
@@ -96,8 +99,8 @@ def test_create_agent_invoke(
checkpointer=sync_checkpointer,
)
thread1 = {"configurable": {"thread_id": "1"}}
assert agent_one.invoke({"messages": ["hello"]}, thread1) == {
thread1: RunnableConfig = {"configurable": {"thread_id": "1"}}
assert agent_one.invoke({"messages": [HumanMessage("hello")]}, thread1) == {
"messages": [
_AnyIdHumanMessage(content="hello"),
AIMessage(
@@ -201,8 +204,8 @@ def test_create_agent_jump(
if isinstance(sync_checkpointer, InMemorySaver):
assert agent_one.get_graph().draw_mermaid() == snapshot
thread1 = {"configurable": {"thread_id": "1"}}
assert agent_one.invoke({"messages": []}, thread1) == {"messages": []}
thread1: RunnableConfig = {"configurable": {"thread_id": "1"}}
assert agent_one.invoke(InputAgentState(messages=[]), thread1) == {"messages": []}
assert calls == ["NoopSeven.before_model", "NoopEight.before_model"]
@@ -368,6 +371,11 @@ def test_public_private_state_for_custom_middleware() -> None:
omit_output: Annotated[str, OmitFromOutput]
private_state: Annotated[str, PrivateStateAttr]
class CustomInputState(InputAgentState):
omit_input: str
omit_output: str
private_state: str
class CustomMiddleware(AgentMiddleware[CustomState]):
state_schema: type[CustomState] = CustomState
@@ -380,12 +388,12 @@ def test_public_private_state_for_custom_middleware() -> None:
agent = create_agent(model=FakeToolCallingModel(), middleware=[CustomMiddleware()])
result = agent.invoke(
{
"messages": [HumanMessage("Hello")],
"omit_input": "test in",
"private_state": "test in",
"omit_output": "test in",
}
CustomInputState(
messages=[HumanMessage("Hello")],
omit_input="test in",
private_state="test in",
omit_output="test in",
)
)
assert "omit_input" in result
assert "omit_output" not in result
@@ -420,7 +428,11 @@ def test_runtime_injected_into_middleware() -> None:
# custom state w/in a function
class CustomState(AgentState[ResponseT], Generic[ResponseT]):
class CustomState(AgentState[Any]):
custom_state: str
class _CustomInputState(InputAgentState):
custom_state: str
@@ -457,10 +469,10 @@ agent = create_agent(
def test_injected_state_in_middleware_agent() -> None:
"""Test that custom state is properly injected into tools when using middleware."""
result = agent.invoke(
{
"custom_state": "I love pizza",
"messages": [HumanMessage("Call the test state tool")],
}
_CustomInputState(
messages=[HumanMessage("Call the test state tool")],
custom_state="I love pizza",
)
)
messages = result["messages"]

View File

@@ -13,6 +13,7 @@ from langchain.agents.factory import create_agent
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
InputAgentState,
ModelCallResult,
ModelRequest,
ModelResponse,
@@ -180,6 +181,9 @@ def test_middleware_can_add_and_remove_tools() -> None:
class AdminState(AgentState[Any]):
is_admin: bool
class AdminInputState(InputAgentState):
is_admin: bool
class ConditionalToolMiddleware(AgentMiddleware[AdminState]):
state_schema = AdminState
@@ -209,11 +213,11 @@ def test_middleware_can_add_and_remove_tools() -> None:
# Test non-admin user - should not have access to admin_tool
# We can't directly inspect the bound model, but we can verify the agent runs
result = agent.invoke({"messages": [HumanMessage("Hello")], "is_admin": False})
result = agent.invoke(AdminInputState(messages=[HumanMessage("Hello")], is_admin=False))
assert "messages" in result
# Test admin user - should have access to all tools
result = agent.invoke({"messages": [HumanMessage("Hello")], "is_admin": True})
result = agent.invoke(AdminInputState(messages=[HumanMessage("Hello")], is_admin=True))
assert "messages" in result
@@ -383,8 +387,8 @@ def test_tool_node_not_accepted() -> None:
tool_node = ToolNode([some_tool])
with pytest.raises(TypeError, match="'ToolNode' object is not iterable"):
create_agent(
create_agent( # type: ignore[call-overload]
model=FakeToolCallingModel(),
tools=tool_node, # type: ignore[arg-type]
tools=tool_node,
system_prompt="You are a helpful assistant.",
)

View File

@@ -60,7 +60,7 @@ def test_middleware_transformer_registered_on_compiled_graph() -> None:
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
assert "middleware_marker" in run._mux.extensions # type: ignore[attr-defined]
assert "middleware_marker" in run._mux.extensions
# Drain to close the run cleanly.
list(run.tool_calls)
@@ -80,7 +80,7 @@ def test_middleware_and_user_transformers_compose_in_order() -> None:
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
transformers = run._mux._transformers # type: ignore[attr-defined]
transformers = run._mux._transformers
tool_call_idx = next(
i for i, t in enumerate(transformers) if isinstance(t, ToolCallTransformer)
)
@@ -119,7 +119,7 @@ def test_transformers_from_multiple_middleware_preserve_middleware_order() -> No
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
transformers = run._mux._transformers # type: ignore[attr-defined]
transformers = run._mux._transformers
idx_a = next(i for i, t in enumerate(transformers) if isinstance(t, _MarkerA))
idx_b = next(i for i, t in enumerate(transformers) if isinstance(t, _MarkerB))
assert idx_a < idx_b
@@ -136,7 +136,7 @@ def test_middleware_without_transformers_does_not_affect_registry() -> None:
agent = create_agent(model=FakeToolCallingModel(), tools=[], middleware=[_Middleware()])
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
transformers = run._mux._transformers # type: ignore[attr-defined]
transformers = run._mux._transformers
assert any(isinstance(t, ToolCallTransformer) for t in transformers)
assert not any(isinstance(t, _MiddlewareMarker) for t in transformers)

View File

@@ -169,7 +169,7 @@ class TestRetryLogic:
def __init__(self, max_retries: int = 3):
super().__init__()
self.max_retries = max_retries
self.attempts = []
self.attempts: list[int] = []
def wrap_model_call(
self,
@@ -1195,7 +1195,7 @@ class TestWrapModelCallDecorator:
messages: list[Any]
custom_field: str
@wrap_model_call(state_schema=CustomState)
@wrap_model_call(state_schema=CustomState) # type: ignore[type-var]
def middleware_with_schema(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],

View File

@@ -153,7 +153,7 @@ class TestCustomStateField:
summary: str
class SummaryMiddleware(AgentMiddleware):
state_schema = MyState # type: ignore[assignment]
state_schema = MyState
def wrap_model_call(
self,
@@ -437,7 +437,7 @@ class TestComposition:
custom_key: str
class OuterMiddleware(AgentMiddleware):
state_schema = MyState # type: ignore[assignment]
state_schema = MyState
def wrap_model_call(
self,
@@ -456,7 +456,7 @@ class TestComposition:
)
class InnerMiddleware(AgentMiddleware):
state_schema = MyState # type: ignore[assignment]
state_schema = MyState
def wrap_model_call(
self,
@@ -493,7 +493,7 @@ class TestComposition:
outer_key: str
class OuterMiddleware(AgentMiddleware):
state_schema = MyState # type: ignore[assignment]
state_schema = MyState
def wrap_model_call(
self,
@@ -507,7 +507,7 @@ class TestComposition:
)
class InnerMiddleware(AgentMiddleware):
state_schema = MyState # type: ignore[assignment]
state_schema = MyState
def wrap_model_call(
self,
@@ -550,7 +550,7 @@ class TestComposition:
return handler(request)
class InnerMiddleware(AgentMiddleware):
state_schema = MyState # type: ignore[assignment]
state_schema = MyState
def wrap_model_call(
self,
@@ -752,7 +752,7 @@ class TestAsyncComposition:
return await handler(request)
class InnerMiddleware(AgentMiddleware):
state_schema = MyState # type: ignore[assignment]
state_schema = MyState
async def awrap_model_call(
self,

View File

@@ -37,14 +37,14 @@ class _TokenCountingChatModel(FakeChatModel):
def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],
tools: Sequence | None = None,
tools: Sequence[Any] | None = None,
) -> int:
return sum(_count_message_tokens(message) for message in messages)
def _count_message_tokens(message: MessageLikeRepresentation) -> int:
if isinstance(message, (AIMessage, ToolMessage)):
return _count_content(message.content)
return _count_content(cast("MessageLikeRepresentation", message.content))
if isinstance(message, str):
return len(message)
return len(str(message))

View File

@@ -1,3 +1,5 @@
from typing import TYPE_CHECKING
import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.tools import tool
@@ -5,6 +7,7 @@ from langgraph.checkpoint.memory import InMemorySaver
from langgraph.runtime import Runtime
from langchain.agents.factory import create_agent
from langchain.agents.middleware import InputAgentState
from langchain.agents.middleware.model_call_limit import (
ModelCallLimitExceededError,
ModelCallLimitMiddleware,
@@ -12,6 +15,9 @@ from langchain.agents.middleware.model_call_limit import (
)
from tests.unit_tests.agents.model import FakeToolCallingModel
if TYPE_CHECKING:
from langchain_core.runnables import RunnableConfig
@tool
def simple_tool(value: str) -> str:
@@ -215,13 +221,13 @@ def test_run_limit_resets_between_invocations() -> None:
agent = create_agent(model=model, middleware=[middleware], checkpointer=InMemorySaver())
thread_config = {"configurable": {"thread_id": "test_thread"}}
agent.invoke({"messages": [HumanMessage("Hello")]}, thread_config)
agent.invoke({"messages": [HumanMessage("Hello again")]}, thread_config)
agent.invoke({"messages": [HumanMessage("Hello third")]}, thread_config)
thread_config: RunnableConfig = {"configurable": {"thread_id": "test_thread"}}
agent.invoke(InputAgentState(messages=[HumanMessage("Hello")]), thread_config)
agent.invoke(InputAgentState(messages=[HumanMessage("Hello again")]), thread_config)
agent.invoke(InputAgentState(messages=[HumanMessage("Hello third")]), thread_config)
# Fourth run: should raise, thread_model_call_count == 3 (limit)
with pytest.raises(ModelCallLimitExceededError) as exc_info:
agent.invoke({"messages": [HumanMessage("Hello fourth")]}, thread_config)
agent.invoke(InputAgentState(messages=[HumanMessage("Hello fourth")]), thread_config)
error_msg = str(exc_info.value)
assert "thread limit (3/3)" in error_msg

View File

@@ -15,6 +15,7 @@ from langchain_core.messages import (
)
from langchain_core.tools import tool
from langgraph.runtime import Runtime
from langgraph.stream._types import ProtocolEvent
from langgraph.stream.transformers import MessagesTransformer
from langchain.agents import AgentState
@@ -578,7 +579,7 @@ class TestCustomDetector:
KeyError: 'value' when used with hash or mask strategies.
"""
def detect_phone(content: str) -> list[dict]: # type: ignore[type-arg]
def detect_phone(content: str) -> list[dict[str, Any]]:
return [
{"text": m.group(), "start": m.start(), "end": m.end()}
for m in re.finditer(r"\+91[\s.-]?\d{10}", content)
@@ -586,7 +587,7 @@ class TestCustomDetector:
middleware = PIIMiddleware(
"indian_phone",
detector=detect_phone,
detector=detect_phone, # type: ignore[arg-type]
strategy="hash",
apply_to_input=True,
)
@@ -601,7 +602,7 @@ class TestCustomDetector:
def test_custom_callable_detector_with_text_key_mask(self) -> None:
"""Custom detectors returning 'text' instead of 'value' must work with mask strategy."""
def detect_phone(content: str) -> list[dict]: # type: ignore[type-arg]
def detect_phone(content: str) -> list[dict[str, Any]]:
return [
{"text": m.group(), "start": m.start(), "end": m.end()}
for m in re.finditer(r"\+91[\s.-]?\d{10}", content)
@@ -609,7 +610,7 @@ class TestCustomDetector:
middleware = PIIMiddleware(
"indian_phone",
detector=detect_phone,
detector=detect_phone, # type: ignore[arg-type]
strategy="mask",
apply_to_input=True,
)
@@ -715,7 +716,7 @@ class TestMultipleMiddleware:
# ============================================================================
def _make_delta_event(text: str, *, index: int = 0, run_id: str = "r1") -> dict[str, Any]:
def _make_delta_event(text: str, *, index: int = 0, run_id: str = "r1") -> ProtocolEvent:
"""Build a `messages` protocol event for a text content-block delta."""
return {
"type": "event",
@@ -735,7 +736,7 @@ def _make_delta_event(text: str, *, index: int = 0, run_id: str = "r1") -> dict[
}
def _make_finish_event(text: str, *, index: int = 0, run_id: str = "r1") -> dict[str, Any]:
def _make_finish_event(text: str, *, index: int = 0, run_id: str = "r1") -> ProtocolEvent:
"""Build a `messages` protocol event for content-block-finish on a text block."""
return {
"type": "event",
@@ -755,7 +756,7 @@ def _make_finish_event(text: str, *, index: int = 0, run_id: str = "r1") -> dict
}
def _emitted_text(events: list[dict[str, Any]]) -> str:
def _emitted_text(events: list[ProtocolEvent]) -> tuple[str, dict[int, str]]:
"""Concatenate delta + finalized text the way a streaming consumer would."""
parts = []
final_by_index: dict[int, str] = {}
@@ -772,10 +773,10 @@ def _emitted_text(events: list[dict[str, Any]]) -> str:
final_by_index[payload["index"]] = content["text"]
# Concatenated delta stream is what the consumer sees in real time;
# finalized text is the snapshot. Return both via a tuple-like dict.
return "".join(parts), final_by_index # type: ignore[return-value]
return "".join(parts), final_by_index
def _run_transformer(transformer: Any, events: list[dict[str, Any]]) -> list[dict[str, Any]]:
def _run_transformer(transformer: Any, events: list[ProtocolEvent]) -> list[ProtocolEvent]:
"""Feed events through the transformer (mutates in place) and return them."""
for event in events:
transformer.process(event)
@@ -840,6 +841,7 @@ class TestPIIStreamTransformer:
redacted = transformer._redact_value(msg)
# Original untouched.
assert isinstance(msg.content[0], dict)
assert msg.content[0]["text"] == "Reach me at alice@example.com"
# Redacted copy walked every block.
assert "alice@example.com" not in redacted.content[0]["text"]
@@ -898,7 +900,7 @@ class TestPIIStreamTransformer:
rule = RedactionRule(pii_type="email").resolve()
transformer = _PIIStreamTransformer(rule=rule, lookback=32)
events = [
events: list[ProtocolEvent] = [
{
"type": "event",
"method": "messages",
@@ -953,7 +955,7 @@ class TestPIIStreamTransformer:
rule = RedactionRule(pii_type="email", strategy="block").resolve()
transformer = _PIIStreamTransformer(rule=rule)
event = {
event: ProtocolEvent = {
"type": "event",
"method": "messages",
"params": {
@@ -986,7 +988,7 @@ class TestPIIStreamTransformer:
rule = RedactionRule(pii_type="email").resolve()
transformer = _PIIStreamTransformer(rule=rule) # default lookback=128
events = [
events: list[ProtocolEvent] = [
{
"type": "event",
"method": "messages",
@@ -1031,7 +1033,7 @@ class TestPIIStreamTransformer:
# 50-char args with PII near the start; emit_end = 50 - 8 = 42.
args = '{"to": "alice@example.com", "subject": "hi"}'
events = [
events: list[ProtocolEvent] = [
{
"type": "event",
"method": "messages",
@@ -1082,7 +1084,7 @@ class TestPIIStreamTransformer:
'{"to": "alice@example',
'{"to": "alice@example.com"}',
]
events = [
events: list[ProtocolEvent] = [
{
"type": "event",
"method": "messages",
@@ -1124,7 +1126,7 @@ class TestPIIStreamTransformer:
rule = RedactionRule(pii_type="email").resolve()
transformer = _PIIStreamTransformer(rule=rule)
events = [
events: list[ProtocolEvent] = [
{
"type": "event",
"method": "messages",
@@ -1156,7 +1158,7 @@ class TestPIIStreamTransformer:
rule = RedactionRule(pii_type="email", strategy="block").resolve()
transformer = _PIIStreamTransformer(rule=rule)
event = {
event: ProtocolEvent = {
"type": "event",
"method": "messages",
"params": {
@@ -1192,7 +1194,7 @@ class TestPIIStreamTransformer:
],
id="m1",
)
event: dict[str, Any] = {
event: ProtocolEvent = {
"type": "event",
"method": "messages",
"params": {
@@ -1217,7 +1219,7 @@ class TestPIIStreamTransformer:
],
id="m1",
)
event: dict[str, Any] = {
event: ProtocolEvent = {
"type": "event",
"method": "messages",
"params": {
@@ -1235,7 +1237,7 @@ class TestPIIStreamTransformer:
transformer = _PIIStreamTransformer(rule=rule)
msg = ToolMessage(content="Result: alice@example.com", tool_call_id="c1", id="m1")
event: dict[str, Any] = {
event: ProtocolEvent = {
"type": "event",
"method": "messages",
"params": {
@@ -1258,7 +1260,7 @@ class TestPIIStreamTransformer:
"""Tools events route to the new handler without error."""
rule = RedactionRule(pii_type="email").resolve()
transformer = _PIIStreamTransformer(rule=rule)
event = {
event: ProtocolEvent = {
"type": "event",
"method": "tools",
"params": {
@@ -1279,7 +1281,7 @@ class TestPIIStreamTransformer:
rule = RedactionRule(pii_type="email").resolve()
transformer = _PIIStreamTransformer(rule=rule)
event = {
event: ProtocolEvent = {
"type": "event",
"method": "tools",
"params": {
@@ -1303,7 +1305,7 @@ class TestPIIStreamTransformer:
rule = RedactionRule(pii_type="email").resolve()
transformer = _PIIStreamTransformer(rule=rule)
event = {
event: ProtocolEvent = {
"type": "event",
"method": "tools",
"params": {
@@ -1365,7 +1367,7 @@ class TestPIIStreamTransformer:
rule = RedactionRule(pii_type="email").resolve()
transformer = _PIIStreamTransformer(rule=rule)
events = [
events: list[ProtocolEvent] = [
{
"type": "event",
"method": "messages",
@@ -1403,6 +1405,7 @@ class TestPIIStreamTransformer:
content="",
invalid_tool_calls=[
InvalidToolCall(
type="invalid_tool_call",
name="send_email",
args='{"to": "alice@example.com"} BROKEN',
id="c1",
@@ -1411,16 +1414,18 @@ class TestPIIStreamTransformer:
],
id="m1",
)
redacted = transformer._redact_value(msg)
redacted: AIMessage = transformer._redact_value(msg)
assert redacted.invalid_tool_calls[0]["args"] is not None
assert "alice@example.com" not in redacted.invalid_tool_calls[0]["args"]
assert "[REDACTED_EMAIL]" in redacted.invalid_tool_calls[0]["args"]
assert msg.invalid_tool_calls[0]["args"] is not None
assert "alice@example.com" in msg.invalid_tool_calls[0]["args"]
def test_tool_started_input_is_redacted(self) -> None:
rule = RedactionRule(pii_type="email").resolve()
transformer = _PIIStreamTransformer(rule=rule)
event = {
event: ProtocolEvent = {
"type": "event",
"method": "tools",
"params": {
@@ -1445,7 +1450,7 @@ class TestPIIStreamTransformer:
rule = RedactionRule(pii_type="email").resolve()
transformer = _PIIStreamTransformer(rule=rule, lookback=64)
events = [
events: list[ProtocolEvent] = [
{
"type": "event",
"method": "tools",
@@ -1483,7 +1488,7 @@ class TestPIIStreamTransformer:
rule = RedactionRule(pii_type="email").resolve()
transformer = _PIIStreamTransformer(rule=rule)
event = {
event: ProtocolEvent = {
"type": "event",
"method": "tools",
"params": {
@@ -1506,7 +1511,7 @@ class TestPIIStreamTransformer:
rule = RedactionRule(pii_type="email").resolve()
transformer = _PIIStreamTransformer(rule=rule)
event = {
event: ProtocolEvent = {
"type": "event",
"method": "tools",
"params": {
@@ -1529,7 +1534,7 @@ class TestPIIStreamTransformer:
rule = RedactionRule(pii_type="email").resolve()
transformer = _PIIStreamTransformer(rule=rule)
event = {
event: ProtocolEvent = {
"type": "event",
"method": "tools",
"params": {
@@ -1551,7 +1556,7 @@ class TestPIIStreamTransformer:
rule = RedactionRule(pii_type="email").resolve()
transformer = _PIIStreamTransformer(rule=rule, lookback=64)
delta_event = {
delta_event: ProtocolEvent = {
"type": "event",
"method": "tools",
"params": {
@@ -1567,7 +1572,7 @@ class TestPIIStreamTransformer:
transformer.process(delta_event)
assert "c1" in transformer._tool_buffers
finish_event = {
finish_event: ProtocolEvent = {
"type": "event",
"method": "tools",
"params": {
@@ -1588,7 +1593,7 @@ class TestPIIStreamTransformer:
rule = RedactionRule(pii_type="email", strategy="block").resolve()
transformer = _PIIStreamTransformer(rule=rule)
event = {
event: ProtocolEvent = {
"type": "event",
"method": "tools",
"params": {
@@ -1609,7 +1614,7 @@ class TestPIIStreamTransformer:
rule = RedactionRule(pii_type="email", strategy="block").resolve()
transformer = _PIIStreamTransformer(rule=rule)
event = {
event: ProtocolEvent = {
"type": "event",
"method": "messages",
"params": {
@@ -1646,7 +1651,7 @@ class TestPIIStreamTransformer:
_make_finish_event("Reach me at alice@example.com tomorrow."),
]
_run_transformer(transformer, events)
streamed, finals = _emitted_text(events) # type: ignore[misc]
streamed, finals = _emitted_text(events)
# The raw email never reaches the wire — the delta is held in the
# lookback buffer and the finalize snapshot is the redacted text.
@@ -1665,7 +1670,7 @@ class TestPIIStreamTransformer:
_make_finish_event("Hi, contact alice@example.com when ready"),
]
_run_transformer(transformer, events)
streamed, finals = _emitted_text(events) # type: ignore[misc]
streamed, finals = _emitted_text(events)
# The held-buffer should have prevented the raw email from being
# released until detection ran over the concatenation.
@@ -1695,7 +1700,7 @@ class TestPIIStreamTransformer:
_make_finish_event(text),
]
_run_transformer(transformer, events)
streamed, finals = _emitted_text(events) # type: ignore[misc]
streamed, finals = _emitted_text(events)
# No prefix of the email reaches the wire.
assert email not in streamed
@@ -1716,7 +1721,7 @@ class TestPIIStreamTransformer:
_make_finish_event("Card: 5425 2334 3010 9903 next"),
]
_run_transformer(transformer, events)
streamed, finals = _emitted_text(events) # type: ignore[misc]
streamed, finals = _emitted_text(events)
# No prefix of the card may reach the wire — the lookback buffer
# holds whitespace-separated groups until detection runs over the
@@ -1762,7 +1767,7 @@ class TestPIIStreamTransformer:
transformer.process(clean_event) # no raise yet
# Second delta completes the email — detection fires, raises.
completing_event = _make_delta_event("@example.com soon")
completing_event: ProtocolEvent = _make_delta_event("@example.com soon")
with pytest.raises(PIIDetectionError):
transformer.process(completing_event)
@@ -1777,7 +1782,7 @@ class TestPIIStreamTransformer:
_make_finish_event("Hello there, how are you?"),
]
_run_transformer(transformer, events)
streamed, finals = _emitted_text(events) # type: ignore[misc]
streamed, finals = _emitted_text(events)
# Deltas hold everything back; the finalize event carries the
# whole block at once. `ChatModelStream._resolve_block_text`
@@ -1801,7 +1806,7 @@ class TestPIIStreamTransformer:
_make_finish_event("alice@example.com"),
]
_run_transformer(transformer, events)
_, finals = _emitted_text(events) # type: ignore[misc]
_, finals = _emitted_text(events)
assert "alice@example.com" not in finals[0]
assert "[REDACTED_EMAIL]" in finals[0]
@@ -1851,7 +1856,7 @@ class TestPIIStreamTransformer:
assert ("r1", 0) in transformer._buffers
# message-finish for the run wipes any (run-id, *) entries.
message_finish_event = {
message_finish_event: ProtocolEvent = {
"type": "event",
"method": "messages",
"params": {
@@ -1886,7 +1891,7 @@ class TestPIIStreamTransformer:
_make_finish_event("hello alice@example.com goodbye"),
]
_run_transformer(transformer, events)
_, finals = _emitted_text(events) # type: ignore[misc]
_, finals = _emitted_text(events)
# The finalized snapshot always re-runs detection over the full text.
assert "alice@example.com" not in finals[0]
@@ -1901,7 +1906,7 @@ class TestPIIStreamTransformer:
rule = RedactionRule(pii_type="email").resolve()
transformer = _PIIStreamTransformer(rule=rule)
data_event = {
data_event: ProtocolEvent = {
"type": "event",
"method": "messages",
"params": {
@@ -1934,7 +1939,7 @@ class TestPIIStreamTransformer:
agent = create_agent(model, [], middleware=[PIIMiddleware("email", apply_to_output=True)])
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
transformers = run._mux._transformers # type: ignore[attr-defined]
transformers = run._mux._transformers
pii_idx = next(
i for i, t in enumerate(transformers) if isinstance(t, _PIIStreamTransformer)

View File

@@ -30,7 +30,9 @@ from typing_extensions import override
from langchain.agents import AgentState
from langchain.agents.middleware.summarization import (
ContextSize,
SummarizationMiddleware,
TriggerClause,
_provider_matches,
)
from langchain.chat_models import init_chat_model
@@ -1153,8 +1155,8 @@ def test_summarization_middleware_cutoff_at_start_of_tool_sequence() -> None:
def test_trigger_copies_mutable_inputs() -> None:
"""Test caller mutations do not change stored trigger configuration."""
model = FakeToolCallingModel()
clause = {"tokens": 1000}
trigger = [clause]
clause: TriggerClause = {"tokens": 1000}
trigger: list[ContextSize | TriggerClause] = [clause]
middleware = SummarizationMiddleware(
model=model,
@@ -1171,7 +1173,7 @@ def test_trigger_copies_mutable_inputs() -> None:
return 500
middleware.token_counter = token_counter_low
state = {"messages": [HumanMessage(content="1"), HumanMessage(content="2")]}
state = AgentState(messages=[HumanMessage(content="1"), HumanMessage(content="2")])
result = middleware.before_model(state, Runtime())
assert result is None
@@ -1231,13 +1233,13 @@ def test_and_trigger_conditions() -> None:
return 1500 # Above token threshold
middleware.token_counter = token_counter_high
state = {
"messages": [
state = AgentState(
messages=[
HumanMessage(content="1"),
AIMessage(content="2"),
HumanMessage(content="3"),
]
}
)
result = middleware.before_model(state, Runtime())
assert result is None, "Should not summarize when only tokens condition is met"
@@ -1247,8 +1249,8 @@ def test_and_trigger_conditions() -> None:
return 500 # Below token threshold
middleware.token_counter = token_counter_low
state = {
"messages": [
state = AgentState(
messages=[
HumanMessage(content="1"),
AIMessage(content="2"),
HumanMessage(content="3"),
@@ -1256,7 +1258,7 @@ def test_and_trigger_conditions() -> None:
HumanMessage(content="5"),
AIMessage(content="6"),
]
}
)
result = middleware.before_model(state, Runtime())
assert result is None, "Should not summarize when only messages condition is met"
@@ -1289,14 +1291,14 @@ def test_or_trigger_conditions_with_and_clauses() -> None:
return 5500
middleware.token_counter = token_counter_5500
state = {
"messages": [
state = AgentState(
messages=[
HumanMessage(content="1"),
AIMessage(content="2"),
HumanMessage(content="3"),
AIMessage(content="4"),
]
}
)
result = middleware.before_model(state, Runtime())
assert result is not None, "Should summarize when first OR clause is met"
@@ -1306,7 +1308,7 @@ def test_or_trigger_conditions_with_and_clauses() -> None:
return 3500
middleware.token_counter = token_counter_3500
state = {"messages": [HumanMessage(content=str(i)) for i in range(7)]}
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(7)])
result = middleware.before_model(state, Runtime())
assert result is not None, "Should summarize when second OR clause is met"
@@ -1318,14 +1320,14 @@ def test_or_trigger_conditions_with_and_clauses() -> None:
return 4500
middleware.token_counter = token_counter_4500
state = {
"messages": [
state = AgentState(
messages=[
HumanMessage(content="1"),
AIMessage(content="2"),
HumanMessage(content="3"),
AIMessage(content="4"),
]
}
)
result = middleware.before_model(state, Runtime())
assert result is None, "Should not summarize when no complete clause is met"
@@ -1337,7 +1339,7 @@ async def test_and_trigger_conditions_async() -> None:
trigger={"tokens": 1000, "messages": 5},
keep=("messages", 2),
)
state = {"messages": [HumanMessage(content=str(i)) for i in range(6)]}
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(6)])
# Only the messages threshold met (tokens below) -> should not summarize.
def token_counter_low(_messages: Iterable[MessageLikeRepresentation]) -> int:
@@ -1367,7 +1369,7 @@ async def test_or_trigger_conditions_with_and_clauses_async() -> None:
],
keep=("messages", 2),
)
state = {"messages": [HumanMessage(content=str(i)) for i in range(4)]}
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(4)])
# First clause met (tokens = 5500, messages = 4) -> should summarize.
def token_counter_5500(_messages: Iterable[MessageLikeRepresentation]) -> int:
@@ -1401,7 +1403,7 @@ def test_backward_compatibility_tuple_trigger() -> None:
return 1500
middleware_single.token_counter = token_counter_high
state = {"messages": [HumanMessage(content=str(i)) for i in range(3)]}
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(3)])
result = middleware_single.before_model(state, Runtime())
assert result is not None, "Single tuple trigger should work"
@@ -1414,7 +1416,7 @@ def test_backward_compatibility_tuple_trigger() -> None:
# Should trigger with high tokens (first condition met)
middleware_list.token_counter = token_counter_high
state = {"messages": [HumanMessage(content=str(i)) for i in range(3)]}
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(3)])
result = middleware_list.before_model(state, Runtime())
assert result is not None, "List of tuples should trigger when any condition met"
@@ -1423,7 +1425,7 @@ def test_backward_compatibility_tuple_trigger() -> None:
return 100
middleware_list.token_counter = token_counter_low
state = {"messages": [HumanMessage(content=str(i)) for i in range(6)]}
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(6)])
result = middleware_list.before_model(state, Runtime())
assert result is not None, "List of tuples should trigger when second condition met"
@@ -1447,7 +1449,7 @@ def test_mixed_and_or_conditions() -> None:
return 4500
middleware.token_counter = token_counter_high
state = {"messages": [HumanMessage(content=str(i)) for i in range(12)]}
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(12)])
result = middleware.before_model(state, Runtime())
assert result is not None, "Should trigger when AND clause is met"
@@ -1456,13 +1458,13 @@ def test_mixed_and_or_conditions() -> None:
return 1000
middleware.token_counter = token_counter_low
state = {"messages": [HumanMessage(content=str(i)) for i in range(55)]}
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(55)])
result = middleware.before_model(state, Runtime())
assert result is not None, "Should trigger when simple messages condition is met"
# Test case 3: Neither condition met
middleware.token_counter = token_counter_low
state = {"messages": [HumanMessage(content=str(i)) for i in range(8)]}
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(8)])
result = middleware.before_model(state, Runtime())
assert result is None, "Should not trigger when no condition is met"
@@ -1484,21 +1486,21 @@ def test_fraction_in_and_trigger() -> None:
# Test case 1: Both conditions met
# 5 messages * 200 = 1000 tokens (profile max is 1000)
# 1000 / 1000 = 1.0 >= 0.8 AND messages = 5 >= 5
state = {"messages": [HumanMessage(content=str(i)) for i in range(5)]}
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(5)])
result = middleware.before_model(state, Runtime())
assert result is not None, "Should trigger when both fraction and messages conditions met"
# Test case 2: Only messages condition met
# 3 messages * 200 = 600 tokens
# 600 / 1000 = 0.6 < 0.8 and messages = 3 < 5
state = {"messages": [HumanMessage(content=str(i)) for i in range(3)]}
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(3)])
result = middleware.before_model(state, Runtime())
assert result is None, "Should not trigger when neither condition is fully met"
# Test case 3: High fraction but not enough messages
# 4 messages * 200 = 800 tokens
# 800 / 1000 = 0.8 >= 0.8 but messages = 4 < 5
state = {"messages": [HumanMessage(content=str(i)) for i in range(4)]}
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(4)])
result = middleware.before_model(state, Runtime())
assert result is None, "Should not trigger when only fraction condition is met"
@@ -1511,7 +1513,7 @@ def test_trigger_validation_errors() -> None:
with pytest.raises(ValueError, match="Unsupported trigger metric"):
SummarizationMiddleware(
model=model,
trigger={"invalid_metric": 100},
trigger={"invalid_metric": 100}, # type: ignore[arg-type]
)
# Invalid fraction value (> 1) — shares the tuple path's message via
@@ -1547,21 +1549,21 @@ def test_trigger_validation_errors() -> None:
with pytest.raises(ValueError, match="Fraction trigger values must be numeric"):
SummarizationMiddleware(
model=model,
trigger={"fraction": "invalid"},
trigger={"fraction": "invalid"}, # type: ignore[arg-type]
)
# Float value for an integer metric is rejected (no silent truncation)
with pytest.raises(ValueError, match="tokens trigger values must be integers"):
SummarizationMiddleware(
model=model,
trigger={"tokens": 1000.5},
trigger={"tokens": 1000.5}, # type: ignore[arg-type]
)
# Numeric string for an integer metric is rejected (no silent coercion)
with pytest.raises(ValueError, match="messages trigger values must be integers"):
SummarizationMiddleware(
model=model,
trigger={"messages": "10"},
trigger={"messages": "10"}, # type: ignore[arg-type]
)
# Boolean is rejected (bool is an int subclass)
@@ -1575,7 +1577,7 @@ def test_trigger_validation_errors() -> None:
with pytest.raises(TypeError, match="Unsupported trigger item type"):
SummarizationMiddleware(
model=model,
trigger=["invalid"],
trigger=["invalid"], # type: ignore[list-item]
)
# Unsupported top-level trigger type (not a tuple, dict, or list)
@@ -1616,7 +1618,7 @@ def test_empty_list_trigger_never_summarizes() -> None:
token_counter=lambda _: 10_000,
)
assert middleware._trigger_conditions == []
state = {"messages": [HumanMessage(content=str(i)) for i in range(50)]}
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(50)])
assert middleware.before_model(state, Runtime()) is None

View File

@@ -7,6 +7,7 @@ from langgraph.checkpoint.memory import InMemorySaver
from langchain.agents.factory import create_agent
from langchain.agents.middleware.tool_call_limit import (
ExitBehavior,
ToolCallLimitExceededError,
ToolCallLimitMiddleware,
ToolCallLimitState,
@@ -34,7 +35,8 @@ def test_middleware_initialization_validation() -> None:
assert middleware.run_limit is None
# Test exit behaviors
for behavior in ["error", "end", "continue"]:
behaviors: tuple[ExitBehavior, ...] = ("error", "end", "continue")
for behavior in behaviors:
middleware = ToolCallLimitMiddleware(thread_limit=5, exit_behavior=behavior)
assert middleware.exit_behavior == behavior