mirror of
https://github.com/hwchase17/langchain.git
synced 2026-07-01 06:42:37 +00:00
chore(langchain): add types in agent middleware tests (#38188)
Co-authored-by: Mason Daugherty <github@mdrxy.com> Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
committed by
GitHub
parent
18fe3473a5
commit
64177b6fc5
@@ -15,7 +15,7 @@ import uuid
|
||||
import weakref
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast, overload
|
||||
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.runnables import run_in_executor
|
||||
@@ -758,6 +758,24 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState[ResponseT], ContextT, R
|
||||
matches_by_type.setdefault(rule.pii_type, []).extend(matches)
|
||||
return updated, matches_by_type
|
||||
|
||||
@overload
|
||||
def _run_shell_tool(
|
||||
self,
|
||||
resources: _SessionResources,
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
tool_call_id: str,
|
||||
) -> ToolMessage: ...
|
||||
|
||||
@overload
|
||||
def _run_shell_tool(
|
||||
self,
|
||||
resources: _SessionResources,
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
tool_call_id: None,
|
||||
) -> str: ...
|
||||
|
||||
def _run_shell_tool(
|
||||
self,
|
||||
resources: _SessionResources,
|
||||
@@ -853,6 +871,26 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState[ResponseT], ContextT, R
|
||||
artifact=artifact,
|
||||
)
|
||||
|
||||
@overload
|
||||
def _format_tool_message(
|
||||
self,
|
||||
content: str,
|
||||
tool_call_id: str,
|
||||
*,
|
||||
status: Literal["success", "error"],
|
||||
artifact: dict[str, Any] | None = None,
|
||||
) -> ToolMessage: ...
|
||||
|
||||
@overload
|
||||
def _format_tool_message(
|
||||
self,
|
||||
content: str,
|
||||
tool_call_id: None,
|
||||
*,
|
||||
status: Literal["success", "error"],
|
||||
artifact: dict[str, Any] | None = None,
|
||||
) -> str: ...
|
||||
|
||||
def _format_tool_message(
|
||||
self,
|
||||
content: str,
|
||||
|
||||
@@ -110,10 +110,6 @@ strict = true
|
||||
enable_error_code = "deprecated"
|
||||
warn_unreachable = true
|
||||
|
||||
exclude = [
|
||||
"tests/unit_tests/agents/middleware/",
|
||||
]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["pytest_socket.*", "vcr.*"]
|
||||
ignore_missing_imports = true
|
||||
|
||||
@@ -277,7 +277,7 @@ class TestChainModelCallHandlers:
|
||||
test: str
|
||||
|
||||
state_values = []
|
||||
runtime_values = []
|
||||
runtime_values: list[tuple[str, Runtime[Any]]] = []
|
||||
|
||||
def outer(
|
||||
request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse]
|
||||
|
||||
@@ -6,24 +6,31 @@ that are not declared upfront when creating the agent.
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage, ToolCall, ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from langgraph.types import Command
|
||||
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
InputAgentState,
|
||||
ModelCallResult,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
OutputAgentState,
|
||||
ToolCallRequest,
|
||||
)
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
|
||||
@tool
|
||||
def static_tool(value: str) -> str:
|
||||
@@ -156,7 +163,7 @@ class ConditionalDynamicToolMiddleware(AgentMiddleware):
|
||||
|
||||
def _should_add_tool(self, request: ModelRequest) -> bool:
|
||||
messages = request.state.get("messages", [])
|
||||
return messages and "calculator" in str(messages[-1].content).lower()
|
||||
return bool(messages) and "calculator" in str(messages[-1].content).lower()
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
@@ -205,10 +212,15 @@ def get_tool_messages(result: dict[str, Any]) -> list[ToolMessage]:
|
||||
return [m for m in result["messages"] if isinstance(m, ToolMessage)]
|
||||
|
||||
|
||||
async def invoke_agent(agent: Any, message: str, *, use_async: bool) -> dict[str, Any]:
|
||||
async def invoke_agent(
|
||||
agent: CompiledStateGraph[AgentState[Any], None, InputAgentState, OutputAgentState[Any]],
|
||||
message: str,
|
||||
*,
|
||||
use_async: bool,
|
||||
) -> dict[str, Any]:
|
||||
"""Invoke agent synchronously or asynchronously based on flag."""
|
||||
input_data = {"messages": [HumanMessage(message)]}
|
||||
config = {"configurable": {"thread_id": "test"}}
|
||||
input_data: InputAgentState = {"messages": [HumanMessage(message)]}
|
||||
config: RunnableConfig = {"configurable": {"thread_id": "test"}}
|
||||
if use_async:
|
||||
return await agent.ainvoke(input_data, config)
|
||||
# Run sync invoke in thread pool to avoid blocking the event loop
|
||||
@@ -240,7 +252,7 @@ async def test_dynamic_tool_basic(*, use_async: bool, tools: list[Any] | None) -
|
||||
|
||||
agent = create_agent(
|
||||
model=model,
|
||||
tools=tools, # type: ignore[arg-type]
|
||||
tools=tools,
|
||||
middleware=[DynamicToolMiddleware()],
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import sys
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Annotated, Any, Generic
|
||||
from typing import TYPE_CHECKING, Annotated, Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import GenericFakeChatModel
|
||||
@@ -14,6 +14,7 @@ from syrupy.assertion import SnapshotAssertion
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.agents.middleware import InputAgentState
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
@@ -23,7 +24,6 @@ from langchain.agents.middleware.types import (
|
||||
OmitFromInput,
|
||||
OmitFromOutput,
|
||||
PrivateStateAttr,
|
||||
ResponseT,
|
||||
after_agent,
|
||||
after_model,
|
||||
before_agent,
|
||||
@@ -35,6 +35,9 @@ from langchain.tools import InjectedState
|
||||
from tests.unit_tests.agents.messages import _AnyIdHumanMessage, _AnyIdToolMessage
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
|
||||
def test_create_agent_invoke(
|
||||
sync_checkpointer: BaseCheckpointSaver[str],
|
||||
@@ -96,8 +99,8 @@ def test_create_agent_invoke(
|
||||
checkpointer=sync_checkpointer,
|
||||
)
|
||||
|
||||
thread1 = {"configurable": {"thread_id": "1"}}
|
||||
assert agent_one.invoke({"messages": ["hello"]}, thread1) == {
|
||||
thread1: RunnableConfig = {"configurable": {"thread_id": "1"}}
|
||||
assert agent_one.invoke({"messages": [HumanMessage("hello")]}, thread1) == {
|
||||
"messages": [
|
||||
_AnyIdHumanMessage(content="hello"),
|
||||
AIMessage(
|
||||
@@ -201,8 +204,8 @@ def test_create_agent_jump(
|
||||
if isinstance(sync_checkpointer, InMemorySaver):
|
||||
assert agent_one.get_graph().draw_mermaid() == snapshot
|
||||
|
||||
thread1 = {"configurable": {"thread_id": "1"}}
|
||||
assert agent_one.invoke({"messages": []}, thread1) == {"messages": []}
|
||||
thread1: RunnableConfig = {"configurable": {"thread_id": "1"}}
|
||||
assert agent_one.invoke(InputAgentState(messages=[]), thread1) == {"messages": []}
|
||||
assert calls == ["NoopSeven.before_model", "NoopEight.before_model"]
|
||||
|
||||
|
||||
@@ -368,6 +371,11 @@ def test_public_private_state_for_custom_middleware() -> None:
|
||||
omit_output: Annotated[str, OmitFromOutput]
|
||||
private_state: Annotated[str, PrivateStateAttr]
|
||||
|
||||
class CustomInputState(InputAgentState):
|
||||
omit_input: str
|
||||
omit_output: str
|
||||
private_state: str
|
||||
|
||||
class CustomMiddleware(AgentMiddleware[CustomState]):
|
||||
state_schema: type[CustomState] = CustomState
|
||||
|
||||
@@ -380,12 +388,12 @@ def test_public_private_state_for_custom_middleware() -> None:
|
||||
|
||||
agent = create_agent(model=FakeToolCallingModel(), middleware=[CustomMiddleware()])
|
||||
result = agent.invoke(
|
||||
{
|
||||
"messages": [HumanMessage("Hello")],
|
||||
"omit_input": "test in",
|
||||
"private_state": "test in",
|
||||
"omit_output": "test in",
|
||||
}
|
||||
CustomInputState(
|
||||
messages=[HumanMessage("Hello")],
|
||||
omit_input="test in",
|
||||
private_state="test in",
|
||||
omit_output="test in",
|
||||
)
|
||||
)
|
||||
assert "omit_input" in result
|
||||
assert "omit_output" not in result
|
||||
@@ -420,7 +428,11 @@ def test_runtime_injected_into_middleware() -> None:
|
||||
# custom state w/in a function
|
||||
|
||||
|
||||
class CustomState(AgentState[ResponseT], Generic[ResponseT]):
|
||||
class CustomState(AgentState[Any]):
|
||||
custom_state: str
|
||||
|
||||
|
||||
class _CustomInputState(InputAgentState):
|
||||
custom_state: str
|
||||
|
||||
|
||||
@@ -457,10 +469,10 @@ agent = create_agent(
|
||||
def test_injected_state_in_middleware_agent() -> None:
|
||||
"""Test that custom state is properly injected into tools when using middleware."""
|
||||
result = agent.invoke(
|
||||
{
|
||||
"custom_state": "I love pizza",
|
||||
"messages": [HumanMessage("Call the test state tool")],
|
||||
}
|
||||
_CustomInputState(
|
||||
messages=[HumanMessage("Call the test state tool")],
|
||||
custom_state="I love pizza",
|
||||
)
|
||||
)
|
||||
|
||||
messages = result["messages"]
|
||||
|
||||
@@ -13,6 +13,7 @@ from langchain.agents.factory import create_agent
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
InputAgentState,
|
||||
ModelCallResult,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
@@ -180,6 +181,9 @@ def test_middleware_can_add_and_remove_tools() -> None:
|
||||
class AdminState(AgentState[Any]):
|
||||
is_admin: bool
|
||||
|
||||
class AdminInputState(InputAgentState):
|
||||
is_admin: bool
|
||||
|
||||
class ConditionalToolMiddleware(AgentMiddleware[AdminState]):
|
||||
state_schema = AdminState
|
||||
|
||||
@@ -209,11 +213,11 @@ def test_middleware_can_add_and_remove_tools() -> None:
|
||||
|
||||
# Test non-admin user - should not have access to admin_tool
|
||||
# We can't directly inspect the bound model, but we can verify the agent runs
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")], "is_admin": False})
|
||||
result = agent.invoke(AdminInputState(messages=[HumanMessage("Hello")], is_admin=False))
|
||||
assert "messages" in result
|
||||
|
||||
# Test admin user - should have access to all tools
|
||||
result = agent.invoke({"messages": [HumanMessage("Hello")], "is_admin": True})
|
||||
result = agent.invoke(AdminInputState(messages=[HumanMessage("Hello")], is_admin=True))
|
||||
assert "messages" in result
|
||||
|
||||
|
||||
@@ -383,8 +387,8 @@ def test_tool_node_not_accepted() -> None:
|
||||
tool_node = ToolNode([some_tool])
|
||||
|
||||
with pytest.raises(TypeError, match="'ToolNode' object is not iterable"):
|
||||
create_agent(
|
||||
create_agent( # type: ignore[call-overload]
|
||||
model=FakeToolCallingModel(),
|
||||
tools=tool_node, # type: ignore[arg-type]
|
||||
tools=tool_node,
|
||||
system_prompt="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
@@ -60,7 +60,7 @@ def test_middleware_transformer_registered_on_compiled_graph() -> None:
|
||||
|
||||
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
|
||||
|
||||
assert "middleware_marker" in run._mux.extensions # type: ignore[attr-defined]
|
||||
assert "middleware_marker" in run._mux.extensions
|
||||
# Drain to close the run cleanly.
|
||||
list(run.tool_calls)
|
||||
|
||||
@@ -80,7 +80,7 @@ def test_middleware_and_user_transformers_compose_in_order() -> None:
|
||||
|
||||
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
|
||||
|
||||
transformers = run._mux._transformers # type: ignore[attr-defined]
|
||||
transformers = run._mux._transformers
|
||||
tool_call_idx = next(
|
||||
i for i, t in enumerate(transformers) if isinstance(t, ToolCallTransformer)
|
||||
)
|
||||
@@ -119,7 +119,7 @@ def test_transformers_from_multiple_middleware_preserve_middleware_order() -> No
|
||||
|
||||
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
|
||||
|
||||
transformers = run._mux._transformers # type: ignore[attr-defined]
|
||||
transformers = run._mux._transformers
|
||||
idx_a = next(i for i, t in enumerate(transformers) if isinstance(t, _MarkerA))
|
||||
idx_b = next(i for i, t in enumerate(transformers) if isinstance(t, _MarkerB))
|
||||
assert idx_a < idx_b
|
||||
@@ -136,7 +136,7 @@ def test_middleware_without_transformers_does_not_affect_registry() -> None:
|
||||
agent = create_agent(model=FakeToolCallingModel(), tools=[], middleware=[_Middleware()])
|
||||
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
|
||||
|
||||
transformers = run._mux._transformers # type: ignore[attr-defined]
|
||||
transformers = run._mux._transformers
|
||||
assert any(isinstance(t, ToolCallTransformer) for t in transformers)
|
||||
assert not any(isinstance(t, _MiddlewareMarker) for t in transformers)
|
||||
|
||||
|
||||
@@ -169,7 +169,7 @@ class TestRetryLogic:
|
||||
def __init__(self, max_retries: int = 3):
|
||||
super().__init__()
|
||||
self.max_retries = max_retries
|
||||
self.attempts = []
|
||||
self.attempts: list[int] = []
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
@@ -1195,7 +1195,7 @@ class TestWrapModelCallDecorator:
|
||||
messages: list[Any]
|
||||
custom_field: str
|
||||
|
||||
@wrap_model_call(state_schema=CustomState)
|
||||
@wrap_model_call(state_schema=CustomState) # type: ignore[type-var]
|
||||
def middleware_with_schema(
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
|
||||
@@ -153,7 +153,7 @@ class TestCustomStateField:
|
||||
summary: str
|
||||
|
||||
class SummaryMiddleware(AgentMiddleware):
|
||||
state_schema = MyState # type: ignore[assignment]
|
||||
state_schema = MyState
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
@@ -437,7 +437,7 @@ class TestComposition:
|
||||
custom_key: str
|
||||
|
||||
class OuterMiddleware(AgentMiddleware):
|
||||
state_schema = MyState # type: ignore[assignment]
|
||||
state_schema = MyState
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
@@ -456,7 +456,7 @@ class TestComposition:
|
||||
)
|
||||
|
||||
class InnerMiddleware(AgentMiddleware):
|
||||
state_schema = MyState # type: ignore[assignment]
|
||||
state_schema = MyState
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
@@ -493,7 +493,7 @@ class TestComposition:
|
||||
outer_key: str
|
||||
|
||||
class OuterMiddleware(AgentMiddleware):
|
||||
state_schema = MyState # type: ignore[assignment]
|
||||
state_schema = MyState
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
@@ -507,7 +507,7 @@ class TestComposition:
|
||||
)
|
||||
|
||||
class InnerMiddleware(AgentMiddleware):
|
||||
state_schema = MyState # type: ignore[assignment]
|
||||
state_schema = MyState
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
@@ -550,7 +550,7 @@ class TestComposition:
|
||||
return handler(request)
|
||||
|
||||
class InnerMiddleware(AgentMiddleware):
|
||||
state_schema = MyState # type: ignore[assignment]
|
||||
state_schema = MyState
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
@@ -752,7 +752,7 @@ class TestAsyncComposition:
|
||||
return await handler(request)
|
||||
|
||||
class InnerMiddleware(AgentMiddleware):
|
||||
state_schema = MyState # type: ignore[assignment]
|
||||
state_schema = MyState
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
|
||||
@@ -37,14 +37,14 @@ class _TokenCountingChatModel(FakeChatModel):
|
||||
def get_num_tokens_from_messages(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
tools: Sequence | None = None,
|
||||
tools: Sequence[Any] | None = None,
|
||||
) -> int:
|
||||
return sum(_count_message_tokens(message) for message in messages)
|
||||
|
||||
|
||||
def _count_message_tokens(message: MessageLikeRepresentation) -> int:
|
||||
if isinstance(message, (AIMessage, ToolMessage)):
|
||||
return _count_content(message.content)
|
||||
return _count_content(cast("MessageLikeRepresentation", message.content))
|
||||
if isinstance(message, str):
|
||||
return len(message)
|
||||
return len(str(message))
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
@@ -5,6 +7,7 @@ from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.agents.middleware import InputAgentState
|
||||
from langchain.agents.middleware.model_call_limit import (
|
||||
ModelCallLimitExceededError,
|
||||
ModelCallLimitMiddleware,
|
||||
@@ -12,6 +15,9 @@ from langchain.agents.middleware.model_call_limit import (
|
||||
)
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
|
||||
@tool
|
||||
def simple_tool(value: str) -> str:
|
||||
@@ -215,13 +221,13 @@ def test_run_limit_resets_between_invocations() -> None:
|
||||
|
||||
agent = create_agent(model=model, middleware=[middleware], checkpointer=InMemorySaver())
|
||||
|
||||
thread_config = {"configurable": {"thread_id": "test_thread"}}
|
||||
agent.invoke({"messages": [HumanMessage("Hello")]}, thread_config)
|
||||
agent.invoke({"messages": [HumanMessage("Hello again")]}, thread_config)
|
||||
agent.invoke({"messages": [HumanMessage("Hello third")]}, thread_config)
|
||||
thread_config: RunnableConfig = {"configurable": {"thread_id": "test_thread"}}
|
||||
agent.invoke(InputAgentState(messages=[HumanMessage("Hello")]), thread_config)
|
||||
agent.invoke(InputAgentState(messages=[HumanMessage("Hello again")]), thread_config)
|
||||
agent.invoke(InputAgentState(messages=[HumanMessage("Hello third")]), thread_config)
|
||||
|
||||
# Fourth run: should raise, thread_model_call_count == 3 (limit)
|
||||
with pytest.raises(ModelCallLimitExceededError) as exc_info:
|
||||
agent.invoke({"messages": [HumanMessage("Hello fourth")]}, thread_config)
|
||||
agent.invoke(InputAgentState(messages=[HumanMessage("Hello fourth")]), thread_config)
|
||||
error_msg = str(exc_info.value)
|
||||
assert "thread limit (3/3)" in error_msg
|
||||
|
||||
@@ -15,6 +15,7 @@ from langchain_core.messages import (
|
||||
)
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.stream._types import ProtocolEvent
|
||||
from langgraph.stream.transformers import MessagesTransformer
|
||||
|
||||
from langchain.agents import AgentState
|
||||
@@ -578,7 +579,7 @@ class TestCustomDetector:
|
||||
KeyError: 'value' when used with hash or mask strategies.
|
||||
"""
|
||||
|
||||
def detect_phone(content: str) -> list[dict]: # type: ignore[type-arg]
|
||||
def detect_phone(content: str) -> list[dict[str, Any]]:
|
||||
return [
|
||||
{"text": m.group(), "start": m.start(), "end": m.end()}
|
||||
for m in re.finditer(r"\+91[\s.-]?\d{10}", content)
|
||||
@@ -586,7 +587,7 @@ class TestCustomDetector:
|
||||
|
||||
middleware = PIIMiddleware(
|
||||
"indian_phone",
|
||||
detector=detect_phone,
|
||||
detector=detect_phone, # type: ignore[arg-type]
|
||||
strategy="hash",
|
||||
apply_to_input=True,
|
||||
)
|
||||
@@ -601,7 +602,7 @@ class TestCustomDetector:
|
||||
def test_custom_callable_detector_with_text_key_mask(self) -> None:
|
||||
"""Custom detectors returning 'text' instead of 'value' must work with mask strategy."""
|
||||
|
||||
def detect_phone(content: str) -> list[dict]: # type: ignore[type-arg]
|
||||
def detect_phone(content: str) -> list[dict[str, Any]]:
|
||||
return [
|
||||
{"text": m.group(), "start": m.start(), "end": m.end()}
|
||||
for m in re.finditer(r"\+91[\s.-]?\d{10}", content)
|
||||
@@ -609,7 +610,7 @@ class TestCustomDetector:
|
||||
|
||||
middleware = PIIMiddleware(
|
||||
"indian_phone",
|
||||
detector=detect_phone,
|
||||
detector=detect_phone, # type: ignore[arg-type]
|
||||
strategy="mask",
|
||||
apply_to_input=True,
|
||||
)
|
||||
@@ -715,7 +716,7 @@ class TestMultipleMiddleware:
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _make_delta_event(text: str, *, index: int = 0, run_id: str = "r1") -> dict[str, Any]:
|
||||
def _make_delta_event(text: str, *, index: int = 0, run_id: str = "r1") -> ProtocolEvent:
|
||||
"""Build a `messages` protocol event for a text content-block delta."""
|
||||
return {
|
||||
"type": "event",
|
||||
@@ -735,7 +736,7 @@ def _make_delta_event(text: str, *, index: int = 0, run_id: str = "r1") -> dict[
|
||||
}
|
||||
|
||||
|
||||
def _make_finish_event(text: str, *, index: int = 0, run_id: str = "r1") -> dict[str, Any]:
|
||||
def _make_finish_event(text: str, *, index: int = 0, run_id: str = "r1") -> ProtocolEvent:
|
||||
"""Build a `messages` protocol event for content-block-finish on a text block."""
|
||||
return {
|
||||
"type": "event",
|
||||
@@ -755,7 +756,7 @@ def _make_finish_event(text: str, *, index: int = 0, run_id: str = "r1") -> dict
|
||||
}
|
||||
|
||||
|
||||
def _emitted_text(events: list[dict[str, Any]]) -> str:
|
||||
def _emitted_text(events: list[ProtocolEvent]) -> tuple[str, dict[int, str]]:
|
||||
"""Concatenate delta + finalized text the way a streaming consumer would."""
|
||||
parts = []
|
||||
final_by_index: dict[int, str] = {}
|
||||
@@ -772,10 +773,10 @@ def _emitted_text(events: list[dict[str, Any]]) -> str:
|
||||
final_by_index[payload["index"]] = content["text"]
|
||||
# Concatenated delta stream is what the consumer sees in real time;
|
||||
# finalized text is the snapshot. Return both via a tuple-like dict.
|
||||
return "".join(parts), final_by_index # type: ignore[return-value]
|
||||
return "".join(parts), final_by_index
|
||||
|
||||
|
||||
def _run_transformer(transformer: Any, events: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
def _run_transformer(transformer: Any, events: list[ProtocolEvent]) -> list[ProtocolEvent]:
|
||||
"""Feed events through the transformer (mutates in place) and return them."""
|
||||
for event in events:
|
||||
transformer.process(event)
|
||||
@@ -840,6 +841,7 @@ class TestPIIStreamTransformer:
|
||||
redacted = transformer._redact_value(msg)
|
||||
|
||||
# Original untouched.
|
||||
assert isinstance(msg.content[0], dict)
|
||||
assert msg.content[0]["text"] == "Reach me at alice@example.com"
|
||||
# Redacted copy walked every block.
|
||||
assert "alice@example.com" not in redacted.content[0]["text"]
|
||||
@@ -898,7 +900,7 @@ class TestPIIStreamTransformer:
|
||||
rule = RedactionRule(pii_type="email").resolve()
|
||||
transformer = _PIIStreamTransformer(rule=rule, lookback=32)
|
||||
|
||||
events = [
|
||||
events: list[ProtocolEvent] = [
|
||||
{
|
||||
"type": "event",
|
||||
"method": "messages",
|
||||
@@ -953,7 +955,7 @@ class TestPIIStreamTransformer:
|
||||
rule = RedactionRule(pii_type="email", strategy="block").resolve()
|
||||
transformer = _PIIStreamTransformer(rule=rule)
|
||||
|
||||
event = {
|
||||
event: ProtocolEvent = {
|
||||
"type": "event",
|
||||
"method": "messages",
|
||||
"params": {
|
||||
@@ -986,7 +988,7 @@ class TestPIIStreamTransformer:
|
||||
rule = RedactionRule(pii_type="email").resolve()
|
||||
transformer = _PIIStreamTransformer(rule=rule) # default lookback=128
|
||||
|
||||
events = [
|
||||
events: list[ProtocolEvent] = [
|
||||
{
|
||||
"type": "event",
|
||||
"method": "messages",
|
||||
@@ -1031,7 +1033,7 @@ class TestPIIStreamTransformer:
|
||||
|
||||
# 50-char args with PII near the start; emit_end = 50 - 8 = 42.
|
||||
args = '{"to": "alice@example.com", "subject": "hi"}'
|
||||
events = [
|
||||
events: list[ProtocolEvent] = [
|
||||
{
|
||||
"type": "event",
|
||||
"method": "messages",
|
||||
@@ -1082,7 +1084,7 @@ class TestPIIStreamTransformer:
|
||||
'{"to": "alice@example',
|
||||
'{"to": "alice@example.com"}',
|
||||
]
|
||||
events = [
|
||||
events: list[ProtocolEvent] = [
|
||||
{
|
||||
"type": "event",
|
||||
"method": "messages",
|
||||
@@ -1124,7 +1126,7 @@ class TestPIIStreamTransformer:
|
||||
rule = RedactionRule(pii_type="email").resolve()
|
||||
transformer = _PIIStreamTransformer(rule=rule)
|
||||
|
||||
events = [
|
||||
events: list[ProtocolEvent] = [
|
||||
{
|
||||
"type": "event",
|
||||
"method": "messages",
|
||||
@@ -1156,7 +1158,7 @@ class TestPIIStreamTransformer:
|
||||
rule = RedactionRule(pii_type="email", strategy="block").resolve()
|
||||
transformer = _PIIStreamTransformer(rule=rule)
|
||||
|
||||
event = {
|
||||
event: ProtocolEvent = {
|
||||
"type": "event",
|
||||
"method": "messages",
|
||||
"params": {
|
||||
@@ -1192,7 +1194,7 @@ class TestPIIStreamTransformer:
|
||||
],
|
||||
id="m1",
|
||||
)
|
||||
event: dict[str, Any] = {
|
||||
event: ProtocolEvent = {
|
||||
"type": "event",
|
||||
"method": "messages",
|
||||
"params": {
|
||||
@@ -1217,7 +1219,7 @@ class TestPIIStreamTransformer:
|
||||
],
|
||||
id="m1",
|
||||
)
|
||||
event: dict[str, Any] = {
|
||||
event: ProtocolEvent = {
|
||||
"type": "event",
|
||||
"method": "messages",
|
||||
"params": {
|
||||
@@ -1235,7 +1237,7 @@ class TestPIIStreamTransformer:
|
||||
transformer = _PIIStreamTransformer(rule=rule)
|
||||
|
||||
msg = ToolMessage(content="Result: alice@example.com", tool_call_id="c1", id="m1")
|
||||
event: dict[str, Any] = {
|
||||
event: ProtocolEvent = {
|
||||
"type": "event",
|
||||
"method": "messages",
|
||||
"params": {
|
||||
@@ -1258,7 +1260,7 @@ class TestPIIStreamTransformer:
|
||||
"""Tools events route to the new handler without error."""
|
||||
rule = RedactionRule(pii_type="email").resolve()
|
||||
transformer = _PIIStreamTransformer(rule=rule)
|
||||
event = {
|
||||
event: ProtocolEvent = {
|
||||
"type": "event",
|
||||
"method": "tools",
|
||||
"params": {
|
||||
@@ -1279,7 +1281,7 @@ class TestPIIStreamTransformer:
|
||||
rule = RedactionRule(pii_type="email").resolve()
|
||||
transformer = _PIIStreamTransformer(rule=rule)
|
||||
|
||||
event = {
|
||||
event: ProtocolEvent = {
|
||||
"type": "event",
|
||||
"method": "tools",
|
||||
"params": {
|
||||
@@ -1303,7 +1305,7 @@ class TestPIIStreamTransformer:
|
||||
rule = RedactionRule(pii_type="email").resolve()
|
||||
transformer = _PIIStreamTransformer(rule=rule)
|
||||
|
||||
event = {
|
||||
event: ProtocolEvent = {
|
||||
"type": "event",
|
||||
"method": "tools",
|
||||
"params": {
|
||||
@@ -1365,7 +1367,7 @@ class TestPIIStreamTransformer:
|
||||
rule = RedactionRule(pii_type="email").resolve()
|
||||
transformer = _PIIStreamTransformer(rule=rule)
|
||||
|
||||
events = [
|
||||
events: list[ProtocolEvent] = [
|
||||
{
|
||||
"type": "event",
|
||||
"method": "messages",
|
||||
@@ -1403,6 +1405,7 @@ class TestPIIStreamTransformer:
|
||||
content="",
|
||||
invalid_tool_calls=[
|
||||
InvalidToolCall(
|
||||
type="invalid_tool_call",
|
||||
name="send_email",
|
||||
args='{"to": "alice@example.com"} BROKEN',
|
||||
id="c1",
|
||||
@@ -1411,16 +1414,18 @@ class TestPIIStreamTransformer:
|
||||
],
|
||||
id="m1",
|
||||
)
|
||||
redacted = transformer._redact_value(msg)
|
||||
redacted: AIMessage = transformer._redact_value(msg)
|
||||
assert redacted.invalid_tool_calls[0]["args"] is not None
|
||||
assert "alice@example.com" not in redacted.invalid_tool_calls[0]["args"]
|
||||
assert "[REDACTED_EMAIL]" in redacted.invalid_tool_calls[0]["args"]
|
||||
assert msg.invalid_tool_calls[0]["args"] is not None
|
||||
assert "alice@example.com" in msg.invalid_tool_calls[0]["args"]
|
||||
|
||||
def test_tool_started_input_is_redacted(self) -> None:
|
||||
rule = RedactionRule(pii_type="email").resolve()
|
||||
transformer = _PIIStreamTransformer(rule=rule)
|
||||
|
||||
event = {
|
||||
event: ProtocolEvent = {
|
||||
"type": "event",
|
||||
"method": "tools",
|
||||
"params": {
|
||||
@@ -1445,7 +1450,7 @@ class TestPIIStreamTransformer:
|
||||
rule = RedactionRule(pii_type="email").resolve()
|
||||
transformer = _PIIStreamTransformer(rule=rule, lookback=64)
|
||||
|
||||
events = [
|
||||
events: list[ProtocolEvent] = [
|
||||
{
|
||||
"type": "event",
|
||||
"method": "tools",
|
||||
@@ -1483,7 +1488,7 @@ class TestPIIStreamTransformer:
|
||||
rule = RedactionRule(pii_type="email").resolve()
|
||||
transformer = _PIIStreamTransformer(rule=rule)
|
||||
|
||||
event = {
|
||||
event: ProtocolEvent = {
|
||||
"type": "event",
|
||||
"method": "tools",
|
||||
"params": {
|
||||
@@ -1506,7 +1511,7 @@ class TestPIIStreamTransformer:
|
||||
rule = RedactionRule(pii_type="email").resolve()
|
||||
transformer = _PIIStreamTransformer(rule=rule)
|
||||
|
||||
event = {
|
||||
event: ProtocolEvent = {
|
||||
"type": "event",
|
||||
"method": "tools",
|
||||
"params": {
|
||||
@@ -1529,7 +1534,7 @@ class TestPIIStreamTransformer:
|
||||
rule = RedactionRule(pii_type="email").resolve()
|
||||
transformer = _PIIStreamTransformer(rule=rule)
|
||||
|
||||
event = {
|
||||
event: ProtocolEvent = {
|
||||
"type": "event",
|
||||
"method": "tools",
|
||||
"params": {
|
||||
@@ -1551,7 +1556,7 @@ class TestPIIStreamTransformer:
|
||||
rule = RedactionRule(pii_type="email").resolve()
|
||||
transformer = _PIIStreamTransformer(rule=rule, lookback=64)
|
||||
|
||||
delta_event = {
|
||||
delta_event: ProtocolEvent = {
|
||||
"type": "event",
|
||||
"method": "tools",
|
||||
"params": {
|
||||
@@ -1567,7 +1572,7 @@ class TestPIIStreamTransformer:
|
||||
transformer.process(delta_event)
|
||||
assert "c1" in transformer._tool_buffers
|
||||
|
||||
finish_event = {
|
||||
finish_event: ProtocolEvent = {
|
||||
"type": "event",
|
||||
"method": "tools",
|
||||
"params": {
|
||||
@@ -1588,7 +1593,7 @@ class TestPIIStreamTransformer:
|
||||
rule = RedactionRule(pii_type="email", strategy="block").resolve()
|
||||
transformer = _PIIStreamTransformer(rule=rule)
|
||||
|
||||
event = {
|
||||
event: ProtocolEvent = {
|
||||
"type": "event",
|
||||
"method": "tools",
|
||||
"params": {
|
||||
@@ -1609,7 +1614,7 @@ class TestPIIStreamTransformer:
|
||||
rule = RedactionRule(pii_type="email", strategy="block").resolve()
|
||||
transformer = _PIIStreamTransformer(rule=rule)
|
||||
|
||||
event = {
|
||||
event: ProtocolEvent = {
|
||||
"type": "event",
|
||||
"method": "messages",
|
||||
"params": {
|
||||
@@ -1646,7 +1651,7 @@ class TestPIIStreamTransformer:
|
||||
_make_finish_event("Reach me at alice@example.com tomorrow."),
|
||||
]
|
||||
_run_transformer(transformer, events)
|
||||
streamed, finals = _emitted_text(events) # type: ignore[misc]
|
||||
streamed, finals = _emitted_text(events)
|
||||
|
||||
# The raw email never reaches the wire — the delta is held in the
|
||||
# lookback buffer and the finalize snapshot is the redacted text.
|
||||
@@ -1665,7 +1670,7 @@ class TestPIIStreamTransformer:
|
||||
_make_finish_event("Hi, contact alice@example.com when ready"),
|
||||
]
|
||||
_run_transformer(transformer, events)
|
||||
streamed, finals = _emitted_text(events) # type: ignore[misc]
|
||||
streamed, finals = _emitted_text(events)
|
||||
|
||||
# The held-buffer should have prevented the raw email from being
|
||||
# released until detection ran over the concatenation.
|
||||
@@ -1695,7 +1700,7 @@ class TestPIIStreamTransformer:
|
||||
_make_finish_event(text),
|
||||
]
|
||||
_run_transformer(transformer, events)
|
||||
streamed, finals = _emitted_text(events) # type: ignore[misc]
|
||||
streamed, finals = _emitted_text(events)
|
||||
|
||||
# No prefix of the email reaches the wire.
|
||||
assert email not in streamed
|
||||
@@ -1716,7 +1721,7 @@ class TestPIIStreamTransformer:
|
||||
_make_finish_event("Card: 5425 2334 3010 9903 next"),
|
||||
]
|
||||
_run_transformer(transformer, events)
|
||||
streamed, finals = _emitted_text(events) # type: ignore[misc]
|
||||
streamed, finals = _emitted_text(events)
|
||||
|
||||
# No prefix of the card may reach the wire — the lookback buffer
|
||||
# holds whitespace-separated groups until detection runs over the
|
||||
@@ -1762,7 +1767,7 @@ class TestPIIStreamTransformer:
|
||||
transformer.process(clean_event) # no raise yet
|
||||
|
||||
# Second delta completes the email — detection fires, raises.
|
||||
completing_event = _make_delta_event("@example.com soon")
|
||||
completing_event: ProtocolEvent = _make_delta_event("@example.com soon")
|
||||
with pytest.raises(PIIDetectionError):
|
||||
transformer.process(completing_event)
|
||||
|
||||
@@ -1777,7 +1782,7 @@ class TestPIIStreamTransformer:
|
||||
_make_finish_event("Hello there, how are you?"),
|
||||
]
|
||||
_run_transformer(transformer, events)
|
||||
streamed, finals = _emitted_text(events) # type: ignore[misc]
|
||||
streamed, finals = _emitted_text(events)
|
||||
|
||||
# Deltas hold everything back; the finalize event carries the
|
||||
# whole block at once. `ChatModelStream._resolve_block_text`
|
||||
@@ -1801,7 +1806,7 @@ class TestPIIStreamTransformer:
|
||||
_make_finish_event("alice@example.com"),
|
||||
]
|
||||
_run_transformer(transformer, events)
|
||||
_, finals = _emitted_text(events) # type: ignore[misc]
|
||||
_, finals = _emitted_text(events)
|
||||
|
||||
assert "alice@example.com" not in finals[0]
|
||||
assert "[REDACTED_EMAIL]" in finals[0]
|
||||
@@ -1851,7 +1856,7 @@ class TestPIIStreamTransformer:
|
||||
assert ("r1", 0) in transformer._buffers
|
||||
|
||||
# message-finish for the run wipes any (run-id, *) entries.
|
||||
message_finish_event = {
|
||||
message_finish_event: ProtocolEvent = {
|
||||
"type": "event",
|
||||
"method": "messages",
|
||||
"params": {
|
||||
@@ -1886,7 +1891,7 @@ class TestPIIStreamTransformer:
|
||||
_make_finish_event("hello alice@example.com goodbye"),
|
||||
]
|
||||
_run_transformer(transformer, events)
|
||||
_, finals = _emitted_text(events) # type: ignore[misc]
|
||||
_, finals = _emitted_text(events)
|
||||
# The finalized snapshot always re-runs detection over the full text.
|
||||
assert "alice@example.com" not in finals[0]
|
||||
|
||||
@@ -1901,7 +1906,7 @@ class TestPIIStreamTransformer:
|
||||
rule = RedactionRule(pii_type="email").resolve()
|
||||
transformer = _PIIStreamTransformer(rule=rule)
|
||||
|
||||
data_event = {
|
||||
data_event: ProtocolEvent = {
|
||||
"type": "event",
|
||||
"method": "messages",
|
||||
"params": {
|
||||
@@ -1934,7 +1939,7 @@ class TestPIIStreamTransformer:
|
||||
agent = create_agent(model, [], middleware=[PIIMiddleware("email", apply_to_output=True)])
|
||||
|
||||
run = agent.stream_events({"messages": [HumanMessage("hi")]}, version="v3")
|
||||
transformers = run._mux._transformers # type: ignore[attr-defined]
|
||||
transformers = run._mux._transformers
|
||||
|
||||
pii_idx = next(
|
||||
i for i, t in enumerate(transformers) if isinstance(t, _PIIStreamTransformer)
|
||||
|
||||
@@ -30,7 +30,9 @@ from typing_extensions import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware.summarization import (
|
||||
ContextSize,
|
||||
SummarizationMiddleware,
|
||||
TriggerClause,
|
||||
_provider_matches,
|
||||
)
|
||||
from langchain.chat_models import init_chat_model
|
||||
@@ -1153,8 +1155,8 @@ def test_summarization_middleware_cutoff_at_start_of_tool_sequence() -> None:
|
||||
def test_trigger_copies_mutable_inputs() -> None:
|
||||
"""Test caller mutations do not change stored trigger configuration."""
|
||||
model = FakeToolCallingModel()
|
||||
clause = {"tokens": 1000}
|
||||
trigger = [clause]
|
||||
clause: TriggerClause = {"tokens": 1000}
|
||||
trigger: list[ContextSize | TriggerClause] = [clause]
|
||||
|
||||
middleware = SummarizationMiddleware(
|
||||
model=model,
|
||||
@@ -1171,7 +1173,7 @@ def test_trigger_copies_mutable_inputs() -> None:
|
||||
return 500
|
||||
|
||||
middleware.token_counter = token_counter_low
|
||||
state = {"messages": [HumanMessage(content="1"), HumanMessage(content="2")]}
|
||||
state = AgentState(messages=[HumanMessage(content="1"), HumanMessage(content="2")])
|
||||
result = middleware.before_model(state, Runtime())
|
||||
assert result is None
|
||||
|
||||
@@ -1231,13 +1233,13 @@ def test_and_trigger_conditions() -> None:
|
||||
return 1500 # Above token threshold
|
||||
|
||||
middleware.token_counter = token_counter_high
|
||||
state = {
|
||||
"messages": [
|
||||
state = AgentState(
|
||||
messages=[
|
||||
HumanMessage(content="1"),
|
||||
AIMessage(content="2"),
|
||||
HumanMessage(content="3"),
|
||||
]
|
||||
}
|
||||
)
|
||||
result = middleware.before_model(state, Runtime())
|
||||
assert result is None, "Should not summarize when only tokens condition is met"
|
||||
|
||||
@@ -1247,8 +1249,8 @@ def test_and_trigger_conditions() -> None:
|
||||
return 500 # Below token threshold
|
||||
|
||||
middleware.token_counter = token_counter_low
|
||||
state = {
|
||||
"messages": [
|
||||
state = AgentState(
|
||||
messages=[
|
||||
HumanMessage(content="1"),
|
||||
AIMessage(content="2"),
|
||||
HumanMessage(content="3"),
|
||||
@@ -1256,7 +1258,7 @@ def test_and_trigger_conditions() -> None:
|
||||
HumanMessage(content="5"),
|
||||
AIMessage(content="6"),
|
||||
]
|
||||
}
|
||||
)
|
||||
result = middleware.before_model(state, Runtime())
|
||||
assert result is None, "Should not summarize when only messages condition is met"
|
||||
|
||||
@@ -1289,14 +1291,14 @@ def test_or_trigger_conditions_with_and_clauses() -> None:
|
||||
return 5500
|
||||
|
||||
middleware.token_counter = token_counter_5500
|
||||
state = {
|
||||
"messages": [
|
||||
state = AgentState(
|
||||
messages=[
|
||||
HumanMessage(content="1"),
|
||||
AIMessage(content="2"),
|
||||
HumanMessage(content="3"),
|
||||
AIMessage(content="4"),
|
||||
]
|
||||
}
|
||||
)
|
||||
result = middleware.before_model(state, Runtime())
|
||||
assert result is not None, "Should summarize when first OR clause is met"
|
||||
|
||||
@@ -1306,7 +1308,7 @@ def test_or_trigger_conditions_with_and_clauses() -> None:
|
||||
return 3500
|
||||
|
||||
middleware.token_counter = token_counter_3500
|
||||
state = {"messages": [HumanMessage(content=str(i)) for i in range(7)]}
|
||||
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(7)])
|
||||
result = middleware.before_model(state, Runtime())
|
||||
assert result is not None, "Should summarize when second OR clause is met"
|
||||
|
||||
@@ -1318,14 +1320,14 @@ def test_or_trigger_conditions_with_and_clauses() -> None:
|
||||
return 4500
|
||||
|
||||
middleware.token_counter = token_counter_4500
|
||||
state = {
|
||||
"messages": [
|
||||
state = AgentState(
|
||||
messages=[
|
||||
HumanMessage(content="1"),
|
||||
AIMessage(content="2"),
|
||||
HumanMessage(content="3"),
|
||||
AIMessage(content="4"),
|
||||
]
|
||||
}
|
||||
)
|
||||
result = middleware.before_model(state, Runtime())
|
||||
assert result is None, "Should not summarize when no complete clause is met"
|
||||
|
||||
@@ -1337,7 +1339,7 @@ async def test_and_trigger_conditions_async() -> None:
|
||||
trigger={"tokens": 1000, "messages": 5},
|
||||
keep=("messages", 2),
|
||||
)
|
||||
state = {"messages": [HumanMessage(content=str(i)) for i in range(6)]}
|
||||
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(6)])
|
||||
|
||||
# Only the messages threshold met (tokens below) -> should not summarize.
|
||||
def token_counter_low(_messages: Iterable[MessageLikeRepresentation]) -> int:
|
||||
@@ -1367,7 +1369,7 @@ async def test_or_trigger_conditions_with_and_clauses_async() -> None:
|
||||
],
|
||||
keep=("messages", 2),
|
||||
)
|
||||
state = {"messages": [HumanMessage(content=str(i)) for i in range(4)]}
|
||||
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(4)])
|
||||
|
||||
# First clause met (tokens = 5500, messages = 4) -> should summarize.
|
||||
def token_counter_5500(_messages: Iterable[MessageLikeRepresentation]) -> int:
|
||||
@@ -1401,7 +1403,7 @@ def test_backward_compatibility_tuple_trigger() -> None:
|
||||
return 1500
|
||||
|
||||
middleware_single.token_counter = token_counter_high
|
||||
state = {"messages": [HumanMessage(content=str(i)) for i in range(3)]}
|
||||
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(3)])
|
||||
result = middleware_single.before_model(state, Runtime())
|
||||
assert result is not None, "Single tuple trigger should work"
|
||||
|
||||
@@ -1414,7 +1416,7 @@ def test_backward_compatibility_tuple_trigger() -> None:
|
||||
|
||||
# Should trigger with high tokens (first condition met)
|
||||
middleware_list.token_counter = token_counter_high
|
||||
state = {"messages": [HumanMessage(content=str(i)) for i in range(3)]}
|
||||
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(3)])
|
||||
result = middleware_list.before_model(state, Runtime())
|
||||
assert result is not None, "List of tuples should trigger when any condition met"
|
||||
|
||||
@@ -1423,7 +1425,7 @@ def test_backward_compatibility_tuple_trigger() -> None:
|
||||
return 100
|
||||
|
||||
middleware_list.token_counter = token_counter_low
|
||||
state = {"messages": [HumanMessage(content=str(i)) for i in range(6)]}
|
||||
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(6)])
|
||||
result = middleware_list.before_model(state, Runtime())
|
||||
assert result is not None, "List of tuples should trigger when second condition met"
|
||||
|
||||
@@ -1447,7 +1449,7 @@ def test_mixed_and_or_conditions() -> None:
|
||||
return 4500
|
||||
|
||||
middleware.token_counter = token_counter_high
|
||||
state = {"messages": [HumanMessage(content=str(i)) for i in range(12)]}
|
||||
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(12)])
|
||||
result = middleware.before_model(state, Runtime())
|
||||
assert result is not None, "Should trigger when AND clause is met"
|
||||
|
||||
@@ -1456,13 +1458,13 @@ def test_mixed_and_or_conditions() -> None:
|
||||
return 1000
|
||||
|
||||
middleware.token_counter = token_counter_low
|
||||
state = {"messages": [HumanMessage(content=str(i)) for i in range(55)]}
|
||||
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(55)])
|
||||
result = middleware.before_model(state, Runtime())
|
||||
assert result is not None, "Should trigger when simple messages condition is met"
|
||||
|
||||
# Test case 3: Neither condition met
|
||||
middleware.token_counter = token_counter_low
|
||||
state = {"messages": [HumanMessage(content=str(i)) for i in range(8)]}
|
||||
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(8)])
|
||||
result = middleware.before_model(state, Runtime())
|
||||
assert result is None, "Should not trigger when no condition is met"
|
||||
|
||||
@@ -1484,21 +1486,21 @@ def test_fraction_in_and_trigger() -> None:
|
||||
# Test case 1: Both conditions met
|
||||
# 5 messages * 200 = 1000 tokens (profile max is 1000)
|
||||
# 1000 / 1000 = 1.0 >= 0.8 AND messages = 5 >= 5
|
||||
state = {"messages": [HumanMessage(content=str(i)) for i in range(5)]}
|
||||
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(5)])
|
||||
result = middleware.before_model(state, Runtime())
|
||||
assert result is not None, "Should trigger when both fraction and messages conditions met"
|
||||
|
||||
# Test case 2: Only messages condition met
|
||||
# 3 messages * 200 = 600 tokens
|
||||
# 600 / 1000 = 0.6 < 0.8 and messages = 3 < 5
|
||||
state = {"messages": [HumanMessage(content=str(i)) for i in range(3)]}
|
||||
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(3)])
|
||||
result = middleware.before_model(state, Runtime())
|
||||
assert result is None, "Should not trigger when neither condition is fully met"
|
||||
|
||||
# Test case 3: High fraction but not enough messages
|
||||
# 4 messages * 200 = 800 tokens
|
||||
# 800 / 1000 = 0.8 >= 0.8 but messages = 4 < 5
|
||||
state = {"messages": [HumanMessage(content=str(i)) for i in range(4)]}
|
||||
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(4)])
|
||||
result = middleware.before_model(state, Runtime())
|
||||
assert result is None, "Should not trigger when only fraction condition is met"
|
||||
|
||||
@@ -1511,7 +1513,7 @@ def test_trigger_validation_errors() -> None:
|
||||
with pytest.raises(ValueError, match="Unsupported trigger metric"):
|
||||
SummarizationMiddleware(
|
||||
model=model,
|
||||
trigger={"invalid_metric": 100},
|
||||
trigger={"invalid_metric": 100}, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Invalid fraction value (> 1) — shares the tuple path's message via
|
||||
@@ -1547,21 +1549,21 @@ def test_trigger_validation_errors() -> None:
|
||||
with pytest.raises(ValueError, match="Fraction trigger values must be numeric"):
|
||||
SummarizationMiddleware(
|
||||
model=model,
|
||||
trigger={"fraction": "invalid"},
|
||||
trigger={"fraction": "invalid"}, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Float value for an integer metric is rejected (no silent truncation)
|
||||
with pytest.raises(ValueError, match="tokens trigger values must be integers"):
|
||||
SummarizationMiddleware(
|
||||
model=model,
|
||||
trigger={"tokens": 1000.5},
|
||||
trigger={"tokens": 1000.5}, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Numeric string for an integer metric is rejected (no silent coercion)
|
||||
with pytest.raises(ValueError, match="messages trigger values must be integers"):
|
||||
SummarizationMiddleware(
|
||||
model=model,
|
||||
trigger={"messages": "10"},
|
||||
trigger={"messages": "10"}, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Boolean is rejected (bool is an int subclass)
|
||||
@@ -1575,7 +1577,7 @@ def test_trigger_validation_errors() -> None:
|
||||
with pytest.raises(TypeError, match="Unsupported trigger item type"):
|
||||
SummarizationMiddleware(
|
||||
model=model,
|
||||
trigger=["invalid"],
|
||||
trigger=["invalid"], # type: ignore[list-item]
|
||||
)
|
||||
|
||||
# Unsupported top-level trigger type (not a tuple, dict, or list)
|
||||
@@ -1616,7 +1618,7 @@ def test_empty_list_trigger_never_summarizes() -> None:
|
||||
token_counter=lambda _: 10_000,
|
||||
)
|
||||
assert middleware._trigger_conditions == []
|
||||
state = {"messages": [HumanMessage(content=str(i)) for i in range(50)]}
|
||||
state = AgentState(messages=[HumanMessage(content=str(i)) for i in range(50)])
|
||||
assert middleware.before_model(state, Runtime()) is None
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from langchain.agents.factory import create_agent
|
||||
from langchain.agents.middleware.tool_call_limit import (
|
||||
ExitBehavior,
|
||||
ToolCallLimitExceededError,
|
||||
ToolCallLimitMiddleware,
|
||||
ToolCallLimitState,
|
||||
@@ -34,7 +35,8 @@ def test_middleware_initialization_validation() -> None:
|
||||
assert middleware.run_limit is None
|
||||
|
||||
# Test exit behaviors
|
||||
for behavior in ["error", "end", "continue"]:
|
||||
behaviors: tuple[ExitBehavior, ...] = ("error", "end", "continue")
|
||||
for behavior in behaviors:
|
||||
middleware = ToolCallLimitMiddleware(thread_limit=5, exit_behavior=behavior)
|
||||
assert middleware.exit_behavior == behavior
|
||||
|
||||
|
||||
Reference in New Issue
Block a user