mirror of
https://github.com/hwchase17/langchain.git
synced 2026-05-17 13:00:49 +00:00
fix(langchain): alias Bedrock providers in summarization token check (#37453)
This commit is contained in:
@@ -77,6 +77,23 @@ _DEFAULT_MESSAGES_TO_KEEP = 20
|
||||
_DEFAULT_TRIM_TOKEN_LIMIT = 4000
|
||||
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
|
||||
|
||||
# Some providers tag emitted messages with a `model_provider` string that differs from
|
||||
# their LangSmith `ls_provider`. The reported-token check below compares the two, so we
|
||||
# accept known aliases per `ls_provider`.
|
||||
_LS_PROVIDER_ALIASES: dict[str, frozenset[str]] = {
|
||||
"amazon_bedrock": frozenset({"bedrock", "bedrock_converse"}),
|
||||
}
|
||||
|
||||
|
||||
def _provider_matches(message_provider: str, model_ls_provider: str | None) -> bool:
|
||||
if model_ls_provider is None:
|
||||
return False
|
||||
if message_provider == model_ls_provider:
|
||||
return True
|
||||
aliases = _LS_PROVIDER_ALIASES.get(model_ls_provider)
|
||||
return aliases is not None and message_provider in aliases
|
||||
|
||||
|
||||
ContextFraction = tuple[Literal["fraction"], float]
|
||||
"""Fraction of model's maximum input tokens.
|
||||
|
||||
@@ -379,7 +396,10 @@ class SummarizationMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, R
|
||||
and (reported_tokens := last_ai_message.usage_metadata.get("total_tokens", -1))
|
||||
and reported_tokens >= threshold
|
||||
and (message_provider := last_ai_message.response_metadata.get("model_provider"))
|
||||
and message_provider == self.model._get_ls_params().get("ls_provider") # noqa: SLF001
|
||||
and _provider_matches(
|
||||
message_provider,
|
||||
self.model._get_ls_params().get("ls_provider"), # noqa: SLF001
|
||||
)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -6,6 +6,7 @@ import pytest
|
||||
from langchain_core.callbacks import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import ModelProfile
|
||||
from langchain_core.language_models.base import (
|
||||
LangSmithParams,
|
||||
LanguageModelInput,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
@@ -27,7 +28,10 @@ from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware.summarization import SummarizationMiddleware
|
||||
from langchain.agents.middleware.summarization import (
|
||||
SummarizationMiddleware,
|
||||
_provider_matches,
|
||||
)
|
||||
from langchain.chat_models import init_chat_model
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
|
||||
@@ -1219,6 +1223,84 @@ def test_usage_metadata_trigger() -> None:
|
||||
assert not middleware._should_summarize(messages, 0)
|
||||
|
||||
|
||||
def test_provider_matches() -> None:
|
||||
"""Direct equality matches, plus Bedrock aliases under amazon_bedrock."""
|
||||
assert _provider_matches("anthropic", "anthropic")
|
||||
assert _provider_matches("openai", "openai")
|
||||
# Bedrock chat models tag messages with model_provider="bedrock" or
|
||||
# "bedrock_converse" but trace under ls_provider="amazon_bedrock".
|
||||
assert _provider_matches("bedrock", "amazon_bedrock")
|
||||
assert _provider_matches("bedrock_converse", "amazon_bedrock")
|
||||
# Non-matches
|
||||
assert not _provider_matches("openai", "anthropic")
|
||||
assert not _provider_matches("bedrock", "anthropic")
|
||||
assert not _provider_matches("anthropic", None)
|
||||
|
||||
|
||||
class _MockBedrockChatModel(BaseChatModel):
|
||||
"""Mock model that mimics ChatBedrockConverse's ls_provider for tracing."""
|
||||
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "amazon_bedrock_converse_chat"
|
||||
|
||||
@override
|
||||
def _get_ls_params(self, stop: list[str] | None = None, **kwargs: Any) -> LangSmithParams:
|
||||
return LangSmithParams(ls_provider="amazon_bedrock", ls_model_type="chat")
|
||||
|
||||
|
||||
def test_reported_tokens_trigger_for_bedrock_converse() -> None:
|
||||
"""Bedrock messages should satisfy the reported-token check.
|
||||
|
||||
Despite the model_provider/ls_provider mismatch (bedrock_converse vs.
|
||||
amazon_bedrock), the reported-token check should still trigger summarization.
|
||||
"""
|
||||
middleware = SummarizationMiddleware(
|
||||
model=_MockBedrockChatModel(),
|
||||
trigger=("tokens", 10_000),
|
||||
keep=("messages", 4),
|
||||
)
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="msg1"),
|
||||
AIMessage(
|
||||
content="msg2",
|
||||
response_metadata={"model_provider": "bedrock_converse"},
|
||||
usage_metadata={
|
||||
"input_tokens": 7500,
|
||||
"output_tokens": 2501,
|
||||
"total_tokens": 10_001,
|
||||
},
|
||||
),
|
||||
]
|
||||
# reported token count (10_001) should override the supplied count of 0
|
||||
assert middleware._should_summarize(messages, 0)
|
||||
|
||||
# mismatched provider should not engage
|
||||
messages_other_provider: list[AnyMessage] = [
|
||||
HumanMessage(content="msg1"),
|
||||
AIMessage(
|
||||
content="msg2",
|
||||
response_metadata={"model_provider": "anthropic"},
|
||||
usage_metadata={
|
||||
"input_tokens": 7500,
|
||||
"output_tokens": 2501,
|
||||
"total_tokens": 10_001,
|
||||
},
|
||||
),
|
||||
]
|
||||
assert not middleware._should_summarize(messages_other_provider, 0)
|
||||
|
||||
|
||||
class ConfigCapturingModel(BaseChatModel):
|
||||
"""Mock model that captures the config passed to invoke/ainvoke."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user