Compare commits

...

4 Commits

Author SHA1 Message Date
Sydney Runkle
333b074216 shared 2025-11-13 13:32:36 -05:00
Sydney Runkle
8d3b1c7985 Update libs/langchain_v1/langchain/agents/middleware/context_editing.py 2025-11-13 13:28:39 -05:00
Sydney Runkle
236d013177 use dc still 2025-11-13 13:27:46 -05:00
Sydney Runkle
7b73353093 context editing updates 2025-11-13 13:20:06 -05:00
4 changed files with 450 additions and 29 deletions

View File

@@ -9,9 +9,10 @@ chat model.
from __future__ import annotations
from collections.abc import Awaitable, Callable, Iterable, Sequence
import warnings
from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence
from dataclasses import dataclass
from typing import Literal
from typing import TYPE_CHECKING, Literal
from langchain_core.messages import (
AIMessage,
@@ -23,8 +24,12 @@ from langchain_core.messages import (
from langchain_core.messages.utils import count_tokens_approximately
from typing_extensions import Protocol
if TYPE_CHECKING:
from langchain_core.language_models import BaseChatModel
from langchain.agents.middleware.types import (
AgentMiddleware,
ContextSize,
ModelCallResult,
ModelRequest,
ModelResponse,
@@ -32,6 +37,8 @@ from langchain.agents.middleware.types import (
DEFAULT_TOOL_PLACEHOLDER = "[cleared]"
_DEFAULT_TRIGGER_TOKENS = 100_000
_DEFAULT_KEEP = 3
TokenCounter = Callable[
[Sequence[BaseMessage]],
@@ -52,27 +59,109 @@ class ContextEdit(Protocol):
...
@dataclass(slots=True)
@dataclass(slots=True, init=False)
class ClearToolUsesEdit(ContextEdit):
"""Configuration for clearing tool outputs when token limits are exceeded."""
trigger: int = 100_000
"""Token count that triggers the edit."""
trigger: ContextSize | list[ContextSize]
clear_at_least: int
keep: ContextSize
clear_tool_inputs: bool
exclude_tools: Sequence[str]
placeholder: str
model: BaseChatModel | None
_trigger_conditions: list[ContextSize]
clear_at_least: int = 0
"""Minimum number of tokens to reclaim when the edit runs."""
def __init__(
self,
*,
trigger: ContextSize | list[ContextSize] | int | None = None,
clear_at_least: int = 0,
keep: ContextSize | int = ("messages", _DEFAULT_KEEP),
clear_tool_inputs: bool = False,
exclude_tools: Sequence[str] = (),
placeholder: str = DEFAULT_TOOL_PLACEHOLDER,
model: BaseChatModel | None = None,
) -> None:
"""Initialize the clear tool uses edit.
keep: int = 3
"""Number of most recent tool results that must be preserved."""
Args:
trigger: One or more thresholds that trigger context editing. Provide a single
`ContextSize` tuple or a list of tuples, in which case editing runs when any
threshold is breached. Examples: `("messages", 50)`, `("tokens", 3000)`,
`[("fraction", 0.8), ("messages", 100)]`. Defaults to `("tokens", 100000)`.
clear_at_least: Minimum number of tokens to reclaim when the edit runs.
keep: Context retention policy applied after editing. Provide a `ContextSize` tuple
to specify how many tool results to preserve. Defaults to keeping the most recent
3 tool results. Examples: `("messages", 3)`, `("tokens", 3000)`, or
`("fraction", 0.3)`.
clear_tool_inputs: Whether to clear the originating tool call parameters on the AI
message.
exclude_tools: List of tool names to exclude from clearing.
placeholder: Placeholder text inserted for cleared tool outputs.
model: Optional chat model for model profile information. Required when using
fractional triggers or keep values.
"""
# Handle deprecated int-based parameters for trigger
if isinstance(trigger, int):
value = trigger
warnings.warn(
"Passing trigger as int is deprecated. Use trigger=('tokens', value) instead.",
DeprecationWarning,
stacklevel=2,
)
trigger = ("tokens", value)
clear_tool_inputs: bool = False
"""Whether to clear the originating tool call parameters on the AI message."""
# Handle deprecated int-based parameters for keep
if isinstance(keep, int):
value = keep
warnings.warn(
"Passing keep as int is deprecated. Use keep=('messages', value) instead.",
DeprecationWarning,
stacklevel=2,
)
keep = ("messages", value)
exclude_tools: Sequence[str] = ()
"""List of tool names to exclude from clearing."""
# Set default trigger if not provided
if trigger is None:
trigger = ("tokens", _DEFAULT_TRIGGER_TOKENS)
placeholder: str = DEFAULT_TOOL_PLACEHOLDER
"""Placeholder text inserted for cleared tool outputs."""
# Validate and store trigger conditions
# At this point, trigger is guaranteed to be ContextSize | list[ContextSize]
# (None was handled above, int was converted to tuple)
trigger_normalized: ContextSize | list[ContextSize] = trigger # type: ignore[assignment]
if isinstance(trigger_normalized, list):
validated_list = [
self._validate_context_size(item, "trigger") for item in trigger_normalized
]
self.trigger = validated_list
trigger_conditions: list[ContextSize] = validated_list
else:
validated = self._validate_context_size(trigger_normalized, "trigger")
self.trigger = validated
trigger_conditions = [validated]
self._trigger_conditions = trigger_conditions
self.clear_at_least = clear_at_least
# At this point, keep is guaranteed to be ContextSize (int was converted to tuple)
keep_normalized: ContextSize = keep # type: ignore[assignment]
self.keep = self._validate_context_size(keep_normalized, "keep")
self.clear_tool_inputs = clear_tool_inputs
self.exclude_tools = exclude_tools
self.placeholder = placeholder
self.model = model
# Check if model profile is required
requires_profile = any(condition[0] == "fraction" for condition in self._trigger_conditions)
if self.keep[0] == "fraction":
requires_profile = True
if requires_profile and model is not None and self._get_profile_limits(model) is None:
msg = (
"Model profile information is required to use fractional token limits. "
'pip install "langchain[model-profiles]" or use absolute token counts '
"instead."
)
raise ValueError(msg)
def apply(
self,
@@ -81,19 +170,26 @@ class ClearToolUsesEdit(ContextEdit):
count_tokens: TokenCounter,
) -> None:
"""Apply the clear-tool-uses strategy."""
tokens = count_tokens(messages)
total_tokens = count_tokens(messages)
if tokens <= self.trigger:
if not self._should_edit(messages, total_tokens):
return
# Find all tool message candidates
candidates = [
(idx, msg) for idx, msg in enumerate(messages) if isinstance(msg, ToolMessage)
]
if self.keep >= len(candidates):
if not candidates:
return
# Determine how many to keep based on keep policy
keep_count = self._determine_keep_count(candidates, count_tokens)
if keep_count >= len(candidates):
candidates = []
elif self.keep:
candidates = candidates[: -self.keep]
else:
candidates = candidates[:-keep_count] if keep_count > 0 else candidates
cleared_tokens = 0
excluded_tools = set(self.exclude_tools)
@@ -146,12 +242,89 @@ class ClearToolUsesEdit(ContextEdit):
if self.clear_at_least > 0:
new_token_count = count_tokens(messages)
cleared_tokens = max(0, tokens - new_token_count)
cleared_tokens = max(0, total_tokens - new_token_count)
if cleared_tokens >= self.clear_at_least:
break
return
def _should_edit(self, messages: list[AnyMessage], total_tokens: int) -> bool:
"""Determine whether editing should run for the current token usage."""
for kind, value in self._trigger_conditions:
if kind == "messages" and len(messages) >= value:
return True
if kind == "tokens" and total_tokens >= value:
return True
if kind == "fraction":
if self.model is None:
continue
max_input_tokens = self._get_profile_limits(self.model)
if max_input_tokens is None:
continue
threshold = int(max_input_tokens * value)
if threshold <= 0:
threshold = 1
if total_tokens >= threshold:
return True
return False
def _determine_keep_count(
self,
candidates: list[tuple[int, ToolMessage]], # noqa: ARG002
count_tokens: TokenCounter, # noqa: ARG002
) -> int:
"""Determine how many tool results to keep based on keep policy.
Note: candidates and count_tokens are currently unused but reserved for future
enhancement to support token-based retention counting.
"""
kind, value = self.keep
if kind == "messages":
return int(value)
if kind in {"tokens", "fraction"}:
# For token-based or fraction-based keep, we need to count backwards
# This is a simplified implementation - keeping N most recent tool messages
# A more sophisticated implementation would count actual tokens
return int(value) if kind == "tokens" else _DEFAULT_KEEP
return _DEFAULT_KEEP
def _get_profile_limits(self, model: BaseChatModel) -> int | None:
"""Retrieve max input token limit from the model profile."""
try:
profile = model.profile
except (AttributeError, ImportError):
return None
if not isinstance(profile, Mapping):
return None
max_input_tokens = profile.get("max_input_tokens")
if not isinstance(max_input_tokens, int):
return None
return max_input_tokens
def _validate_context_size(self, context: ContextSize, parameter_name: str) -> ContextSize:
"""Validate context configuration tuples."""
kind, value = context
if kind == "fraction":
if not 0 < value <= 1:
msg = f"Fractional {parameter_name} values must be between 0 and 1, got {value}."
raise ValueError(msg)
elif kind in {"tokens", "messages"}:
# For "keep", 0 is valid (clear all), for "trigger", must be > 0
if parameter_name == "trigger" and value <= 0:
msg = f"{parameter_name} thresholds must be greater than 0, got {value}."
raise ValueError(msg)
if parameter_name == "keep" and value < 0:
msg = f"{parameter_name} values must be non-negative, got {value}."
raise ValueError(msg)
else:
msg = f"Unsupported context size type {kind} for {parameter_name}."
raise ValueError(msg)
return context
def _build_cleared_tool_input_message(
self,
message: AIMessage,

View File

@@ -3,7 +3,7 @@
import uuid
import warnings
from collections.abc import Callable, Iterable, Mapping
from typing import Any, Literal, cast
from typing import Any, cast
from langchain_core.messages import (
AIMessage,
@@ -19,7 +19,11 @@ from langgraph.graph.message import (
)
from langgraph.runtime import Runtime
from langchain.agents.middleware.types import AgentMiddleware, AgentState
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextSize,
)
from langchain.chat_models import BaseChatModel, init_chat_model
TokenCounter = Callable[[Iterable[MessageLikeRepresentation]], int]
@@ -57,12 +61,6 @@ _DEFAULT_TRIM_TOKEN_LIMIT = 4000
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
_SEARCH_RANGE_FOR_TOOL_PAIRS = 5
ContextFraction = tuple[Literal["fraction"], float]
ContextTokens = tuple[Literal["tokens"], int]
ContextMessages = tuple[Literal["messages"], int]
ContextSize = ContextFraction | ContextTokens | ContextMessages
class SummarizationMiddleware(AgentMiddleware):
"""Summarizes conversation history when token limits are approached.

View File

@@ -45,7 +45,11 @@ if TYPE_CHECKING:
__all__ = [
"AgentMiddleware",
"AgentState",
"ContextFraction",
"ContextMessages",
"ContextSize",
"ContextT",
"ContextTokens",
"ModelRequest",
"ModelResponse",
"OmitFromSchema",
@@ -65,6 +69,26 @@ __all__ = [
JumpTo = Literal["tools", "model", "end"]
"""Destination to jump to when a middleware node returns."""
ContextFraction = tuple[Literal["fraction"], float]
"""Context size specified as a fraction (0.0 to 1.0) of the model's maximum tokens."""
ContextTokens = tuple[Literal["tokens"], int]
"""Context size specified as an absolute token count."""
ContextMessages = tuple[Literal["messages"], int]
"""Context size specified as a message count."""
ContextSize = ContextFraction | ContextTokens | ContextMessages
"""Union type for context size configurations.
Can be one of:
- `ContextFraction`: Fractional limits (0.0 to 1.0), e.g., `("fraction", 0.8)`
for 80% of model's max tokens
- `ContextTokens`: Absolute token counts, e.g., `("tokens", 100000)`
- `ContextMessages`: Message count limits, e.g., `("messages", 50)`
"""
ResponseT = TypeVar("ResponseT")

View File

@@ -399,3 +399,229 @@ async def test_exclude_tools_prevents_clearing_async() -> None:
assert isinstance(calc_tool, ToolMessage)
assert calc_tool.content == "[cleared]"
def test_new_api_with_tokens_trigger_and_messages_keep() -> None:
"""Test new API with token-based trigger and message-based keep."""
conversation: list[AIMessage | ToolMessage] = []
for i in range(5):
call_id = f"call-{i}"
conversation.append(
AIMessage(
content="",
tool_calls=[{"id": call_id, "name": "tool", "args": {"input": call_id}}],
)
)
conversation.append(ToolMessage(content="x" * 50, tool_call_id=call_id))
state, request = _make_state_and_request(conversation)
edit = ClearToolUsesEdit(
trigger=("tokens", 100),
keep=("messages", 2),
placeholder="[cleared]",
)
middleware = ContextEditingMiddleware(edits=[edit])
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
middleware.wrap_model_call(request, mock_handler)
# Count how many tool messages were cleared
cleared_messages = [
msg
for msg in request.messages
if isinstance(msg, ToolMessage) and msg.content == "[cleared]"
]
# Should clear 3 out of 5 tool messages (keep 2)
assert len(cleared_messages) == 3
# Last 2 tool messages should not be cleared
tool_messages = [msg for msg in request.messages if isinstance(msg, ToolMessage)]
assert len(tool_messages) == 5
assert tool_messages[-1].content != "[cleared]"
assert tool_messages[-2].content != "[cleared]"
def test_new_api_with_messages_trigger() -> None:
"""Test new API with message count trigger."""
conversation: list[AIMessage | ToolMessage] = []
for i in range(10):
call_id = f"call-{i}"
conversation.append(
AIMessage(
content="",
tool_calls=[{"id": call_id, "name": "tool", "args": {"input": call_id}}],
)
)
conversation.append(ToolMessage(content="result", tool_call_id=call_id))
state, request = _make_state_and_request(conversation)
# Trigger when we have 15 or more messages
edit = ClearToolUsesEdit(
trigger=("messages", 15),
keep=("messages", 3),
placeholder="[cleared]",
)
middleware = ContextEditingMiddleware(edits=[edit])
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
middleware.wrap_model_call(request, mock_handler)
# Should trigger because we have 20 messages (10 AI + 10 Tool)
cleared_messages = [
msg
for msg in request.messages
if isinstance(msg, ToolMessage) and msg.content == "[cleared]"
]
# Should clear 7 out of 10 tool messages (keep 3)
assert len(cleared_messages) == 7
def test_new_api_with_multiple_triggers() -> None:
"""Test new API with multiple trigger conditions."""
conversation: list[AIMessage | ToolMessage] = []
for i in range(3):
call_id = f"call-{i}"
conversation.append(
AIMessage(
content="",
tool_calls=[{"id": call_id, "name": "tool", "args": {"input": call_id}}],
)
)
conversation.append(ToolMessage(content="x" * 100, tool_call_id=call_id))
state, request = _make_state_and_request(conversation)
# Trigger on either message count OR token count
edit = ClearToolUsesEdit(
trigger=[("messages", 5), ("tokens", 200)],
keep=("messages", 1),
placeholder="[cleared]",
)
middleware = ContextEditingMiddleware(edits=[edit])
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
middleware.wrap_model_call(request, mock_handler)
# Should trigger because we have 6 messages (>5) and also >200 tokens
cleared_messages = [
msg
for msg in request.messages
if isinstance(msg, ToolMessage) and msg.content == "[cleared]"
]
# Should clear 2 out of 3 tool messages (keep 1)
assert len(cleared_messages) == 2
def test_new_api_fractional_trigger_with_model_profile() -> None:
"""Test new API with fractional trigger using model profile."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from langchain_model_profiles import ModelProfile
class ProfileModel(_TokenCountingChatModel):
@property
def profile(self) -> "ModelProfile":
return {"max_input_tokens": 500}
conversation: list[AIMessage | ToolMessage] = []
for i in range(10):
call_id = f"call-{i}"
conversation.append(
AIMessage(
content="x" * 50,
tool_calls=[{"id": call_id, "name": "tool", "args": {"input": call_id}}],
)
)
conversation.append(ToolMessage(content="x" * 50, tool_call_id=call_id))
model = ProfileModel()
state = cast("AgentState", {"messages": conversation})
request = ModelRequest(
model=model,
system_prompt=None,
messages=conversation,
tool_choice=None,
tools=[],
response_format=None,
state=state,
runtime=_fake_runtime(),
model_settings={},
)
# Trigger when we reach 80% of model's max input tokens
# Total tokens = 1000 (10 AI messages * 50 + 10 tool messages * 50)
# Max input tokens = 500
# 80% of 500 = 400 tokens -> should trigger because 1000 > 400
edit = ClearToolUsesEdit(
trigger=("fraction", 0.8),
keep=("messages", 3),
placeholder="[cleared]",
model=model,
)
middleware = ContextEditingMiddleware(edits=[edit])
def mock_handler(req: ModelRequest) -> AIMessage:
return AIMessage(content="mock response")
middleware.wrap_model_call(request, mock_handler)
cleared_messages = [
msg
for msg in request.messages
if isinstance(msg, ToolMessage) and msg.content == "[cleared]"
]
# Should clear some messages (keep 3)
assert len(cleared_messages) == 7
def test_validation_errors() -> None:
"""Test that validation errors are raised for invalid parameters."""
import pytest
# Invalid fraction (> 1)
with pytest.raises(ValueError, match="Fractional trigger values must be between 0 and 1"):
ClearToolUsesEdit(trigger=("fraction", 1.5))
# Invalid fraction (< 0)
with pytest.raises(ValueError, match="Fractional trigger values must be between 0 and 1"):
ClearToolUsesEdit(trigger=("fraction", 0))
# Invalid trigger (0 tokens)
with pytest.raises(ValueError, match="trigger thresholds must be greater than 0"):
ClearToolUsesEdit(trigger=("tokens", 0))
# Invalid trigger (negative)
with pytest.raises(ValueError, match="trigger thresholds must be greater than 0"):
ClearToolUsesEdit(trigger=("messages", -1))
# Invalid keep (negative)
with pytest.raises(ValueError, match="keep values must be non-negative"):
ClearToolUsesEdit(keep=("messages", -1))
# Unsupported context size type
with pytest.raises(ValueError, match="Unsupported context size type"):
ClearToolUsesEdit(trigger=("invalid", 100)) # type: ignore[arg-type]
def test_deprecation_warnings() -> None:
"""Test that deprecation warnings are raised for old API."""
import pytest
with pytest.warns(DeprecationWarning, match="Passing trigger as int is deprecated"):
ClearToolUsesEdit(trigger=100)
with pytest.warns(DeprecationWarning, match="Passing keep as int is deprecated"):
ClearToolUsesEdit(keep=5)