mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-15 17:44:08 +00:00
Compare commits
9 Commits
v1.2
...
sr/refacto
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
96525c825c | ||
|
|
841529615e | ||
|
|
79b675db9e | ||
|
|
99ece77cd4 | ||
|
|
5d6c7e2433 | ||
|
|
c7d2bf34e6 | ||
|
|
4f7059bd6d | ||
|
|
571d947c82 | ||
|
|
57faf33be0 |
@@ -1,5 +1,11 @@
|
||||
"""Entrypoint to using [middleware](https://docs.langchain.com/oss/python/langchain/middleware) plugins with [Agents](https://docs.langchain.com/oss/python/langchain/agents).""" # noqa: E501
|
||||
|
||||
from ._context import (
|
||||
ContextFraction,
|
||||
ContextMessages,
|
||||
ContextSize,
|
||||
ContextTokens,
|
||||
)
|
||||
from .context_editing import (
|
||||
ClearToolUsesEdit,
|
||||
ContextEditingMiddleware,
|
||||
@@ -46,6 +52,10 @@ __all__ = [
|
||||
"ClearToolUsesEdit",
|
||||
"CodexSandboxExecutionPolicy",
|
||||
"ContextEditingMiddleware",
|
||||
"ContextFraction",
|
||||
"ContextMessages",
|
||||
"ContextSize",
|
||||
"ContextTokens",
|
||||
"DockerExecutionPolicy",
|
||||
"FilesystemFileSearchMiddleware",
|
||||
"HostExecutionPolicy",
|
||||
|
||||
145
libs/langchain_v1/langchain/agents/middleware/_context.py
Normal file
145
libs/langchain_v1/langchain/agents/middleware/_context.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Shared types and utilities for context size management.
|
||||
|
||||
This module centralizes type definitions and helper functions used across
|
||||
context editing and summarization middleware implementations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
from langchain_core.messages import MessageLikeRepresentation
|
||||
|
||||
# Type aliases for context size specifications
|
||||
ContextFraction: TypeAlias = tuple[Literal["fraction"], float]
|
||||
"""Fractional threshold (0.0 to 1.0) of max model input tokens.
|
||||
|
||||
Example: `("fraction", 0.8)` means trigger at 80% of max input tokens.
|
||||
"""
|
||||
|
||||
ContextTokens: TypeAlias = tuple[Literal["tokens"], int]
|
||||
"""Absolute token count threshold.
|
||||
|
||||
Example: `("tokens", 100_000)` means trigger at exactly 100,000 tokens.
|
||||
"""
|
||||
|
||||
ContextMessages: TypeAlias = tuple[Literal["messages"], int]
|
||||
"""Message count threshold.
|
||||
|
||||
Example: `("messages", 50)` means trigger at exactly 50 messages.
|
||||
"""
|
||||
|
||||
ContextSize: TypeAlias = ContextFraction | ContextTokens | ContextMessages
|
||||
"""Union type for context size specifications.
|
||||
|
||||
Provides type-safe representation of context size thresholds using one of:
|
||||
- `("fraction", float)`: Fractional threshold (0.0 to 1.0) of max model input tokens
|
||||
- `("tokens", int)`: Absolute token count threshold
|
||||
- `("messages", int)`: Message count threshold
|
||||
"""
|
||||
|
||||
# Token counter callable type
|
||||
TokenCounter: TypeAlias = Callable[
|
||||
[Iterable[MessageLikeRepresentation]],
|
||||
int,
|
||||
]
|
||||
"""Callable that counts tokens in messages.
|
||||
|
||||
Accepts either `Sequence[BaseMessage]` or `Iterable[MessageLikeRepresentation]`.
|
||||
"""
|
||||
|
||||
|
||||
def coerce_to_context_size(
|
||||
value: int | ContextSize, *, kind: Literal["trigger", "keep"], param_name: str
|
||||
) -> ContextSize:
|
||||
"""Coerce integer values to ContextSize tuples for backwards compatibility.
|
||||
|
||||
Args:
|
||||
value: Integer or ContextSize tuple.
|
||||
kind: Whether this is for a trigger or keep parameter.
|
||||
param_name: Name of the parameter for deprecation warnings.
|
||||
|
||||
Returns:
|
||||
ContextSize tuple.
|
||||
"""
|
||||
if isinstance(value, int):
|
||||
# trigger uses tokens, keep uses messages (backwards compat with old API)
|
||||
if kind == "trigger":
|
||||
warnings.warn(
|
||||
f"{param_name}={value} (int) is deprecated. "
|
||||
f"Use {param_name}=('tokens', {value}) instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
return ("tokens", value)
|
||||
warnings.warn(
|
||||
f"{param_name}={value} (int) is deprecated. "
|
||||
f"Use {param_name}=('messages', {value}) instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
return ("messages", value)
|
||||
return value
|
||||
|
||||
|
||||
def validate_context_size(
|
||||
context: ContextSize,
|
||||
parameter_name: str,
|
||||
*,
|
||||
allow_zero_for_keep: bool = False,
|
||||
) -> ContextSize:
|
||||
"""Validate context configuration tuples.
|
||||
|
||||
Args:
|
||||
context: The ContextSize tuple to validate.
|
||||
parameter_name: Name of the parameter being validated (for error messages).
|
||||
allow_zero_for_keep: Whether to allow zero values for "keep" parameters.
|
||||
|
||||
Returns:
|
||||
The validated ContextSize tuple.
|
||||
|
||||
Raises:
|
||||
ValueError: If the context configuration is invalid.
|
||||
"""
|
||||
kind, value = context
|
||||
if kind == "fraction":
|
||||
if not 0 < value <= 1:
|
||||
msg = f"Fractional {parameter_name} values must be between 0 and 1, got {value}."
|
||||
raise ValueError(msg)
|
||||
elif kind in {"tokens", "messages"}:
|
||||
# For keep parameters, 0 is valid (means keep nothing) when allow_zero_for_keep is True
|
||||
# For trigger parameters, must be >= 1
|
||||
# For summarization keep (allow_zero_for_keep=False), must be > 0
|
||||
if parameter_name == "keep" and allow_zero_for_keep:
|
||||
# Context editing allows 0 for keep (means keep nothing)
|
||||
min_value = 0
|
||||
if value < min_value:
|
||||
msg = f"{parameter_name} thresholds must be >= {min_value}, got {value}."
|
||||
raise ValueError(msg)
|
||||
elif parameter_name == "trigger":
|
||||
# Trigger must be >= 1 for both context editing and summarization
|
||||
min_value = 1
|
||||
if value < min_value:
|
||||
msg = f"{parameter_name} thresholds must be >= {min_value}, got {value}."
|
||||
raise ValueError(msg)
|
||||
# Summarization keep must be > 0
|
||||
elif value <= 0:
|
||||
msg = f"{parameter_name} thresholds must be >= 1, got {value}."
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
msg = f"Unsupported context size type {kind} for {parameter_name}."
|
||||
raise ValueError(msg)
|
||||
return context
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ContextFraction",
|
||||
"ContextMessages",
|
||||
"ContextSize",
|
||||
"ContextTokens",
|
||||
"TokenCounter",
|
||||
"coerce_to_context_size",
|
||||
"validate_context_size",
|
||||
]
|
||||
@@ -9,20 +9,33 @@ chat model.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable, Iterable, Sequence
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.utils import count_tokens_approximately
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from langchain.agents.middleware._context import (
|
||||
ContextFraction,
|
||||
ContextMessages,
|
||||
ContextSize,
|
||||
ContextTokens,
|
||||
TokenCounter,
|
||||
)
|
||||
from langchain.agents.middleware._context import (
|
||||
coerce_to_context_size as _coerce_to_context_size,
|
||||
)
|
||||
from langchain.agents.middleware._context import (
|
||||
validate_context_size as _validate_context_size_base,
|
||||
)
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelCallResult,
|
||||
@@ -30,13 +43,8 @@ from langchain.agents.middleware.types import (
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
DEFAULT_TOOL_PLACEHOLDER = "[cleared]"
|
||||
|
||||
|
||||
TokenCounter = Callable[
|
||||
[Sequence[BaseMessage]],
|
||||
int,
|
||||
]
|
||||
_DEFAULT_TOOL_PLACEHOLDER = "[cleared]"
|
||||
"""Default placeholder text for cleared tool outputs."""
|
||||
|
||||
|
||||
class ContextEdit(Protocol):
|
||||
@@ -52,28 +60,125 @@ class ContextEdit(Protocol):
|
||||
...
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@dataclass(slots=True, init=False)
|
||||
class ClearToolUsesEdit(ContextEdit):
|
||||
"""Configuration for clearing tool outputs when token limits are exceeded."""
|
||||
|
||||
trigger: int = 100_000
|
||||
"""Token count that triggers the edit."""
|
||||
trigger: ContextSize | list[ContextSize]
|
||||
"""One or more thresholds that trigger the edit.
|
||||
|
||||
clear_at_least: int = 0
|
||||
"""Minimum number of tokens to reclaim when the edit runs."""
|
||||
Provide a single `ContextSize` tuple or a list of tuples, in which case
|
||||
the edit runs when any threshold is breached.
|
||||
|
||||
keep: int = 3
|
||||
"""Number of most recent tool results that must be preserved."""
|
||||
Examples: `("messages", 50)`, `("tokens", 100_000)`,
|
||||
`[("fraction", 0.8), ("messages", 100)]`.
|
||||
"""
|
||||
|
||||
clear_tool_inputs: bool = False
|
||||
keep: ContextSize
|
||||
"""Context retention policy for tool results.
|
||||
|
||||
Provide a `ContextSize` tuple to specify how many tool results to preserve.
|
||||
|
||||
Defaults to keeping the most recent 3 tool results.
|
||||
|
||||
Examples: `("messages", 3)`, `("tokens", 3000)`, or `("fraction", 0.3)`.
|
||||
"""
|
||||
|
||||
clear_tool_inputs: bool
|
||||
"""Whether to clear the originating tool call parameters on the AI message."""
|
||||
|
||||
exclude_tools: Sequence[str] = ()
|
||||
exclude_tools: Sequence[str]
|
||||
"""List of tool names to exclude from clearing."""
|
||||
|
||||
placeholder: str = DEFAULT_TOOL_PLACEHOLDER
|
||||
placeholder: str
|
||||
"""Placeholder text inserted for cleared tool outputs."""
|
||||
|
||||
_model: Any
|
||||
"""Reference to the model for accessing profile information."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
trigger: ContextSize | list[ContextSize] | int | list[int] = ("tokens", 100_000),
|
||||
keep: ContextSize | int = ("messages", 3),
|
||||
clear_tool_inputs: bool = False,
|
||||
exclude_tools: Sequence[str] = (),
|
||||
placeholder: str = _DEFAULT_TOOL_PLACEHOLDER,
|
||||
**deprecated_kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the clear tool uses edit configuration.
|
||||
|
||||
Args:
|
||||
trigger: One or more thresholds that trigger the edit.
|
||||
|
||||
Provide a single `ContextSize` tuple or a list of tuples, in which case
|
||||
the edit runs when any threshold is breached.
|
||||
|
||||
For backwards compatibility, integers are interpreted as token counts.
|
||||
|
||||
Examples: `("messages", 50)`, `("tokens", 100_000)`, `100_000`,
|
||||
`[("fraction", 0.8), ("messages", 100)]`.
|
||||
keep: Context retention policy for tool results.
|
||||
|
||||
Provide a `ContextSize` tuple to specify how many tool results to preserve.
|
||||
|
||||
For backwards compatibility, integers are interpreted as message counts.
|
||||
|
||||
Defaults to keeping the most recent 3 tool results.
|
||||
|
||||
Examples: `("messages", 3)`, `3`, `("tokens", 3000)`, or `("fraction", 0.3)`.
|
||||
clear_tool_inputs: Whether to clear the originating tool call parameters
|
||||
on the AI message.
|
||||
exclude_tools: List of tool names to exclude from clearing.
|
||||
placeholder: Placeholder text inserted for cleared tool outputs.
|
||||
"""
|
||||
# Handle deprecated clear_at_least parameter
|
||||
if "clear_at_least" in deprecated_kwargs:
|
||||
warnings.warn(
|
||||
"clear_at_least is deprecated and will be ignored. "
|
||||
"Use keep=('tokens', value) to control token retention instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# Coerce and validate trigger
|
||||
if isinstance(trigger, list):
|
||||
coerced_list = []
|
||||
for idx, item in enumerate(trigger):
|
||||
if isinstance(item, int):
|
||||
coerced = _coerce_to_context_size(
|
||||
item, kind="trigger", param_name=f"trigger[{idx}]"
|
||||
)
|
||||
else:
|
||||
coerced = item
|
||||
validated = self._validate_context_size(coerced, "trigger")
|
||||
coerced_list.append(validated)
|
||||
self.trigger = coerced_list
|
||||
else:
|
||||
if isinstance(trigger, int):
|
||||
coerced = _coerce_to_context_size(trigger, kind="trigger", param_name="trigger")
|
||||
else:
|
||||
coerced = trigger
|
||||
validated = self._validate_context_size(coerced, "trigger")
|
||||
self.trigger = validated
|
||||
|
||||
# Coerce and validate keep
|
||||
if isinstance(keep, int):
|
||||
coerced_keep = _coerce_to_context_size(keep, kind="keep", param_name="keep")
|
||||
else:
|
||||
coerced_keep = keep
|
||||
validated_keep = self._validate_context_size(coerced_keep, "keep")
|
||||
self.keep = validated_keep
|
||||
|
||||
self.clear_tool_inputs = clear_tool_inputs
|
||||
self.exclude_tools = exclude_tools
|
||||
self.placeholder = placeholder
|
||||
self._model = None # Will be set by middleware
|
||||
|
||||
def _validate_context_size(self, context: ContextSize, parameter_name: str) -> ContextSize:
|
||||
"""Validate context configuration tuples."""
|
||||
return _validate_context_size_base(context, parameter_name, allow_zero_for_keep=True)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
messages: list[AnyMessage],
|
||||
@@ -83,19 +188,21 @@ class ClearToolUsesEdit(ContextEdit):
|
||||
"""Apply the clear-tool-uses strategy."""
|
||||
tokens = count_tokens(messages)
|
||||
|
||||
if tokens <= self.trigger:
|
||||
if not self._should_trigger(messages, tokens):
|
||||
return
|
||||
|
||||
candidates = [
|
||||
(idx, msg) for idx, msg in enumerate(messages) if isinstance(msg, ToolMessage)
|
||||
]
|
||||
|
||||
if self.keep >= len(candidates):
|
||||
candidates = []
|
||||
elif self.keep:
|
||||
candidates = candidates[: -self.keep]
|
||||
# Calculate how many tool results to keep
|
||||
keep_count = self._calculate_keep_count(candidates)
|
||||
|
||||
if keep_count >= len(candidates):
|
||||
candidates = []
|
||||
elif keep_count > 0:
|
||||
candidates = candidates[:-keep_count]
|
||||
|
||||
cleared_tokens = 0
|
||||
excluded_tools = set(self.exclude_tools)
|
||||
|
||||
for idx, tool_message in candidates:
|
||||
@@ -144,14 +251,80 @@ class ClearToolUsesEdit(ContextEdit):
|
||||
tool_message.tool_call_id,
|
||||
)
|
||||
|
||||
if self.clear_at_least > 0:
|
||||
new_token_count = count_tokens(messages)
|
||||
cleared_tokens = max(0, tokens - new_token_count)
|
||||
if cleared_tokens >= self.clear_at_least:
|
||||
break
|
||||
|
||||
return
|
||||
|
||||
def _should_trigger(
|
||||
self,
|
||||
messages: list[AnyMessage],
|
||||
total_tokens: int,
|
||||
) -> bool:
|
||||
"""Determine whether the edit should run for the current context usage."""
|
||||
trigger_conditions: list[ContextSize] = (
|
||||
self.trigger if isinstance(self.trigger, list) else [self.trigger]
|
||||
)
|
||||
|
||||
for condition in trigger_conditions:
|
||||
kind, value = condition
|
||||
if kind == "messages" and len(messages) >= value:
|
||||
return True
|
||||
if kind == "tokens" and total_tokens >= value:
|
||||
return True
|
||||
if kind == "fraction":
|
||||
max_input_tokens = self._get_profile_limits()
|
||||
if max_input_tokens is None:
|
||||
continue
|
||||
threshold = int(max_input_tokens * value)
|
||||
if threshold <= 0:
|
||||
threshold = 1
|
||||
if total_tokens >= threshold:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _calculate_keep_count(
|
||||
self,
|
||||
candidates: list[tuple[int, ToolMessage]],
|
||||
) -> int:
|
||||
"""Calculate how many tool results to keep based on retention policy."""
|
||||
keep_policy: ContextSize = self.keep
|
||||
kind, value = keep_policy
|
||||
if kind == "messages":
|
||||
return int(value)
|
||||
if kind == "tokens":
|
||||
# For token-based retention, we would need to count tokens per tool message
|
||||
# For simplicity, convert to message count based on average
|
||||
# This is a simplified implementation - could be enhanced
|
||||
return int(value)
|
||||
if kind == "fraction":
|
||||
max_input_tokens = self._get_profile_limits()
|
||||
if max_input_tokens is None:
|
||||
# Fallback to default message count
|
||||
return 3
|
||||
target_count = int(len(candidates) * value)
|
||||
if target_count <= 0:
|
||||
target_count = 1
|
||||
return target_count
|
||||
return 3 # Default fallback
|
||||
|
||||
def _get_profile_limits(self) -> int | None:
|
||||
"""Retrieve max input token limit from the model profile."""
|
||||
if self._model is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
profile = self._model.profile
|
||||
except (AttributeError, ImportError):
|
||||
return None
|
||||
|
||||
if not isinstance(profile, Mapping):
|
||||
return None
|
||||
|
||||
max_input_tokens = profile.get("max_input_tokens")
|
||||
|
||||
if not isinstance(max_input_tokens, int):
|
||||
return None
|
||||
|
||||
return max_input_tokens
|
||||
|
||||
def _build_cleared_tool_input_message(
|
||||
self,
|
||||
message: AIMessage,
|
||||
@@ -215,6 +388,25 @@ class ContextEditingMiddleware(AgentMiddleware):
|
||||
self.edits = list(edits or (ClearToolUsesEdit(),))
|
||||
self.token_count_method = token_count_method
|
||||
|
||||
# Validate that fractional limits can be used
|
||||
requires_profile = False
|
||||
for edit in self.edits:
|
||||
if isinstance(edit, ClearToolUsesEdit):
|
||||
trigger_conditions: list[ContextSize] = (
|
||||
edit.trigger if isinstance(edit.trigger, list) else [edit.trigger]
|
||||
)
|
||||
for condition in trigger_conditions:
|
||||
if condition[0] == "fraction":
|
||||
requires_profile = True
|
||||
break
|
||||
keep_policy: ContextSize = edit.keep
|
||||
if keep_policy[0] == "fraction":
|
||||
requires_profile = True
|
||||
|
||||
if requires_profile:
|
||||
# Just warn, don't raise - we'll handle it gracefully at runtime
|
||||
pass
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
@@ -225,20 +417,21 @@ class ContextEditingMiddleware(AgentMiddleware):
|
||||
return handler(request)
|
||||
|
||||
if self.token_count_method == "approximate": # noqa: S105
|
||||
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return count_tokens_approximately(messages)
|
||||
count_tokens: TokenCounter = count_tokens_approximately
|
||||
else:
|
||||
system_msg = (
|
||||
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
|
||||
)
|
||||
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
def count_tokens(messages: Iterable[Any]) -> int:
|
||||
return request.model.get_num_tokens_from_messages(
|
||||
system_msg + list(messages), request.tools
|
||||
)
|
||||
|
||||
# Set model reference on edits that need it
|
||||
for edit in self.edits:
|
||||
if isinstance(edit, ClearToolUsesEdit):
|
||||
edit._model = request.model
|
||||
edit.apply(request.messages, count_tokens=count_tokens)
|
||||
|
||||
return handler(request)
|
||||
@@ -253,20 +446,21 @@ class ContextEditingMiddleware(AgentMiddleware):
|
||||
return await handler(request)
|
||||
|
||||
if self.token_count_method == "approximate": # noqa: S105
|
||||
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return count_tokens_approximately(messages)
|
||||
count_tokens: TokenCounter = count_tokens_approximately
|
||||
else:
|
||||
system_msg = (
|
||||
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
|
||||
)
|
||||
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
def count_tokens(messages: Iterable[Any]) -> int:
|
||||
return request.model.get_num_tokens_from_messages(
|
||||
system_msg + list(messages), request.tools
|
||||
)
|
||||
|
||||
# Set model reference on edits that need it
|
||||
for edit in self.edits:
|
||||
if isinstance(edit, ClearToolUsesEdit):
|
||||
edit._model = request.model
|
||||
edit.apply(request.messages, count_tokens=count_tokens)
|
||||
|
||||
return await handler(request)
|
||||
@@ -275,4 +469,8 @@ class ContextEditingMiddleware(AgentMiddleware):
|
||||
__all__ = [
|
||||
"ClearToolUsesEdit",
|
||||
"ContextEditingMiddleware",
|
||||
"ContextFraction",
|
||||
"ContextMessages",
|
||||
"ContextSize",
|
||||
"ContextTokens",
|
||||
]
|
||||
|
||||
@@ -2,13 +2,12 @@
|
||||
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from typing import Any, Literal, cast
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
MessageLikeRepresentation,
|
||||
RemoveMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
@@ -19,10 +18,24 @@ from langgraph.graph.message import (
|
||||
)
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from langchain.agents.middleware._context import ContextSize, TokenCounter
|
||||
from langchain.agents.middleware._context import (
|
||||
validate_context_size as _validate_context_size_base,
|
||||
)
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState
|
||||
from langchain.chat_models import BaseChatModel, init_chat_model
|
||||
|
||||
TokenCounter = Callable[[Iterable[MessageLikeRepresentation]], int]
|
||||
_DEFAULT_MESSAGES_TO_KEEP = 20
|
||||
"""Default number of messages to keep after summarization."""
|
||||
|
||||
_DEFAULT_TRIM_TOKEN_LIMIT = 4000
|
||||
"""Default token limit when trimming messages for summarization."""
|
||||
|
||||
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
|
||||
"""Default fallback message count when trimming fails."""
|
||||
|
||||
_SEARCH_RANGE_FOR_TOOL_PAIRS = 5
|
||||
"""Range to search for AI/Tool message pairs when determining safe cutoff points."""
|
||||
|
||||
DEFAULT_SUMMARY_PROMPT = """<role>
|
||||
Context Extraction Assistant
|
||||
@@ -52,17 +65,6 @@ Messages to summarize:
|
||||
{messages}
|
||||
</messages>""" # noqa: E501
|
||||
|
||||
_DEFAULT_MESSAGES_TO_KEEP = 20
|
||||
_DEFAULT_TRIM_TOKEN_LIMIT = 4000
|
||||
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
|
||||
_SEARCH_RANGE_FOR_TOOL_PAIRS = 5
|
||||
|
||||
ContextFraction = tuple[Literal["fraction"], float]
|
||||
ContextTokens = tuple[Literal["tokens"], int]
|
||||
ContextMessages = tuple[Literal["messages"], int]
|
||||
|
||||
ContextSize = ContextFraction | ContextTokens | ContextMessages
|
||||
|
||||
|
||||
class SummarizationMiddleware(AgentMiddleware):
|
||||
"""Summarizes conversation history when token limits are approached.
|
||||
@@ -323,19 +325,8 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
|
||||
def _validate_context_size(self, context: ContextSize, parameter_name: str) -> ContextSize:
|
||||
"""Validate context configuration tuples."""
|
||||
kind, value = context
|
||||
if kind == "fraction":
|
||||
if not 0 < value <= 1:
|
||||
msg = f"Fractional {parameter_name} values must be between 0 and 1, got {value}."
|
||||
raise ValueError(msg)
|
||||
elif kind in {"tokens", "messages"}:
|
||||
if value <= 0:
|
||||
msg = f"{parameter_name} thresholds must be greater than 0, got {value}."
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
msg = f"Unsupported context size type {kind} for {parameter_name}."
|
||||
raise ValueError(msg)
|
||||
return context
|
||||
# For summarization, we don't allow zero values for keep
|
||||
return _validate_context_size_base(context, parameter_name, allow_zero_for_keep=False)
|
||||
|
||||
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
|
||||
return [
|
||||
|
||||
@@ -4,6 +4,8 @@ from __future__ import annotations
|
||||
|
||||
from typing import Iterable, cast
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.agents.middleware.context_editing import (
|
||||
ClearToolUsesEdit,
|
||||
ContextEditingMiddleware,
|
||||
@@ -399,3 +401,105 @@ async def test_exclude_tools_prevents_clearing_async() -> None:
|
||||
|
||||
assert isinstance(calc_tool, ToolMessage)
|
||||
assert calc_tool.content == "[cleared]"
|
||||
|
||||
|
||||
def test_new_api_with_context_size_tuples() -> None:
|
||||
"""Test the new API with ContextSize tuples."""
|
||||
tool_call_id = "call-1"
|
||||
ai_message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"id": tool_call_id, "name": "search", "args": {}}],
|
||||
)
|
||||
tool_message = ToolMessage(content="x" * 200, tool_call_id=tool_call_id)
|
||||
|
||||
state, request = _make_state_and_request([ai_message, tool_message])
|
||||
|
||||
# Test with messages-based trigger and keep
|
||||
edit = ClearToolUsesEdit(
|
||||
trigger=("messages", 2),
|
||||
keep=("messages", 0),
|
||||
placeholder="[cleared]",
|
||||
)
|
||||
middleware = ContextEditingMiddleware(edits=[edit])
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
cleared_tool = request.messages[1]
|
||||
assert isinstance(cleared_tool, ToolMessage)
|
||||
assert cleared_tool.content == "[cleared]"
|
||||
|
||||
|
||||
def test_multiple_trigger_conditions() -> None:
|
||||
"""Test multiple trigger conditions (OR logic)."""
|
||||
tool_call_id = "call-1"
|
||||
ai_message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"id": tool_call_id, "name": "search", "args": {}}],
|
||||
)
|
||||
tool_message = ToolMessage(content="x" * 200, tool_call_id=tool_call_id)
|
||||
|
||||
state, request = _make_state_and_request([ai_message, tool_message])
|
||||
|
||||
# Multiple triggers - should trigger if ANY condition is met
|
||||
edit = ClearToolUsesEdit(
|
||||
trigger=[("messages", 10), ("tokens", 50)], # Token count will trigger
|
||||
keep=("messages", 0),
|
||||
placeholder="[cleared]",
|
||||
)
|
||||
middleware = ContextEditingMiddleware(edits=[edit])
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
cleared_tool = request.messages[1]
|
||||
assert isinstance(cleared_tool, ToolMessage)
|
||||
assert cleared_tool.content == "[cleared]"
|
||||
|
||||
|
||||
def test_backwards_compatibility_deprecation_warnings() -> None:
|
||||
"""Test that integer parameters raise deprecation warnings."""
|
||||
import warnings
|
||||
|
||||
# Test trigger deprecation
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
edit = ClearToolUsesEdit(trigger=100_000)
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[0].category, DeprecationWarning)
|
||||
assert "trigger=100000 (int) is deprecated" in str(w[0].message)
|
||||
assert edit.trigger == ("tokens", 100_000)
|
||||
|
||||
# Test keep deprecation
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
edit = ClearToolUsesEdit(keep=5)
|
||||
assert len(w) == 1
|
||||
assert issubclass(w[0].category, DeprecationWarning)
|
||||
assert "keep=5 (int) is deprecated" in str(w[0].message)
|
||||
assert edit.keep == ("messages", 5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("param_name", "param_value", "expected_error"),
|
||||
[
|
||||
("trigger", ("fraction", 1.5), "Fractional trigger values must be between 0 and 1"),
|
||||
("trigger", ("fraction", 0.0), "Fractional trigger values must be between 0 and 1"),
|
||||
("trigger", ("tokens", 0), "trigger thresholds must be >= 1"),
|
||||
("trigger", ("messages", -5), "trigger thresholds must be >= 1"),
|
||||
("keep", ("messages", -1), "keep thresholds must be >= 0"),
|
||||
("keep", ("fraction", -0.1), "Fractional keep values must be between 0 and 1"),
|
||||
("trigger", ("invalid", 100), "Unsupported context size type"),
|
||||
("keep", ("invalid", 100), "Unsupported context size type"),
|
||||
],
|
||||
)
|
||||
def test_validation_errors(
|
||||
param_name: str, param_value: tuple[str, float | int], expected_error: str
|
||||
) -> None:
|
||||
"""Test validation of ContextSize parameters."""
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
ClearToolUsesEdit(**{param_name: param_value}) # type: ignore[arg-type]
|
||||
|
||||
@@ -531,9 +531,9 @@ def test_summarization_middleware_keep_messages() -> None:
|
||||
("trigger", ("fraction", 0.0), "Fractional trigger values must be between 0 and 1"),
|
||||
("trigger", ("fraction", 1.5), "Fractional trigger values must be between 0 and 1"),
|
||||
("keep", ("fraction", -0.1), "Fractional keep values must be between 0 and 1"),
|
||||
("trigger", ("tokens", 0), "trigger thresholds must be greater than 0"),
|
||||
("trigger", ("messages", -5), "trigger thresholds must be greater than 0"),
|
||||
("keep", ("tokens", 0), "keep thresholds must be greater than 0"),
|
||||
("trigger", ("tokens", 0), "trigger thresholds must be >= 1"),
|
||||
("trigger", ("messages", -5), "trigger thresholds must be >= 1"),
|
||||
("keep", ("tokens", 0), "keep thresholds must be >= 1"),
|
||||
("trigger", ("invalid", 100), "Unsupported context size type"),
|
||||
("keep", ("invalid", 100), "Unsupported context size type"),
|
||||
],
|
||||
|
||||
2
libs/langchain_v1/uv.lock
generated
2
libs/langchain_v1/uv.lock
generated
@@ -2395,7 +2395,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-openai"
|
||||
version = "1.0.2"
|
||||
version = "1.0.3"
|
||||
source = { editable = "../partners/openai" }
|
||||
dependencies = [
|
||||
{ name = "langchain-core" },
|
||||
|
||||
Reference in New Issue
Block a user