Compare commits

...

7 Commits

Author SHA1 Message Date
Sydney Runkle
03373459dc context trimming 2025-11-14 14:57:23 -05:00
Sydney Runkle
62c05e09c1 readable tests 2025-11-14 14:34:24 -05:00
Sydney Runkle
83b9d9f810 parametrize 2025-11-14 14:31:28 -05:00
Sydney Runkle
6deee23d8d Revert "cleanup"
This reverts commit 690aabe8d4.
2025-11-14 14:14:32 -05:00
Sydney Runkle
690aabe8d4 cleanup 2025-11-14 14:12:14 -05:00
Sydney Runkle
80554df1e6 docs 2025-11-14 14:00:15 -05:00
Sydney Runkle
72c45e65e8 summarization edits 2025-11-14 13:56:26 -05:00
6 changed files with 747 additions and 41 deletions

View File

@@ -1,5 +1,12 @@
"""Entrypoint to using [middleware](https://docs.langchain.com/oss/python/langchain/middleware) plugins with [Agents](https://docs.langchain.com/oss/python/langchain/agents).""" # noqa: E501
from ._context import (
ContextCondition,
ContextFraction,
ContextMessages,
ContextSize,
ContextTokens,
)
from .context_editing import (
ClearToolUsesEdit,
ContextEditingMiddleware,
@@ -45,7 +52,12 @@ __all__ = [
"AgentState",
"ClearToolUsesEdit",
"CodexSandboxExecutionPolicy",
"ContextCondition",
"ContextEditingMiddleware",
"ContextFraction",
"ContextMessages",
"ContextSize",
"ContextTokens",
"DockerExecutionPolicy",
"FilesystemFileSearchMiddleware",
"HostExecutionPolicy",

View File

@@ -0,0 +1,30 @@
"""Shared context size types for middleware that manages message history."""
from typing import Literal
ContextFraction = tuple[Literal["fraction"], float]
"""Tuple specifying context size as a fraction of the model's context window."""
ContextTokens = tuple[Literal["tokens"], int]
"""Tuple specifying context size as a number of tokens."""
ContextMessages = tuple[Literal["messages"], int]
"""Tuple specifying context size as a number of messages."""
ContextSize = ContextFraction | ContextTokens | ContextMessages
"""Context size tuple to specify how much history to preserve or trigger conditions."""
ContextCondition = ContextSize | list[ContextSize | list[ContextSize]]
"""Recursive type to support nested AND/OR conditions.
Top-level list = OR logic, nested list = AND logic.
"""
__all__ = [
"ContextCondition",
"ContextFraction",
"ContextMessages",
"ContextSize",
"ContextTokens",
]

View File

@@ -9,9 +9,9 @@ chat model.
from __future__ import annotations
from collections.abc import Awaitable, Callable, Iterable, Sequence
from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence
from dataclasses import dataclass
from typing import Literal
from typing import TYPE_CHECKING, Any, Literal
from langchain_core.messages import (
AIMessage,
@@ -30,6 +30,9 @@ from langchain.agents.middleware.types import (
ModelResponse,
)
if TYPE_CHECKING:
from langchain.agents.middleware._context import ContextCondition, ContextSize
DEFAULT_TOOL_PLACEHOLDER = "[cleared]"
@@ -54,16 +57,37 @@ class ContextEdit(Protocol):
@dataclass(slots=True)
class ClearToolUsesEdit(ContextEdit):
"""Configuration for clearing tool outputs when token limits are exceeded."""
"""Configuration for clearing tool outputs when token limits are exceeded.
trigger: int = 100_000
"""Token count that triggers the edit."""
Supports flexible trigger and keep configurations using `ContextSize` tuples or
backwards-compatible integer values.
"""
trigger: ContextCondition | int = 100_000
"""Trigger condition(s) for when the edit should run.
Supports flexible AND/OR logic via nested lists:
- Single condition: `("messages", 50)` or `("tokens", 3000)`
- OR conditions: `[("tokens", 3000), ("messages", 100)]`
- AND conditions: `[("tokens", 500), ("fraction", 0.8)]` as nested list
- Mixed AND/OR: `[("messages", 10), [("tokens", 500), ("fraction", 0.8)]]`
For backwards compatibility, also accepts an integer token count.
"""
clear_at_least: int = 0
"""Minimum number of tokens to reclaim when the edit runs."""
keep: int = 3
"""Number of most recent tool results that must be preserved."""
keep: ContextSize | int = 3
"""Context retention policy for tool results.
Provide a `ContextSize` tuple to specify how much history to preserve:
- `("messages", 3)` - Keep last 3 tool results
- `("tokens", 1000)` - Keep tool results within token budget
- `("fraction", 0.3)` - Keep tool results within 30% of model's max tokens
For backwards compatibility, also accepts an integer message count.
"""
clear_tool_inputs: bool = False
"""Whether to clear the originating tool call parameters on the AI message."""
@@ -74,6 +98,66 @@ class ClearToolUsesEdit(ContextEdit):
placeholder: str = DEFAULT_TOOL_PLACEHOLDER
"""Placeholder text inserted for cleared tool outputs."""
model: Any = None
"""Optional model instance for fractional token limits."""
_trigger_conditions: list[ContextSize | list[ContextSize]] | None = None
_keep_normalized: ContextSize | None = None
_trigger_as_int: int | None = None
_keep_as_int: int | None = None
def __post_init__(self) -> None:
"""Validate and normalize trigger/keep parameters."""
# Normalize trigger
if isinstance(self.trigger, int):
self._trigger_as_int = self.trigger
self._trigger_conditions = None
elif isinstance(self.trigger, tuple):
# Single ContextSize
self._validate_context_size(self.trigger, "trigger")
self._trigger_conditions = [self.trigger]
self._trigger_as_int = None
elif isinstance(self.trigger, list):
# List of conditions
self._trigger_conditions = self._validate_trigger_conditions(self.trigger)
self._trigger_as_int = None
else:
msg = f"trigger must be int or ContextCondition, got {type(self.trigger).__name__}"
raise TypeError(msg)
# Normalize keep
if isinstance(self.keep, int):
self._keep_as_int = self.keep
self._keep_normalized = None
elif isinstance(self.keep, tuple):
self._validate_context_size(self.keep, "keep")
self._keep_normalized = self.keep
self._keep_as_int = None
else:
msg = f"keep must be int or ContextSize, got {type(self.keep).__name__}"
raise TypeError(msg)
# Check if model profile is required
requires_profile = False
if self._trigger_conditions:
requires_profile = self._requires_profile(self._trigger_conditions)
if self._keep_normalized and self._keep_normalized[0] == "fraction":
requires_profile = True
if requires_profile and self.model is None:
msg = (
"model parameter is required when using fractional token limits. "
"Pass a model instance or use absolute token/message counts instead."
)
raise ValueError(msg)
if requires_profile and self._get_profile_limits() is None:
msg = (
"Model profile information is required to use fractional token limits. "
'pip install "langchain[model-profiles]" or use absolute token counts instead.'
)
raise ValueError(msg)
def apply(
self,
messages: list[AnyMessage],
@@ -83,17 +167,20 @@ class ClearToolUsesEdit(ContextEdit):
"""Apply the clear-tool-uses strategy."""
tokens = count_tokens(messages)
if tokens <= self.trigger:
if not self._should_trigger(messages, tokens):
return
candidates = [
(idx, msg) for idx, msg in enumerate(messages) if isinstance(msg, ToolMessage)
]
if self.keep >= len(candidates):
# Determine how many tool results to keep
keep_count = self._determine_keep_count(messages, tokens)
if keep_count >= len(candidates):
candidates = []
elif self.keep:
candidates = candidates[: -self.keep]
elif keep_count > 0:
candidates = candidates[:-keep_count]
cleared_tokens = 0
excluded_tools = set(self.exclude_tools)
@@ -181,6 +268,191 @@ class ClearToolUsesEdit(ContextEdit):
}
)
def _should_trigger(self, messages: list[AnyMessage], total_tokens: int) -> bool:
"""Determine whether the edit should trigger based on current state."""
# Backwards compatibility: int trigger
if self._trigger_as_int is not None:
return total_tokens > self._trigger_as_int
# New API: ContextCondition with AND/OR logic
if not self._trigger_conditions:
return False
# OR logic across top-level conditions
for condition in self._trigger_conditions:
if isinstance(condition, list):
# AND group - all must be satisfied
if self._check_and_group(condition, messages, total_tokens):
return True
elif self._check_single_condition(condition, messages, total_tokens):
# Single condition
return True
return False
def _check_and_group(
self, and_group: list[ContextSize], messages: list[AnyMessage], total_tokens: int
) -> bool:
"""Check if all conditions in an AND group are satisfied."""
for condition in and_group:
if not self._check_single_condition(condition, messages, total_tokens):
return False
return True
def _check_single_condition(
self, condition: ContextSize, messages: list[AnyMessage], total_tokens: int
) -> bool:
"""Check if a single condition is satisfied."""
kind, value = condition
if kind == "messages":
return len(messages) >= value
if kind == "tokens":
return total_tokens >= value
if kind == "fraction":
max_input_tokens = self._get_profile_limits()
if max_input_tokens is None:
return False
threshold = int(max_input_tokens * value)
if threshold <= 0:
threshold = 1
return total_tokens >= threshold
return False
def _determine_keep_count(self, messages: list[AnyMessage], total_tokens: int) -> int: # noqa: ARG002
"""Determine how many tool results to keep based on keep configuration."""
# Backwards compatibility: int keep
if self._keep_as_int is not None:
return self._keep_as_int
# New API: ContextSize
if self._keep_normalized is None:
return 0
kind, value = self._keep_normalized
if kind == "messages":
return int(value)
if kind in {"tokens", "fraction"}:
# For token-based keep, we need to count backwards through tool messages
# to find how many fit within the budget
target_tokens = self._get_target_token_count(value, kind)
if target_tokens is None:
return 0
return self._count_tool_messages_within_budget(messages, target_tokens)
return 0
def _get_target_token_count(self, value: float, kind: str) -> int | None:
"""Get the target token count for token/fraction-based keep."""
if kind == "fraction":
max_input_tokens = self._get_profile_limits()
if max_input_tokens is None:
return None
target = int(max_input_tokens * value)
elif kind == "tokens":
target = int(value)
else:
return None
return max(1, target) if target > 0 else 1
def _count_tool_messages_within_budget(
self, messages: list[AnyMessage], target_tokens: int
) -> int:
"""Count how many recent tool messages fit within token budget."""
tool_messages = [msg for msg in messages if isinstance(msg, ToolMessage)]
if not tool_messages:
return 0
# Count backwards from the end
count = 0
accumulated_tokens = 0
for tool_msg in reversed(tool_messages):
# Approximate token count for this message
msg_tokens = len(str(tool_msg.content))
if accumulated_tokens + msg_tokens > target_tokens and count > 0:
break
accumulated_tokens += msg_tokens
count += 1
return count
def _get_profile_limits(self) -> int | None:
"""Retrieve max input token limit from the model profile."""
if self.model is None:
return None
try:
profile = self.model.profile
except (AttributeError, ImportError):
return None
if not isinstance(profile, Mapping):
return None
max_input_tokens = profile.get("max_input_tokens")
if not isinstance(max_input_tokens, int):
return None
return max_input_tokens
def _validate_context_size(self, context: ContextSize, parameter_name: str) -> ContextSize:
"""Validate context configuration tuples."""
kind, value = context
if kind == "fraction":
if not 0 < value <= 1:
msg = f"Fractional {parameter_name} values must be between 0 and 1, got {value}."
raise ValueError(msg)
elif kind in {"tokens", "messages"}:
# For trigger, value must be > 0. For keep, value can be >= 0 (0 means keep nothing)
if parameter_name == "trigger" and value <= 0:
msg = f"{parameter_name} thresholds must be greater than 0, got {value}."
raise ValueError(msg)
if parameter_name == "keep" and value < 0:
msg = f"{parameter_name} values must be non-negative, got {value}."
raise ValueError(msg)
else:
msg = f"Unsupported context size type {kind} for {parameter_name}."
raise ValueError(msg)
return context
def _validate_trigger_conditions(
self, conditions: list[Any]
) -> list[ContextSize | list[ContextSize]]:
"""Validate and normalize trigger conditions with nested AND/OR logic.
Args:
conditions: List of ContextSize tuples or nested lists of ContextSize tuples.
Returns:
Validated list where top-level items are OR'd and nested lists are AND'd.
"""
validated: list[ContextSize | list[ContextSize]] = []
for item in conditions:
if isinstance(item, tuple):
# Single condition (tuple)
validated.append(self._validate_context_size(item, "trigger"))
elif isinstance(item, list):
# AND group (nested list)
if not item:
msg = "Empty AND groups are not allowed in trigger conditions."
raise ValueError(msg)
and_group = [self._validate_context_size(cond, "trigger") for cond in item]
validated.append(and_group)
else:
msg = f"Trigger conditions must be tuples or lists, got {type(item).__name__}."
raise ValueError(msg)
return validated
def _requires_profile(self, conditions: list[ContextSize | list[ContextSize]]) -> bool:
"""Check if any condition requires model profile information."""
for condition in conditions:
if isinstance(condition, list):
# AND group
if any(c[0] == "fraction" for c in condition):
return True
elif condition[0] == "fraction":
return True
return False
class ContextEditingMiddleware(AgentMiddleware):
"""Automatically prune tool results to manage context size.

View File

@@ -3,7 +3,7 @@
import uuid
import warnings
from collections.abc import Callable, Iterable, Mapping
from typing import Any, Literal, cast
from typing import Any, cast
from langchain_core.messages import (
AIMessage,
@@ -19,6 +19,10 @@ from langgraph.graph.message import (
)
from langgraph.runtime import Runtime
from langchain.agents.middleware._context import (
ContextCondition,
ContextSize,
)
from langchain.agents.middleware.types import AgentMiddleware, AgentState
from langchain.chat_models import BaseChatModel, init_chat_model
@@ -57,12 +61,6 @@ _DEFAULT_TRIM_TOKEN_LIMIT = 4000
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
_SEARCH_RANGE_FOR_TOOL_PAIRS = 5
ContextFraction = tuple[Literal["fraction"], float]
ContextTokens = tuple[Literal["tokens"], int]
ContextMessages = tuple[Literal["messages"], int]
ContextSize = ContextFraction | ContextTokens | ContextMessages
class SummarizationMiddleware(AgentMiddleware):
"""Summarizes conversation history when token limits are approached.
@@ -76,7 +74,7 @@ class SummarizationMiddleware(AgentMiddleware):
self,
model: str | BaseChatModel,
*,
trigger: ContextSize | list[ContextSize] | None = None,
trigger: ContextCondition | None = None,
keep: ContextSize = ("messages", _DEFAULT_MESSAGES_TO_KEEP),
token_counter: TokenCounter = count_tokens_approximately,
summary_prompt: str = DEFAULT_SUMMARY_PROMPT,
@@ -87,10 +85,16 @@ class SummarizationMiddleware(AgentMiddleware):
Args:
model: The language model to use for generating summaries.
trigger: One or more thresholds that trigger summarization. Provide a single
`ContextSize` tuple or a list of tuples, in which case summarization runs
when any threshold is breached. Examples: `("messages", 50)`, `("tokens", 3000)`,
`[("fraction", 0.8), ("messages", 100)]`.
trigger: One or more thresholds that trigger summarization. Supports flexible
AND/OR logic via nested lists. Top-level list items are combined with OR,
nested lists are combined with AND. Examples:
- Single condition: `("messages", 50)`
- OR conditions: `[("tokens", 3000), ("messages", 100)]` (triggers when
tokens >= 3000 OR messages >= 100)
- AND conditions: `[("tokens", 500), ("fraction", 0.8)]` as a nested list
within the top-level list
- Mixed AND/OR: `[("messages", 10), [("tokens", 500), ("fraction", 0.8)]]`
(triggers when messages >= 10 OR (tokens >= 500 AND fraction >= 0.8))
keep: Context retention policy applied after summarization. Provide a
`ContextSize` tuple to specify how much history to preserve. Defaults to
keeping the most recent 20 messages. Examples: `("messages", 20)`,
@@ -128,13 +132,15 @@ class SummarizationMiddleware(AgentMiddleware):
self.model = model
if trigger is None:
self.trigger: ContextSize | list[ContextSize] | None = None
trigger_conditions: list[ContextSize] = []
self.trigger: ContextCondition | None = None
trigger_conditions: list[ContextSize | list[ContextSize]] = []
elif isinstance(trigger, list):
validated_list = [self._validate_context_size(item, "trigger") for item in trigger]
# Validate and normalize nested structure
validated_list = self._validate_trigger_conditions(trigger)
self.trigger = validated_list
trigger_conditions = validated_list
else:
# Single ContextSize tuple
validated = self._validate_context_size(trigger, "trigger")
self.trigger = validated
trigger_conditions = [validated]
@@ -145,7 +151,7 @@ class SummarizationMiddleware(AgentMiddleware):
self.summary_prompt = summary_prompt
self.trim_tokens_to_summarize = trim_tokens_to_summarize
requires_profile = any(condition[0] == "fraction" for condition in self._trigger_conditions)
requires_profile = self._requires_profile(self._trigger_conditions)
if self.keep[0] == "fraction":
requires_profile = True
if requires_profile and self._get_profile_limits() is None:
@@ -211,24 +217,52 @@ class SummarizationMiddleware(AgentMiddleware):
}
def _should_summarize(self, messages: list[AnyMessage], total_tokens: int) -> bool:
"""Determine whether summarization should run for the current token usage."""
"""Determine whether summarization should run for the current token usage.
Evaluates trigger conditions with AND/OR logic:
- Top-level items are OR'd together
- Nested lists are AND'd together
"""
if not self._trigger_conditions:
return False
for kind, value in self._trigger_conditions:
if kind == "messages" and len(messages) >= value:
return True
if kind == "tokens" and total_tokens >= value:
return True
if kind == "fraction":
max_input_tokens = self._get_profile_limits()
if max_input_tokens is None:
continue
threshold = int(max_input_tokens * value)
if threshold <= 0:
threshold = 1
if total_tokens >= threshold:
# OR logic across top-level conditions
for condition in self._trigger_conditions:
if isinstance(condition, list):
# AND group - all must be satisfied
if self._check_and_group(condition, messages, total_tokens):
return True
elif self._check_single_condition(condition, messages, total_tokens):
# Single condition
return True
return False
def _check_and_group(
self, and_group: list[ContextSize], messages: list[AnyMessage], total_tokens: int
) -> bool:
"""Check if all conditions in an AND group are satisfied."""
for condition in and_group:
if not self._check_single_condition(condition, messages, total_tokens):
return False
return True
def _check_single_condition(
self, condition: ContextSize, messages: list[AnyMessage], total_tokens: int
) -> bool:
"""Check if a single condition is satisfied."""
kind, value = condition
if kind == "messages":
return len(messages) >= value
if kind == "tokens":
return total_tokens >= value
if kind == "fraction":
max_input_tokens = self._get_profile_limits()
if max_input_tokens is None:
return False
threshold = int(max_input_tokens * value)
if threshold <= 0:
threshold = 1
return total_tokens >= threshold
return False
def _determine_cutoff_index(self, messages: list[AnyMessage]) -> int:
@@ -328,6 +362,45 @@ class SummarizationMiddleware(AgentMiddleware):
raise ValueError(msg)
return context
def _validate_trigger_conditions(
self, conditions: list[Any]
) -> list[ContextSize | list[ContextSize]]:
"""Validate and normalize trigger conditions with nested AND/OR logic.
Args:
conditions: List of ContextSize tuples or nested lists of ContextSize tuples.
Returns:
Validated list where top-level items are OR'd and nested lists are AND'd.
"""
validated: list[ContextSize | list[ContextSize]] = []
for item in conditions:
if isinstance(item, tuple):
# Single condition (tuple)
validated.append(self._validate_context_size(item, "trigger"))
elif isinstance(item, list):
# AND group (nested list)
if not item:
msg = "Empty AND groups are not allowed in trigger conditions."
raise ValueError(msg)
and_group = [self._validate_context_size(cond, "trigger") for cond in item]
validated.append(and_group)
else:
msg = f"Trigger conditions must be tuples or lists, got {type(item).__name__}."
raise ValueError(msg)
return validated
def _requires_profile(self, conditions: list[ContextSize | list[ContextSize]]) -> bool:
"""Check if any condition requires model profile information."""
for condition in conditions:
if isinstance(condition, list):
# AND group
if any(c[0] == "fraction" for c in condition):
return True
elif condition[0] == "fraction":
return True
return False
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
return [
HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}")

View File

@@ -399,3 +399,154 @@ async def test_exclude_tools_prevents_clearing_async() -> None:
assert isinstance(calc_tool, ToolMessage)
assert calc_tool.content == "[cleared]"
# New API tests
def test_new_api_trigger_with_context_size_tuple() -> None:
"""Test new API with ContextSize tuple for trigger."""
tool_call_id = "call-1"
ai_message = AIMessage(
content="",
tool_calls=[{"id": tool_call_id, "name": "search", "args": {}}],
)
tool_message = ToolMessage(content="x" * 200, tool_call_id=tool_call_id)
state, request = _make_state_and_request([ai_message, tool_message])
# Use new API with tuple
edit = ClearToolUsesEdit(
trigger=("tokens", 50),
keep=("messages", 0),
placeholder="[cleared]",
)
middleware = ContextEditingMiddleware(edits=[edit])
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
middleware.wrap_model_call(request, mock_handler)
cleared_tool = request.messages[1]
assert isinstance(cleared_tool, ToolMessage)
assert cleared_tool.content == "[cleared]"
def test_new_api_keep_with_messages_tuple() -> None:
"""Test new API with messages-based keep."""
conversation: list[AIMessage | ToolMessage] = []
for i in range(5):
call_id = f"call-{i}"
conversation.append(
AIMessage(
content="",
tool_calls=[{"id": call_id, "name": "tool", "args": {}}],
)
)
conversation.append(ToolMessage(content="x" * 50, tool_call_id=call_id))
state, request = _make_state_and_request(conversation)
edit = ClearToolUsesEdit(
trigger=("tokens", 50),
keep=("messages", 2), # Keep last 2 tool results
placeholder="[cleared]",
)
middleware = ContextEditingMiddleware(edits=[edit])
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
middleware.wrap_model_call(request, mock_handler)
# Check that first 3 tool messages are cleared, last 2 are preserved
tool_messages = [msg for msg in request.messages if isinstance(msg, ToolMessage)]
cleared = [msg for msg in tool_messages if msg.content == "[cleared]"]
preserved = [msg for msg in tool_messages if msg.content != "[cleared]"]
assert len(cleared) == 3
assert len(preserved) == 2
def test_new_api_or_conditions() -> None:
"""Test new API with OR trigger conditions."""
tool_call_id = "call-1"
ai_message = AIMessage(
content="",
tool_calls=[{"id": tool_call_id, "name": "search", "args": {}}],
)
tool_message = ToolMessage(content="x" * 200, tool_call_id=tool_call_id)
state, request = _make_state_and_request([ai_message, tool_message])
# Use OR conditions: triggers if tokens >= 50 OR messages >= 100
edit = ClearToolUsesEdit(
trigger=[("tokens", 50), ("messages", 100)],
keep=("messages", 0),
placeholder="[cleared]",
)
middleware = ContextEditingMiddleware(edits=[edit])
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
middleware.wrap_model_call(request, mock_handler)
# Should trigger because tokens >= 50 (even though messages < 100)
cleared_tool = request.messages[1]
assert isinstance(cleared_tool, ToolMessage)
assert cleared_tool.content == "[cleared]"
def test_new_api_backwards_compatibility() -> None:
"""Test that old integer API still works."""
tool_call_id = "call-1"
ai_message = AIMessage(
content="",
tool_calls=[{"id": tool_call_id, "name": "search", "args": {}}],
)
tool_message = ToolMessage(content="x" * 200, tool_call_id=tool_call_id)
state, request = _make_state_and_request([ai_message, tool_message])
# Old API with integers
edit = ClearToolUsesEdit(
trigger=50, # int
keep=0, # int
placeholder="[cleared]",
)
middleware = ContextEditingMiddleware(edits=[edit])
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
middleware.wrap_model_call(request, mock_handler)
cleared_tool = request.messages[1]
assert isinstance(cleared_tool, ToolMessage)
assert cleared_tool.content == "[cleared]"
def test_new_api_validation_errors() -> None:
"""Test that validation errors are raised for invalid configurations."""
# Test invalid fraction value
try:
ClearToolUsesEdit(trigger=("fraction", 1.5), keep=("messages", 3))
assert False, "Should have raised ValueError"
except ValueError as e:
assert "must be between 0 and 1" in str(e)
# Test invalid token count
try:
ClearToolUsesEdit(trigger=("tokens", -1), keep=("messages", 3))
assert False, "Should have raised ValueError"
except ValueError as e:
assert "must be greater than 0" in str(e)
# Test unsupported type
try:
ClearToolUsesEdit(trigger=("invalid", 100), keep=("messages", 3)) # type: ignore[arg-type]
assert False, "Should have raised ValueError"
except ValueError as e:
assert "Unsupported context size type" in str(e)

View File

@@ -896,3 +896,171 @@ def test_summarization_middleware_is_safe_cutoff_at_end() -> None:
# Cutoff past the length should also be safe
assert middleware._is_safe_cutoff_point(messages, len(messages) + 5)
@pytest.mark.parametrize(
("trigger", "model_class", "tokens_per_message", "num_messages", "should_trigger"),
[
# ========================================================================
# Test Case Group 1: Pure AND logic
# Trigger condition: messages >= 5 AND tokens >= 300
# ========================================================================
(
[[("messages", 5), ("tokens", 300)]],
MockChatModel,
50, # tokens_per_message
4, # num_messages -> total: 4 msgs, 200 tokens
False, # Expected: neither condition met
),
(
[[("messages", 5), ("tokens", 300)]],
MockChatModel,
50, # tokens_per_message
5, # num_messages -> total: 5 msgs, 250 tokens
False, # Expected: only messages condition met (tokens < 300)
),
(
[[("messages", 5), ("tokens", 300)]],
MockChatModel,
50, # tokens_per_message
7, # num_messages -> total: 7 msgs, 350 tokens
True, # Expected: both conditions met
),
# ========================================================================
# Test Case Group 2: Mixed OR/AND logic
# Trigger condition: messages >= 10 OR (tokens >= 500 AND messages >= 3)
# ========================================================================
(
[("messages", 10), [("tokens", 500), ("messages", 3)]],
MockChatModel,
100, # tokens_per_message
2, # num_messages -> total: 2 msgs, 200 tokens
False, # Expected: no conditions met
),
(
[("messages", 10), [("tokens", 500), ("messages", 3)]],
MockChatModel,
100, # tokens_per_message
5, # num_messages -> total: 5 msgs, 500 tokens
True, # Expected: AND group satisfied (tokens >= 500 AND messages >= 3)
),
(
[("messages", 10), [("tokens", 500), ("messages", 3)]],
MockChatModel,
10, # tokens_per_message
10, # num_messages -> total: 10 msgs, 100 tokens
True, # Expected: first OR condition satisfied (messages >= 10)
),
# ========================================================================
# Test Case Group 3: AND logic with fraction threshold
# Trigger condition: tokens >= 400 AND fraction >= 0.5
# Note: ProfileChatModel has max_input_tokens=1000, so 0.5 * 1000 = 500
# ========================================================================
(
[[("tokens", 400), ("fraction", 0.5)]],
ProfileChatModel,
100, # tokens_per_message
4, # num_messages -> total: 400 tokens
False, # Expected: 400 tokens < 500 fraction threshold
),
(
[[("tokens", 400), ("fraction", 0.5)]],
ProfileChatModel,
100, # tokens_per_message
6, # num_messages -> total: 600 tokens
True, # Expected: 600 >= 400 AND 600 >= 500 (both conditions met)
),
],
)
def test_summarization_middleware_and_or_logic(
trigger: list,
model_class: type[BaseChatModel],
tokens_per_message: int,
num_messages: int,
should_trigger: bool,
) -> None:
"""Test middleware with AND/OR logic combinations."""
middleware = SummarizationMiddleware(
model=model_class(),
trigger=trigger,
keep=("messages", 2),
)
middleware.token_counter = lambda messages: len(messages) * tokens_per_message
messages = [HumanMessage(content=str(i)) for i in range(num_messages)]
state = {"messages": messages}
result = middleware.before_model(state, None)
if should_trigger:
assert result is not None
else:
assert result is None
def test_summarization_middleware_nested_validation() -> None:
"""Test validation of nested AND groups."""
model = MockChatModel()
# Valid nested list
middleware = SummarizationMiddleware(
model=model, trigger=[("messages", 5), [("tokens", 100), ("messages", 3)]]
)
assert middleware.trigger is not None
# Empty AND group should raise error
with pytest.raises(ValueError, match="Empty AND groups are not allowed"):
SummarizationMiddleware(model=model, trigger=[("messages", 5), []])
# Invalid type in trigger list
with pytest.raises(ValueError, match="Trigger conditions must be tuples or lists"):
SummarizationMiddleware(model=model, trigger=[("messages", 5), "invalid"])
# Invalid tuple in AND group
with pytest.raises(ValueError, match="Unsupported context size type"):
SummarizationMiddleware(model=model, trigger=[[("invalid", 100)]])
@pytest.mark.parametrize(
("tokens_per_message", "num_messages", "should_trigger"),
[
(100, 2, False), # 2 messages, 200 tokens - no conditions met
(
100,
5,
True,
), # 5 messages, 500 tokens - second AND group satisfied (tokens >= 300 AND messages >= 5)
(100, 4, False), # 4 messages, 400 tokens - tokens < 500 for third AND group
(5, 20, True), # 20 messages, 100 tokens - first OR condition satisfied
],
)
def test_summarization_middleware_complex_and_or(
tokens_per_message: int,
num_messages: int,
should_trigger: bool,
) -> None:
"""Test complex AND/OR combinations.
Trigger when:
- messages >= 20 OR
- (tokens >= 300 AND messages >= 5) OR
- (tokens >= 500 AND messages >= 3)
"""
middleware = SummarizationMiddleware(
model=MockChatModel(),
trigger=[
("messages", 20),
[("tokens", 300), ("messages", 5)],
[("tokens", 500), ("messages", 3)],
],
keep=("messages", 2),
)
middleware.token_counter = lambda messages: len(messages) * tokens_per_message
messages = [HumanMessage(content=str(i)) for i in range(num_messages)]
state = {"messages": messages}
result = middleware.before_model(state, None)
if should_trigger:
assert result is not None
else:
assert result is None