feat(anthropic): add AnthropicToolIdSanitizationMiddleware

This commit is contained in:
Mason Daugherty
2026-04-30 03:04:17 -04:00
parent 38553c3f2d
commit efc1c48572
4 changed files with 1369 additions and 2 deletions

View File

@@ -13,9 +13,13 @@ from langchain_anthropic.middleware.file_search import (
from langchain_anthropic.middleware.prompt_caching import (
AnthropicPromptCachingMiddleware,
)
from langchain_anthropic.middleware.tool_id_sanitization import (
AnthropicToolIdSanitizationMiddleware,
)
__all__ = [
"AnthropicPromptCachingMiddleware",
"AnthropicToolIdSanitizationMiddleware",
"ClaudeBashToolMiddleware",
"FilesystemClaudeMemoryMiddleware",
"FilesystemClaudeTextEditorMiddleware",

View File

@@ -0,0 +1,475 @@
"""Anthropic tool-call ID sanitization middleware."""
from __future__ import annotations
import hashlib
import logging
import re
from collections.abc import Awaitable, Callable
from typing import Any, Literal
from warnings import warn
from langchain_core.messages import AIMessage, AnyMessage, ToolMessage
from langchain_anthropic.chat_models import ChatAnthropic
try:
from langchain.agents.middleware.types import (
AgentMiddleware,
ModelCallResult,
ModelRequest,
ModelResponse,
)
except ModuleNotFoundError as e:
msg = (
"AnthropicToolIdSanitizationMiddleware requires 'langchain' to be "
"installed. This middleware is designed for use with LangChain agents. "
"Install it with: pip install langchain"
)
raise ImportError(msg) from e
logger = logging.getLogger(__name__)
_ANTHROPIC_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+$")
"""Character set Anthropic enforces server-side on `tool_use.id`.
Mirrors the regex echoed in the API errors returned for non-conforming IDs:
`messages.N.content.M.tool_use.id: String should match pattern
'^[a-zA-Z0-9_-]+$'`.
"""
_INVALID_CHAR_PATTERN = re.compile(r"[^a-zA-Z0-9_-]")
"""Complement of the character class in `_ANTHROPIC_ID_PATTERN`.
Matches any single illegal character. Used by `_make_safe_id` to substitute
offending characters with `_`.
"""
def _is_tool_use_type(block_type: Any) -> bool:
"""Return `True` if `block_type` names an ID-bearing tool-use block.
Covers `tool_use` (client tools), `server_tool_use` (Anthropic-emitted
server tools), and `mcp_tool_use` (MCP). Uses suffix matching so future
`*_tool_use` variants are picked up automatically.
"""
return isinstance(block_type, str) and block_type.endswith("tool_use")
def _is_client_tool_use_type(block_type: Any) -> bool:
"""Return `True` if `block_type` is a client `tool_use` block."""
return block_type == "tool_use"
def _is_tool_result_type(block_type: Any) -> bool:
"""Return `True` if `block_type` names a `tool_use_id`-bearing result block.
Covers `tool_result`, `mcp_tool_result`, and Anthropic server-tool result
variants like `web_search_tool_result`, `code_execution_tool_result`, and
`bash_tool_result`. Uses suffix matching for forward compatibility.
"""
return isinstance(block_type, str) and block_type.endswith("tool_result")
def _is_valid_id(tool_id: str | None) -> bool:
"""Return `True` if `tool_id` matches Anthropic's required pattern."""
if not tool_id:
return False
return _ANTHROPIC_ID_PATTERN.match(tool_id) is not None
def _make_safe_id(original: str, used: set[str]) -> str:
"""Derive an Anthropic-safe id, avoiding collisions within a single call.
Replaces every illegal character with `_`. If the resulting `base` already
appears in `used`, appends a sha256-derived suffix (and a counter on
further collisions) so each invalid input gets a distinct safe output
within one sanitization pass.
Args:
original: The invalid tool-call id to sanitize.
used: The set of ids already taken in this pass — both pre-existing
valid ids and previously-allocated safe ids.
Returns:
A regex-conformant id distinct from every entry in `used`.
"""
base = _INVALID_CHAR_PATTERN.sub("_", original) or "tool"
if base not in used:
return base
digest = hashlib.sha256(original.encode("utf-8")).hexdigest()[:8]
candidate = f"{base}_{digest}"
counter = 0
while candidate in used:
counter += 1
candidate = f"{base}_{digest}_{counter}"
return candidate
class AnthropicToolIdSanitizationMiddleware(AgentMiddleware):
"""Rewrite illegal tool-call IDs before sending to Anthropic.
Anthropic enforces `tool_use.id` matching `^[a-zA-Z0-9_-]+$`. Conversation
histories produced by other providers can violate this — e.g. Kimi-K2 emits
IDs of the form `functions.<name>:<idx>`, where `.` and `:` are illegal.
Replaying such a thread against Claude raises a 400 error.
This middleware scans `request.messages` for offending IDs, builds a
deterministic `bad_id -> safe_id` map, and rewrites every occurrence:
- `AIMessage.tool_calls[*]["id"]`
- `AIMessage.content[*]["id"]` for `tool_use` / `server_tool_use` /
`mcp_tool_use` blocks
- `ToolMessage.tool_call_id`
- `ToolMessage.content[*]["tool_use_id"]` for `tool_result` and the
`*_tool_result` variants (`web_search_tool_result`,
`code_execution_tool_result`, `bash_tool_result`, `mcp_tool_result`)
Within an `AIMessage`, position-paired `tool_calls[i]` and `tool_use`
content blocks are forced to share the same final id even when their
inputs disagree (drift), so Anthropic never receives a mismatched pair.
Only the outgoing `ModelRequest` is modified — graph state and persisted
checkpoints are left untouched, so the sanitization is idempotent across
turns and safe to combine with HITL resume.
"""
def __init__(
self,
unsupported_model_behavior: Literal["ignore", "warn", "raise"] = "ignore",
) -> None:
"""Initialize the middleware.
Args:
unsupported_model_behavior: Behavior when the bound model is not
`ChatAnthropic`.
`'ignore'` skips sanitization silently (default — other
providers may accept the original IDs).
`'warn'` emits a warning and skips sanitization.
`'raise'` raises `ValueError`.
Raises:
ValueError: If `unsupported_model_behavior` is not one of the
three allowed strings. `Literal` is enforced statically only;
this guards against typos at runtime.
"""
allowed = ("ignore", "warn", "raise")
if unsupported_model_behavior not in allowed:
msg = (
f"unsupported_model_behavior must be one of {allowed}; "
f"got {unsupported_model_behavior!r}"
)
raise ValueError(msg)
self.unsupported_model_behavior = unsupported_model_behavior
def _should_run(self, request: ModelRequest) -> bool:
"""Return `True` if the bound model is `ChatAnthropic`.
Args:
request: The model request to inspect.
Returns:
`True` when sanitization should run; `False` to bypass.
Raises:
ValueError: When the bound model is not `ChatAnthropic` and
`unsupported_model_behavior='raise'`.
"""
if isinstance(request.model, ChatAnthropic):
return True
msg = (
"AnthropicToolIdSanitizationMiddleware only supports Anthropic "
f"models, not instances of {type(request.model)}"
)
if self.unsupported_model_behavior == "raise":
raise ValueError(msg)
if self.unsupported_model_behavior == "warn":
warn(msg, stacklevel=3)
else:
# `ignore` mode — leave a breadcrumb so users debugging missing
# sanitization can confirm the middleware was bypassed and why.
logger.debug(
"AnthropicToolIdSanitizationMiddleware skipped: bound model "
"is %s, not ChatAnthropic.",
type(request.model).__name__,
)
return False
def _build_id_map(self, messages: list[AnyMessage]) -> dict[str, str]:
"""Build a deterministic mapping from invalid IDs to safe ones.
Args:
messages: The outgoing message list to scan.
Returns:
A dict mapping each invalid id to its safe replacement, or an
empty dict when every id already conforms.
"""
invalid: list[str] = []
seen: set[str] = set()
valid_ids: set[str] = set()
for tool_id in _iter_all_ids(messages):
if _is_valid_id(tool_id):
valid_ids.add(tool_id)
elif tool_id and tool_id not in seen:
invalid.append(tool_id)
seen.add(tool_id)
if not invalid:
return {}
used = set(valid_ids)
mapping: dict[str, str] = {}
for original in invalid:
new = _make_safe_id(original, used)
mapping[original] = new
used.add(new)
return mapping
def _sanitize(self, request: ModelRequest) -> ModelRequest:
"""Return a copy of `request` with sanitized tool-call IDs.
Args:
request: The model request to sanitize.
Returns:
A new request when any id was rewritten, otherwise the original.
"""
mapping = self._build_id_map(request.messages)
rewritten, changed = _rewrite_messages(request.messages, mapping)
if not changed:
return request
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Sanitized %d tool-call id(s) for Anthropic compatibility: %s",
len(mapping),
{k: mapping[k] for k in sorted(mapping)},
)
return request.override(messages=rewritten)
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
"""Sanitize tool-call IDs before delegating to the handler.
Args:
request: The model request to potentially modify.
handler: The handler to execute the (possibly rewritten) request.
Returns:
The model response produced by `handler`.
"""
if not self._should_run(request):
return handler(request)
return handler(self._sanitize(request))
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
"""Sanitize tool-call IDs before delegating to the async handler.
Args:
request: The model request to potentially modify.
handler: The async handler to execute the (possibly rewritten)
request.
Returns:
The model response produced by `handler`.
"""
if not self._should_run(request):
return await handler(request)
return await handler(self._sanitize(request))
def _iter_all_ids(messages: list[AnyMessage]) -> list[str]:
"""Collect every tool-call ID present in the message list."""
ids: list[str] = []
for msg in messages:
if isinstance(msg, AIMessage):
for tc in msg.tool_calls or []:
tid = tc.get("id")
if tid:
ids.append(tid)
if isinstance(msg.content, list):
for block in msg.content:
if isinstance(block, dict) and _is_tool_use_type(block.get("type")):
tid = block.get("id")
if isinstance(tid, str) and tid:
ids.append(tid)
elif isinstance(msg, ToolMessage):
if msg.tool_call_id:
ids.append(msg.tool_call_id)
if isinstance(msg.content, list):
for block in msg.content:
if isinstance(block, dict) and _is_tool_result_type(
block.get("type")
):
tid = block.get("tool_use_id")
if isinstance(tid, str) and tid:
ids.append(tid)
return ids
def _rewrite_messages(
messages: list[AnyMessage], mapping: dict[str, str]
) -> tuple[list[AnyMessage], bool]:
"""Return messages rewritten with sanitized IDs and local drift aliases."""
rewritten: list[AnyMessage] = []
changed = False
active_result_aliases: dict[str, str] = {}
for msg in messages:
new_msg: AnyMessage
if isinstance(msg, AIMessage):
new_msg, active_result_aliases = _rewrite_ai_message(msg, mapping)
elif isinstance(msg, ToolMessage):
result_mapping = {**mapping, **active_result_aliases}
new_msg = _rewrite_tool_message(msg, result_mapping)
else:
active_result_aliases = {}
new_msg = msg
if new_msg is not msg:
changed = True
rewritten.append(new_msg)
return rewritten, changed
def _rewrite_ai_message(
msg: AIMessage, mapping: dict[str, str]
) -> tuple[AIMessage, dict[str, str]]:
"""Return a copy of `msg` with `tool_calls` and tool-use blocks aligned.
Applies `mapping` to each id, then enforces position-paired alignment so
`tool_calls[i].id` and the i-th client `tool_use` content block share the
same final id even if their inputs drifted.
"""
result_aliases: dict[str, str] = {}
new_tool_calls: list[Any] | None = None
if msg.tool_calls and any(tc.get("id") in mapping for tc in msg.tool_calls):
new_tool_calls = []
for tc in msg.tool_calls:
tid = tc.get("id")
if tid is not None and tid in mapping:
new_tool_calls.append({**tc, "id": mapping[tid]})
else:
new_tool_calls.append(tc)
new_content: list[Any] | None = None
if isinstance(msg.content, list) and any(
isinstance(b, dict)
and _is_tool_use_type(b.get("type"))
and b.get("id") in mapping
for b in msg.content
):
new_content = [
{**b, "id": mapping[b["id"]]}
if (
isinstance(b, dict)
and _is_tool_use_type(b.get("type"))
and b.get("id") in mapping
)
else b
for b in msg.content
]
# Drift correction: even if `tool_calls[i].id` and the i-th client
# `tool_use` block disagreed pre-mapping (or post-mapping due to distinct
# invalid inputs collapsing through different sanitization branches), force
# the content block to adopt the canonical `tool_calls[i].id`. `tool_calls`
# is the langchain-canonical view.
effective_tool_calls = (
new_tool_calls if new_tool_calls is not None else msg.tool_calls
)
effective_content = new_content if new_content is not None else msg.content
if effective_tool_calls and isinstance(effective_content, list):
aligned, result_aliases, drift_seen = _align_client_tool_use_blocks(
effective_tool_calls,
effective_content,
msg.content if isinstance(msg.content, list) else effective_content,
)
if drift_seen:
new_content = aligned
updates: dict[str, Any] = {}
if new_tool_calls is not None:
updates["tool_calls"] = new_tool_calls
if new_content is not None:
updates["content"] = new_content
return (msg.model_copy(update=updates) if updates else msg), result_aliases
def _align_client_tool_use_blocks(
tool_calls: list[Any], content: list[Any], original_content: list[Any]
) -> tuple[list[Any], dict[str, str], bool]:
"""Force tool-use content blocks to share IDs with corresponding tool_calls.
Returns the (possibly rewritten) content list plus aliases from drifted
content-block IDs to the canonical tool-call ID. These aliases are applied
only to the following `ToolMessage` results for this assistant turn.
"""
result_aliases: dict[str, str] = {}
drift_seen = False
aligned: list[Any] = list(content)
tu_indices = [
idx
for idx, b in enumerate(content)
if isinstance(b, dict) and _is_client_tool_use_type(b.get("type"))
]
for pair_idx, content_idx in enumerate(tu_indices):
if pair_idx >= len(tool_calls):
break
canonical_id = tool_calls[pair_idx].get("id")
block = aligned[content_idx]
if not isinstance(block, dict):
continue
if canonical_id and block.get("id") != canonical_id:
aligned[content_idx] = {**block, "id": canonical_id}
drift_seen = True
original_block = original_content[content_idx]
if isinstance(original_block, dict):
original_id = original_block.get("id")
if isinstance(original_id, str) and original_id != canonical_id:
result_aliases[original_id] = canonical_id
current_id = block.get("id")
if isinstance(current_id, str) and current_id != canonical_id:
result_aliases[current_id] = canonical_id
return aligned, result_aliases, drift_seen
def _rewrite_tool_message(msg: ToolMessage, mapping: dict[str, str]) -> ToolMessage:
"""Return a copy of `msg` with `tool_call_id` and tool-result blocks rewritten."""
updates: dict[str, Any] = {}
if msg.tool_call_id in mapping:
updates["tool_call_id"] = mapping[msg.tool_call_id]
if isinstance(msg.content, list) and any(
isinstance(b, dict)
and _is_tool_result_type(b.get("type"))
and b.get("tool_use_id") in mapping
for b in msg.content
):
updates["content"] = [
{**b, "tool_use_id": mapping[b["tool_use_id"]]}
if (
isinstance(b, dict)
and _is_tool_result_type(b.get("type"))
and b.get("tool_use_id") in mapping
)
else b
for b in msg.content
]
return msg.model_copy(update=updates) if updates else msg

View File

@@ -0,0 +1,888 @@
"""Tests for Anthropic tool-call ID sanitization middleware."""
import logging
import re
import warnings
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from langchain.agents.middleware.types import AgentState, ModelRequest, ModelResponse
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AnyMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langgraph.runtime import Runtime
from langchain_anthropic.chat_models import ChatAnthropic
from langchain_anthropic.middleware import AnthropicToolIdSanitizationMiddleware
_ANTHROPIC_ID_RE = re.compile(r"^[a-zA-Z0-9_-]+$")
class _FakeModel(BaseChatModel):
"""Stand-in non-Anthropic model for unsupported-model tests."""
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content="ok", id="0"))]
)
async def _agenerate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: AsyncCallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content="ok", id="0"))]
)
@property
def _llm_type(self) -> str:
return "fake"
def _make_request(
messages: list[BaseMessage], *, anthropic: bool = True
) -> ModelRequest:
"""Build a `ModelRequest` for testing."""
if anthropic:
mock_model = MagicMock(spec=ChatAnthropic)
mock_model._llm_type = "anthropic-chat"
model: BaseChatModel = cast(BaseChatModel, mock_model)
else:
model = _FakeModel()
return ModelRequest(
model=model,
messages=cast(list[AnyMessage], messages),
system_prompt=None,
tool_choice=None,
tools=[],
response_format=None,
state=cast(AgentState[Any], {"messages": messages}),
runtime=cast(Runtime, object()),
model_settings={},
)
def _capture(
messages: list[BaseMessage],
) -> tuple[
AnthropicToolIdSanitizationMiddleware,
ModelRequest,
list[ModelRequest],
]:
"""Wire a middleware + request + capturing handler list for a test."""
middleware = AnthropicToolIdSanitizationMiddleware()
request = _make_request(messages)
captured: list[ModelRequest] = []
return middleware, request, captured
def _handler_factory(captured: list[ModelRequest]) -> Any:
def _handler(req: ModelRequest) -> ModelResponse:
captured.append(req)
return ModelResponse(result=[AIMessage(content="ok")])
return _handler
def _valid_tool_call_id(value: str | None) -> str:
"""Assert that an optional tool-call ID is present and Anthropic-safe."""
assert value is not None
assert _ANTHROPIC_ID_RE.match(value)
return value
def test_passthrough_when_all_ids_valid() -> None:
"""Valid IDs short-circuit: every message reaches the handler unchanged."""
messages: list[BaseMessage] = [
HumanMessage("hi"),
AIMessage(
content="",
tool_calls=[
{"name": "grep", "args": {}, "id": "toolu_01abc", "type": "tool_call"}
],
),
ToolMessage(content="done", tool_call_id="toolu_01abc"),
]
middleware, request, captured = _capture(messages)
middleware.wrap_model_call(request, _handler_factory(captured))
assert len(captured) == 1
# Per-message identity: each individual message is the original instance.
for original, sent in zip(request.messages, captured[0].messages, strict=True):
assert sent is original
def test_rewrites_kimi_k2_style_ids_in_pairs() -> None:
"""Bad IDs on AIMessage tool_calls and matching ToolMessage are rewritten."""
messages: list[BaseMessage] = [
HumanMessage("hi"),
AIMessage(
content="",
tool_calls=[
{
"name": "grep",
"args": {"q": "x"},
"id": "functions.grep:0",
"type": "tool_call",
},
{
"name": "grep",
"args": {"q": "y"},
"id": "functions.grep:1",
"type": "tool_call",
},
],
),
ToolMessage(content="r0", tool_call_id="functions.grep:0"),
ToolMessage(content="r1", tool_call_id="functions.grep:1"),
]
middleware, request, captured = _capture(messages)
middleware.wrap_model_call(request, _handler_factory(captured))
sent = captured[0].messages
ai = sent[1]
assert isinstance(ai, AIMessage)
new_ids = [_valid_tool_call_id(tc["id"]) for tc in ai.tool_calls]
tool_msgs = [m for m in sent if isinstance(m, ToolMessage)]
assert [tm.tool_call_id for tm in tool_msgs] == new_ids
assert new_ids[0] != new_ids[1]
def test_rewrite_is_deterministic_across_invocations() -> None:
"""The same input produces the same sanitized IDs across separate calls."""
payload: list[BaseMessage] = [
AIMessage(
content="",
tool_calls=[
{
"name": "grep",
"args": {},
"id": "functions.grep:0",
"type": "tool_call",
},
{
"name": "grep",
"args": {},
"id": "functions.grep:1",
"type": "tool_call",
},
],
),
ToolMessage(content="r0", tool_call_id="functions.grep:0"),
ToolMessage(content="r1", tool_call_id="functions.grep:1"),
]
def _run() -> list[str]:
# Build a fresh copy each invocation so the two runs are independent.
msgs = [m.model_copy(deep=True) for m in payload]
middleware = AnthropicToolIdSanitizationMiddleware()
request = _make_request(msgs)
captured: list[ModelRequest] = []
middleware.wrap_model_call(request, _handler_factory(captured))
ai = cast(AIMessage, captured[0].messages[0])
return [_valid_tool_call_id(tc["id"]) for tc in ai.tool_calls]
assert _run() == _run()
def test_rewrites_anthropic_content_blocks() -> None:
"""`tool_use` and `tool_result` content blocks have their IDs rewritten too."""
messages: list[BaseMessage] = [
AIMessage(
content=[
{"type": "text", "text": "calling"},
{
"type": "tool_use",
"id": "functions.read_file:0",
"name": "read_file",
"input": {},
},
],
tool_calls=[
{
"name": "read_file",
"args": {},
"id": "functions.read_file:0",
"type": "tool_call",
},
],
),
ToolMessage(
content=[
{
"type": "tool_result",
"tool_use_id": "functions.read_file:0",
"content": "hi",
},
],
tool_call_id="functions.read_file:0",
),
]
middleware, request, captured = _capture(messages)
middleware.wrap_model_call(request, _handler_factory(captured))
sent = captured[0].messages
ai = sent[0]
tm = sent[1]
assert isinstance(ai, AIMessage)
assert isinstance(tm, ToolMessage)
block_id = next(
b["id"]
for b in ai.content
if isinstance(b, dict) and b.get("type") == "tool_use"
)
tool_call_id = _valid_tool_call_id(ai.tool_calls[0]["id"])
result_id = next(
b["tool_use_id"]
for b in tm.content
if isinstance(b, dict) and b.get("type") == "tool_result"
)
assert block_id == tool_call_id == tm.tool_call_id == result_id
for value in (block_id, tool_call_id, tm.tool_call_id, result_id):
assert _ANTHROPIC_ID_RE.match(value)
def test_drift_between_tool_calls_and_tool_use_blocks_is_corrected() -> None:
"""When `tool_calls[i].id` and the i-th tool_use block disagree, alignment wins.
Without correction, two distinct invalid IDs would map to two different
safe IDs (`a_b` vs `a_b_<digest>`), and Anthropic would 400 on the
mismatched pair. The middleware forces both to share `tool_calls[i].id`.
"""
messages: list[BaseMessage] = [
AIMessage(
content=[
{
"type": "tool_use",
"id": "a:b", # drift — different illegal id
"name": "tool",
"input": {},
},
],
tool_calls=[
{
"name": "tool",
"args": {},
"id": "a.b", # canonical
"type": "tool_call",
}
],
),
ToolMessage(
content=[
{
"type": "tool_result",
"tool_use_id": "a:b",
"content": "r",
},
],
tool_call_id="a:b",
),
]
middleware, request, captured = _capture(messages)
middleware.wrap_model_call(request, _handler_factory(captured))
sent = captured[0].messages
ai = cast(AIMessage, sent[0])
tm = cast(ToolMessage, sent[1])
final_tc_id = _valid_tool_call_id(ai.tool_calls[0]["id"])
final_block_id = next(
b["id"]
for b in ai.content
if isinstance(b, dict) and b.get("type") == "tool_use"
)
final_result_id = next(
b["tool_use_id"]
for b in tm.content
if isinstance(b, dict) and b.get("type") == "tool_result"
)
assert final_tc_id == final_block_id == tm.tool_call_id == final_result_id
def test_valid_drift_between_tool_calls_and_tool_use_blocks_is_corrected() -> None:
"""Valid but mismatched tool-call and `tool_use` IDs are aligned."""
messages: list[BaseMessage] = [
AIMessage(
content=[
{
"type": "tool_use",
"id": "toolu_content",
"name": "tool",
"input": {},
},
],
tool_calls=[
{
"name": "tool",
"args": {},
"id": "toolu_call",
"type": "tool_call",
}
],
),
ToolMessage(
content=[
{
"type": "tool_result",
"tool_use_id": "toolu_content",
"content": "r",
},
],
tool_call_id="toolu_content",
),
]
middleware, request, captured = _capture(messages)
middleware.wrap_model_call(request, _handler_factory(captured))
sent = captured[0].messages
ai = cast(AIMessage, sent[0])
tm = cast(ToolMessage, sent[1])
block_id = next(
b["id"]
for b in ai.content
if isinstance(b, dict) and b.get("type") == "tool_use"
)
result_id = next(
b["tool_use_id"]
for b in tm.content
if isinstance(b, dict) and b.get("type") == "tool_result"
)
assert ai.tool_calls[0]["id"] == "toolu_call"
assert block_id == "toolu_call"
assert tm.tool_call_id == "toolu_call"
assert result_id == "toolu_call"
def test_two_invalid_ids_with_same_base_get_distinct_safe_ids() -> None:
"""`a.b` and `a:b` both sanitize to base `a_b` — second falls back to suffix.
Exercises the sha256-suffix branch of `_make_safe_id`. Both AIMessages
use the same callable name so name-based heuristics can't disambiguate;
only the `_make_safe_id` collision logic produces distinct safe IDs.
"""
messages: list[BaseMessage] = [
AIMessage(
content="",
tool_calls=[{"name": "tool", "args": {}, "id": "a.b", "type": "tool_call"}],
),
ToolMessage(content="r1", tool_call_id="a.b"),
AIMessage(
content="",
tool_calls=[{"name": "tool", "args": {}, "id": "a:b", "type": "tool_call"}],
),
ToolMessage(content="r2", tool_call_id="a:b"),
]
middleware, request, captured = _capture(messages)
middleware.wrap_model_call(request, _handler_factory(captured))
sent = captured[0].messages
first_id = _valid_tool_call_id(cast(AIMessage, sent[0]).tool_calls[0]["id"])
second_id = _valid_tool_call_id(cast(AIMessage, sent[2]).tool_calls[0]["id"])
# Both sanitize cleanly, both are valid, both are distinct.
assert first_id != second_id
# The first invalid wins the base; the second gets the suffix.
assert first_id == "a_b"
assert second_id.startswith("a_b_")
# Pairs survive: ToolMessage ids match their AIMessage counterparts.
assert cast(ToolMessage, sent[1]).tool_call_id == first_id
assert cast(ToolMessage, sent[3]).tool_call_id == second_id
def test_state_is_not_mutated() -> None:
"""Original messages on the request and in graph state are not mutated."""
original_id = "functions.grep:0"
messages: list[BaseMessage] = [
AIMessage(
content="",
tool_calls=[
{"name": "grep", "args": {}, "id": original_id, "type": "tool_call"}
],
),
ToolMessage(content="r", tool_call_id=original_id),
]
middleware, request, captured = _capture(messages)
middleware.wrap_model_call(request, _handler_factory(captured))
assert request.messages[0].tool_calls[0]["id"] == original_id # type: ignore[union-attr]
assert cast(ToolMessage, request.messages[1]).tool_call_id == original_id
assert captured[0].messages is not request.messages
def test_collision_avoidance_with_existing_valid_id() -> None:
"""A bad ID that sanitizes to an already-used valid ID gets a hash suffix."""
messages: list[BaseMessage] = [
AIMessage(
content="",
tool_calls=[
{
"name": "grep",
"args": {},
"id": "functions_grep_0",
"type": "tool_call",
},
],
),
ToolMessage(content="a", tool_call_id="functions_grep_0"),
AIMessage(
content="",
tool_calls=[
{
"name": "grep",
"args": {},
"id": "functions.grep:0",
"type": "tool_call",
},
],
),
ToolMessage(content="b", tool_call_id="functions.grep:0"),
]
middleware, request, captured = _capture(messages)
middleware.wrap_model_call(request, _handler_factory(captured))
sent = captured[0].messages
first_id = _valid_tool_call_id(cast(AIMessage, sent[0]).tool_calls[0]["id"])
second_id = _valid_tool_call_id(cast(AIMessage, sent[2]).tool_calls[0]["id"])
assert first_id == "functions_grep_0"
assert second_id != first_id
def test_mixed_valid_and_invalid_ids_in_same_message() -> None:
"""Valid IDs are preserved while siblings with invalid IDs are rewritten."""
messages: list[BaseMessage] = [
AIMessage(
content="",
tool_calls=[
{
"name": "good",
"args": {},
"id": "toolu_clean",
"type": "tool_call",
},
{
"name": "bad",
"args": {},
"id": "functions.bad:0",
"type": "tool_call",
},
],
),
ToolMessage(content="g", tool_call_id="toolu_clean"),
ToolMessage(content="b", tool_call_id="functions.bad:0"),
]
middleware, request, captured = _capture(messages)
middleware.wrap_model_call(request, _handler_factory(captured))
sent = captured[0].messages
ai = cast(AIMessage, sent[0])
valid_id = _valid_tool_call_id(ai.tool_calls[0]["id"])
rewritten_id = _valid_tool_call_id(ai.tool_calls[1]["id"])
assert valid_id == "toolu_clean" # untouched
assert rewritten_id != "functions.bad:0"
assert cast(ToolMessage, sent[1]).tool_call_id == valid_id
assert cast(ToolMessage, sent[2]).tool_call_id == rewritten_id
def test_rewrites_server_and_mcp_block_types() -> None:
"""`server_tool_use` / `mcp_tool_use` and their result variants are also handled."""
messages: list[BaseMessage] = [
AIMessage(
content=[
{
"type": "server_tool_use",
"id": "srv.1",
"name": "web_search",
"input": {},
},
{
"type": "mcp_tool_use",
"id": "mcp.2",
"name": "fetch",
"input": {},
},
],
tool_calls=[],
),
ToolMessage(
content=[
{
"type": "web_search_tool_result",
"tool_use_id": "srv.1",
"content": "hits",
},
],
tool_call_id="srv.1",
),
ToolMessage(
content=[
{
"type": "mcp_tool_result",
"tool_use_id": "mcp.2",
"content": "data",
},
],
tool_call_id="mcp.2",
),
]
middleware, request, captured = _capture(messages)
middleware.wrap_model_call(request, _handler_factory(captured))
sent = captured[0].messages
ai = cast(AIMessage, sent[0])
server_block_id = next(
b["id"]
for b in ai.content
if isinstance(b, dict) and b.get("type") == "server_tool_use"
)
mcp_block_id = next(
b["id"]
for b in ai.content
if isinstance(b, dict) and b.get("type") == "mcp_tool_use"
)
for value in (
server_block_id,
mcp_block_id,
cast(ToolMessage, sent[1]).tool_call_id,
cast(ToolMessage, sent[2]).tool_call_id,
):
assert _ANTHROPIC_ID_RE.match(value)
# Pair survival:
web_result_id = next(
b["tool_use_id"]
for b in cast(ToolMessage, sent[1]).content
if isinstance(b, dict) and b.get("type") == "web_search_tool_result"
)
mcp_result_id = next(
b["tool_use_id"]
for b in cast(ToolMessage, sent[2]).content
if isinstance(b, dict) and b.get("type") == "mcp_tool_result"
)
assert web_result_id == server_block_id
assert mcp_result_id == mcp_block_id
def test_client_alignment_skips_server_and_mcp_tool_use_blocks() -> None:
"""Server and MCP tool-use blocks do not position-pair with `tool_calls`."""
messages: list[BaseMessage] = [
AIMessage(
content=[
{
"type": "server_tool_use",
"id": "srv.1",
"name": "web_search",
"input": {},
},
{
"type": "mcp_tool_use",
"id": "mcp.2",
"name": "fetch",
"input": {},
},
{
"type": "tool_use",
"id": "client.3",
"name": "client_tool",
"input": {},
},
],
tool_calls=[
{
"name": "client_tool",
"args": {},
"id": "client.3",
"type": "tool_call",
},
],
),
ToolMessage(
content=[
{
"type": "web_search_tool_result",
"tool_use_id": "srv.1",
"content": "hits",
},
],
tool_call_id="srv.1",
),
ToolMessage(
content=[
{
"type": "mcp_tool_result",
"tool_use_id": "mcp.2",
"content": "data",
},
],
tool_call_id="mcp.2",
),
ToolMessage(
content=[
{
"type": "tool_result",
"tool_use_id": "client.3",
"content": "done",
},
],
tool_call_id="client.3",
),
]
middleware, request, captured = _capture(messages)
middleware.wrap_model_call(request, _handler_factory(captured))
sent = captured[0].messages
ai = cast(AIMessage, sent[0])
server_block_id = next(
b["id"]
for b in ai.content
if isinstance(b, dict) and b.get("type") == "server_tool_use"
)
mcp_block_id = next(
b["id"]
for b in ai.content
if isinstance(b, dict) and b.get("type") == "mcp_tool_use"
)
client_block_id = next(
b["id"]
for b in ai.content
if isinstance(b, dict) and b.get("type") == "tool_use"
)
client_call_id = _valid_tool_call_id(ai.tool_calls[0]["id"])
web_result_id = next(
b["tool_use_id"]
for b in cast(ToolMessage, sent[1]).content
if isinstance(b, dict) and b.get("type") == "web_search_tool_result"
)
mcp_result_id = next(
b["tool_use_id"]
for b in cast(ToolMessage, sent[2]).content
if isinstance(b, dict) and b.get("type") == "mcp_tool_result"
)
client_result_id = next(
b["tool_use_id"]
for b in cast(ToolMessage, sent[3]).content
if isinstance(b, dict) and b.get("type") == "tool_result"
)
assert server_block_id == web_result_id == cast(ToolMessage, sent[1]).tool_call_id
assert mcp_block_id == mcp_result_id == cast(ToolMessage, sent[2]).tool_call_id
assert client_block_id == client_call_id == client_result_id
assert cast(ToolMessage, sent[3]).tool_call_id == client_call_id
assert server_block_id != client_call_id
assert mcp_block_id != client_call_id
def test_partial_update_only_rewrites_what_changed() -> None:
"""An AIMessage with a clean tool_calls but dirty content block triggers
rewrite of the content only, not the tool_calls list (partial-update path).
"""
messages: list[BaseMessage] = [
AIMessage(
content=[
{"type": "text", "text": "thinking"},
{
"type": "tool_use",
"id": "functions.x:0", # invalid
"name": "x",
"input": {},
},
],
tool_calls=[
# Valid id; should NOT be rewritten by mapping. Drift correction
# will re-align the content block to use this id.
{"name": "x", "args": {}, "id": "toolu_canon", "type": "tool_call"},
],
),
ToolMessage(content="r", tool_call_id="toolu_canon"),
]
middleware, request, captured = _capture(messages)
middleware.wrap_model_call(request, _handler_factory(captured))
sent = captured[0].messages
ai = cast(AIMessage, sent[0])
block_id = next(
b["id"]
for b in ai.content
if isinstance(b, dict) and b.get("type") == "tool_use"
)
# Drift correction collapses the divergent block id onto the canonical one.
assert ai.tool_calls[0]["id"] == "toolu_canon"
assert block_id == "toolu_canon"
assert cast(ToolMessage, sent[1]).tool_call_id == "toolu_canon"
def test_unrelated_message_subclasses_pass_through() -> None:
"""SystemMessage and HumanMessage are returned untouched even mid-rewrite."""
sys_msg = SystemMessage(content="you are an agent")
human = HumanMessage(content="hi")
messages: list[BaseMessage] = [
sys_msg,
human,
AIMessage(
content="",
tool_calls=[
{"name": "x", "args": {}, "id": "functions.x:0", "type": "tool_call"}
],
),
ToolMessage(content="r", tool_call_id="functions.x:0"),
]
middleware, request, captured = _capture(messages)
middleware.wrap_model_call(request, _handler_factory(captured))
sent = captured[0].messages
# Non-AI/non-Tool messages keep their identity.
assert sent[0] is sys_msg
assert sent[1] is human
async def test_async_path_rewrites_ids() -> None:
"""`awrap_model_call` rewrites IDs the same way as the sync path."""
messages: list[BaseMessage] = [
AIMessage(
content="",
tool_calls=[
{
"name": "grep",
"args": {},
"id": "functions.grep:0",
"type": "tool_call",
}
],
),
ToolMessage(content="r", tool_call_id="functions.grep:0"),
]
middleware, request, captured = _capture(messages)
async def handler(req: ModelRequest) -> ModelResponse:
captured.append(req)
return ModelResponse(result=[AIMessage(content="ok")])
await middleware.awrap_model_call(request, handler)
sent = captured[0].messages
rewritten_id = _valid_tool_call_id(cast(AIMessage, sent[0]).tool_calls[0]["id"])
assert cast(ToolMessage, sent[1]).tool_call_id == rewritten_id
def test_unsupported_model_ignore_default_skips_silently() -> None:
"""Default `unsupported_model_behavior='ignore'` does not warn or rewrite."""
messages: list[BaseMessage] = [
AIMessage(
content="",
tool_calls=[
{
"name": "grep",
"args": {},
"id": "functions.grep:0",
"type": "tool_call",
}
],
),
ToolMessage(content="r", tool_call_id="functions.grep:0"),
]
middleware = AnthropicToolIdSanitizationMiddleware()
request = _make_request(messages, anthropic=False)
captured: list[ModelRequest] = []
with warnings.catch_warnings():
warnings.simplefilter("error")
middleware.wrap_model_call(request, _handler_factory(captured))
assert captured[0].messages is request.messages
def test_unsupported_model_warn_emits_warning() -> None:
"""`unsupported_model_behavior='warn'` warns and skips rewrite."""
middleware = AnthropicToolIdSanitizationMiddleware(
unsupported_model_behavior="warn"
)
request = _make_request([HumanMessage("hi")], anthropic=False)
captured: list[ModelRequest] = []
with pytest.warns(UserWarning, match="only supports Anthropic"):
middleware.wrap_model_call(request, _handler_factory(captured))
def test_unsupported_model_raise_errors() -> None:
"""`unsupported_model_behavior='raise'` raises `ValueError`."""
middleware = AnthropicToolIdSanitizationMiddleware(
unsupported_model_behavior="raise"
)
request = _make_request([HumanMessage("hi")], anthropic=False)
captured: list[ModelRequest] = []
with pytest.raises(ValueError, match="only supports Anthropic"):
middleware.wrap_model_call(request, _handler_factory(captured))
def test_invalid_unsupported_model_behavior_rejected() -> None:
"""A typo in `unsupported_model_behavior` raises `ValueError` at construction.
`Literal` is enforced statically only; without runtime validation a typo
silently falls through to ignore semantics — a footgun precisely for users
who opted into `'raise'` to surface bugs.
"""
with pytest.raises(ValueError, match="unsupported_model_behavior"):
AnthropicToolIdSanitizationMiddleware(
unsupported_model_behavior="raies", # type: ignore[arg-type]
)
def test_unsupported_model_ignore_logs_debug_breadcrumb(
caplog: pytest.LogCaptureFixture,
) -> None:
"""Default `'ignore'` mode emits a debug log so the bypass is observable."""
middleware = AnthropicToolIdSanitizationMiddleware()
request = _make_request([HumanMessage("hi")], anthropic=False)
captured: list[ModelRequest] = []
with caplog.at_level(
logging.DEBUG, logger="langchain_anthropic.middleware.tool_id_sanitization"
):
middleware.wrap_model_call(request, _handler_factory(captured))
assert any(
"skipped" in record.message and "_FakeModel" in record.message
for record in caplog.records
)

View File

@@ -510,7 +510,7 @@ wheels = [
[[package]]
name = "langchain"
version = "1.2.15"
version = "1.2.16"
source = { editable = "../../langchain_v1" }
dependencies = [
{ name = "langchain-core" },
@@ -734,7 +734,7 @@ wheels = [
[[package]]
name = "langchain-tests"
version = "1.1.6"
version = "1.1.7"
source = { editable = "../../standard-tests" }
dependencies = [
{ name = "httpx" },