mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-25 01:16:55 +00:00
fix: simplify summarization cutoff logic (#34195)
This PR changes how we find the cutoff for summarization, summarizing
content more eagerly if the initial cutoff point isn't safe (ie, would
break apart AI + tool message pairs)
This new algorithm is quite simple - it looks at the initial cutoff
point, if it's not safe, moves forward through the message list until it
finds the first non tool message.
For example:
```
H
AI
TM
--- theoretical cutoff based keep=('messages', 3)
TM
AI
TM
```
```
H
AI
TM
TM
--- actual cutoff, more aggressive summarization
AI
TM
```
This commit is contained in:
@@ -7,7 +7,6 @@ from functools import partial
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
MessageLikeRepresentation,
|
||||
RemoveMessage,
|
||||
@@ -56,7 +55,6 @@ Messages to summarize:
|
||||
_DEFAULT_MESSAGES_TO_KEEP = 20
|
||||
_DEFAULT_TRIM_TOKEN_LIMIT = 4000
|
||||
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
|
||||
_SEARCH_RANGE_FOR_TOOL_PAIRS = 5
|
||||
|
||||
ContextFraction = tuple[Literal["fraction"], float]
|
||||
"""Fraction of model's maximum input tokens.
|
||||
@@ -397,11 +395,8 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
return 0
|
||||
cutoff_candidate = len(messages) - 1
|
||||
|
||||
for i in range(cutoff_candidate, -1, -1):
|
||||
if self._is_safe_cutoff_point(messages, i):
|
||||
return i
|
||||
|
||||
return 0
|
||||
# Advance past any ToolMessages to avoid splitting AI/Tool pairs
|
||||
return self._find_safe_cutoff_point(messages, cutoff_candidate)
|
||||
|
||||
def _get_profile_limits(self) -> int | None:
|
||||
"""Retrieve max input token limit from the model profile."""
|
||||
@@ -463,67 +458,26 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
|
||||
Returns the index where messages can be safely cut without separating
|
||||
related AI and Tool messages. Returns `0` if no safe cutoff is found.
|
||||
|
||||
This is aggressive with summarization - if the target cutoff lands in the
|
||||
middle of tool messages, we advance past all of them (summarizing more).
|
||||
"""
|
||||
if len(messages) <= messages_to_keep:
|
||||
return 0
|
||||
|
||||
target_cutoff = len(messages) - messages_to_keep
|
||||
return self._find_safe_cutoff_point(messages, target_cutoff)
|
||||
|
||||
for i in range(target_cutoff, -1, -1):
|
||||
if self._is_safe_cutoff_point(messages, i):
|
||||
return i
|
||||
def _find_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> int:
|
||||
"""Find a safe cutoff point that doesn't split AI/Tool message pairs.
|
||||
|
||||
return 0
|
||||
|
||||
def _is_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> bool:
|
||||
"""Check if cutting at index would separate AI/Tool message pairs."""
|
||||
if cutoff_index >= len(messages):
|
||||
return True
|
||||
|
||||
search_start = max(0, cutoff_index - _SEARCH_RANGE_FOR_TOOL_PAIRS)
|
||||
search_end = min(len(messages), cutoff_index + _SEARCH_RANGE_FOR_TOOL_PAIRS)
|
||||
|
||||
for i in range(search_start, search_end):
|
||||
if not self._has_tool_calls(messages[i]):
|
||||
continue
|
||||
|
||||
tool_call_ids = self._extract_tool_call_ids(cast("AIMessage", messages[i]))
|
||||
if self._cutoff_separates_tool_pair(messages, i, cutoff_index, tool_call_ids):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _has_tool_calls(self, message: AnyMessage) -> bool:
|
||||
"""Check if message is an AI message with tool calls."""
|
||||
return (
|
||||
isinstance(message, AIMessage) and hasattr(message, "tool_calls") and message.tool_calls # type: ignore[return-value]
|
||||
)
|
||||
|
||||
def _extract_tool_call_ids(self, ai_message: AIMessage) -> set[str]:
|
||||
"""Extract tool call IDs from an AI message."""
|
||||
tool_call_ids = set()
|
||||
for tc in ai_message.tool_calls:
|
||||
call_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None)
|
||||
if call_id is not None:
|
||||
tool_call_ids.add(call_id)
|
||||
return tool_call_ids
|
||||
|
||||
def _cutoff_separates_tool_pair(
|
||||
self,
|
||||
messages: list[AnyMessage],
|
||||
ai_message_index: int,
|
||||
cutoff_index: int,
|
||||
tool_call_ids: set[str],
|
||||
) -> bool:
|
||||
"""Check if cutoff separates an AI message from its corresponding tool messages."""
|
||||
for j in range(ai_message_index + 1, len(messages)):
|
||||
message = messages[j]
|
||||
if isinstance(message, ToolMessage) and message.tool_call_id in tool_call_ids:
|
||||
ai_before_cutoff = ai_message_index < cutoff_index
|
||||
tool_before_cutoff = j < cutoff_index
|
||||
if ai_before_cutoff != tool_before_cutoff:
|
||||
return True
|
||||
return False
|
||||
If the message at cutoff_index is a ToolMessage, advance until we find
|
||||
a non-ToolMessage. This ensures we never cut in the middle of parallel
|
||||
tool call responses.
|
||||
"""
|
||||
while cutoff_index < len(messages) and isinstance(messages[cutoff_index], ToolMessage):
|
||||
cutoff_index += 1
|
||||
return cutoff_index
|
||||
|
||||
def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
||||
"""Generate summary for the given messages."""
|
||||
|
||||
@@ -121,46 +121,6 @@ def test_summarization_middleware_helper_methods() -> None:
|
||||
assert "Here is a summary of the conversation to date:" in new_messages[0].content
|
||||
assert summary in new_messages[0].content
|
||||
|
||||
# Test tool call detection
|
||||
ai_message_no_tools = AIMessage(content="Hello")
|
||||
assert not middleware._has_tool_calls(ai_message_no_tools)
|
||||
|
||||
ai_message_with_tools = AIMessage(
|
||||
content="Hello", tool_calls=[{"name": "test", "args": {}, "id": "1"}]
|
||||
)
|
||||
assert middleware._has_tool_calls(ai_message_with_tools)
|
||||
|
||||
human_message = HumanMessage(content="Hello")
|
||||
assert not middleware._has_tool_calls(human_message)
|
||||
|
||||
|
||||
def test_summarization_middleware_tool_call_safety() -> None:
|
||||
"""Test SummarizationMiddleware tool call safety logic."""
|
||||
model = FakeToolCallingModel()
|
||||
middleware = SummarizationMiddleware(
|
||||
model=model, trigger=("tokens", 1000), keep=("messages", 3)
|
||||
)
|
||||
|
||||
# Test safe cutoff point detection with tool calls
|
||||
messages = [
|
||||
HumanMessage(content="1"),
|
||||
AIMessage(content="2", tool_calls=[{"name": "test", "args": {}, "id": "1"}]),
|
||||
ToolMessage(content="3", tool_call_id="1"),
|
||||
HumanMessage(content="4"),
|
||||
]
|
||||
|
||||
# Safe cutoff (doesn't separate AI/Tool pair)
|
||||
is_safe = middleware._is_safe_cutoff_point(messages, 0)
|
||||
assert is_safe is True
|
||||
|
||||
# Unsafe cutoff (separates AI/Tool pair)
|
||||
is_safe = middleware._is_safe_cutoff_point(messages, 2)
|
||||
assert is_safe is False
|
||||
|
||||
# Test tool call ID extraction
|
||||
ids = middleware._extract_tool_call_ids(messages[1])
|
||||
assert ids == {"1"}
|
||||
|
||||
|
||||
def test_summarization_middleware_summary_creation() -> None:
|
||||
"""Test SummarizationMiddleware summary creation."""
|
||||
@@ -315,8 +275,8 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
|
||||
]
|
||||
|
||||
|
||||
def test_summarization_middleware_token_retention_pct_respects_tool_pairs() -> None:
|
||||
"""Ensure token retention keeps pairs together even if exceeding target tokens."""
|
||||
def test_summarization_middleware_token_retention_advances_past_tool_messages() -> None:
|
||||
"""Ensure token retention advances past tool messages for aggressive summarization."""
|
||||
|
||||
def token_counter(messages: list[AnyMessage]) -> int:
|
||||
return sum(len(getattr(message, "content", "")) for message in messages)
|
||||
@@ -328,6 +288,10 @@ def test_summarization_middleware_token_retention_pct_respects_tool_pairs() -> N
|
||||
)
|
||||
middleware.token_counter = token_counter
|
||||
|
||||
# Total tokens: 300 + 200 + 50 + 180 + 160 = 890
|
||||
# Target keep: 500 tokens (50% of 1000)
|
||||
# Binary search finds cutoff around index 2 (ToolMessage)
|
||||
# We advance past it to index 3 (HumanMessage)
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="H" * 300),
|
||||
AIMessage(
|
||||
@@ -344,13 +308,14 @@ def test_summarization_middleware_token_retention_pct_respects_tool_pairs() -> N
|
||||
assert result is not None
|
||||
|
||||
preserved_messages = result["messages"][2:]
|
||||
assert preserved_messages == messages[1:]
|
||||
# With aggressive summarization, we advance past the ToolMessage
|
||||
# So we preserve messages from index 3 onward (the two HumanMessages)
|
||||
assert preserved_messages == messages[3:]
|
||||
|
||||
# Verify preserved tokens are within budget
|
||||
target_token_count = int(1000 * 0.5)
|
||||
preserved_tokens = middleware.token_counter(preserved_messages)
|
||||
|
||||
# Tool pair retention can exceed the target token count but should keep the pair intact.
|
||||
assert preserved_tokens > target_token_count
|
||||
assert preserved_tokens <= target_token_count
|
||||
|
||||
|
||||
def test_summarization_middleware_missing_profile() -> None:
|
||||
@@ -692,95 +657,38 @@ def test_summarization_middleware_binary_search_edge_cases() -> None:
|
||||
assert cutoff == 0
|
||||
|
||||
|
||||
def test_summarization_middleware_tool_call_extraction_edge_cases() -> None:
|
||||
"""Test tool call ID extraction with various message formats."""
|
||||
model = FakeToolCallingModel()
|
||||
middleware = SummarizationMiddleware(model=model, trigger=("messages", 5))
|
||||
|
||||
# Test with dict-style tool calls
|
||||
ai_message_dict = AIMessage(
|
||||
content="test", tool_calls=[{"name": "tool1", "args": {}, "id": "id1"}]
|
||||
)
|
||||
ids = middleware._extract_tool_call_ids(ai_message_dict)
|
||||
assert ids == {"id1"}
|
||||
|
||||
# Test with multiple tool calls
|
||||
ai_message_multiple = AIMessage(
|
||||
content="test",
|
||||
tool_calls=[
|
||||
{"name": "tool1", "args": {}, "id": "id1"},
|
||||
{"name": "tool2", "args": {}, "id": "id2"},
|
||||
],
|
||||
)
|
||||
ids = middleware._extract_tool_call_ids(ai_message_multiple)
|
||||
assert ids == {"id1", "id2"}
|
||||
|
||||
# Test with empty tool calls list
|
||||
ai_message_empty = AIMessage(content="test", tool_calls=[])
|
||||
ids = middleware._extract_tool_call_ids(ai_message_empty)
|
||||
assert len(ids) == 0
|
||||
|
||||
|
||||
def test_summarization_middleware_complex_tool_pair_scenarios() -> None:
|
||||
"""Test complex tool call pairing scenarios."""
|
||||
model = FakeToolCallingModel()
|
||||
middleware = SummarizationMiddleware(model=model, trigger=("messages", 5), keep=("messages", 3))
|
||||
|
||||
# Test with multiple AI messages with tool calls
|
||||
messages = [
|
||||
HumanMessage(content="msg1"),
|
||||
AIMessage(content="ai1", tool_calls=[{"name": "tool1", "args": {}, "id": "call1"}]),
|
||||
ToolMessage(content="result1", tool_call_id="call1"),
|
||||
HumanMessage(content="msg2"),
|
||||
AIMessage(content="ai2", tool_calls=[{"name": "tool2", "args": {}, "id": "call2"}]),
|
||||
ToolMessage(content="result2", tool_call_id="call2"),
|
||||
HumanMessage(content="msg3"),
|
||||
]
|
||||
|
||||
# Test cutoff at index 1 - unsafe (separates first AI/Tool pair)
|
||||
assert not middleware._is_safe_cutoff_point(messages, 2)
|
||||
|
||||
# Test cutoff at index 3 - safe (keeps first pair together)
|
||||
assert middleware._is_safe_cutoff_point(messages, 3)
|
||||
|
||||
# Test cutoff at index 5 - unsafe (separates second AI/Tool pair)
|
||||
assert not middleware._is_safe_cutoff_point(messages, 5)
|
||||
|
||||
# Test _cutoff_separates_tool_pair directly
|
||||
assert middleware._cutoff_separates_tool_pair(messages, 1, 2, {"call1"})
|
||||
assert not middleware._cutoff_separates_tool_pair(messages, 1, 0, {"call1"})
|
||||
assert not middleware._cutoff_separates_tool_pair(messages, 1, 3, {"call1"})
|
||||
|
||||
|
||||
def test_summarization_middleware_tool_call_in_search_range() -> None:
|
||||
"""Test tool call safety with messages at edge of search range."""
|
||||
def test_summarization_middleware_find_safe_cutoff_point() -> None:
|
||||
"""Test _find_safe_cutoff_point finds safe cutoff past ToolMessages."""
|
||||
model = FakeToolCallingModel()
|
||||
middleware = SummarizationMiddleware(
|
||||
model=model, trigger=("messages", 10), keep=("messages", 2)
|
||||
)
|
||||
|
||||
# Create messages with tool pair separated by some distance
|
||||
# Search range is 5, so messages within 5 positions of cutoff are checked
|
||||
messages = [
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="msg1"),
|
||||
HumanMessage(content="msg2"),
|
||||
AIMessage(content="ai", tool_calls=[{"name": "tool", "args": {}, "id": "call1"}]),
|
||||
HumanMessage(content="msg3"),
|
||||
HumanMessage(content="msg4"),
|
||||
ToolMessage(content="result", tool_call_id="call1"),
|
||||
HumanMessage(content="msg6"),
|
||||
ToolMessage(content="result1", tool_call_id="call1"),
|
||||
ToolMessage(content="result2", tool_call_id="call2"),
|
||||
HumanMessage(content="msg2"),
|
||||
]
|
||||
|
||||
# Cutoff at index 3 would separate: [0,1,2] from [3,4,5,6]
|
||||
# AI at index 2 is before cutoff, Tool at index 5 is after cutoff - unsafe
|
||||
assert not middleware._is_safe_cutoff_point(messages, 3)
|
||||
# Starting at a non-ToolMessage returns the same index
|
||||
assert middleware._find_safe_cutoff_point(messages, 0) == 0
|
||||
assert middleware._find_safe_cutoff_point(messages, 1) == 1
|
||||
|
||||
# Cutoff at index 6 keeps AI and Tool both in summarized section
|
||||
assert middleware._is_safe_cutoff_point(messages, 6)
|
||||
# Starting at a ToolMessage advances to the next non-ToolMessage
|
||||
assert middleware._find_safe_cutoff_point(messages, 2) == 4
|
||||
assert middleware._find_safe_cutoff_point(messages, 3) == 4
|
||||
|
||||
# Cutoff at index 0 or 1 also safe - both AI and Tool in preserved section
|
||||
assert middleware._is_safe_cutoff_point(messages, 0)
|
||||
assert middleware._is_safe_cutoff_point(messages, 1)
|
||||
# Starting at the HumanMessage after tools returns that index
|
||||
assert middleware._find_safe_cutoff_point(messages, 4) == 4
|
||||
|
||||
# Starting past the end returns the index unchanged
|
||||
assert middleware._find_safe_cutoff_point(messages, 5) == 5
|
||||
|
||||
# Cutoff at or past length stays the same
|
||||
assert middleware._find_safe_cutoff_point(messages, len(messages)) == len(messages)
|
||||
assert middleware._find_safe_cutoff_point(messages, len(messages) + 5) == len(messages) + 5
|
||||
|
||||
|
||||
def test_summarization_middleware_zero_and_negative_target_tokens() -> None:
|
||||
@@ -880,20 +788,6 @@ def test_summarization_middleware_fraction_trigger_with_no_profile() -> None:
|
||||
middleware._get_profile_limits = original_method
|
||||
|
||||
|
||||
def test_summarization_middleware_is_safe_cutoff_at_end() -> None:
|
||||
"""Test _is_safe_cutoff_point when cutoff is at or past the end."""
|
||||
model = FakeToolCallingModel()
|
||||
middleware = SummarizationMiddleware(model=model, trigger=("messages", 5))
|
||||
|
||||
messages = [HumanMessage(content=str(i)) for i in range(5)]
|
||||
|
||||
# Cutoff at exactly the length should be safe
|
||||
assert middleware._is_safe_cutoff_point(messages, len(messages))
|
||||
|
||||
# Cutoff past the length should also be safe
|
||||
assert middleware._is_safe_cutoff_point(messages, len(messages) + 5)
|
||||
|
||||
|
||||
def test_summarization_adjust_token_counts() -> None:
|
||||
test_message = HumanMessage(content="a" * 12)
|
||||
|
||||
@@ -909,3 +803,84 @@ def test_summarization_adjust_token_counts() -> None:
|
||||
count_2 = middleware.token_counter([test_message])
|
||||
|
||||
assert count_1 != count_2
|
||||
|
||||
|
||||
def test_summarization_middleware_many_parallel_tool_calls_safety() -> None:
|
||||
"""Test cutoff safety with many parallel tool calls extending beyond old search range."""
|
||||
middleware = SummarizationMiddleware(
|
||||
model=MockChatModel(), trigger=("messages", 15), keep=("messages", 5)
|
||||
)
|
||||
tool_calls = [{"name": f"tool_{i}", "args": {}, "id": f"call_{i}"} for i in range(10)]
|
||||
human_message = HumanMessage(content="calling 10 tools")
|
||||
ai_message = AIMessage(content="calling 10 tools", tool_calls=tool_calls)
|
||||
tool_messages = [
|
||||
ToolMessage(content=f"result_{i}", tool_call_id=f"call_{i}") for i in range(10)
|
||||
]
|
||||
messages: list[AnyMessage] = [human_message, ai_message, *tool_messages]
|
||||
|
||||
# Cutoff at index 7 (a ToolMessage) advances to index 12 (end of messages)
|
||||
assert middleware._find_safe_cutoff_point(messages, 7) == 12
|
||||
|
||||
# Any cutoff pointing at a ToolMessage (indices 2-11) advances to index 12
|
||||
for i in range(2, 12):
|
||||
assert middleware._find_safe_cutoff_point(messages, i) == 12
|
||||
|
||||
# Cutoff at index 0, 1 (before tool messages) stays the same
|
||||
assert middleware._find_safe_cutoff_point(messages, 0) == 0
|
||||
assert middleware._find_safe_cutoff_point(messages, 1) == 1
|
||||
|
||||
|
||||
def test_summarization_middleware_find_safe_cutoff_advances_past_tools() -> None:
|
||||
"""Test _find_safe_cutoff advances past ToolMessages to find safe cutoff."""
|
||||
middleware = SummarizationMiddleware(
|
||||
model=MockChatModel(), trigger=("messages", 10), keep=("messages", 3)
|
||||
)
|
||||
|
||||
# Messages: [Human, AI, Tool, Tool, Tool, Human]
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="msg1"),
|
||||
AIMessage(
|
||||
content="ai",
|
||||
tool_calls=[
|
||||
{"name": "tool1", "args": {}, "id": "call1"},
|
||||
{"name": "tool2", "args": {}, "id": "call2"},
|
||||
{"name": "tool3", "args": {}, "id": "call3"},
|
||||
],
|
||||
),
|
||||
ToolMessage(content="result1", tool_call_id="call1"),
|
||||
ToolMessage(content="result2", tool_call_id="call2"),
|
||||
ToolMessage(content="result3", tool_call_id="call3"),
|
||||
HumanMessage(content="msg2"),
|
||||
]
|
||||
|
||||
# Target cutoff index is len(messages) - messages_to_keep = 6 - 3 = 3
|
||||
# Index 3 is a ToolMessage, so we advance past the tool sequence to index 5
|
||||
cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=3)
|
||||
assert cutoff == 5
|
||||
|
||||
# With messages_to_keep=2, target cutoff index is 6 - 2 = 4
|
||||
# Index 4 is a ToolMessage, so we advance past the tool sequence to index 5
|
||||
# This is aggressive - we keep only 1 message instead of 2
|
||||
cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=2)
|
||||
assert cutoff == 5
|
||||
|
||||
|
||||
def test_summarization_middleware_cutoff_at_start_of_tool_sequence() -> None:
|
||||
"""Test cutoff when target lands exactly at the first ToolMessage."""
|
||||
middleware = SummarizationMiddleware(
|
||||
model=MockChatModel(), trigger=("messages", 8), keep=("messages", 4)
|
||||
)
|
||||
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="msg1"),
|
||||
HumanMessage(content="msg2"),
|
||||
AIMessage(content="ai", tool_calls=[{"name": "tool", "args": {}, "id": "call1"}]),
|
||||
ToolMessage(content="result", tool_call_id="call1"),
|
||||
HumanMessage(content="msg3"),
|
||||
HumanMessage(content="msg4"),
|
||||
]
|
||||
|
||||
# Target cutoff index is len(messages) - messages_to_keep = 6 - 4 = 2
|
||||
# Index 2 is an AIMessage (safe cutoff point), so no adjustment needed
|
||||
cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=4)
|
||||
assert cutoff == 2
|
||||
|
||||
Reference in New Issue
Block a user