fix(langchain): (SummarizationMiddleware) adjust token counts based on model (#34161)

This commit is contained in:
ccurme
2025-12-01 11:22:44 -05:00
committed by GitHub
parent 7549845d82
commit 7a2952210e
2 changed files with 31 additions and 1 deletions

View File

@@ -3,6 +3,7 @@
import uuid
import warnings
from collections.abc import Callable, Iterable, Mapping
from functools import partial
from typing import Any, Literal, cast
from langchain_core.messages import (
@@ -119,6 +120,15 @@ Example:
"""
def _get_approximate_token_counter(model: BaseChatModel) -> TokenCounter:
"""Tune parameters of approximate token counter based on model type."""
if model._llm_type == "anthropic-chat":
# 3.3 was estimated in an offline experiment, comparing with Claude's token-counting
# API: https://platform.claude.com/docs/en/build-with-claude/token-counting
return partial(count_tokens_approximately, chars_per_token=3.3)
return count_tokens_approximately
class SummarizationMiddleware(AgentMiddleware):
"""Summarizes conversation history when token limits are approached.
@@ -234,7 +244,10 @@ class SummarizationMiddleware(AgentMiddleware):
self._trigger_conditions = trigger_conditions
self.keep = self._validate_context_size(keep, "keep")
self.token_counter = token_counter
if token_counter is count_tokens_approximately:
self.token_counter = _get_approximate_token_counter(self.model)
else:
self.token_counter = token_counter
self.summary_prompt = summary_prompt
self.trim_tokens_to_summarize = trim_tokens_to_summarize

View File

@@ -892,3 +892,20 @@ def test_summarization_middleware_is_safe_cutoff_at_end() -> None:
# Cutoff past the length should also be safe
assert middleware._is_safe_cutoff_point(messages, len(messages) + 5)
def test_summarization_adjust_token_counts() -> None:
test_message = HumanMessage(content="a" * 12)
middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 5))
count_1 = middleware.token_counter([test_message])
class MockAnthropicModel(MockChatModel):
@property
def _llm_type(self) -> str:
return "anthropic-chat"
middleware = SummarizationMiddleware(model=MockAnthropicModel(), trigger=("messages", 5))
count_2 = middleware.token_counter([test_message])
assert count_1 != count_2