mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-21 21:56:38 +00:00
Compare commits
4 Commits
langchain-
...
sr/refacto
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
333b074216 | ||
|
|
8d3b1c7985 | ||
|
|
236d013177 | ||
|
|
7b73353093 |
@@ -9,9 +9,10 @@ chat model.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable, Iterable, Sequence
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@@ -23,8 +24,12 @@ from langchain_core.messages import (
|
||||
from langchain_core.messages.utils import count_tokens_approximately
|
||||
from typing_extensions import Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ContextSize,
|
||||
ModelCallResult,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
@@ -32,6 +37,8 @@ from langchain.agents.middleware.types import (
|
||||
|
||||
DEFAULT_TOOL_PLACEHOLDER = "[cleared]"
|
||||
|
||||
_DEFAULT_TRIGGER_TOKENS = 100_000
|
||||
_DEFAULT_KEEP = 3
|
||||
|
||||
TokenCounter = Callable[
|
||||
[Sequence[BaseMessage]],
|
||||
@@ -52,27 +59,109 @@ class ContextEdit(Protocol):
|
||||
...
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@dataclass(slots=True, init=False)
|
||||
class ClearToolUsesEdit(ContextEdit):
|
||||
"""Configuration for clearing tool outputs when token limits are exceeded."""
|
||||
|
||||
trigger: int = 100_000
|
||||
"""Token count that triggers the edit."""
|
||||
trigger: ContextSize | list[ContextSize]
|
||||
clear_at_least: int
|
||||
keep: ContextSize
|
||||
clear_tool_inputs: bool
|
||||
exclude_tools: Sequence[str]
|
||||
placeholder: str
|
||||
model: BaseChatModel | None
|
||||
_trigger_conditions: list[ContextSize]
|
||||
|
||||
clear_at_least: int = 0
|
||||
"""Minimum number of tokens to reclaim when the edit runs."""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
trigger: ContextSize | list[ContextSize] | int | None = None,
|
||||
clear_at_least: int = 0,
|
||||
keep: ContextSize | int = ("messages", _DEFAULT_KEEP),
|
||||
clear_tool_inputs: bool = False,
|
||||
exclude_tools: Sequence[str] = (),
|
||||
placeholder: str = DEFAULT_TOOL_PLACEHOLDER,
|
||||
model: BaseChatModel | None = None,
|
||||
) -> None:
|
||||
"""Initialize the clear tool uses edit.
|
||||
|
||||
keep: int = 3
|
||||
"""Number of most recent tool results that must be preserved."""
|
||||
Args:
|
||||
trigger: One or more thresholds that trigger context editing. Provide a single
|
||||
`ContextSize` tuple or a list of tuples, in which case editing runs when any
|
||||
threshold is breached. Examples: `("messages", 50)`, `("tokens", 3000)`,
|
||||
`[("fraction", 0.8), ("messages", 100)]`. Defaults to `("tokens", 100000)`.
|
||||
clear_at_least: Minimum number of tokens to reclaim when the edit runs.
|
||||
keep: Context retention policy applied after editing. Provide a `ContextSize` tuple
|
||||
to specify how many tool results to preserve. Defaults to keeping the most recent
|
||||
3 tool results. Examples: `("messages", 3)`, `("tokens", 3000)`, or
|
||||
`("fraction", 0.3)`.
|
||||
clear_tool_inputs: Whether to clear the originating tool call parameters on the AI
|
||||
message.
|
||||
exclude_tools: List of tool names to exclude from clearing.
|
||||
placeholder: Placeholder text inserted for cleared tool outputs.
|
||||
model: Optional chat model for model profile information. Required when using
|
||||
fractional triggers or keep values.
|
||||
"""
|
||||
# Handle deprecated int-based parameters for trigger
|
||||
if isinstance(trigger, int):
|
||||
value = trigger
|
||||
warnings.warn(
|
||||
"Passing trigger as int is deprecated. Use trigger=('tokens', value) instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
trigger = ("tokens", value)
|
||||
|
||||
clear_tool_inputs: bool = False
|
||||
"""Whether to clear the originating tool call parameters on the AI message."""
|
||||
# Handle deprecated int-based parameters for keep
|
||||
if isinstance(keep, int):
|
||||
value = keep
|
||||
warnings.warn(
|
||||
"Passing keep as int is deprecated. Use keep=('messages', value) instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
keep = ("messages", value)
|
||||
|
||||
exclude_tools: Sequence[str] = ()
|
||||
"""List of tool names to exclude from clearing."""
|
||||
# Set default trigger if not provided
|
||||
if trigger is None:
|
||||
trigger = ("tokens", _DEFAULT_TRIGGER_TOKENS)
|
||||
|
||||
placeholder: str = DEFAULT_TOOL_PLACEHOLDER
|
||||
"""Placeholder text inserted for cleared tool outputs."""
|
||||
# Validate and store trigger conditions
|
||||
# At this point, trigger is guaranteed to be ContextSize | list[ContextSize]
|
||||
# (None was handled above, int was converted to tuple)
|
||||
trigger_normalized: ContextSize | list[ContextSize] = trigger # type: ignore[assignment]
|
||||
if isinstance(trigger_normalized, list):
|
||||
validated_list = [
|
||||
self._validate_context_size(item, "trigger") for item in trigger_normalized
|
||||
]
|
||||
self.trigger = validated_list
|
||||
trigger_conditions: list[ContextSize] = validated_list
|
||||
else:
|
||||
validated = self._validate_context_size(trigger_normalized, "trigger")
|
||||
self.trigger = validated
|
||||
trigger_conditions = [validated]
|
||||
self._trigger_conditions = trigger_conditions
|
||||
|
||||
self.clear_at_least = clear_at_least
|
||||
# At this point, keep is guaranteed to be ContextSize (int was converted to tuple)
|
||||
keep_normalized: ContextSize = keep # type: ignore[assignment]
|
||||
self.keep = self._validate_context_size(keep_normalized, "keep")
|
||||
self.clear_tool_inputs = clear_tool_inputs
|
||||
self.exclude_tools = exclude_tools
|
||||
self.placeholder = placeholder
|
||||
self.model = model
|
||||
|
||||
# Check if model profile is required
|
||||
requires_profile = any(condition[0] == "fraction" for condition in self._trigger_conditions)
|
||||
if self.keep[0] == "fraction":
|
||||
requires_profile = True
|
||||
if requires_profile and model is not None and self._get_profile_limits(model) is None:
|
||||
msg = (
|
||||
"Model profile information is required to use fractional token limits. "
|
||||
'pip install "langchain[model-profiles]" or use absolute token counts '
|
||||
"instead."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -81,19 +170,26 @@ class ClearToolUsesEdit(ContextEdit):
|
||||
count_tokens: TokenCounter,
|
||||
) -> None:
|
||||
"""Apply the clear-tool-uses strategy."""
|
||||
tokens = count_tokens(messages)
|
||||
total_tokens = count_tokens(messages)
|
||||
|
||||
if tokens <= self.trigger:
|
||||
if not self._should_edit(messages, total_tokens):
|
||||
return
|
||||
|
||||
# Find all tool message candidates
|
||||
candidates = [
|
||||
(idx, msg) for idx, msg in enumerate(messages) if isinstance(msg, ToolMessage)
|
||||
]
|
||||
|
||||
if self.keep >= len(candidates):
|
||||
if not candidates:
|
||||
return
|
||||
|
||||
# Determine how many to keep based on keep policy
|
||||
keep_count = self._determine_keep_count(candidates, count_tokens)
|
||||
|
||||
if keep_count >= len(candidates):
|
||||
candidates = []
|
||||
elif self.keep:
|
||||
candidates = candidates[: -self.keep]
|
||||
else:
|
||||
candidates = candidates[:-keep_count] if keep_count > 0 else candidates
|
||||
|
||||
cleared_tokens = 0
|
||||
excluded_tools = set(self.exclude_tools)
|
||||
@@ -146,12 +242,89 @@ class ClearToolUsesEdit(ContextEdit):
|
||||
|
||||
if self.clear_at_least > 0:
|
||||
new_token_count = count_tokens(messages)
|
||||
cleared_tokens = max(0, tokens - new_token_count)
|
||||
cleared_tokens = max(0, total_tokens - new_token_count)
|
||||
if cleared_tokens >= self.clear_at_least:
|
||||
break
|
||||
|
||||
return
|
||||
|
||||
def _should_edit(self, messages: list[AnyMessage], total_tokens: int) -> bool:
|
||||
"""Determine whether editing should run for the current token usage."""
|
||||
for kind, value in self._trigger_conditions:
|
||||
if kind == "messages" and len(messages) >= value:
|
||||
return True
|
||||
if kind == "tokens" and total_tokens >= value:
|
||||
return True
|
||||
if kind == "fraction":
|
||||
if self.model is None:
|
||||
continue
|
||||
max_input_tokens = self._get_profile_limits(self.model)
|
||||
if max_input_tokens is None:
|
||||
continue
|
||||
threshold = int(max_input_tokens * value)
|
||||
if threshold <= 0:
|
||||
threshold = 1
|
||||
if total_tokens >= threshold:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _determine_keep_count(
|
||||
self,
|
||||
candidates: list[tuple[int, ToolMessage]], # noqa: ARG002
|
||||
count_tokens: TokenCounter, # noqa: ARG002
|
||||
) -> int:
|
||||
"""Determine how many tool results to keep based on keep policy.
|
||||
|
||||
Note: candidates and count_tokens are currently unused but reserved for future
|
||||
enhancement to support token-based retention counting.
|
||||
"""
|
||||
kind, value = self.keep
|
||||
if kind == "messages":
|
||||
return int(value)
|
||||
if kind in {"tokens", "fraction"}:
|
||||
# For token-based or fraction-based keep, we need to count backwards
|
||||
# This is a simplified implementation - keeping N most recent tool messages
|
||||
# A more sophisticated implementation would count actual tokens
|
||||
return int(value) if kind == "tokens" else _DEFAULT_KEEP
|
||||
return _DEFAULT_KEEP
|
||||
|
||||
def _get_profile_limits(self, model: BaseChatModel) -> int | None:
|
||||
"""Retrieve max input token limit from the model profile."""
|
||||
try:
|
||||
profile = model.profile
|
||||
except (AttributeError, ImportError):
|
||||
return None
|
||||
|
||||
if not isinstance(profile, Mapping):
|
||||
return None
|
||||
|
||||
max_input_tokens = profile.get("max_input_tokens")
|
||||
|
||||
if not isinstance(max_input_tokens, int):
|
||||
return None
|
||||
|
||||
return max_input_tokens
|
||||
|
||||
def _validate_context_size(self, context: ContextSize, parameter_name: str) -> ContextSize:
|
||||
"""Validate context configuration tuples."""
|
||||
kind, value = context
|
||||
if kind == "fraction":
|
||||
if not 0 < value <= 1:
|
||||
msg = f"Fractional {parameter_name} values must be between 0 and 1, got {value}."
|
||||
raise ValueError(msg)
|
||||
elif kind in {"tokens", "messages"}:
|
||||
# For "keep", 0 is valid (clear all), for "trigger", must be > 0
|
||||
if parameter_name == "trigger" and value <= 0:
|
||||
msg = f"{parameter_name} thresholds must be greater than 0, got {value}."
|
||||
raise ValueError(msg)
|
||||
if parameter_name == "keep" and value < 0:
|
||||
msg = f"{parameter_name} values must be non-negative, got {value}."
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
msg = f"Unsupported context size type {kind} for {parameter_name}."
|
||||
raise ValueError(msg)
|
||||
return context
|
||||
|
||||
def _build_cleared_tool_input_message(
|
||||
self,
|
||||
message: AIMessage,
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from typing import Any, Literal, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@@ -19,7 +19,11 @@ from langgraph.graph.message import (
|
||||
)
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from langchain.agents.middleware.types import AgentMiddleware, AgentState
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextSize,
|
||||
)
|
||||
from langchain.chat_models import BaseChatModel, init_chat_model
|
||||
|
||||
TokenCounter = Callable[[Iterable[MessageLikeRepresentation]], int]
|
||||
@@ -57,12 +61,6 @@ _DEFAULT_TRIM_TOKEN_LIMIT = 4000
|
||||
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
|
||||
_SEARCH_RANGE_FOR_TOOL_PAIRS = 5
|
||||
|
||||
ContextFraction = tuple[Literal["fraction"], float]
|
||||
ContextTokens = tuple[Literal["tokens"], int]
|
||||
ContextMessages = tuple[Literal["messages"], int]
|
||||
|
||||
ContextSize = ContextFraction | ContextTokens | ContextMessages
|
||||
|
||||
|
||||
class SummarizationMiddleware(AgentMiddleware):
|
||||
"""Summarizes conversation history when token limits are approached.
|
||||
|
||||
@@ -45,7 +45,11 @@ if TYPE_CHECKING:
|
||||
__all__ = [
|
||||
"AgentMiddleware",
|
||||
"AgentState",
|
||||
"ContextFraction",
|
||||
"ContextMessages",
|
||||
"ContextSize",
|
||||
"ContextT",
|
||||
"ContextTokens",
|
||||
"ModelRequest",
|
||||
"ModelResponse",
|
||||
"OmitFromSchema",
|
||||
@@ -65,6 +69,26 @@ __all__ = [
|
||||
JumpTo = Literal["tools", "model", "end"]
|
||||
"""Destination to jump to when a middleware node returns."""
|
||||
|
||||
ContextFraction = tuple[Literal["fraction"], float]
|
||||
"""Context size specified as a fraction (0.0 to 1.0) of the model's maximum tokens."""
|
||||
|
||||
ContextTokens = tuple[Literal["tokens"], int]
|
||||
"""Context size specified as an absolute token count."""
|
||||
|
||||
ContextMessages = tuple[Literal["messages"], int]
|
||||
"""Context size specified as a message count."""
|
||||
|
||||
ContextSize = ContextFraction | ContextTokens | ContextMessages
|
||||
"""Union type for context size configurations.
|
||||
|
||||
Can be one of:
|
||||
|
||||
- `ContextFraction`: Fractional limits (0.0 to 1.0), e.g., `("fraction", 0.8)`
|
||||
for 80% of model's max tokens
|
||||
- `ContextTokens`: Absolute token counts, e.g., `("tokens", 100000)`
|
||||
- `ContextMessages`: Message count limits, e.g., `("messages", 50)`
|
||||
"""
|
||||
|
||||
ResponseT = TypeVar("ResponseT")
|
||||
|
||||
|
||||
|
||||
@@ -399,3 +399,229 @@ async def test_exclude_tools_prevents_clearing_async() -> None:
|
||||
|
||||
assert isinstance(calc_tool, ToolMessage)
|
||||
assert calc_tool.content == "[cleared]"
|
||||
|
||||
|
||||
def test_new_api_with_tokens_trigger_and_messages_keep() -> None:
|
||||
"""Test new API with token-based trigger and message-based keep."""
|
||||
conversation: list[AIMessage | ToolMessage] = []
|
||||
for i in range(5):
|
||||
call_id = f"call-{i}"
|
||||
conversation.append(
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"id": call_id, "name": "tool", "args": {"input": call_id}}],
|
||||
)
|
||||
)
|
||||
conversation.append(ToolMessage(content="x" * 50, tool_call_id=call_id))
|
||||
|
||||
state, request = _make_state_and_request(conversation)
|
||||
|
||||
edit = ClearToolUsesEdit(
|
||||
trigger=("tokens", 100),
|
||||
keep=("messages", 2),
|
||||
placeholder="[cleared]",
|
||||
)
|
||||
middleware = ContextEditingMiddleware(edits=[edit])
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
# Count how many tool messages were cleared
|
||||
cleared_messages = [
|
||||
msg
|
||||
for msg in request.messages
|
||||
if isinstance(msg, ToolMessage) and msg.content == "[cleared]"
|
||||
]
|
||||
|
||||
# Should clear 3 out of 5 tool messages (keep 2)
|
||||
assert len(cleared_messages) == 3
|
||||
|
||||
# Last 2 tool messages should not be cleared
|
||||
tool_messages = [msg for msg in request.messages if isinstance(msg, ToolMessage)]
|
||||
assert len(tool_messages) == 5
|
||||
assert tool_messages[-1].content != "[cleared]"
|
||||
assert tool_messages[-2].content != "[cleared]"
|
||||
|
||||
|
||||
def test_new_api_with_messages_trigger() -> None:
|
||||
"""Test new API with message count trigger."""
|
||||
conversation: list[AIMessage | ToolMessage] = []
|
||||
for i in range(10):
|
||||
call_id = f"call-{i}"
|
||||
conversation.append(
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"id": call_id, "name": "tool", "args": {"input": call_id}}],
|
||||
)
|
||||
)
|
||||
conversation.append(ToolMessage(content="result", tool_call_id=call_id))
|
||||
|
||||
state, request = _make_state_and_request(conversation)
|
||||
|
||||
# Trigger when we have 15 or more messages
|
||||
edit = ClearToolUsesEdit(
|
||||
trigger=("messages", 15),
|
||||
keep=("messages", 3),
|
||||
placeholder="[cleared]",
|
||||
)
|
||||
middleware = ContextEditingMiddleware(edits=[edit])
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
# Should trigger because we have 20 messages (10 AI + 10 Tool)
|
||||
cleared_messages = [
|
||||
msg
|
||||
for msg in request.messages
|
||||
if isinstance(msg, ToolMessage) and msg.content == "[cleared]"
|
||||
]
|
||||
|
||||
# Should clear 7 out of 10 tool messages (keep 3)
|
||||
assert len(cleared_messages) == 7
|
||||
|
||||
|
||||
def test_new_api_with_multiple_triggers() -> None:
|
||||
"""Test new API with multiple trigger conditions."""
|
||||
conversation: list[AIMessage | ToolMessage] = []
|
||||
for i in range(3):
|
||||
call_id = f"call-{i}"
|
||||
conversation.append(
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"id": call_id, "name": "tool", "args": {"input": call_id}}],
|
||||
)
|
||||
)
|
||||
conversation.append(ToolMessage(content="x" * 100, tool_call_id=call_id))
|
||||
|
||||
state, request = _make_state_and_request(conversation)
|
||||
|
||||
# Trigger on either message count OR token count
|
||||
edit = ClearToolUsesEdit(
|
||||
trigger=[("messages", 5), ("tokens", 200)],
|
||||
keep=("messages", 1),
|
||||
placeholder="[cleared]",
|
||||
)
|
||||
middleware = ContextEditingMiddleware(edits=[edit])
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
# Should trigger because we have 6 messages (>5) and also >200 tokens
|
||||
cleared_messages = [
|
||||
msg
|
||||
for msg in request.messages
|
||||
if isinstance(msg, ToolMessage) and msg.content == "[cleared]"
|
||||
]
|
||||
|
||||
# Should clear 2 out of 3 tool messages (keep 1)
|
||||
assert len(cleared_messages) == 2
|
||||
|
||||
|
||||
def test_new_api_fractional_trigger_with_model_profile() -> None:
|
||||
"""Test new API with fractional trigger using model profile."""
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_model_profiles import ModelProfile
|
||||
|
||||
class ProfileModel(_TokenCountingChatModel):
|
||||
@property
|
||||
def profile(self) -> "ModelProfile":
|
||||
return {"max_input_tokens": 500}
|
||||
|
||||
conversation: list[AIMessage | ToolMessage] = []
|
||||
for i in range(10):
|
||||
call_id = f"call-{i}"
|
||||
conversation.append(
|
||||
AIMessage(
|
||||
content="x" * 50,
|
||||
tool_calls=[{"id": call_id, "name": "tool", "args": {"input": call_id}}],
|
||||
)
|
||||
)
|
||||
conversation.append(ToolMessage(content="x" * 50, tool_call_id=call_id))
|
||||
|
||||
model = ProfileModel()
|
||||
state = cast("AgentState", {"messages": conversation})
|
||||
request = ModelRequest(
|
||||
model=model,
|
||||
system_prompt=None,
|
||||
messages=conversation,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state=state,
|
||||
runtime=_fake_runtime(),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
# Trigger when we reach 80% of model's max input tokens
|
||||
# Total tokens = 1000 (10 AI messages * 50 + 10 tool messages * 50)
|
||||
# Max input tokens = 500
|
||||
# 80% of 500 = 400 tokens -> should trigger because 1000 > 400
|
||||
edit = ClearToolUsesEdit(
|
||||
trigger=("fraction", 0.8),
|
||||
keep=("messages", 3),
|
||||
placeholder="[cleared]",
|
||||
model=model,
|
||||
)
|
||||
middleware = ContextEditingMiddleware(edits=[edit])
|
||||
|
||||
def mock_handler(req: ModelRequest) -> AIMessage:
|
||||
return AIMessage(content="mock response")
|
||||
|
||||
middleware.wrap_model_call(request, mock_handler)
|
||||
|
||||
cleared_messages = [
|
||||
msg
|
||||
for msg in request.messages
|
||||
if isinstance(msg, ToolMessage) and msg.content == "[cleared]"
|
||||
]
|
||||
|
||||
# Should clear some messages (keep 3)
|
||||
assert len(cleared_messages) == 7
|
||||
|
||||
|
||||
def test_validation_errors() -> None:
|
||||
"""Test that validation errors are raised for invalid parameters."""
|
||||
import pytest
|
||||
|
||||
# Invalid fraction (> 1)
|
||||
with pytest.raises(ValueError, match="Fractional trigger values must be between 0 and 1"):
|
||||
ClearToolUsesEdit(trigger=("fraction", 1.5))
|
||||
|
||||
# Invalid fraction (< 0)
|
||||
with pytest.raises(ValueError, match="Fractional trigger values must be between 0 and 1"):
|
||||
ClearToolUsesEdit(trigger=("fraction", 0))
|
||||
|
||||
# Invalid trigger (0 tokens)
|
||||
with pytest.raises(ValueError, match="trigger thresholds must be greater than 0"):
|
||||
ClearToolUsesEdit(trigger=("tokens", 0))
|
||||
|
||||
# Invalid trigger (negative)
|
||||
with pytest.raises(ValueError, match="trigger thresholds must be greater than 0"):
|
||||
ClearToolUsesEdit(trigger=("messages", -1))
|
||||
|
||||
# Invalid keep (negative)
|
||||
with pytest.raises(ValueError, match="keep values must be non-negative"):
|
||||
ClearToolUsesEdit(keep=("messages", -1))
|
||||
|
||||
# Unsupported context size type
|
||||
with pytest.raises(ValueError, match="Unsupported context size type"):
|
||||
ClearToolUsesEdit(trigger=("invalid", 100)) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_deprecation_warnings() -> None:
|
||||
"""Test that deprecation warnings are raised for old API."""
|
||||
import pytest
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="Passing trigger as int is deprecated"):
|
||||
ClearToolUsesEdit(trigger=100)
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="Passing keep as int is deprecated"):
|
||||
ClearToolUsesEdit(keep=5)
|
||||
|
||||
Reference in New Issue
Block a user