mirror of
https://github.com/hwchase17/langchain.git
synced 2026-05-03 01:46:42 +00:00
feat(anthropic): add AnthropicToolIdSanitizationMiddleware
This commit is contained in:
@@ -13,9 +13,13 @@ from langchain_anthropic.middleware.file_search import (
|
||||
from langchain_anthropic.middleware.prompt_caching import (
|
||||
AnthropicPromptCachingMiddleware,
|
||||
)
|
||||
from langchain_anthropic.middleware.tool_id_sanitization import (
|
||||
AnthropicToolIdSanitizationMiddleware,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AnthropicPromptCachingMiddleware",
|
||||
"AnthropicToolIdSanitizationMiddleware",
|
||||
"ClaudeBashToolMiddleware",
|
||||
"FilesystemClaudeMemoryMiddleware",
|
||||
"FilesystemClaudeTextEditorMiddleware",
|
||||
|
||||
@@ -0,0 +1,475 @@
|
||||
"""Anthropic tool-call ID sanitization middleware."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, Literal
|
||||
from warnings import warn
|
||||
|
||||
from langchain_core.messages import AIMessage, AnyMessage, ToolMessage
|
||||
|
||||
from langchain_anthropic.chat_models import ChatAnthropic
|
||||
|
||||
try:
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelCallResult,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
msg = (
|
||||
"AnthropicToolIdSanitizationMiddleware requires 'langchain' to be "
|
||||
"installed. This middleware is designed for use with LangChain agents. "
|
||||
"Install it with: pip install langchain"
|
||||
)
|
||||
raise ImportError(msg) from e
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_ANTHROPIC_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+$")
|
||||
"""Character set Anthropic enforces server-side on `tool_use.id`.
|
||||
|
||||
Mirrors the regex echoed in the API errors returned for non-conforming IDs:
|
||||
`messages.N.content.M.tool_use.id: String should match pattern
|
||||
'^[a-zA-Z0-9_-]+$'`.
|
||||
"""
|
||||
|
||||
_INVALID_CHAR_PATTERN = re.compile(r"[^a-zA-Z0-9_-]")
|
||||
"""Complement of the character class in `_ANTHROPIC_ID_PATTERN`.
|
||||
|
||||
Matches any single illegal character. Used by `_make_safe_id` to substitute
|
||||
offending characters with `_`.
|
||||
"""
|
||||
|
||||
|
||||
def _is_tool_use_type(block_type: Any) -> bool:
|
||||
"""Return `True` if `block_type` names an ID-bearing tool-use block.
|
||||
|
||||
Covers `tool_use` (client tools), `server_tool_use` (Anthropic-emitted
|
||||
server tools), and `mcp_tool_use` (MCP). Uses suffix matching so future
|
||||
`*_tool_use` variants are picked up automatically.
|
||||
"""
|
||||
return isinstance(block_type, str) and block_type.endswith("tool_use")
|
||||
|
||||
|
||||
def _is_client_tool_use_type(block_type: Any) -> bool:
|
||||
"""Return `True` if `block_type` is a client `tool_use` block."""
|
||||
return block_type == "tool_use"
|
||||
|
||||
|
||||
def _is_tool_result_type(block_type: Any) -> bool:
|
||||
"""Return `True` if `block_type` names a `tool_use_id`-bearing result block.
|
||||
|
||||
Covers `tool_result`, `mcp_tool_result`, and Anthropic server-tool result
|
||||
variants like `web_search_tool_result`, `code_execution_tool_result`, and
|
||||
`bash_tool_result`. Uses suffix matching for forward compatibility.
|
||||
"""
|
||||
return isinstance(block_type, str) and block_type.endswith("tool_result")
|
||||
|
||||
|
||||
def _is_valid_id(tool_id: str | None) -> bool:
|
||||
"""Return `True` if `tool_id` matches Anthropic's required pattern."""
|
||||
if not tool_id:
|
||||
return False
|
||||
return _ANTHROPIC_ID_PATTERN.match(tool_id) is not None
|
||||
|
||||
|
||||
def _make_safe_id(original: str, used: set[str]) -> str:
|
||||
"""Derive an Anthropic-safe id, avoiding collisions within a single call.
|
||||
|
||||
Replaces every illegal character with `_`. If the resulting `base` already
|
||||
appears in `used`, appends a sha256-derived suffix (and a counter on
|
||||
further collisions) so each invalid input gets a distinct safe output
|
||||
within one sanitization pass.
|
||||
|
||||
Args:
|
||||
original: The invalid tool-call id to sanitize.
|
||||
used: The set of ids already taken in this pass — both pre-existing
|
||||
valid ids and previously-allocated safe ids.
|
||||
|
||||
Returns:
|
||||
A regex-conformant id distinct from every entry in `used`.
|
||||
"""
|
||||
base = _INVALID_CHAR_PATTERN.sub("_", original) or "tool"
|
||||
if base not in used:
|
||||
return base
|
||||
digest = hashlib.sha256(original.encode("utf-8")).hexdigest()[:8]
|
||||
candidate = f"{base}_{digest}"
|
||||
counter = 0
|
||||
while candidate in used:
|
||||
counter += 1
|
||||
candidate = f"{base}_{digest}_{counter}"
|
||||
return candidate
|
||||
|
||||
|
||||
class AnthropicToolIdSanitizationMiddleware(AgentMiddleware):
|
||||
"""Rewrite illegal tool-call IDs before sending to Anthropic.
|
||||
|
||||
Anthropic enforces `tool_use.id` matching `^[a-zA-Z0-9_-]+$`. Conversation
|
||||
histories produced by other providers can violate this — e.g. Kimi-K2 emits
|
||||
IDs of the form `functions.<name>:<idx>`, where `.` and `:` are illegal.
|
||||
Replaying such a thread against Claude raises a 400 error.
|
||||
|
||||
This middleware scans `request.messages` for offending IDs, builds a
|
||||
deterministic `bad_id -> safe_id` map, and rewrites every occurrence:
|
||||
|
||||
- `AIMessage.tool_calls[*]["id"]`
|
||||
- `AIMessage.content[*]["id"]` for `tool_use` / `server_tool_use` /
|
||||
`mcp_tool_use` blocks
|
||||
- `ToolMessage.tool_call_id`
|
||||
- `ToolMessage.content[*]["tool_use_id"]` for `tool_result` and the
|
||||
`*_tool_result` variants (`web_search_tool_result`,
|
||||
`code_execution_tool_result`, `bash_tool_result`, `mcp_tool_result`)
|
||||
|
||||
Within an `AIMessage`, position-paired `tool_calls[i]` and `tool_use`
|
||||
content blocks are forced to share the same final id even when their
|
||||
inputs disagree (drift), so Anthropic never receives a mismatched pair.
|
||||
|
||||
Only the outgoing `ModelRequest` is modified — graph state and persisted
|
||||
checkpoints are left untouched, so the sanitization is idempotent across
|
||||
turns and safe to combine with HITL resume.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unsupported_model_behavior: Literal["ignore", "warn", "raise"] = "ignore",
|
||||
) -> None:
|
||||
"""Initialize the middleware.
|
||||
|
||||
Args:
|
||||
unsupported_model_behavior: Behavior when the bound model is not
|
||||
`ChatAnthropic`.
|
||||
|
||||
`'ignore'` skips sanitization silently (default — other
|
||||
providers may accept the original IDs).
|
||||
|
||||
`'warn'` emits a warning and skips sanitization.
|
||||
|
||||
`'raise'` raises `ValueError`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `unsupported_model_behavior` is not one of the
|
||||
three allowed strings. `Literal` is enforced statically only;
|
||||
this guards against typos at runtime.
|
||||
"""
|
||||
allowed = ("ignore", "warn", "raise")
|
||||
if unsupported_model_behavior not in allowed:
|
||||
msg = (
|
||||
f"unsupported_model_behavior must be one of {allowed}; "
|
||||
f"got {unsupported_model_behavior!r}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
self.unsupported_model_behavior = unsupported_model_behavior
|
||||
|
||||
def _should_run(self, request: ModelRequest) -> bool:
|
||||
"""Return `True` if the bound model is `ChatAnthropic`.
|
||||
|
||||
Args:
|
||||
request: The model request to inspect.
|
||||
|
||||
Returns:
|
||||
`True` when sanitization should run; `False` to bypass.
|
||||
|
||||
Raises:
|
||||
ValueError: When the bound model is not `ChatAnthropic` and
|
||||
`unsupported_model_behavior='raise'`.
|
||||
"""
|
||||
if isinstance(request.model, ChatAnthropic):
|
||||
return True
|
||||
msg = (
|
||||
"AnthropicToolIdSanitizationMiddleware only supports Anthropic "
|
||||
f"models, not instances of {type(request.model)}"
|
||||
)
|
||||
if self.unsupported_model_behavior == "raise":
|
||||
raise ValueError(msg)
|
||||
if self.unsupported_model_behavior == "warn":
|
||||
warn(msg, stacklevel=3)
|
||||
else:
|
||||
# `ignore` mode — leave a breadcrumb so users debugging missing
|
||||
# sanitization can confirm the middleware was bypassed and why.
|
||||
logger.debug(
|
||||
"AnthropicToolIdSanitizationMiddleware skipped: bound model "
|
||||
"is %s, not ChatAnthropic.",
|
||||
type(request.model).__name__,
|
||||
)
|
||||
return False
|
||||
|
||||
def _build_id_map(self, messages: list[AnyMessage]) -> dict[str, str]:
|
||||
"""Build a deterministic mapping from invalid IDs to safe ones.
|
||||
|
||||
Args:
|
||||
messages: The outgoing message list to scan.
|
||||
|
||||
Returns:
|
||||
A dict mapping each invalid id to its safe replacement, or an
|
||||
empty dict when every id already conforms.
|
||||
"""
|
||||
invalid: list[str] = []
|
||||
seen: set[str] = set()
|
||||
valid_ids: set[str] = set()
|
||||
|
||||
for tool_id in _iter_all_ids(messages):
|
||||
if _is_valid_id(tool_id):
|
||||
valid_ids.add(tool_id)
|
||||
elif tool_id and tool_id not in seen:
|
||||
invalid.append(tool_id)
|
||||
seen.add(tool_id)
|
||||
|
||||
if not invalid:
|
||||
return {}
|
||||
|
||||
used = set(valid_ids)
|
||||
mapping: dict[str, str] = {}
|
||||
for original in invalid:
|
||||
new = _make_safe_id(original, used)
|
||||
mapping[original] = new
|
||||
used.add(new)
|
||||
return mapping
|
||||
|
||||
def _sanitize(self, request: ModelRequest) -> ModelRequest:
|
||||
"""Return a copy of `request` with sanitized tool-call IDs.
|
||||
|
||||
Args:
|
||||
request: The model request to sanitize.
|
||||
|
||||
Returns:
|
||||
A new request when any id was rewritten, otherwise the original.
|
||||
"""
|
||||
mapping = self._build_id_map(request.messages)
|
||||
rewritten, changed = _rewrite_messages(request.messages, mapping)
|
||||
if not changed:
|
||||
return request
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
"Sanitized %d tool-call id(s) for Anthropic compatibility: %s",
|
||||
len(mapping),
|
||||
{k: mapping[k] for k in sorted(mapping)},
|
||||
)
|
||||
return request.override(messages=rewritten)
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
"""Sanitize tool-call IDs before delegating to the handler.
|
||||
|
||||
Args:
|
||||
request: The model request to potentially modify.
|
||||
handler: The handler to execute the (possibly rewritten) request.
|
||||
|
||||
Returns:
|
||||
The model response produced by `handler`.
|
||||
"""
|
||||
if not self._should_run(request):
|
||||
return handler(request)
|
||||
return handler(self._sanitize(request))
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
"""Sanitize tool-call IDs before delegating to the async handler.
|
||||
|
||||
Args:
|
||||
request: The model request to potentially modify.
|
||||
handler: The async handler to execute the (possibly rewritten)
|
||||
request.
|
||||
|
||||
Returns:
|
||||
The model response produced by `handler`.
|
||||
"""
|
||||
if not self._should_run(request):
|
||||
return await handler(request)
|
||||
return await handler(self._sanitize(request))
|
||||
|
||||
|
||||
def _iter_all_ids(messages: list[AnyMessage]) -> list[str]:
|
||||
"""Collect every tool-call ID present in the message list."""
|
||||
ids: list[str] = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, AIMessage):
|
||||
for tc in msg.tool_calls or []:
|
||||
tid = tc.get("id")
|
||||
if tid:
|
||||
ids.append(tid)
|
||||
if isinstance(msg.content, list):
|
||||
for block in msg.content:
|
||||
if isinstance(block, dict) and _is_tool_use_type(block.get("type")):
|
||||
tid = block.get("id")
|
||||
if isinstance(tid, str) and tid:
|
||||
ids.append(tid)
|
||||
elif isinstance(msg, ToolMessage):
|
||||
if msg.tool_call_id:
|
||||
ids.append(msg.tool_call_id)
|
||||
if isinstance(msg.content, list):
|
||||
for block in msg.content:
|
||||
if isinstance(block, dict) and _is_tool_result_type(
|
||||
block.get("type")
|
||||
):
|
||||
tid = block.get("tool_use_id")
|
||||
if isinstance(tid, str) and tid:
|
||||
ids.append(tid)
|
||||
return ids
|
||||
|
||||
|
||||
def _rewrite_messages(
|
||||
messages: list[AnyMessage], mapping: dict[str, str]
|
||||
) -> tuple[list[AnyMessage], bool]:
|
||||
"""Return messages rewritten with sanitized IDs and local drift aliases."""
|
||||
rewritten: list[AnyMessage] = []
|
||||
changed = False
|
||||
active_result_aliases: dict[str, str] = {}
|
||||
|
||||
for msg in messages:
|
||||
new_msg: AnyMessage
|
||||
if isinstance(msg, AIMessage):
|
||||
new_msg, active_result_aliases = _rewrite_ai_message(msg, mapping)
|
||||
elif isinstance(msg, ToolMessage):
|
||||
result_mapping = {**mapping, **active_result_aliases}
|
||||
new_msg = _rewrite_tool_message(msg, result_mapping)
|
||||
else:
|
||||
active_result_aliases = {}
|
||||
new_msg = msg
|
||||
|
||||
if new_msg is not msg:
|
||||
changed = True
|
||||
rewritten.append(new_msg)
|
||||
|
||||
return rewritten, changed
|
||||
|
||||
|
||||
def _rewrite_ai_message(
|
||||
msg: AIMessage, mapping: dict[str, str]
|
||||
) -> tuple[AIMessage, dict[str, str]]:
|
||||
"""Return a copy of `msg` with `tool_calls` and tool-use blocks aligned.
|
||||
|
||||
Applies `mapping` to each id, then enforces position-paired alignment so
|
||||
`tool_calls[i].id` and the i-th client `tool_use` content block share the
|
||||
same final id even if their inputs drifted.
|
||||
"""
|
||||
result_aliases: dict[str, str] = {}
|
||||
new_tool_calls: list[Any] | None = None
|
||||
if msg.tool_calls and any(tc.get("id") in mapping for tc in msg.tool_calls):
|
||||
new_tool_calls = []
|
||||
for tc in msg.tool_calls:
|
||||
tid = tc.get("id")
|
||||
if tid is not None and tid in mapping:
|
||||
new_tool_calls.append({**tc, "id": mapping[tid]})
|
||||
else:
|
||||
new_tool_calls.append(tc)
|
||||
|
||||
new_content: list[Any] | None = None
|
||||
if isinstance(msg.content, list) and any(
|
||||
isinstance(b, dict)
|
||||
and _is_tool_use_type(b.get("type"))
|
||||
and b.get("id") in mapping
|
||||
for b in msg.content
|
||||
):
|
||||
new_content = [
|
||||
{**b, "id": mapping[b["id"]]}
|
||||
if (
|
||||
isinstance(b, dict)
|
||||
and _is_tool_use_type(b.get("type"))
|
||||
and b.get("id") in mapping
|
||||
)
|
||||
else b
|
||||
for b in msg.content
|
||||
]
|
||||
|
||||
# Drift correction: even if `tool_calls[i].id` and the i-th client
|
||||
# `tool_use` block disagreed pre-mapping (or post-mapping due to distinct
|
||||
# invalid inputs collapsing through different sanitization branches), force
|
||||
# the content block to adopt the canonical `tool_calls[i].id`. `tool_calls`
|
||||
# is the langchain-canonical view.
|
||||
effective_tool_calls = (
|
||||
new_tool_calls if new_tool_calls is not None else msg.tool_calls
|
||||
)
|
||||
effective_content = new_content if new_content is not None else msg.content
|
||||
if effective_tool_calls and isinstance(effective_content, list):
|
||||
aligned, result_aliases, drift_seen = _align_client_tool_use_blocks(
|
||||
effective_tool_calls,
|
||||
effective_content,
|
||||
msg.content if isinstance(msg.content, list) else effective_content,
|
||||
)
|
||||
if drift_seen:
|
||||
new_content = aligned
|
||||
|
||||
updates: dict[str, Any] = {}
|
||||
if new_tool_calls is not None:
|
||||
updates["tool_calls"] = new_tool_calls
|
||||
if new_content is not None:
|
||||
updates["content"] = new_content
|
||||
return (msg.model_copy(update=updates) if updates else msg), result_aliases
|
||||
|
||||
|
||||
def _align_client_tool_use_blocks(
|
||||
tool_calls: list[Any], content: list[Any], original_content: list[Any]
|
||||
) -> tuple[list[Any], dict[str, str], bool]:
|
||||
"""Force tool-use content blocks to share IDs with corresponding tool_calls.
|
||||
|
||||
Returns the (possibly rewritten) content list plus aliases from drifted
|
||||
content-block IDs to the canonical tool-call ID. These aliases are applied
|
||||
only to the following `ToolMessage` results for this assistant turn.
|
||||
"""
|
||||
result_aliases: dict[str, str] = {}
|
||||
drift_seen = False
|
||||
aligned: list[Any] = list(content)
|
||||
tu_indices = [
|
||||
idx
|
||||
for idx, b in enumerate(content)
|
||||
if isinstance(b, dict) and _is_client_tool_use_type(b.get("type"))
|
||||
]
|
||||
for pair_idx, content_idx in enumerate(tu_indices):
|
||||
if pair_idx >= len(tool_calls):
|
||||
break
|
||||
canonical_id = tool_calls[pair_idx].get("id")
|
||||
block = aligned[content_idx]
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if canonical_id and block.get("id") != canonical_id:
|
||||
aligned[content_idx] = {**block, "id": canonical_id}
|
||||
drift_seen = True
|
||||
original_block = original_content[content_idx]
|
||||
if isinstance(original_block, dict):
|
||||
original_id = original_block.get("id")
|
||||
if isinstance(original_id, str) and original_id != canonical_id:
|
||||
result_aliases[original_id] = canonical_id
|
||||
current_id = block.get("id")
|
||||
if isinstance(current_id, str) and current_id != canonical_id:
|
||||
result_aliases[current_id] = canonical_id
|
||||
return aligned, result_aliases, drift_seen
|
||||
|
||||
|
||||
def _rewrite_tool_message(msg: ToolMessage, mapping: dict[str, str]) -> ToolMessage:
|
||||
"""Return a copy of `msg` with `tool_call_id` and tool-result blocks rewritten."""
|
||||
updates: dict[str, Any] = {}
|
||||
|
||||
if msg.tool_call_id in mapping:
|
||||
updates["tool_call_id"] = mapping[msg.tool_call_id]
|
||||
|
||||
if isinstance(msg.content, list) and any(
|
||||
isinstance(b, dict)
|
||||
and _is_tool_result_type(b.get("type"))
|
||||
and b.get("tool_use_id") in mapping
|
||||
for b in msg.content
|
||||
):
|
||||
updates["content"] = [
|
||||
{**b, "tool_use_id": mapping[b["tool_use_id"]]}
|
||||
if (
|
||||
isinstance(b, dict)
|
||||
and _is_tool_result_type(b.get("type"))
|
||||
and b.get("tool_use_id") in mapping
|
||||
)
|
||||
else b
|
||||
for b in msg.content
|
||||
]
|
||||
|
||||
return msg.model_copy(update=updates) if updates else msg
|
||||
@@ -0,0 +1,888 @@
|
||||
"""Tests for Anthropic tool-call ID sanitization middleware."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import warnings
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain.agents.middleware.types import AgentState, ModelRequest, ModelResponse
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from langchain_anthropic.chat_models import ChatAnthropic
|
||||
from langchain_anthropic.middleware import AnthropicToolIdSanitizationMiddleware
|
||||
|
||||
_ANTHROPIC_ID_RE = re.compile(r"^[a-zA-Z0-9_-]+$")
|
||||
|
||||
|
||||
class _FakeModel(BaseChatModel):
|
||||
"""Stand-in non-Anthropic model for unsupported-model tests."""
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content="ok", id="0"))]
|
||||
)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content="ok", id="0"))]
|
||||
)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake"
|
||||
|
||||
|
||||
def _make_request(
|
||||
messages: list[BaseMessage], *, anthropic: bool = True
|
||||
) -> ModelRequest:
|
||||
"""Build a `ModelRequest` for testing."""
|
||||
if anthropic:
|
||||
mock_model = MagicMock(spec=ChatAnthropic)
|
||||
mock_model._llm_type = "anthropic-chat"
|
||||
model: BaseChatModel = cast(BaseChatModel, mock_model)
|
||||
else:
|
||||
model = _FakeModel()
|
||||
return ModelRequest(
|
||||
model=model,
|
||||
messages=cast(list[AnyMessage], messages),
|
||||
system_prompt=None,
|
||||
tool_choice=None,
|
||||
tools=[],
|
||||
response_format=None,
|
||||
state=cast(AgentState[Any], {"messages": messages}),
|
||||
runtime=cast(Runtime, object()),
|
||||
model_settings={},
|
||||
)
|
||||
|
||||
|
||||
def _capture(
|
||||
messages: list[BaseMessage],
|
||||
) -> tuple[
|
||||
AnthropicToolIdSanitizationMiddleware,
|
||||
ModelRequest,
|
||||
list[ModelRequest],
|
||||
]:
|
||||
"""Wire a middleware + request + capturing handler list for a test."""
|
||||
middleware = AnthropicToolIdSanitizationMiddleware()
|
||||
request = _make_request(messages)
|
||||
captured: list[ModelRequest] = []
|
||||
|
||||
return middleware, request, captured
|
||||
|
||||
|
||||
def _handler_factory(captured: list[ModelRequest]) -> Any:
|
||||
def _handler(req: ModelRequest) -> ModelResponse:
|
||||
captured.append(req)
|
||||
return ModelResponse(result=[AIMessage(content="ok")])
|
||||
|
||||
return _handler
|
||||
|
||||
|
||||
def _valid_tool_call_id(value: str | None) -> str:
|
||||
"""Assert that an optional tool-call ID is present and Anthropic-safe."""
|
||||
assert value is not None
|
||||
assert _ANTHROPIC_ID_RE.match(value)
|
||||
return value
|
||||
|
||||
|
||||
def test_passthrough_when_all_ids_valid() -> None:
|
||||
"""Valid IDs short-circuit: every message reaches the handler unchanged."""
|
||||
messages: list[BaseMessage] = [
|
||||
HumanMessage("hi"),
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"name": "grep", "args": {}, "id": "toolu_01abc", "type": "tool_call"}
|
||||
],
|
||||
),
|
||||
ToolMessage(content="done", tool_call_id="toolu_01abc"),
|
||||
]
|
||||
middleware, request, captured = _capture(messages)
|
||||
|
||||
middleware.wrap_model_call(request, _handler_factory(captured))
|
||||
|
||||
assert len(captured) == 1
|
||||
# Per-message identity: each individual message is the original instance.
|
||||
for original, sent in zip(request.messages, captured[0].messages, strict=True):
|
||||
assert sent is original
|
||||
|
||||
|
||||
def test_rewrites_kimi_k2_style_ids_in_pairs() -> None:
|
||||
"""Bad IDs on AIMessage tool_calls and matching ToolMessage are rewritten."""
|
||||
messages: list[BaseMessage] = [
|
||||
HumanMessage("hi"),
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "grep",
|
||||
"args": {"q": "x"},
|
||||
"id": "functions.grep:0",
|
||||
"type": "tool_call",
|
||||
},
|
||||
{
|
||||
"name": "grep",
|
||||
"args": {"q": "y"},
|
||||
"id": "functions.grep:1",
|
||||
"type": "tool_call",
|
||||
},
|
||||
],
|
||||
),
|
||||
ToolMessage(content="r0", tool_call_id="functions.grep:0"),
|
||||
ToolMessage(content="r1", tool_call_id="functions.grep:1"),
|
||||
]
|
||||
middleware, request, captured = _capture(messages)
|
||||
|
||||
middleware.wrap_model_call(request, _handler_factory(captured))
|
||||
|
||||
sent = captured[0].messages
|
||||
ai = sent[1]
|
||||
assert isinstance(ai, AIMessage)
|
||||
new_ids = [_valid_tool_call_id(tc["id"]) for tc in ai.tool_calls]
|
||||
tool_msgs = [m for m in sent if isinstance(m, ToolMessage)]
|
||||
assert [tm.tool_call_id for tm in tool_msgs] == new_ids
|
||||
assert new_ids[0] != new_ids[1]
|
||||
|
||||
|
||||
def test_rewrite_is_deterministic_across_invocations() -> None:
|
||||
"""The same input produces the same sanitized IDs across separate calls."""
|
||||
payload: list[BaseMessage] = [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "grep",
|
||||
"args": {},
|
||||
"id": "functions.grep:0",
|
||||
"type": "tool_call",
|
||||
},
|
||||
{
|
||||
"name": "grep",
|
||||
"args": {},
|
||||
"id": "functions.grep:1",
|
||||
"type": "tool_call",
|
||||
},
|
||||
],
|
||||
),
|
||||
ToolMessage(content="r0", tool_call_id="functions.grep:0"),
|
||||
ToolMessage(content="r1", tool_call_id="functions.grep:1"),
|
||||
]
|
||||
|
||||
def _run() -> list[str]:
|
||||
# Build a fresh copy each invocation so the two runs are independent.
|
||||
msgs = [m.model_copy(deep=True) for m in payload]
|
||||
middleware = AnthropicToolIdSanitizationMiddleware()
|
||||
request = _make_request(msgs)
|
||||
captured: list[ModelRequest] = []
|
||||
middleware.wrap_model_call(request, _handler_factory(captured))
|
||||
ai = cast(AIMessage, captured[0].messages[0])
|
||||
return [_valid_tool_call_id(tc["id"]) for tc in ai.tool_calls]
|
||||
|
||||
assert _run() == _run()
|
||||
|
||||
|
||||
def test_rewrites_anthropic_content_blocks() -> None:
|
||||
"""`tool_use` and `tool_result` content blocks have their IDs rewritten too."""
|
||||
messages: list[BaseMessage] = [
|
||||
AIMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "calling"},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "functions.read_file:0",
|
||||
"name": "read_file",
|
||||
"input": {},
|
||||
},
|
||||
],
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "read_file",
|
||||
"args": {},
|
||||
"id": "functions.read_file:0",
|
||||
"type": "tool_call",
|
||||
},
|
||||
],
|
||||
),
|
||||
ToolMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "functions.read_file:0",
|
||||
"content": "hi",
|
||||
},
|
||||
],
|
||||
tool_call_id="functions.read_file:0",
|
||||
),
|
||||
]
|
||||
middleware, request, captured = _capture(messages)
|
||||
|
||||
middleware.wrap_model_call(request, _handler_factory(captured))
|
||||
|
||||
sent = captured[0].messages
|
||||
ai = sent[0]
|
||||
tm = sent[1]
|
||||
assert isinstance(ai, AIMessage)
|
||||
assert isinstance(tm, ToolMessage)
|
||||
|
||||
block_id = next(
|
||||
b["id"]
|
||||
for b in ai.content
|
||||
if isinstance(b, dict) and b.get("type") == "tool_use"
|
||||
)
|
||||
tool_call_id = _valid_tool_call_id(ai.tool_calls[0]["id"])
|
||||
result_id = next(
|
||||
b["tool_use_id"]
|
||||
for b in tm.content
|
||||
if isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
)
|
||||
|
||||
assert block_id == tool_call_id == tm.tool_call_id == result_id
|
||||
for value in (block_id, tool_call_id, tm.tool_call_id, result_id):
|
||||
assert _ANTHROPIC_ID_RE.match(value)
|
||||
|
||||
|
||||
def test_drift_between_tool_calls_and_tool_use_blocks_is_corrected() -> None:
|
||||
"""When `tool_calls[i].id` and the i-th tool_use block disagree, alignment wins.
|
||||
|
||||
Without correction, two distinct invalid IDs would map to two different
|
||||
safe IDs (`a_b` vs `a_b_<digest>`), and Anthropic would 400 on the
|
||||
mismatched pair. The middleware forces both to share `tool_calls[i].id`.
|
||||
"""
|
||||
messages: list[BaseMessage] = [
|
||||
AIMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "a:b", # drift — different illegal id
|
||||
"name": "tool",
|
||||
"input": {},
|
||||
},
|
||||
],
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "tool",
|
||||
"args": {},
|
||||
"id": "a.b", # canonical
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
),
|
||||
ToolMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "a:b",
|
||||
"content": "r",
|
||||
},
|
||||
],
|
||||
tool_call_id="a:b",
|
||||
),
|
||||
]
|
||||
middleware, request, captured = _capture(messages)
|
||||
|
||||
middleware.wrap_model_call(request, _handler_factory(captured))
|
||||
|
||||
sent = captured[0].messages
|
||||
ai = cast(AIMessage, sent[0])
|
||||
tm = cast(ToolMessage, sent[1])
|
||||
|
||||
final_tc_id = _valid_tool_call_id(ai.tool_calls[0]["id"])
|
||||
final_block_id = next(
|
||||
b["id"]
|
||||
for b in ai.content
|
||||
if isinstance(b, dict) and b.get("type") == "tool_use"
|
||||
)
|
||||
final_result_id = next(
|
||||
b["tool_use_id"]
|
||||
for b in tm.content
|
||||
if isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
)
|
||||
|
||||
assert final_tc_id == final_block_id == tm.tool_call_id == final_result_id
|
||||
|
||||
|
||||
def test_valid_drift_between_tool_calls_and_tool_use_blocks_is_corrected() -> None:
|
||||
"""Valid but mismatched tool-call and `tool_use` IDs are aligned."""
|
||||
messages: list[BaseMessage] = [
|
||||
AIMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_content",
|
||||
"name": "tool",
|
||||
"input": {},
|
||||
},
|
||||
],
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "tool",
|
||||
"args": {},
|
||||
"id": "toolu_call",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
),
|
||||
ToolMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_content",
|
||||
"content": "r",
|
||||
},
|
||||
],
|
||||
tool_call_id="toolu_content",
|
||||
),
|
||||
]
|
||||
middleware, request, captured = _capture(messages)
|
||||
|
||||
middleware.wrap_model_call(request, _handler_factory(captured))
|
||||
|
||||
sent = captured[0].messages
|
||||
ai = cast(AIMessage, sent[0])
|
||||
tm = cast(ToolMessage, sent[1])
|
||||
block_id = next(
|
||||
b["id"]
|
||||
for b in ai.content
|
||||
if isinstance(b, dict) and b.get("type") == "tool_use"
|
||||
)
|
||||
result_id = next(
|
||||
b["tool_use_id"]
|
||||
for b in tm.content
|
||||
if isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
)
|
||||
|
||||
assert ai.tool_calls[0]["id"] == "toolu_call"
|
||||
assert block_id == "toolu_call"
|
||||
assert tm.tool_call_id == "toolu_call"
|
||||
assert result_id == "toolu_call"
|
||||
|
||||
|
||||
def test_two_invalid_ids_with_same_base_get_distinct_safe_ids() -> None:
|
||||
"""`a.b` and `a:b` both sanitize to base `a_b` — second falls back to suffix.
|
||||
|
||||
Exercises the sha256-suffix branch of `_make_safe_id`. Both AIMessages
|
||||
use the same callable name so name-based heuristics can't disambiguate;
|
||||
only the `_make_safe_id` collision logic produces distinct safe IDs.
|
||||
"""
|
||||
messages: list[BaseMessage] = [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"name": "tool", "args": {}, "id": "a.b", "type": "tool_call"}],
|
||||
),
|
||||
ToolMessage(content="r1", tool_call_id="a.b"),
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"name": "tool", "args": {}, "id": "a:b", "type": "tool_call"}],
|
||||
),
|
||||
ToolMessage(content="r2", tool_call_id="a:b"),
|
||||
]
|
||||
middleware, request, captured = _capture(messages)
|
||||
|
||||
middleware.wrap_model_call(request, _handler_factory(captured))
|
||||
|
||||
sent = captured[0].messages
|
||||
first_id = _valid_tool_call_id(cast(AIMessage, sent[0]).tool_calls[0]["id"])
|
||||
second_id = _valid_tool_call_id(cast(AIMessage, sent[2]).tool_calls[0]["id"])
|
||||
|
||||
# Both sanitize cleanly, both are valid, both are distinct.
|
||||
assert first_id != second_id
|
||||
# The first invalid wins the base; the second gets the suffix.
|
||||
assert first_id == "a_b"
|
||||
assert second_id.startswith("a_b_")
|
||||
# Pairs survive: ToolMessage ids match their AIMessage counterparts.
|
||||
assert cast(ToolMessage, sent[1]).tool_call_id == first_id
|
||||
assert cast(ToolMessage, sent[3]).tool_call_id == second_id
|
||||
|
||||
|
||||
def test_state_is_not_mutated() -> None:
|
||||
"""Original messages on the request and in graph state are not mutated."""
|
||||
original_id = "functions.grep:0"
|
||||
messages: list[BaseMessage] = [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"name": "grep", "args": {}, "id": original_id, "type": "tool_call"}
|
||||
],
|
||||
),
|
||||
ToolMessage(content="r", tool_call_id=original_id),
|
||||
]
|
||||
middleware, request, captured = _capture(messages)
|
||||
|
||||
middleware.wrap_model_call(request, _handler_factory(captured))
|
||||
|
||||
assert request.messages[0].tool_calls[0]["id"] == original_id # type: ignore[union-attr]
|
||||
assert cast(ToolMessage, request.messages[1]).tool_call_id == original_id
|
||||
assert captured[0].messages is not request.messages
|
||||
|
||||
|
||||
def test_collision_avoidance_with_existing_valid_id() -> None:
|
||||
"""A bad ID that sanitizes to an already-used valid ID gets a hash suffix."""
|
||||
messages: list[BaseMessage] = [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "grep",
|
||||
"args": {},
|
||||
"id": "functions_grep_0",
|
||||
"type": "tool_call",
|
||||
},
|
||||
],
|
||||
),
|
||||
ToolMessage(content="a", tool_call_id="functions_grep_0"),
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "grep",
|
||||
"args": {},
|
||||
"id": "functions.grep:0",
|
||||
"type": "tool_call",
|
||||
},
|
||||
],
|
||||
),
|
||||
ToolMessage(content="b", tool_call_id="functions.grep:0"),
|
||||
]
|
||||
middleware, request, captured = _capture(messages)
|
||||
|
||||
middleware.wrap_model_call(request, _handler_factory(captured))
|
||||
|
||||
sent = captured[0].messages
|
||||
first_id = _valid_tool_call_id(cast(AIMessage, sent[0]).tool_calls[0]["id"])
|
||||
second_id = _valid_tool_call_id(cast(AIMessage, sent[2]).tool_calls[0]["id"])
|
||||
|
||||
assert first_id == "functions_grep_0"
|
||||
assert second_id != first_id
|
||||
|
||||
|
||||
def test_mixed_valid_and_invalid_ids_in_same_message() -> None:
|
||||
"""Valid IDs are preserved while siblings with invalid IDs are rewritten."""
|
||||
messages: list[BaseMessage] = [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "good",
|
||||
"args": {},
|
||||
"id": "toolu_clean",
|
||||
"type": "tool_call",
|
||||
},
|
||||
{
|
||||
"name": "bad",
|
||||
"args": {},
|
||||
"id": "functions.bad:0",
|
||||
"type": "tool_call",
|
||||
},
|
||||
],
|
||||
),
|
||||
ToolMessage(content="g", tool_call_id="toolu_clean"),
|
||||
ToolMessage(content="b", tool_call_id="functions.bad:0"),
|
||||
]
|
||||
middleware, request, captured = _capture(messages)
|
||||
|
||||
middleware.wrap_model_call(request, _handler_factory(captured))
|
||||
|
||||
sent = captured[0].messages
|
||||
ai = cast(AIMessage, sent[0])
|
||||
valid_id = _valid_tool_call_id(ai.tool_calls[0]["id"])
|
||||
rewritten_id = _valid_tool_call_id(ai.tool_calls[1]["id"])
|
||||
|
||||
assert valid_id == "toolu_clean" # untouched
|
||||
assert rewritten_id != "functions.bad:0"
|
||||
assert cast(ToolMessage, sent[1]).tool_call_id == valid_id
|
||||
assert cast(ToolMessage, sent[2]).tool_call_id == rewritten_id
|
||||
|
||||
|
||||
def test_rewrites_server_and_mcp_block_types() -> None:
|
||||
"""`server_tool_use` / `mcp_tool_use` and their result variants are also handled."""
|
||||
messages: list[BaseMessage] = [
|
||||
AIMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "server_tool_use",
|
||||
"id": "srv.1",
|
||||
"name": "web_search",
|
||||
"input": {},
|
||||
},
|
||||
{
|
||||
"type": "mcp_tool_use",
|
||||
"id": "mcp.2",
|
||||
"name": "fetch",
|
||||
"input": {},
|
||||
},
|
||||
],
|
||||
tool_calls=[],
|
||||
),
|
||||
ToolMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": "srv.1",
|
||||
"content": "hits",
|
||||
},
|
||||
],
|
||||
tool_call_id="srv.1",
|
||||
),
|
||||
ToolMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "mcp_tool_result",
|
||||
"tool_use_id": "mcp.2",
|
||||
"content": "data",
|
||||
},
|
||||
],
|
||||
tool_call_id="mcp.2",
|
||||
),
|
||||
]
|
||||
middleware, request, captured = _capture(messages)
|
||||
|
||||
middleware.wrap_model_call(request, _handler_factory(captured))
|
||||
|
||||
sent = captured[0].messages
|
||||
ai = cast(AIMessage, sent[0])
|
||||
server_block_id = next(
|
||||
b["id"]
|
||||
for b in ai.content
|
||||
if isinstance(b, dict) and b.get("type") == "server_tool_use"
|
||||
)
|
||||
mcp_block_id = next(
|
||||
b["id"]
|
||||
for b in ai.content
|
||||
if isinstance(b, dict) and b.get("type") == "mcp_tool_use"
|
||||
)
|
||||
|
||||
for value in (
|
||||
server_block_id,
|
||||
mcp_block_id,
|
||||
cast(ToolMessage, sent[1]).tool_call_id,
|
||||
cast(ToolMessage, sent[2]).tool_call_id,
|
||||
):
|
||||
assert _ANTHROPIC_ID_RE.match(value)
|
||||
# Pair survival:
|
||||
web_result_id = next(
|
||||
b["tool_use_id"]
|
||||
for b in cast(ToolMessage, sent[1]).content
|
||||
if isinstance(b, dict) and b.get("type") == "web_search_tool_result"
|
||||
)
|
||||
mcp_result_id = next(
|
||||
b["tool_use_id"]
|
||||
for b in cast(ToolMessage, sent[2]).content
|
||||
if isinstance(b, dict) and b.get("type") == "mcp_tool_result"
|
||||
)
|
||||
assert web_result_id == server_block_id
|
||||
assert mcp_result_id == mcp_block_id
|
||||
|
||||
|
||||
def test_client_alignment_skips_server_and_mcp_tool_use_blocks() -> None:
|
||||
"""Server and MCP tool-use blocks do not position-pair with `tool_calls`."""
|
||||
messages: list[BaseMessage] = [
|
||||
AIMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "server_tool_use",
|
||||
"id": "srv.1",
|
||||
"name": "web_search",
|
||||
"input": {},
|
||||
},
|
||||
{
|
||||
"type": "mcp_tool_use",
|
||||
"id": "mcp.2",
|
||||
"name": "fetch",
|
||||
"input": {},
|
||||
},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "client.3",
|
||||
"name": "client_tool",
|
||||
"input": {},
|
||||
},
|
||||
],
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "client_tool",
|
||||
"args": {},
|
||||
"id": "client.3",
|
||||
"type": "tool_call",
|
||||
},
|
||||
],
|
||||
),
|
||||
ToolMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": "srv.1",
|
||||
"content": "hits",
|
||||
},
|
||||
],
|
||||
tool_call_id="srv.1",
|
||||
),
|
||||
ToolMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "mcp_tool_result",
|
||||
"tool_use_id": "mcp.2",
|
||||
"content": "data",
|
||||
},
|
||||
],
|
||||
tool_call_id="mcp.2",
|
||||
),
|
||||
ToolMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "client.3",
|
||||
"content": "done",
|
||||
},
|
||||
],
|
||||
tool_call_id="client.3",
|
||||
),
|
||||
]
|
||||
middleware, request, captured = _capture(messages)
|
||||
|
||||
middleware.wrap_model_call(request, _handler_factory(captured))
|
||||
|
||||
sent = captured[0].messages
|
||||
ai = cast(AIMessage, sent[0])
|
||||
server_block_id = next(
|
||||
b["id"]
|
||||
for b in ai.content
|
||||
if isinstance(b, dict) and b.get("type") == "server_tool_use"
|
||||
)
|
||||
mcp_block_id = next(
|
||||
b["id"]
|
||||
for b in ai.content
|
||||
if isinstance(b, dict) and b.get("type") == "mcp_tool_use"
|
||||
)
|
||||
client_block_id = next(
|
||||
b["id"]
|
||||
for b in ai.content
|
||||
if isinstance(b, dict) and b.get("type") == "tool_use"
|
||||
)
|
||||
client_call_id = _valid_tool_call_id(ai.tool_calls[0]["id"])
|
||||
web_result_id = next(
|
||||
b["tool_use_id"]
|
||||
for b in cast(ToolMessage, sent[1]).content
|
||||
if isinstance(b, dict) and b.get("type") == "web_search_tool_result"
|
||||
)
|
||||
mcp_result_id = next(
|
||||
b["tool_use_id"]
|
||||
for b in cast(ToolMessage, sent[2]).content
|
||||
if isinstance(b, dict) and b.get("type") == "mcp_tool_result"
|
||||
)
|
||||
client_result_id = next(
|
||||
b["tool_use_id"]
|
||||
for b in cast(ToolMessage, sent[3]).content
|
||||
if isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
)
|
||||
|
||||
assert server_block_id == web_result_id == cast(ToolMessage, sent[1]).tool_call_id
|
||||
assert mcp_block_id == mcp_result_id == cast(ToolMessage, sent[2]).tool_call_id
|
||||
assert client_block_id == client_call_id == client_result_id
|
||||
assert cast(ToolMessage, sent[3]).tool_call_id == client_call_id
|
||||
assert server_block_id != client_call_id
|
||||
assert mcp_block_id != client_call_id
|
||||
|
||||
|
||||
def test_partial_update_only_rewrites_what_changed() -> None:
|
||||
"""An AIMessage with a clean tool_calls but dirty content block triggers
|
||||
rewrite of the content only, not the tool_calls list (partial-update path).
|
||||
"""
|
||||
messages: list[BaseMessage] = [
|
||||
AIMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "thinking"},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "functions.x:0", # invalid
|
||||
"name": "x",
|
||||
"input": {},
|
||||
},
|
||||
],
|
||||
tool_calls=[
|
||||
# Valid id; should NOT be rewritten by mapping. Drift correction
|
||||
# will re-align the content block to use this id.
|
||||
{"name": "x", "args": {}, "id": "toolu_canon", "type": "tool_call"},
|
||||
],
|
||||
),
|
||||
ToolMessage(content="r", tool_call_id="toolu_canon"),
|
||||
]
|
||||
middleware, request, captured = _capture(messages)
|
||||
|
||||
middleware.wrap_model_call(request, _handler_factory(captured))
|
||||
|
||||
sent = captured[0].messages
|
||||
ai = cast(AIMessage, sent[0])
|
||||
block_id = next(
|
||||
b["id"]
|
||||
for b in ai.content
|
||||
if isinstance(b, dict) and b.get("type") == "tool_use"
|
||||
)
|
||||
# Drift correction collapses the divergent block id onto the canonical one.
|
||||
assert ai.tool_calls[0]["id"] == "toolu_canon"
|
||||
assert block_id == "toolu_canon"
|
||||
assert cast(ToolMessage, sent[1]).tool_call_id == "toolu_canon"
|
||||
|
||||
|
||||
def test_unrelated_message_subclasses_pass_through() -> None:
|
||||
"""SystemMessage and HumanMessage are returned untouched even mid-rewrite."""
|
||||
sys_msg = SystemMessage(content="you are an agent")
|
||||
human = HumanMessage(content="hi")
|
||||
messages: list[BaseMessage] = [
|
||||
sys_msg,
|
||||
human,
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"name": "x", "args": {}, "id": "functions.x:0", "type": "tool_call"}
|
||||
],
|
||||
),
|
||||
ToolMessage(content="r", tool_call_id="functions.x:0"),
|
||||
]
|
||||
middleware, request, captured = _capture(messages)
|
||||
|
||||
middleware.wrap_model_call(request, _handler_factory(captured))
|
||||
|
||||
sent = captured[0].messages
|
||||
# Non-AI/non-Tool messages keep their identity.
|
||||
assert sent[0] is sys_msg
|
||||
assert sent[1] is human
|
||||
|
||||
|
||||
async def test_async_path_rewrites_ids() -> None:
|
||||
"""`awrap_model_call` rewrites IDs the same way as the sync path."""
|
||||
messages: list[BaseMessage] = [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "grep",
|
||||
"args": {},
|
||||
"id": "functions.grep:0",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
),
|
||||
ToolMessage(content="r", tool_call_id="functions.grep:0"),
|
||||
]
|
||||
middleware, request, captured = _capture(messages)
|
||||
|
||||
async def handler(req: ModelRequest) -> ModelResponse:
|
||||
captured.append(req)
|
||||
return ModelResponse(result=[AIMessage(content="ok")])
|
||||
|
||||
await middleware.awrap_model_call(request, handler)
|
||||
|
||||
sent = captured[0].messages
|
||||
rewritten_id = _valid_tool_call_id(cast(AIMessage, sent[0]).tool_calls[0]["id"])
|
||||
assert cast(ToolMessage, sent[1]).tool_call_id == rewritten_id
|
||||
|
||||
|
||||
def test_unsupported_model_ignore_default_skips_silently() -> None:
|
||||
"""Default `unsupported_model_behavior='ignore'` does not warn or rewrite."""
|
||||
messages: list[BaseMessage] = [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "grep",
|
||||
"args": {},
|
||||
"id": "functions.grep:0",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
),
|
||||
ToolMessage(content="r", tool_call_id="functions.grep:0"),
|
||||
]
|
||||
middleware = AnthropicToolIdSanitizationMiddleware()
|
||||
request = _make_request(messages, anthropic=False)
|
||||
captured: list[ModelRequest] = []
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error")
|
||||
middleware.wrap_model_call(request, _handler_factory(captured))
|
||||
|
||||
assert captured[0].messages is request.messages
|
||||
|
||||
|
||||
def test_unsupported_model_warn_emits_warning() -> None:
|
||||
"""`unsupported_model_behavior='warn'` warns and skips rewrite."""
|
||||
middleware = AnthropicToolIdSanitizationMiddleware(
|
||||
unsupported_model_behavior="warn"
|
||||
)
|
||||
request = _make_request([HumanMessage("hi")], anthropic=False)
|
||||
captured: list[ModelRequest] = []
|
||||
|
||||
with pytest.warns(UserWarning, match="only supports Anthropic"):
|
||||
middleware.wrap_model_call(request, _handler_factory(captured))
|
||||
|
||||
|
||||
def test_unsupported_model_raise_errors() -> None:
|
||||
"""`unsupported_model_behavior='raise'` raises `ValueError`."""
|
||||
middleware = AnthropicToolIdSanitizationMiddleware(
|
||||
unsupported_model_behavior="raise"
|
||||
)
|
||||
request = _make_request([HumanMessage("hi")], anthropic=False)
|
||||
captured: list[ModelRequest] = []
|
||||
|
||||
with pytest.raises(ValueError, match="only supports Anthropic"):
|
||||
middleware.wrap_model_call(request, _handler_factory(captured))
|
||||
|
||||
|
||||
def test_invalid_unsupported_model_behavior_rejected() -> None:
|
||||
"""A typo in `unsupported_model_behavior` raises `ValueError` at construction.
|
||||
|
||||
`Literal` is enforced statically only; without runtime validation a typo
|
||||
silently falls through to ignore semantics — a footgun precisely for users
|
||||
who opted into `'raise'` to surface bugs.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="unsupported_model_behavior"):
|
||||
AnthropicToolIdSanitizationMiddleware(
|
||||
unsupported_model_behavior="raies", # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
def test_unsupported_model_ignore_logs_debug_breadcrumb(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Default `'ignore'` mode emits a debug log so the bypass is observable."""
|
||||
middleware = AnthropicToolIdSanitizationMiddleware()
|
||||
request = _make_request([HumanMessage("hi")], anthropic=False)
|
||||
captured: list[ModelRequest] = []
|
||||
|
||||
with caplog.at_level(
|
||||
logging.DEBUG, logger="langchain_anthropic.middleware.tool_id_sanitization"
|
||||
):
|
||||
middleware.wrap_model_call(request, _handler_factory(captured))
|
||||
|
||||
assert any(
|
||||
"skipped" in record.message and "_FakeModel" in record.message
|
||||
for record in caplog.records
|
||||
)
|
||||
4
libs/partners/anthropic/uv.lock
generated
4
libs/partners/anthropic/uv.lock
generated
@@ -510,7 +510,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain"
|
||||
version = "1.2.15"
|
||||
version = "1.2.16"
|
||||
source = { editable = "../../langchain_v1" }
|
||||
dependencies = [
|
||||
{ name = "langchain-core" },
|
||||
@@ -734,7 +734,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-tests"
|
||||
version = "1.1.6"
|
||||
version = "1.1.7"
|
||||
source = { editable = "../../standard-tests" }
|
||||
dependencies = [
|
||||
{ name = "httpx" },
|
||||
|
||||
Reference in New Issue
Block a user