mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-16 18:02:57 +00:00
Compare commits
7 Commits
jacob/trac
...
sr/refine-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
03373459dc | ||
|
|
62c05e09c1 | ||
|
|
83b9d9f810 | ||
|
|
6deee23d8d | ||
|
|
690aabe8d4 | ||
|
|
80554df1e6 | ||
|
|
72c45e65e8 |
@@ -1,5 +1,12 @@
|
||||
"""Entrypoint to using [middleware](https://docs.langchain.com/oss/python/langchain/middleware) plugins with [Agents](https://docs.langchain.com/oss/python/langchain/agents).""" # noqa: E501
|
||||
|
||||
from ._context import (
|
||||
ContextCondition,
|
||||
ContextFraction,
|
||||
ContextMessages,
|
||||
ContextSize,
|
||||
ContextTokens,
|
||||
)
|
||||
from .context_editing import (
|
||||
ClearToolUsesEdit,
|
||||
ContextEditingMiddleware,
|
||||
@@ -45,7 +52,12 @@ __all__ = [
|
||||
"AgentState",
|
||||
"ClearToolUsesEdit",
|
||||
"CodexSandboxExecutionPolicy",
|
||||
"ContextCondition",
|
||||
"ContextEditingMiddleware",
|
||||
"ContextFraction",
|
||||
"ContextMessages",
|
||||
"ContextSize",
|
||||
"ContextTokens",
|
||||
"DockerExecutionPolicy",
|
||||
"FilesystemFileSearchMiddleware",
|
||||
"HostExecutionPolicy",
|
||||
|
||||
30
libs/langchain_v1/langchain/agents/middleware/_context.py
Normal file
30
libs/langchain_v1/langchain/agents/middleware/_context.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Shared context size types for middleware that manages message history."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
ContextFraction = tuple[Literal["fraction"], float]
|
||||
"""Tuple specifying context size as a fraction of the model's context window."""
|
||||
|
||||
ContextTokens = tuple[Literal["tokens"], int]
|
||||
"""Tuple specifying context size as a number of tokens."""
|
||||
|
||||
ContextMessages = tuple[Literal["messages"], int]
|
||||
"""Tuple specifying context size as a number of messages."""
|
||||
|
||||
ContextSize = ContextFraction | ContextTokens | ContextMessages
|
||||
"""Context size tuple to specify how much history to preserve or trigger conditions."""
|
||||
|
||||
ContextCondition = ContextSize | list[ContextSize | list[ContextSize]]
|
||||
"""Recursive type to support nested AND/OR conditions.
|
||||
|
||||
Top-level list = OR logic, nested list = AND logic.
|
||||
"""
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ContextCondition",
|
||||
"ContextFraction",
|
||||
"ContextMessages",
|
||||
"ContextSize",
|
||||
"ContextTokens",
|
||||
]
|
||||
@@ -9,9 +9,9 @@ chat model.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable, Iterable, Sequence
|
||||
from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@@ -30,6 +30,9 @@ from langchain.agents.middleware.types import (
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.agents.middleware._context import ContextCondition, ContextSize
|
||||
|
||||
DEFAULT_TOOL_PLACEHOLDER = "[cleared]"
|
||||
|
||||
|
||||
@@ -54,16 +57,37 @@ class ContextEdit(Protocol):
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ClearToolUsesEdit(ContextEdit):
|
||||
"""Configuration for clearing tool outputs when token limits are exceeded."""
|
||||
"""Configuration for clearing tool outputs when token limits are exceeded.
|
||||
|
||||
trigger: int = 100_000
|
||||
"""Token count that triggers the edit."""
|
||||
Supports flexible trigger and keep configurations using `ContextSize` tuples or
|
||||
backwards-compatible integer values.
|
||||
"""
|
||||
|
||||
trigger: ContextCondition | int = 100_000
|
||||
"""Trigger condition(s) for when the edit should run.
|
||||
|
||||
Supports flexible AND/OR logic via nested lists:
|
||||
- Single condition: `("messages", 50)` or `("tokens", 3000)`
|
||||
- OR conditions: `[("tokens", 3000), ("messages", 100)]`
|
||||
- AND conditions: `[("tokens", 500), ("fraction", 0.8)]` as nested list
|
||||
- Mixed AND/OR: `[("messages", 10), [("tokens", 500), ("fraction", 0.8)]]`
|
||||
|
||||
For backwards compatibility, also accepts an integer token count.
|
||||
"""
|
||||
|
||||
clear_at_least: int = 0
|
||||
"""Minimum number of tokens to reclaim when the edit runs."""
|
||||
|
||||
keep: int = 3
|
||||
"""Number of most recent tool results that must be preserved."""
|
||||
keep: ContextSize | int = 3
|
||||
"""Context retention policy for tool results.
|
||||
|
||||
Provide a `ContextSize` tuple to specify how much history to preserve:
|
||||
- `("messages", 3)` - Keep last 3 tool results
|
||||
- `("tokens", 1000)` - Keep tool results within token budget
|
||||
- `("fraction", 0.3)` - Keep tool results within 30% of model's max tokens
|
||||
|
||||
For backwards compatibility, also accepts an integer message count.
|
||||
"""
|
||||
|
||||
clear_tool_inputs: bool = False
|
||||
"""Whether to clear the originating tool call parameters on the AI message."""
|
||||
@@ -74,6 +98,66 @@ class ClearToolUsesEdit(ContextEdit):
|
||||
placeholder: str = DEFAULT_TOOL_PLACEHOLDER
|
||||
"""Placeholder text inserted for cleared tool outputs."""
|
||||
|
||||
model: Any = None
|
||||
"""Optional model instance for fractional token limits."""
|
||||
|
||||
_trigger_conditions: list[ContextSize | list[ContextSize]] | None = None
|
||||
_keep_normalized: ContextSize | None = None
|
||||
_trigger_as_int: int | None = None
|
||||
_keep_as_int: int | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate and normalize trigger/keep parameters."""
|
||||
# Normalize trigger
|
||||
if isinstance(self.trigger, int):
|
||||
self._trigger_as_int = self.trigger
|
||||
self._trigger_conditions = None
|
||||
elif isinstance(self.trigger, tuple):
|
||||
# Single ContextSize
|
||||
self._validate_context_size(self.trigger, "trigger")
|
||||
self._trigger_conditions = [self.trigger]
|
||||
self._trigger_as_int = None
|
||||
elif isinstance(self.trigger, list):
|
||||
# List of conditions
|
||||
self._trigger_conditions = self._validate_trigger_conditions(self.trigger)
|
||||
self._trigger_as_int = None
|
||||
else:
|
||||
msg = f"trigger must be int or ContextCondition, got {type(self.trigger).__name__}"
|
||||
raise TypeError(msg)
|
||||
|
||||
# Normalize keep
|
||||
if isinstance(self.keep, int):
|
||||
self._keep_as_int = self.keep
|
||||
self._keep_normalized = None
|
||||
elif isinstance(self.keep, tuple):
|
||||
self._validate_context_size(self.keep, "keep")
|
||||
self._keep_normalized = self.keep
|
||||
self._keep_as_int = None
|
||||
else:
|
||||
msg = f"keep must be int or ContextSize, got {type(self.keep).__name__}"
|
||||
raise TypeError(msg)
|
||||
|
||||
# Check if model profile is required
|
||||
requires_profile = False
|
||||
if self._trigger_conditions:
|
||||
requires_profile = self._requires_profile(self._trigger_conditions)
|
||||
if self._keep_normalized and self._keep_normalized[0] == "fraction":
|
||||
requires_profile = True
|
||||
|
||||
if requires_profile and self.model is None:
|
||||
msg = (
|
||||
"model parameter is required when using fractional token limits. "
|
||||
"Pass a model instance or use absolute token/message counts instead."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if requires_profile and self._get_profile_limits() is None:
|
||||
msg = (
|
||||
"Model profile information is required to use fractional token limits. "
|
||||
'pip install "langchain[model-profiles]" or use absolute token counts instead.'
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
messages: list[AnyMessage],
|
||||
@@ -83,17 +167,20 @@ class ClearToolUsesEdit(ContextEdit):
|
||||
"""Apply the clear-tool-uses strategy."""
|
||||
tokens = count_tokens(messages)
|
||||
|
||||
if tokens <= self.trigger:
|
||||
if not self._should_trigger(messages, tokens):
|
||||
return
|
||||
|
||||
candidates = [
|
||||
(idx, msg) for idx, msg in enumerate(messages) if isinstance(msg, ToolMessage)
|
||||
]
|
||||
|
||||
if self.keep >= len(candidates):
|
||||
# Determine how many tool results to keep
|
||||
keep_count = self._determine_keep_count(messages, tokens)
|
||||
|
||||
if keep_count >= len(candidates):
|
||||
candidates = []
|
||||
elif self.keep:
|
||||
candidates = candidates[: -self.keep]
|
||||
elif keep_count > 0:
|
||||
candidates = candidates[:-keep_count]
|
||||
|
||||
cleared_tokens = 0
|
||||
excluded_tools = set(self.exclude_tools)
|
||||
@@ -181,6 +268,191 @@ class ClearToolUsesEdit(ContextEdit):
|
||||
}
|
||||
)
|
||||
|
||||
def _should_trigger(self, messages: list[AnyMessage], total_tokens: int) -> bool:
|
||||
"""Determine whether the edit should trigger based on current state."""
|
||||
# Backwards compatibility: int trigger
|
||||
if self._trigger_as_int is not None:
|
||||
return total_tokens > self._trigger_as_int
|
||||
|
||||
# New API: ContextCondition with AND/OR logic
|
||||
if not self._trigger_conditions:
|
||||
return False
|
||||
|
||||
# OR logic across top-level conditions
|
||||
for condition in self._trigger_conditions:
|
||||
if isinstance(condition, list):
|
||||
# AND group - all must be satisfied
|
||||
if self._check_and_group(condition, messages, total_tokens):
|
||||
return True
|
||||
elif self._check_single_condition(condition, messages, total_tokens):
|
||||
# Single condition
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check_and_group(
|
||||
self, and_group: list[ContextSize], messages: list[AnyMessage], total_tokens: int
|
||||
) -> bool:
|
||||
"""Check if all conditions in an AND group are satisfied."""
|
||||
for condition in and_group:
|
||||
if not self._check_single_condition(condition, messages, total_tokens):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _check_single_condition(
|
||||
self, condition: ContextSize, messages: list[AnyMessage], total_tokens: int
|
||||
) -> bool:
|
||||
"""Check if a single condition is satisfied."""
|
||||
kind, value = condition
|
||||
if kind == "messages":
|
||||
return len(messages) >= value
|
||||
if kind == "tokens":
|
||||
return total_tokens >= value
|
||||
if kind == "fraction":
|
||||
max_input_tokens = self._get_profile_limits()
|
||||
if max_input_tokens is None:
|
||||
return False
|
||||
threshold = int(max_input_tokens * value)
|
||||
if threshold <= 0:
|
||||
threshold = 1
|
||||
return total_tokens >= threshold
|
||||
return False
|
||||
|
||||
def _determine_keep_count(self, messages: list[AnyMessage], total_tokens: int) -> int: # noqa: ARG002
|
||||
"""Determine how many tool results to keep based on keep configuration."""
|
||||
# Backwards compatibility: int keep
|
||||
if self._keep_as_int is not None:
|
||||
return self._keep_as_int
|
||||
|
||||
# New API: ContextSize
|
||||
if self._keep_normalized is None:
|
||||
return 0
|
||||
|
||||
kind, value = self._keep_normalized
|
||||
if kind == "messages":
|
||||
return int(value)
|
||||
if kind in {"tokens", "fraction"}:
|
||||
# For token-based keep, we need to count backwards through tool messages
|
||||
# to find how many fit within the budget
|
||||
target_tokens = self._get_target_token_count(value, kind)
|
||||
if target_tokens is None:
|
||||
return 0
|
||||
return self._count_tool_messages_within_budget(messages, target_tokens)
|
||||
return 0
|
||||
|
||||
def _get_target_token_count(self, value: float, kind: str) -> int | None:
|
||||
"""Get the target token count for token/fraction-based keep."""
|
||||
if kind == "fraction":
|
||||
max_input_tokens = self._get_profile_limits()
|
||||
if max_input_tokens is None:
|
||||
return None
|
||||
target = int(max_input_tokens * value)
|
||||
elif kind == "tokens":
|
||||
target = int(value)
|
||||
else:
|
||||
return None
|
||||
|
||||
return max(1, target) if target > 0 else 1
|
||||
|
||||
def _count_tool_messages_within_budget(
|
||||
self, messages: list[AnyMessage], target_tokens: int
|
||||
) -> int:
|
||||
"""Count how many recent tool messages fit within token budget."""
|
||||
tool_messages = [msg for msg in messages if isinstance(msg, ToolMessage)]
|
||||
if not tool_messages:
|
||||
return 0
|
||||
|
||||
# Count backwards from the end
|
||||
count = 0
|
||||
accumulated_tokens = 0
|
||||
for tool_msg in reversed(tool_messages):
|
||||
# Approximate token count for this message
|
||||
msg_tokens = len(str(tool_msg.content))
|
||||
if accumulated_tokens + msg_tokens > target_tokens and count > 0:
|
||||
break
|
||||
accumulated_tokens += msg_tokens
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
def _get_profile_limits(self) -> int | None:
|
||||
"""Retrieve max input token limit from the model profile."""
|
||||
if self.model is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
profile = self.model.profile
|
||||
except (AttributeError, ImportError):
|
||||
return None
|
||||
|
||||
if not isinstance(profile, Mapping):
|
||||
return None
|
||||
|
||||
max_input_tokens = profile.get("max_input_tokens")
|
||||
|
||||
if not isinstance(max_input_tokens, int):
|
||||
return None
|
||||
|
||||
return max_input_tokens
|
||||
|
||||
def _validate_context_size(self, context: ContextSize, parameter_name: str) -> ContextSize:
|
||||
"""Validate context configuration tuples."""
|
||||
kind, value = context
|
||||
if kind == "fraction":
|
||||
if not 0 < value <= 1:
|
||||
msg = f"Fractional {parameter_name} values must be between 0 and 1, got {value}."
|
||||
raise ValueError(msg)
|
||||
elif kind in {"tokens", "messages"}:
|
||||
# For trigger, value must be > 0. For keep, value can be >= 0 (0 means keep nothing)
|
||||
if parameter_name == "trigger" and value <= 0:
|
||||
msg = f"{parameter_name} thresholds must be greater than 0, got {value}."
|
||||
raise ValueError(msg)
|
||||
if parameter_name == "keep" and value < 0:
|
||||
msg = f"{parameter_name} values must be non-negative, got {value}."
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
msg = f"Unsupported context size type {kind} for {parameter_name}."
|
||||
raise ValueError(msg)
|
||||
return context
|
||||
|
||||
def _validate_trigger_conditions(
|
||||
self, conditions: list[Any]
|
||||
) -> list[ContextSize | list[ContextSize]]:
|
||||
"""Validate and normalize trigger conditions with nested AND/OR logic.
|
||||
|
||||
Args:
|
||||
conditions: List of ContextSize tuples or nested lists of ContextSize tuples.
|
||||
|
||||
Returns:
|
||||
Validated list where top-level items are OR'd and nested lists are AND'd.
|
||||
"""
|
||||
validated: list[ContextSize | list[ContextSize]] = []
|
||||
for item in conditions:
|
||||
if isinstance(item, tuple):
|
||||
# Single condition (tuple)
|
||||
validated.append(self._validate_context_size(item, "trigger"))
|
||||
elif isinstance(item, list):
|
||||
# AND group (nested list)
|
||||
if not item:
|
||||
msg = "Empty AND groups are not allowed in trigger conditions."
|
||||
raise ValueError(msg)
|
||||
and_group = [self._validate_context_size(cond, "trigger") for cond in item]
|
||||
validated.append(and_group)
|
||||
else:
|
||||
msg = f"Trigger conditions must be tuples or lists, got {type(item).__name__}."
|
||||
raise ValueError(msg)
|
||||
return validated
|
||||
|
||||
def _requires_profile(self, conditions: list[ContextSize | list[ContextSize]]) -> bool:
|
||||
"""Check if any condition requires model profile information."""
|
||||
for condition in conditions:
|
||||
if isinstance(condition, list):
|
||||
# AND group
|
||||
if any(c[0] == "fraction" for c in condition):
|
||||
return True
|
||||
elif condition[0] == "fraction":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class ContextEditingMiddleware(AgentMiddleware):
|
||||
"""Automatically prune tool results to manage context size.
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from typing import Any, Literal, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@@ -19,6 +19,10 @@ from langgraph.graph.message import (
|
||||
)
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from langchain.agents.middleware._context import (
|
||||
ContextCondition,
|
||||
ContextSize,
|
||||
)
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState
|
||||
from langchain.chat_models import BaseChatModel, init_chat_model
|
||||
|
||||
@@ -57,12 +61,6 @@ _DEFAULT_TRIM_TOKEN_LIMIT = 4000
|
||||
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
|
||||
_SEARCH_RANGE_FOR_TOOL_PAIRS = 5
|
||||
|
||||
ContextFraction = tuple[Literal["fraction"], float]
|
||||
ContextTokens = tuple[Literal["tokens"], int]
|
||||
ContextMessages = tuple[Literal["messages"], int]
|
||||
|
||||
ContextSize = ContextFraction | ContextTokens | ContextMessages
|
||||
|
||||
|
||||
class SummarizationMiddleware(AgentMiddleware):
|
||||
"""Summarizes conversation history when token limits are approached.
|
||||
@@ -76,7 +74,7 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
self,
|
||||
model: str | BaseChatModel,
|
||||
*,
|
||||
trigger: ContextSize | list[ContextSize] | None = None,
|
||||
trigger: ContextCondition | None = None,
|
||||
keep: ContextSize = ("messages", _DEFAULT_MESSAGES_TO_KEEP),
|
||||
token_counter: TokenCounter = count_tokens_approximately,
|
||||
summary_prompt: str = DEFAULT_SUMMARY_PROMPT,
|
||||
@@ -87,10 +85,16 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
|
||||
Args:
|
||||
model: The language model to use for generating summaries.
|
||||
trigger: One or more thresholds that trigger summarization. Provide a single
|
||||
`ContextSize` tuple or a list of tuples, in which case summarization runs
|
||||
when any threshold is breached. Examples: `("messages", 50)`, `("tokens", 3000)`,
|
||||
`[("fraction", 0.8), ("messages", 100)]`.
|
||||
trigger: One or more thresholds that trigger summarization. Supports flexible
|
||||
AND/OR logic via nested lists. Top-level list items are combined with OR,
|
||||
nested lists are combined with AND. Examples:
|
||||
- Single condition: `("messages", 50)`
|
||||
- OR conditions: `[("tokens", 3000), ("messages", 100)]` (triggers when
|
||||
tokens >= 3000 OR messages >= 100)
|
||||
- AND conditions: `[("tokens", 500), ("fraction", 0.8)]` as a nested list
|
||||
within the top-level list
|
||||
- Mixed AND/OR: `[("messages", 10), [("tokens", 500), ("fraction", 0.8)]]`
|
||||
(triggers when messages >= 10 OR (tokens >= 500 AND fraction >= 0.8))
|
||||
keep: Context retention policy applied after summarization. Provide a
|
||||
`ContextSize` tuple to specify how much history to preserve. Defaults to
|
||||
keeping the most recent 20 messages. Examples: `("messages", 20)`,
|
||||
@@ -128,13 +132,15 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
|
||||
self.model = model
|
||||
if trigger is None:
|
||||
self.trigger: ContextSize | list[ContextSize] | None = None
|
||||
trigger_conditions: list[ContextSize] = []
|
||||
self.trigger: ContextCondition | None = None
|
||||
trigger_conditions: list[ContextSize | list[ContextSize]] = []
|
||||
elif isinstance(trigger, list):
|
||||
validated_list = [self._validate_context_size(item, "trigger") for item in trigger]
|
||||
# Validate and normalize nested structure
|
||||
validated_list = self._validate_trigger_conditions(trigger)
|
||||
self.trigger = validated_list
|
||||
trigger_conditions = validated_list
|
||||
else:
|
||||
# Single ContextSize tuple
|
||||
validated = self._validate_context_size(trigger, "trigger")
|
||||
self.trigger = validated
|
||||
trigger_conditions = [validated]
|
||||
@@ -145,7 +151,7 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
self.summary_prompt = summary_prompt
|
||||
self.trim_tokens_to_summarize = trim_tokens_to_summarize
|
||||
|
||||
requires_profile = any(condition[0] == "fraction" for condition in self._trigger_conditions)
|
||||
requires_profile = self._requires_profile(self._trigger_conditions)
|
||||
if self.keep[0] == "fraction":
|
||||
requires_profile = True
|
||||
if requires_profile and self._get_profile_limits() is None:
|
||||
@@ -211,24 +217,52 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
}
|
||||
|
||||
def _should_summarize(self, messages: list[AnyMessage], total_tokens: int) -> bool:
|
||||
"""Determine whether summarization should run for the current token usage."""
|
||||
"""Determine whether summarization should run for the current token usage.
|
||||
|
||||
Evaluates trigger conditions with AND/OR logic:
|
||||
- Top-level items are OR'd together
|
||||
- Nested lists are AND'd together
|
||||
"""
|
||||
if not self._trigger_conditions:
|
||||
return False
|
||||
|
||||
for kind, value in self._trigger_conditions:
|
||||
if kind == "messages" and len(messages) >= value:
|
||||
return True
|
||||
if kind == "tokens" and total_tokens >= value:
|
||||
return True
|
||||
if kind == "fraction":
|
||||
max_input_tokens = self._get_profile_limits()
|
||||
if max_input_tokens is None:
|
||||
continue
|
||||
threshold = int(max_input_tokens * value)
|
||||
if threshold <= 0:
|
||||
threshold = 1
|
||||
if total_tokens >= threshold:
|
||||
# OR logic across top-level conditions
|
||||
for condition in self._trigger_conditions:
|
||||
if isinstance(condition, list):
|
||||
# AND group - all must be satisfied
|
||||
if self._check_and_group(condition, messages, total_tokens):
|
||||
return True
|
||||
elif self._check_single_condition(condition, messages, total_tokens):
|
||||
# Single condition
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check_and_group(
|
||||
self, and_group: list[ContextSize], messages: list[AnyMessage], total_tokens: int
|
||||
) -> bool:
|
||||
"""Check if all conditions in an AND group are satisfied."""
|
||||
for condition in and_group:
|
||||
if not self._check_single_condition(condition, messages, total_tokens):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _check_single_condition(
|
||||
self, condition: ContextSize, messages: list[AnyMessage], total_tokens: int
|
||||
) -> bool:
|
||||
"""Check if a single condition is satisfied."""
|
||||
kind, value = condition
|
||||
if kind == "messages":
|
||||
return len(messages) >= value
|
||||
if kind == "tokens":
|
||||
return total_tokens >= value
|
||||
if kind == "fraction":
|
||||
max_input_tokens = self._get_profile_limits()
|
||||
if max_input_tokens is None:
|
||||
return False
|
||||
threshold = int(max_input_tokens * value)
|
||||
if threshold <= 0:
|
||||
threshold = 1
|
||||
return total_tokens >= threshold
|
||||
return False
|
||||
|
||||
def _determine_cutoff_index(self, messages: list[AnyMessage]) -> int:
|
||||
@@ -328,6 +362,45 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
raise ValueError(msg)
|
||||
return context
|
||||
|
||||
def _validate_trigger_conditions(
|
||||
self, conditions: list[Any]
|
||||
) -> list[ContextSize | list[ContextSize]]:
|
||||
"""Validate and normalize trigger conditions with nested AND/OR logic.
|
||||
|
||||
Args:
|
||||
conditions: List of ContextSize tuples or nested lists of ContextSize tuples.
|
||||
|
||||
Returns:
|
||||
Validated list where top-level items are OR'd and nested lists are AND'd.
|
||||
"""
|
||||
validated: list[ContextSize | list[ContextSize]] = []
|
||||
for item in conditions:
|
||||
if isinstance(item, tuple):
|
||||
# Single condition (tuple)
|
||||
validated.append(self._validate_context_size(item, "trigger"))
|
||||
elif isinstance(item, list):
|
||||
# AND group (nested list)
|
||||
if not item:
|
||||
msg = "Empty AND groups are not allowed in trigger conditions."
|
||||
raise ValueError(msg)
|
||||
and_group = [self._validate_context_size(cond, "trigger") for cond in item]
|
||||
validated.append(and_group)
|
||||
else:
|
||||
msg = f"Trigger conditions must be tuples or lists, got {type(item).__name__}."
|
||||
raise ValueError(msg)
|
||||
return validated
|
||||
|
||||
def _requires_profile(self, conditions: list[ContextSize | list[ContextSize]]) -> bool:
|
||||
"""Check if any condition requires model profile information."""
|
||||
for condition in conditions:
|
||||
if isinstance(condition, list):
|
||||
# AND group
|
||||
if any(c[0] == "fraction" for c in condition):
|
||||
return True
|
||||
elif condition[0] == "fraction":
|
||||
return True
|
||||
return False
|
||||
|
||||
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
|
||||
return [
|
||||
HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}")
|
||||
|
||||
@@ -399,3 +399,154 @@ async def test_exclude_tools_prevents_clearing_async() -> None:
|
||||
|
||||
assert isinstance(calc_tool, ToolMessage)
|
||||
assert calc_tool.content == "[cleared]"
|
||||
|
||||
|
||||
# New API tests
|
||||
|
||||
|
||||
def test_new_api_trigger_with_context_size_tuple() -> None:
|
||||
"""Test new API with ContextSize tuple for trigger."""
|
||||
tool_call_id = "call-1"
|
||||
ai_message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"id": tool_call_id, "name": "search", "args": {}}],
|
||||
)
|
||||
tool_message = ToolMessage(content="x" * 200, tool_call_id=tool_call_id)
|
||||
|
||||
state, request = _make_state_and_request([ai_message, tool_message])
|
||||
|
||||
# Use new API with tuple
|
||||
edit = ClearToolUsesEdit(
|
||||
trigger=("tokens", 50),
|
||||
keep=("messages", 0),
|
||||
placeholder="[cleared]",
|
||||
)
|
||||
middleware = ContextEditingMiddleware(edits=[edit])
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
cleared_tool = request.messages[1]
|
||||
assert isinstance(cleared_tool, ToolMessage)
|
||||
assert cleared_tool.content == "[cleared]"
|
||||
|
||||
|
||||
def test_new_api_keep_with_messages_tuple() -> None:
|
||||
"""Test new API with messages-based keep."""
|
||||
conversation: list[AIMessage | ToolMessage] = []
|
||||
for i in range(5):
|
||||
call_id = f"call-{i}"
|
||||
conversation.append(
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"id": call_id, "name": "tool", "args": {}}],
|
||||
)
|
||||
)
|
||||
conversation.append(ToolMessage(content="x" * 50, tool_call_id=call_id))
|
||||
|
||||
state, request = _make_state_and_request(conversation)
|
||||
|
||||
edit = ClearToolUsesEdit(
|
||||
trigger=("tokens", 50),
|
||||
keep=("messages", 2), # Keep last 2 tool results
|
||||
placeholder="[cleared]",
|
||||
)
|
||||
middleware = ContextEditingMiddleware(edits=[edit])
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
# Check that first 3 tool messages are cleared, last 2 are preserved
|
||||
tool_messages = [msg for msg in request.messages if isinstance(msg, ToolMessage)]
|
||||
cleared = [msg for msg in tool_messages if msg.content == "[cleared]"]
|
||||
preserved = [msg for msg in tool_messages if msg.content != "[cleared]"]
|
||||
|
||||
assert len(cleared) == 3
|
||||
assert len(preserved) == 2
|
||||
|
||||
|
||||
def test_new_api_or_conditions() -> None:
|
||||
"""Test new API with OR trigger conditions."""
|
||||
tool_call_id = "call-1"
|
||||
ai_message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"id": tool_call_id, "name": "search", "args": {}}],
|
||||
)
|
||||
tool_message = ToolMessage(content="x" * 200, tool_call_id=tool_call_id)
|
||||
|
||||
state, request = _make_state_and_request([ai_message, tool_message])
|
||||
|
||||
# Use OR conditions: triggers if tokens >= 50 OR messages >= 100
|
||||
edit = ClearToolUsesEdit(
|
||||
trigger=[("tokens", 50), ("messages", 100)],
|
||||
keep=("messages", 0),
|
||||
placeholder="[cleared]",
|
||||
)
|
||||
middleware = ContextEditingMiddleware(edits=[edit])
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
# Should trigger because tokens >= 50 (even though messages < 100)
|
||||
cleared_tool = request.messages[1]
|
||||
assert isinstance(cleared_tool, ToolMessage)
|
||||
assert cleared_tool.content == "[cleared]"
|
||||
|
||||
|
||||
def test_new_api_backwards_compatibility() -> None:
|
||||
"""Test that old integer API still works."""
|
||||
tool_call_id = "call-1"
|
||||
ai_message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"id": tool_call_id, "name": "search", "args": {}}],
|
||||
)
|
||||
tool_message = ToolMessage(content="x" * 200, tool_call_id=tool_call_id)
|
||||
|
||||
state, request = _make_state_and_request([ai_message, tool_message])
|
||||
|
||||
# Old API with integers
|
||||
edit = ClearToolUsesEdit(
|
||||
trigger=50, # int
|
||||
keep=0, # int
|
||||
placeholder="[cleared]",
|
||||
)
|
||||
middleware = ContextEditingMiddleware(edits=[edit])
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
cleared_tool = request.messages[1]
|
||||
assert isinstance(cleared_tool, ToolMessage)
|
||||
assert cleared_tool.content == "[cleared]"
|
||||
|
||||
|
||||
def test_new_api_validation_errors() -> None:
|
||||
"""Test that validation errors are raised for invalid configurations."""
|
||||
# Test invalid fraction value
|
||||
try:
|
||||
ClearToolUsesEdit(trigger=("fraction", 1.5), keep=("messages", 3))
|
||||
assert False, "Should have raised ValueError"
|
||||
except ValueError as e:
|
||||
assert "must be between 0 and 1" in str(e)
|
||||
|
||||
# Test invalid token count
|
||||
try:
|
||||
ClearToolUsesEdit(trigger=("tokens", -1), keep=("messages", 3))
|
||||
assert False, "Should have raised ValueError"
|
||||
except ValueError as e:
|
||||
assert "must be greater than 0" in str(e)
|
||||
|
||||
# Test unsupported type
|
||||
try:
|
||||
ClearToolUsesEdit(trigger=("invalid", 100), keep=("messages", 3)) # type: ignore[arg-type]
|
||||
assert False, "Should have raised ValueError"
|
||||
except ValueError as e:
|
||||
assert "Unsupported context size type" in str(e)
|
||||
|
||||
@@ -896,3 +896,171 @@ def test_summarization_middleware_is_safe_cutoff_at_end() -> None:
|
||||
|
||||
# Cutoff past the length should also be safe
|
||||
assert middleware._is_safe_cutoff_point(messages, len(messages) + 5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("trigger", "model_class", "tokens_per_message", "num_messages", "should_trigger"),
|
||||
[
|
||||
# ========================================================================
|
||||
# Test Case Group 1: Pure AND logic
|
||||
# Trigger condition: messages >= 5 AND tokens >= 300
|
||||
# ========================================================================
|
||||
(
|
||||
[[("messages", 5), ("tokens", 300)]],
|
||||
MockChatModel,
|
||||
50, # tokens_per_message
|
||||
4, # num_messages -> total: 4 msgs, 200 tokens
|
||||
False, # Expected: neither condition met
|
||||
),
|
||||
(
|
||||
[[("messages", 5), ("tokens", 300)]],
|
||||
MockChatModel,
|
||||
50, # tokens_per_message
|
||||
5, # num_messages -> total: 5 msgs, 250 tokens
|
||||
False, # Expected: only messages condition met (tokens < 300)
|
||||
),
|
||||
(
|
||||
[[("messages", 5), ("tokens", 300)]],
|
||||
MockChatModel,
|
||||
50, # tokens_per_message
|
||||
7, # num_messages -> total: 7 msgs, 350 tokens
|
||||
True, # Expected: both conditions met
|
||||
),
|
||||
# ========================================================================
|
||||
# Test Case Group 2: Mixed OR/AND logic
|
||||
# Trigger condition: messages >= 10 OR (tokens >= 500 AND messages >= 3)
|
||||
# ========================================================================
|
||||
(
|
||||
[("messages", 10), [("tokens", 500), ("messages", 3)]],
|
||||
MockChatModel,
|
||||
100, # tokens_per_message
|
||||
2, # num_messages -> total: 2 msgs, 200 tokens
|
||||
False, # Expected: no conditions met
|
||||
),
|
||||
(
|
||||
[("messages", 10), [("tokens", 500), ("messages", 3)]],
|
||||
MockChatModel,
|
||||
100, # tokens_per_message
|
||||
5, # num_messages -> total: 5 msgs, 500 tokens
|
||||
True, # Expected: AND group satisfied (tokens >= 500 AND messages >= 3)
|
||||
),
|
||||
(
|
||||
[("messages", 10), [("tokens", 500), ("messages", 3)]],
|
||||
MockChatModel,
|
||||
10, # tokens_per_message
|
||||
10, # num_messages -> total: 10 msgs, 100 tokens
|
||||
True, # Expected: first OR condition satisfied (messages >= 10)
|
||||
),
|
||||
# ========================================================================
|
||||
# Test Case Group 3: AND logic with fraction threshold
|
||||
# Trigger condition: tokens >= 400 AND fraction >= 0.5
|
||||
# Note: ProfileChatModel has max_input_tokens=1000, so 0.5 * 1000 = 500
|
||||
# ========================================================================
|
||||
(
|
||||
[[("tokens", 400), ("fraction", 0.5)]],
|
||||
ProfileChatModel,
|
||||
100, # tokens_per_message
|
||||
4, # num_messages -> total: 400 tokens
|
||||
False, # Expected: 400 tokens < 500 fraction threshold
|
||||
),
|
||||
(
|
||||
[[("tokens", 400), ("fraction", 0.5)]],
|
||||
ProfileChatModel,
|
||||
100, # tokens_per_message
|
||||
6, # num_messages -> total: 600 tokens
|
||||
True, # Expected: 600 >= 400 AND 600 >= 500 (both conditions met)
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_summarization_middleware_and_or_logic(
|
||||
trigger: list,
|
||||
model_class: type[BaseChatModel],
|
||||
tokens_per_message: int,
|
||||
num_messages: int,
|
||||
should_trigger: bool,
|
||||
) -> None:
|
||||
"""Test middleware with AND/OR logic combinations."""
|
||||
middleware = SummarizationMiddleware(
|
||||
model=model_class(),
|
||||
trigger=trigger,
|
||||
keep=("messages", 2),
|
||||
)
|
||||
middleware.token_counter = lambda messages: len(messages) * tokens_per_message
|
||||
|
||||
messages = [HumanMessage(content=str(i)) for i in range(num_messages)]
|
||||
state = {"messages": messages}
|
||||
result = middleware.before_model(state, None)
|
||||
|
||||
if should_trigger:
|
||||
assert result is not None
|
||||
else:
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_summarization_middleware_nested_validation() -> None:
|
||||
"""Test validation of nested AND groups."""
|
||||
model = MockChatModel()
|
||||
|
||||
# Valid nested list
|
||||
middleware = SummarizationMiddleware(
|
||||
model=model, trigger=[("messages", 5), [("tokens", 100), ("messages", 3)]]
|
||||
)
|
||||
assert middleware.trigger is not None
|
||||
|
||||
# Empty AND group should raise error
|
||||
with pytest.raises(ValueError, match="Empty AND groups are not allowed"):
|
||||
SummarizationMiddleware(model=model, trigger=[("messages", 5), []])
|
||||
|
||||
# Invalid type in trigger list
|
||||
with pytest.raises(ValueError, match="Trigger conditions must be tuples or lists"):
|
||||
SummarizationMiddleware(model=model, trigger=[("messages", 5), "invalid"])
|
||||
|
||||
# Invalid tuple in AND group
|
||||
with pytest.raises(ValueError, match="Unsupported context size type"):
|
||||
SummarizationMiddleware(model=model, trigger=[[("invalid", 100)]])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("tokens_per_message", "num_messages", "should_trigger"),
|
||||
[
|
||||
(100, 2, False), # 2 messages, 200 tokens - no conditions met
|
||||
(
|
||||
100,
|
||||
5,
|
||||
True,
|
||||
), # 5 messages, 500 tokens - second AND group satisfied (tokens >= 300 AND messages >= 5)
|
||||
(100, 4, False), # 4 messages, 400 tokens - tokens < 500 for third AND group
|
||||
(5, 20, True), # 20 messages, 100 tokens - first OR condition satisfied
|
||||
],
|
||||
)
|
||||
def test_summarization_middleware_complex_and_or(
|
||||
tokens_per_message: int,
|
||||
num_messages: int,
|
||||
should_trigger: bool,
|
||||
) -> None:
|
||||
"""Test complex AND/OR combinations.
|
||||
|
||||
Trigger when:
|
||||
- messages >= 20 OR
|
||||
- (tokens >= 300 AND messages >= 5) OR
|
||||
- (tokens >= 500 AND messages >= 3)
|
||||
"""
|
||||
middleware = SummarizationMiddleware(
|
||||
model=MockChatModel(),
|
||||
trigger=[
|
||||
("messages", 20),
|
||||
[("tokens", 300), ("messages", 5)],
|
||||
[("tokens", 500), ("messages", 3)],
|
||||
],
|
||||
keep=("messages", 2),
|
||||
)
|
||||
middleware.token_counter = lambda messages: len(messages) * tokens_per_message
|
||||
|
||||
messages = [HumanMessage(content=str(i)) for i in range(num_messages)]
|
||||
state = {"messages": messages}
|
||||
result = middleware.before_model(state, None)
|
||||
|
||||
if should_trigger:
|
||||
assert result is not None
|
||||
else:
|
||||
assert result is None
|
||||
|
||||
Reference in New Issue
Block a user