fix: simplify summarization cutoff logic (#34195)

This PR changes how we find the cutoff for summarization, summarizing
content more eagerly if the initial cutoff point isn't safe (ie, would
break apart AI + tool message pairs)

This new algorithm is quite simple - it looks at the initial cutoff
point, if it's not safe, moves forward through the message list until it
finds the first non tool message.

For example:

```
H
AI
TM
--- theoretical cutoff based keep=('messages', 3)
TM
AI
TM
```

```
H
AI
TM
TM
--- actual cutoff, more aggressive summarization
AI
TM
```
This commit is contained in:
Sydney Runkle
2025-12-04 12:44:50 -05:00
committed by GitHub
parent 1ad9de4b45
commit 3030ffc248
2 changed files with 128 additions and 199 deletions

View File

@@ -7,7 +7,6 @@ from functools import partial
from typing import Any, Literal, cast
from langchain_core.messages import (
AIMessage,
AnyMessage,
MessageLikeRepresentation,
RemoveMessage,
@@ -56,7 +55,6 @@ Messages to summarize:
_DEFAULT_MESSAGES_TO_KEEP = 20
_DEFAULT_TRIM_TOKEN_LIMIT = 4000
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
_SEARCH_RANGE_FOR_TOOL_PAIRS = 5
ContextFraction = tuple[Literal["fraction"], float]
"""Fraction of model's maximum input tokens.
@@ -397,11 +395,8 @@ class SummarizationMiddleware(AgentMiddleware):
return 0
cutoff_candidate = len(messages) - 1
for i in range(cutoff_candidate, -1, -1):
if self._is_safe_cutoff_point(messages, i):
return i
return 0
# Advance past any ToolMessages to avoid splitting AI/Tool pairs
return self._find_safe_cutoff_point(messages, cutoff_candidate)
def _get_profile_limits(self) -> int | None:
"""Retrieve max input token limit from the model profile."""
@@ -463,67 +458,26 @@ class SummarizationMiddleware(AgentMiddleware):
Returns the index where messages can be safely cut without separating
related AI and Tool messages. Returns `0` if no safe cutoff is found.
This is aggressive with summarization - if the target cutoff lands in the
middle of tool messages, we advance past all of them (summarizing more).
"""
if len(messages) <= messages_to_keep:
return 0
target_cutoff = len(messages) - messages_to_keep
return self._find_safe_cutoff_point(messages, target_cutoff)
for i in range(target_cutoff, -1, -1):
if self._is_safe_cutoff_point(messages, i):
return i
def _find_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> int:
"""Find a safe cutoff point that doesn't split AI/Tool message pairs.
return 0
def _is_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> bool:
"""Check if cutting at index would separate AI/Tool message pairs."""
if cutoff_index >= len(messages):
return True
search_start = max(0, cutoff_index - _SEARCH_RANGE_FOR_TOOL_PAIRS)
search_end = min(len(messages), cutoff_index + _SEARCH_RANGE_FOR_TOOL_PAIRS)
for i in range(search_start, search_end):
if not self._has_tool_calls(messages[i]):
continue
tool_call_ids = self._extract_tool_call_ids(cast("AIMessage", messages[i]))
if self._cutoff_separates_tool_pair(messages, i, cutoff_index, tool_call_ids):
return False
return True
def _has_tool_calls(self, message: AnyMessage) -> bool:
"""Check if message is an AI message with tool calls."""
return (
isinstance(message, AIMessage) and hasattr(message, "tool_calls") and message.tool_calls # type: ignore[return-value]
)
def _extract_tool_call_ids(self, ai_message: AIMessage) -> set[str]:
"""Extract tool call IDs from an AI message."""
tool_call_ids = set()
for tc in ai_message.tool_calls:
call_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None)
if call_id is not None:
tool_call_ids.add(call_id)
return tool_call_ids
def _cutoff_separates_tool_pair(
self,
messages: list[AnyMessage],
ai_message_index: int,
cutoff_index: int,
tool_call_ids: set[str],
) -> bool:
"""Check if cutoff separates an AI message from its corresponding tool messages."""
for j in range(ai_message_index + 1, len(messages)):
message = messages[j]
if isinstance(message, ToolMessage) and message.tool_call_id in tool_call_ids:
ai_before_cutoff = ai_message_index < cutoff_index
tool_before_cutoff = j < cutoff_index
if ai_before_cutoff != tool_before_cutoff:
return True
return False
If the message at cutoff_index is a ToolMessage, advance until we find
a non-ToolMessage. This ensures we never cut in the middle of parallel
tool call responses.
"""
while cutoff_index < len(messages) and isinstance(messages[cutoff_index], ToolMessage):
cutoff_index += 1
return cutoff_index
def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
"""Generate summary for the given messages."""

View File

@@ -121,46 +121,6 @@ def test_summarization_middleware_helper_methods() -> None:
assert "Here is a summary of the conversation to date:" in new_messages[0].content
assert summary in new_messages[0].content
# Test tool call detection
ai_message_no_tools = AIMessage(content="Hello")
assert not middleware._has_tool_calls(ai_message_no_tools)
ai_message_with_tools = AIMessage(
content="Hello", tool_calls=[{"name": "test", "args": {}, "id": "1"}]
)
assert middleware._has_tool_calls(ai_message_with_tools)
human_message = HumanMessage(content="Hello")
assert not middleware._has_tool_calls(human_message)
def test_summarization_middleware_tool_call_safety() -> None:
"""Test SummarizationMiddleware tool call safety logic."""
model = FakeToolCallingModel()
middleware = SummarizationMiddleware(
model=model, trigger=("tokens", 1000), keep=("messages", 3)
)
# Test safe cutoff point detection with tool calls
messages = [
HumanMessage(content="1"),
AIMessage(content="2", tool_calls=[{"name": "test", "args": {}, "id": "1"}]),
ToolMessage(content="3", tool_call_id="1"),
HumanMessage(content="4"),
]
# Safe cutoff (doesn't separate AI/Tool pair)
is_safe = middleware._is_safe_cutoff_point(messages, 0)
assert is_safe is True
# Unsafe cutoff (separates AI/Tool pair)
is_safe = middleware._is_safe_cutoff_point(messages, 2)
assert is_safe is False
# Test tool call ID extraction
ids = middleware._extract_tool_call_ids(messages[1])
assert ids == {"1"}
def test_summarization_middleware_summary_creation() -> None:
"""Test SummarizationMiddleware summary creation."""
@@ -315,8 +275,8 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
]
def test_summarization_middleware_token_retention_pct_respects_tool_pairs() -> None:
"""Ensure token retention keeps pairs together even if exceeding target tokens."""
def test_summarization_middleware_token_retention_advances_past_tool_messages() -> None:
"""Ensure token retention advances past tool messages for aggressive summarization."""
def token_counter(messages: list[AnyMessage]) -> int:
return sum(len(getattr(message, "content", "")) for message in messages)
@@ -328,6 +288,10 @@ def test_summarization_middleware_token_retention_pct_respects_tool_pairs() -> N
)
middleware.token_counter = token_counter
# Total tokens: 300 + 200 + 50 + 180 + 160 = 890
# Target keep: 500 tokens (50% of 1000)
# Binary search finds cutoff around index 2 (ToolMessage)
# We advance past it to index 3 (HumanMessage)
messages: list[AnyMessage] = [
HumanMessage(content="H" * 300),
AIMessage(
@@ -344,13 +308,14 @@ def test_summarization_middleware_token_retention_pct_respects_tool_pairs() -> N
assert result is not None
preserved_messages = result["messages"][2:]
assert preserved_messages == messages[1:]
# With aggressive summarization, we advance past the ToolMessage
# So we preserve messages from index 3 onward (the two HumanMessages)
assert preserved_messages == messages[3:]
# Verify preserved tokens are within budget
target_token_count = int(1000 * 0.5)
preserved_tokens = middleware.token_counter(preserved_messages)
# Tool pair retention can exceed the target token count but should keep the pair intact.
assert preserved_tokens > target_token_count
assert preserved_tokens <= target_token_count
def test_summarization_middleware_missing_profile() -> None:
@@ -692,95 +657,38 @@ def test_summarization_middleware_binary_search_edge_cases() -> None:
assert cutoff == 0
def test_summarization_middleware_tool_call_extraction_edge_cases() -> None:
"""Test tool call ID extraction with various message formats."""
model = FakeToolCallingModel()
middleware = SummarizationMiddleware(model=model, trigger=("messages", 5))
# Test with dict-style tool calls
ai_message_dict = AIMessage(
content="test", tool_calls=[{"name": "tool1", "args": {}, "id": "id1"}]
)
ids = middleware._extract_tool_call_ids(ai_message_dict)
assert ids == {"id1"}
# Test with multiple tool calls
ai_message_multiple = AIMessage(
content="test",
tool_calls=[
{"name": "tool1", "args": {}, "id": "id1"},
{"name": "tool2", "args": {}, "id": "id2"},
],
)
ids = middleware._extract_tool_call_ids(ai_message_multiple)
assert ids == {"id1", "id2"}
# Test with empty tool calls list
ai_message_empty = AIMessage(content="test", tool_calls=[])
ids = middleware._extract_tool_call_ids(ai_message_empty)
assert len(ids) == 0
def test_summarization_middleware_complex_tool_pair_scenarios() -> None:
"""Test complex tool call pairing scenarios."""
model = FakeToolCallingModel()
middleware = SummarizationMiddleware(model=model, trigger=("messages", 5), keep=("messages", 3))
# Test with multiple AI messages with tool calls
messages = [
HumanMessage(content="msg1"),
AIMessage(content="ai1", tool_calls=[{"name": "tool1", "args": {}, "id": "call1"}]),
ToolMessage(content="result1", tool_call_id="call1"),
HumanMessage(content="msg2"),
AIMessage(content="ai2", tool_calls=[{"name": "tool2", "args": {}, "id": "call2"}]),
ToolMessage(content="result2", tool_call_id="call2"),
HumanMessage(content="msg3"),
]
# Test cutoff at index 1 - unsafe (separates first AI/Tool pair)
assert not middleware._is_safe_cutoff_point(messages, 2)
# Test cutoff at index 3 - safe (keeps first pair together)
assert middleware._is_safe_cutoff_point(messages, 3)
# Test cutoff at index 5 - unsafe (separates second AI/Tool pair)
assert not middleware._is_safe_cutoff_point(messages, 5)
# Test _cutoff_separates_tool_pair directly
assert middleware._cutoff_separates_tool_pair(messages, 1, 2, {"call1"})
assert not middleware._cutoff_separates_tool_pair(messages, 1, 0, {"call1"})
assert not middleware._cutoff_separates_tool_pair(messages, 1, 3, {"call1"})
def test_summarization_middleware_tool_call_in_search_range() -> None:
"""Test tool call safety with messages at edge of search range."""
def test_summarization_middleware_find_safe_cutoff_point() -> None:
"""Test _find_safe_cutoff_point finds safe cutoff past ToolMessages."""
model = FakeToolCallingModel()
middleware = SummarizationMiddleware(
model=model, trigger=("messages", 10), keep=("messages", 2)
)
# Create messages with tool pair separated by some distance
# Search range is 5, so messages within 5 positions of cutoff are checked
messages = [
messages: list[AnyMessage] = [
HumanMessage(content="msg1"),
HumanMessage(content="msg2"),
AIMessage(content="ai", tool_calls=[{"name": "tool", "args": {}, "id": "call1"}]),
HumanMessage(content="msg3"),
HumanMessage(content="msg4"),
ToolMessage(content="result", tool_call_id="call1"),
HumanMessage(content="msg6"),
ToolMessage(content="result1", tool_call_id="call1"),
ToolMessage(content="result2", tool_call_id="call2"),
HumanMessage(content="msg2"),
]
# Cutoff at index 3 would separate: [0,1,2] from [3,4,5,6]
# AI at index 2 is before cutoff, Tool at index 5 is after cutoff - unsafe
assert not middleware._is_safe_cutoff_point(messages, 3)
# Starting at a non-ToolMessage returns the same index
assert middleware._find_safe_cutoff_point(messages, 0) == 0
assert middleware._find_safe_cutoff_point(messages, 1) == 1
# Cutoff at index 6 keeps AI and Tool both in summarized section
assert middleware._is_safe_cutoff_point(messages, 6)
# Starting at a ToolMessage advances to the next non-ToolMessage
assert middleware._find_safe_cutoff_point(messages, 2) == 4
assert middleware._find_safe_cutoff_point(messages, 3) == 4
# Cutoff at index 0 or 1 also safe - both AI and Tool in preserved section
assert middleware._is_safe_cutoff_point(messages, 0)
assert middleware._is_safe_cutoff_point(messages, 1)
# Starting at the HumanMessage after tools returns that index
assert middleware._find_safe_cutoff_point(messages, 4) == 4
# Starting past the end returns the index unchanged
assert middleware._find_safe_cutoff_point(messages, 5) == 5
# Cutoff at or past length stays the same
assert middleware._find_safe_cutoff_point(messages, len(messages)) == len(messages)
assert middleware._find_safe_cutoff_point(messages, len(messages) + 5) == len(messages) + 5
def test_summarization_middleware_zero_and_negative_target_tokens() -> None:
@@ -880,20 +788,6 @@ def test_summarization_middleware_fraction_trigger_with_no_profile() -> None:
middleware._get_profile_limits = original_method
def test_summarization_middleware_is_safe_cutoff_at_end() -> None:
"""Test _is_safe_cutoff_point when cutoff is at or past the end."""
model = FakeToolCallingModel()
middleware = SummarizationMiddleware(model=model, trigger=("messages", 5))
messages = [HumanMessage(content=str(i)) for i in range(5)]
# Cutoff at exactly the length should be safe
assert middleware._is_safe_cutoff_point(messages, len(messages))
# Cutoff past the length should also be safe
assert middleware._is_safe_cutoff_point(messages, len(messages) + 5)
def test_summarization_adjust_token_counts() -> None:
test_message = HumanMessage(content="a" * 12)
@@ -909,3 +803,84 @@ def test_summarization_adjust_token_counts() -> None:
count_2 = middleware.token_counter([test_message])
assert count_1 != count_2
def test_summarization_middleware_many_parallel_tool_calls_safety() -> None:
"""Test cutoff safety with many parallel tool calls extending beyond old search range."""
middleware = SummarizationMiddleware(
model=MockChatModel(), trigger=("messages", 15), keep=("messages", 5)
)
tool_calls = [{"name": f"tool_{i}", "args": {}, "id": f"call_{i}"} for i in range(10)]
human_message = HumanMessage(content="calling 10 tools")
ai_message = AIMessage(content="calling 10 tools", tool_calls=tool_calls)
tool_messages = [
ToolMessage(content=f"result_{i}", tool_call_id=f"call_{i}") for i in range(10)
]
messages: list[AnyMessage] = [human_message, ai_message, *tool_messages]
# Cutoff at index 7 (a ToolMessage) advances to index 12 (end of messages)
assert middleware._find_safe_cutoff_point(messages, 7) == 12
# Any cutoff pointing at a ToolMessage (indices 2-11) advances to index 12
for i in range(2, 12):
assert middleware._find_safe_cutoff_point(messages, i) == 12
# Cutoff at index 0, 1 (before tool messages) stays the same
assert middleware._find_safe_cutoff_point(messages, 0) == 0
assert middleware._find_safe_cutoff_point(messages, 1) == 1
def test_summarization_middleware_find_safe_cutoff_advances_past_tools() -> None:
"""Test _find_safe_cutoff advances past ToolMessages to find safe cutoff."""
middleware = SummarizationMiddleware(
model=MockChatModel(), trigger=("messages", 10), keep=("messages", 3)
)
# Messages: [Human, AI, Tool, Tool, Tool, Human]
messages: list[AnyMessage] = [
HumanMessage(content="msg1"),
AIMessage(
content="ai",
tool_calls=[
{"name": "tool1", "args": {}, "id": "call1"},
{"name": "tool2", "args": {}, "id": "call2"},
{"name": "tool3", "args": {}, "id": "call3"},
],
),
ToolMessage(content="result1", tool_call_id="call1"),
ToolMessage(content="result2", tool_call_id="call2"),
ToolMessage(content="result3", tool_call_id="call3"),
HumanMessage(content="msg2"),
]
# Target cutoff index is len(messages) - messages_to_keep = 6 - 3 = 3
# Index 3 is a ToolMessage, so we advance past the tool sequence to index 5
cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=3)
assert cutoff == 5
# With messages_to_keep=2, target cutoff index is 6 - 2 = 4
# Index 4 is a ToolMessage, so we advance past the tool sequence to index 5
# This is aggressive - we keep only 1 message instead of 2
cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=2)
assert cutoff == 5
def test_summarization_middleware_cutoff_at_start_of_tool_sequence() -> None:
"""Test cutoff when target lands exactly at the first ToolMessage."""
middleware = SummarizationMiddleware(
model=MockChatModel(), trigger=("messages", 8), keep=("messages", 4)
)
messages: list[AnyMessage] = [
HumanMessage(content="msg1"),
HumanMessage(content="msg2"),
AIMessage(content="ai", tool_calls=[{"name": "tool", "args": {}, "id": "call1"}]),
ToolMessage(content="result", tool_call_id="call1"),
HumanMessage(content="msg3"),
HumanMessage(content="msg4"),
]
# Target cutoff index is len(messages) - messages_to_keep = 6 - 4 = 2
# Index 2 is an AIMessage (safe cutoff point), so no adjustment needed
cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=4)
assert cutoff == 2