Compare commits

...

7 Commits

Author SHA1 Message Date
Sydney Runkle
6ba1177f4f Merge branch 'master' into sr/looser-api-for-summarization 2025-11-17 08:55:42 -05:00
Sydney Runkle
62c05e09c1 readable tests 2025-11-14 14:34:24 -05:00
Sydney Runkle
83b9d9f810 parametrize 2025-11-14 14:31:28 -05:00
Sydney Runkle
6deee23d8d Revert "cleanup"
This reverts commit 690aabe8d4.
2025-11-14 14:14:32 -05:00
Sydney Runkle
690aabe8d4 cleanup 2025-11-14 14:12:14 -05:00
Sydney Runkle
80554df1e6 docs 2025-11-14 14:00:15 -05:00
Sydney Runkle
72c45e65e8 summarization edits 2025-11-14 13:56:26 -05:00
2 changed files with 275 additions and 26 deletions

View File

@@ -58,10 +58,19 @@ _DEFAULT_FALLBACK_MESSAGE_COUNT = 15
_SEARCH_RANGE_FOR_TOOL_PAIRS = 5
ContextFraction = tuple[Literal["fraction"], float]
"""Tuple specifying context size as a fraction of the model's context window."""
ContextTokens = tuple[Literal["tokens"], int]
"""Tuple specifying context size as a number of tokens."""
ContextMessages = tuple[Literal["messages"], int]
"""Tuple specifying context size as a number of messages."""
ContextSize = ContextFraction | ContextTokens | ContextMessages
"""Context size tuple to specify how much history to preserve."""
ContextCondition = ContextSize | list[ContextSize | list[ContextSize]]
"""Recursive type to support nested AND/OR conditions
Top-level list = OR logic, nested list = AND logic."""
class SummarizationMiddleware(AgentMiddleware):
@@ -76,7 +85,7 @@ class SummarizationMiddleware(AgentMiddleware):
self,
model: str | BaseChatModel,
*,
trigger: ContextSize | list[ContextSize] | None = None,
trigger: ContextCondition | None = None,
keep: ContextSize = ("messages", _DEFAULT_MESSAGES_TO_KEEP),
token_counter: TokenCounter = count_tokens_approximately,
summary_prompt: str = DEFAULT_SUMMARY_PROMPT,
@@ -87,13 +96,16 @@ class SummarizationMiddleware(AgentMiddleware):
Args:
model: The language model to use for generating summaries.
trigger: One or more thresholds that trigger summarization.
Provide a single `ContextSize` tuple or a list of tuples, in which case
summarization runs when any threshold is breached.
Examples: `("messages", 50)`, `("tokens", 3000)`, `[("fraction", 0.8),
("messages", 100)]`.
trigger: One or more thresholds that trigger summarization. Supports flexible
AND/OR logic via nested lists. Top-level list items are combined with OR,
nested lists are combined with AND. Examples:
- Single condition: `("messages", 50)`
- OR conditions: `[("tokens", 3000), ("messages", 100)]` (triggers when
tokens >= 3000 OR messages >= 100)
- AND conditions: `[("tokens", 500), ("fraction", 0.8)]` as a nested list
within the top-level list
- Mixed AND/OR: `[("messages", 10), [("tokens", 500), ("fraction", 0.8)]]`
(triggers when messages >= 10 OR (tokens >= 500 AND fraction >= 0.8))
keep: Context retention policy applied after summarization.
Provide a `ContextSize` tuple to specify how much history to preserve.
@@ -137,13 +149,15 @@ class SummarizationMiddleware(AgentMiddleware):
self.model = model
if trigger is None:
self.trigger: ContextSize | list[ContextSize] | None = None
trigger_conditions: list[ContextSize] = []
self.trigger: ContextCondition | None = None
trigger_conditions: list[ContextSize | list[ContextSize]] = []
elif isinstance(trigger, list):
validated_list = [self._validate_context_size(item, "trigger") for item in trigger]
# Validate and normalize nested structure
validated_list = self._validate_trigger_conditions(trigger)
self.trigger = validated_list
trigger_conditions = validated_list
else:
# Single ContextSize tuple
validated = self._validate_context_size(trigger, "trigger")
self.trigger = validated
trigger_conditions = [validated]
@@ -154,7 +168,7 @@ class SummarizationMiddleware(AgentMiddleware):
self.summary_prompt = summary_prompt
self.trim_tokens_to_summarize = trim_tokens_to_summarize
requires_profile = any(condition[0] == "fraction" for condition in self._trigger_conditions)
requires_profile = self._requires_profile(self._trigger_conditions)
if self.keep[0] == "fraction":
requires_profile = True
if requires_profile and self._get_profile_limits() is None:
@@ -220,24 +234,52 @@ class SummarizationMiddleware(AgentMiddleware):
}
def _should_summarize(self, messages: list[AnyMessage], total_tokens: int) -> bool:
"""Determine whether summarization should run for the current token usage."""
"""Determine whether summarization should run for the current token usage.
Evaluates trigger conditions with AND/OR logic:
- Top-level items are OR'd together
- Nested lists are AND'd together
"""
if not self._trigger_conditions:
return False
for kind, value in self._trigger_conditions:
if kind == "messages" and len(messages) >= value:
return True
if kind == "tokens" and total_tokens >= value:
return True
if kind == "fraction":
max_input_tokens = self._get_profile_limits()
if max_input_tokens is None:
continue
threshold = int(max_input_tokens * value)
if threshold <= 0:
threshold = 1
if total_tokens >= threshold:
# OR logic across top-level conditions
for condition in self._trigger_conditions:
if isinstance(condition, list):
# AND group - all must be satisfied
if self._check_and_group(condition, messages, total_tokens):
return True
elif self._check_single_condition(condition, messages, total_tokens):
# Single condition
return True
return False
def _check_and_group(
self, and_group: list[ContextSize], messages: list[AnyMessage], total_tokens: int
) -> bool:
"""Check if all conditions in an AND group are satisfied."""
for condition in and_group:
if not self._check_single_condition(condition, messages, total_tokens):
return False
return True
def _check_single_condition(
self, condition: ContextSize, messages: list[AnyMessage], total_tokens: int
) -> bool:
"""Check if a single condition is satisfied."""
kind, value = condition
if kind == "messages":
return len(messages) >= value
if kind == "tokens":
return total_tokens >= value
if kind == "fraction":
max_input_tokens = self._get_profile_limits()
if max_input_tokens is None:
return False
threshold = int(max_input_tokens * value)
if threshold <= 0:
threshold = 1
return total_tokens >= threshold
return False
def _determine_cutoff_index(self, messages: list[AnyMessage]) -> int:
@@ -337,6 +379,45 @@ class SummarizationMiddleware(AgentMiddleware):
raise ValueError(msg)
return context
def _validate_trigger_conditions(
self, conditions: list[Any]
) -> list[ContextSize | list[ContextSize]]:
"""Validate and normalize trigger conditions with nested AND/OR logic.
Args:
conditions: List of ContextSize tuples or nested lists of ContextSize tuples.
Returns:
Validated list where top-level items are OR'd and nested lists are AND'd.
"""
validated: list[ContextSize | list[ContextSize]] = []
for item in conditions:
if isinstance(item, tuple):
# Single condition (tuple)
validated.append(self._validate_context_size(item, "trigger"))
elif isinstance(item, list):
# AND group (nested list)
if not item:
msg = "Empty AND groups are not allowed in trigger conditions."
raise ValueError(msg)
and_group = [self._validate_context_size(cond, "trigger") for cond in item]
validated.append(and_group)
else:
msg = f"Trigger conditions must be tuples or lists, got {type(item).__name__}."
raise ValueError(msg)
return validated
def _requires_profile(self, conditions: list[ContextSize | list[ContextSize]]) -> bool:
"""Check if any condition requires model profile information."""
for condition in conditions:
if isinstance(condition, list):
# AND group
if any(c[0] == "fraction" for c in condition):
return True
elif condition[0] == "fraction":
return True
return False
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
return [
HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}")

View File

@@ -896,3 +896,171 @@ def test_summarization_middleware_is_safe_cutoff_at_end() -> None:
# Cutoff past the length should also be safe
assert middleware._is_safe_cutoff_point(messages, len(messages) + 5)
@pytest.mark.parametrize(
("trigger", "model_class", "tokens_per_message", "num_messages", "should_trigger"),
[
# ========================================================================
# Test Case Group 1: Pure AND logic
# Trigger condition: messages >= 5 AND tokens >= 300
# ========================================================================
(
[[("messages", 5), ("tokens", 300)]],
MockChatModel,
50, # tokens_per_message
4, # num_messages -> total: 4 msgs, 200 tokens
False, # Expected: neither condition met
),
(
[[("messages", 5), ("tokens", 300)]],
MockChatModel,
50, # tokens_per_message
5, # num_messages -> total: 5 msgs, 250 tokens
False, # Expected: only messages condition met (tokens < 300)
),
(
[[("messages", 5), ("tokens", 300)]],
MockChatModel,
50, # tokens_per_message
7, # num_messages -> total: 7 msgs, 350 tokens
True, # Expected: both conditions met
),
# ========================================================================
# Test Case Group 2: Mixed OR/AND logic
# Trigger condition: messages >= 10 OR (tokens >= 500 AND messages >= 3)
# ========================================================================
(
[("messages", 10), [("tokens", 500), ("messages", 3)]],
MockChatModel,
100, # tokens_per_message
2, # num_messages -> total: 2 msgs, 200 tokens
False, # Expected: no conditions met
),
(
[("messages", 10), [("tokens", 500), ("messages", 3)]],
MockChatModel,
100, # tokens_per_message
5, # num_messages -> total: 5 msgs, 500 tokens
True, # Expected: AND group satisfied (tokens >= 500 AND messages >= 3)
),
(
[("messages", 10), [("tokens", 500), ("messages", 3)]],
MockChatModel,
10, # tokens_per_message
10, # num_messages -> total: 10 msgs, 100 tokens
True, # Expected: first OR condition satisfied (messages >= 10)
),
# ========================================================================
# Test Case Group 3: AND logic with fraction threshold
# Trigger condition: tokens >= 400 AND fraction >= 0.5
# Note: ProfileChatModel has max_input_tokens=1000, so 0.5 * 1000 = 500
# ========================================================================
(
[[("tokens", 400), ("fraction", 0.5)]],
ProfileChatModel,
100, # tokens_per_message
4, # num_messages -> total: 400 tokens
False, # Expected: 400 tokens < 500 fraction threshold
),
(
[[("tokens", 400), ("fraction", 0.5)]],
ProfileChatModel,
100, # tokens_per_message
6, # num_messages -> total: 600 tokens
True, # Expected: 600 >= 400 AND 600 >= 500 (both conditions met)
),
],
)
def test_summarization_middleware_and_or_logic(
trigger: list,
model_class: type[BaseChatModel],
tokens_per_message: int,
num_messages: int,
should_trigger: bool,
) -> None:
"""Test middleware with AND/OR logic combinations."""
middleware = SummarizationMiddleware(
model=model_class(),
trigger=trigger,
keep=("messages", 2),
)
middleware.token_counter = lambda messages: len(messages) * tokens_per_message
messages = [HumanMessage(content=str(i)) for i in range(num_messages)]
state = {"messages": messages}
result = middleware.before_model(state, None)
if should_trigger:
assert result is not None
else:
assert result is None
def test_summarization_middleware_nested_validation() -> None:
"""Test validation of nested AND groups."""
model = MockChatModel()
# Valid nested list
middleware = SummarizationMiddleware(
model=model, trigger=[("messages", 5), [("tokens", 100), ("messages", 3)]]
)
assert middleware.trigger is not None
# Empty AND group should raise error
with pytest.raises(ValueError, match="Empty AND groups are not allowed"):
SummarizationMiddleware(model=model, trigger=[("messages", 5), []])
# Invalid type in trigger list
with pytest.raises(ValueError, match="Trigger conditions must be tuples or lists"):
SummarizationMiddleware(model=model, trigger=[("messages", 5), "invalid"])
# Invalid tuple in AND group
with pytest.raises(ValueError, match="Unsupported context size type"):
SummarizationMiddleware(model=model, trigger=[[("invalid", 100)]])
@pytest.mark.parametrize(
("tokens_per_message", "num_messages", "should_trigger"),
[
(100, 2, False), # 2 messages, 200 tokens - no conditions met
(
100,
5,
True,
), # 5 messages, 500 tokens - second AND group satisfied (tokens >= 300 AND messages >= 5)
(100, 4, False), # 4 messages, 400 tokens - tokens < 500 for third AND group
(5, 20, True), # 20 messages, 100 tokens - first OR condition satisfied
],
)
def test_summarization_middleware_complex_and_or(
tokens_per_message: int,
num_messages: int,
should_trigger: bool,
) -> None:
"""Test complex AND/OR combinations.
Trigger when:
- messages >= 20 OR
- (tokens >= 300 AND messages >= 5) OR
- (tokens >= 500 AND messages >= 3)
"""
middleware = SummarizationMiddleware(
model=MockChatModel(),
trigger=[
("messages", 20),
[("tokens", 300), ("messages", 5)],
[("tokens", 500), ("messages", 3)],
],
keep=("messages", 2),
)
middleware.token_counter = lambda messages: len(messages) * tokens_per_message
messages = [HumanMessage(content=str(i)) for i in range(num_messages)]
state = {"messages": messages}
result = middleware.before_model(state, None)
if should_trigger:
assert result is not None
else:
assert result is None