diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py index 728c6c97dfe..d7b3157ab17 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py @@ -1,13 +1,32 @@ +from collections.abc import Iterable +from typing import Any from unittest.mock import patch import pytest +from langchain_core.callbacks import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun from langchain_core.language_models import ModelProfile +from langchain_core.language_models.base import ( + LanguageModelInput, +) from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage +from langchain_core.messages import ( + AIMessage, + AnyMessage, + BaseMessage, + HumanMessage, + MessageLikeRepresentation, + RemoveMessage, + ToolMessage, +) from langchain_core.messages.utils import count_tokens_approximately, get_buffer_string from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.runnables import RunnableConfig from langgraph.graph.message import REMOVE_ALL_MESSAGES +from langgraph.runtime import Runtime +from pydantic import Field +from typing_extensions import override +from langchain.agents import AgentState from langchain.agents.middleware.summarization import SummarizationMiddleware from langchain.chat_models import init_chat_model from tests.unit_tests.agents.model import FakeToolCallingModel @@ -16,10 +35,25 @@ from tests.unit_tests.agents.model import FakeToolCallingModel class MockChatModel(BaseChatModel): """Mock chat model for testing.""" - def invoke(self, prompt): # type: ignore[no-untyped-def] + @override + def invoke( + self, + input: LanguageModelInput, + config: RunnableConfig | None = None, + *, + stop: list[str] | None = None, + **kwargs: Any, + ) -> AIMessage: return AIMessage(content="Generated summary") - def _generate(self, messages, **kwargs): # type: ignore[no-untyped-def] + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) @property @@ -30,7 +64,14 @@ class MockChatModel(BaseChatModel): class ProfileChatModel(BaseChatModel): """Mock chat model with profile for testing.""" - def _generate(self, messages, **kwargs): # type: ignore[no-untyped-def] + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) profile: ModelProfile | None = ModelProfile(max_input_tokens=1000) @@ -79,16 +120,16 @@ def test_summarization_middleware_no_summarization_cases() -> None: # Test when summarization is disabled middleware_disabled = SummarizationMiddleware(model=model, trigger=None) - state = {"messages": [HumanMessage(content="Hello"), AIMessage(content="Hi")]} - result = middleware_disabled.before_model(state, None) + state = AgentState[Any](messages=[HumanMessage(content="Hello"), AIMessage(content="Hi")]) + result = middleware_disabled.before_model(state, Runtime()) assert result is None # Test when token count is below threshold - def mock_token_counter(messages): + def mock_token_counter(_: Iterable[MessageLikeRepresentation]) -> int: return 500 # Below threshold middleware.token_counter = mock_token_counter - result = middleware.before_model(state, None) + result = middleware.before_model(state, Runtime()) assert result is None @@ -98,7 +139,7 @@ def test_summarization_middleware_helper_methods() -> None: middleware = SummarizationMiddleware(model=model, trigger=("tokens", 1000)) # Test message ID assignment - messages = [HumanMessage(content="Hello"), AIMessage(content="Hi")] + messages: list[AnyMessage] = [HumanMessage(content="Hello"), AIMessage(content="Hi")] middleware._ensure_message_ids(messages) for msg in messages: assert msg.id is not None @@ -131,7 +172,7 @@ def test_summarization_middleware_summary_creation() -> None: middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("tokens", 1000)) # Test normal summary creation - messages = [HumanMessage(content="Hello"), AIMessage(content="Hi")] + messages: list[AnyMessage] = [HumanMessage(content="Hello"), AIMessage(content="Hi")] summary = middleware._create_summary(messages) assert summary == "Generated summary" @@ -141,15 +182,30 @@ def test_summarization_middleware_summary_creation() -> None: # Test error handling class ErrorModel(BaseChatModel): - def invoke(self, prompt): + @override + def invoke( + self, + input: LanguageModelInput, + config: RunnableConfig | None = None, + *, + stop: list[str] | None = None, + **kwargs: Any, + ) -> AIMessage: msg = "Model error" raise ValueError(msg) - def _generate(self, messages, **kwargs): + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) @property - def _llm_type(self): + def _llm_type(self) -> str: return "mock" middleware_error = SummarizationMiddleware(model=ErrorModel(), trigger=("tokens", 1000)) @@ -165,12 +221,16 @@ def test_summarization_middleware_summary_creation() -> None: def test_summarization_middleware_trim_limit_none_keeps_all_messages() -> None: """Verify disabling trim limit preserves full message sequence.""" - messages = [HumanMessage(content=str(i)) for i in range(10)] + messages: list[AnyMessage] = [HumanMessage(content=str(i)) for i in range(10)] middleware = SummarizationMiddleware( model=MockChatModel(), trim_tokens_to_summarize=None, ) - middleware.token_counter = len + + def token_counter(messages: Iterable[MessageLikeRepresentation]) -> int: + return len(list(messages)) + + middleware.token_counter = token_counter trimmed = middleware._trim_messages_for_summary(messages) assert trimmed is messages @@ -179,8 +239,8 @@ def test_summarization_middleware_trim_limit_none_keeps_all_messages() -> None: def test_summarization_middleware_profile_inference_triggers_summary() -> None: """Ensure automatic profile inference triggers summarization when limits are exceeded.""" - def token_counter(messages): - return len(messages) * 200 + def token_counter(messages: Iterable[MessageLikeRepresentation]) -> int: + return len(list(messages)) * 200 middleware = SummarizationMiddleware( model=ProfileChatModel(), @@ -189,20 +249,20 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None: token_counter=token_counter, ) - state = { - "messages": [ + state = AgentState[Any]( + messages=[ HumanMessage(content="Message 1"), AIMessage(content="Message 2"), HumanMessage(content="Message 3"), AIMessage(content="Message 4"), ] - } + ) # Test we don't engage summarization # we have total_tokens = 4 * 200 = 800 # and max_input_tokens = 1000 # since 0.81 * 1000 == 810 > 800 -> summarization not triggered - result = middleware.before_model(state, None) + result = middleware.before_model(state, Runtime()) assert result is None # Engage summarization @@ -213,7 +273,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None: keep=("fraction", 0.5), token_counter=token_counter, ) - result = middleware.before_model(state, None) + result = middleware.before_model(state, Runtime()) assert result is not None assert isinstance(result["messages"][0], RemoveMessage) summary_message = result["messages"][1] @@ -233,7 +293,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None: keep=("fraction", 0.6), token_counter=token_counter, ) - result = middleware.before_model(state, None) + result = middleware.before_model(state, Runtime()) assert result is not None assert [message.content for message in result["messages"][2:]] == [ "Message 2", @@ -250,7 +310,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None: keep=("fraction", 0.8), token_counter=token_counter, ) - assert middleware.before_model(state, None) is None + assert middleware.before_model(state, Runtime()) is None # Test with tokens_to_keep as absolute int value middleware_int = SummarizationMiddleware( @@ -259,7 +319,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None: keep=("tokens", 400), # Keep exactly 400 tokens (2 messages) token_counter=token_counter, ) - result = middleware_int.before_model(state, None) + result = middleware_int.before_model(state, Runtime()) assert result is not None assert [message.content for message in result["messages"][2:]] == [ "Message 3", @@ -273,7 +333,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None: keep=("tokens", 600), # Keep 600 tokens (3 messages) token_counter=token_counter, ) - result = middleware_int_large.before_model(state, None) + result = middleware_int_large.before_model(state, Runtime()) assert result is not None assert [message.content for message in result["messages"][2:]] == [ "Message 2", @@ -285,7 +345,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None: def test_summarization_middleware_token_retention_preserves_ai_tool_pairs() -> None: """Ensure token retention preserves AI/Tool message pairs together.""" - def token_counter(messages: list[AnyMessage]) -> int: + def token_counter(messages: Iterable[MessageLikeRepresentation]) -> int: return sum(len(getattr(message, "content", "")) for message in messages) middleware = SummarizationMiddleware( @@ -310,8 +370,8 @@ def test_summarization_middleware_token_retention_preserves_ai_tool_pairs() -> N HumanMessage(content="H" * 160), ] - state = {"messages": messages} - result = middleware.before_model(state, None) + state = AgentState[Any](messages=messages) + result = middleware.before_model(state, Runtime()) assert result is not None preserved_messages = result["messages"][2:] @@ -330,24 +390,33 @@ def test_summarization_middleware_missing_profile() -> None: """Ensure automatic profile inference falls back when profiles are unavailable.""" class ImportErrorProfileModel(BaseChatModel): - def _generate(self, messages, **kwargs): + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: raise NotImplementedError @property def _llm_type(self) -> str: return "mock" - @property - def profile(self): - msg = "Profile not available" - raise ImportError(msg) + # NOTE: Using __getattribute__ because @property cannot override Pydantic fields. + def __getattribute__(self, name: str) -> Any: + if name == "profile": + msg = "Profile not available" + raise AttributeError(msg) + return super().__getattribute__(name) with pytest.raises( ValueError, match="Model profile information is required to use fractional token limits", ): _ = SummarizationMiddleware( - model=ImportErrorProfileModel, trigger=("fraction", 0.5), keep=("messages", 1) + model=ImportErrorProfileModel(), trigger=("fraction", 0.5), keep=("messages", 1) ) @@ -360,12 +429,12 @@ def test_summarization_middleware_full_workflow() -> None: ) # Mock high token count to trigger summarization - def mock_token_counter(messages): + def mock_token_counter(_: Iterable[MessageLikeRepresentation]) -> int: return 1500 # Above threshold middleware.token_counter = mock_token_counter - messages = [ + messages: list[AnyMessage] = [ HumanMessage(content="1"), HumanMessage(content="2"), HumanMessage(content="3"), @@ -373,8 +442,8 @@ def test_summarization_middleware_full_workflow() -> None: HumanMessage(content="5"), ] - state = {"messages": messages} - result = middleware.before_model(state, None) + state = AgentState[Any](messages=messages) + result = middleware.before_model(state, Runtime()) assert result is not None assert "messages" in result @@ -399,14 +468,28 @@ async def test_summarization_middleware_full_workflow_async() -> None: """Test SummarizationMiddleware complete summarization workflow.""" class MockModel(BaseChatModel): - def _generate(self, messages, **kwargs): + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Blep"))]) - async def _agenerate(self, messages, **kwargs): + @override + async def _agenerate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Blip"))]) @property - def _llm_type(self): + def _llm_type(self) -> str: return "mock" middleware = SummarizationMiddleware( @@ -414,12 +497,12 @@ async def test_summarization_middleware_full_workflow_async() -> None: ) # Mock high token count to trigger summarization - def mock_token_counter(messages): + def mock_token_counter(_: Iterable[MessageLikeRepresentation]) -> int: return 1500 # Above threshold middleware.token_counter = mock_token_counter - messages = [ + messages: list[AnyMessage] = [ HumanMessage(content="1"), HumanMessage(content="2"), HumanMessage(content="3"), @@ -427,8 +510,8 @@ async def test_summarization_middleware_full_workflow_async() -> None: HumanMessage(content="5"), ] - state = {"messages": messages} - result = await middleware.abefore_model(state, None) + state = AgentState[Any](messages=messages) + result = await middleware.abefore_model(state, Runtime()) assert result is not None assert "messages" in result @@ -451,26 +534,26 @@ def test_summarization_middleware_keep_messages() -> None: ) # Below threshold - no summarization - messages_below = [ + messages_below: list[AnyMessage] = [ HumanMessage(content="1"), HumanMessage(content="2"), HumanMessage(content="3"), HumanMessage(content="4"), ] - state_below = {"messages": messages_below} - result = middleware.before_model(state_below, None) + state_below = AgentState[Any](messages=messages_below) + result = middleware.before_model(state_below, Runtime()) assert result is None # At threshold - should trigger summarization - messages_at_threshold = [ + messages_at_threshold: list[AnyMessage] = [ HumanMessage(content="1"), HumanMessage(content="2"), HumanMessage(content="3"), HumanMessage(content="4"), HumanMessage(content="5"), ] - state_at = {"messages": messages_at_threshold} - result = middleware.before_model(state_at, None) + state_at = AgentState[Any](messages=messages_at_threshold) + result = middleware.before_model(state_at, Runtime()) assert result is not None assert "messages" in result expected_types = ["remove", "human", "human", "human"] @@ -479,9 +562,9 @@ def test_summarization_middleware_keep_messages() -> None: assert [message.content for message in result["messages"][2:]] == ["4", "5"] # Above threshold - should also trigger summarization - messages_above = [*messages_at_threshold, HumanMessage(content="6")] - state_above = {"messages": messages_above} - result = middleware.before_model(state_above, None) + messages_above: list[AnyMessage] = [*messages_at_threshold, HumanMessage(content="6")] + state_above = AgentState[Any](messages=messages_above) + result = middleware.before_model(state_above, Runtime()) assert result is not None assert "messages" in result expected_types = ["remove", "human", "human", "human"] @@ -491,7 +574,7 @@ def test_summarization_middleware_keep_messages() -> None: # Test with both parameters disabled middleware_disabled = SummarizationMiddleware(model=MockChatModel(), trigger=None) - result = middleware_disabled.before_model(state_above, None) + result = middleware_disabled.before_model(state_above, Runtime()) assert result is None @@ -509,7 +592,7 @@ def test_summarization_middleware_keep_messages() -> None: ], ) def test_summarization_middleware_validation_edge_cases( - param_name: str, param_value: tuple[str, float | int], expected_error: str + param_name: str, param_value: Any, expected_error: str ) -> None: """Test validation of context size parameters with edge cases.""" model = FakeToolCallingModel() @@ -527,31 +610,31 @@ def test_summarization_middleware_multiple_triggers() -> None: ) # Mock token counter to return low count - def mock_low_tokens(messages): + def mock_low_tokens(_: Iterable[MessageLikeRepresentation]) -> int: return 100 middleware.token_counter = mock_low_tokens # Should not trigger - neither condition met - messages = [HumanMessage(content=str(i)) for i in range(5)] - state = {"messages": messages} - result = middleware.before_model(state, None) + messages: list[AnyMessage] = [HumanMessage(content=str(i)) for i in range(5)] + state = AgentState[Any](messages=messages) + result = middleware.before_model(state, Runtime()) assert result is None # Should trigger - message count threshold met messages = [HumanMessage(content=str(i)) for i in range(10)] - state = {"messages": messages} - result = middleware.before_model(state, None) + state = AgentState[Any](messages=messages) + result = middleware.before_model(state, Runtime()) assert result is not None # Test token trigger - def mock_high_tokens(messages): + def mock_high_tokens(_: Iterable[MessageLikeRepresentation]) -> int: return 600 middleware.token_counter = mock_high_tokens messages = [HumanMessage(content=str(i)) for i in range(5)] - state = {"messages": messages} - result = middleware.before_model(state, None) + state = AgentState[Any](messages=messages) + result = middleware.before_model(state, Runtime()) assert result is not None @@ -559,11 +642,18 @@ def test_summarization_middleware_profile_edge_cases() -> None: """Test profile retrieval with various edge cases.""" class NoProfileModel(BaseChatModel): - def _generate(self, messages, **kwargs): + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) @property - def _llm_type(self): + def _llm_type(self) -> str: return "mock" # Model without profile attribute @@ -571,49 +661,71 @@ def test_summarization_middleware_profile_edge_cases() -> None: assert middleware._get_profile_limits() is None class InvalidProfileModel(BaseChatModel): - def _generate(self, messages, **kwargs): + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) @property - def _llm_type(self): + def _llm_type(self) -> str: return "mock" - @property - def profile(self): - return "invalid_profile_type" + # NOTE: Using __getattribute__ because @property cannot override Pydantic fields. + def __getattribute__(self, name: str) -> Any: + if name == "profile": + return "invalid_profile_type" + return super().__getattribute__(name) # Model with non-dict profile middleware = SummarizationMiddleware(model=InvalidProfileModel(), trigger=("messages", 5)) assert middleware._get_profile_limits() is None class MissingTokensModel(BaseChatModel): - def _generate(self, messages, **kwargs): + profile: ModelProfile | None = Field(default=ModelProfile(other_field=100), exclude=True) # type: ignore[typeddict-unknown-key] + + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) @property - def _llm_type(self): + def _llm_type(self) -> str: return "mock" - @property - def profile(self): - return {"other_field": 100} - # Model with profile but no max_input_tokens middleware = SummarizationMiddleware(model=MissingTokensModel(), trigger=("messages", 5)) assert middleware._get_profile_limits() is None class InvalidTokenTypeModel(BaseChatModel): - def _generate(self, messages, **kwargs): + profile: ModelProfile | None = Field( + default=ModelProfile(max_input_tokens="not_an_int"), # type: ignore[typeddict-item] + exclude=True, + ) + + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) @property - def _llm_type(self): + def _llm_type(self) -> str: return "mock" - @property - def profile(self): - return {"max_input_tokens": "not_an_int"} - # Model with non-int max_input_tokens middleware = SummarizationMiddleware(model=InvalidTokenTypeModel(), trigger=("messages", 5)) assert middleware._get_profile_limits() is None @@ -624,14 +736,14 @@ def test_summarization_middleware_trim_messages_error_fallback() -> None: middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 5)) # Create a mock token counter that raises an exception - def failing_token_counter(messages): + def failing_token_counter(_: Iterable[MessageLikeRepresentation]) -> int: msg = "Token counting failed" raise ValueError(msg) middleware.token_counter = failing_token_counter # Should fall back to last 15 messages - messages = [HumanMessage(content=str(i)) for i in range(20)] + messages: list[AnyMessage] = [HumanMessage(content=str(i)) for i in range(20)] trimmed = middleware._trim_messages_for_summary(messages) assert len(trimmed) == 15 assert trimmed == messages[-15:] @@ -644,12 +756,12 @@ def test_summarization_middleware_binary_search_edge_cases() -> None: ) # Test with single message that's too large - def token_counter_single_large(messages): - return len(messages) * 200 + def token_counter_single_large(messages: Iterable[MessageLikeRepresentation]) -> int: + return len(list(messages)) * 200 middleware.token_counter = token_counter_single_large - single_message = [HumanMessage(content="x" * 200)] + single_message: list[AnyMessage] = [HumanMessage(content="x" * 200)] cutoff = middleware._find_token_based_cutoff(single_message) assert cutoff == 0 @@ -658,11 +770,11 @@ def test_summarization_middleware_binary_search_edge_cases() -> None: assert cutoff == 0 # Test when all messages fit within token budget - def token_counter_small(messages): - return len(messages) * 10 + def token_counter_small(messages: Iterable[MessageLikeRepresentation]) -> int: + return len(list(messages)) * 10 middleware.token_counter = token_counter_small - messages = [HumanMessage(content=str(i)) for i in range(5)] + messages: list[AnyMessage] = [HumanMessage(content=str(i)) for i in range(5)] cutoff = middleware._find_token_based_cutoff(messages) assert cutoff == 0 @@ -774,11 +886,11 @@ def test_summarization_middleware_zero_and_negative_target_tokens() -> None: ) # Should set threshold to 1 when calculated value is <= 0 - messages = [HumanMessage(content="test")] + messages: list[AnyMessage] = [HumanMessage(content="test")] # The trigger fraction calculation: int(1000 * 0.0001) = 0, but should be set to 1 # Token count of 1 message should exceed threshold of 1 - def token_counter(msgs): + def token_counter(_: Iterable[MessageLikeRepresentation]) -> int: return 2 middleware.token_counter = token_counter @@ -789,19 +901,33 @@ async def test_summarization_middleware_async_error_handling() -> None: """Test async summary creation with errors.""" class ErrorAsyncModel(BaseChatModel): - def _generate(self, messages, **kwargs): + @override + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) - async def _agenerate(self, messages, **kwargs): + @override + async def _agenerate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: msg = "Async model error" raise ValueError(msg) @property - def _llm_type(self): + def _llm_type(self) -> str: return "mock" middleware = SummarizationMiddleware(model=ErrorAsyncModel(), trigger=("messages", 5)) - messages = [HumanMessage(content="test")] + messages: list[AnyMessage] = [HumanMessage(content="test")] summary = await middleware._acreate_summary(messages) assert "Error generating summary: Async model error" in summary @@ -813,7 +939,7 @@ def test_summarization_middleware_cutoff_at_boundary() -> None: ) # When we want to keep exactly as many messages as we have - messages = [HumanMessage(content=str(i)) for i in range(5)] + messages: list[AnyMessage] = [HumanMessage(content=str(i)) for i in range(5)] cutoff = middleware._find_safe_cutoff(messages, 5) assert cutoff == 0 # Should not cut anything @@ -848,19 +974,14 @@ def test_summarization_middleware_fraction_trigger_with_no_profile() -> None: ) # Test that when fractional condition can't be evaluated, other triggers still work - messages = [HumanMessage(content=str(i)) for i in range(100)] + messages: list[AnyMessage] = [HumanMessage(content=str(i)) for i in range(100)] # Mock _get_profile_limits to return None - original_method = middleware._get_profile_limits - middleware._get_profile_limits = lambda: None - - # Should still trigger based on message count - state = {"messages": messages} - result = middleware.before_model(state, None) - assert result is not None - - # Restore original method - middleware._get_profile_limits = original_method + with patch.object(middleware, "_get_profile_limits", autospec=True, return_value=None): + # Should still trigger based on message count + state = AgentState[Any](messages=messages) + result = middleware.before_model(state, Runtime()) + assert result is not None def test_summarization_adjust_token_counts() -> None: