fix(langchain): alias Bedrock providers in summarization token check (#37453)

This commit is contained in:
ccurme
2026-05-15 14:08:31 -04:00
committed by GitHub
parent 0831e445cf
commit 36c381b149
2 changed files with 104 additions and 2 deletions

View File

@@ -77,6 +77,23 @@ _DEFAULT_MESSAGES_TO_KEEP = 20
_DEFAULT_TRIM_TOKEN_LIMIT = 4000
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
# Some providers tag emitted messages with a `model_provider` string that differs from
# their LangSmith `ls_provider`. The reported-token check below compares the two, so we
# accept known aliases per `ls_provider`.
_LS_PROVIDER_ALIASES: dict[str, frozenset[str]] = {
"amazon_bedrock": frozenset({"bedrock", "bedrock_converse"}),
}
def _provider_matches(message_provider: str, model_ls_provider: str | None) -> bool:
if model_ls_provider is None:
return False
if message_provider == model_ls_provider:
return True
aliases = _LS_PROVIDER_ALIASES.get(model_ls_provider)
return aliases is not None and message_provider in aliases
ContextFraction = tuple[Literal["fraction"], float]
"""Fraction of model's maximum input tokens.
@@ -379,7 +396,10 @@ class SummarizationMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, R
and (reported_tokens := last_ai_message.usage_metadata.get("total_tokens", -1))
and reported_tokens >= threshold
and (message_provider := last_ai_message.response_metadata.get("model_provider"))
and message_provider == self.model._get_ls_params().get("ls_provider") # noqa: SLF001
and _provider_matches(
message_provider,
self.model._get_ls_params().get("ls_provider"), # noqa: SLF001
)
):
return True
return False

View File

@@ -6,6 +6,7 @@ import pytest
from langchain_core.callbacks import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun
from langchain_core.language_models import ModelProfile
from langchain_core.language_models.base import (
LangSmithParams,
LanguageModelInput,
)
from langchain_core.language_models.chat_models import BaseChatModel
@@ -27,7 +28,10 @@ from pydantic import Field
from typing_extensions import override
from langchain.agents import AgentState
from langchain.agents.middleware.summarization import SummarizationMiddleware
from langchain.agents.middleware.summarization import (
SummarizationMiddleware,
_provider_matches,
)
from langchain.chat_models import init_chat_model
from tests.unit_tests.agents.model import FakeToolCallingModel
@@ -1219,6 +1223,84 @@ def test_usage_metadata_trigger() -> None:
assert not middleware._should_summarize(messages, 0)
def test_provider_matches() -> None:
"""Direct equality matches, plus Bedrock aliases under amazon_bedrock."""
assert _provider_matches("anthropic", "anthropic")
assert _provider_matches("openai", "openai")
# Bedrock chat models tag messages with model_provider="bedrock" or
# "bedrock_converse" but trace under ls_provider="amazon_bedrock".
assert _provider_matches("bedrock", "amazon_bedrock")
assert _provider_matches("bedrock_converse", "amazon_bedrock")
# Non-matches
assert not _provider_matches("openai", "anthropic")
assert not _provider_matches("bedrock", "anthropic")
assert not _provider_matches("anthropic", None)
class _MockBedrockChatModel(BaseChatModel):
"""Mock model that mimics ChatBedrockConverse's ls_provider for tracing."""
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
def _llm_type(self) -> str:
return "amazon_bedrock_converse_chat"
@override
def _get_ls_params(self, stop: list[str] | None = None, **kwargs: Any) -> LangSmithParams:
return LangSmithParams(ls_provider="amazon_bedrock", ls_model_type="chat")
def test_reported_tokens_trigger_for_bedrock_converse() -> None:
"""Bedrock messages should satisfy the reported-token check.
Despite the model_provider/ls_provider mismatch (bedrock_converse vs.
amazon_bedrock), the reported-token check should still trigger summarization.
"""
middleware = SummarizationMiddleware(
model=_MockBedrockChatModel(),
trigger=("tokens", 10_000),
keep=("messages", 4),
)
messages: list[AnyMessage] = [
HumanMessage(content="msg1"),
AIMessage(
content="msg2",
response_metadata={"model_provider": "bedrock_converse"},
usage_metadata={
"input_tokens": 7500,
"output_tokens": 2501,
"total_tokens": 10_001,
},
),
]
# reported token count (10_001) should override the supplied count of 0
assert middleware._should_summarize(messages, 0)
# mismatched provider should not engage
messages_other_provider: list[AnyMessage] = [
HumanMessage(content="msg1"),
AIMessage(
content="msg2",
response_metadata={"model_provider": "anthropic"},
usage_metadata={
"input_tokens": 7500,
"output_tokens": 2501,
"total_tokens": 10_001,
},
),
]
assert not middleware._should_summarize(messages_other_provider, 0)
class ConfigCapturingModel(BaseChatModel):
"""Mock model that captures the config passed to invoke/ainvoke."""