mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
fix(langchain): (SummarizationMiddleware) adjust token counts based on model (#34161)
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from functools import partial
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from langchain_core.messages import (
|
||||
@@ -119,6 +120,15 @@ Example:
|
||||
"""
|
||||
|
||||
|
||||
def _get_approximate_token_counter(model: BaseChatModel) -> TokenCounter:
|
||||
"""Tune parameters of approximate token counter based on model type."""
|
||||
if model._llm_type == "anthropic-chat":
|
||||
# 3.3 was estimated in an offline experiment, comparing with Claude's token-counting
|
||||
# API: https://platform.claude.com/docs/en/build-with-claude/token-counting
|
||||
return partial(count_tokens_approximately, chars_per_token=3.3)
|
||||
return count_tokens_approximately
|
||||
|
||||
|
||||
class SummarizationMiddleware(AgentMiddleware):
|
||||
"""Summarizes conversation history when token limits are approached.
|
||||
|
||||
@@ -234,7 +244,10 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
self._trigger_conditions = trigger_conditions
|
||||
|
||||
self.keep = self._validate_context_size(keep, "keep")
|
||||
self.token_counter = token_counter
|
||||
if token_counter is count_tokens_approximately:
|
||||
self.token_counter = _get_approximate_token_counter(self.model)
|
||||
else:
|
||||
self.token_counter = token_counter
|
||||
self.summary_prompt = summary_prompt
|
||||
self.trim_tokens_to_summarize = trim_tokens_to_summarize
|
||||
|
||||
|
||||
@@ -892,3 +892,20 @@ def test_summarization_middleware_is_safe_cutoff_at_end() -> None:
|
||||
|
||||
# Cutoff past the length should also be safe
|
||||
assert middleware._is_safe_cutoff_point(messages, len(messages) + 5)
|
||||
|
||||
|
||||
def test_summarization_adjust_token_counts() -> None:
|
||||
test_message = HumanMessage(content="a" * 12)
|
||||
|
||||
middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 5))
|
||||
count_1 = middleware.token_counter([test_message])
|
||||
|
||||
class MockAnthropicModel(MockChatModel):
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "anthropic-chat"
|
||||
|
||||
middleware = SummarizationMiddleware(model=MockAnthropicModel(), trigger=("messages", 5))
|
||||
count_2 = middleware.token_counter([test_message])
|
||||
|
||||
assert count_1 != count_2
|
||||
|
||||
Reference in New Issue
Block a user