mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
refactor(langchain): engage summarization based on reported usage_metadata (#34632)
This commit is contained in:
@@ -324,6 +324,25 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
]
|
||||
}
|
||||
|
||||
def _should_summarize_based_on_reported_tokens(
|
||||
self, messages: list[AnyMessage], threshold: float
|
||||
) -> bool:
|
||||
"""Check if reported token usage from last AIMessage exceeds threshold."""
|
||||
last_ai_message = next(
|
||||
(msg for msg in reversed(messages) if isinstance(msg, AIMessage)),
|
||||
None,
|
||||
)
|
||||
if ( # noqa: SIM103
|
||||
isinstance(last_ai_message, AIMessage)
|
||||
and last_ai_message.usage_metadata is not None
|
||||
and (reported_tokens := last_ai_message.usage_metadata.get("total_tokens", -1))
|
||||
and reported_tokens >= threshold
|
||||
and (message_provider := last_ai_message.response_metadata.get("model_provider"))
|
||||
and message_provider == self.model._get_ls_params().get("ls_provider") # noqa: SLF001
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _should_summarize(self, messages: list[AnyMessage], total_tokens: int) -> bool:
|
||||
"""Determine whether summarization should run for the current token usage."""
|
||||
if not self._trigger_conditions:
|
||||
@@ -334,6 +353,10 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
return True
|
||||
if kind == "tokens" and total_tokens >= value:
|
||||
return True
|
||||
if kind == "tokens" and self._should_summarize_based_on_reported_tokens(
|
||||
messages, value
|
||||
):
|
||||
return True
|
||||
if kind == "fraction":
|
||||
max_input_tokens = self._get_profile_limits()
|
||||
if max_input_tokens is None:
|
||||
@@ -343,6 +366,9 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
threshold = 1
|
||||
if total_tokens >= threshold:
|
||||
return True
|
||||
|
||||
if self._should_summarize_based_on_reported_tokens(messages, threshold):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _determine_cutoff_index(self, messages: list[AnyMessage]) -> int:
|
||||
|
||||
@@ -9,6 +9,7 @@ from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
|
||||
from langchain.agents.middleware.summarization import SummarizationMiddleware
|
||||
from langchain.chat_models import init_chat_model
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
|
||||
|
||||
@@ -1014,3 +1015,80 @@ def test_create_summary_uses_get_buffer_string_format() -> None:
|
||||
f"str(messages) should produce significantly more tokens. "
|
||||
f"Got ratio {str_ratio:.2f}x (expected > 1.5)"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_anthropic")
|
||||
def test_usage_metadata_trigger() -> None:
|
||||
model = init_chat_model("anthropic:claude-sonnet-4-5")
|
||||
middleware = SummarizationMiddleware(
|
||||
model=model, trigger=("tokens", 10_000), keep=("messages", 4)
|
||||
)
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="msg1"),
|
||||
AIMessage(
|
||||
content="msg2",
|
||||
tool_calls=[{"name": "tool", "args": {}, "id": "call1"}],
|
||||
response_metadata={"model_provider": "anthropic"},
|
||||
usage_metadata={
|
||||
"input_tokens": 5000,
|
||||
"output_tokens": 1000,
|
||||
"total_tokens": 6000,
|
||||
},
|
||||
),
|
||||
ToolMessage(content="result", tool_call_id="call1"),
|
||||
AIMessage(
|
||||
content="msg3",
|
||||
response_metadata={"model_provider": "anthropic"},
|
||||
usage_metadata={
|
||||
"input_tokens": 6100,
|
||||
"output_tokens": 900,
|
||||
"total_tokens": 7000,
|
||||
},
|
||||
),
|
||||
HumanMessage(content="msg4"),
|
||||
AIMessage(
|
||||
content="msg5",
|
||||
response_metadata={"model_provider": "anthropic"},
|
||||
usage_metadata={
|
||||
"input_tokens": 7500,
|
||||
"output_tokens": 2501,
|
||||
"total_tokens": 10_001,
|
||||
},
|
||||
),
|
||||
]
|
||||
# reported token count should override count of zero
|
||||
assert middleware._should_summarize(messages, 0)
|
||||
|
||||
# don't engage unless model provider matches
|
||||
messages.extend(
|
||||
[
|
||||
HumanMessage(content="msg6"),
|
||||
AIMessage(
|
||||
content="msg7",
|
||||
response_metadata={"model_provider": "not-anthropic"},
|
||||
usage_metadata={
|
||||
"input_tokens": 7500,
|
||||
"output_tokens": 2501,
|
||||
"total_tokens": 10_001,
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
assert not middleware._should_summarize(messages, 0)
|
||||
|
||||
# don't engage if subsequent message stays under threshold (e.g., after summarization)
|
||||
messages.extend(
|
||||
[
|
||||
HumanMessage(content="msg8"),
|
||||
AIMessage(
|
||||
content="msg9",
|
||||
response_metadata={"model_provider": "anthropic"},
|
||||
usage_metadata={
|
||||
"input_tokens": 7500,
|
||||
"output_tokens": 2499,
|
||||
"total_tokens": 9999,
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
assert not middleware._should_summarize(messages, 0)
|
||||
|
||||
Reference in New Issue
Block a user