diff --git a/libs/langchain_v1/langchain/agents/middleware/summarization.py b/libs/langchain_v1/langchain/agents/middleware/summarization.py index 8f6fb2789cd..b3533c2eb6f 100644 --- a/libs/langchain_v1/langchain/agents/middleware/summarization.py +++ b/libs/langchain_v1/langchain/agents/middleware/summarization.py @@ -324,6 +324,25 @@ class SummarizationMiddleware(AgentMiddleware): ] } + def _should_summarize_based_on_reported_tokens( + self, messages: list[AnyMessage], threshold: float + ) -> bool: + """Check if reported token usage from last AIMessage exceeds threshold.""" + last_ai_message = next( + (msg for msg in reversed(messages) if isinstance(msg, AIMessage)), + None, + ) + if ( # noqa: SIM103 + isinstance(last_ai_message, AIMessage) + and last_ai_message.usage_metadata is not None + and (reported_tokens := last_ai_message.usage_metadata.get("total_tokens", -1)) + and reported_tokens >= threshold + and (message_provider := last_ai_message.response_metadata.get("model_provider")) + and message_provider == self.model._get_ls_params().get("ls_provider") # noqa: SLF001 + ): + return True + return False + def _should_summarize(self, messages: list[AnyMessage], total_tokens: int) -> bool: """Determine whether summarization should run for the current token usage.""" if not self._trigger_conditions: @@ -334,6 +353,10 @@ class SummarizationMiddleware(AgentMiddleware): return True if kind == "tokens" and total_tokens >= value: return True + if kind == "tokens" and self._should_summarize_based_on_reported_tokens( + messages, value + ): + return True if kind == "fraction": max_input_tokens = self._get_profile_limits() if max_input_tokens is None: @@ -343,6 +366,9 @@ class SummarizationMiddleware(AgentMiddleware): threshold = 1 if total_tokens >= threshold: return True + + if self._should_summarize_based_on_reported_tokens(messages, threshold): + return True return False def _determine_cutoff_index(self, messages: list[AnyMessage]) -> int: diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py index c0ac6c82f51..728c6c97dfe 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py @@ -9,6 +9,7 @@ from langchain_core.outputs import ChatGeneration, ChatResult from langgraph.graph.message import REMOVE_ALL_MESSAGES from langchain.agents.middleware.summarization import SummarizationMiddleware +from langchain.chat_models import init_chat_model from tests.unit_tests.agents.model import FakeToolCallingModel @@ -1014,3 +1015,80 @@ def test_create_summary_uses_get_buffer_string_format() -> None: f"str(messages) should produce significantly more tokens. " f"Got ratio {str_ratio:.2f}x (expected > 1.5)" ) + + +@pytest.mark.requires("langchain_anthropic") +def test_usage_metadata_trigger() -> None: + model = init_chat_model("anthropic:claude-sonnet-4-5") + middleware = SummarizationMiddleware( + model=model, trigger=("tokens", 10_000), keep=("messages", 4) + ) + messages: list[AnyMessage] = [ + HumanMessage(content="msg1"), + AIMessage( + content="msg2", + tool_calls=[{"name": "tool", "args": {}, "id": "call1"}], + response_metadata={"model_provider": "anthropic"}, + usage_metadata={ + "input_tokens": 5000, + "output_tokens": 1000, + "total_tokens": 6000, + }, + ), + ToolMessage(content="result", tool_call_id="call1"), + AIMessage( + content="msg3", + response_metadata={"model_provider": "anthropic"}, + usage_metadata={ + "input_tokens": 6100, + "output_tokens": 900, + "total_tokens": 7000, + }, + ), + HumanMessage(content="msg4"), + AIMessage( + content="msg5", + response_metadata={"model_provider": "anthropic"}, + usage_metadata={ + "input_tokens": 7500, + "output_tokens": 2501, + "total_tokens": 10_001, + }, + ), + ] + # reported token count should override count of zero + assert middleware._should_summarize(messages, 0) + + # don't engage unless model provider matches + messages.extend( + [ + HumanMessage(content="msg6"), + AIMessage( + content="msg7", + response_metadata={"model_provider": "not-anthropic"}, + usage_metadata={ + "input_tokens": 7500, + "output_tokens": 2501, + "total_tokens": 10_001, + }, + ), + ] + ) + assert not middleware._should_summarize(messages, 0) + + # don't engage if subsequent message stays under threshold (e.g., after summarization) + messages.extend( + [ + HumanMessage(content="msg8"), + AIMessage( + content="msg9", + response_metadata={"model_provider": "anthropic"}, + usage_metadata={ + "input_tokens": 7500, + "output_tokens": 2499, + "total_tokens": 9999, + }, + ), + ] + ) + assert not middleware._should_summarize(messages, 0)