mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
fix(langchain): fix token counting on partial message sequences (#35101)
This commit is contained in:
@@ -267,8 +267,12 @@ class SummarizationMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, R
|
||||
self.keep = self._validate_context_size(keep, "keep")
|
||||
if token_counter is count_tokens_approximately:
|
||||
self.token_counter = _get_approximate_token_counter(self.model)
|
||||
self._partial_token_counter: TokenCounter = partial( # type: ignore[call-arg]
|
||||
self.token_counter, use_usage_metadata_scaling=False
|
||||
)
|
||||
else:
|
||||
self.token_counter = token_counter
|
||||
self._partial_token_counter = token_counter
|
||||
self.summary_prompt = summary_prompt
|
||||
self.trim_tokens_to_summarize = trim_tokens_to_summarize
|
||||
|
||||
@@ -452,7 +456,7 @@ class SummarizationMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, R
|
||||
break
|
||||
|
||||
mid = (left + right) // 2
|
||||
if self.token_counter(messages[mid:]) <= target_token_count:
|
||||
if self._partial_token_counter(messages[mid:]) <= target_token_count:
|
||||
cutoff_candidate = mid
|
||||
right = mid
|
||||
else:
|
||||
|
||||
@@ -353,8 +353,8 @@ def test_summarization_middleware_token_retention_preserves_ai_tool_pairs() -> N
|
||||
model=ProfileChatModel(),
|
||||
trigger=("fraction", 0.1),
|
||||
keep=("fraction", 0.5),
|
||||
token_counter=token_counter,
|
||||
)
|
||||
middleware.token_counter = token_counter
|
||||
|
||||
# Total tokens: 300 + 200 + 50 + 180 + 160 = 890
|
||||
# Target keep: 500 tokens (50% of 1000)
|
||||
@@ -1028,6 +1028,32 @@ def test_summarization_middleware_many_parallel_tool_calls_safety() -> None:
|
||||
assert middleware._find_safe_cutoff_point(messages, 1) == 1
|
||||
|
||||
|
||||
def test_summarization_before_model_uses_unscaled_tokens_for_cutoff() -> None:
|
||||
calls: list[dict[str, Any]] = []
|
||||
|
||||
def fake_counter(_: Iterable[MessageLikeRepresentation], **kwargs: Any) -> int:
|
||||
calls.append(kwargs)
|
||||
return 100
|
||||
|
||||
with patch(
|
||||
"langchain.agents.middleware.summarization.count_tokens_approximately",
|
||||
side_effect=fake_counter,
|
||||
) as mock_counter:
|
||||
middleware = SummarizationMiddleware(
|
||||
model=MockChatModel(),
|
||||
trigger=("tokens", 1),
|
||||
keep=("tokens", 1),
|
||||
token_counter=mock_counter,
|
||||
)
|
||||
state = AgentState[Any](messages=[HumanMessage(content="one"), HumanMessage(content="two")])
|
||||
assert middleware.before_model(state, Runtime()) is not None
|
||||
|
||||
# Test we support partial token counting (which for default token counter does not
|
||||
# use use_usage_metadata_scaling)
|
||||
assert any(call.get("use_usage_metadata_scaling") is False for call in calls)
|
||||
assert any(call.get("use_usage_metadata_scaling") is True for call in calls)
|
||||
|
||||
|
||||
def test_summarization_middleware_find_safe_cutoff_preserves_ai_tool_pair() -> None:
|
||||
"""Test `_find_safe_cutoff` preserves AI/Tool message pairs together."""
|
||||
middleware = SummarizationMiddleware(
|
||||
|
||||
Reference in New Issue
Block a user