fix(langchain): (SummarizationMiddleware) adjust token counts based on model (#34161)

This commit is contained in:
ccurme
2025-12-01 11:22:44 -05:00
committed by GitHub
parent 7549845d82
commit 7a2952210e
2 changed files with 31 additions and 1 deletions

View File

@@ -3,6 +3,7 @@
import uuid import uuid
import warnings import warnings
from collections.abc import Callable, Iterable, Mapping from collections.abc import Callable, Iterable, Mapping
from functools import partial
from typing import Any, Literal, cast from typing import Any, Literal, cast
from langchain_core.messages import ( from langchain_core.messages import (
@@ -119,6 +120,15 @@ Example:
""" """
def _get_approximate_token_counter(model: BaseChatModel) -> TokenCounter:
"""Tune parameters of approximate token counter based on model type."""
if model._llm_type == "anthropic-chat":
# 3.3 was estimated in an offline experiment, comparing with Claude's token-counting
# API: https://platform.claude.com/docs/en/build-with-claude/token-counting
return partial(count_tokens_approximately, chars_per_token=3.3)
return count_tokens_approximately
class SummarizationMiddleware(AgentMiddleware): class SummarizationMiddleware(AgentMiddleware):
"""Summarizes conversation history when token limits are approached. """Summarizes conversation history when token limits are approached.
@@ -234,7 +244,10 @@ class SummarizationMiddleware(AgentMiddleware):
self._trigger_conditions = trigger_conditions self._trigger_conditions = trigger_conditions
self.keep = self._validate_context_size(keep, "keep") self.keep = self._validate_context_size(keep, "keep")
self.token_counter = token_counter if token_counter is count_tokens_approximately:
self.token_counter = _get_approximate_token_counter(self.model)
else:
self.token_counter = token_counter
self.summary_prompt = summary_prompt self.summary_prompt = summary_prompt
self.trim_tokens_to_summarize = trim_tokens_to_summarize self.trim_tokens_to_summarize = trim_tokens_to_summarize

View File

@@ -892,3 +892,20 @@ def test_summarization_middleware_is_safe_cutoff_at_end() -> None:
# Cutoff past the length should also be safe # Cutoff past the length should also be safe
assert middleware._is_safe_cutoff_point(messages, len(messages) + 5) assert middleware._is_safe_cutoff_point(messages, len(messages) + 5)
def test_summarization_adjust_token_counts() -> None:
test_message = HumanMessage(content="a" * 12)
middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 5))
count_1 = middleware.token_counter([test_message])
class MockAnthropicModel(MockChatModel):
@property
def _llm_type(self) -> str:
return "anthropic-chat"
middleware = SummarizationMiddleware(model=MockAnthropicModel(), trigger=("messages", 5))
count_2 = middleware.token_counter([test_message])
assert count_1 != count_2