From 0040e1a8aa366a1aeff585661c17cafb4281aafe Mon Sep 17 00:00:00 2001 From: ccurme Date: Mon, 9 Feb 2026 15:27:17 -0500 Subject: [PATCH] fix(langchain): fix token counting on partial message sequences (#35101) --- .../agents/middleware/summarization.py | 6 +++- .../implementations/test_summarization.py | 28 ++++++++++++++++++- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/libs/langchain_v1/langchain/agents/middleware/summarization.py b/libs/langchain_v1/langchain/agents/middleware/summarization.py index a59b23ffc03..42d2760b4e1 100644 --- a/libs/langchain_v1/langchain/agents/middleware/summarization.py +++ b/libs/langchain_v1/langchain/agents/middleware/summarization.py @@ -267,8 +267,12 @@ class SummarizationMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, R self.keep = self._validate_context_size(keep, "keep") if token_counter is count_tokens_approximately: self.token_counter = _get_approximate_token_counter(self.model) + self._partial_token_counter: TokenCounter = partial( # type: ignore[call-arg] + self.token_counter, use_usage_metadata_scaling=False + ) else: self.token_counter = token_counter + self._partial_token_counter = token_counter self.summary_prompt = summary_prompt self.trim_tokens_to_summarize = trim_tokens_to_summarize @@ -452,7 +456,7 @@ class SummarizationMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, R break mid = (left + right) // 2 - if self.token_counter(messages[mid:]) <= target_token_count: + if self._partial_token_counter(messages[mid:]) <= target_token_count: cutoff_candidate = mid right = mid else: diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py index bbe5b1388ac..91ef582ef4a 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py @@ -353,8 +353,8 @@ def test_summarization_middleware_token_retention_preserves_ai_tool_pairs() -> N model=ProfileChatModel(), trigger=("fraction", 0.1), keep=("fraction", 0.5), + token_counter=token_counter, ) - middleware.token_counter = token_counter # Total tokens: 300 + 200 + 50 + 180 + 160 = 890 # Target keep: 500 tokens (50% of 1000) @@ -1028,6 +1028,32 @@ def test_summarization_middleware_many_parallel_tool_calls_safety() -> None: assert middleware._find_safe_cutoff_point(messages, 1) == 1 +def test_summarization_before_model_uses_unscaled_tokens_for_cutoff() -> None: + calls: list[dict[str, Any]] = [] + + def fake_counter(_: Iterable[MessageLikeRepresentation], **kwargs: Any) -> int: + calls.append(kwargs) + return 100 + + with patch( + "langchain.agents.middleware.summarization.count_tokens_approximately", + side_effect=fake_counter, + ) as mock_counter: + middleware = SummarizationMiddleware( + model=MockChatModel(), + trigger=("tokens", 1), + keep=("tokens", 1), + token_counter=mock_counter, + ) + state = AgentState[Any](messages=[HumanMessage(content="one"), HumanMessage(content="two")]) + assert middleware.before_model(state, Runtime()) is not None + + # Test we support partial token counting (which for default token counter does not + # use use_usage_metadata_scaling) + assert any(call.get("use_usage_metadata_scaling") is False for call in calls) + assert any(call.get("use_usage_metadata_scaling") is True for call in calls) + + def test_summarization_middleware_find_safe_cutoff_preserves_ai_tool_pair() -> None: """Test `_find_safe_cutoff` preserves AI/Tool message pairs together.""" middleware = SummarizationMiddleware(