Compare commits

...

3 Commits

Author SHA1 Message Date
Chester Curme
2c938b787f Merge branch 'master' into cc/summarization_patch
# Conflicts:
#	libs/langchain_v1/langchain/agents/middleware/summarization.py
2025-12-02 10:07:51 -05:00
Chester Curme
fa18f8eda0 Merge branch 'master' into cc/summarization_patch 2025-12-01 09:39:54 -05:00
Chester Curme
b2db842cd4 treat keep threshold as a hard cap 2025-11-24 11:21:28 -05:00
2 changed files with 66 additions and 13 deletions

View File

@@ -3,7 +3,7 @@
import uuid import uuid
import warnings import warnings
from collections.abc import Callable, Iterable, Mapping from collections.abc import Callable, Iterable, Mapping
from functools import partial from functools import cache, partial
from typing import Any, Literal, cast from typing import Any, Literal, cast
from langchain_core.messages import ( from langchain_core.messages import (
@@ -356,6 +356,10 @@ class SummarizationMiddleware(AgentMiddleware):
if not messages: if not messages:
return 0 return 0
@cache
def suffix_token_count(start_index: int) -> int:
return self.token_counter(messages[start_index:])
kind, value = self.keep kind, value = self.keep
if kind == "fraction": if kind == "fraction":
max_input_tokens = self._get_profile_limits() max_input_tokens = self._get_profile_limits()
@@ -370,7 +374,7 @@ class SummarizationMiddleware(AgentMiddleware):
if target_token_count <= 0: if target_token_count <= 0:
target_token_count = 1 target_token_count = 1
if self.token_counter(messages) <= target_token_count: if suffix_token_count(0) <= target_token_count:
return 0 return 0
# Use binary search to identify the earliest message index that keeps the # Use binary search to identify the earliest message index that keeps the
@@ -383,7 +387,7 @@ class SummarizationMiddleware(AgentMiddleware):
break break
mid = (left + right) // 2 mid = (left + right) // 2
if self.token_counter(messages[mid:]) <= target_token_count: if suffix_token_count(mid) <= target_token_count:
cutoff_candidate = mid cutoff_candidate = mid
right = mid right = mid
else: else:
@@ -397,8 +401,11 @@ class SummarizationMiddleware(AgentMiddleware):
return 0 return 0
cutoff_candidate = len(messages) - 1 cutoff_candidate = len(messages) - 1
for i in range(cutoff_candidate, -1, -1): for i in range(cutoff_candidate, len(messages) + 1):
if self._is_safe_cutoff_point(messages, i): if (
self._is_safe_cutoff_point(messages, i)
and suffix_token_count(i) <= target_token_count
):
return i return i
return 0 return 0

View File

@@ -1,10 +1,17 @@
from typing import TYPE_CHECKING from typing import Iterable
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from langchain_core.language_models import ModelProfile from langchain_core.language_models import ModelProfile
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage from langchain_core.messages import (
AIMessage,
AnyMessage,
HumanMessage,
MessageLikeRepresentation,
RemoveMessage,
ToolMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult from langchain_core.outputs import ChatGeneration, ChatResult
from langgraph.graph.message import REMOVE_ALL_MESSAGES from langgraph.graph.message import REMOVE_ALL_MESSAGES
@@ -316,7 +323,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
def test_summarization_middleware_token_retention_pct_respects_tool_pairs() -> None: def test_summarization_middleware_token_retention_pct_respects_tool_pairs() -> None:
"""Ensure token retention keeps pairs together even if exceeding target tokens.""" """Ensure token retention never splits tool pairs while enforcing hard caps."""
def token_counter(messages: list[AnyMessage]) -> int: def token_counter(messages: list[AnyMessage]) -> int:
return sum(len(getattr(message, "content", "")) for message in messages) return sum(len(getattr(message, "content", "")) for message in messages)
@@ -344,13 +351,13 @@ def test_summarization_middleware_token_retention_pct_respects_tool_pairs() -> N
assert result is not None assert result is not None
preserved_messages = result["messages"][2:] preserved_messages = result["messages"][2:]
assert preserved_messages == messages[1:] assert preserved_messages == messages[3:]
assert token_counter(preserved_messages) <= 500
assert not any(isinstance(msg, (AIMessage, ToolMessage)) for msg in preserved_messages)
target_token_count = int(1000 * 0.5)
preserved_tokens = middleware.token_counter(preserved_messages) preserved_tokens = middleware.token_counter(preserved_messages)
target_token_count = int(1000 * 0.5)
# Tool pair retention can exceed the target token count but should keep the pair intact. assert preserved_tokens <= target_token_count
assert preserved_tokens > target_token_count
def test_summarization_middleware_missing_profile() -> None: def test_summarization_middleware_missing_profile() -> None:
@@ -783,6 +790,45 @@ def test_summarization_middleware_tool_call_in_search_range() -> None:
assert middleware._is_safe_cutoff_point(messages, 1) assert middleware._is_safe_cutoff_point(messages, 1)
def test_summarization_middleware_results_under_window() -> None:
"""Ensure automatic profile inference triggers summarization when limits are exceeded."""
def _token_counter(messages: Iterable[MessageLikeRepresentation]) -> int:
count = 0
for message in messages:
if isinstance(message, ToolMessage):
count = count + 500
else:
count = count + 100
return count
state = {
"messages": [
HumanMessage(content="Message 1"),
AIMessage(
content="Message 2",
tool_calls=[
{"name": "test", "args": {}, "id": "call-1"},
{"name": "test", "args": {}, "id": "call-2"},
],
),
ToolMessage(content="Result 2-1", tool_call_id="call-1"),
ToolMessage(content="Result 2-2", tool_call_id="call-2"),
]
}
middleware = SummarizationMiddleware(
model=ProfileChatModel(),
trigger=("fraction", 0.80),
keep=("fraction", 0.5),
token_counter=_token_counter,
)
result = middleware.before_model(state, None)
assert result is not None
count_after_summarization = _token_counter(result["messages"])
assert count_after_summarization <= 1000 # max_input_tokens of ProfileChatModel
def test_summarization_middleware_zero_and_negative_target_tokens() -> None: def test_summarization_middleware_zero_and_negative_target_tokens() -> None:
"""Test handling of edge cases with target token calculations.""" """Test handling of edge cases with target token calculations."""
# Test with very small fraction that rounds to zero # Test with very small fraction that rounds to zero