From 7a2952210efde87883af0d52381bf9bd6ddc6561 Mon Sep 17 00:00:00 2001 From: ccurme Date: Mon, 1 Dec 2025 11:22:44 -0500 Subject: [PATCH] fix(langchain): (SummarizationMiddleware) adjust token counts based on model (#34161) --- .../agents/middleware/summarization.py | 15 ++++++++++++++- .../implementations/test_summarization.py | 17 +++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/libs/langchain_v1/langchain/agents/middleware/summarization.py b/libs/langchain_v1/langchain/agents/middleware/summarization.py index fec90b08b72..6e67b7e5d57 100644 --- a/libs/langchain_v1/langchain/agents/middleware/summarization.py +++ b/libs/langchain_v1/langchain/agents/middleware/summarization.py @@ -3,6 +3,7 @@ import uuid import warnings from collections.abc import Callable, Iterable, Mapping +from functools import partial from typing import Any, Literal, cast from langchain_core.messages import ( @@ -119,6 +120,15 @@ Example: """ +def _get_approximate_token_counter(model: BaseChatModel) -> TokenCounter: + """Tune parameters of approximate token counter based on model type.""" + if model._llm_type == "anthropic-chat": + # 3.3 was estimated in an offline experiment, comparing with Claude's token-counting + # API: https://platform.claude.com/docs/en/build-with-claude/token-counting + return partial(count_tokens_approximately, chars_per_token=3.3) + return count_tokens_approximately + + class SummarizationMiddleware(AgentMiddleware): """Summarizes conversation history when token limits are approached. @@ -234,7 +244,10 @@ class SummarizationMiddleware(AgentMiddleware): self._trigger_conditions = trigger_conditions self.keep = self._validate_context_size(keep, "keep") - self.token_counter = token_counter + if token_counter is count_tokens_approximately: + self.token_counter = _get_approximate_token_counter(self.model) + else: + self.token_counter = token_counter self.summary_prompt = summary_prompt self.trim_tokens_to_summarize = trim_tokens_to_summarize diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py index 3a1481051df..015b37bd2b6 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py @@ -892,3 +892,20 @@ def test_summarization_middleware_is_safe_cutoff_at_end() -> None: # Cutoff past the length should also be safe assert middleware._is_safe_cutoff_point(messages, len(messages) + 5) + + +def test_summarization_adjust_token_counts() -> None: + test_message = HumanMessage(content="a" * 12) + + middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 5)) + count_1 = middleware.token_counter([test_message]) + + class MockAnthropicModel(MockChatModel): + @property + def _llm_type(self) -> str: + return "anthropic-chat" + + middleware = SummarizationMiddleware(model=MockAnthropicModel(), trigger=("messages", 5)) + count_2 = middleware.token_counter([test_message]) + + assert count_1 != count_2