From 36c381b1496827db4405e49918e64cfda87ae701 Mon Sep 17 00:00:00 2001 From: ccurme Date: Fri, 15 May 2026 14:08:31 -0400 Subject: [PATCH] fix(langchain): alias Bedrock providers in summarization token check (#37453) --- .../agents/middleware/summarization.py | 22 ++++- .../implementations/test_summarization.py | 84 ++++++++++++++++++- 2 files changed, 104 insertions(+), 2 deletions(-) diff --git a/libs/langchain_v1/langchain/agents/middleware/summarization.py b/libs/langchain_v1/langchain/agents/middleware/summarization.py index 6c1091152c1..ac31da5b299 100644 --- a/libs/langchain_v1/langchain/agents/middleware/summarization.py +++ b/libs/langchain_v1/langchain/agents/middleware/summarization.py @@ -77,6 +77,23 @@ _DEFAULT_MESSAGES_TO_KEEP = 20 _DEFAULT_TRIM_TOKEN_LIMIT = 4000 _DEFAULT_FALLBACK_MESSAGE_COUNT = 15 +# Some providers tag emitted messages with a `model_provider` string that differs from +# their LangSmith `ls_provider`. The reported-token check below compares the two, so we +# accept known aliases per `ls_provider`. +_LS_PROVIDER_ALIASES: dict[str, frozenset[str]] = { + "amazon_bedrock": frozenset({"bedrock", "bedrock_converse"}), +} + + +def _provider_matches(message_provider: str, model_ls_provider: str | None) -> bool: + if model_ls_provider is None: + return False + if message_provider == model_ls_provider: + return True + aliases = _LS_PROVIDER_ALIASES.get(model_ls_provider) + return aliases is not None and message_provider in aliases + + ContextFraction = tuple[Literal["fraction"], float] """Fraction of model's maximum input tokens. @@ -379,7 +396,10 @@ class SummarizationMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, R and (reported_tokens := last_ai_message.usage_metadata.get("total_tokens", -1)) and reported_tokens >= threshold and (message_provider := last_ai_message.response_metadata.get("model_provider")) - and message_provider == self.model._get_ls_params().get("ls_provider") # noqa: SLF001 + and _provider_matches( + message_provider, + self.model._get_ls_params().get("ls_provider"), # noqa: SLF001 + ) ): return True return False diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py index 1ae973c7ad4..8dac1ff76cb 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py @@ -6,6 +6,7 @@ import pytest from langchain_core.callbacks import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun from langchain_core.language_models import ModelProfile from langchain_core.language_models.base import ( + LangSmithParams, LanguageModelInput, ) from langchain_core.language_models.chat_models import BaseChatModel @@ -27,7 +28,10 @@ from pydantic import Field from typing_extensions import override from langchain.agents import AgentState -from langchain.agents.middleware.summarization import SummarizationMiddleware +from langchain.agents.middleware.summarization import ( + SummarizationMiddleware, + _provider_matches, +) from langchain.chat_models import init_chat_model from tests.unit_tests.agents.model import FakeToolCallingModel @@ -1219,6 +1223,84 @@ def test_usage_metadata_trigger() -> None: assert not middleware._should_summarize(messages, 0) +def test_provider_matches() -> None: + """Direct equality matches, plus Bedrock aliases under amazon_bedrock.""" + assert _provider_matches("anthropic", "anthropic") + assert _provider_matches("openai", "openai") + # Bedrock chat models tag messages with model_provider="bedrock" or + # "bedrock_converse" but trace under ls_provider="amazon_bedrock". + assert _provider_matches("bedrock", "amazon_bedrock") + assert _provider_matches("bedrock_converse", "amazon_bedrock") + # Non-matches + assert not _provider_matches("openai", "anthropic") + assert not _provider_matches("bedrock", "anthropic") + assert not _provider_matches("anthropic", None) + + +class _MockBedrockChatModel(BaseChatModel): + """Mock model that mimics ChatBedrockConverse's ls_provider for tracing.""" + + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) + + @property + def _llm_type(self) -> str: + return "amazon_bedrock_converse_chat" + + @override + def _get_ls_params(self, stop: list[str] | None = None, **kwargs: Any) -> LangSmithParams: + return LangSmithParams(ls_provider="amazon_bedrock", ls_model_type="chat") + + +def test_reported_tokens_trigger_for_bedrock_converse() -> None: + """Bedrock messages should satisfy the reported-token check. + + Despite the model_provider/ls_provider mismatch (bedrock_converse vs. + amazon_bedrock), the reported-token check should still trigger summarization. + """ + middleware = SummarizationMiddleware( + model=_MockBedrockChatModel(), + trigger=("tokens", 10_000), + keep=("messages", 4), + ) + messages: list[AnyMessage] = [ + HumanMessage(content="msg1"), + AIMessage( + content="msg2", + response_metadata={"model_provider": "bedrock_converse"}, + usage_metadata={ + "input_tokens": 7500, + "output_tokens": 2501, + "total_tokens": 10_001, + }, + ), + ] + # reported token count (10_001) should override the supplied count of 0 + assert middleware._should_summarize(messages, 0) + + # mismatched provider should not engage + messages_other_provider: list[AnyMessage] = [ + HumanMessage(content="msg1"), + AIMessage( + content="msg2", + response_metadata={"model_provider": "anthropic"}, + usage_metadata={ + "input_tokens": 7500, + "output_tokens": 2501, + "total_tokens": 10_001, + }, + ), + ] + assert not middleware._should_summarize(messages_other_provider, 0) + + class ConfigCapturingModel(BaseChatModel): """Mock model that captures the config passed to invoke/ainvoke."""