mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 02:53:16 +00:00
chore(langchain): fix types in test_summarization (#34656)
Co-authored-by: Mason Daugherty <mason@langchain.dev> Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
committed by
GitHub
parent
9ce73a73f8
commit
4e40c2766a
@@ -1,13 +1,32 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import ModelProfile
|
||||
from langchain_core.language_models.base import (
|
||||
LanguageModelInput,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
MessageLikeRepresentation,
|
||||
RemoveMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.utils import count_tokens_approximately, get_buffer_string
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
from langgraph.runtime import Runtime
|
||||
from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware.summarization import SummarizationMiddleware
|
||||
from langchain.chat_models import init_chat_model
|
||||
from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
@@ -16,10 +35,25 @@ from tests.unit_tests.agents.model import FakeToolCallingModel
|
||||
class MockChatModel(BaseChatModel):
|
||||
"""Mock chat model for testing."""
|
||||
|
||||
def invoke(self, prompt): # type: ignore[no-untyped-def]
|
||||
@override
|
||||
def invoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: RunnableConfig | None = None,
|
||||
*,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AIMessage:
|
||||
return AIMessage(content="Generated summary")
|
||||
|
||||
def _generate(self, messages, **kwargs): # type: ignore[no-untyped-def]
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
@@ -30,7 +64,14 @@ class MockChatModel(BaseChatModel):
|
||||
class ProfileChatModel(BaseChatModel):
|
||||
"""Mock chat model with profile for testing."""
|
||||
|
||||
def _generate(self, messages, **kwargs): # type: ignore[no-untyped-def]
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
profile: ModelProfile | None = ModelProfile(max_input_tokens=1000)
|
||||
@@ -79,16 +120,16 @@ def test_summarization_middleware_no_summarization_cases() -> None:
|
||||
|
||||
# Test when summarization is disabled
|
||||
middleware_disabled = SummarizationMiddleware(model=model, trigger=None)
|
||||
state = {"messages": [HumanMessage(content="Hello"), AIMessage(content="Hi")]}
|
||||
result = middleware_disabled.before_model(state, None)
|
||||
state = AgentState[Any](messages=[HumanMessage(content="Hello"), AIMessage(content="Hi")])
|
||||
result = middleware_disabled.before_model(state, Runtime())
|
||||
assert result is None
|
||||
|
||||
# Test when token count is below threshold
|
||||
def mock_token_counter(messages):
|
||||
def mock_token_counter(_: Iterable[MessageLikeRepresentation]) -> int:
|
||||
return 500 # Below threshold
|
||||
|
||||
middleware.token_counter = mock_token_counter
|
||||
result = middleware.before_model(state, None)
|
||||
result = middleware.before_model(state, Runtime())
|
||||
assert result is None
|
||||
|
||||
|
||||
@@ -98,7 +139,7 @@ def test_summarization_middleware_helper_methods() -> None:
|
||||
middleware = SummarizationMiddleware(model=model, trigger=("tokens", 1000))
|
||||
|
||||
# Test message ID assignment
|
||||
messages = [HumanMessage(content="Hello"), AIMessage(content="Hi")]
|
||||
messages: list[AnyMessage] = [HumanMessage(content="Hello"), AIMessage(content="Hi")]
|
||||
middleware._ensure_message_ids(messages)
|
||||
for msg in messages:
|
||||
assert msg.id is not None
|
||||
@@ -131,7 +172,7 @@ def test_summarization_middleware_summary_creation() -> None:
|
||||
middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("tokens", 1000))
|
||||
|
||||
# Test normal summary creation
|
||||
messages = [HumanMessage(content="Hello"), AIMessage(content="Hi")]
|
||||
messages: list[AnyMessage] = [HumanMessage(content="Hello"), AIMessage(content="Hi")]
|
||||
summary = middleware._create_summary(messages)
|
||||
assert summary == "Generated summary"
|
||||
|
||||
@@ -141,15 +182,30 @@ def test_summarization_middleware_summary_creation() -> None:
|
||||
|
||||
# Test error handling
|
||||
class ErrorModel(BaseChatModel):
|
||||
def invoke(self, prompt):
|
||||
@override
|
||||
def invoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: RunnableConfig | None = None,
|
||||
*,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AIMessage:
|
||||
msg = "Model error"
|
||||
raise ValueError(msg)
|
||||
|
||||
def _generate(self, messages, **kwargs):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
def _llm_type(self) -> str:
|
||||
return "mock"
|
||||
|
||||
middleware_error = SummarizationMiddleware(model=ErrorModel(), trigger=("tokens", 1000))
|
||||
@@ -165,12 +221,16 @@ def test_summarization_middleware_summary_creation() -> None:
|
||||
|
||||
def test_summarization_middleware_trim_limit_none_keeps_all_messages() -> None:
|
||||
"""Verify disabling trim limit preserves full message sequence."""
|
||||
messages = [HumanMessage(content=str(i)) for i in range(10)]
|
||||
messages: list[AnyMessage] = [HumanMessage(content=str(i)) for i in range(10)]
|
||||
middleware = SummarizationMiddleware(
|
||||
model=MockChatModel(),
|
||||
trim_tokens_to_summarize=None,
|
||||
)
|
||||
middleware.token_counter = len
|
||||
|
||||
def token_counter(messages: Iterable[MessageLikeRepresentation]) -> int:
|
||||
return len(list(messages))
|
||||
|
||||
middleware.token_counter = token_counter
|
||||
|
||||
trimmed = middleware._trim_messages_for_summary(messages)
|
||||
assert trimmed is messages
|
||||
@@ -179,8 +239,8 @@ def test_summarization_middleware_trim_limit_none_keeps_all_messages() -> None:
|
||||
def test_summarization_middleware_profile_inference_triggers_summary() -> None:
|
||||
"""Ensure automatic profile inference triggers summarization when limits are exceeded."""
|
||||
|
||||
def token_counter(messages):
|
||||
return len(messages) * 200
|
||||
def token_counter(messages: Iterable[MessageLikeRepresentation]) -> int:
|
||||
return len(list(messages)) * 200
|
||||
|
||||
middleware = SummarizationMiddleware(
|
||||
model=ProfileChatModel(),
|
||||
@@ -189,20 +249,20 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
|
||||
token_counter=token_counter,
|
||||
)
|
||||
|
||||
state = {
|
||||
"messages": [
|
||||
state = AgentState[Any](
|
||||
messages=[
|
||||
HumanMessage(content="Message 1"),
|
||||
AIMessage(content="Message 2"),
|
||||
HumanMessage(content="Message 3"),
|
||||
AIMessage(content="Message 4"),
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
# Test we don't engage summarization
|
||||
# we have total_tokens = 4 * 200 = 800
|
||||
# and max_input_tokens = 1000
|
||||
# since 0.81 * 1000 == 810 > 800 -> summarization not triggered
|
||||
result = middleware.before_model(state, None)
|
||||
result = middleware.before_model(state, Runtime())
|
||||
assert result is None
|
||||
|
||||
# Engage summarization
|
||||
@@ -213,7 +273,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
|
||||
keep=("fraction", 0.5),
|
||||
token_counter=token_counter,
|
||||
)
|
||||
result = middleware.before_model(state, None)
|
||||
result = middleware.before_model(state, Runtime())
|
||||
assert result is not None
|
||||
assert isinstance(result["messages"][0], RemoveMessage)
|
||||
summary_message = result["messages"][1]
|
||||
@@ -233,7 +293,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
|
||||
keep=("fraction", 0.6),
|
||||
token_counter=token_counter,
|
||||
)
|
||||
result = middleware.before_model(state, None)
|
||||
result = middleware.before_model(state, Runtime())
|
||||
assert result is not None
|
||||
assert [message.content for message in result["messages"][2:]] == [
|
||||
"Message 2",
|
||||
@@ -250,7 +310,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
|
||||
keep=("fraction", 0.8),
|
||||
token_counter=token_counter,
|
||||
)
|
||||
assert middleware.before_model(state, None) is None
|
||||
assert middleware.before_model(state, Runtime()) is None
|
||||
|
||||
# Test with tokens_to_keep as absolute int value
|
||||
middleware_int = SummarizationMiddleware(
|
||||
@@ -259,7 +319,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
|
||||
keep=("tokens", 400), # Keep exactly 400 tokens (2 messages)
|
||||
token_counter=token_counter,
|
||||
)
|
||||
result = middleware_int.before_model(state, None)
|
||||
result = middleware_int.before_model(state, Runtime())
|
||||
assert result is not None
|
||||
assert [message.content for message in result["messages"][2:]] == [
|
||||
"Message 3",
|
||||
@@ -273,7 +333,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
|
||||
keep=("tokens", 600), # Keep 600 tokens (3 messages)
|
||||
token_counter=token_counter,
|
||||
)
|
||||
result = middleware_int_large.before_model(state, None)
|
||||
result = middleware_int_large.before_model(state, Runtime())
|
||||
assert result is not None
|
||||
assert [message.content for message in result["messages"][2:]] == [
|
||||
"Message 2",
|
||||
@@ -285,7 +345,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
|
||||
def test_summarization_middleware_token_retention_preserves_ai_tool_pairs() -> None:
|
||||
"""Ensure token retention preserves AI/Tool message pairs together."""
|
||||
|
||||
def token_counter(messages: list[AnyMessage]) -> int:
|
||||
def token_counter(messages: Iterable[MessageLikeRepresentation]) -> int:
|
||||
return sum(len(getattr(message, "content", "")) for message in messages)
|
||||
|
||||
middleware = SummarizationMiddleware(
|
||||
@@ -310,8 +370,8 @@ def test_summarization_middleware_token_retention_preserves_ai_tool_pairs() -> N
|
||||
HumanMessage(content="H" * 160),
|
||||
]
|
||||
|
||||
state = {"messages": messages}
|
||||
result = middleware.before_model(state, None)
|
||||
state = AgentState[Any](messages=messages)
|
||||
result = middleware.before_model(state, Runtime())
|
||||
assert result is not None
|
||||
|
||||
preserved_messages = result["messages"][2:]
|
||||
@@ -330,24 +390,33 @@ def test_summarization_middleware_missing_profile() -> None:
|
||||
"""Ensure automatic profile inference falls back when profiles are unavailable."""
|
||||
|
||||
class ImportErrorProfileModel(BaseChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "mock"
|
||||
|
||||
@property
|
||||
def profile(self):
|
||||
msg = "Profile not available"
|
||||
raise ImportError(msg)
|
||||
# NOTE: Using __getattribute__ because @property cannot override Pydantic fields.
|
||||
def __getattribute__(self, name: str) -> Any:
|
||||
if name == "profile":
|
||||
msg = "Profile not available"
|
||||
raise AttributeError(msg)
|
||||
return super().__getattribute__(name)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Model profile information is required to use fractional token limits",
|
||||
):
|
||||
_ = SummarizationMiddleware(
|
||||
model=ImportErrorProfileModel, trigger=("fraction", 0.5), keep=("messages", 1)
|
||||
model=ImportErrorProfileModel(), trigger=("fraction", 0.5), keep=("messages", 1)
|
||||
)
|
||||
|
||||
|
||||
@@ -360,12 +429,12 @@ def test_summarization_middleware_full_workflow() -> None:
|
||||
)
|
||||
|
||||
# Mock high token count to trigger summarization
|
||||
def mock_token_counter(messages):
|
||||
def mock_token_counter(_: Iterable[MessageLikeRepresentation]) -> int:
|
||||
return 1500 # Above threshold
|
||||
|
||||
middleware.token_counter = mock_token_counter
|
||||
|
||||
messages = [
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="1"),
|
||||
HumanMessage(content="2"),
|
||||
HumanMessage(content="3"),
|
||||
@@ -373,8 +442,8 @@ def test_summarization_middleware_full_workflow() -> None:
|
||||
HumanMessage(content="5"),
|
||||
]
|
||||
|
||||
state = {"messages": messages}
|
||||
result = middleware.before_model(state, None)
|
||||
state = AgentState[Any](messages=messages)
|
||||
result = middleware.before_model(state, Runtime())
|
||||
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
@@ -399,14 +468,28 @@ async def test_summarization_middleware_full_workflow_async() -> None:
|
||||
"""Test SummarizationMiddleware complete summarization workflow."""
|
||||
|
||||
class MockModel(BaseChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Blep"))])
|
||||
|
||||
async def _agenerate(self, messages, **kwargs):
|
||||
@override
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Blip"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
def _llm_type(self) -> str:
|
||||
return "mock"
|
||||
|
||||
middleware = SummarizationMiddleware(
|
||||
@@ -414,12 +497,12 @@ async def test_summarization_middleware_full_workflow_async() -> None:
|
||||
)
|
||||
|
||||
# Mock high token count to trigger summarization
|
||||
def mock_token_counter(messages):
|
||||
def mock_token_counter(_: Iterable[MessageLikeRepresentation]) -> int:
|
||||
return 1500 # Above threshold
|
||||
|
||||
middleware.token_counter = mock_token_counter
|
||||
|
||||
messages = [
|
||||
messages: list[AnyMessage] = [
|
||||
HumanMessage(content="1"),
|
||||
HumanMessage(content="2"),
|
||||
HumanMessage(content="3"),
|
||||
@@ -427,8 +510,8 @@ async def test_summarization_middleware_full_workflow_async() -> None:
|
||||
HumanMessage(content="5"),
|
||||
]
|
||||
|
||||
state = {"messages": messages}
|
||||
result = await middleware.abefore_model(state, None)
|
||||
state = AgentState[Any](messages=messages)
|
||||
result = await middleware.abefore_model(state, Runtime())
|
||||
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
@@ -451,26 +534,26 @@ def test_summarization_middleware_keep_messages() -> None:
|
||||
)
|
||||
|
||||
# Below threshold - no summarization
|
||||
messages_below = [
|
||||
messages_below: list[AnyMessage] = [
|
||||
HumanMessage(content="1"),
|
||||
HumanMessage(content="2"),
|
||||
HumanMessage(content="3"),
|
||||
HumanMessage(content="4"),
|
||||
]
|
||||
state_below = {"messages": messages_below}
|
||||
result = middleware.before_model(state_below, None)
|
||||
state_below = AgentState[Any](messages=messages_below)
|
||||
result = middleware.before_model(state_below, Runtime())
|
||||
assert result is None
|
||||
|
||||
# At threshold - should trigger summarization
|
||||
messages_at_threshold = [
|
||||
messages_at_threshold: list[AnyMessage] = [
|
||||
HumanMessage(content="1"),
|
||||
HumanMessage(content="2"),
|
||||
HumanMessage(content="3"),
|
||||
HumanMessage(content="4"),
|
||||
HumanMessage(content="5"),
|
||||
]
|
||||
state_at = {"messages": messages_at_threshold}
|
||||
result = middleware.before_model(state_at, None)
|
||||
state_at = AgentState[Any](messages=messages_at_threshold)
|
||||
result = middleware.before_model(state_at, Runtime())
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
expected_types = ["remove", "human", "human", "human"]
|
||||
@@ -479,9 +562,9 @@ def test_summarization_middleware_keep_messages() -> None:
|
||||
assert [message.content for message in result["messages"][2:]] == ["4", "5"]
|
||||
|
||||
# Above threshold - should also trigger summarization
|
||||
messages_above = [*messages_at_threshold, HumanMessage(content="6")]
|
||||
state_above = {"messages": messages_above}
|
||||
result = middleware.before_model(state_above, None)
|
||||
messages_above: list[AnyMessage] = [*messages_at_threshold, HumanMessage(content="6")]
|
||||
state_above = AgentState[Any](messages=messages_above)
|
||||
result = middleware.before_model(state_above, Runtime())
|
||||
assert result is not None
|
||||
assert "messages" in result
|
||||
expected_types = ["remove", "human", "human", "human"]
|
||||
@@ -491,7 +574,7 @@ def test_summarization_middleware_keep_messages() -> None:
|
||||
|
||||
# Test with both parameters disabled
|
||||
middleware_disabled = SummarizationMiddleware(model=MockChatModel(), trigger=None)
|
||||
result = middleware_disabled.before_model(state_above, None)
|
||||
result = middleware_disabled.before_model(state_above, Runtime())
|
||||
assert result is None
|
||||
|
||||
|
||||
@@ -509,7 +592,7 @@ def test_summarization_middleware_keep_messages() -> None:
|
||||
],
|
||||
)
|
||||
def test_summarization_middleware_validation_edge_cases(
|
||||
param_name: str, param_value: tuple[str, float | int], expected_error: str
|
||||
param_name: str, param_value: Any, expected_error: str
|
||||
) -> None:
|
||||
"""Test validation of context size parameters with edge cases."""
|
||||
model = FakeToolCallingModel()
|
||||
@@ -527,31 +610,31 @@ def test_summarization_middleware_multiple_triggers() -> None:
|
||||
)
|
||||
|
||||
# Mock token counter to return low count
|
||||
def mock_low_tokens(messages):
|
||||
def mock_low_tokens(_: Iterable[MessageLikeRepresentation]) -> int:
|
||||
return 100
|
||||
|
||||
middleware.token_counter = mock_low_tokens
|
||||
|
||||
# Should not trigger - neither condition met
|
||||
messages = [HumanMessage(content=str(i)) for i in range(5)]
|
||||
state = {"messages": messages}
|
||||
result = middleware.before_model(state, None)
|
||||
messages: list[AnyMessage] = [HumanMessage(content=str(i)) for i in range(5)]
|
||||
state = AgentState[Any](messages=messages)
|
||||
result = middleware.before_model(state, Runtime())
|
||||
assert result is None
|
||||
|
||||
# Should trigger - message count threshold met
|
||||
messages = [HumanMessage(content=str(i)) for i in range(10)]
|
||||
state = {"messages": messages}
|
||||
result = middleware.before_model(state, None)
|
||||
state = AgentState[Any](messages=messages)
|
||||
result = middleware.before_model(state, Runtime())
|
||||
assert result is not None
|
||||
|
||||
# Test token trigger
|
||||
def mock_high_tokens(messages):
|
||||
def mock_high_tokens(_: Iterable[MessageLikeRepresentation]) -> int:
|
||||
return 600
|
||||
|
||||
middleware.token_counter = mock_high_tokens
|
||||
messages = [HumanMessage(content=str(i)) for i in range(5)]
|
||||
state = {"messages": messages}
|
||||
result = middleware.before_model(state, None)
|
||||
state = AgentState[Any](messages=messages)
|
||||
result = middleware.before_model(state, Runtime())
|
||||
assert result is not None
|
||||
|
||||
|
||||
@@ -559,11 +642,18 @@ def test_summarization_middleware_profile_edge_cases() -> None:
|
||||
"""Test profile retrieval with various edge cases."""
|
||||
|
||||
class NoProfileModel(BaseChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
def _llm_type(self) -> str:
|
||||
return "mock"
|
||||
|
||||
# Model without profile attribute
|
||||
@@ -571,49 +661,71 @@ def test_summarization_middleware_profile_edge_cases() -> None:
|
||||
assert middleware._get_profile_limits() is None
|
||||
|
||||
class InvalidProfileModel(BaseChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
def _llm_type(self) -> str:
|
||||
return "mock"
|
||||
|
||||
@property
|
||||
def profile(self):
|
||||
return "invalid_profile_type"
|
||||
# NOTE: Using __getattribute__ because @property cannot override Pydantic fields.
|
||||
def __getattribute__(self, name: str) -> Any:
|
||||
if name == "profile":
|
||||
return "invalid_profile_type"
|
||||
return super().__getattribute__(name)
|
||||
|
||||
# Model with non-dict profile
|
||||
middleware = SummarizationMiddleware(model=InvalidProfileModel(), trigger=("messages", 5))
|
||||
assert middleware._get_profile_limits() is None
|
||||
|
||||
class MissingTokensModel(BaseChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
profile: ModelProfile | None = Field(default=ModelProfile(other_field=100), exclude=True) # type: ignore[typeddict-unknown-key]
|
||||
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
def _llm_type(self) -> str:
|
||||
return "mock"
|
||||
|
||||
@property
|
||||
def profile(self):
|
||||
return {"other_field": 100}
|
||||
|
||||
# Model with profile but no max_input_tokens
|
||||
middleware = SummarizationMiddleware(model=MissingTokensModel(), trigger=("messages", 5))
|
||||
assert middleware._get_profile_limits() is None
|
||||
|
||||
class InvalidTokenTypeModel(BaseChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
profile: ModelProfile | None = Field(
|
||||
default=ModelProfile(max_input_tokens="not_an_int"), # type: ignore[typeddict-item]
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
def _llm_type(self) -> str:
|
||||
return "mock"
|
||||
|
||||
@property
|
||||
def profile(self):
|
||||
return {"max_input_tokens": "not_an_int"}
|
||||
|
||||
# Model with non-int max_input_tokens
|
||||
middleware = SummarizationMiddleware(model=InvalidTokenTypeModel(), trigger=("messages", 5))
|
||||
assert middleware._get_profile_limits() is None
|
||||
@@ -624,14 +736,14 @@ def test_summarization_middleware_trim_messages_error_fallback() -> None:
|
||||
middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 5))
|
||||
|
||||
# Create a mock token counter that raises an exception
|
||||
def failing_token_counter(messages):
|
||||
def failing_token_counter(_: Iterable[MessageLikeRepresentation]) -> int:
|
||||
msg = "Token counting failed"
|
||||
raise ValueError(msg)
|
||||
|
||||
middleware.token_counter = failing_token_counter
|
||||
|
||||
# Should fall back to last 15 messages
|
||||
messages = [HumanMessage(content=str(i)) for i in range(20)]
|
||||
messages: list[AnyMessage] = [HumanMessage(content=str(i)) for i in range(20)]
|
||||
trimmed = middleware._trim_messages_for_summary(messages)
|
||||
assert len(trimmed) == 15
|
||||
assert trimmed == messages[-15:]
|
||||
@@ -644,12 +756,12 @@ def test_summarization_middleware_binary_search_edge_cases() -> None:
|
||||
)
|
||||
|
||||
# Test with single message that's too large
|
||||
def token_counter_single_large(messages):
|
||||
return len(messages) * 200
|
||||
def token_counter_single_large(messages: Iterable[MessageLikeRepresentation]) -> int:
|
||||
return len(list(messages)) * 200
|
||||
|
||||
middleware.token_counter = token_counter_single_large
|
||||
|
||||
single_message = [HumanMessage(content="x" * 200)]
|
||||
single_message: list[AnyMessage] = [HumanMessage(content="x" * 200)]
|
||||
cutoff = middleware._find_token_based_cutoff(single_message)
|
||||
assert cutoff == 0
|
||||
|
||||
@@ -658,11 +770,11 @@ def test_summarization_middleware_binary_search_edge_cases() -> None:
|
||||
assert cutoff == 0
|
||||
|
||||
# Test when all messages fit within token budget
|
||||
def token_counter_small(messages):
|
||||
return len(messages) * 10
|
||||
def token_counter_small(messages: Iterable[MessageLikeRepresentation]) -> int:
|
||||
return len(list(messages)) * 10
|
||||
|
||||
middleware.token_counter = token_counter_small
|
||||
messages = [HumanMessage(content=str(i)) for i in range(5)]
|
||||
messages: list[AnyMessage] = [HumanMessage(content=str(i)) for i in range(5)]
|
||||
cutoff = middleware._find_token_based_cutoff(messages)
|
||||
assert cutoff == 0
|
||||
|
||||
@@ -774,11 +886,11 @@ def test_summarization_middleware_zero_and_negative_target_tokens() -> None:
|
||||
)
|
||||
|
||||
# Should set threshold to 1 when calculated value is <= 0
|
||||
messages = [HumanMessage(content="test")]
|
||||
messages: list[AnyMessage] = [HumanMessage(content="test")]
|
||||
|
||||
# The trigger fraction calculation: int(1000 * 0.0001) = 0, but should be set to 1
|
||||
# Token count of 1 message should exceed threshold of 1
|
||||
def token_counter(msgs):
|
||||
def token_counter(_: Iterable[MessageLikeRepresentation]) -> int:
|
||||
return 2
|
||||
|
||||
middleware.token_counter = token_counter
|
||||
@@ -789,19 +901,33 @@ async def test_summarization_middleware_async_error_handling() -> None:
|
||||
"""Test async summary creation with errors."""
|
||||
|
||||
class ErrorAsyncModel(BaseChatModel):
|
||||
def _generate(self, messages, **kwargs):
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
|
||||
|
||||
async def _agenerate(self, messages, **kwargs):
|
||||
@override
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
msg = "Async model error"
|
||||
raise ValueError(msg)
|
||||
|
||||
@property
|
||||
def _llm_type(self):
|
||||
def _llm_type(self) -> str:
|
||||
return "mock"
|
||||
|
||||
middleware = SummarizationMiddleware(model=ErrorAsyncModel(), trigger=("messages", 5))
|
||||
messages = [HumanMessage(content="test")]
|
||||
messages: list[AnyMessage] = [HumanMessage(content="test")]
|
||||
summary = await middleware._acreate_summary(messages)
|
||||
assert "Error generating summary: Async model error" in summary
|
||||
|
||||
@@ -813,7 +939,7 @@ def test_summarization_middleware_cutoff_at_boundary() -> None:
|
||||
)
|
||||
|
||||
# When we want to keep exactly as many messages as we have
|
||||
messages = [HumanMessage(content=str(i)) for i in range(5)]
|
||||
messages: list[AnyMessage] = [HumanMessage(content=str(i)) for i in range(5)]
|
||||
cutoff = middleware._find_safe_cutoff(messages, 5)
|
||||
assert cutoff == 0 # Should not cut anything
|
||||
|
||||
@@ -848,19 +974,14 @@ def test_summarization_middleware_fraction_trigger_with_no_profile() -> None:
|
||||
)
|
||||
|
||||
# Test that when fractional condition can't be evaluated, other triggers still work
|
||||
messages = [HumanMessage(content=str(i)) for i in range(100)]
|
||||
messages: list[AnyMessage] = [HumanMessage(content=str(i)) for i in range(100)]
|
||||
|
||||
# Mock _get_profile_limits to return None
|
||||
original_method = middleware._get_profile_limits
|
||||
middleware._get_profile_limits = lambda: None
|
||||
|
||||
# Should still trigger based on message count
|
||||
state = {"messages": messages}
|
||||
result = middleware.before_model(state, None)
|
||||
assert result is not None
|
||||
|
||||
# Restore original method
|
||||
middleware._get_profile_limits = original_method
|
||||
with patch.object(middleware, "_get_profile_limits", autospec=True, return_value=None):
|
||||
# Should still trigger based on message count
|
||||
state = AgentState[Any](messages=messages)
|
||||
result = middleware.before_model(state, Runtime())
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_summarization_adjust_token_counts() -> None:
|
||||
|
||||
Reference in New Issue
Block a user