Compare commits

...

2 Commits

Author SHA1 Message Date
Sydney Runkle
8a5a5de7d0 Merge branch 'master' into sr/fix-summarization 2025-12-03 10:48:51 -05:00
Sydney Runkle
8d20a391c9 fixing summarization bug 2025-12-03 09:27:44 -05:00
2 changed files with 35 additions and 9 deletions

View File

@@ -56,7 +56,6 @@ Messages to summarize:
_DEFAULT_MESSAGES_TO_KEEP = 20
_DEFAULT_TRIM_TOKEN_LIMIT = 4000
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
_SEARCH_RANGE_FOR_TOOL_PAIRS = 5
ContextFraction = tuple[Literal["fraction"], float]
"""Fraction of model's maximum input tokens.
@@ -480,16 +479,25 @@ class SummarizationMiddleware(AgentMiddleware):
if cutoff_index >= len(messages):
return True
search_start = max(0, cutoff_index - _SEARCH_RANGE_FOR_TOOL_PAIRS)
search_end = min(len(messages), cutoff_index + _SEARCH_RANGE_FOR_TOOL_PAIRS)
for i in range(search_start, search_end):
if not self._has_tool_calls(messages[i]):
# Check tool messages at or after cutoff and find their source AI message
for i in range(cutoff_index, len(messages)):
msg = messages[i]
if not isinstance(msg, ToolMessage):
continue
tool_call_ids = self._extract_tool_call_ids(cast("AIMessage", messages[i]))
if self._cutoff_separates_tool_pair(messages, i, cutoff_index, tool_call_ids):
return False
# Search backwards to find the AI message that generated this tool call
tool_call_id = msg.tool_call_id
for j in range(i - 1, -1, -1):
ai_msg = messages[j]
if not self._has_tool_calls(ai_msg):
continue
ai_tool_ids = self._extract_tool_call_ids(cast("AIMessage", ai_msg))
if tool_call_id in ai_tool_ids:
# Found the AI message - check if cutoff separates them
if j < cutoff_index:
# AI message would be summarized, tool message would be kept
return False
break
return True

View File

@@ -909,3 +909,21 @@ def test_summarization_adjust_token_counts() -> None:
count_2 = middleware.token_counter([test_message])
assert count_1 != count_2
def test_summarization_middleware_many_parallel_tool_calls_safety_gap() -> None:
"""Test cutoff safety with many parallel tool calls extending beyond old search range."""
middleware = SummarizationMiddleware(
model=MockChatModel(), trigger=("messages", 15), keep=("messages", 5)
)
tool_calls = [{"name": f"tool_{i}", "args": {}, "id": f"call_{i}"} for i in range(10)]
human_message = HumanMessage(content="calling 10 tools")
ai_message = AIMessage(content="calling 10 tools", tool_calls=tool_calls)
tool_messages = [
ToolMessage(content=f"result_{i}", tool_call_id=f"call_{i}") for i in range(10)
]
messages: list[AnyMessage] = [human_message, ai_message, *tool_messages]
# Cutoff at index 7 would separate the AI message (index 1) from tool messages 7-11
is_safe = middleware._is_safe_cutoff_point(messages, 7)
assert is_safe is False