Compare commits

...

3 Commits

Author SHA1 Message Date
Chester Curme
2c938b787f Merge branch 'master' into cc/summarization_patch
# Conflicts:
#	libs/langchain_v1/langchain/agents/middleware/summarization.py
2025-12-02 10:07:51 -05:00
Chester Curme
fa18f8eda0 Merge branch 'master' into cc/summarization_patch 2025-12-01 09:39:54 -05:00
Chester Curme
b2db842cd4 treat keep threshold as a hard cap 2025-11-24 11:21:28 -05:00
2 changed files with 66 additions and 13 deletions

View File

@@ -3,7 +3,7 @@
import uuid
import warnings
from collections.abc import Callable, Iterable, Mapping
from functools import partial
from functools import cache, partial
from typing import Any, Literal, cast
from langchain_core.messages import (
@@ -356,6 +356,10 @@ class SummarizationMiddleware(AgentMiddleware):
if not messages:
return 0
@cache
def suffix_token_count(start_index: int) -> int:
return self.token_counter(messages[start_index:])
kind, value = self.keep
if kind == "fraction":
max_input_tokens = self._get_profile_limits()
@@ -370,7 +374,7 @@ class SummarizationMiddleware(AgentMiddleware):
if target_token_count <= 0:
target_token_count = 1
if self.token_counter(messages) <= target_token_count:
if suffix_token_count(0) <= target_token_count:
return 0
# Use binary search to identify the earliest message index that keeps the
@@ -383,7 +387,7 @@ class SummarizationMiddleware(AgentMiddleware):
break
mid = (left + right) // 2
if self.token_counter(messages[mid:]) <= target_token_count:
if suffix_token_count(mid) <= target_token_count:
cutoff_candidate = mid
right = mid
else:
@@ -397,8 +401,11 @@ class SummarizationMiddleware(AgentMiddleware):
return 0
cutoff_candidate = len(messages) - 1
for i in range(cutoff_candidate, -1, -1):
if self._is_safe_cutoff_point(messages, i):
for i in range(cutoff_candidate, len(messages) + 1):
if (
self._is_safe_cutoff_point(messages, i)
and suffix_token_count(i) <= target_token_count
):
return i
return 0

View File

@@ -1,10 +1,17 @@
from typing import TYPE_CHECKING
from typing import Iterable
from unittest.mock import patch
import pytest
from langchain_core.language_models import ModelProfile
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage
from langchain_core.messages import (
AIMessage,
AnyMessage,
HumanMessage,
MessageLikeRepresentation,
RemoveMessage,
ToolMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langgraph.graph.message import REMOVE_ALL_MESSAGES
@@ -316,7 +323,7 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
def test_summarization_middleware_token_retention_pct_respects_tool_pairs() -> None:
"""Ensure token retention keeps pairs together even if exceeding target tokens."""
"""Ensure token retention never splits tool pairs while enforcing hard caps."""
def token_counter(messages: list[AnyMessage]) -> int:
return sum(len(getattr(message, "content", "")) for message in messages)
@@ -344,13 +351,13 @@ def test_summarization_middleware_token_retention_pct_respects_tool_pairs() -> N
assert result is not None
preserved_messages = result["messages"][2:]
assert preserved_messages == messages[1:]
assert preserved_messages == messages[3:]
assert token_counter(preserved_messages) <= 500
assert not any(isinstance(msg, (AIMessage, ToolMessage)) for msg in preserved_messages)
target_token_count = int(1000 * 0.5)
preserved_tokens = middleware.token_counter(preserved_messages)
# Tool pair retention can exceed the target token count but should keep the pair intact.
assert preserved_tokens > target_token_count
target_token_count = int(1000 * 0.5)
assert preserved_tokens <= target_token_count
def test_summarization_middleware_missing_profile() -> None:
@@ -783,6 +790,45 @@ def test_summarization_middleware_tool_call_in_search_range() -> None:
assert middleware._is_safe_cutoff_point(messages, 1)
def test_summarization_middleware_results_under_window() -> None:
"""Ensure automatic profile inference triggers summarization when limits are exceeded."""
def _token_counter(messages: Iterable[MessageLikeRepresentation]) -> int:
count = 0
for message in messages:
if isinstance(message, ToolMessage):
count = count + 500
else:
count = count + 100
return count
state = {
"messages": [
HumanMessage(content="Message 1"),
AIMessage(
content="Message 2",
tool_calls=[
{"name": "test", "args": {}, "id": "call-1"},
{"name": "test", "args": {}, "id": "call-2"},
],
),
ToolMessage(content="Result 2-1", tool_call_id="call-1"),
ToolMessage(content="Result 2-2", tool_call_id="call-2"),
]
}
middleware = SummarizationMiddleware(
model=ProfileChatModel(),
trigger=("fraction", 0.80),
keep=("fraction", 0.5),
token_counter=_token_counter,
)
result = middleware.before_model(state, None)
assert result is not None
count_after_summarization = _token_counter(result["messages"])
assert count_after_summarization <= 1000 # max_input_tokens of ProfileChatModel
def test_summarization_middleware_zero_and_negative_target_tokens() -> None:
"""Test handling of edge cases with target token calculations."""
# Test with very small fraction that rounds to zero