fix(langchain): fix token counting on partial message sequences (#35101)

This commit is contained in:
ccurme
2026-02-09 15:27:17 -05:00
committed by GitHub
parent ce5f73e07c
commit 0040e1a8aa
2 changed files with 32 additions and 2 deletions

View File

@@ -267,8 +267,12 @@ class SummarizationMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, R
self.keep = self._validate_context_size(keep, "keep")
if token_counter is count_tokens_approximately:
self.token_counter = _get_approximate_token_counter(self.model)
self._partial_token_counter: TokenCounter = partial( # type: ignore[call-arg]
self.token_counter, use_usage_metadata_scaling=False
)
else:
self.token_counter = token_counter
self._partial_token_counter = token_counter
self.summary_prompt = summary_prompt
self.trim_tokens_to_summarize = trim_tokens_to_summarize
@@ -452,7 +456,7 @@ class SummarizationMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, R
break
mid = (left + right) // 2
if self.token_counter(messages[mid:]) <= target_token_count:
if self._partial_token_counter(messages[mid:]) <= target_token_count:
cutoff_candidate = mid
right = mid
else:

View File

@@ -353,8 +353,8 @@ def test_summarization_middleware_token_retention_preserves_ai_tool_pairs() -> N
model=ProfileChatModel(),
trigger=("fraction", 0.1),
keep=("fraction", 0.5),
token_counter=token_counter,
)
middleware.token_counter = token_counter
# Total tokens: 300 + 200 + 50 + 180 + 160 = 890
# Target keep: 500 tokens (50% of 1000)
@@ -1028,6 +1028,32 @@ def test_summarization_middleware_many_parallel_tool_calls_safety() -> None:
assert middleware._find_safe_cutoff_point(messages, 1) == 1
def test_summarization_before_model_uses_unscaled_tokens_for_cutoff() -> None:
calls: list[dict[str, Any]] = []
def fake_counter(_: Iterable[MessageLikeRepresentation], **kwargs: Any) -> int:
calls.append(kwargs)
return 100
with patch(
"langchain.agents.middleware.summarization.count_tokens_approximately",
side_effect=fake_counter,
) as mock_counter:
middleware = SummarizationMiddleware(
model=MockChatModel(),
trigger=("tokens", 1),
keep=("tokens", 1),
token_counter=mock_counter,
)
state = AgentState[Any](messages=[HumanMessage(content="one"), HumanMessage(content="two")])
assert middleware.before_model(state, Runtime()) is not None
# Test we support partial token counting (which for default token counter does not
# use use_usage_metadata_scaling)
assert any(call.get("use_usage_metadata_scaling") is False for call in calls)
assert any(call.get("use_usage_metadata_scaling") is True for call in calls)
def test_summarization_middleware_find_safe_cutoff_preserves_ai_tool_pair() -> None:
"""Test `_find_safe_cutoff` preserves AI/Tool message pairs together."""
middleware = SummarizationMiddleware(