Compare commits

...

1 Commits

Author SHA1 Message Date
Mason Daugherty
5ce81dff22 feat(langchain): LargeToolResultMiddleware 2026-01-08 14:39:51 -05:00
5 changed files with 1275 additions and 8 deletions

View File

@@ -9,6 +9,7 @@ from .human_in_the_loop import (
HumanInTheLoopMiddleware,
InterruptOnConfig,
)
from .large_tool_result import LargeToolResultMiddleware
from .model_call_limit import ModelCallLimitMiddleware
from .model_fallback import ModelFallbackMiddleware
from .model_retry import ModelRetryMiddleware
@@ -54,6 +55,7 @@ __all__ = [
"InterruptOnConfig",
"LLMToolEmulator",
"LLMToolSelectorMiddleware",
"LargeToolResultMiddleware",
"ModelCallLimitMiddleware",
"ModelFallbackMiddleware",
"ModelRequest",

View File

@@ -0,0 +1,327 @@
"""Large tool result middleware.
Offloads large tool results to the filesystem to prevent context overflow.
Results exceeding a configurable threshold are written to temporary files
with a truncated preview kept in the message.
"""
from __future__ import annotations
import contextlib
import tempfile
from collections.abc import Awaitable, Callable, Mapping, Sequence
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Any, Literal
from langchain_core.messages import AnyMessage, ToolMessage
from typing_extensions import NotRequired, override
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ModelCallResult,
ModelRequest,
ModelResponse,
PrivateStateAttr,
)
if TYPE_CHECKING:
from langchain_core.language_models.chat_models import BaseChatModel
from langgraph.runtime import Runtime
__all__ = ["LargeToolResultMiddleware"]
_DEFAULT_THRESHOLD_FRACTION = 0.10
_DEFAULT_PREVIEW_LENGTH = 500
_OFFLOAD_METADATA_KEY = "large_tool_result_offloaded"
ContextFraction = tuple[Literal["fraction"], float]
ContextTokens = tuple[Literal["tokens"], int]
ContextSize = ContextFraction | ContextTokens
class LargeToolResultState(AgentState):
"""Extended state for large tool result middleware."""
offloaded_results_dir: NotRequired[Annotated[str | None, PrivateStateAttr]]
"""Directory where large tool results are stored."""
class LargeToolResultMiddleware(AgentMiddleware[LargeToolResultState, Any]):
"""Offloads large tool results to filesystem to prevent context overflow.
This middleware monitors tool result sizes and automatically writes results
exceeding a threshold to temporary files, keeping a truncated preview in the
message. This prevents massive tool outputs from triggering aggressive
summarization or exceeding model context limits.
The middleware uses `wrap_model_call` to process messages before each model
invocation, ensuring large results are offloaded before token counting occurs.
"""
state_schema = LargeToolResultState
tools: Sequence = ()
def __init__(
self,
*,
threshold: ContextSize = ("fraction", _DEFAULT_THRESHOLD_FRACTION),
preview_length: int = _DEFAULT_PREVIEW_LENGTH,
temp_dir: Path | str | None = None,
cleanup_on_end: bool = True,
) -> None:
"""Initialize large tool result middleware.
Args:
threshold: Size threshold that triggers offloading to disk.
Provide a tuple specifying the threshold type:
- `('fraction', 0.10)`: Offload if result exceeds 10% of model's
max input tokens (default)
- `('tokens', 5000)`: Offload if result exceeds 5000 tokens
preview_length: Number of characters to keep as preview in the message.
Defaults to 500 characters.
temp_dir: Directory for storing offloaded results.
If `None` (default), creates a temporary directory that is cleaned
up when the agent session ends.
If provided, uses the specified directory and does not delete it
on cleanup (user-managed).
cleanup_on_end: Whether to clean up the temp directory when the agent
session ends.
Only applies when `temp_dir` is `None` (auto-created directory).
Defaults to `True`.
"""
super().__init__()
self.threshold = self._validate_threshold(threshold)
self.preview_length = preview_length
self.user_temp_dir = Path(temp_dir) if temp_dir else None
self.cleanup_on_end = cleanup_on_end
self._temp_dir: tempfile.TemporaryDirectory[str] | None = None
def _validate_threshold(self, threshold: ContextSize) -> ContextSize:
"""Validate threshold configuration."""
kind, value = threshold
if kind == "fraction":
if not 0 < value <= 1:
msg = f"Fractional threshold must be between 0 and 1, got {value}."
raise ValueError(msg)
elif kind == "tokens":
if value <= 0:
msg = f"Token threshold must be greater than 0, got {value}."
raise ValueError(msg)
else:
msg = f"Unsupported threshold type: {kind}"
raise ValueError(msg)
return threshold
def _get_threshold_tokens(self, model: BaseChatModel | None) -> int:
"""Calculate threshold in tokens based on configuration."""
kind, value = self.threshold
if kind == "tokens":
return int(value)
# Fractional threshold - need model profile
if model is None:
# Fallback to character-based estimate (4 chars per token)
return int(value * 100_000)
max_input_tokens = self._get_model_max_tokens(model)
if max_input_tokens is None:
# Fallback
return int(value * 100_000)
return int(max_input_tokens * value)
def _get_model_max_tokens(self, model: BaseChatModel) -> int | None:
"""Get model's max input tokens from profile."""
try:
profile = model.profile
except AttributeError:
return None
if not isinstance(profile, Mapping):
return None
max_input_tokens = profile.get("max_input_tokens")
return max_input_tokens if isinstance(max_input_tokens, int) else None
def _get_or_create_temp_dir(self) -> Path:
"""Get or create the temporary directory for storing results."""
if self.user_temp_dir is not None:
self.user_temp_dir.mkdir(parents=True, exist_ok=True)
return self.user_temp_dir
if self._temp_dir is not None:
return Path(self._temp_dir.name)
# Create new temp directory
self._temp_dir = tempfile.TemporaryDirectory(prefix="langchain-large-results-")
return Path(self._temp_dir.name)
def _estimate_tokens(self, content: str) -> int:
"""Estimate token count for a string."""
# Approximate: ~4 characters per token
return len(content) // 4
def _should_offload(self, content: str, threshold_tokens: int) -> bool:
"""Determine if content should be offloaded based on size."""
estimated_tokens = self._estimate_tokens(content)
return estimated_tokens > threshold_tokens
def _offload_content(self, tool_call_id: str, content: str) -> str:
"""Write content to file and return the file path."""
temp_dir = self._get_or_create_temp_dir()
# Sanitize tool_call_id for filename
safe_id = "".join(c if c.isalnum() or c in "-_" else "_" for c in tool_call_id)
file_path = temp_dir / f"{safe_id}.txt"
file_path.write_text(content, encoding="utf-8")
return str(file_path)
def _create_truncated_message(self, original_content: str, file_path: str) -> str:
"""Create a truncated message with file reference."""
preview = original_content[: self.preview_length]
if len(original_content) > self.preview_length:
preview += "..."
return (
f"[TRUNCATED - Full result saved to: {file_path}]\n\n"
f"Preview (first {self.preview_length} chars):\n{preview}"
)
def _process_messages(
self,
messages: list[AnyMessage],
threshold_tokens: int,
) -> list[AnyMessage]:
"""Process messages and offload large tool results."""
processed: list[AnyMessage] = []
for msg in messages:
if not isinstance(msg, ToolMessage):
processed.append(msg)
continue
# Check if already offloaded
if msg.response_metadata.get(_OFFLOAD_METADATA_KEY):
processed.append(msg)
continue
# Get content as string
content = msg.content
if isinstance(content, list):
# Multimodal content - convert to string for size check
content = str(content)
if not isinstance(content, str):
content = str(content)
# Check if should offload
if not self._should_offload(content, threshold_tokens):
processed.append(msg)
continue
# Offload to file
tool_call_id = msg.tool_call_id or f"unknown_{id(msg)}"
file_path = self._offload_content(tool_call_id, content)
# Create truncated message
truncated_content = self._create_truncated_message(content, file_path)
# Create new message with truncated content
new_msg = msg.model_copy(
update={
"content": truncated_content,
"response_metadata": {
**msg.response_metadata,
_OFFLOAD_METADATA_KEY: {
"offloaded": True,
"file_path": file_path,
"original_size_chars": len(content),
},
},
}
)
processed.append(new_msg)
return processed
@override
def before_agent(
self, state: LargeToolResultState, runtime: Runtime[Any]
) -> dict[str, Any] | None:
"""Initialize temp directory tracking in state."""
# Just return None - we'll create temp dir lazily when needed
return None
@override
async def abefore_agent(
self, state: LargeToolResultState, runtime: Runtime[Any]
) -> dict[str, Any] | None:
"""Initialize temp directory tracking in state (async)."""
return None
@override
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
"""Process messages before model call, offloading large tool results."""
if not request.messages:
return handler(request)
threshold_tokens = self._get_threshold_tokens(request.model)
processed_messages = self._process_messages(
deepcopy(list(request.messages)),
threshold_tokens,
)
return handler(request.override(messages=processed_messages))
@override
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
"""Process messages before model call, offloading large tool results (async)."""
if not request.messages:
return await handler(request)
threshold_tokens = self._get_threshold_tokens(request.model)
processed_messages = self._process_messages(
deepcopy(list(request.messages)),
threshold_tokens,
)
return await handler(request.override(messages=processed_messages))
@override
def after_agent(
self, state: LargeToolResultState, runtime: Runtime[Any]
) -> dict[str, Any] | None:
"""Clean up temp directory on agent completion."""
if self.cleanup_on_end and self._temp_dir is not None:
with contextlib.suppress(Exception):
self._temp_dir.cleanup()
self._temp_dir = None
return None
@override
async def aafter_agent(
self, state: LargeToolResultState, runtime: Runtime[Any]
) -> dict[str, Any] | None:
"""Clean up temp directory on agent completion (async)."""
if self.cleanup_on_end and self._temp_dir is not None:
with contextlib.suppress(Exception):
self._temp_dir.cleanup()
self._temp_dir = None
return None

View File

@@ -61,6 +61,8 @@ Messages to summarize:
_DEFAULT_MESSAGES_TO_KEEP = 20
_DEFAULT_TRIM_TOKEN_LIMIT = 4000
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
_DEFAULT_MIN_PRESERVE_TURNS = 1
_TURN_PRESERVATION_SAFETY_THRESHOLD = 0.8 # Skip min_preserve_turns if turn exceeds this fraction
ContextFraction = tuple[Literal["fraction"], float]
"""Fraction of model's maximum input tokens.
@@ -147,6 +149,7 @@ class SummarizationMiddleware(AgentMiddleware):
*,
trigger: ContextSize | list[ContextSize] | None = None,
keep: ContextSize = ("messages", _DEFAULT_MESSAGES_TO_KEEP),
min_preserve_turns: int = _DEFAULT_MIN_PRESERVE_TURNS,
token_counter: TokenCounter = count_tokens_approximately,
summary_prompt: str = DEFAULT_SUMMARY_PROMPT,
trim_tokens_to_summarize: int | None = _DEFAULT_TRIM_TOKEN_LIMIT,
@@ -200,6 +203,17 @@ class SummarizationMiddleware(AgentMiddleware):
# Keep the most recent 30% of the model's max input tokens
("fraction", 0.3)
```
min_preserve_turns: Minimum number of complete turns to always preserve,
regardless of the `keep` policy.
A turn consists of a `HumanMessage` followed by any `AIMessage` responses
and their associated `ToolMessage` results.
Defaults to `1`, ensuring the most recent complete turn is always kept.
Set to `0` to disable this floor and rely solely on the `keep` policy.
This prevents aggressive summarization from removing all context.
token_counter: Function to count tokens in messages.
summary_prompt: Prompt template for generating summaries.
trim_tokens_to_summarize: Maximum tokens to keep when preparing messages for
@@ -248,6 +262,7 @@ class SummarizationMiddleware(AgentMiddleware):
self._trigger_conditions = trigger_conditions
self.keep = self._validate_context_size(keep, "keep")
self.min_preserve_turns = min_preserve_turns
if token_counter is count_tokens_approximately:
self.token_counter = _get_approximate_token_counter(self.model)
else:
@@ -346,16 +361,45 @@ class SummarizationMiddleware(AgentMiddleware):
return False
def _determine_cutoff_index(self, messages: list[AnyMessage]) -> int:
"""Choose cutoff index respecting retention configuration."""
"""Choose cutoff index respecting retention configuration and `min_preserve_turns`.
The cutoff is determined by the `keep` policy, but is further constrained by
`min_preserve_turns` to ensure at least N complete turns are always preserved.
Safety check: If preserving the minimum turns would exceed 80% of the model's
context window, fall back to the `keep` policy to avoid model rejection.
"""
kind, value = self.keep
if kind in {"tokens", "fraction"}:
token_based_cutoff = self._find_token_based_cutoff(messages)
if token_based_cutoff is not None:
return token_based_cutoff
# None cutoff -> model profile data not available (caught in __init__ but
# here for safety), fallback to message count
return self._find_safe_cutoff(messages, _DEFAULT_MESSAGES_TO_KEEP)
return self._find_safe_cutoff(messages, cast("int", value))
calculated_cutoff = self._find_token_based_cutoff(messages)
if calculated_cutoff is None:
# None cutoff -> model profile data not available (caught in __init__ but
# here for safety), fallback to message count
calculated_cutoff = self._find_safe_cutoff(messages, _DEFAULT_MESSAGES_TO_KEEP)
else:
calculated_cutoff = self._find_safe_cutoff(messages, cast("int", value))
# Enforce minimum turn preservation floor (with safety check)
if self.min_preserve_turns > 0:
min_cutoff = self._find_last_n_turns(messages, self.min_preserve_turns)
# Safety check: ensure preserving these turns won't exceed context limit
if min_cutoff < calculated_cutoff:
preserved_messages = messages[min_cutoff:]
preserved_tokens = self.token_counter(preserved_messages)
# Check against safety threshold (80% of model's max tokens)
max_input_tokens = self._get_profile_limits()
if max_input_tokens is not None:
safety_limit = int(max_input_tokens * _TURN_PRESERVATION_SAFETY_THRESHOLD)
if preserved_tokens > safety_limit:
# Turn is too large - fall back to keep policy
return calculated_cutoff
# Safe to preserve the turn
calculated_cutoff = min_cutoff
return calculated_cutoff
def _find_token_based_cutoff(self, messages: list[AnyMessage]) -> int | None:
"""Find cutoff index based on target token retention."""
@@ -461,6 +505,42 @@ class SummarizationMiddleware(AgentMiddleware):
return messages_to_summarize, preserved_messages
def _find_last_n_turns(self, messages: list[AnyMessage], n: int) -> int:
"""Find the cutoff index that preserves the last N complete turns.
A 'turn' is defined as starting with a `HumanMessage` and including all
subsequent `AIMessage` responses and their associated `ToolMessage` results,
until the next `HumanMessage` begins a new turn.
Args:
messages: The list of messages to analyze.
n: Number of turns to preserve.
Returns:
The cutoff index such that `messages[cutoff_index:]` contains at least the
last N complete turns.
Returns `0` if fewer than N turns exist (preserving all messages).
"""
if n <= 0 or not messages:
return len(messages)
# Find indices of all HumanMessages (turn boundaries)
human_indices = [i for i, msg in enumerate(messages) if isinstance(msg, HumanMessage)]
if not human_indices:
# No HumanMessages - preserve all messages
return 0
if len(human_indices) <= n:
# Fewer turns than requested - preserve all
return 0
# Get the index of the HumanMessage that starts the Nth-to-last turn
# If n=1, we want the last HumanMessage
# If n=2, we want the second-to-last HumanMessage, etc.
return human_indices[-n]
def _find_safe_cutoff(self, messages: list[AnyMessage], messages_to_keep: int) -> int:
"""Find safe cutoff point that preserves AI/Tool message pairs.

View File

@@ -0,0 +1,474 @@
"""Tests for the `LargeToolResultMiddleware`."""
from __future__ import annotations
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING, cast
import pytest
from langchain_core.language_models import ModelProfile
from langchain_core.language_models.fake_chat_models import FakeChatModel
from langchain_core.messages import (
AIMessage,
AnyMessage,
HumanMessage,
ToolMessage,
)
from langchain.agents.middleware.large_tool_result import (
_OFFLOAD_METADATA_KEY,
LargeToolResultMiddleware,
)
from langchain.agents.middleware.types import ModelRequest, ModelResponse
if TYPE_CHECKING:
from langgraph.runtime import Runtime
class ProfileChatModel(FakeChatModel):
"""Fake chat model with profile for testing."""
profile: ModelProfile | None = ModelProfile(max_input_tokens=10000)
def _fake_runtime() -> Runtime:
"""Create a fake runtime for testing."""
return None # type: ignore[return-value]
def _make_request(
messages: list[AnyMessage],
*,
model: FakeChatModel | None = None,
) -> ModelRequest:
"""Create a `ModelRequest` for testing."""
model = model or ProfileChatModel()
state = cast("dict", {"messages": messages})
return ModelRequest(
model=model,
messages=list(messages),
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state=state,
runtime=_fake_runtime(),
model_settings={},
)
def test_initialization_with_defaults() -> None:
"""Test `LargeToolResultMiddleware` initialization with default values."""
middleware = LargeToolResultMiddleware()
assert middleware.threshold == ("fraction", 0.10)
assert middleware.preview_length == 500
assert middleware.user_temp_dir is None
assert middleware.cleanup_on_end is True
def test_initialization_with_custom_values() -> None:
"""Test `LargeToolResultMiddleware` initialization with custom values."""
middleware = LargeToolResultMiddleware(
threshold=("tokens", 5000),
preview_length=200,
cleanup_on_end=False,
)
assert middleware.threshold == ("tokens", 5000)
assert middleware.preview_length == 200
assert middleware.cleanup_on_end is False
def test_initialization_with_temp_dir() -> None:
"""Test initialization with custom temp directory."""
with tempfile.TemporaryDirectory() as tmp_dir:
middleware = LargeToolResultMiddleware(temp_dir=tmp_dir)
assert middleware.user_temp_dir == Path(tmp_dir)
def test_validation_invalid_fraction_threshold() -> None:
"""Test validation of invalid fractional threshold."""
with pytest.raises(ValueError, match="Fractional threshold must be between 0 and 1"):
LargeToolResultMiddleware(threshold=("fraction", 0.0))
with pytest.raises(ValueError, match="Fractional threshold must be between 0 and 1"):
LargeToolResultMiddleware(threshold=("fraction", 1.5))
def test_validation_invalid_token_threshold() -> None:
"""Test validation of invalid token threshold."""
with pytest.raises(ValueError, match="Token threshold must be greater than 0"):
LargeToolResultMiddleware(threshold=("tokens", 0))
with pytest.raises(ValueError, match="Token threshold must be greater than 0"):
LargeToolResultMiddleware(threshold=("tokens", -100))
def test_no_offload_below_threshold() -> None:
"""Test that small tool results are not offloaded."""
middleware = LargeToolResultMiddleware(threshold=("tokens", 1000))
tool_message = ToolMessage(content="Small result", tool_call_id="call_123")
messages: list[AnyMessage] = [
HumanMessage(content="Hello"),
AIMessage(
content="I'll use a tool", tool_calls=[{"name": "test", "args": {}, "id": "call_123"}]
),
tool_message,
]
request = _make_request(messages)
modified_request = None
def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal modified_request
modified_request = req
return ModelResponse(result=[AIMessage(content="mock")])
middleware.wrap_model_call(request, mock_handler)
assert modified_request is not None
# Tool message should be unchanged
result_tool_msg = modified_request.messages[2]
assert isinstance(result_tool_msg, ToolMessage)
assert result_tool_msg.content == "Small result"
assert not result_tool_msg.response_metadata.get(_OFFLOAD_METADATA_KEY)
def test_offload_large_tool_result() -> None:
"""Test that large tool results are offloaded to disk."""
middleware = LargeToolResultMiddleware(
threshold=("tokens", 10), # Very low threshold
preview_length=20,
)
# Create a large tool result (way over 10 tokens)
large_content = "A" * 1000
tool_message = ToolMessage(content=large_content, tool_call_id="call_abc123")
messages: list[AnyMessage] = [
HumanMessage(content="Hello"),
AIMessage(
content="I'll use a tool",
tool_calls=[{"name": "test", "args": {}, "id": "call_abc123"}],
),
tool_message,
]
request = _make_request(messages)
modified_request = None
def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal modified_request
modified_request = req
return ModelResponse(result=[AIMessage(content="mock")])
middleware.wrap_model_call(request, mock_handler)
assert modified_request is not None
result_tool_msg = modified_request.messages[2]
assert isinstance(result_tool_msg, ToolMessage)
# Content should be truncated
assert "[TRUNCATED - Full result saved to:" in result_tool_msg.content
assert "Preview (first 20 chars):" in result_tool_msg.content
assert "A" * 20 in result_tool_msg.content
# Metadata should be set
offload_metadata = result_tool_msg.response_metadata.get(_OFFLOAD_METADATA_KEY)
assert offload_metadata is not None
assert offload_metadata["offloaded"] is True
assert "file_path" in offload_metadata
assert offload_metadata["original_size_chars"] == 1000
# Verify file exists and contains original content
file_path = Path(offload_metadata["file_path"])
assert file_path.exists()
assert file_path.read_text() == large_content
# Cleanup
middleware.after_agent({"messages": []}, _fake_runtime())
def test_preview_content_preserved() -> None:
"""Test that preview contains the beginning of the content."""
middleware = LargeToolResultMiddleware(
threshold=("tokens", 10),
preview_length=50,
)
content = "This is the beginning of a very long message." + "X" * 1000
tool_message = ToolMessage(content=content, tool_call_id="call_xyz")
messages: list[AnyMessage] = [tool_message]
request = _make_request(messages)
modified_request = None
def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal modified_request
modified_request = req
return ModelResponse(result=[AIMessage(content="mock")])
middleware.wrap_model_call(request, mock_handler)
assert modified_request is not None
result_tool_msg = modified_request.messages[0]
assert isinstance(result_tool_msg, ToolMessage)
# Preview should contain the beginning of the content
assert "This is the beginning of a very long message." in result_tool_msg.content
# Cleanup
middleware.after_agent({"messages": []}, _fake_runtime())
def test_temp_dir_cleanup_on_agent_end() -> None:
"""Test that temp directory is cleaned up when agent ends."""
middleware = LargeToolResultMiddleware(
threshold=("tokens", 10),
cleanup_on_end=True,
)
# Trigger offload to create temp dir
large_content = "A" * 1000
tool_message = ToolMessage(content=large_content, tool_call_id="call_cleanup")
messages: list[AnyMessage] = [tool_message]
request = _make_request(messages)
modified_request = None
def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal modified_request
modified_request = req
return ModelResponse(result=[AIMessage(content="mock")])
middleware.wrap_model_call(request, mock_handler)
assert modified_request is not None
result_tool_msg = modified_request.messages[0]
offload_metadata = result_tool_msg.response_metadata.get(_OFFLOAD_METADATA_KEY)
file_path = Path(offload_metadata["file_path"])
temp_dir = file_path.parent
# Verify file exists before cleanup
assert file_path.exists()
assert temp_dir.exists()
# Call after_agent to trigger cleanup
middleware.after_agent({"messages": []}, _fake_runtime())
# Temp dir should be cleaned up
assert not temp_dir.exists()
def test_custom_temp_dir_not_deleted() -> None:
"""Test that user-provided temp dir is not deleted on cleanup."""
with tempfile.TemporaryDirectory() as user_dir:
user_path = Path(user_dir)
middleware = LargeToolResultMiddleware(
threshold=("tokens", 10),
temp_dir=user_dir,
cleanup_on_end=True,
)
# Trigger offload
large_content = "A" * 1000
tool_message = ToolMessage(content=large_content, tool_call_id="call_custom")
messages: list[AnyMessage] = [tool_message]
request = _make_request(messages)
def mock_handler(req: ModelRequest) -> ModelResponse:
return ModelResponse(result=[AIMessage(content="mock")])
middleware.wrap_model_call(request, mock_handler)
# Call after_agent
middleware.after_agent({"messages": []}, _fake_runtime())
# User-provided dir should still exist
assert user_path.exists()
def test_multiple_large_results_in_turn() -> None:
"""Test handling of multiple large tool results in parallel tool calls."""
middleware = LargeToolResultMiddleware(
threshold=("tokens", 10),
preview_length=10,
)
messages: list[AnyMessage] = [
HumanMessage(content="Process these files"),
AIMessage(
content="I'll process all files",
tool_calls=[
{"name": "read_file", "args": {"path": "file1.txt"}, "id": "call_1"},
{"name": "read_file", "args": {"path": "file2.txt"}, "id": "call_2"},
{"name": "read_file", "args": {"path": "file3.txt"}, "id": "call_3"},
],
),
ToolMessage(content="X" * 500, tool_call_id="call_1"),
ToolMessage(content="Y" * 500, tool_call_id="call_2"),
ToolMessage(content="Z" * 500, tool_call_id="call_3"),
]
request = _make_request(messages)
modified_request = None
def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal modified_request
modified_request = req
return ModelResponse(result=[AIMessage(content="mock")])
middleware.wrap_model_call(request, mock_handler)
assert modified_request is not None
# All three tool messages should be offloaded
for i in range(2, 5):
tool_msg = modified_request.messages[i]
assert isinstance(tool_msg, ToolMessage)
assert "[TRUNCATED - Full result saved to:" in tool_msg.content
offload_metadata = tool_msg.response_metadata.get(_OFFLOAD_METADATA_KEY)
assert offload_metadata is not None
assert offload_metadata["offloaded"] is True
# Each should have a unique file
file_paths = set()
for i in range(2, 5):
tool_msg = modified_request.messages[i]
offload_metadata = tool_msg.response_metadata.get(_OFFLOAD_METADATA_KEY)
file_paths.add(offload_metadata["file_path"])
assert len(file_paths) == 3 # Three unique files
# Cleanup
middleware.after_agent({"messages": []}, _fake_runtime())
def test_already_offloaded_skipped() -> None:
"""Test that already offloaded messages are not re-processed."""
middleware = LargeToolResultMiddleware(
threshold=("tokens", 10),
)
# Pre-offloaded message
tool_message = ToolMessage(
content="[TRUNCATED - Full result saved to: /path/to/file.txt]\n\nPreview...",
tool_call_id="call_already",
response_metadata={
_OFFLOAD_METADATA_KEY: {
"offloaded": True,
"file_path": "/path/to/file.txt",
"original_size_chars": 10000,
}
},
)
messages: list[AnyMessage] = [tool_message]
request = _make_request(messages)
modified_request = None
def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal modified_request
modified_request = req
return ModelResponse(result=[AIMessage(content="mock")])
middleware.wrap_model_call(request, mock_handler)
assert modified_request is not None
result_tool_msg = modified_request.messages[0]
assert isinstance(result_tool_msg, ToolMessage)
# Content should be unchanged (not re-processed)
assert result_tool_msg.content == tool_message.content
def test_fraction_threshold_with_model_profile() -> None:
"""Test fractional threshold calculation with model profile."""
middleware = LargeToolResultMiddleware(
threshold=("fraction", 0.10), # 10% of 10000 = 1000 tokens
)
# Content that is about 1200 tokens (4800 chars / 4 chars per token estimate)
# This should exceed 10% of 10000 tokens
large_content = "A" * 4800
tool_message = ToolMessage(content=large_content, tool_call_id="call_frac")
messages: list[AnyMessage] = [tool_message]
# Use model with profile
model = ProfileChatModel()
request = _make_request(messages, model=model)
modified_request = None
def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal modified_request
modified_request = req
return ModelResponse(result=[AIMessage(content="mock")])
middleware.wrap_model_call(request, mock_handler)
assert modified_request is not None
result_tool_msg = modified_request.messages[0]
assert isinstance(result_tool_msg, ToolMessage)
# Should be offloaded due to exceeding 10% threshold
assert "[TRUNCATED - Full result saved to:" in result_tool_msg.content
# Cleanup
middleware.after_agent({"messages": []}, _fake_runtime())
async def test_async_wrap_model_call() -> None:
"""Test async version of `wrap_model_call`."""
middleware = LargeToolResultMiddleware(
threshold=("tokens", 10),
preview_length=20,
)
large_content = "A" * 1000
tool_message = ToolMessage(content=large_content, tool_call_id="call_async")
messages: list[AnyMessage] = [tool_message]
request = _make_request(messages)
modified_request = None
async def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal modified_request
modified_request = req
return ModelResponse(result=[AIMessage(content="mock")])
await middleware.awrap_model_call(request, mock_handler)
assert modified_request is not None
result_tool_msg = modified_request.messages[0]
assert isinstance(result_tool_msg, ToolMessage)
assert "[TRUNCATED - Full result saved to:" in result_tool_msg.content
# Cleanup
await middleware.aafter_agent({"messages": []}, _fake_runtime())
def test_empty_messages_passthrough() -> None:
"""Test that empty messages list is handled correctly."""
middleware = LargeToolResultMiddleware(threshold=("tokens", 10))
request = _make_request([])
called = False
def mock_handler(req: ModelRequest) -> ModelResponse:
nonlocal called
called = True
return ModelResponse(result=[AIMessage(content="mock")])
middleware.wrap_model_call(request, mock_handler)
assert called

View File

@@ -1014,3 +1014,387 @@ def test_create_summary_uses_get_buffer_string_format() -> None:
f"str(messages) should produce significantly more tokens. "
f"Got ratio {str_ratio:.2f}x (expected > 1.5)"
)
def test_summarization_middleware_min_preserve_turns_default() -> None:
"""Test that `min_preserve_turns` defaults to 1 and keeps at least one turn."""
middleware = SummarizationMiddleware(
model=MockChatModel(),
trigger=("messages", 5),
keep=("messages", 1), # Would keep only 1 message (very aggressive)
)
# Default min_preserve_turns should be 1
assert middleware.min_preserve_turns == 1
# Mock token counter
def mock_token_counter(messages: list[AnyMessage]) -> int:
return 1500
middleware.token_counter = mock_token_counter
# Messages: 3 turns
messages: list[AnyMessage] = [
HumanMessage(content="Turn 1"),
AIMessage(content="Response 1"),
HumanMessage(content="Turn 2"),
AIMessage(content="Response 2"),
HumanMessage(content="Turn 3"),
AIMessage(content="Response 3"),
]
state = {"messages": messages}
result = middleware.before_model(state, None)
# Should preserve at least 1 turn (last HumanMessage at index 4)
assert result is not None
preserved = result["messages"][2:] # After RemoveMessage and summary
# min_preserve_turns=1 means we preserve from the last HumanMessage (index 4)
# which includes Turn 3 (HumanMessage + AIMessage = 2 messages)
assert len(preserved) >= 2
# The last turn starts at the last HumanMessage
assert any(isinstance(m, HumanMessage) and m.content == "Turn 3" for m in preserved)
def test_summarization_middleware_min_preserve_turns_with_massive_tool_result() -> None:
"""Test that `min_preserve_turns` keeps turn even if tool result is massive."""
middleware = SummarizationMiddleware(
model=MockChatModel(),
trigger=("messages", 4),
keep=("messages", 1), # Would only keep 1 message
min_preserve_turns=1,
)
def mock_token_counter(messages: list[AnyMessage]) -> int:
return 10000 # Always above trigger
middleware.token_counter = mock_token_counter
# Turn 2 has a massive tool result
messages: list[AnyMessage] = [
HumanMessage(content="Turn 1"),
AIMessage(content="Response 1"),
HumanMessage(content="Turn 2 - Read file"),
AIMessage(
content="I'll read the file",
tool_calls=[{"name": "read", "args": {}, "id": "call_1"}],
),
ToolMessage(content="X" * 10000, tool_call_id="call_1"), # Massive result
]
state = {"messages": messages}
result = middleware.before_model(state, None)
# Should preserve turn 2 (HumanMessage + AIMessage + ToolMessage)
assert result is not None
preserved = result["messages"][2:] # After RemoveMessage and summary
# Last turn starts at index 2 (HumanMessage "Turn 2 - Read file")
assert len(preserved) == 3
assert isinstance(preserved[0], HumanMessage)
assert preserved[0].content == "Turn 2 - Read file"
def test_summarization_middleware_min_preserve_turns_zero_disables() -> None:
"""Test that `min_preserve_turns=0` disables the floor."""
middleware = SummarizationMiddleware(
model=MockChatModel(),
trigger=("messages", 3),
keep=("messages", 1),
min_preserve_turns=0, # Disabled
)
def mock_token_counter(messages: list[AnyMessage]) -> int:
return 1500
middleware.token_counter = mock_token_counter
messages: list[AnyMessage] = [
HumanMessage(content="Turn 1"),
AIMessage(content="Response 1"),
HumanMessage(content="Turn 2"),
]
state = {"messages": messages}
result = middleware.before_model(state, None)
# With min_preserve_turns=0, should respect keep=1 strictly
assert result is not None
preserved = result["messages"][2:]
assert len(preserved) == 1 # Only keeps 1 message
def test_summarization_middleware_cutoff_never_past_last_turn() -> None:
"""Test that cutoff never goes past the last turn, even with tight token budget."""
middleware = SummarizationMiddleware(
model=MockChatModel(),
trigger=("messages", 5),
keep=("messages", 1), # Very aggressive - would keep only 1 message
min_preserve_turns=1,
)
def mock_token_counter(messages: list[AnyMessage]) -> int:
return 5000
middleware.token_counter = mock_token_counter
# Create messages where keep=1 would want to summarize almost everything
messages: list[AnyMessage] = [
HumanMessage(content="Old turn 1"),
AIMessage(content="Old response 1"),
HumanMessage(content="Old turn 2"),
AIMessage(content="Old response 2"),
HumanMessage(content="Last turn"),
AIMessage(content="Last response"),
]
state = {"messages": messages}
result = middleware.before_model(state, None)
assert result is not None
preserved = result["messages"][2:]
# Should preserve at least the last turn (HumanMessage + AIMessage)
# min_preserve_turns=1 overrides keep=1 to preserve the full last turn
assert len(preserved) >= 2
# Find the last HumanMessage in preserved
human_msgs = [m for m in preserved if isinstance(m, HumanMessage)]
assert len(human_msgs) >= 1
assert human_msgs[-1].content == "Last turn"
def test_summarization_middleware_find_last_n_turns_basic() -> None:
"""Test `_find_last_n_turns` with basic message sequences."""
middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 10))
messages: list[AnyMessage] = [
HumanMessage(content="Turn 1"), # index 0
AIMessage(content="Response 1"), # index 1
HumanMessage(content="Turn 2"), # index 2
AIMessage(content="Response 2"), # index 3
HumanMessage(content="Turn 3"), # index 4
AIMessage(content="Response 3"), # index 5
]
# n=1: should return index 4 (last HumanMessage)
assert middleware._find_last_n_turns(messages, 1) == 4
# n=2: should return index 2 (second-to-last HumanMessage)
assert middleware._find_last_n_turns(messages, 2) == 2
# n=3: should return index 0 (third-to-last HumanMessage)
assert middleware._find_last_n_turns(messages, 3) == 0
# n=4: fewer turns than requested, return 0
assert middleware._find_last_n_turns(messages, 4) == 0
def test_summarization_middleware_find_last_n_turns_with_tools() -> None:
"""Test `_find_last_n_turns` with tool messages in turns."""
middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 10))
messages: list[AnyMessage] = [
HumanMessage(content="Turn 1"), # index 0
AIMessage(content="Response 1"), # index 1
HumanMessage(content="Turn 2"), # index 2
AIMessage(
content="Using tools",
tool_calls=[{"name": "tool", "args": {}, "id": "call_1"}],
), # index 3
ToolMessage(content="Result 1", tool_call_id="call_1"), # index 4
ToolMessage(content="Result 2", tool_call_id="call_2"), # index 5
AIMessage(content="Final response"), # index 6
]
# n=1: should return index 2 (last turn starts at "Turn 2")
assert middleware._find_last_n_turns(messages, 1) == 2
# n=2: should return index 0 (first turn)
assert middleware._find_last_n_turns(messages, 2) == 0
def test_summarization_middleware_find_last_n_turns_no_human_messages() -> None:
"""Test `_find_last_n_turns` when there are no `HumanMessage`s."""
middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 10))
messages: list[AnyMessage] = [
AIMessage(content="AI only 1"),
AIMessage(content="AI only 2"),
]
# No turns defined by HumanMessages - preserve all
assert middleware._find_last_n_turns(messages, 1) == 0
def test_summarization_middleware_find_last_n_turns_empty() -> None:
"""Test _find_last_n_turns with empty messages."""
middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 10))
# Empty messages
assert middleware._find_last_n_turns([], 1) == 0
# n=0 should return len(messages)
messages: list[AnyMessage] = [HumanMessage(content="Test")]
assert middleware._find_last_n_turns(messages, 0) == 1
def test_summarization_middleware_turn_boundary_detection() -> None:
"""Test that turn boundaries are correctly identified at `HumanMessage`s."""
middleware = SummarizationMiddleware(
model=MockChatModel(),
trigger=("messages", 5),
keep=("messages", 2),
min_preserve_turns=1,
)
def mock_token_counter(messages: list[AnyMessage]) -> int:
return 1500
middleware.token_counter = mock_token_counter
# Complex sequence with tools
messages: list[AnyMessage] = [
HumanMessage(content="Question 1"), # Turn 1 starts
AIMessage(content="Answer 1"),
HumanMessage(content="Question 2"), # Turn 2 starts
AIMessage(
content="Let me check",
tool_calls=[{"name": "search", "args": {}, "id": "call_1"}],
),
ToolMessage(content="Search result", tool_call_id="call_1"),
AIMessage(content="Based on search, here's my answer"),
]
state = {"messages": messages}
result = middleware.before_model(state, None)
assert result is not None
preserved = result["messages"][2:]
# Last turn (Turn 2) should be preserved completely
# It starts at index 2 (HumanMessage "Question 2") and includes:
# - HumanMessage (Question 2)
# - AIMessage (Let me check)
# - ToolMessage (Search result)
# - AIMessage (Based on search...)
assert len(preserved) == 4
assert isinstance(preserved[0], HumanMessage)
assert preserved[0].content == "Question 2"
def test_summarization_middleware_multiple_turns_preserved() -> None:
"""Test `min_preserve_turns=2` keeps 2 complete turns."""
middleware = SummarizationMiddleware(
model=MockChatModel(),
trigger=("messages", 6),
keep=("messages", 1), # Would only keep 1 message
min_preserve_turns=2, # But we want 2 turns minimum
)
def mock_token_counter(messages: list[AnyMessage]) -> int:
return 2000
middleware.token_counter = mock_token_counter
messages: list[AnyMessage] = [
HumanMessage(content="Turn 1"),
AIMessage(content="Response 1"),
HumanMessage(content="Turn 2"),
AIMessage(content="Response 2"),
HumanMessage(content="Turn 3"),
AIMessage(content="Response 3"),
]
state = {"messages": messages}
result = middleware.before_model(state, None)
assert result is not None
preserved = result["messages"][2:]
# Should preserve turns 2 and 3 (starting from index 2)
assert len(preserved) == 4 # Turn 2 (2 msgs) + Turn 3 (2 msgs)
human_msgs = [m for m in preserved if isinstance(m, HumanMessage)]
assert len(human_msgs) == 2
assert human_msgs[0].content == "Turn 2"
assert human_msgs[1].content == "Turn 3"
def test_summarization_middleware_safety_check_skips_large_turn() -> None:
"""Test that `min_preserve_turns` is skipped if turn exceeds 80% of context."""
# Model with 1000 max tokens
model = MockChatModel()
model.profile = {"max_input_tokens": 1000}
middleware = SummarizationMiddleware(
model=model,
trigger=("messages", 4),
keep=("messages", 2), # Would keep 2 messages
min_preserve_turns=1, # Would want to keep full turn
)
# Token counter that makes the last turn very expensive
def mock_token_counter(messages: list[AnyMessage]) -> int:
# Make the last turn (messages from index 2 onward) exceed 80% of 1000 = 800 tokens
# If checking all messages, return high count to trigger summarization
if len(messages) >= 4:
return 1500 # Total exceeds trigger
if len(messages) == 2:
return 900 # Last turn alone exceeds 80% threshold
return len(messages) * 100
middleware.token_counter = mock_token_counter
messages: list[AnyMessage] = [
HumanMessage(content="Turn 1"),
AIMessage(content="Response 1"),
HumanMessage(content="Turn 2 with massive content"), # Last turn starts here
AIMessage(content="X" * 10000), # Massive response
]
state = {"messages": messages}
result = middleware.before_model(state, None)
assert result is not None
preserved = result["messages"][2:]
# Safety check should kick in - min_preserve_turns skipped, falls back to keep=2
# This means we keep only 2 messages (the keep policy) not the full turn
assert len(preserved) == 2
def test_summarization_middleware_safety_check_allows_small_turn() -> None:
"""Test that `min_preserve_turns` works normally when turn is small enough."""
# Model with 10000 max tokens
model = MockChatModel()
model.profile = {"max_input_tokens": 10000}
middleware = SummarizationMiddleware(
model=model,
trigger=("messages", 4),
keep=("messages", 1), # Would keep 1 message
min_preserve_turns=1, # Want to keep full turn
)
def mock_token_counter(messages: list[AnyMessage]) -> int:
# Last turn is small (well under 80% of 10000 = 8000)
return len(messages) * 100
middleware.token_counter = mock_token_counter
messages: list[AnyMessage] = [
HumanMessage(content="Turn 1"),
AIMessage(content="Response 1"),
HumanMessage(content="Turn 2"), # Last turn starts here
AIMessage(content="Response 2"),
]
state = {"messages": messages}
result = middleware.before_model(state, None)
assert result is not None
preserved = result["messages"][2:]
# Turn is small enough, so min_preserve_turns applies
# Should keep full turn (2 messages) not just keep=1
assert len(preserved) == 2
assert isinstance(preserved[0], HumanMessage)
assert preserved[0].content == "Turn 2"