chore: increase coverage for shell, filesystem, and summarization middleware (#33928)

cc generated, just a start here but wanted to bump things up from 70%
ish
This commit is contained in:
Sydney Runkle
2025-11-14 13:30:36 -05:00
committed by GitHub
parent 1bc88028e6
commit 189dcf7295
4 changed files with 837 additions and 96 deletions

View File

@@ -7,6 +7,9 @@ import pytest
from langchain.agents.middleware.file_search import (
FilesystemFileSearchMiddleware,
_expand_include_patterns,
_is_valid_include_pattern,
_match_include_pattern,
)
@@ -259,3 +262,105 @@ class TestPathTraversalSecurity:
assert result == "No matches found"
assert "secret" not in result
class TestExpandIncludePatterns:
"""Tests for _expand_include_patterns helper function."""
def test_expand_patterns_basic_brace_expansion(self) -> None:
"""Test basic brace expansion with multiple options."""
result = _expand_include_patterns("*.{py,txt}")
assert result == ["*.py", "*.txt"]
def test_expand_patterns_nested_braces(self) -> None:
"""Test nested brace expansion."""
result = _expand_include_patterns("test.{a,b}.{c,d}")
assert result is not None
assert len(result) == 4
assert "test.a.c" in result
assert "test.b.d" in result
@pytest.mark.parametrize(
"pattern",
[
"*.py}", # closing brace without opening
"*.{}", # empty braces
"*.{py", # unclosed brace
],
)
def test_expand_patterns_invalid_braces(self, pattern: str) -> None:
"""Test patterns with invalid brace syntax return None."""
result = _expand_include_patterns(pattern)
assert result is None
class TestValidateIncludePattern:
"""Tests for _is_valid_include_pattern helper function."""
@pytest.mark.parametrize(
"pattern",
[
"", # empty pattern
"*.py\x00", # null byte
"*.py\n", # newline
],
)
def test_validate_invalid_patterns(self, pattern: str) -> None:
"""Test that invalid patterns are rejected."""
assert not _is_valid_include_pattern(pattern)
class TestMatchIncludePattern:
"""Tests for _match_include_pattern helper function."""
def test_match_pattern_with_braces(self) -> None:
"""Test matching with brace expansion."""
assert _match_include_pattern("test.py", "*.{py,txt}")
assert _match_include_pattern("test.txt", "*.{py,txt}")
assert not _match_include_pattern("test.md", "*.{py,txt}")
def test_match_pattern_invalid_expansion(self) -> None:
"""Test matching with pattern that cannot be expanded returns False."""
assert not _match_include_pattern("test.py", "*.{}")
class TestGrepEdgeCases:
"""Tests for edge cases in grep search."""
def test_grep_with_special_chars_in_pattern(self, tmp_path: Path) -> None:
"""Test grep with special characters in pattern."""
(tmp_path / "test.py").write_text("def test():\n pass\n", encoding="utf-8")
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False)
result = middleware.grep_search.func(pattern="def.*:")
assert "/test.py" in result
def test_grep_case_insensitive(self, tmp_path: Path) -> None:
"""Test grep with case-insensitive search."""
(tmp_path / "test.py").write_text("HELLO world\n", encoding="utf-8")
middleware = FilesystemFileSearchMiddleware(root_path=str(tmp_path), use_ripgrep=False)
result = middleware.grep_search.func(pattern="(?i)hello")
assert "/test.py" in result
def test_grep_with_large_file_skipping(self, tmp_path: Path) -> None:
"""Test that grep skips files larger than max_file_size_mb."""
# Create a file larger than 1MB
large_content = "x" * (2 * 1024 * 1024) # 2MB
(tmp_path / "large.txt").write_text(large_content, encoding="utf-8")
(tmp_path / "small.txt").write_text("x", encoding="utf-8")
middleware = FilesystemFileSearchMiddleware(
root_path=str(tmp_path),
use_ripgrep=False,
max_file_size_mb=1, # 1MB limit
)
result = middleware.grep_search.func(pattern="x")
# Large file should be skipped
assert "/small.txt" in result

View File

@@ -1,17 +1,21 @@
from __future__ import annotations
import time
import asyncio
import gc
import tempfile
import time
from pathlib import Path
import pytest
from langchain_core.messages import AIMessage, ToolMessage
from langchain_core.tools.base import ToolException
from langchain.agents.middleware.shell_tool import (
HostExecutionPolicy,
RedactionRule,
ShellToolMiddleware,
_SessionResources,
RedactionRule,
_ShellToolInput,
)
from langchain.agents.middleware.types import AgentState
@@ -173,3 +177,294 @@ def test_session_resources_finalizer_cleans_up(tmp_path: Path) -> None:
assert not finalizer.alive
assert session.stopped
assert not tempdir_path.exists()
def test_shell_tool_input_validation() -> None:
"""Test _ShellToolInput validation rules."""
# Both command and restart not allowed
with pytest.raises(ValueError, match="only one"):
_ShellToolInput(command="ls", restart=True)
# Neither command nor restart provided
with pytest.raises(ValueError, match="requires either"):
_ShellToolInput()
# Valid: command only
valid_cmd = _ShellToolInput(command="ls")
assert valid_cmd.command == "ls"
assert not valid_cmd.restart
# Valid: restart only
valid_restart = _ShellToolInput(restart=True)
assert valid_restart.restart is True
assert valid_restart.command is None
def test_normalize_shell_command_empty() -> None:
"""Test that empty shell command raises an error."""
with pytest.raises(ValueError, match="at least one argument"):
ShellToolMiddleware(shell_command=[])
def test_normalize_env_non_string_keys() -> None:
"""Test that non-string environment keys raise an error."""
with pytest.raises(TypeError, match="must be strings"):
ShellToolMiddleware(env={123: "value"}) # type: ignore[dict-item]
def test_normalize_env_coercion(tmp_path: Path) -> None:
"""Test that environment values are coerced to strings."""
middleware = ShellToolMiddleware(
workspace_root=tmp_path / "workspace", env={"NUM": 42, "BOOL": True}
)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(
resources, {"command": "echo $NUM $BOOL"}, tool_call_id=None
)
assert "42" in result
assert "True" in result
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_shell_tool_missing_command_string(tmp_path: Path) -> None:
"""Test that shell tool raises an error when command is not a string."""
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
with pytest.raises(ToolException, match="expects a 'command' string"):
middleware._run_shell_tool(resources, {"command": None}, tool_call_id=None)
with pytest.raises(ToolException, match="expects a 'command' string"):
middleware._run_shell_tool(
resources,
{"command": 123}, # type: ignore[dict-item]
tool_call_id=None,
)
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_tool_message_formatting_with_id(tmp_path: Path) -> None:
"""Test that tool messages are properly formatted with tool_call_id."""
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(
resources, {"command": "echo test"}, tool_call_id="test-id-123"
)
assert isinstance(result, ToolMessage)
assert result.tool_call_id == "test-id-123"
assert result.name == "shell"
assert result.status == "success"
assert "test" in result.content
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_nonzero_exit_code_returns_error(tmp_path: Path) -> None:
"""Test that non-zero exit codes are marked as errors."""
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(
resources,
{"command": "false"}, # Command that exits with 1 but doesn't kill shell
tool_call_id="test-id",
)
assert isinstance(result, ToolMessage)
assert result.status == "error"
assert "Exit code: 1" in result.content
assert result.artifact["exit_code"] == 1 # type: ignore[index]
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_truncation_by_bytes(tmp_path: Path) -> None:
"""Test that output is truncated by bytes when max_output_bytes is exceeded."""
policy = HostExecutionPolicy(max_output_bytes=50, command_timeout=5.0)
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace", execution_policy=policy)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(
resources, {"command": "python3 -c 'print(\"x\" * 100)'"}, tool_call_id=None
)
assert "truncated at 50 bytes" in result.lower()
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_startup_command_failure(tmp_path: Path) -> None:
"""Test that startup command failure raises an error."""
policy = HostExecutionPolicy(startup_timeout=1.0)
middleware = ShellToolMiddleware(
workspace_root=tmp_path / "workspace", startup_commands=("exit 1",), execution_policy=policy
)
state: AgentState = _empty_state()
with pytest.raises(RuntimeError, match="Startup command.*failed"):
middleware.before_agent(state, None)
def test_shutdown_command_failure_logged(tmp_path: Path) -> None:
"""Test that shutdown command failures are logged but don't raise."""
policy = HostExecutionPolicy(command_timeout=1.0)
middleware = ShellToolMiddleware(
workspace_root=tmp_path / "workspace",
shutdown_commands=("exit 1",),
execution_policy=policy,
)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
finally:
# Should not raise despite shutdown command failing
middleware.after_agent(state, None)
def test_shutdown_command_timeout_logged(tmp_path: Path) -> None:
"""Test that shutdown command timeouts are logged but don't raise."""
policy = HostExecutionPolicy(command_timeout=0.1)
middleware = ShellToolMiddleware(
workspace_root=tmp_path / "workspace",
execution_policy=policy,
shutdown_commands=("sleep 2",),
)
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
finally:
# Should not raise despite shutdown command timing out
middleware.after_agent(state, None)
def test_ensure_resources_missing_state(tmp_path: Path) -> None:
"""Test that _ensure_resources raises when resources are missing."""
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
state: AgentState = _empty_state()
with pytest.raises(ToolException, match="Shell session resources are unavailable"):
middleware._ensure_resources(state) # type: ignore[attr-defined]
def test_empty_output_replaced_with_no_output(tmp_path: Path) -> None:
"""Test that empty command output is replaced with '<no output>'."""
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(
resources,
{"command": "true"}, # Command that produces no output
tool_call_id=None,
)
assert "<no output>" in result
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
def test_stderr_output_labeling(tmp_path: Path) -> None:
"""Test that stderr output is properly labeled."""
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
try:
state: AgentState = _empty_state()
updates = middleware.before_agent(state, None)
if updates:
state.update(updates)
resources = middleware._ensure_resources(state) # type: ignore[attr-defined]
result = middleware._run_shell_tool(
resources, {"command": "echo error >&2"}, tool_call_id=None
)
assert "[stderr] error" in result
finally:
updates = middleware.after_agent(state, None)
if updates:
state.update(updates)
@pytest.mark.parametrize(
("startup_commands", "expected"),
[
("echo test", ("echo test",)), # String
(["echo test", "pwd"], ("echo test", "pwd")), # List
(("echo test",), ("echo test",)), # Tuple
(None, ()), # None
],
)
def test_normalize_commands_string_tuple_list(
tmp_path: Path,
startup_commands: str | list[str] | tuple[str, ...] | None,
expected: tuple[str, ...],
) -> None:
"""Test various command normalization formats."""
middleware = ShellToolMiddleware(
workspace_root=tmp_path / "workspace", startup_commands=startup_commands
)
assert middleware._startup_commands == expected # type: ignore[attr-defined]
def test_async_methods_delegate_to_sync(tmp_path: Path) -> None:
"""Test that async methods properly delegate to sync methods."""
middleware = ShellToolMiddleware(workspace_root=tmp_path / "workspace")
try:
state: AgentState = _empty_state()
# Test abefore_agent
updates = asyncio.run(middleware.abefore_agent(state, None))
if updates:
state.update(updates)
# Test aafter_agent
asyncio.run(middleware.aafter_agent(state, None))
finally:
pass

View File

@@ -1,11 +1,11 @@
from typing import TYPE_CHECKING
from unittest.mock import patch
import pytest
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from unittest.mock import patch
from langchain.agents.middleware.summarization import SummarizationMiddleware
@@ -15,6 +15,35 @@ if TYPE_CHECKING:
from langchain_model_profiles import ModelProfile
class MockChatModel(BaseChatModel):
"""Mock chat model for testing."""
def invoke(self, prompt): # type: ignore[no-untyped-def]
return AIMessage(content="Generated summary")
def _generate(self, messages, **kwargs): # type: ignore[no-untyped-def]
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
def _llm_type(self) -> str:
return "mock"
class ProfileChatModel(BaseChatModel):
"""Mock chat model with profile for testing."""
def _generate(self, messages, **kwargs): # type: ignore[no-untyped-def]
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
def _llm_type(self) -> str:
return "mock"
@property
def profile(self) -> "ModelProfile":
return {"max_input_tokens": 1000}
def test_summarization_middleware_initialization() -> None:
"""Test SummarizationMiddleware initialization."""
model = FakeToolCallingModel()
@@ -139,19 +168,7 @@ def test_summarization_middleware_tool_call_safety() -> None:
def test_summarization_middleware_summary_creation() -> None:
"""Test SummarizationMiddleware summary creation."""
class MockModel(BaseChatModel):
def invoke(self, prompt):
return AIMessage(content="Generated summary")
def _generate(self, messages, **kwargs):
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
def _llm_type(self):
return "mock"
middleware = SummarizationMiddleware(model=MockModel(), trigger=("tokens", 1000))
middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("tokens", 1000))
# Test normal summary creation
messages = [HumanMessage(content="Hello"), AIMessage(content="Hi")]
@@ -180,25 +197,16 @@ def test_summarization_middleware_summary_creation() -> None:
# Test we raise warning if max_tokens_before_summary or messages_to_keep is specified
with pytest.warns(DeprecationWarning, match="max_tokens_before_summary is deprecated"):
SummarizationMiddleware(model=MockModel(), max_tokens_before_summary=500)
SummarizationMiddleware(model=MockChatModel(), max_tokens_before_summary=500)
with pytest.warns(DeprecationWarning, match="messages_to_keep is deprecated"):
SummarizationMiddleware(model=MockModel(), messages_to_keep=5)
SummarizationMiddleware(model=MockChatModel(), messages_to_keep=5)
def test_summarization_middleware_trim_limit_none_keeps_all_messages() -> None:
"""Verify disabling trim limit preserves full message sequence."""
class MockModel(BaseChatModel):
def _generate(self, messages, **kwargs):
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
def _llm_type(self):
return "mock"
messages = [HumanMessage(content=str(i)) for i in range(10)]
middleware = SummarizationMiddleware(
model=MockModel(),
model=MockChatModel(),
trim_tokens_to_summarize=None,
)
middleware.token_counter = lambda msgs: len(msgs)
@@ -209,23 +217,10 @@ def test_summarization_middleware_trim_limit_none_keeps_all_messages() -> None:
def test_summarization_middleware_profile_inference_triggers_summary() -> None:
"""Ensure automatic profile inference triggers summarization when limits are exceeded."""
class ProfileModel(BaseChatModel):
def _generate(self, messages, **kwargs):
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
def _llm_type(self) -> str:
return "mock"
@property
def profile(self) -> "ModelProfile":
return {"max_input_tokens": 1000}
token_counter = lambda messages: len(messages) * 200
middleware = SummarizationMiddleware(
model=ProfileModel(),
model=ProfileChatModel(),
trigger=("fraction", 0.81),
keep=("fraction", 0.5),
token_counter=token_counter,
@@ -250,7 +245,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
# Engage summarization
# 0.80 * 1000 == 800 <= 800
middleware = SummarizationMiddleware(
model=ProfileModel(),
model=ProfileChatModel(),
trigger=("fraction", 0.80),
keep=("fraction", 0.5),
token_counter=token_counter,
@@ -270,7 +265,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
# With keep=("fraction", 0.6) the target token allowance becomes 600,
# so the cutoff shifts to keep the last three messages instead of two.
middleware = SummarizationMiddleware(
model=ProfileModel(),
model=ProfileChatModel(),
trigger=("fraction", 0.80),
keep=("fraction", 0.6),
token_counter=token_counter,
@@ -287,7 +282,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
# context (target tokens = 800), so token-based retention keeps everything
# and summarization is skipped entirely.
middleware = SummarizationMiddleware(
model=ProfileModel(),
model=ProfileChatModel(),
trigger=("fraction", 0.80),
keep=("fraction", 0.8),
token_counter=token_counter,
@@ -296,7 +291,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
# Test with tokens_to_keep as absolute int value
middleware_int = SummarizationMiddleware(
model=ProfileModel(),
model=ProfileChatModel(),
trigger=("fraction", 0.80),
keep=("tokens", 400), # Keep exactly 400 tokens (2 messages)
token_counter=token_counter,
@@ -310,7 +305,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
# Test with tokens_to_keep as larger int value
middleware_int_large = SummarizationMiddleware(
model=ProfileModel(),
model=ProfileChatModel(),
trigger=("fraction", 0.80),
keep=("tokens", 600), # Keep 600 tokens (3 messages)
token_counter=token_counter,
@@ -327,23 +322,11 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
def test_summarization_middleware_token_retention_pct_respects_tool_pairs() -> None:
"""Ensure token retention keeps pairs together even if exceeding target tokens."""
class ProfileModel(BaseChatModel):
def _generate(self, messages, **kwargs):
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
def _llm_type(self) -> str:
return "mock"
@property
def profile(self) -> "ModelProfile":
return {"max_input_tokens": 1000}
def token_counter(messages):
def token_counter(messages: list[AnyMessage]) -> int:
return sum(len(getattr(message, "content", "")) for message in messages)
middleware = SummarizationMiddleware(
model=ProfileModel(),
model=ProfileChatModel(),
trigger=("fraction", 0.1),
keep=("fraction", 0.5),
)
@@ -400,22 +383,10 @@ def test_summarization_middleware_missing_profile() -> None:
def test_summarization_middleware_full_workflow() -> None:
"""Test SummarizationMiddleware complete summarization workflow."""
class MockModel(BaseChatModel):
def invoke(self, prompt):
return AIMessage(content="Generated summary")
def _generate(self, messages, **kwargs):
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
def _llm_type(self):
return "mock"
with pytest.warns(DeprecationWarning):
# keep test for functionality
middleware = SummarizationMiddleware(
model=MockModel(), max_tokens_before_summary=1000, messages_to_keep=2
model=MockChatModel(), max_tokens_before_summary=1000, messages_to_keep=2
)
# Mock high token count to trigger summarization
@@ -504,21 +475,9 @@ async def test_summarization_middleware_full_workflow_async() -> None:
def test_summarization_middleware_keep_messages() -> None:
"""Test SummarizationMiddleware with keep parameter specifying messages."""
class MockModel(BaseChatModel):
def invoke(self, prompt):
return AIMessage(content="Generated summary")
def _generate(self, messages, **kwargs):
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
def _llm_type(self):
return "mock"
# Test that summarization is triggered when message count reaches threshold
middleware = SummarizationMiddleware(
model=MockModel(), trigger=("messages", 5), keep=("messages", 2)
model=MockChatModel(), trigger=("messages", 5), keep=("messages", 2)
)
# Below threshold - no summarization
@@ -561,6 +520,379 @@ def test_summarization_middleware_keep_messages() -> None:
assert [message.content for message in result["messages"][2:]] == ["5", "6"]
# Test with both parameters disabled
middleware_disabled = SummarizationMiddleware(model=MockModel(), trigger=None)
middleware_disabled = SummarizationMiddleware(model=MockChatModel(), trigger=None)
result = middleware_disabled.before_model(state_above, None)
assert result is None
@pytest.mark.parametrize(
("param_name", "param_value", "expected_error"),
[
("trigger", ("fraction", 0.0), "Fractional trigger values must be between 0 and 1"),
("trigger", ("fraction", 1.5), "Fractional trigger values must be between 0 and 1"),
("keep", ("fraction", -0.1), "Fractional keep values must be between 0 and 1"),
("trigger", ("tokens", 0), "trigger thresholds must be greater than 0"),
("trigger", ("messages", -5), "trigger thresholds must be greater than 0"),
("keep", ("tokens", 0), "keep thresholds must be greater than 0"),
("trigger", ("invalid", 100), "Unsupported context size type"),
("keep", ("invalid", 100), "Unsupported context size type"),
],
)
def test_summarization_middleware_validation_edge_cases(
param_name: str, param_value: tuple[str, float | int], expected_error: str
) -> None:
"""Test validation of context size parameters with edge cases."""
model = FakeToolCallingModel()
with pytest.raises(ValueError, match=expected_error):
SummarizationMiddleware(model=model, **{param_name: param_value})
def test_summarization_middleware_multiple_triggers() -> None:
"""Test middleware with multiple trigger conditions."""
# Test with multiple triggers - should activate when ANY condition is met
middleware = SummarizationMiddleware(
model=MockChatModel(),
trigger=[("messages", 10), ("tokens", 500)],
keep=("messages", 2),
)
# Mock token counter to return low count
def mock_low_tokens(messages):
return 100
middleware.token_counter = mock_low_tokens
# Should not trigger - neither condition met
messages = [HumanMessage(content=str(i)) for i in range(5)]
state = {"messages": messages}
result = middleware.before_model(state, None)
assert result is None
# Should trigger - message count threshold met
messages = [HumanMessage(content=str(i)) for i in range(10)]
state = {"messages": messages}
result = middleware.before_model(state, None)
assert result is not None
# Test token trigger
def mock_high_tokens(messages):
return 600
middleware.token_counter = mock_high_tokens
messages = [HumanMessage(content=str(i)) for i in range(5)]
state = {"messages": messages}
result = middleware.before_model(state, None)
assert result is not None
def test_summarization_middleware_profile_edge_cases() -> None:
"""Test profile retrieval with various edge cases."""
class NoProfileModel(BaseChatModel):
def _generate(self, messages, **kwargs):
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
def _llm_type(self):
return "mock"
# Model without profile attribute
middleware = SummarizationMiddleware(model=NoProfileModel(), trigger=("messages", 5))
assert middleware._get_profile_limits() is None
class InvalidProfileModel(BaseChatModel):
def _generate(self, messages, **kwargs):
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
def _llm_type(self):
return "mock"
@property
def profile(self):
return "invalid_profile_type"
# Model with non-dict profile
middleware = SummarizationMiddleware(model=InvalidProfileModel(), trigger=("messages", 5))
assert middleware._get_profile_limits() is None
class MissingTokensModel(BaseChatModel):
def _generate(self, messages, **kwargs):
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
def _llm_type(self):
return "mock"
@property
def profile(self):
return {"other_field": 100}
# Model with profile but no max_input_tokens
middleware = SummarizationMiddleware(model=MissingTokensModel(), trigger=("messages", 5))
assert middleware._get_profile_limits() is None
class InvalidTokenTypeModel(BaseChatModel):
def _generate(self, messages, **kwargs):
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
def _llm_type(self):
return "mock"
@property
def profile(self):
return {"max_input_tokens": "not_an_int"}
# Model with non-int max_input_tokens
middleware = SummarizationMiddleware(model=InvalidTokenTypeModel(), trigger=("messages", 5))
assert middleware._get_profile_limits() is None
def test_summarization_middleware_trim_messages_error_fallback() -> None:
"""Test that trim_messages_for_summary falls back gracefully on errors."""
middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 5))
# Create a mock token counter that raises an exception
def failing_token_counter(messages):
raise Exception("Token counting failed")
middleware.token_counter = failing_token_counter
# Should fall back to last 15 messages
messages = [HumanMessage(content=str(i)) for i in range(20)]
trimmed = middleware._trim_messages_for_summary(messages)
assert len(trimmed) == 15
assert trimmed == messages[-15:]
def test_summarization_middleware_binary_search_edge_cases() -> None:
"""Test binary search in _find_token_based_cutoff with edge cases."""
middleware = SummarizationMiddleware(
model=MockChatModel(), trigger=("messages", 5), keep=("tokens", 100)
)
# Test with single message that's too large
def token_counter_single_large(messages):
return len(messages) * 200
middleware.token_counter = token_counter_single_large
single_message = [HumanMessage(content="x" * 200)]
cutoff = middleware._find_token_based_cutoff(single_message)
assert cutoff == 0
# Test with empty messages
cutoff = middleware._find_token_based_cutoff([])
assert cutoff == 0
# Test when all messages fit within token budget
def token_counter_small(messages):
return len(messages) * 10
middleware.token_counter = token_counter_small
messages = [HumanMessage(content=str(i)) for i in range(5)]
cutoff = middleware._find_token_based_cutoff(messages)
assert cutoff == 0
def test_summarization_middleware_tool_call_extraction_edge_cases() -> None:
"""Test tool call ID extraction with various message formats."""
model = FakeToolCallingModel()
middleware = SummarizationMiddleware(model=model, trigger=("messages", 5))
# Test with dict-style tool calls
ai_message_dict = AIMessage(
content="test", tool_calls=[{"name": "tool1", "args": {}, "id": "id1"}]
)
ids = middleware._extract_tool_call_ids(ai_message_dict)
assert ids == {"id1"}
# Test with multiple tool calls
ai_message_multiple = AIMessage(
content="test",
tool_calls=[
{"name": "tool1", "args": {}, "id": "id1"},
{"name": "tool2", "args": {}, "id": "id2"},
],
)
ids = middleware._extract_tool_call_ids(ai_message_multiple)
assert ids == {"id1", "id2"}
# Test with empty tool calls list
ai_message_empty = AIMessage(content="test", tool_calls=[])
ids = middleware._extract_tool_call_ids(ai_message_empty)
assert len(ids) == 0
def test_summarization_middleware_complex_tool_pair_scenarios() -> None:
"""Test complex tool call pairing scenarios."""
model = FakeToolCallingModel()
middleware = SummarizationMiddleware(model=model, trigger=("messages", 5), keep=("messages", 3))
# Test with multiple AI messages with tool calls
messages = [
HumanMessage(content="msg1"),
AIMessage(content="ai1", tool_calls=[{"name": "tool1", "args": {}, "id": "call1"}]),
ToolMessage(content="result1", tool_call_id="call1"),
HumanMessage(content="msg2"),
AIMessage(content="ai2", tool_calls=[{"name": "tool2", "args": {}, "id": "call2"}]),
ToolMessage(content="result2", tool_call_id="call2"),
HumanMessage(content="msg3"),
]
# Test cutoff at index 1 - unsafe (separates first AI/Tool pair)
assert not middleware._is_safe_cutoff_point(messages, 2)
# Test cutoff at index 3 - safe (keeps first pair together)
assert middleware._is_safe_cutoff_point(messages, 3)
# Test cutoff at index 5 - unsafe (separates second AI/Tool pair)
assert not middleware._is_safe_cutoff_point(messages, 5)
# Test _cutoff_separates_tool_pair directly
assert middleware._cutoff_separates_tool_pair(messages, 1, 2, {"call1"})
assert not middleware._cutoff_separates_tool_pair(messages, 1, 0, {"call1"})
assert not middleware._cutoff_separates_tool_pair(messages, 1, 3, {"call1"})
def test_summarization_middleware_tool_call_in_search_range() -> None:
"""Test tool call safety with messages at edge of search range."""
model = FakeToolCallingModel()
middleware = SummarizationMiddleware(
model=model, trigger=("messages", 10), keep=("messages", 2)
)
# Create messages with tool pair separated by some distance
# Search range is 5, so messages within 5 positions of cutoff are checked
messages = [
HumanMessage(content="msg1"),
HumanMessage(content="msg2"),
AIMessage(content="ai", tool_calls=[{"name": "tool", "args": {}, "id": "call1"}]),
HumanMessage(content="msg3"),
HumanMessage(content="msg4"),
ToolMessage(content="result", tool_call_id="call1"),
HumanMessage(content="msg6"),
]
# Cutoff at index 3 would separate: [0,1,2] from [3,4,5,6]
# AI at index 2 is before cutoff, Tool at index 5 is after cutoff - unsafe
assert not middleware._is_safe_cutoff_point(messages, 3)
# Cutoff at index 6 keeps AI and Tool both in summarized section
assert middleware._is_safe_cutoff_point(messages, 6)
# Cutoff at index 0 or 1 also safe - both AI and Tool in preserved section
assert middleware._is_safe_cutoff_point(messages, 0)
assert middleware._is_safe_cutoff_point(messages, 1)
def test_summarization_middleware_zero_and_negative_target_tokens() -> None:
"""Test handling of edge cases with target token calculations."""
# Test with very small fraction that rounds to zero
middleware = SummarizationMiddleware(
model=ProfileChatModel(), trigger=("fraction", 0.0001), keep=("fraction", 0.0001)
)
# Should set threshold to 1 when calculated value is <= 0
messages = [HumanMessage(content="test")]
state = {"messages": messages}
# The trigger fraction calculation: int(1000 * 0.0001) = 0, but should be set to 1
# Token count of 1 message should exceed threshold of 1
def token_counter(msgs):
return 2
middleware.token_counter = token_counter
assert middleware._should_summarize(messages, 2)
async def test_summarization_middleware_async_error_handling() -> None:
"""Test async summary creation with errors."""
class ErrorAsyncModel(BaseChatModel):
def _generate(self, messages, **kwargs):
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
async def _agenerate(self, messages, **kwargs):
raise Exception("Async model error")
@property
def _llm_type(self):
return "mock"
middleware = SummarizationMiddleware(model=ErrorAsyncModel(), trigger=("messages", 5))
messages = [HumanMessage(content="test")]
summary = await middleware._acreate_summary(messages)
assert "Error generating summary: Async model error" in summary
def test_summarization_middleware_cutoff_at_boundary() -> None:
"""Test cutoff index determination at exact message boundaries."""
middleware = SummarizationMiddleware(
model=MockChatModel(), trigger=("messages", 5), keep=("messages", 5)
)
# When we want to keep exactly as many messages as we have
messages = [HumanMessage(content=str(i)) for i in range(5)]
cutoff = middleware._find_safe_cutoff(messages, 5)
assert cutoff == 0 # Should not cut anything
# When we want to keep more messages than we have
cutoff = middleware._find_safe_cutoff(messages, 10)
assert cutoff == 0
def test_summarization_middleware_deprecated_parameters_with_defaults() -> None:
"""Test that deprecated parameters work correctly with default values."""
# Test that deprecated max_tokens_before_summary is ignored when trigger is set
with pytest.warns(DeprecationWarning):
middleware = SummarizationMiddleware(
model=MockChatModel(), trigger=("tokens", 2000), max_tokens_before_summary=1000
)
assert middleware.trigger == ("tokens", 2000)
# Test that messages_to_keep is ignored when keep is not default
with pytest.warns(DeprecationWarning):
middleware = SummarizationMiddleware(
model=MockChatModel(), keep=("messages", 5), messages_to_keep=10
)
assert middleware.keep == ("messages", 5)
def test_summarization_middleware_fraction_trigger_with_no_profile() -> None:
"""Test fractional trigger condition when profile data becomes unavailable."""
middleware = SummarizationMiddleware(
model=ProfileChatModel(),
trigger=[("fraction", 0.5), ("messages", 100)],
keep=("messages", 5),
)
# Test that when fractional condition can't be evaluated, other triggers still work
messages = [HumanMessage(content=str(i)) for i in range(100)]
# Mock _get_profile_limits to return None
original_method = middleware._get_profile_limits
middleware._get_profile_limits = lambda: None
# Should still trigger based on message count
state = {"messages": messages}
result = middleware.before_model(state, None)
assert result is not None
# Restore original method
middleware._get_profile_limits = original_method
def test_summarization_middleware_is_safe_cutoff_at_end() -> None:
"""Test _is_safe_cutoff_point when cutoff is at or past the end."""
model = FakeToolCallingModel()
middleware = SummarizationMiddleware(model=model, trigger=("messages", 5))
messages = [HumanMessage(content=str(i)) for i in range(5)]
# Cutoff at exactly the length should be safe
assert middleware._is_safe_cutoff_point(messages, len(messages))
# Cutoff past the length should also be safe
assert middleware._is_safe_cutoff_point(messages, len(messages) + 5)

View File

@@ -1,22 +1,22 @@
"""Unit tests for LLM tool selection middleware."""
import typing
from typing import Union, Any, Literal
from itertools import cycle
from typing import Any, Literal, Union
import pytest
from pydantic import BaseModel
from langchain.agents import create_agent
from langchain.agents.middleware import AgentState, ModelRequest, wrap_model_call
from langchain.agents.middleware import LLMToolSelectorMiddleware
from langchain.agents.middleware import LLMToolSelectorMiddleware, ModelRequest, wrap_model_call
from langchain.agents.middleware.tool_selection import _create_tool_selection_response
from langchain.agents.middleware.types import AgentState
from langchain.messages import AIMessage
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.tools import tool
from langchain_core.tools import BaseTool, tool
@tool
@@ -596,3 +596,12 @@ class TestDuplicateAndInvalidTools:
assert len(tool_names) == 2
assert "get_weather" in tool_names
assert "search_web" in tool_names
class TestEdgeCases:
"""Test edge cases and error handling."""
def test_empty_tools_list_raises_error(self) -> None:
"""Test that empty tools list raises an error in schema creation."""
with pytest.raises(AssertionError, match="tools must be non-empty"):
_create_tool_selection_response([])