fix(langchain): keep tool call / AIMessage pairings when summarizing (#34609)

Fixes #34282

**Before:** When using agents with tools (like file reading, web search,
etc.), the conversation looks like this:

```
[User]     "Read these 10 files and summarize them"
[AI]       "I'll read all 10 files" + [tool_call: read_file x 10]
[Tool]     "Contents of file1.txt..."
[Tool]     "Contents of file2.txt..."
[Tool]     "Contents of file3.txt..."
... (7 more tool responses)
```

When the conversation gets too long, `SummarizationMiddleware` kicks in
to compress older messages. The problem was:

If you asked to keep the last 6 messages, you'd get:

```
[Summary]  "Here's what happened before..."
[Tool]     "Contents of file5.txt..."
[Tool]     "Contents of file6.txt..."
[Tool]     "Contents of file7.txt..."
[Tool]     "Contents of file8.txt..."
[Tool]     "Contents of file9.txt..."
[Tool]     "Contents of file10.txt..."
```

The AI's original request to read the files (`[AI]` message with
`tool_calls`) was summarized away, but the tool responses remained. This
caused the error:

```
Error code: 400 - "No tool call found for function call output with call_id..."
```

Many APIs require that every tool response has a matching tool request.
Without the AI message, the tool responses are "orphaned."

## The fix

Now when the cutoff lands on tool messages, we **move backward** to
include the AI message that requested those tools:

Same scenario, keeping last 6 messages:

```
[Summary]  "Here's what happened before..."
[AI]       "I'll read all 10 files" + [tool_call: read_file x 10]
[Tool]     "Contents of file1.txt..."
[Tool]     "Contents of file2.txt..."
... (all 10 tool responses)
```

The AI message is preserved along with its tool responses, keeping them
paired together.

## Practical examples

### Example 1: Parallel tool calls

**Scenario:** Agent reads 10 files in parallel, summarization triggers
(see above)

### Example 2: Mixed conversation

**Scenario:** User asks question, AI uses tools, user says thanks

```
[User]     "What's the weather?"
[AI]       "Let me check" + [tool_call: get_weather]
[Tool]     "72F and sunny"
[AI]       "It's 72F and sunny!"
[User]     "Thanks!"
```

Keeping last 2 messages:

| Before (Bug) | After (Fix) |
|--------------|-------------|
| Only `[User] "Thanks!"` kept | `[AI] + [Tool] + [AI] + [User]` all
kept |
| Lost the weather info | Tool pair preserved with response |

### Example 3: Multiple tool sequences

```
[User]     "Search for X"
[AI]       [tool_call: search]
[Tool]     "Results for X"
[User]     "Now search for Y"
[AI]       [tool_call: search]
[Tool]     "Results for Y"
[User]     "Great!"
```

**Keeping last 3 messages:** If cutoff lands on `[Tool] "Results for
Y"`, we now include `[AI] [tool_call: search]` to keep the pair
together.
This commit is contained in:
Mason Daugherty
2026-01-08 10:07:56 -05:00
committed by GitHub
parent f805ea9601
commit 2b6911d9af
2 changed files with 124 additions and 32 deletions

View File

@@ -7,6 +7,7 @@ from functools import partial
from typing import Any, Literal, cast
from langchain_core.messages import (
AIMessage,
AnyMessage,
MessageLikeRepresentation,
RemoveMessage,
@@ -478,13 +479,37 @@ class SummarizationMiddleware(AgentMiddleware):
def _find_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> int:
"""Find a safe cutoff point that doesn't split AI/Tool message pairs.
If the message at cutoff_index is a ToolMessage, advance until we find
a non-ToolMessage. This ensures we never cut in the middle of parallel
tool call responses.
If the message at `cutoff_index` is a `ToolMessage`, search backward for the
`AIMessage` containing the corresponding `tool_calls` and adjust the cutoff to
include it. This ensures tool call requests and responses stay together.
Falls back to advancing forward past `ToolMessage` objects only if no matching
`AIMessage` is found (edge case).
"""
while cutoff_index < len(messages) and isinstance(messages[cutoff_index], ToolMessage):
cutoff_index += 1
return cutoff_index
if cutoff_index >= len(messages) or not isinstance(messages[cutoff_index], ToolMessage):
return cutoff_index
# Collect tool_call_ids from consecutive ToolMessages at/after cutoff
tool_call_ids: set[str] = set()
idx = cutoff_index
while idx < len(messages) and isinstance(messages[idx], ToolMessage):
tool_msg = cast("ToolMessage", messages[idx])
if tool_msg.tool_call_id:
tool_call_ids.add(tool_msg.tool_call_id)
idx += 1
# Search backward for AIMessage with matching tool_calls
for i in range(cutoff_index - 1, -1, -1):
msg = messages[i]
if isinstance(msg, AIMessage) and msg.tool_calls:
ai_tool_call_ids = {tc.get("id") for tc in msg.tool_calls if tc.get("id")}
if tool_call_ids & ai_tool_call_ids:
# Found the AIMessage - move cutoff to include it
return i
# Fallback: no matching AIMessage found, advance past ToolMessages to avoid
# orphaned tool responses
return idx
def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
"""Generate summary for the given messages."""

View File

@@ -281,8 +281,8 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
]
def test_summarization_middleware_token_retention_advances_past_tool_messages() -> None:
"""Ensure token retention advances past tool messages for aggressive summarization."""
def test_summarization_middleware_token_retention_preserves_ai_tool_pairs() -> None:
"""Ensure token retention preserves AI/Tool message pairs together."""
def token_counter(messages: list[AnyMessage]) -> int:
return sum(len(getattr(message, "content", "")) for message in messages)
@@ -297,7 +297,7 @@ def test_summarization_middleware_token_retention_advances_past_tool_messages()
# Total tokens: 300 + 200 + 50 + 180 + 160 = 890
# Target keep: 500 tokens (50% of 1000)
# Binary search finds cutoff around index 2 (ToolMessage)
# We advance past it to index 3 (HumanMessage)
# We move back to index 1 to preserve the AIMessage with its ToolMessage
messages: list[AnyMessage] = [
HumanMessage(content="H" * 300),
AIMessage(
@@ -314,14 +314,15 @@ def test_summarization_middleware_token_retention_advances_past_tool_messages()
assert result is not None
preserved_messages = result["messages"][2:]
# With aggressive summarization, we advance past the ToolMessage
# So we preserve messages from index 3 onward (the two HumanMessages)
assert preserved_messages == messages[3:]
# We move the cutoff back to include the AIMessage with its ToolMessage
# So we preserve messages from index 1 onward (AI + Tool + Human + Human)
assert preserved_messages == messages[1:]
# Verify preserved tokens are within budget
target_token_count = int(1000 * 0.5)
preserved_tokens = middleware.token_counter(preserved_messages)
assert preserved_tokens <= target_token_count
# Verify the AI/Tool pair is preserved together
assert isinstance(preserved_messages[0], AIMessage)
assert preserved_messages[0].tool_calls
assert isinstance(preserved_messages[1], ToolMessage)
assert preserved_messages[1].tool_call_id == preserved_messages[0].tool_calls[0]["id"]
def test_summarization_middleware_missing_profile() -> None:
@@ -666,7 +667,7 @@ def test_summarization_middleware_binary_search_edge_cases() -> None:
def test_summarization_middleware_find_safe_cutoff_point() -> None:
"""Test _find_safe_cutoff_point finds safe cutoff past ToolMessages."""
"""Test `_find_safe_cutoff_point` preserves AI/Tool message pairs."""
model = FakeToolCallingModel()
middleware = SummarizationMiddleware(
model=model, trigger=("messages", 10), keep=("messages", 2)
@@ -676,7 +677,7 @@ def test_summarization_middleware_find_safe_cutoff_point() -> None:
HumanMessage(content="msg1"),
AIMessage(content="ai", tool_calls=[{"name": "tool", "args": {}, "id": "call1"}]),
ToolMessage(content="result1", tool_call_id="call1"),
ToolMessage(content="result2", tool_call_id="call2"),
ToolMessage(content="result2", tool_call_id="call2"), # orphan - no matching AI
HumanMessage(content="msg2"),
]
@@ -684,8 +685,14 @@ def test_summarization_middleware_find_safe_cutoff_point() -> None:
assert middleware._find_safe_cutoff_point(messages, 0) == 0
assert middleware._find_safe_cutoff_point(messages, 1) == 1
# Starting at a ToolMessage advances to the next non-ToolMessage
assert middleware._find_safe_cutoff_point(messages, 2) == 4
# Starting at ToolMessage with matching AIMessage moves back to include it
# ToolMessage at index 2 has tool_call_id="call1" which matches AIMessage at index 1
assert middleware._find_safe_cutoff_point(messages, 2) == 1
# Starting at orphan ToolMessage (no matching AIMessage) falls back to advancing
# ToolMessage at index 3 has tool_call_id="call2" with no matching AIMessage
# Since we only collect from cutoff_index onwards, only {call2} is collected
# No match found, so we fall back to advancing past ToolMessages
assert middleware._find_safe_cutoff_point(messages, 3) == 4
# Starting at the HumanMessage after tools returns that index
@@ -699,6 +706,65 @@ def test_summarization_middleware_find_safe_cutoff_point() -> None:
assert middleware._find_safe_cutoff_point(messages, len(messages) + 5) == len(messages) + 5
def test_summarization_middleware_find_safe_cutoff_point_orphan_tool() -> None:
"""Test `_find_safe_cutoff_point` with truly orphan `ToolMessage` (no matching `AIMessage`)."""
model = FakeToolCallingModel()
middleware = SummarizationMiddleware(
model=model, trigger=("messages", 10), keep=("messages", 2)
)
# Messages where ToolMessage has no matching AIMessage at all
messages: list[AnyMessage] = [
HumanMessage(content="msg1"),
AIMessage(content="ai_no_tools"), # No tool_calls
ToolMessage(content="orphan_result", tool_call_id="orphan_call"),
HumanMessage(content="msg2"),
]
# Starting at orphan ToolMessage falls back to advancing forward
assert middleware._find_safe_cutoff_point(messages, 2) == 3
def test_summarization_cutoff_moves_backward_to_include_ai_message() -> None:
"""Test that cutoff moves backward to include `AIMessage` with its `ToolMessage`s.
Previously, when the cutoff landed on a `ToolMessage`, the code would advance
FORWARD past all `ToolMessage`s. This could result in orphaned `ToolMessage`s (kept
without their `AIMessage`) or aggressive summarization that removed AI/Tool pairs.
The fix searches backward from a `ToolMessage` to find the `AIMessage` with matching
`tool_calls`, ensuring the pair stays together in the preserved messages.
"""
model = FakeToolCallingModel()
middleware = SummarizationMiddleware(
model=model, trigger=("messages", 10), keep=("messages", 2)
)
# Scenario: cutoff lands on ToolMessage that has a matching AIMessage before it
messages: list[AnyMessage] = [
HumanMessage(content="initial question"), # index 0
AIMessage(
content="I'll use a tool",
tool_calls=[{"name": "search", "args": {"q": "test"}, "id": "call_abc"}],
), # index 1
ToolMessage(content="search result", tool_call_id="call_abc"), # index 2
HumanMessage(content="followup"), # index 3
]
# When cutoff is at index 2 (ToolMessage), it should move BACKWARD to index 1
# to include the AIMessage that generated the tool call
result = middleware._find_safe_cutoff_point(messages, 2)
assert result == 1, (
f"Expected cutoff to move backward to index 1 (AIMessage), got {result}. "
"The cutoff should preserve AI/Tool pairs together."
)
assert isinstance(messages[result], AIMessage)
assert messages[result].tool_calls # type: ignore[union-attr]
assert messages[result].tool_calls[0]["id"] == "call_abc" # type: ignore[union-attr]
def test_summarization_middleware_zero_and_negative_target_tokens() -> None:
"""Test handling of edge cases with target token calculations."""
# Test with very small fraction that rounds to zero
@@ -814,7 +880,7 @@ def test_summarization_adjust_token_counts() -> None:
def test_summarization_middleware_many_parallel_tool_calls_safety() -> None:
"""Test cutoff safety with many parallel tool calls extending beyond old search range."""
"""Test cutoff safety preserves AI message with many parallel tool calls."""
middleware = SummarizationMiddleware(
model=MockChatModel(), trigger=("messages", 15), keep=("messages", 5)
)
@@ -826,20 +892,21 @@ def test_summarization_middleware_many_parallel_tool_calls_safety() -> None:
]
messages: list[AnyMessage] = [human_message, ai_message, *tool_messages]
# Cutoff at index 7 (a ToolMessage) advances to index 12 (end of messages)
assert middleware._find_safe_cutoff_point(messages, 7) == 12
# Cutoff at index 7 (a ToolMessage) moves back to index 1 (AIMessage)
# to preserve the AI/Tool pair together
assert middleware._find_safe_cutoff_point(messages, 7) == 1
# Any cutoff pointing at a ToolMessage (indices 2-11) advances to index 12
# Any cutoff pointing at a ToolMessage (indices 2-11) moves back to index 1
for i in range(2, 12):
assert middleware._find_safe_cutoff_point(messages, i) == 12
assert middleware._find_safe_cutoff_point(messages, i) == 1
# Cutoff at index 0, 1 (before tool messages) stays the same
assert middleware._find_safe_cutoff_point(messages, 0) == 0
assert middleware._find_safe_cutoff_point(messages, 1) == 1
def test_summarization_middleware_find_safe_cutoff_advances_past_tools() -> None:
"""Test _find_safe_cutoff advances past ToolMessages to find safe cutoff."""
def test_summarization_middleware_find_safe_cutoff_preserves_ai_tool_pair() -> None:
"""Test `_find_safe_cutoff` preserves AI/Tool message pairs together."""
middleware = SummarizationMiddleware(
model=MockChatModel(), trigger=("messages", 10), keep=("messages", 3)
)
@@ -862,15 +929,15 @@ def test_summarization_middleware_find_safe_cutoff_advances_past_tools() -> None
]
# Target cutoff index is len(messages) - messages_to_keep = 6 - 3 = 3
# Index 3 is a ToolMessage, so we advance past the tool sequence to index 5
# Index 3 is a ToolMessage, we move back to index 1 to include AIMessage
cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=3)
assert cutoff == 5
assert cutoff == 1
# With messages_to_keep=2, target cutoff index is 6 - 2 = 4
# Index 4 is a ToolMessage, so we advance past the tool sequence to index 5
# This is aggressive - we keep only 1 message instead of 2
# Index 4 is a ToolMessage, we move back to index 1 to include AIMessage
# This preserves the AI + Tools + Human, more than requested but valid
cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=2)
assert cutoff == 5
assert cutoff == 1
def test_summarization_middleware_cutoff_at_start_of_tool_sequence() -> None: