mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-12 04:01:05 +00:00
Compare commits
1 Commits
cc/release
...
mdrxy/Larg
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5ce81dff22 |
@@ -9,6 +9,7 @@ from .human_in_the_loop import (
|
||||
HumanInTheLoopMiddleware,
|
||||
InterruptOnConfig,
|
||||
)
|
||||
from .large_tool_result import LargeToolResultMiddleware
|
||||
from .model_call_limit import ModelCallLimitMiddleware
|
||||
from .model_fallback import ModelFallbackMiddleware
|
||||
from .model_retry import ModelRetryMiddleware
|
||||
@@ -54,6 +55,7 @@ __all__ = [
|
||||
"InterruptOnConfig",
|
||||
"LLMToolEmulator",
|
||||
"LLMToolSelectorMiddleware",
|
||||
"LargeToolResultMiddleware",
|
||||
"ModelCallLimitMiddleware",
|
||||
"ModelFallbackMiddleware",
|
||||
"ModelRequest",
|
||||
|
||||
@@ -0,0 +1,327 @@
|
||||
"""Large tool result middleware.
|
||||
|
||||
Offloads large tool results to the filesystem to prevent context overflow.
|
||||
Results exceeding a configurable threshold are written to temporary files
|
||||
with a truncated preview kept in the message.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import tempfile
|
||||
from collections.abc import Awaitable, Callable, Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
||||
|
||||
from langchain_core.messages import AnyMessage, ToolMessage
|
||||
from typing_extensions import NotRequired, override
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ModelCallResult,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
PrivateStateAttr,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
__all__ = ["LargeToolResultMiddleware"]
|
||||
|
||||
_DEFAULT_THRESHOLD_FRACTION = 0.10
|
||||
_DEFAULT_PREVIEW_LENGTH = 500
|
||||
_OFFLOAD_METADATA_KEY = "large_tool_result_offloaded"
|
||||
|
||||
ContextFraction = tuple[Literal["fraction"], float]
|
||||
ContextTokens = tuple[Literal["tokens"], int]
|
||||
ContextSize = ContextFraction | ContextTokens
|
||||
|
||||
|
||||
class LargeToolResultState(AgentState):
|
||||
"""Extended state for large tool result middleware."""
|
||||
|
||||
offloaded_results_dir: NotRequired[Annotated[str | None, PrivateStateAttr]]
|
||||
"""Directory where large tool results are stored."""
|
||||
|
||||
|
||||
class LargeToolResultMiddleware(AgentMiddleware[LargeToolResultState, Any]):
|
||||
"""Offloads large tool results to filesystem to prevent context overflow.
|
||||
|
||||
This middleware monitors tool result sizes and automatically writes results
|
||||
exceeding a threshold to temporary files, keeping a truncated preview in the
|
||||
message. This prevents massive tool outputs from triggering aggressive
|
||||
summarization or exceeding model context limits.
|
||||
|
||||
The middleware uses `wrap_model_call` to process messages before each model
|
||||
invocation, ensuring large results are offloaded before token counting occurs.
|
||||
"""
|
||||
|
||||
state_schema = LargeToolResultState
|
||||
tools: Sequence = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
threshold: ContextSize = ("fraction", _DEFAULT_THRESHOLD_FRACTION),
|
||||
preview_length: int = _DEFAULT_PREVIEW_LENGTH,
|
||||
temp_dir: Path | str | None = None,
|
||||
cleanup_on_end: bool = True,
|
||||
) -> None:
|
||||
"""Initialize large tool result middleware.
|
||||
|
||||
Args:
|
||||
threshold: Size threshold that triggers offloading to disk.
|
||||
|
||||
Provide a tuple specifying the threshold type:
|
||||
|
||||
- `('fraction', 0.10)`: Offload if result exceeds 10% of model's
|
||||
max input tokens (default)
|
||||
- `('tokens', 5000)`: Offload if result exceeds 5000 tokens
|
||||
|
||||
preview_length: Number of characters to keep as preview in the message.
|
||||
Defaults to 500 characters.
|
||||
|
||||
temp_dir: Directory for storing offloaded results.
|
||||
|
||||
If `None` (default), creates a temporary directory that is cleaned
|
||||
up when the agent session ends.
|
||||
|
||||
If provided, uses the specified directory and does not delete it
|
||||
on cleanup (user-managed).
|
||||
|
||||
cleanup_on_end: Whether to clean up the temp directory when the agent
|
||||
session ends.
|
||||
|
||||
Only applies when `temp_dir` is `None` (auto-created directory).
|
||||
|
||||
Defaults to `True`.
|
||||
"""
|
||||
super().__init__()
|
||||
self.threshold = self._validate_threshold(threshold)
|
||||
self.preview_length = preview_length
|
||||
self.user_temp_dir = Path(temp_dir) if temp_dir else None
|
||||
self.cleanup_on_end = cleanup_on_end
|
||||
self._temp_dir: tempfile.TemporaryDirectory[str] | None = None
|
||||
|
||||
def _validate_threshold(self, threshold: ContextSize) -> ContextSize:
|
||||
"""Validate threshold configuration."""
|
||||
kind, value = threshold
|
||||
if kind == "fraction":
|
||||
if not 0 < value <= 1:
|
||||
msg = f"Fractional threshold must be between 0 and 1, got {value}."
|
||||
raise ValueError(msg)
|
||||
elif kind == "tokens":
|
||||
if value <= 0:
|
||||
msg = f"Token threshold must be greater than 0, got {value}."
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
msg = f"Unsupported threshold type: {kind}"
|
||||
raise ValueError(msg)
|
||||
return threshold
|
||||
|
||||
def _get_threshold_tokens(self, model: BaseChatModel | None) -> int:
|
||||
"""Calculate threshold in tokens based on configuration."""
|
||||
kind, value = self.threshold
|
||||
if kind == "tokens":
|
||||
return int(value)
|
||||
|
||||
# Fractional threshold - need model profile
|
||||
if model is None:
|
||||
# Fallback to character-based estimate (4 chars per token)
|
||||
return int(value * 100_000)
|
||||
|
||||
max_input_tokens = self._get_model_max_tokens(model)
|
||||
if max_input_tokens is None:
|
||||
# Fallback
|
||||
return int(value * 100_000)
|
||||
|
||||
return int(max_input_tokens * value)
|
||||
|
||||
def _get_model_max_tokens(self, model: BaseChatModel) -> int | None:
|
||||
"""Get model's max input tokens from profile."""
|
||||
try:
|
||||
profile = model.profile
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
if not isinstance(profile, Mapping):
|
||||
return None
|
||||
|
||||
max_input_tokens = profile.get("max_input_tokens")
|
||||
return max_input_tokens if isinstance(max_input_tokens, int) else None
|
||||
|
||||
def _get_or_create_temp_dir(self) -> Path:
|
||||
"""Get or create the temporary directory for storing results."""
|
||||
if self.user_temp_dir is not None:
|
||||
self.user_temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
return self.user_temp_dir
|
||||
|
||||
if self._temp_dir is not None:
|
||||
return Path(self._temp_dir.name)
|
||||
|
||||
# Create new temp directory
|
||||
self._temp_dir = tempfile.TemporaryDirectory(prefix="langchain-large-results-")
|
||||
return Path(self._temp_dir.name)
|
||||
|
||||
def _estimate_tokens(self, content: str) -> int:
|
||||
"""Estimate token count for a string."""
|
||||
# Approximate: ~4 characters per token
|
||||
return len(content) // 4
|
||||
|
||||
def _should_offload(self, content: str, threshold_tokens: int) -> bool:
|
||||
"""Determine if content should be offloaded based on size."""
|
||||
estimated_tokens = self._estimate_tokens(content)
|
||||
return estimated_tokens > threshold_tokens
|
||||
|
||||
def _offload_content(self, tool_call_id: str, content: str) -> str:
|
||||
"""Write content to file and return the file path."""
|
||||
temp_dir = self._get_or_create_temp_dir()
|
||||
# Sanitize tool_call_id for filename
|
||||
safe_id = "".join(c if c.isalnum() or c in "-_" else "_" for c in tool_call_id)
|
||||
file_path = temp_dir / f"{safe_id}.txt"
|
||||
|
||||
file_path.write_text(content, encoding="utf-8")
|
||||
return str(file_path)
|
||||
|
||||
def _create_truncated_message(self, original_content: str, file_path: str) -> str:
|
||||
"""Create a truncated message with file reference."""
|
||||
preview = original_content[: self.preview_length]
|
||||
if len(original_content) > self.preview_length:
|
||||
preview += "..."
|
||||
|
||||
return (
|
||||
f"[TRUNCATED - Full result saved to: {file_path}]\n\n"
|
||||
f"Preview (first {self.preview_length} chars):\n{preview}"
|
||||
)
|
||||
|
||||
def _process_messages(
|
||||
self,
|
||||
messages: list[AnyMessage],
|
||||
threshold_tokens: int,
|
||||
) -> list[AnyMessage]:
|
||||
"""Process messages and offload large tool results."""
|
||||
processed: list[AnyMessage] = []
|
||||
for msg in messages:
|
||||
if not isinstance(msg, ToolMessage):
|
||||
processed.append(msg)
|
||||
continue
|
||||
|
||||
# Check if already offloaded
|
||||
if msg.response_metadata.get(_OFFLOAD_METADATA_KEY):
|
||||
processed.append(msg)
|
||||
continue
|
||||
|
||||
# Get content as string
|
||||
content = msg.content
|
||||
if isinstance(content, list):
|
||||
# Multimodal content - convert to string for size check
|
||||
content = str(content)
|
||||
if not isinstance(content, str):
|
||||
content = str(content)
|
||||
|
||||
# Check if should offload
|
||||
if not self._should_offload(content, threshold_tokens):
|
||||
processed.append(msg)
|
||||
continue
|
||||
|
||||
# Offload to file
|
||||
tool_call_id = msg.tool_call_id or f"unknown_{id(msg)}"
|
||||
file_path = self._offload_content(tool_call_id, content)
|
||||
|
||||
# Create truncated message
|
||||
truncated_content = self._create_truncated_message(content, file_path)
|
||||
|
||||
# Create new message with truncated content
|
||||
new_msg = msg.model_copy(
|
||||
update={
|
||||
"content": truncated_content,
|
||||
"response_metadata": {
|
||||
**msg.response_metadata,
|
||||
_OFFLOAD_METADATA_KEY: {
|
||||
"offloaded": True,
|
||||
"file_path": file_path,
|
||||
"original_size_chars": len(content),
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
processed.append(new_msg)
|
||||
|
||||
return processed
|
||||
|
||||
@override
|
||||
def before_agent(
|
||||
self, state: LargeToolResultState, runtime: Runtime[Any]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Initialize temp directory tracking in state."""
|
||||
# Just return None - we'll create temp dir lazily when needed
|
||||
return None
|
||||
|
||||
@override
|
||||
async def abefore_agent(
|
||||
self, state: LargeToolResultState, runtime: Runtime[Any]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Initialize temp directory tracking in state (async)."""
|
||||
return None
|
||||
|
||||
@override
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
"""Process messages before model call, offloading large tool results."""
|
||||
if not request.messages:
|
||||
return handler(request)
|
||||
|
||||
threshold_tokens = self._get_threshold_tokens(request.model)
|
||||
processed_messages = self._process_messages(
|
||||
deepcopy(list(request.messages)),
|
||||
threshold_tokens,
|
||||
)
|
||||
|
||||
return handler(request.override(messages=processed_messages))
|
||||
|
||||
@override
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
"""Process messages before model call, offloading large tool results (async)."""
|
||||
if not request.messages:
|
||||
return await handler(request)
|
||||
|
||||
threshold_tokens = self._get_threshold_tokens(request.model)
|
||||
processed_messages = self._process_messages(
|
||||
deepcopy(list(request.messages)),
|
||||
threshold_tokens,
|
||||
)
|
||||
|
||||
return await handler(request.override(messages=processed_messages))
|
||||
|
||||
@override
|
||||
def after_agent(
|
||||
self, state: LargeToolResultState, runtime: Runtime[Any]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Clean up temp directory on agent completion."""
|
||||
if self.cleanup_on_end and self._temp_dir is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
self._temp_dir.cleanup()
|
||||
self._temp_dir = None
|
||||
return None
|
||||
|
||||
@override
|
||||
async def aafter_agent(
|
||||
self, state: LargeToolResultState, runtime: Runtime[Any]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Clean up temp directory on agent completion (async)."""
|
||||
if self.cleanup_on_end and self._temp_dir is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
self._temp_dir.cleanup()
|
||||
self._temp_dir = None
|
||||
return None
|
||||
@@ -61,6 +61,8 @@ Messages to summarize:
|
||||
_DEFAULT_MESSAGES_TO_KEEP = 20
|
||||
_DEFAULT_TRIM_TOKEN_LIMIT = 4000
|
||||
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
|
||||
_DEFAULT_MIN_PRESERVE_TURNS = 1
|
||||
_TURN_PRESERVATION_SAFETY_THRESHOLD = 0.8 # Skip min_preserve_turns if turn exceeds this fraction
|
||||
|
||||
ContextFraction = tuple[Literal["fraction"], float]
|
||||
"""Fraction of model's maximum input tokens.
|
||||
@@ -147,6 +149,7 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
*,
|
||||
trigger: ContextSize | list[ContextSize] | None = None,
|
||||
keep: ContextSize = ("messages", _DEFAULT_MESSAGES_TO_KEEP),
|
||||
min_preserve_turns: int = _DEFAULT_MIN_PRESERVE_TURNS,
|
||||
token_counter: TokenCounter = count_tokens_approximately,
|
||||
summary_prompt: str = DEFAULT_SUMMARY_PROMPT,
|
||||
trim_tokens_to_summarize: int | None = _DEFAULT_TRIM_TOKEN_LIMIT,
|
||||
@@ -200,6 +203,17 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
# Keep the most recent 30% of the model's max input tokens
|
||||
("fraction", 0.3)
|
||||
```
|
||||
min_preserve_turns: Minimum number of complete turns to always preserve,
|
||||
regardless of the `keep` policy.
|
||||
|
||||
A turn consists of a `HumanMessage` followed by any `AIMessage` responses
|
||||
and their associated `ToolMessage` results.
|
||||
|
||||
Defaults to `1`, ensuring the most recent complete turn is always kept.
|
||||
|
||||
Set to `0` to disable this floor and rely solely on the `keep` policy.
|
||||
|
||||
This prevents aggressive summarization from removing all context.
|
||||
token_counter: Function to count tokens in messages.
|
||||
summary_prompt: Prompt template for generating summaries.
|
||||
trim_tokens_to_summarize: Maximum tokens to keep when preparing messages for
|
||||
@@ -248,6 +262,7 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
self._trigger_conditions = trigger_conditions
|
||||
|
||||
self.keep = self._validate_context_size(keep, "keep")
|
||||
self.min_preserve_turns = min_preserve_turns
|
||||
if token_counter is count_tokens_approximately:
|
||||
self.token_counter = _get_approximate_token_counter(self.model)
|
||||
else:
|
||||
@@ -346,16 +361,45 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
return False
|
||||
|
||||
def _determine_cutoff_index(self, messages: list[AnyMessage]) -> int:
|
||||
"""Choose cutoff index respecting retention configuration."""
|
||||
"""Choose cutoff index respecting retention configuration and `min_preserve_turns`.
|
||||
|
||||
The cutoff is determined by the `keep` policy, but is further constrained by
|
||||
`min_preserve_turns` to ensure at least N complete turns are always preserved.
|
||||
|
||||
Safety check: If preserving the minimum turns would exceed 80% of the model's
|
||||
context window, fall back to the `keep` policy to avoid model rejection.
|
||||
"""
|
||||
kind, value = self.keep
|
||||
if kind in {"tokens", "fraction"}:
|
||||
token_based_cutoff = self._find_token_based_cutoff(messages)
|
||||
if token_based_cutoff is not None:
|
||||
return token_based_cutoff
|
||||
# None cutoff -> model profile data not available (caught in __init__ but
|
||||
# here for safety), fallback to message count
|
||||
return self._find_safe_cutoff(messages, _DEFAULT_MESSAGES_TO_KEEP)
|
||||
return self._find_safe_cutoff(messages, cast("int", value))
|
||||
calculated_cutoff = self._find_token_based_cutoff(messages)
|
||||
if calculated_cutoff is None:
|
||||
# None cutoff -> model profile data not available (caught in __init__ but
|
||||
# here for safety), fallback to message count
|
||||
calculated_cutoff = self._find_safe_cutoff(messages, _DEFAULT_MESSAGES_TO_KEEP)
|
||||
else:
|
||||
calculated_cutoff = self._find_safe_cutoff(messages, cast("int", value))
|
||||
|
||||
# Enforce minimum turn preservation floor (with safety check)
|
||||
if self.min_preserve_turns > 0:
|
||||
min_cutoff = self._find_last_n_turns(messages, self.min_preserve_turns)
|
||||
|
||||
# Safety check: ensure preserving these turns won't exceed context limit
|
||||
if min_cutoff < calculated_cutoff:
|
||||
preserved_messages = messages[min_cutoff:]
|
||||
preserved_tokens = self.token_counter(preserved_messages)
|
||||
|
||||
# Check against safety threshold (80% of model's max tokens)
|
||||
max_input_tokens = self._get_profile_limits()
|
||||
if max_input_tokens is not None:
|
||||
safety_limit = int(max_input_tokens * _TURN_PRESERVATION_SAFETY_THRESHOLD)
|
||||
if preserved_tokens > safety_limit:
|
||||
# Turn is too large - fall back to keep policy
|
||||
return calculated_cutoff
|
||||
|
||||
# Safe to preserve the turn
|
||||
calculated_cutoff = min_cutoff
|
||||
|
||||
return calculated_cutoff
|
||||
|
||||
def _find_token_based_cutoff(self, messages: list[AnyMessage]) -> int | None:
|
||||
"""Find cutoff index based on target token retention."""
|
||||
@@ -461,6 +505,42 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
|
||||
return messages_to_summarize, preserved_messages
|
||||
|
||||
def _find_last_n_turns(self, messages: list[AnyMessage], n: int) -> int:
|
||||
"""Find the cutoff index that preserves the last N complete turns.
|
||||
|
||||
A 'turn' is defined as starting with a `HumanMessage` and including all
|
||||
subsequent `AIMessage` responses and their associated `ToolMessage` results,
|
||||
until the next `HumanMessage` begins a new turn.
|
||||
|
||||
Args:
|
||||
messages: The list of messages to analyze.
|
||||
n: Number of turns to preserve.
|
||||
|
||||
Returns:
|
||||
The cutoff index such that `messages[cutoff_index:]` contains at least the
|
||||
last N complete turns.
|
||||
|
||||
Returns `0` if fewer than N turns exist (preserving all messages).
|
||||
"""
|
||||
if n <= 0 or not messages:
|
||||
return len(messages)
|
||||
|
||||
# Find indices of all HumanMessages (turn boundaries)
|
||||
human_indices = [i for i, msg in enumerate(messages) if isinstance(msg, HumanMessage)]
|
||||
|
||||
if not human_indices:
|
||||
# No HumanMessages - preserve all messages
|
||||
return 0
|
||||
|
||||
if len(human_indices) <= n:
|
||||
# Fewer turns than requested - preserve all
|
||||
return 0
|
||||
|
||||
# Get the index of the HumanMessage that starts the Nth-to-last turn
|
||||
# If n=1, we want the last HumanMessage
|
||||
# If n=2, we want the second-to-last HumanMessage, etc.
|
||||
return human_indices[-n]
|
||||
|
||||
def _find_safe_cutoff(self, messages: list[AnyMessage], messages_to_keep: int) -> int:
|
||||
"""Find safe cutoff point that preserves AI/Tool message pairs.
|
||||
|
||||
|
||||
@@ -0,0 +1,474 @@
|
||||
"""Tests for the `LargeToolResultMiddleware`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import ModelProfile
|
||||
from langchain_core.language_models.fake_chat_models import FakeChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
HumanMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
|
||||
from langchain.agents.middleware.large_tool_result import (
|
||||
_OFFLOAD_METADATA_KEY,
|
||||
LargeToolResultMiddleware,
|
||||
)
|
||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
|
||||
class ProfileChatModel(FakeChatModel):
|
||||
"""Fake chat model with profile for testing."""
|
||||
|
||||
profile: ModelProfile | None = ModelProfile(max_input_tokens=10000)
|
||||
|
||||
|
||||
def _fake_runtime() -> Runtime:
|
||||
"""Create a fake runtime for testing."""
|
||||
return None # type: ignore[return-value]
|
||||
|
||||
|
||||
def _make_request(
|
||||
messages: list[AnyMessage],
|
||||
*,
|
||||
model: FakeChatModel | None = None,
|
||||
) -> ModelRequest:
|
||||
"""Create a `ModelRequest` for testing."""
|
||||
model = model or ProfileChatModel()
|
||||
state = cast("dict", {"messages": messages})
|
||||
return ModelRequest(
|
||||
model=model,
|
||||
messages=list(messages),
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state=state,
|
||||
runtime=_fake_runtime(),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
|
||||
def test_initialization_with_defaults() -> None:
|
||||
"""Test `LargeToolResultMiddleware` initialization with default values."""
|
||||
middleware = LargeToolResultMiddleware()
|
||||
assert middleware.threshold == ("fraction", 0.10)
|
||||
assert middleware.preview_length == 500
|
||||
assert middleware.user_temp_dir is None
|
||||
assert middleware.cleanup_on_end is True
|
||||
|
||||
|
||||
def test_initialization_with_custom_values() -> None:
|
||||
"""Test `LargeToolResultMiddleware` initialization with custom values."""
|
||||
middleware = LargeToolResultMiddleware(
|
||||
threshold=("tokens", 5000),
|
||||
preview_length=200,
|
||||
cleanup_on_end=False,
|
||||
)
|
||||
assert middleware.threshold == ("tokens", 5000)
|
||||
assert middleware.preview_length == 200
|
||||
assert middleware.cleanup_on_end is False
|
||||
|
||||
|
||||
def test_initialization_with_temp_dir() -> None:
|
||||
"""Test initialization with custom temp directory."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
middleware = LargeToolResultMiddleware(temp_dir=tmp_dir)
|
||||
assert middleware.user_temp_dir == Path(tmp_dir)
|
||||
|
||||
|
||||
def test_validation_invalid_fraction_threshold() -> None:
|
||||
"""Test validation of invalid fractional threshold."""
|
||||
with pytest.raises(ValueError, match="Fractional threshold must be between 0 and 1"):
|
||||
LargeToolResultMiddleware(threshold=("fraction", 0.0))
|
||||
|
||||
with pytest.raises(ValueError, match="Fractional threshold must be between 0 and 1"):
|
||||
LargeToolResultMiddleware(threshold=("fraction", 1.5))
|
||||
|
||||
|
||||
def test_validation_invalid_token_threshold() -> None:
|
||||
"""Test validation of invalid token threshold."""
|
||||
with pytest.raises(ValueError, match="Token threshold must be greater than 0"):
|
||||
LargeToolResultMiddleware(threshold=("tokens", 0))
|
||||
|
||||
with pytest.raises(ValueError, match="Token threshold must be greater than 0"):
|
||||
LargeToolResultMiddleware(threshold=("tokens", -100))
|
||||
|
||||
|
||||
def test_no_offload_below_threshold() -> None:
|
||||
"""Test that small tool results are not offloaded."""
|
||||
middleware = LargeToolResultMiddleware(threshold=("tokens", 1000))
|
||||
|
||||
tool_message = ToolMessage(content="Small result", tool_call_id="call_123")
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(
|
||||
content="I'll use a tool", tool_calls=[{"name": "test", "args": {}, "id": "call_123"}]
|
||||
),
|
||||
tool_message,
|
||||
]
|
||||
|
||||
request = _make_request(messages)
|
||||
|
||||
modified_request = None
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return ModelResponse(result=[AIMessage(content="mock")])
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
assert modified_request is not None
|
||||
# Tool message should be unchanged
|
||||
result_tool_msg = modified_request.messages[2]
|
||||
assert isinstance(result_tool_msg, ToolMessage)
|
||||
assert result_tool_msg.content == "Small result"
|
||||
assert not result_tool_msg.response_metadata.get(_OFFLOAD_METADATA_KEY)
|
||||
|
||||
|
||||
def test_offload_large_tool_result() -> None:
|
||||
"""Test that large tool results are offloaded to disk."""
|
||||
middleware = LargeToolResultMiddleware(
|
||||
threshold=("tokens", 10), # Very low threshold
|
||||
preview_length=20,
|
||||
)
|
||||
|
||||
# Create a large tool result (way over 10 tokens)
|
||||
large_content = "A" * 1000
|
||||
tool_message = ToolMessage(content=large_content, tool_call_id="call_abc123")
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(
|
||||
content="I'll use a tool",
|
||||
tool_calls=[{"name": "test", "args": {}, "id": "call_abc123"}],
|
||||
),
|
||||
tool_message,
|
||||
]
|
||||
|
||||
request = _make_request(messages)
|
||||
|
||||
modified_request = None
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return ModelResponse(result=[AIMessage(content="mock")])
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
assert modified_request is not None
|
||||
result_tool_msg = modified_request.messages[2]
|
||||
assert isinstance(result_tool_msg, ToolMessage)
|
||||
|
||||
# Content should be truncated
|
||||
assert "[TRUNCATED - Full result saved to:" in result_tool_msg.content
|
||||
assert "Preview (first 20 chars):" in result_tool_msg.content
|
||||
assert "A" * 20 in result_tool_msg.content
|
||||
|
||||
# Metadata should be set
|
||||
offload_metadata = result_tool_msg.response_metadata.get(_OFFLOAD_METADATA_KEY)
|
||||
assert offload_metadata is not None
|
||||
assert offload_metadata["offloaded"] is True
|
||||
assert "file_path" in offload_metadata
|
||||
assert offload_metadata["original_size_chars"] == 1000
|
||||
|
||||
# Verify file exists and contains original content
|
||||
file_path = Path(offload_metadata["file_path"])
|
||||
assert file_path.exists()
|
||||
assert file_path.read_text() == large_content
|
||||
|
||||
# Cleanup
|
||||
middleware.after_agent({"messages": []}, _fake_runtime())
|
||||
|
||||
|
||||
def test_preview_content_preserved() -> None:
|
||||
"""Test that preview contains the beginning of the content."""
|
||||
middleware = LargeToolResultMiddleware(
|
||||
threshold=("tokens", 10),
|
||||
preview_length=50,
|
||||
)
|
||||
|
||||
content = "This is the beginning of a very long message." + "X" * 1000
|
||||
tool_message = ToolMessage(content=content, tool_call_id="call_xyz")
|
||||
messages: list[AnyMessage] = [tool_message]
|
||||
|
||||
request = _make_request(messages)
|
||||
|
||||
modified_request = None
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return ModelResponse(result=[AIMessage(content="mock")])
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
assert modified_request is not None
|
||||
result_tool_msg = modified_request.messages[0]
|
||||
assert isinstance(result_tool_msg, ToolMessage)
|
||||
|
||||
# Preview should contain the beginning of the content
|
||||
assert "This is the beginning of a very long message." in result_tool_msg.content
|
||||
|
||||
# Cleanup
|
||||
middleware.after_agent({"messages": []}, _fake_runtime())
|
||||
|
||||
|
||||
def test_temp_dir_cleanup_on_agent_end() -> None:
|
||||
"""Test that temp directory is cleaned up when agent ends."""
|
||||
middleware = LargeToolResultMiddleware(
|
||||
threshold=("tokens", 10),
|
||||
cleanup_on_end=True,
|
||||
)
|
||||
|
||||
# Trigger offload to create temp dir
|
||||
large_content = "A" * 1000
|
||||
tool_message = ToolMessage(content=large_content, tool_call_id="call_cleanup")
|
||||
messages: list[AnyMessage] = [tool_message]
|
||||
|
||||
request = _make_request(messages)
|
||||
|
||||
modified_request = None
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return ModelResponse(result=[AIMessage(content="mock")])
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
assert modified_request is not None
|
||||
result_tool_msg = modified_request.messages[0]
|
||||
offload_metadata = result_tool_msg.response_metadata.get(_OFFLOAD_METADATA_KEY)
|
||||
file_path = Path(offload_metadata["file_path"])
|
||||
temp_dir = file_path.parent
|
||||
|
||||
# Verify file exists before cleanup
|
||||
assert file_path.exists()
|
||||
assert temp_dir.exists()
|
||||
|
||||
# Call after_agent to trigger cleanup
|
||||
middleware.after_agent({"messages": []}, _fake_runtime())
|
||||
|
||||
# Temp dir should be cleaned up
|
||||
assert not temp_dir.exists()
|
||||
|
||||
|
||||
def test_custom_temp_dir_not_deleted() -> None:
|
||||
"""Test that user-provided temp dir is not deleted on cleanup."""
|
||||
with tempfile.TemporaryDirectory() as user_dir:
|
||||
user_path = Path(user_dir)
|
||||
middleware = LargeToolResultMiddleware(
|
||||
threshold=("tokens", 10),
|
||||
temp_dir=user_dir,
|
||||
cleanup_on_end=True,
|
||||
)
|
||||
|
||||
# Trigger offload
|
||||
large_content = "A" * 1000
|
||||
tool_message = ToolMessage(content=large_content, tool_call_id="call_custom")
|
||||
messages: list[AnyMessage] = [tool_message]
|
||||
|
||||
request = _make_request(messages)
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
return ModelResponse(result=[AIMessage(content="mock")])
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
# Call after_agent
|
||||
middleware.after_agent({"messages": []}, _fake_runtime())
|
||||
|
||||
# User-provided dir should still exist
|
||||
assert user_path.exists()
|
||||
|
||||
|
||||
def test_multiple_large_results_in_turn() -> None:
|
||||
"""Test handling of multiple large tool results in parallel tool calls."""
|
||||
middleware = LargeToolResultMiddleware(
|
||||
threshold=("tokens", 10),
|
||||
preview_length=10,
|
||||
)
|
||||
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="Process these files"),
|
||||
AIMessage(
|
||||
content="I'll process all files",
|
||||
tool_calls=[
|
||||
{"name": "read_file", "args": {"path": "file1.txt"}, "id": "call_1"},
|
||||
{"name": "read_file", "args": {"path": "file2.txt"}, "id": "call_2"},
|
||||
{"name": "read_file", "args": {"path": "file3.txt"}, "id": "call_3"},
|
||||
],
|
||||
),
|
||||
ToolMessage(content="X" * 500, tool_call_id="call_1"),
|
||||
ToolMessage(content="Y" * 500, tool_call_id="call_2"),
|
||||
ToolMessage(content="Z" * 500, tool_call_id="call_3"),
|
||||
]
|
||||
|
||||
request = _make_request(messages)
|
||||
|
||||
modified_request = None
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return ModelResponse(result=[AIMessage(content="mock")])
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
assert modified_request is not None
|
||||
|
||||
# All three tool messages should be offloaded
|
||||
for i in range(2, 5):
|
||||
tool_msg = modified_request.messages[i]
|
||||
assert isinstance(tool_msg, ToolMessage)
|
||||
assert "[TRUNCATED - Full result saved to:" in tool_msg.content
|
||||
offload_metadata = tool_msg.response_metadata.get(_OFFLOAD_METADATA_KEY)
|
||||
assert offload_metadata is not None
|
||||
assert offload_metadata["offloaded"] is True
|
||||
|
||||
# Each should have a unique file
|
||||
file_paths = set()
|
||||
for i in range(2, 5):
|
||||
tool_msg = modified_request.messages[i]
|
||||
offload_metadata = tool_msg.response_metadata.get(_OFFLOAD_METADATA_KEY)
|
||||
file_paths.add(offload_metadata["file_path"])
|
||||
|
||||
assert len(file_paths) == 3 # Three unique files
|
||||
|
||||
# Cleanup
|
||||
middleware.after_agent({"messages": []}, _fake_runtime())
|
||||
|
||||
|
||||
def test_already_offloaded_skipped() -> None:
|
||||
"""Test that already offloaded messages are not re-processed."""
|
||||
middleware = LargeToolResultMiddleware(
|
||||
threshold=("tokens", 10),
|
||||
)
|
||||
|
||||
# Pre-offloaded message
|
||||
tool_message = ToolMessage(
|
||||
content="[TRUNCATED - Full result saved to: /path/to/file.txt]\n\nPreview...",
|
||||
tool_call_id="call_already",
|
||||
response_metadata={
|
||||
_OFFLOAD_METADATA_KEY: {
|
||||
"offloaded": True,
|
||||
"file_path": "/path/to/file.txt",
|
||||
"original_size_chars": 10000,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
messages: list[AnyMessage] = [tool_message]
|
||||
request = _make_request(messages)
|
||||
|
||||
modified_request = None
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return ModelResponse(result=[AIMessage(content="mock")])
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
assert modified_request is not None
|
||||
result_tool_msg = modified_request.messages[0]
|
||||
assert isinstance(result_tool_msg, ToolMessage)
|
||||
|
||||
# Content should be unchanged (not re-processed)
|
||||
assert result_tool_msg.content == tool_message.content
|
||||
|
||||
|
||||
def test_fraction_threshold_with_model_profile() -> None:
|
||||
"""Test fractional threshold calculation with model profile."""
|
||||
middleware = LargeToolResultMiddleware(
|
||||
threshold=("fraction", 0.10), # 10% of 10000 = 1000 tokens
|
||||
)
|
||||
|
||||
# Content that is about 1200 tokens (4800 chars / 4 chars per token estimate)
|
||||
# This should exceed 10% of 10000 tokens
|
||||
large_content = "A" * 4800
|
||||
tool_message = ToolMessage(content=large_content, tool_call_id="call_frac")
|
||||
messages: list[AnyMessage] = [tool_message]
|
||||
|
||||
# Use model with profile
|
||||
model = ProfileChatModel()
|
||||
request = _make_request(messages, model=model)
|
||||
|
||||
modified_request = None
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return ModelResponse(result=[AIMessage(content="mock")])
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
assert modified_request is not None
|
||||
result_tool_msg = modified_request.messages[0]
|
||||
assert isinstance(result_tool_msg, ToolMessage)
|
||||
|
||||
# Should be offloaded due to exceeding 10% threshold
|
||||
assert "[TRUNCATED - Full result saved to:" in result_tool_msg.content
|
||||
|
||||
# Cleanup
|
||||
middleware.after_agent({"messages": []}, _fake_runtime())
|
||||
|
||||
|
||||
async def test_async_wrap_model_call() -> None:
|
||||
"""Test async version of `wrap_model_call`."""
|
||||
middleware = LargeToolResultMiddleware(
|
||||
threshold=("tokens", 10),
|
||||
preview_length=20,
|
||||
)
|
||||
|
||||
large_content = "A" * 1000
|
||||
tool_message = ToolMessage(content=large_content, tool_call_id="call_async")
|
||||
messages: list[AnyMessage] = [tool_message]
|
||||
|
||||
request = _make_request(messages)
|
||||
|
||||
modified_request = None
|
||||
|
||||
async def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
nonlocal modified_request
|
||||
modified_request = req
|
||||
return ModelResponse(result=[AIMessage(content="mock")])
|
||||
|
||||
await middleware.awrap_model_call(request, mock_handler)
|
||||
|
||||
assert modified_request is not None
|
||||
result_tool_msg = modified_request.messages[0]
|
||||
assert isinstance(result_tool_msg, ToolMessage)
|
||||
assert "[TRUNCATED - Full result saved to:" in result_tool_msg.content
|
||||
|
||||
# Cleanup
|
||||
await middleware.aafter_agent({"messages": []}, _fake_runtime())
|
||||
|
||||
|
||||
def test_empty_messages_passthrough() -> None:
|
||||
"""Test that empty messages list is handled correctly."""
|
||||
middleware = LargeToolResultMiddleware(threshold=("tokens", 10))
|
||||
|
||||
request = _make_request([])
|
||||
|
||||
called = False
|
||||
|
||||
def mock_handler(req: ModelRequest) -> ModelResponse:
|
||||
nonlocal called
|
||||
called = True
|
||||
return ModelResponse(result=[AIMessage(content="mock")])
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
assert called
|
||||
@@ -1014,3 +1014,387 @@ def test_create_summary_uses_get_buffer_string_format() -> None:
|
||||
f"str(messages) should produce significantly more tokens. "
|
||||
f"Got ratio {str_ratio:.2f}x (expected > 1.5)"
|
||||
)
|
||||
|
||||
|
||||
def test_summarization_middleware_min_preserve_turns_default() -> None:
|
||||
"""Test that `min_preserve_turns` defaults to 1 and keeps at least one turn."""
|
||||
middleware = SummarizationMiddleware(
|
||||
model=MockChatModel(),
|
||||
trigger=("messages", 5),
|
||||
keep=("messages", 1), # Would keep only 1 message (very aggressive)
|
||||
)
|
||||
|
||||
# Default min_preserve_turns should be 1
|
||||
assert middleware.min_preserve_turns == 1
|
||||
|
||||
# Mock token counter
|
||||
def mock_token_counter(messages: list[AnyMessage]) -> int:
|
||||
return 1500
|
||||
|
||||
middleware.token_counter = mock_token_counter
|
||||
|
||||
# Messages: 3 turns
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="Turn 1"),
|
||||
AIMessage(content="Response 1"),
|
||||
HumanMessage(content="Turn 2"),
|
||||
AIMessage(content="Response 2"),
|
||||
HumanMessage(content="Turn 3"),
|
||||
AIMessage(content="Response 3"),
|
||||
]
|
||||
|
||||
state = {"messages": messages}
|
||||
result = middleware.before_model(state, None)
|
||||
|
||||
# Should preserve at least 1 turn (last HumanMessage at index 4)
|
||||
assert result is not None
|
||||
preserved = result["messages"][2:] # After RemoveMessage and summary
|
||||
# min_preserve_turns=1 means we preserve from the last HumanMessage (index 4)
|
||||
# which includes Turn 3 (HumanMessage + AIMessage = 2 messages)
|
||||
assert len(preserved) >= 2
|
||||
# The last turn starts at the last HumanMessage
|
||||
assert any(isinstance(m, HumanMessage) and m.content == "Turn 3" for m in preserved)
|
||||
|
||||
|
||||
def test_summarization_middleware_min_preserve_turns_with_massive_tool_result() -> None:
|
||||
"""Test that `min_preserve_turns` keeps turn even if tool result is massive."""
|
||||
middleware = SummarizationMiddleware(
|
||||
model=MockChatModel(),
|
||||
trigger=("messages", 4),
|
||||
keep=("messages", 1), # Would only keep 1 message
|
||||
min_preserve_turns=1,
|
||||
)
|
||||
|
||||
def mock_token_counter(messages: list[AnyMessage]) -> int:
|
||||
return 10000 # Always above trigger
|
||||
|
||||
middleware.token_counter = mock_token_counter
|
||||
|
||||
# Turn 2 has a massive tool result
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="Turn 1"),
|
||||
AIMessage(content="Response 1"),
|
||||
HumanMessage(content="Turn 2 - Read file"),
|
||||
AIMessage(
|
||||
content="I'll read the file",
|
||||
tool_calls=[{"name": "read", "args": {}, "id": "call_1"}],
|
||||
),
|
||||
ToolMessage(content="X" * 10000, tool_call_id="call_1"), # Massive result
|
||||
]
|
||||
|
||||
state = {"messages": messages}
|
||||
result = middleware.before_model(state, None)
|
||||
|
||||
# Should preserve turn 2 (HumanMessage + AIMessage + ToolMessage)
|
||||
assert result is not None
|
||||
preserved = result["messages"][2:] # After RemoveMessage and summary
|
||||
# Last turn starts at index 2 (HumanMessage "Turn 2 - Read file")
|
||||
assert len(preserved) == 3
|
||||
assert isinstance(preserved[0], HumanMessage)
|
||||
assert preserved[0].content == "Turn 2 - Read file"
|
||||
|
||||
|
||||
def test_summarization_middleware_min_preserve_turns_zero_disables() -> None:
|
||||
"""Test that `min_preserve_turns=0` disables the floor."""
|
||||
middleware = SummarizationMiddleware(
|
||||
model=MockChatModel(),
|
||||
trigger=("messages", 3),
|
||||
keep=("messages", 1),
|
||||
min_preserve_turns=0, # Disabled
|
||||
)
|
||||
|
||||
def mock_token_counter(messages: list[AnyMessage]) -> int:
|
||||
return 1500
|
||||
|
||||
middleware.token_counter = mock_token_counter
|
||||
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="Turn 1"),
|
||||
AIMessage(content="Response 1"),
|
||||
HumanMessage(content="Turn 2"),
|
||||
]
|
||||
|
||||
state = {"messages": messages}
|
||||
result = middleware.before_model(state, None)
|
||||
|
||||
# With min_preserve_turns=0, should respect keep=1 strictly
|
||||
assert result is not None
|
||||
preserved = result["messages"][2:]
|
||||
assert len(preserved) == 1 # Only keeps 1 message
|
||||
|
||||
|
||||
def test_summarization_middleware_cutoff_never_past_last_turn() -> None:
|
||||
"""Test that cutoff never goes past the last turn, even with tight token budget."""
|
||||
middleware = SummarizationMiddleware(
|
||||
model=MockChatModel(),
|
||||
trigger=("messages", 5),
|
||||
keep=("messages", 1), # Very aggressive - would keep only 1 message
|
||||
min_preserve_turns=1,
|
||||
)
|
||||
|
||||
def mock_token_counter(messages: list[AnyMessage]) -> int:
|
||||
return 5000
|
||||
|
||||
middleware.token_counter = mock_token_counter
|
||||
|
||||
# Create messages where keep=1 would want to summarize almost everything
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="Old turn 1"),
|
||||
AIMessage(content="Old response 1"),
|
||||
HumanMessage(content="Old turn 2"),
|
||||
AIMessage(content="Old response 2"),
|
||||
HumanMessage(content="Last turn"),
|
||||
AIMessage(content="Last response"),
|
||||
]
|
||||
|
||||
state = {"messages": messages}
|
||||
result = middleware.before_model(state, None)
|
||||
|
||||
assert result is not None
|
||||
preserved = result["messages"][2:]
|
||||
|
||||
# Should preserve at least the last turn (HumanMessage + AIMessage)
|
||||
# min_preserve_turns=1 overrides keep=1 to preserve the full last turn
|
||||
assert len(preserved) >= 2
|
||||
# Find the last HumanMessage in preserved
|
||||
human_msgs = [m for m in preserved if isinstance(m, HumanMessage)]
|
||||
assert len(human_msgs) >= 1
|
||||
assert human_msgs[-1].content == "Last turn"
|
||||
|
||||
|
||||
def test_summarization_middleware_find_last_n_turns_basic() -> None:
|
||||
"""Test `_find_last_n_turns` with basic message sequences."""
|
||||
middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 10))
|
||||
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="Turn 1"), # index 0
|
||||
AIMessage(content="Response 1"), # index 1
|
||||
HumanMessage(content="Turn 2"), # index 2
|
||||
AIMessage(content="Response 2"), # index 3
|
||||
HumanMessage(content="Turn 3"), # index 4
|
||||
AIMessage(content="Response 3"), # index 5
|
||||
]
|
||||
|
||||
# n=1: should return index 4 (last HumanMessage)
|
||||
assert middleware._find_last_n_turns(messages, 1) == 4
|
||||
|
||||
# n=2: should return index 2 (second-to-last HumanMessage)
|
||||
assert middleware._find_last_n_turns(messages, 2) == 2
|
||||
|
||||
# n=3: should return index 0 (third-to-last HumanMessage)
|
||||
assert middleware._find_last_n_turns(messages, 3) == 0
|
||||
|
||||
# n=4: fewer turns than requested, return 0
|
||||
assert middleware._find_last_n_turns(messages, 4) == 0
|
||||
|
||||
|
||||
def test_summarization_middleware_find_last_n_turns_with_tools() -> None:
|
||||
"""Test `_find_last_n_turns` with tool messages in turns."""
|
||||
middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 10))
|
||||
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="Turn 1"), # index 0
|
||||
AIMessage(content="Response 1"), # index 1
|
||||
HumanMessage(content="Turn 2"), # index 2
|
||||
AIMessage(
|
||||
content="Using tools",
|
||||
tool_calls=[{"name": "tool", "args": {}, "id": "call_1"}],
|
||||
), # index 3
|
||||
ToolMessage(content="Result 1", tool_call_id="call_1"), # index 4
|
||||
ToolMessage(content="Result 2", tool_call_id="call_2"), # index 5
|
||||
AIMessage(content="Final response"), # index 6
|
||||
]
|
||||
|
||||
# n=1: should return index 2 (last turn starts at "Turn 2")
|
||||
assert middleware._find_last_n_turns(messages, 1) == 2
|
||||
|
||||
# n=2: should return index 0 (first turn)
|
||||
assert middleware._find_last_n_turns(messages, 2) == 0
|
||||
|
||||
|
||||
def test_summarization_middleware_find_last_n_turns_no_human_messages() -> None:
|
||||
"""Test `_find_last_n_turns` when there are no `HumanMessage`s."""
|
||||
middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 10))
|
||||
|
||||
messages: list[AnyMessage] = [
|
||||
AIMessage(content="AI only 1"),
|
||||
AIMessage(content="AI only 2"),
|
||||
]
|
||||
|
||||
# No turns defined by HumanMessages - preserve all
|
||||
assert middleware._find_last_n_turns(messages, 1) == 0
|
||||
|
||||
|
||||
def test_summarization_middleware_find_last_n_turns_empty() -> None:
|
||||
"""Test _find_last_n_turns with empty messages."""
|
||||
middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 10))
|
||||
|
||||
# Empty messages
|
||||
assert middleware._find_last_n_turns([], 1) == 0
|
||||
|
||||
# n=0 should return len(messages)
|
||||
messages: list[AnyMessage] = [HumanMessage(content="Test")]
|
||||
assert middleware._find_last_n_turns(messages, 0) == 1
|
||||
|
||||
|
||||
def test_summarization_middleware_turn_boundary_detection() -> None:
|
||||
"""Test that turn boundaries are correctly identified at `HumanMessage`s."""
|
||||
middleware = SummarizationMiddleware(
|
||||
model=MockChatModel(),
|
||||
trigger=("messages", 5),
|
||||
keep=("messages", 2),
|
||||
min_preserve_turns=1,
|
||||
)
|
||||
|
||||
def mock_token_counter(messages: list[AnyMessage]) -> int:
|
||||
return 1500
|
||||
|
||||
middleware.token_counter = mock_token_counter
|
||||
|
||||
# Complex sequence with tools
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="Question 1"), # Turn 1 starts
|
||||
AIMessage(content="Answer 1"),
|
||||
HumanMessage(content="Question 2"), # Turn 2 starts
|
||||
AIMessage(
|
||||
content="Let me check",
|
||||
tool_calls=[{"name": "search", "args": {}, "id": "call_1"}],
|
||||
),
|
||||
ToolMessage(content="Search result", tool_call_id="call_1"),
|
||||
AIMessage(content="Based on search, here's my answer"),
|
||||
]
|
||||
|
||||
state = {"messages": messages}
|
||||
result = middleware.before_model(state, None)
|
||||
|
||||
assert result is not None
|
||||
preserved = result["messages"][2:]
|
||||
|
||||
# Last turn (Turn 2) should be preserved completely
|
||||
# It starts at index 2 (HumanMessage "Question 2") and includes:
|
||||
# - HumanMessage (Question 2)
|
||||
# - AIMessage (Let me check)
|
||||
# - ToolMessage (Search result)
|
||||
# - AIMessage (Based on search...)
|
||||
assert len(preserved) == 4
|
||||
assert isinstance(preserved[0], HumanMessage)
|
||||
assert preserved[0].content == "Question 2"
|
||||
|
||||
|
||||
def test_summarization_middleware_multiple_turns_preserved() -> None:
|
||||
"""Test `min_preserve_turns=2` keeps 2 complete turns."""
|
||||
middleware = SummarizationMiddleware(
|
||||
model=MockChatModel(),
|
||||
trigger=("messages", 6),
|
||||
keep=("messages", 1), # Would only keep 1 message
|
||||
min_preserve_turns=2, # But we want 2 turns minimum
|
||||
)
|
||||
|
||||
def mock_token_counter(messages: list[AnyMessage]) -> int:
|
||||
return 2000
|
||||
|
||||
middleware.token_counter = mock_token_counter
|
||||
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="Turn 1"),
|
||||
AIMessage(content="Response 1"),
|
||||
HumanMessage(content="Turn 2"),
|
||||
AIMessage(content="Response 2"),
|
||||
HumanMessage(content="Turn 3"),
|
||||
AIMessage(content="Response 3"),
|
||||
]
|
||||
|
||||
state = {"messages": messages}
|
||||
result = middleware.before_model(state, None)
|
||||
|
||||
assert result is not None
|
||||
preserved = result["messages"][2:]
|
||||
|
||||
# Should preserve turns 2 and 3 (starting from index 2)
|
||||
assert len(preserved) == 4 # Turn 2 (2 msgs) + Turn 3 (2 msgs)
|
||||
human_msgs = [m for m in preserved if isinstance(m, HumanMessage)]
|
||||
assert len(human_msgs) == 2
|
||||
assert human_msgs[0].content == "Turn 2"
|
||||
assert human_msgs[1].content == "Turn 3"
|
||||
|
||||
|
||||
def test_summarization_middleware_safety_check_skips_large_turn() -> None:
|
||||
"""Test that `min_preserve_turns` is skipped if turn exceeds 80% of context."""
|
||||
# Model with 1000 max tokens
|
||||
model = MockChatModel()
|
||||
model.profile = {"max_input_tokens": 1000}
|
||||
|
||||
middleware = SummarizationMiddleware(
|
||||
model=model,
|
||||
trigger=("messages", 4),
|
||||
keep=("messages", 2), # Would keep 2 messages
|
||||
min_preserve_turns=1, # Would want to keep full turn
|
||||
)
|
||||
|
||||
# Token counter that makes the last turn very expensive
|
||||
def mock_token_counter(messages: list[AnyMessage]) -> int:
|
||||
# Make the last turn (messages from index 2 onward) exceed 80% of 1000 = 800 tokens
|
||||
# If checking all messages, return high count to trigger summarization
|
||||
if len(messages) >= 4:
|
||||
return 1500 # Total exceeds trigger
|
||||
if len(messages) == 2:
|
||||
return 900 # Last turn alone exceeds 80% threshold
|
||||
return len(messages) * 100
|
||||
|
||||
middleware.token_counter = mock_token_counter
|
||||
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="Turn 1"),
|
||||
AIMessage(content="Response 1"),
|
||||
HumanMessage(content="Turn 2 with massive content"), # Last turn starts here
|
||||
AIMessage(content="X" * 10000), # Massive response
|
||||
]
|
||||
|
||||
state = {"messages": messages}
|
||||
result = middleware.before_model(state, None)
|
||||
|
||||
assert result is not None
|
||||
preserved = result["messages"][2:]
|
||||
|
||||
# Safety check should kick in - min_preserve_turns skipped, falls back to keep=2
|
||||
# This means we keep only 2 messages (the keep policy) not the full turn
|
||||
assert len(preserved) == 2
|
||||
|
||||
|
||||
def test_summarization_middleware_safety_check_allows_small_turn() -> None:
|
||||
"""Test that `min_preserve_turns` works normally when turn is small enough."""
|
||||
# Model with 10000 max tokens
|
||||
model = MockChatModel()
|
||||
model.profile = {"max_input_tokens": 10000}
|
||||
|
||||
middleware = SummarizationMiddleware(
|
||||
model=model,
|
||||
trigger=("messages", 4),
|
||||
keep=("messages", 1), # Would keep 1 message
|
||||
min_preserve_turns=1, # Want to keep full turn
|
||||
)
|
||||
|
||||
def mock_token_counter(messages: list[AnyMessage]) -> int:
|
||||
# Last turn is small (well under 80% of 10000 = 8000)
|
||||
return len(messages) * 100
|
||||
|
||||
middleware.token_counter = mock_token_counter
|
||||
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="Turn 1"),
|
||||
AIMessage(content="Response 1"),
|
||||
HumanMessage(content="Turn 2"), # Last turn starts here
|
||||
AIMessage(content="Response 2"),
|
||||
]
|
||||
|
||||
state = {"messages": messages}
|
||||
result = middleware.before_model(state, None)
|
||||
|
||||
assert result is not None
|
||||
preserved = result["messages"][2:]
|
||||
|
||||
# Turn is small enough, so min_preserve_turns applies
|
||||
# Should keep full turn (2 messages) not just keep=1
|
||||
assert len(preserved) == 2
|
||||
assert isinstance(preserved[0], HumanMessage)
|
||||
assert preserved[0].content == "Turn 2"
|
||||
|
||||
Reference in New Issue
Block a user