refactor(langchain): engage summarization based on reported usage_metadata (#34632)

This commit is contained in:
ccurme
2026-01-08 11:12:00 -05:00
committed by GitHub
parent 50c5bb5607
commit d383f00489
2 changed files with 104 additions and 0 deletions

View File

@@ -324,6 +324,25 @@ class SummarizationMiddleware(AgentMiddleware):
]
}
def _should_summarize_based_on_reported_tokens(
self, messages: list[AnyMessage], threshold: float
) -> bool:
"""Check if reported token usage from last AIMessage exceeds threshold."""
last_ai_message = next(
(msg for msg in reversed(messages) if isinstance(msg, AIMessage)),
None,
)
if ( # noqa: SIM103
isinstance(last_ai_message, AIMessage)
and last_ai_message.usage_metadata is not None
and (reported_tokens := last_ai_message.usage_metadata.get("total_tokens", -1))
and reported_tokens >= threshold
and (message_provider := last_ai_message.response_metadata.get("model_provider"))
and message_provider == self.model._get_ls_params().get("ls_provider") # noqa: SLF001
):
return True
return False
def _should_summarize(self, messages: list[AnyMessage], total_tokens: int) -> bool:
"""Determine whether summarization should run for the current token usage."""
if not self._trigger_conditions:
@@ -334,6 +353,10 @@ class SummarizationMiddleware(AgentMiddleware):
return True
if kind == "tokens" and total_tokens >= value:
return True
if kind == "tokens" and self._should_summarize_based_on_reported_tokens(
messages, value
):
return True
if kind == "fraction":
max_input_tokens = self._get_profile_limits()
if max_input_tokens is None:
@@ -343,6 +366,9 @@ class SummarizationMiddleware(AgentMiddleware):
threshold = 1
if total_tokens >= threshold:
return True
if self._should_summarize_based_on_reported_tokens(messages, threshold):
return True
return False
def _determine_cutoff_index(self, messages: list[AnyMessage]) -> int:

View File

@@ -9,6 +9,7 @@ from langchain_core.outputs import ChatGeneration, ChatResult
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langchain.agents.middleware.summarization import SummarizationMiddleware
from langchain.chat_models import init_chat_model
from tests.unit_tests.agents.model import FakeToolCallingModel
@@ -1014,3 +1015,80 @@ def test_create_summary_uses_get_buffer_string_format() -> None:
f"str(messages) should produce significantly more tokens. "
f"Got ratio {str_ratio:.2f}x (expected > 1.5)"
)
@pytest.mark.requires("langchain_anthropic")
def test_usage_metadata_trigger() -> None:
model = init_chat_model("anthropic:claude-sonnet-4-5")
middleware = SummarizationMiddleware(
model=model, trigger=("tokens", 10_000), keep=("messages", 4)
)
messages: list[AnyMessage] = [
HumanMessage(content="msg1"),
AIMessage(
content="msg2",
tool_calls=[{"name": "tool", "args": {}, "id": "call1"}],
response_metadata={"model_provider": "anthropic"},
usage_metadata={
"input_tokens": 5000,
"output_tokens": 1000,
"total_tokens": 6000,
},
),
ToolMessage(content="result", tool_call_id="call1"),
AIMessage(
content="msg3",
response_metadata={"model_provider": "anthropic"},
usage_metadata={
"input_tokens": 6100,
"output_tokens": 900,
"total_tokens": 7000,
},
),
HumanMessage(content="msg4"),
AIMessage(
content="msg5",
response_metadata={"model_provider": "anthropic"},
usage_metadata={
"input_tokens": 7500,
"output_tokens": 2501,
"total_tokens": 10_001,
},
),
]
# reported token count should override count of zero
assert middleware._should_summarize(messages, 0)
# don't engage unless model provider matches
messages.extend(
[
HumanMessage(content="msg6"),
AIMessage(
content="msg7",
response_metadata={"model_provider": "not-anthropic"},
usage_metadata={
"input_tokens": 7500,
"output_tokens": 2501,
"total_tokens": 10_001,
},
),
]
)
assert not middleware._should_summarize(messages, 0)
# don't engage if subsequent message stays under threshold (e.g., after summarization)
messages.extend(
[
HumanMessage(content="msg8"),
AIMessage(
content="msg9",
response_metadata={"model_provider": "anthropic"},
usage_metadata={
"input_tokens": 7500,
"output_tokens": 2499,
"total_tokens": 9999,
},
),
]
)
assert not middleware._should_summarize(messages, 0)