From 3030ffc248fae90765fc437a22bbeb150f095427 Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Thu, 4 Dec 2025 12:44:50 -0500 Subject: [PATCH] fix: simplify summarization cutoff logic (#34195) This PR changes how we find the cutoff for summarization, summarizing content more eagerly if the initial cutoff point isn't safe (ie, would break apart AI + tool message pairs) This new algorithm is quite simple - it looks at the initial cutoff point, if it's not safe, moves forward through the message list until it finds the first non tool message. For example: ``` H AI TM --- theoretical cutoff based keep=('messages', 3) TM AI TM ``` ``` H AI TM TM --- actual cutoff, more aggressive summarization AI TM ``` --- .../agents/middleware/summarization.py | 76 ++---- .../implementations/test_summarization.py | 251 ++++++++---------- 2 files changed, 128 insertions(+), 199 deletions(-) diff --git a/libs/langchain_v1/langchain/agents/middleware/summarization.py b/libs/langchain_v1/langchain/agents/middleware/summarization.py index 6e67b7e5d57..6055c246a18 100644 --- a/libs/langchain_v1/langchain/agents/middleware/summarization.py +++ b/libs/langchain_v1/langchain/agents/middleware/summarization.py @@ -7,7 +7,6 @@ from functools import partial from typing import Any, Literal, cast from langchain_core.messages import ( - AIMessage, AnyMessage, MessageLikeRepresentation, RemoveMessage, @@ -56,7 +55,6 @@ Messages to summarize: _DEFAULT_MESSAGES_TO_KEEP = 20 _DEFAULT_TRIM_TOKEN_LIMIT = 4000 _DEFAULT_FALLBACK_MESSAGE_COUNT = 15 -_SEARCH_RANGE_FOR_TOOL_PAIRS = 5 ContextFraction = tuple[Literal["fraction"], float] """Fraction of model's maximum input tokens. @@ -397,11 +395,8 @@ class SummarizationMiddleware(AgentMiddleware): return 0 cutoff_candidate = len(messages) - 1 - for i in range(cutoff_candidate, -1, -1): - if self._is_safe_cutoff_point(messages, i): - return i - - return 0 + # Advance past any ToolMessages to avoid splitting AI/Tool pairs + return self._find_safe_cutoff_point(messages, cutoff_candidate) def _get_profile_limits(self) -> int | None: """Retrieve max input token limit from the model profile.""" @@ -463,67 +458,26 @@ class SummarizationMiddleware(AgentMiddleware): Returns the index where messages can be safely cut without separating related AI and Tool messages. Returns `0` if no safe cutoff is found. + + This is aggressive with summarization - if the target cutoff lands in the + middle of tool messages, we advance past all of them (summarizing more). """ if len(messages) <= messages_to_keep: return 0 target_cutoff = len(messages) - messages_to_keep + return self._find_safe_cutoff_point(messages, target_cutoff) - for i in range(target_cutoff, -1, -1): - if self._is_safe_cutoff_point(messages, i): - return i + def _find_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> int: + """Find a safe cutoff point that doesn't split AI/Tool message pairs. - return 0 - - def _is_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> bool: - """Check if cutting at index would separate AI/Tool message pairs.""" - if cutoff_index >= len(messages): - return True - - search_start = max(0, cutoff_index - _SEARCH_RANGE_FOR_TOOL_PAIRS) - search_end = min(len(messages), cutoff_index + _SEARCH_RANGE_FOR_TOOL_PAIRS) - - for i in range(search_start, search_end): - if not self._has_tool_calls(messages[i]): - continue - - tool_call_ids = self._extract_tool_call_ids(cast("AIMessage", messages[i])) - if self._cutoff_separates_tool_pair(messages, i, cutoff_index, tool_call_ids): - return False - - return True - - def _has_tool_calls(self, message: AnyMessage) -> bool: - """Check if message is an AI message with tool calls.""" - return ( - isinstance(message, AIMessage) and hasattr(message, "tool_calls") and message.tool_calls # type: ignore[return-value] - ) - - def _extract_tool_call_ids(self, ai_message: AIMessage) -> set[str]: - """Extract tool call IDs from an AI message.""" - tool_call_ids = set() - for tc in ai_message.tool_calls: - call_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None) - if call_id is not None: - tool_call_ids.add(call_id) - return tool_call_ids - - def _cutoff_separates_tool_pair( - self, - messages: list[AnyMessage], - ai_message_index: int, - cutoff_index: int, - tool_call_ids: set[str], - ) -> bool: - """Check if cutoff separates an AI message from its corresponding tool messages.""" - for j in range(ai_message_index + 1, len(messages)): - message = messages[j] - if isinstance(message, ToolMessage) and message.tool_call_id in tool_call_ids: - ai_before_cutoff = ai_message_index < cutoff_index - tool_before_cutoff = j < cutoff_index - if ai_before_cutoff != tool_before_cutoff: - return True - return False + If the message at cutoff_index is a ToolMessage, advance until we find + a non-ToolMessage. This ensures we never cut in the middle of parallel + tool call responses. + """ + while cutoff_index < len(messages) and isinstance(messages[cutoff_index], ToolMessage): + cutoff_index += 1 + return cutoff_index def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str: """Generate summary for the given messages.""" diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py index 015b37bd2b6..f811c1b4d02 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py @@ -121,46 +121,6 @@ def test_summarization_middleware_helper_methods() -> None: assert "Here is a summary of the conversation to date:" in new_messages[0].content assert summary in new_messages[0].content - # Test tool call detection - ai_message_no_tools = AIMessage(content="Hello") - assert not middleware._has_tool_calls(ai_message_no_tools) - - ai_message_with_tools = AIMessage( - content="Hello", tool_calls=[{"name": "test", "args": {}, "id": "1"}] - ) - assert middleware._has_tool_calls(ai_message_with_tools) - - human_message = HumanMessage(content="Hello") - assert not middleware._has_tool_calls(human_message) - - -def test_summarization_middleware_tool_call_safety() -> None: - """Test SummarizationMiddleware tool call safety logic.""" - model = FakeToolCallingModel() - middleware = SummarizationMiddleware( - model=model, trigger=("tokens", 1000), keep=("messages", 3) - ) - - # Test safe cutoff point detection with tool calls - messages = [ - HumanMessage(content="1"), - AIMessage(content="2", tool_calls=[{"name": "test", "args": {}, "id": "1"}]), - ToolMessage(content="3", tool_call_id="1"), - HumanMessage(content="4"), - ] - - # Safe cutoff (doesn't separate AI/Tool pair) - is_safe = middleware._is_safe_cutoff_point(messages, 0) - assert is_safe is True - - # Unsafe cutoff (separates AI/Tool pair) - is_safe = middleware._is_safe_cutoff_point(messages, 2) - assert is_safe is False - - # Test tool call ID extraction - ids = middleware._extract_tool_call_ids(messages[1]) - assert ids == {"1"} - def test_summarization_middleware_summary_creation() -> None: """Test SummarizationMiddleware summary creation.""" @@ -315,8 +275,8 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None: ] -def test_summarization_middleware_token_retention_pct_respects_tool_pairs() -> None: - """Ensure token retention keeps pairs together even if exceeding target tokens.""" +def test_summarization_middleware_token_retention_advances_past_tool_messages() -> None: + """Ensure token retention advances past tool messages for aggressive summarization.""" def token_counter(messages: list[AnyMessage]) -> int: return sum(len(getattr(message, "content", "")) for message in messages) @@ -328,6 +288,10 @@ def test_summarization_middleware_token_retention_pct_respects_tool_pairs() -> N ) middleware.token_counter = token_counter + # Total tokens: 300 + 200 + 50 + 180 + 160 = 890 + # Target keep: 500 tokens (50% of 1000) + # Binary search finds cutoff around index 2 (ToolMessage) + # We advance past it to index 3 (HumanMessage) messages: list[AnyMessage] = [ HumanMessage(content="H" * 300), AIMessage( @@ -344,13 +308,14 @@ def test_summarization_middleware_token_retention_pct_respects_tool_pairs() -> N assert result is not None preserved_messages = result["messages"][2:] - assert preserved_messages == messages[1:] + # With aggressive summarization, we advance past the ToolMessage + # So we preserve messages from index 3 onward (the two HumanMessages) + assert preserved_messages == messages[3:] + # Verify preserved tokens are within budget target_token_count = int(1000 * 0.5) preserved_tokens = middleware.token_counter(preserved_messages) - - # Tool pair retention can exceed the target token count but should keep the pair intact. - assert preserved_tokens > target_token_count + assert preserved_tokens <= target_token_count def test_summarization_middleware_missing_profile() -> None: @@ -692,95 +657,38 @@ def test_summarization_middleware_binary_search_edge_cases() -> None: assert cutoff == 0 -def test_summarization_middleware_tool_call_extraction_edge_cases() -> None: - """Test tool call ID extraction with various message formats.""" - model = FakeToolCallingModel() - middleware = SummarizationMiddleware(model=model, trigger=("messages", 5)) - - # Test with dict-style tool calls - ai_message_dict = AIMessage( - content="test", tool_calls=[{"name": "tool1", "args": {}, "id": "id1"}] - ) - ids = middleware._extract_tool_call_ids(ai_message_dict) - assert ids == {"id1"} - - # Test with multiple tool calls - ai_message_multiple = AIMessage( - content="test", - tool_calls=[ - {"name": "tool1", "args": {}, "id": "id1"}, - {"name": "tool2", "args": {}, "id": "id2"}, - ], - ) - ids = middleware._extract_tool_call_ids(ai_message_multiple) - assert ids == {"id1", "id2"} - - # Test with empty tool calls list - ai_message_empty = AIMessage(content="test", tool_calls=[]) - ids = middleware._extract_tool_call_ids(ai_message_empty) - assert len(ids) == 0 - - -def test_summarization_middleware_complex_tool_pair_scenarios() -> None: - """Test complex tool call pairing scenarios.""" - model = FakeToolCallingModel() - middleware = SummarizationMiddleware(model=model, trigger=("messages", 5), keep=("messages", 3)) - - # Test with multiple AI messages with tool calls - messages = [ - HumanMessage(content="msg1"), - AIMessage(content="ai1", tool_calls=[{"name": "tool1", "args": {}, "id": "call1"}]), - ToolMessage(content="result1", tool_call_id="call1"), - HumanMessage(content="msg2"), - AIMessage(content="ai2", tool_calls=[{"name": "tool2", "args": {}, "id": "call2"}]), - ToolMessage(content="result2", tool_call_id="call2"), - HumanMessage(content="msg3"), - ] - - # Test cutoff at index 1 - unsafe (separates first AI/Tool pair) - assert not middleware._is_safe_cutoff_point(messages, 2) - - # Test cutoff at index 3 - safe (keeps first pair together) - assert middleware._is_safe_cutoff_point(messages, 3) - - # Test cutoff at index 5 - unsafe (separates second AI/Tool pair) - assert not middleware._is_safe_cutoff_point(messages, 5) - - # Test _cutoff_separates_tool_pair directly - assert middleware._cutoff_separates_tool_pair(messages, 1, 2, {"call1"}) - assert not middleware._cutoff_separates_tool_pair(messages, 1, 0, {"call1"}) - assert not middleware._cutoff_separates_tool_pair(messages, 1, 3, {"call1"}) - - -def test_summarization_middleware_tool_call_in_search_range() -> None: - """Test tool call safety with messages at edge of search range.""" +def test_summarization_middleware_find_safe_cutoff_point() -> None: + """Test _find_safe_cutoff_point finds safe cutoff past ToolMessages.""" model = FakeToolCallingModel() middleware = SummarizationMiddleware( model=model, trigger=("messages", 10), keep=("messages", 2) ) - # Create messages with tool pair separated by some distance - # Search range is 5, so messages within 5 positions of cutoff are checked - messages = [ + messages: list[AnyMessage] = [ HumanMessage(content="msg1"), - HumanMessage(content="msg2"), AIMessage(content="ai", tool_calls=[{"name": "tool", "args": {}, "id": "call1"}]), - HumanMessage(content="msg3"), - HumanMessage(content="msg4"), - ToolMessage(content="result", tool_call_id="call1"), - HumanMessage(content="msg6"), + ToolMessage(content="result1", tool_call_id="call1"), + ToolMessage(content="result2", tool_call_id="call2"), + HumanMessage(content="msg2"), ] - # Cutoff at index 3 would separate: [0,1,2] from [3,4,5,6] - # AI at index 2 is before cutoff, Tool at index 5 is after cutoff - unsafe - assert not middleware._is_safe_cutoff_point(messages, 3) + # Starting at a non-ToolMessage returns the same index + assert middleware._find_safe_cutoff_point(messages, 0) == 0 + assert middleware._find_safe_cutoff_point(messages, 1) == 1 - # Cutoff at index 6 keeps AI and Tool both in summarized section - assert middleware._is_safe_cutoff_point(messages, 6) + # Starting at a ToolMessage advances to the next non-ToolMessage + assert middleware._find_safe_cutoff_point(messages, 2) == 4 + assert middleware._find_safe_cutoff_point(messages, 3) == 4 - # Cutoff at index 0 or 1 also safe - both AI and Tool in preserved section - assert middleware._is_safe_cutoff_point(messages, 0) - assert middleware._is_safe_cutoff_point(messages, 1) + # Starting at the HumanMessage after tools returns that index + assert middleware._find_safe_cutoff_point(messages, 4) == 4 + + # Starting past the end returns the index unchanged + assert middleware._find_safe_cutoff_point(messages, 5) == 5 + + # Cutoff at or past length stays the same + assert middleware._find_safe_cutoff_point(messages, len(messages)) == len(messages) + assert middleware._find_safe_cutoff_point(messages, len(messages) + 5) == len(messages) + 5 def test_summarization_middleware_zero_and_negative_target_tokens() -> None: @@ -880,20 +788,6 @@ def test_summarization_middleware_fraction_trigger_with_no_profile() -> None: middleware._get_profile_limits = original_method -def test_summarization_middleware_is_safe_cutoff_at_end() -> None: - """Test _is_safe_cutoff_point when cutoff is at or past the end.""" - model = FakeToolCallingModel() - middleware = SummarizationMiddleware(model=model, trigger=("messages", 5)) - - messages = [HumanMessage(content=str(i)) for i in range(5)] - - # Cutoff at exactly the length should be safe - assert middleware._is_safe_cutoff_point(messages, len(messages)) - - # Cutoff past the length should also be safe - assert middleware._is_safe_cutoff_point(messages, len(messages) + 5) - - def test_summarization_adjust_token_counts() -> None: test_message = HumanMessage(content="a" * 12) @@ -909,3 +803,84 @@ def test_summarization_adjust_token_counts() -> None: count_2 = middleware.token_counter([test_message]) assert count_1 != count_2 + + +def test_summarization_middleware_many_parallel_tool_calls_safety() -> None: + """Test cutoff safety with many parallel tool calls extending beyond old search range.""" + middleware = SummarizationMiddleware( + model=MockChatModel(), trigger=("messages", 15), keep=("messages", 5) + ) + tool_calls = [{"name": f"tool_{i}", "args": {}, "id": f"call_{i}"} for i in range(10)] + human_message = HumanMessage(content="calling 10 tools") + ai_message = AIMessage(content="calling 10 tools", tool_calls=tool_calls) + tool_messages = [ + ToolMessage(content=f"result_{i}", tool_call_id=f"call_{i}") for i in range(10) + ] + messages: list[AnyMessage] = [human_message, ai_message, *tool_messages] + + # Cutoff at index 7 (a ToolMessage) advances to index 12 (end of messages) + assert middleware._find_safe_cutoff_point(messages, 7) == 12 + + # Any cutoff pointing at a ToolMessage (indices 2-11) advances to index 12 + for i in range(2, 12): + assert middleware._find_safe_cutoff_point(messages, i) == 12 + + # Cutoff at index 0, 1 (before tool messages) stays the same + assert middleware._find_safe_cutoff_point(messages, 0) == 0 + assert middleware._find_safe_cutoff_point(messages, 1) == 1 + + +def test_summarization_middleware_find_safe_cutoff_advances_past_tools() -> None: + """Test _find_safe_cutoff advances past ToolMessages to find safe cutoff.""" + middleware = SummarizationMiddleware( + model=MockChatModel(), trigger=("messages", 10), keep=("messages", 3) + ) + + # Messages: [Human, AI, Tool, Tool, Tool, Human] + messages: list[AnyMessage] = [ + HumanMessage(content="msg1"), + AIMessage( + content="ai", + tool_calls=[ + {"name": "tool1", "args": {}, "id": "call1"}, + {"name": "tool2", "args": {}, "id": "call2"}, + {"name": "tool3", "args": {}, "id": "call3"}, + ], + ), + ToolMessage(content="result1", tool_call_id="call1"), + ToolMessage(content="result2", tool_call_id="call2"), + ToolMessage(content="result3", tool_call_id="call3"), + HumanMessage(content="msg2"), + ] + + # Target cutoff index is len(messages) - messages_to_keep = 6 - 3 = 3 + # Index 3 is a ToolMessage, so we advance past the tool sequence to index 5 + cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=3) + assert cutoff == 5 + + # With messages_to_keep=2, target cutoff index is 6 - 2 = 4 + # Index 4 is a ToolMessage, so we advance past the tool sequence to index 5 + # This is aggressive - we keep only 1 message instead of 2 + cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=2) + assert cutoff == 5 + + +def test_summarization_middleware_cutoff_at_start_of_tool_sequence() -> None: + """Test cutoff when target lands exactly at the first ToolMessage.""" + middleware = SummarizationMiddleware( + model=MockChatModel(), trigger=("messages", 8), keep=("messages", 4) + ) + + messages: list[AnyMessage] = [ + HumanMessage(content="msg1"), + HumanMessage(content="msg2"), + AIMessage(content="ai", tool_calls=[{"name": "tool", "args": {}, "id": "call1"}]), + ToolMessage(content="result", tool_call_id="call1"), + HumanMessage(content="msg3"), + HumanMessage(content="msg4"), + ] + + # Target cutoff index is len(messages) - messages_to_keep = 6 - 4 = 2 + # Index 2 is an AIMessage (safe cutoff point), so no adjustment needed + cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=4) + assert cutoff == 2