chore(langchain): fix types in test_summarization (#34656)

Co-authored-by: Mason Daugherty <mason@langchain.dev>
Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
Christophe Bornet
2026-01-09 23:54:42 +01:00
committed by GitHub
parent 9ce73a73f8
commit 4e40c2766a

View File

@@ -1,13 +1,32 @@
from collections.abc import Iterable
from typing import Any
from unittest.mock import patch
import pytest
from langchain_core.callbacks import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun
from langchain_core.language_models import ModelProfile
from langchain_core.language_models.base import (
LanguageModelInput,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage
from langchain_core.messages import (
AIMessage,
AnyMessage,
BaseMessage,
HumanMessage,
MessageLikeRepresentation,
RemoveMessage,
ToolMessage,
)
from langchain_core.messages.utils import count_tokens_approximately, get_buffer_string
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.runnables import RunnableConfig
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import Runtime
from pydantic import Field
from typing_extensions import override
from langchain.agents import AgentState
from langchain.agents.middleware.summarization import SummarizationMiddleware
from langchain.chat_models import init_chat_model
from tests.unit_tests.agents.model import FakeToolCallingModel
@@ -16,10 +35,25 @@ from tests.unit_tests.agents.model import FakeToolCallingModel
class MockChatModel(BaseChatModel):
"""Mock chat model for testing."""
def invoke(self, prompt): # type: ignore[no-untyped-def]
@override
def invoke(
self,
input: LanguageModelInput,
config: RunnableConfig | None = None,
*,
stop: list[str] | None = None,
**kwargs: Any,
) -> AIMessage:
return AIMessage(content="Generated summary")
def _generate(self, messages, **kwargs): # type: ignore[no-untyped-def]
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
@@ -30,7 +64,14 @@ class MockChatModel(BaseChatModel):
class ProfileChatModel(BaseChatModel):
"""Mock chat model with profile for testing."""
def _generate(self, messages, **kwargs): # type: ignore[no-untyped-def]
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
profile: ModelProfile | None = ModelProfile(max_input_tokens=1000)
@@ -79,16 +120,16 @@ def test_summarization_middleware_no_summarization_cases() -> None:
# Test when summarization is disabled
middleware_disabled = SummarizationMiddleware(model=model, trigger=None)
state = {"messages": [HumanMessage(content="Hello"), AIMessage(content="Hi")]}
result = middleware_disabled.before_model(state, None)
state = AgentState[Any](messages=[HumanMessage(content="Hello"), AIMessage(content="Hi")])
result = middleware_disabled.before_model(state, Runtime())
assert result is None
# Test when token count is below threshold
def mock_token_counter(messages):
def mock_token_counter(_: Iterable[MessageLikeRepresentation]) -> int:
return 500 # Below threshold
middleware.token_counter = mock_token_counter
result = middleware.before_model(state, None)
result = middleware.before_model(state, Runtime())
assert result is None
@@ -98,7 +139,7 @@ def test_summarization_middleware_helper_methods() -> None:
middleware = SummarizationMiddleware(model=model, trigger=("tokens", 1000))
# Test message ID assignment
messages = [HumanMessage(content="Hello"), AIMessage(content="Hi")]
messages: list[AnyMessage] = [HumanMessage(content="Hello"), AIMessage(content="Hi")]
middleware._ensure_message_ids(messages)
for msg in messages:
assert msg.id is not None
@@ -131,7 +172,7 @@ def test_summarization_middleware_summary_creation() -> None:
middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("tokens", 1000))
# Test normal summary creation
messages = [HumanMessage(content="Hello"), AIMessage(content="Hi")]
messages: list[AnyMessage] = [HumanMessage(content="Hello"), AIMessage(content="Hi")]
summary = middleware._create_summary(messages)
assert summary == "Generated summary"
@@ -141,15 +182,30 @@ def test_summarization_middleware_summary_creation() -> None:
# Test error handling
class ErrorModel(BaseChatModel):
def invoke(self, prompt):
@override
def invoke(
self,
input: LanguageModelInput,
config: RunnableConfig | None = None,
*,
stop: list[str] | None = None,
**kwargs: Any,
) -> AIMessage:
msg = "Model error"
raise ValueError(msg)
def _generate(self, messages, **kwargs):
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
def _llm_type(self):
def _llm_type(self) -> str:
return "mock"
middleware_error = SummarizationMiddleware(model=ErrorModel(), trigger=("tokens", 1000))
@@ -165,12 +221,16 @@ def test_summarization_middleware_summary_creation() -> None:
def test_summarization_middleware_trim_limit_none_keeps_all_messages() -> None:
"""Verify disabling trim limit preserves full message sequence."""
messages = [HumanMessage(content=str(i)) for i in range(10)]
messages: list[AnyMessage] = [HumanMessage(content=str(i)) for i in range(10)]
middleware = SummarizationMiddleware(
model=MockChatModel(),
trim_tokens_to_summarize=None,
)
middleware.token_counter = len
def token_counter(messages: Iterable[MessageLikeRepresentation]) -> int:
return len(list(messages))
middleware.token_counter = token_counter
trimmed = middleware._trim_messages_for_summary(messages)
assert trimmed is messages
@@ -179,8 +239,8 @@ def test_summarization_middleware_trim_limit_none_keeps_all_messages() -> None:
def test_summarization_middleware_profile_inference_triggers_summary() -> None:
"""Ensure automatic profile inference triggers summarization when limits are exceeded."""
def token_counter(messages):
return len(messages) * 200
def token_counter(messages: Iterable[MessageLikeRepresentation]) -> int:
return len(list(messages)) * 200
middleware = SummarizationMiddleware(
model=ProfileChatModel(),
@@ -189,20 +249,20 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
token_counter=token_counter,
)
state = {
"messages": [
state = AgentState[Any](
messages=[
HumanMessage(content="Message 1"),
AIMessage(content="Message 2"),
HumanMessage(content="Message 3"),
AIMessage(content="Message 4"),
]
}
)
# Test we don't engage summarization
# we have total_tokens = 4 * 200 = 800
# and max_input_tokens = 1000
# since 0.81 * 1000 == 810 > 800 -> summarization not triggered
result = middleware.before_model(state, None)
result = middleware.before_model(state, Runtime())
assert result is None
# Engage summarization
@@ -213,7 +273,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
keep=("fraction", 0.5),
token_counter=token_counter,
)
result = middleware.before_model(state, None)
result = middleware.before_model(state, Runtime())
assert result is not None
assert isinstance(result["messages"][0], RemoveMessage)
summary_message = result["messages"][1]
@@ -233,7 +293,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
keep=("fraction", 0.6),
token_counter=token_counter,
)
result = middleware.before_model(state, None)
result = middleware.before_model(state, Runtime())
assert result is not None
assert [message.content for message in result["messages"][2:]] == [
"Message 2",
@@ -250,7 +310,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
keep=("fraction", 0.8),
token_counter=token_counter,
)
assert middleware.before_model(state, None) is None
assert middleware.before_model(state, Runtime()) is None
# Test with tokens_to_keep as absolute int value
middleware_int = SummarizationMiddleware(
@@ -259,7 +319,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
keep=("tokens", 400), # Keep exactly 400 tokens (2 messages)
token_counter=token_counter,
)
result = middleware_int.before_model(state, None)
result = middleware_int.before_model(state, Runtime())
assert result is not None
assert [message.content for message in result["messages"][2:]] == [
"Message 3",
@@ -273,7 +333,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
keep=("tokens", 600), # Keep 600 tokens (3 messages)
token_counter=token_counter,
)
result = middleware_int_large.before_model(state, None)
result = middleware_int_large.before_model(state, Runtime())
assert result is not None
assert [message.content for message in result["messages"][2:]] == [
"Message 2",
@@ -285,7 +345,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
def test_summarization_middleware_token_retention_preserves_ai_tool_pairs() -> None:
"""Ensure token retention preserves AI/Tool message pairs together."""
def token_counter(messages: list[AnyMessage]) -> int:
def token_counter(messages: Iterable[MessageLikeRepresentation]) -> int:
return sum(len(getattr(message, "content", "")) for message in messages)
middleware = SummarizationMiddleware(
@@ -310,8 +370,8 @@ def test_summarization_middleware_token_retention_preserves_ai_tool_pairs() -> N
HumanMessage(content="H" * 160),
]
state = {"messages": messages}
result = middleware.before_model(state, None)
state = AgentState[Any](messages=messages)
result = middleware.before_model(state, Runtime())
assert result is not None
preserved_messages = result["messages"][2:]
@@ -330,24 +390,33 @@ def test_summarization_middleware_missing_profile() -> None:
"""Ensure automatic profile inference falls back when profiles are unavailable."""
class ImportErrorProfileModel(BaseChatModel):
def _generate(self, messages, **kwargs):
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
raise NotImplementedError
@property
def _llm_type(self) -> str:
return "mock"
@property
def profile(self):
msg = "Profile not available"
raise ImportError(msg)
# NOTE: Using __getattribute__ because @property cannot override Pydantic fields.
def __getattribute__(self, name: str) -> Any:
if name == "profile":
msg = "Profile not available"
raise AttributeError(msg)
return super().__getattribute__(name)
with pytest.raises(
ValueError,
match="Model profile information is required to use fractional token limits",
):
_ = SummarizationMiddleware(
model=ImportErrorProfileModel, trigger=("fraction", 0.5), keep=("messages", 1)
model=ImportErrorProfileModel(), trigger=("fraction", 0.5), keep=("messages", 1)
)
@@ -360,12 +429,12 @@ def test_summarization_middleware_full_workflow() -> None:
)
# Mock high token count to trigger summarization
def mock_token_counter(messages):
def mock_token_counter(_: Iterable[MessageLikeRepresentation]) -> int:
return 1500 # Above threshold
middleware.token_counter = mock_token_counter
messages = [
messages: list[AnyMessage] = [
HumanMessage(content="1"),
HumanMessage(content="2"),
HumanMessage(content="3"),
@@ -373,8 +442,8 @@ def test_summarization_middleware_full_workflow() -> None:
HumanMessage(content="5"),
]
state = {"messages": messages}
result = middleware.before_model(state, None)
state = AgentState[Any](messages=messages)
result = middleware.before_model(state, Runtime())
assert result is not None
assert "messages" in result
@@ -399,14 +468,28 @@ async def test_summarization_middleware_full_workflow_async() -> None:
"""Test SummarizationMiddleware complete summarization workflow."""
class MockModel(BaseChatModel):
def _generate(self, messages, **kwargs):
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Blep"))])
async def _agenerate(self, messages, **kwargs):
@override
async def _agenerate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Blip"))])
@property
def _llm_type(self):
def _llm_type(self) -> str:
return "mock"
middleware = SummarizationMiddleware(
@@ -414,12 +497,12 @@ async def test_summarization_middleware_full_workflow_async() -> None:
)
# Mock high token count to trigger summarization
def mock_token_counter(messages):
def mock_token_counter(_: Iterable[MessageLikeRepresentation]) -> int:
return 1500 # Above threshold
middleware.token_counter = mock_token_counter
messages = [
messages: list[AnyMessage] = [
HumanMessage(content="1"),
HumanMessage(content="2"),
HumanMessage(content="3"),
@@ -427,8 +510,8 @@ async def test_summarization_middleware_full_workflow_async() -> None:
HumanMessage(content="5"),
]
state = {"messages": messages}
result = await middleware.abefore_model(state, None)
state = AgentState[Any](messages=messages)
result = await middleware.abefore_model(state, Runtime())
assert result is not None
assert "messages" in result
@@ -451,26 +534,26 @@ def test_summarization_middleware_keep_messages() -> None:
)
# Below threshold - no summarization
messages_below = [
messages_below: list[AnyMessage] = [
HumanMessage(content="1"),
HumanMessage(content="2"),
HumanMessage(content="3"),
HumanMessage(content="4"),
]
state_below = {"messages": messages_below}
result = middleware.before_model(state_below, None)
state_below = AgentState[Any](messages=messages_below)
result = middleware.before_model(state_below, Runtime())
assert result is None
# At threshold - should trigger summarization
messages_at_threshold = [
messages_at_threshold: list[AnyMessage] = [
HumanMessage(content="1"),
HumanMessage(content="2"),
HumanMessage(content="3"),
HumanMessage(content="4"),
HumanMessage(content="5"),
]
state_at = {"messages": messages_at_threshold}
result = middleware.before_model(state_at, None)
state_at = AgentState[Any](messages=messages_at_threshold)
result = middleware.before_model(state_at, Runtime())
assert result is not None
assert "messages" in result
expected_types = ["remove", "human", "human", "human"]
@@ -479,9 +562,9 @@ def test_summarization_middleware_keep_messages() -> None:
assert [message.content for message in result["messages"][2:]] == ["4", "5"]
# Above threshold - should also trigger summarization
messages_above = [*messages_at_threshold, HumanMessage(content="6")]
state_above = {"messages": messages_above}
result = middleware.before_model(state_above, None)
messages_above: list[AnyMessage] = [*messages_at_threshold, HumanMessage(content="6")]
state_above = AgentState[Any](messages=messages_above)
result = middleware.before_model(state_above, Runtime())
assert result is not None
assert "messages" in result
expected_types = ["remove", "human", "human", "human"]
@@ -491,7 +574,7 @@ def test_summarization_middleware_keep_messages() -> None:
# Test with both parameters disabled
middleware_disabled = SummarizationMiddleware(model=MockChatModel(), trigger=None)
result = middleware_disabled.before_model(state_above, None)
result = middleware_disabled.before_model(state_above, Runtime())
assert result is None
@@ -509,7 +592,7 @@ def test_summarization_middleware_keep_messages() -> None:
],
)
def test_summarization_middleware_validation_edge_cases(
param_name: str, param_value: tuple[str, float | int], expected_error: str
param_name: str, param_value: Any, expected_error: str
) -> None:
"""Test validation of context size parameters with edge cases."""
model = FakeToolCallingModel()
@@ -527,31 +610,31 @@ def test_summarization_middleware_multiple_triggers() -> None:
)
# Mock token counter to return low count
def mock_low_tokens(messages):
def mock_low_tokens(_: Iterable[MessageLikeRepresentation]) -> int:
return 100
middleware.token_counter = mock_low_tokens
# Should not trigger - neither condition met
messages = [HumanMessage(content=str(i)) for i in range(5)]
state = {"messages": messages}
result = middleware.before_model(state, None)
messages: list[AnyMessage] = [HumanMessage(content=str(i)) for i in range(5)]
state = AgentState[Any](messages=messages)
result = middleware.before_model(state, Runtime())
assert result is None
# Should trigger - message count threshold met
messages = [HumanMessage(content=str(i)) for i in range(10)]
state = {"messages": messages}
result = middleware.before_model(state, None)
state = AgentState[Any](messages=messages)
result = middleware.before_model(state, Runtime())
assert result is not None
# Test token trigger
def mock_high_tokens(messages):
def mock_high_tokens(_: Iterable[MessageLikeRepresentation]) -> int:
return 600
middleware.token_counter = mock_high_tokens
messages = [HumanMessage(content=str(i)) for i in range(5)]
state = {"messages": messages}
result = middleware.before_model(state, None)
state = AgentState[Any](messages=messages)
result = middleware.before_model(state, Runtime())
assert result is not None
@@ -559,11 +642,18 @@ def test_summarization_middleware_profile_edge_cases() -> None:
"""Test profile retrieval with various edge cases."""
class NoProfileModel(BaseChatModel):
def _generate(self, messages, **kwargs):
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
def _llm_type(self):
def _llm_type(self) -> str:
return "mock"
# Model without profile attribute
@@ -571,49 +661,71 @@ def test_summarization_middleware_profile_edge_cases() -> None:
assert middleware._get_profile_limits() is None
class InvalidProfileModel(BaseChatModel):
def _generate(self, messages, **kwargs):
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
def _llm_type(self):
def _llm_type(self) -> str:
return "mock"
@property
def profile(self):
return "invalid_profile_type"
# NOTE: Using __getattribute__ because @property cannot override Pydantic fields.
def __getattribute__(self, name: str) -> Any:
if name == "profile":
return "invalid_profile_type"
return super().__getattribute__(name)
# Model with non-dict profile
middleware = SummarizationMiddleware(model=InvalidProfileModel(), trigger=("messages", 5))
assert middleware._get_profile_limits() is None
class MissingTokensModel(BaseChatModel):
def _generate(self, messages, **kwargs):
profile: ModelProfile | None = Field(default=ModelProfile(other_field=100), exclude=True) # type: ignore[typeddict-unknown-key]
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
def _llm_type(self):
def _llm_type(self) -> str:
return "mock"
@property
def profile(self):
return {"other_field": 100}
# Model with profile but no max_input_tokens
middleware = SummarizationMiddleware(model=MissingTokensModel(), trigger=("messages", 5))
assert middleware._get_profile_limits() is None
class InvalidTokenTypeModel(BaseChatModel):
def _generate(self, messages, **kwargs):
profile: ModelProfile | None = Field(
default=ModelProfile(max_input_tokens="not_an_int"), # type: ignore[typeddict-item]
exclude=True,
)
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
@property
def _llm_type(self):
def _llm_type(self) -> str:
return "mock"
@property
def profile(self):
return {"max_input_tokens": "not_an_int"}
# Model with non-int max_input_tokens
middleware = SummarizationMiddleware(model=InvalidTokenTypeModel(), trigger=("messages", 5))
assert middleware._get_profile_limits() is None
@@ -624,14 +736,14 @@ def test_summarization_middleware_trim_messages_error_fallback() -> None:
middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 5))
# Create a mock token counter that raises an exception
def failing_token_counter(messages):
def failing_token_counter(_: Iterable[MessageLikeRepresentation]) -> int:
msg = "Token counting failed"
raise ValueError(msg)
middleware.token_counter = failing_token_counter
# Should fall back to last 15 messages
messages = [HumanMessage(content=str(i)) for i in range(20)]
messages: list[AnyMessage] = [HumanMessage(content=str(i)) for i in range(20)]
trimmed = middleware._trim_messages_for_summary(messages)
assert len(trimmed) == 15
assert trimmed == messages[-15:]
@@ -644,12 +756,12 @@ def test_summarization_middleware_binary_search_edge_cases() -> None:
)
# Test with single message that's too large
def token_counter_single_large(messages):
return len(messages) * 200
def token_counter_single_large(messages: Iterable[MessageLikeRepresentation]) -> int:
return len(list(messages)) * 200
middleware.token_counter = token_counter_single_large
single_message = [HumanMessage(content="x" * 200)]
single_message: list[AnyMessage] = [HumanMessage(content="x" * 200)]
cutoff = middleware._find_token_based_cutoff(single_message)
assert cutoff == 0
@@ -658,11 +770,11 @@ def test_summarization_middleware_binary_search_edge_cases() -> None:
assert cutoff == 0
# Test when all messages fit within token budget
def token_counter_small(messages):
return len(messages) * 10
def token_counter_small(messages: Iterable[MessageLikeRepresentation]) -> int:
return len(list(messages)) * 10
middleware.token_counter = token_counter_small
messages = [HumanMessage(content=str(i)) for i in range(5)]
messages: list[AnyMessage] = [HumanMessage(content=str(i)) for i in range(5)]
cutoff = middleware._find_token_based_cutoff(messages)
assert cutoff == 0
@@ -774,11 +886,11 @@ def test_summarization_middleware_zero_and_negative_target_tokens() -> None:
)
# Should set threshold to 1 when calculated value is <= 0
messages = [HumanMessage(content="test")]
messages: list[AnyMessage] = [HumanMessage(content="test")]
# The trigger fraction calculation: int(1000 * 0.0001) = 0, but should be set to 1
# Token count of 1 message should exceed threshold of 1
def token_counter(msgs):
def token_counter(_: Iterable[MessageLikeRepresentation]) -> int:
return 2
middleware.token_counter = token_counter
@@ -789,19 +901,33 @@ async def test_summarization_middleware_async_error_handling() -> None:
"""Test async summary creation with errors."""
class ErrorAsyncModel(BaseChatModel):
def _generate(self, messages, **kwargs):
@override
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
async def _agenerate(self, messages, **kwargs):
@override
async def _agenerate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
msg = "Async model error"
raise ValueError(msg)
@property
def _llm_type(self):
def _llm_type(self) -> str:
return "mock"
middleware = SummarizationMiddleware(model=ErrorAsyncModel(), trigger=("messages", 5))
messages = [HumanMessage(content="test")]
messages: list[AnyMessage] = [HumanMessage(content="test")]
summary = await middleware._acreate_summary(messages)
assert "Error generating summary: Async model error" in summary
@@ -813,7 +939,7 @@ def test_summarization_middleware_cutoff_at_boundary() -> None:
)
# When we want to keep exactly as many messages as we have
messages = [HumanMessage(content=str(i)) for i in range(5)]
messages: list[AnyMessage] = [HumanMessage(content=str(i)) for i in range(5)]
cutoff = middleware._find_safe_cutoff(messages, 5)
assert cutoff == 0 # Should not cut anything
@@ -848,19 +974,14 @@ def test_summarization_middleware_fraction_trigger_with_no_profile() -> None:
)
# Test that when fractional condition can't be evaluated, other triggers still work
messages = [HumanMessage(content=str(i)) for i in range(100)]
messages: list[AnyMessage] = [HumanMessage(content=str(i)) for i in range(100)]
# Mock _get_profile_limits to return None
original_method = middleware._get_profile_limits
middleware._get_profile_limits = lambda: None
# Should still trigger based on message count
state = {"messages": messages}
result = middleware.before_model(state, None)
assert result is not None
# Restore original method
middleware._get_profile_limits = original_method
with patch.object(middleware, "_get_profile_limits", autospec=True, return_value=None):
# Should still trigger based on message count
state = AgentState[Any](messages=messages)
result = middleware.before_model(state, Runtime())
assert result is not None
def test_summarization_adjust_token_counts() -> None: