mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-18 21:11:06 +00:00
Compare commits
7 Commits
mdrxy/vers
...
sr/looser-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6ba1177f4f | ||
|
|
62c05e09c1 | ||
|
|
83b9d9f810 | ||
|
|
6deee23d8d | ||
|
|
690aabe8d4 | ||
|
|
80554df1e6 | ||
|
|
72c45e65e8 |
@@ -58,10 +58,19 @@ _DEFAULT_FALLBACK_MESSAGE_COUNT = 15
|
||||
_SEARCH_RANGE_FOR_TOOL_PAIRS = 5
|
||||
|
||||
ContextFraction = tuple[Literal["fraction"], float]
|
||||
"""Tuple specifying context size as a fraction of the model's context window."""
|
||||
ContextTokens = tuple[Literal["tokens"], int]
|
||||
"""Tuple specifying context size as a number of tokens."""
|
||||
ContextMessages = tuple[Literal["messages"], int]
|
||||
"""Tuple specifying context size as a number of messages."""
|
||||
|
||||
ContextSize = ContextFraction | ContextTokens | ContextMessages
|
||||
"""Context size tuple to specify how much history to preserve."""
|
||||
|
||||
ContextCondition = ContextSize | list[ContextSize | list[ContextSize]]
|
||||
"""Recursive type to support nested AND/OR conditions
|
||||
|
||||
Top-level list = OR logic, nested list = AND logic."""
|
||||
|
||||
|
||||
class SummarizationMiddleware(AgentMiddleware):
|
||||
@@ -76,7 +85,7 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
self,
|
||||
model: str | BaseChatModel,
|
||||
*,
|
||||
trigger: ContextSize | list[ContextSize] | None = None,
|
||||
trigger: ContextCondition | None = None,
|
||||
keep: ContextSize = ("messages", _DEFAULT_MESSAGES_TO_KEEP),
|
||||
token_counter: TokenCounter = count_tokens_approximately,
|
||||
summary_prompt: str = DEFAULT_SUMMARY_PROMPT,
|
||||
@@ -87,13 +96,16 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
|
||||
Args:
|
||||
model: The language model to use for generating summaries.
|
||||
trigger: One or more thresholds that trigger summarization.
|
||||
|
||||
Provide a single `ContextSize` tuple or a list of tuples, in which case
|
||||
summarization runs when any threshold is breached.
|
||||
|
||||
Examples: `("messages", 50)`, `("tokens", 3000)`, `[("fraction", 0.8),
|
||||
("messages", 100)]`.
|
||||
trigger: One or more thresholds that trigger summarization. Supports flexible
|
||||
AND/OR logic via nested lists. Top-level list items are combined with OR,
|
||||
nested lists are combined with AND. Examples:
|
||||
- Single condition: `("messages", 50)`
|
||||
- OR conditions: `[("tokens", 3000), ("messages", 100)]` (triggers when
|
||||
tokens >= 3000 OR messages >= 100)
|
||||
- AND conditions: `[("tokens", 500), ("fraction", 0.8)]` as a nested list
|
||||
within the top-level list
|
||||
- Mixed AND/OR: `[("messages", 10), [("tokens", 500), ("fraction", 0.8)]]`
|
||||
(triggers when messages >= 10 OR (tokens >= 500 AND fraction >= 0.8))
|
||||
keep: Context retention policy applied after summarization.
|
||||
|
||||
Provide a `ContextSize` tuple to specify how much history to preserve.
|
||||
@@ -137,13 +149,15 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
|
||||
self.model = model
|
||||
if trigger is None:
|
||||
self.trigger: ContextSize | list[ContextSize] | None = None
|
||||
trigger_conditions: list[ContextSize] = []
|
||||
self.trigger: ContextCondition | None = None
|
||||
trigger_conditions: list[ContextSize | list[ContextSize]] = []
|
||||
elif isinstance(trigger, list):
|
||||
validated_list = [self._validate_context_size(item, "trigger") for item in trigger]
|
||||
# Validate and normalize nested structure
|
||||
validated_list = self._validate_trigger_conditions(trigger)
|
||||
self.trigger = validated_list
|
||||
trigger_conditions = validated_list
|
||||
else:
|
||||
# Single ContextSize tuple
|
||||
validated = self._validate_context_size(trigger, "trigger")
|
||||
self.trigger = validated
|
||||
trigger_conditions = [validated]
|
||||
@@ -154,7 +168,7 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
self.summary_prompt = summary_prompt
|
||||
self.trim_tokens_to_summarize = trim_tokens_to_summarize
|
||||
|
||||
requires_profile = any(condition[0] == "fraction" for condition in self._trigger_conditions)
|
||||
requires_profile = self._requires_profile(self._trigger_conditions)
|
||||
if self.keep[0] == "fraction":
|
||||
requires_profile = True
|
||||
if requires_profile and self._get_profile_limits() is None:
|
||||
@@ -220,24 +234,52 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
}
|
||||
|
||||
def _should_summarize(self, messages: list[AnyMessage], total_tokens: int) -> bool:
|
||||
"""Determine whether summarization should run for the current token usage."""
|
||||
"""Determine whether summarization should run for the current token usage.
|
||||
|
||||
Evaluates trigger conditions with AND/OR logic:
|
||||
- Top-level items are OR'd together
|
||||
- Nested lists are AND'd together
|
||||
"""
|
||||
if not self._trigger_conditions:
|
||||
return False
|
||||
|
||||
for kind, value in self._trigger_conditions:
|
||||
if kind == "messages" and len(messages) >= value:
|
||||
return True
|
||||
if kind == "tokens" and total_tokens >= value:
|
||||
return True
|
||||
if kind == "fraction":
|
||||
max_input_tokens = self._get_profile_limits()
|
||||
if max_input_tokens is None:
|
||||
continue
|
||||
threshold = int(max_input_tokens * value)
|
||||
if threshold <= 0:
|
||||
threshold = 1
|
||||
if total_tokens >= threshold:
|
||||
# OR logic across top-level conditions
|
||||
for condition in self._trigger_conditions:
|
||||
if isinstance(condition, list):
|
||||
# AND group - all must be satisfied
|
||||
if self._check_and_group(condition, messages, total_tokens):
|
||||
return True
|
||||
elif self._check_single_condition(condition, messages, total_tokens):
|
||||
# Single condition
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check_and_group(
|
||||
self, and_group: list[ContextSize], messages: list[AnyMessage], total_tokens: int
|
||||
) -> bool:
|
||||
"""Check if all conditions in an AND group are satisfied."""
|
||||
for condition in and_group:
|
||||
if not self._check_single_condition(condition, messages, total_tokens):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _check_single_condition(
|
||||
self, condition: ContextSize, messages: list[AnyMessage], total_tokens: int
|
||||
) -> bool:
|
||||
"""Check if a single condition is satisfied."""
|
||||
kind, value = condition
|
||||
if kind == "messages":
|
||||
return len(messages) >= value
|
||||
if kind == "tokens":
|
||||
return total_tokens >= value
|
||||
if kind == "fraction":
|
||||
max_input_tokens = self._get_profile_limits()
|
||||
if max_input_tokens is None:
|
||||
return False
|
||||
threshold = int(max_input_tokens * value)
|
||||
if threshold <= 0:
|
||||
threshold = 1
|
||||
return total_tokens >= threshold
|
||||
return False
|
||||
|
||||
def _determine_cutoff_index(self, messages: list[AnyMessage]) -> int:
|
||||
@@ -337,6 +379,45 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
raise ValueError(msg)
|
||||
return context
|
||||
|
||||
def _validate_trigger_conditions(
|
||||
self, conditions: list[Any]
|
||||
) -> list[ContextSize | list[ContextSize]]:
|
||||
"""Validate and normalize trigger conditions with nested AND/OR logic.
|
||||
|
||||
Args:
|
||||
conditions: List of ContextSize tuples or nested lists of ContextSize tuples.
|
||||
|
||||
Returns:
|
||||
Validated list where top-level items are OR'd and nested lists are AND'd.
|
||||
"""
|
||||
validated: list[ContextSize | list[ContextSize]] = []
|
||||
for item in conditions:
|
||||
if isinstance(item, tuple):
|
||||
# Single condition (tuple)
|
||||
validated.append(self._validate_context_size(item, "trigger"))
|
||||
elif isinstance(item, list):
|
||||
# AND group (nested list)
|
||||
if not item:
|
||||
msg = "Empty AND groups are not allowed in trigger conditions."
|
||||
raise ValueError(msg)
|
||||
and_group = [self._validate_context_size(cond, "trigger") for cond in item]
|
||||
validated.append(and_group)
|
||||
else:
|
||||
msg = f"Trigger conditions must be tuples or lists, got {type(item).__name__}."
|
||||
raise ValueError(msg)
|
||||
return validated
|
||||
|
||||
def _requires_profile(self, conditions: list[ContextSize | list[ContextSize]]) -> bool:
|
||||
"""Check if any condition requires model profile information."""
|
||||
for condition in conditions:
|
||||
if isinstance(condition, list):
|
||||
# AND group
|
||||
if any(c[0] == "fraction" for c in condition):
|
||||
return True
|
||||
elif condition[0] == "fraction":
|
||||
return True
|
||||
return False
|
||||
|
||||
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
|
||||
return [
|
||||
HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}")
|
||||
|
||||
@@ -896,3 +896,171 @@ def test_summarization_middleware_is_safe_cutoff_at_end() -> None:
|
||||
|
||||
# Cutoff past the length should also be safe
|
||||
assert middleware._is_safe_cutoff_point(messages, len(messages) + 5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("trigger", "model_class", "tokens_per_message", "num_messages", "should_trigger"),
|
||||
[
|
||||
# ========================================================================
|
||||
# Test Case Group 1: Pure AND logic
|
||||
# Trigger condition: messages >= 5 AND tokens >= 300
|
||||
# ========================================================================
|
||||
(
|
||||
[[("messages", 5), ("tokens", 300)]],
|
||||
MockChatModel,
|
||||
50, # tokens_per_message
|
||||
4, # num_messages -> total: 4 msgs, 200 tokens
|
||||
False, # Expected: neither condition met
|
||||
),
|
||||
(
|
||||
[[("messages", 5), ("tokens", 300)]],
|
||||
MockChatModel,
|
||||
50, # tokens_per_message
|
||||
5, # num_messages -> total: 5 msgs, 250 tokens
|
||||
False, # Expected: only messages condition met (tokens < 300)
|
||||
),
|
||||
(
|
||||
[[("messages", 5), ("tokens", 300)]],
|
||||
MockChatModel,
|
||||
50, # tokens_per_message
|
||||
7, # num_messages -> total: 7 msgs, 350 tokens
|
||||
True, # Expected: both conditions met
|
||||
),
|
||||
# ========================================================================
|
||||
# Test Case Group 2: Mixed OR/AND logic
|
||||
# Trigger condition: messages >= 10 OR (tokens >= 500 AND messages >= 3)
|
||||
# ========================================================================
|
||||
(
|
||||
[("messages", 10), [("tokens", 500), ("messages", 3)]],
|
||||
MockChatModel,
|
||||
100, # tokens_per_message
|
||||
2, # num_messages -> total: 2 msgs, 200 tokens
|
||||
False, # Expected: no conditions met
|
||||
),
|
||||
(
|
||||
[("messages", 10), [("tokens", 500), ("messages", 3)]],
|
||||
MockChatModel,
|
||||
100, # tokens_per_message
|
||||
5, # num_messages -> total: 5 msgs, 500 tokens
|
||||
True, # Expected: AND group satisfied (tokens >= 500 AND messages >= 3)
|
||||
),
|
||||
(
|
||||
[("messages", 10), [("tokens", 500), ("messages", 3)]],
|
||||
MockChatModel,
|
||||
10, # tokens_per_message
|
||||
10, # num_messages -> total: 10 msgs, 100 tokens
|
||||
True, # Expected: first OR condition satisfied (messages >= 10)
|
||||
),
|
||||
# ========================================================================
|
||||
# Test Case Group 3: AND logic with fraction threshold
|
||||
# Trigger condition: tokens >= 400 AND fraction >= 0.5
|
||||
# Note: ProfileChatModel has max_input_tokens=1000, so 0.5 * 1000 = 500
|
||||
# ========================================================================
|
||||
(
|
||||
[[("tokens", 400), ("fraction", 0.5)]],
|
||||
ProfileChatModel,
|
||||
100, # tokens_per_message
|
||||
4, # num_messages -> total: 400 tokens
|
||||
False, # Expected: 400 tokens < 500 fraction threshold
|
||||
),
|
||||
(
|
||||
[[("tokens", 400), ("fraction", 0.5)]],
|
||||
ProfileChatModel,
|
||||
100, # tokens_per_message
|
||||
6, # num_messages -> total: 600 tokens
|
||||
True, # Expected: 600 >= 400 AND 600 >= 500 (both conditions met)
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_summarization_middleware_and_or_logic(
|
||||
trigger: list,
|
||||
model_class: type[BaseChatModel],
|
||||
tokens_per_message: int,
|
||||
num_messages: int,
|
||||
should_trigger: bool,
|
||||
) -> None:
|
||||
"""Test middleware with AND/OR logic combinations."""
|
||||
middleware = SummarizationMiddleware(
|
||||
model=model_class(),
|
||||
trigger=trigger,
|
||||
keep=("messages", 2),
|
||||
)
|
||||
middleware.token_counter = lambda messages: len(messages) * tokens_per_message
|
||||
|
||||
messages = [HumanMessage(content=str(i)) for i in range(num_messages)]
|
||||
state = {"messages": messages}
|
||||
result = middleware.before_model(state, None)
|
||||
|
||||
if should_trigger:
|
||||
assert result is not None
|
||||
else:
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_summarization_middleware_nested_validation() -> None:
|
||||
"""Test validation of nested AND groups."""
|
||||
model = MockChatModel()
|
||||
|
||||
# Valid nested list
|
||||
middleware = SummarizationMiddleware(
|
||||
model=model, trigger=[("messages", 5), [("tokens", 100), ("messages", 3)]]
|
||||
)
|
||||
assert middleware.trigger is not None
|
||||
|
||||
# Empty AND group should raise error
|
||||
with pytest.raises(ValueError, match="Empty AND groups are not allowed"):
|
||||
SummarizationMiddleware(model=model, trigger=[("messages", 5), []])
|
||||
|
||||
# Invalid type in trigger list
|
||||
with pytest.raises(ValueError, match="Trigger conditions must be tuples or lists"):
|
||||
SummarizationMiddleware(model=model, trigger=[("messages", 5), "invalid"])
|
||||
|
||||
# Invalid tuple in AND group
|
||||
with pytest.raises(ValueError, match="Unsupported context size type"):
|
||||
SummarizationMiddleware(model=model, trigger=[[("invalid", 100)]])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("tokens_per_message", "num_messages", "should_trigger"),
|
||||
[
|
||||
(100, 2, False), # 2 messages, 200 tokens - no conditions met
|
||||
(
|
||||
100,
|
||||
5,
|
||||
True,
|
||||
), # 5 messages, 500 tokens - second AND group satisfied (tokens >= 300 AND messages >= 5)
|
||||
(100, 4, False), # 4 messages, 400 tokens - tokens < 500 for third AND group
|
||||
(5, 20, True), # 20 messages, 100 tokens - first OR condition satisfied
|
||||
],
|
||||
)
|
||||
def test_summarization_middleware_complex_and_or(
|
||||
tokens_per_message: int,
|
||||
num_messages: int,
|
||||
should_trigger: bool,
|
||||
) -> None:
|
||||
"""Test complex AND/OR combinations.
|
||||
|
||||
Trigger when:
|
||||
- messages >= 20 OR
|
||||
- (tokens >= 300 AND messages >= 5) OR
|
||||
- (tokens >= 500 AND messages >= 3)
|
||||
"""
|
||||
middleware = SummarizationMiddleware(
|
||||
model=MockChatModel(),
|
||||
trigger=[
|
||||
("messages", 20),
|
||||
[("tokens", 300), ("messages", 5)],
|
||||
[("tokens", 500), ("messages", 3)],
|
||||
],
|
||||
keep=("messages", 2),
|
||||
)
|
||||
middleware.token_counter = lambda messages: len(messages) * tokens_per_message
|
||||
|
||||
messages = [HumanMessage(content=str(i)) for i in range(num_messages)]
|
||||
state = {"messages": messages}
|
||||
result = middleware.before_model(state, None)
|
||||
|
||||
if should_trigger:
|
||||
assert result is not None
|
||||
else:
|
||||
assert result is None
|
||||
|
||||
Reference in New Issue
Block a user