mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 18:50:33 +00:00
fix(langchain): keep tool call / AIMessage pairings when summarizing (#34609)
Fixes #34282 **Before:** When using agents with tools (like file reading, web search, etc.), the conversation looks like this: ``` [User] "Read these 10 files and summarize them" [AI] "I'll read all 10 files" + [tool_call: read_file x 10] [Tool] "Contents of file1.txt..." [Tool] "Contents of file2.txt..." [Tool] "Contents of file3.txt..." ... (7 more tool responses) ``` When the conversation gets too long, `SummarizationMiddleware` kicks in to compress older messages. The problem was: If you asked to keep the last 6 messages, you'd get: ``` [Summary] "Here's what happened before..." [Tool] "Contents of file5.txt..." [Tool] "Contents of file6.txt..." [Tool] "Contents of file7.txt..." [Tool] "Contents of file8.txt..." [Tool] "Contents of file9.txt..." [Tool] "Contents of file10.txt..." ``` The AI's original request to read the files (`[AI]` message with `tool_calls`) was summarized away, but the tool responses remained. This caused the error: ``` Error code: 400 - "No tool call found for function call output with call_id..." ``` Many APIs require that every tool response has a matching tool request. Without the AI message, the tool responses are "orphaned." ## The fix Now when the cutoff lands on tool messages, we **move backward** to include the AI message that requested those tools: Same scenario, keeping last 6 messages: ``` [Summary] "Here's what happened before..." [AI] "I'll read all 10 files" + [tool_call: read_file x 10] [Tool] "Contents of file1.txt..." [Tool] "Contents of file2.txt..." ... (all 10 tool responses) ``` The AI message is preserved along with its tool responses, keeping them paired together. ## Practical examples ### Example 1: Parallel tool calls **Scenario:** Agent reads 10 files in parallel, summarization triggers (see above) ### Example 2: Mixed conversation **Scenario:** User asks question, AI uses tools, user says thanks ``` [User] "What's the weather?" [AI] "Let me check" + [tool_call: get_weather] [Tool] "72F and sunny" [AI] "It's 72F and sunny!" [User] "Thanks!" ``` Keeping last 2 messages: | Before (Bug) | After (Fix) | |--------------|-------------| | Only `[User] "Thanks!"` kept | `[AI] + [Tool] + [AI] + [User]` all kept | | Lost the weather info | Tool pair preserved with response | ### Example 3: Multiple tool sequences ``` [User] "Search for X" [AI] [tool_call: search] [Tool] "Results for X" [User] "Now search for Y" [AI] [tool_call: search] [Tool] "Results for Y" [User] "Great!" ``` **Keeping last 3 messages:** If cutoff lands on `[Tool] "Results for Y"`, we now include `[AI] [tool_call: search]` to keep the pair together.
This commit is contained in:
@@ -7,6 +7,7 @@ from functools import partial
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
MessageLikeRepresentation,
|
||||
RemoveMessage,
|
||||
@@ -478,13 +479,37 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
def _find_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> int:
|
||||
"""Find a safe cutoff point that doesn't split AI/Tool message pairs.
|
||||
|
||||
If the message at cutoff_index is a ToolMessage, advance until we find
|
||||
a non-ToolMessage. This ensures we never cut in the middle of parallel
|
||||
tool call responses.
|
||||
If the message at `cutoff_index` is a `ToolMessage`, search backward for the
|
||||
`AIMessage` containing the corresponding `tool_calls` and adjust the cutoff to
|
||||
include it. This ensures tool call requests and responses stay together.
|
||||
|
||||
Falls back to advancing forward past `ToolMessage` objects only if no matching
|
||||
`AIMessage` is found (edge case).
|
||||
"""
|
||||
while cutoff_index < len(messages) and isinstance(messages[cutoff_index], ToolMessage):
|
||||
cutoff_index += 1
|
||||
return cutoff_index
|
||||
if cutoff_index >= len(messages) or not isinstance(messages[cutoff_index], ToolMessage):
|
||||
return cutoff_index
|
||||
|
||||
# Collect tool_call_ids from consecutive ToolMessages at/after cutoff
|
||||
tool_call_ids: set[str] = set()
|
||||
idx = cutoff_index
|
||||
while idx < len(messages) and isinstance(messages[idx], ToolMessage):
|
||||
tool_msg = cast("ToolMessage", messages[idx])
|
||||
if tool_msg.tool_call_id:
|
||||
tool_call_ids.add(tool_msg.tool_call_id)
|
||||
idx += 1
|
||||
|
||||
# Search backward for AIMessage with matching tool_calls
|
||||
for i in range(cutoff_index - 1, -1, -1):
|
||||
msg = messages[i]
|
||||
if isinstance(msg, AIMessage) and msg.tool_calls:
|
||||
ai_tool_call_ids = {tc.get("id") for tc in msg.tool_calls if tc.get("id")}
|
||||
if tool_call_ids & ai_tool_call_ids:
|
||||
# Found the AIMessage - move cutoff to include it
|
||||
return i
|
||||
|
||||
# Fallback: no matching AIMessage found, advance past ToolMessages to avoid
|
||||
# orphaned tool responses
|
||||
return idx
|
||||
|
||||
def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
||||
"""Generate summary for the given messages."""
|
||||
|
||||
@@ -281,8 +281,8 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
|
||||
]
|
||||
|
||||
|
||||
def test_summarization_middleware_token_retention_advances_past_tool_messages() -> None:
|
||||
"""Ensure token retention advances past tool messages for aggressive summarization."""
|
||||
def test_summarization_middleware_token_retention_preserves_ai_tool_pairs() -> None:
|
||||
"""Ensure token retention preserves AI/Tool message pairs together."""
|
||||
|
||||
def token_counter(messages: list[AnyMessage]) -> int:
|
||||
return sum(len(getattr(message, "content", "")) for message in messages)
|
||||
@@ -297,7 +297,7 @@ def test_summarization_middleware_token_retention_advances_past_tool_messages()
|
||||
# Total tokens: 300 + 200 + 50 + 180 + 160 = 890
|
||||
# Target keep: 500 tokens (50% of 1000)
|
||||
# Binary search finds cutoff around index 2 (ToolMessage)
|
||||
# We advance past it to index 3 (HumanMessage)
|
||||
# We move back to index 1 to preserve the AIMessage with its ToolMessage
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="H" * 300),
|
||||
AIMessage(
|
||||
@@ -314,14 +314,15 @@ def test_summarization_middleware_token_retention_advances_past_tool_messages()
|
||||
assert result is not None
|
||||
|
||||
preserved_messages = result["messages"][2:]
|
||||
# With aggressive summarization, we advance past the ToolMessage
|
||||
# So we preserve messages from index 3 onward (the two HumanMessages)
|
||||
assert preserved_messages == messages[3:]
|
||||
# We move the cutoff back to include the AIMessage with its ToolMessage
|
||||
# So we preserve messages from index 1 onward (AI + Tool + Human + Human)
|
||||
assert preserved_messages == messages[1:]
|
||||
|
||||
# Verify preserved tokens are within budget
|
||||
target_token_count = int(1000 * 0.5)
|
||||
preserved_tokens = middleware.token_counter(preserved_messages)
|
||||
assert preserved_tokens <= target_token_count
|
||||
# Verify the AI/Tool pair is preserved together
|
||||
assert isinstance(preserved_messages[0], AIMessage)
|
||||
assert preserved_messages[0].tool_calls
|
||||
assert isinstance(preserved_messages[1], ToolMessage)
|
||||
assert preserved_messages[1].tool_call_id == preserved_messages[0].tool_calls[0]["id"]
|
||||
|
||||
|
||||
def test_summarization_middleware_missing_profile() -> None:
|
||||
@@ -666,7 +667,7 @@ def test_summarization_middleware_binary_search_edge_cases() -> None:
|
||||
|
||||
|
||||
def test_summarization_middleware_find_safe_cutoff_point() -> None:
|
||||
"""Test _find_safe_cutoff_point finds safe cutoff past ToolMessages."""
|
||||
"""Test `_find_safe_cutoff_point` preserves AI/Tool message pairs."""
|
||||
model = FakeToolCallingModel()
|
||||
middleware = SummarizationMiddleware(
|
||||
model=model, trigger=("messages", 10), keep=("messages", 2)
|
||||
@@ -676,7 +677,7 @@ def test_summarization_middleware_find_safe_cutoff_point() -> None:
|
||||
HumanMessage(content="msg1"),
|
||||
AIMessage(content="ai", tool_calls=[{"name": "tool", "args": {}, "id": "call1"}]),
|
||||
ToolMessage(content="result1", tool_call_id="call1"),
|
||||
ToolMessage(content="result2", tool_call_id="call2"),
|
||||
ToolMessage(content="result2", tool_call_id="call2"), # orphan - no matching AI
|
||||
HumanMessage(content="msg2"),
|
||||
]
|
||||
|
||||
@@ -684,8 +685,14 @@ def test_summarization_middleware_find_safe_cutoff_point() -> None:
|
||||
assert middleware._find_safe_cutoff_point(messages, 0) == 0
|
||||
assert middleware._find_safe_cutoff_point(messages, 1) == 1
|
||||
|
||||
# Starting at a ToolMessage advances to the next non-ToolMessage
|
||||
assert middleware._find_safe_cutoff_point(messages, 2) == 4
|
||||
# Starting at ToolMessage with matching AIMessage moves back to include it
|
||||
# ToolMessage at index 2 has tool_call_id="call1" which matches AIMessage at index 1
|
||||
assert middleware._find_safe_cutoff_point(messages, 2) == 1
|
||||
|
||||
# Starting at orphan ToolMessage (no matching AIMessage) falls back to advancing
|
||||
# ToolMessage at index 3 has tool_call_id="call2" with no matching AIMessage
|
||||
# Since we only collect from cutoff_index onwards, only {call2} is collected
|
||||
# No match found, so we fall back to advancing past ToolMessages
|
||||
assert middleware._find_safe_cutoff_point(messages, 3) == 4
|
||||
|
||||
# Starting at the HumanMessage after tools returns that index
|
||||
@@ -699,6 +706,65 @@ def test_summarization_middleware_find_safe_cutoff_point() -> None:
|
||||
assert middleware._find_safe_cutoff_point(messages, len(messages) + 5) == len(messages) + 5
|
||||
|
||||
|
||||
def test_summarization_middleware_find_safe_cutoff_point_orphan_tool() -> None:
|
||||
"""Test `_find_safe_cutoff_point` with truly orphan `ToolMessage` (no matching `AIMessage`)."""
|
||||
model = FakeToolCallingModel()
|
||||
middleware = SummarizationMiddleware(
|
||||
model=model, trigger=("messages", 10), keep=("messages", 2)
|
||||
)
|
||||
|
||||
# Messages where ToolMessage has no matching AIMessage at all
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="msg1"),
|
||||
AIMessage(content="ai_no_tools"), # No tool_calls
|
||||
ToolMessage(content="orphan_result", tool_call_id="orphan_call"),
|
||||
HumanMessage(content="msg2"),
|
||||
]
|
||||
|
||||
# Starting at orphan ToolMessage falls back to advancing forward
|
||||
assert middleware._find_safe_cutoff_point(messages, 2) == 3
|
||||
|
||||
|
||||
def test_summarization_cutoff_moves_backward_to_include_ai_message() -> None:
|
||||
"""Test that cutoff moves backward to include `AIMessage` with its `ToolMessage`s.
|
||||
|
||||
Previously, when the cutoff landed on a `ToolMessage`, the code would advance
|
||||
FORWARD past all `ToolMessage`s. This could result in orphaned `ToolMessage`s (kept
|
||||
without their `AIMessage`) or aggressive summarization that removed AI/Tool pairs.
|
||||
|
||||
The fix searches backward from a `ToolMessage` to find the `AIMessage` with matching
|
||||
`tool_calls`, ensuring the pair stays together in the preserved messages.
|
||||
"""
|
||||
model = FakeToolCallingModel()
|
||||
middleware = SummarizationMiddleware(
|
||||
model=model, trigger=("messages", 10), keep=("messages", 2)
|
||||
)
|
||||
|
||||
# Scenario: cutoff lands on ToolMessage that has a matching AIMessage before it
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="initial question"), # index 0
|
||||
AIMessage(
|
||||
content="I'll use a tool",
|
||||
tool_calls=[{"name": "search", "args": {"q": "test"}, "id": "call_abc"}],
|
||||
), # index 1
|
||||
ToolMessage(content="search result", tool_call_id="call_abc"), # index 2
|
||||
HumanMessage(content="followup"), # index 3
|
||||
]
|
||||
|
||||
# When cutoff is at index 2 (ToolMessage), it should move BACKWARD to index 1
|
||||
# to include the AIMessage that generated the tool call
|
||||
result = middleware._find_safe_cutoff_point(messages, 2)
|
||||
|
||||
assert result == 1, (
|
||||
f"Expected cutoff to move backward to index 1 (AIMessage), got {result}. "
|
||||
"The cutoff should preserve AI/Tool pairs together."
|
||||
)
|
||||
|
||||
assert isinstance(messages[result], AIMessage)
|
||||
assert messages[result].tool_calls # type: ignore[union-attr]
|
||||
assert messages[result].tool_calls[0]["id"] == "call_abc" # type: ignore[union-attr]
|
||||
|
||||
|
||||
def test_summarization_middleware_zero_and_negative_target_tokens() -> None:
|
||||
"""Test handling of edge cases with target token calculations."""
|
||||
# Test with very small fraction that rounds to zero
|
||||
@@ -814,7 +880,7 @@ def test_summarization_adjust_token_counts() -> None:
|
||||
|
||||
|
||||
def test_summarization_middleware_many_parallel_tool_calls_safety() -> None:
|
||||
"""Test cutoff safety with many parallel tool calls extending beyond old search range."""
|
||||
"""Test cutoff safety preserves AI message with many parallel tool calls."""
|
||||
middleware = SummarizationMiddleware(
|
||||
model=MockChatModel(), trigger=("messages", 15), keep=("messages", 5)
|
||||
)
|
||||
@@ -826,20 +892,21 @@ def test_summarization_middleware_many_parallel_tool_calls_safety() -> None:
|
||||
]
|
||||
messages: list[AnyMessage] = [human_message, ai_message, *tool_messages]
|
||||
|
||||
# Cutoff at index 7 (a ToolMessage) advances to index 12 (end of messages)
|
||||
assert middleware._find_safe_cutoff_point(messages, 7) == 12
|
||||
# Cutoff at index 7 (a ToolMessage) moves back to index 1 (AIMessage)
|
||||
# to preserve the AI/Tool pair together
|
||||
assert middleware._find_safe_cutoff_point(messages, 7) == 1
|
||||
|
||||
# Any cutoff pointing at a ToolMessage (indices 2-11) advances to index 12
|
||||
# Any cutoff pointing at a ToolMessage (indices 2-11) moves back to index 1
|
||||
for i in range(2, 12):
|
||||
assert middleware._find_safe_cutoff_point(messages, i) == 12
|
||||
assert middleware._find_safe_cutoff_point(messages, i) == 1
|
||||
|
||||
# Cutoff at index 0, 1 (before tool messages) stays the same
|
||||
assert middleware._find_safe_cutoff_point(messages, 0) == 0
|
||||
assert middleware._find_safe_cutoff_point(messages, 1) == 1
|
||||
|
||||
|
||||
def test_summarization_middleware_find_safe_cutoff_advances_past_tools() -> None:
|
||||
"""Test _find_safe_cutoff advances past ToolMessages to find safe cutoff."""
|
||||
def test_summarization_middleware_find_safe_cutoff_preserves_ai_tool_pair() -> None:
|
||||
"""Test `_find_safe_cutoff` preserves AI/Tool message pairs together."""
|
||||
middleware = SummarizationMiddleware(
|
||||
model=MockChatModel(), trigger=("messages", 10), keep=("messages", 3)
|
||||
)
|
||||
@@ -862,15 +929,15 @@ def test_summarization_middleware_find_safe_cutoff_advances_past_tools() -> None
|
||||
]
|
||||
|
||||
# Target cutoff index is len(messages) - messages_to_keep = 6 - 3 = 3
|
||||
# Index 3 is a ToolMessage, so we advance past the tool sequence to index 5
|
||||
# Index 3 is a ToolMessage, we move back to index 1 to include AIMessage
|
||||
cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=3)
|
||||
assert cutoff == 5
|
||||
assert cutoff == 1
|
||||
|
||||
# With messages_to_keep=2, target cutoff index is 6 - 2 = 4
|
||||
# Index 4 is a ToolMessage, so we advance past the tool sequence to index 5
|
||||
# This is aggressive - we keep only 1 message instead of 2
|
||||
# Index 4 is a ToolMessage, we move back to index 1 to include AIMessage
|
||||
# This preserves the AI + Tools + Human, more than requested but valid
|
||||
cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=2)
|
||||
assert cutoff == 5
|
||||
assert cutoff == 1
|
||||
|
||||
|
||||
def test_summarization_middleware_cutoff_at_start_of_tool_sequence() -> None:
|
||||
|
||||
Reference in New Issue
Block a user