diff --git a/libs/partners/anthropic/langchain_anthropic/middleware/__init__.py b/libs/partners/anthropic/langchain_anthropic/middleware/__init__.py index d1a34993c8d..bc90a6fe2cc 100644 --- a/libs/partners/anthropic/langchain_anthropic/middleware/__init__.py +++ b/libs/partners/anthropic/langchain_anthropic/middleware/__init__.py @@ -13,9 +13,13 @@ from langchain_anthropic.middleware.file_search import ( from langchain_anthropic.middleware.prompt_caching import ( AnthropicPromptCachingMiddleware, ) +from langchain_anthropic.middleware.tool_id_sanitization import ( + AnthropicToolIdSanitizationMiddleware, +) __all__ = [ "AnthropicPromptCachingMiddleware", + "AnthropicToolIdSanitizationMiddleware", "ClaudeBashToolMiddleware", "FilesystemClaudeMemoryMiddleware", "FilesystemClaudeTextEditorMiddleware", diff --git a/libs/partners/anthropic/langchain_anthropic/middleware/tool_id_sanitization.py b/libs/partners/anthropic/langchain_anthropic/middleware/tool_id_sanitization.py new file mode 100644 index 00000000000..864ff43c445 --- /dev/null +++ b/libs/partners/anthropic/langchain_anthropic/middleware/tool_id_sanitization.py @@ -0,0 +1,475 @@ +"""Anthropic tool-call ID sanitization middleware.""" + +from __future__ import annotations + +import hashlib +import logging +import re +from collections.abc import Awaitable, Callable +from typing import Any, Literal +from warnings import warn + +from langchain_core.messages import AIMessage, AnyMessage, ToolMessage + +from langchain_anthropic.chat_models import ChatAnthropic + +try: + from langchain.agents.middleware.types import ( + AgentMiddleware, + ModelCallResult, + ModelRequest, + ModelResponse, + ) +except ModuleNotFoundError as e: + msg = ( + "AnthropicToolIdSanitizationMiddleware requires 'langchain' to be " + "installed. This middleware is designed for use with LangChain agents. " + "Install it with: pip install langchain" + ) + raise ImportError(msg) from e + + +logger = logging.getLogger(__name__) + + +_ANTHROPIC_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+$") +"""Character set Anthropic enforces server-side on `tool_use.id`. + +Mirrors the regex echoed in the API errors returned for non-conforming IDs: +`messages.N.content.M.tool_use.id: String should match pattern +'^[a-zA-Z0-9_-]+$'`. +""" + +_INVALID_CHAR_PATTERN = re.compile(r"[^a-zA-Z0-9_-]") +"""Complement of the character class in `_ANTHROPIC_ID_PATTERN`. + +Matches any single illegal character. Used by `_make_safe_id` to substitute +offending characters with `_`. +""" + + +def _is_tool_use_type(block_type: Any) -> bool: + """Return `True` if `block_type` names an ID-bearing tool-use block. + + Covers `tool_use` (client tools), `server_tool_use` (Anthropic-emitted + server tools), and `mcp_tool_use` (MCP). Uses suffix matching so future + `*_tool_use` variants are picked up automatically. + """ + return isinstance(block_type, str) and block_type.endswith("tool_use") + + +def _is_client_tool_use_type(block_type: Any) -> bool: + """Return `True` if `block_type` is a client `tool_use` block.""" + return block_type == "tool_use" + + +def _is_tool_result_type(block_type: Any) -> bool: + """Return `True` if `block_type` names a `tool_use_id`-bearing result block. + + Covers `tool_result`, `mcp_tool_result`, and Anthropic server-tool result + variants like `web_search_tool_result`, `code_execution_tool_result`, and + `bash_tool_result`. Uses suffix matching for forward compatibility. + """ + return isinstance(block_type, str) and block_type.endswith("tool_result") + + +def _is_valid_id(tool_id: str | None) -> bool: + """Return `True` if `tool_id` matches Anthropic's required pattern.""" + if not tool_id: + return False + return _ANTHROPIC_ID_PATTERN.match(tool_id) is not None + + +def _make_safe_id(original: str, used: set[str]) -> str: + """Derive an Anthropic-safe id, avoiding collisions within a single call. + + Replaces every illegal character with `_`. If the resulting `base` already + appears in `used`, appends a sha256-derived suffix (and a counter on + further collisions) so each invalid input gets a distinct safe output + within one sanitization pass. + + Args: + original: The invalid tool-call id to sanitize. + used: The set of ids already taken in this pass — both pre-existing + valid ids and previously-allocated safe ids. + + Returns: + A regex-conformant id distinct from every entry in `used`. + """ + base = _INVALID_CHAR_PATTERN.sub("_", original) or "tool" + if base not in used: + return base + digest = hashlib.sha256(original.encode("utf-8")).hexdigest()[:8] + candidate = f"{base}_{digest}" + counter = 0 + while candidate in used: + counter += 1 + candidate = f"{base}_{digest}_{counter}" + return candidate + + +class AnthropicToolIdSanitizationMiddleware(AgentMiddleware): + """Rewrite illegal tool-call IDs before sending to Anthropic. + + Anthropic enforces `tool_use.id` matching `^[a-zA-Z0-9_-]+$`. Conversation + histories produced by other providers can violate this — e.g. Kimi-K2 emits + IDs of the form `functions.:`, where `.` and `:` are illegal. + Replaying such a thread against Claude raises a 400 error. + + This middleware scans `request.messages` for offending IDs, builds a + deterministic `bad_id -> safe_id` map, and rewrites every occurrence: + + - `AIMessage.tool_calls[*]["id"]` + - `AIMessage.content[*]["id"]` for `tool_use` / `server_tool_use` / + `mcp_tool_use` blocks + - `ToolMessage.tool_call_id` + - `ToolMessage.content[*]["tool_use_id"]` for `tool_result` and the + `*_tool_result` variants (`web_search_tool_result`, + `code_execution_tool_result`, `bash_tool_result`, `mcp_tool_result`) + + Within an `AIMessage`, position-paired `tool_calls[i]` and `tool_use` + content blocks are forced to share the same final id even when their + inputs disagree (drift), so Anthropic never receives a mismatched pair. + + Only the outgoing `ModelRequest` is modified — graph state and persisted + checkpoints are left untouched, so the sanitization is idempotent across + turns and safe to combine with HITL resume. + """ + + def __init__( + self, + unsupported_model_behavior: Literal["ignore", "warn", "raise"] = "ignore", + ) -> None: + """Initialize the middleware. + + Args: + unsupported_model_behavior: Behavior when the bound model is not + `ChatAnthropic`. + + `'ignore'` skips sanitization silently (default — other + providers may accept the original IDs). + + `'warn'` emits a warning and skips sanitization. + + `'raise'` raises `ValueError`. + + Raises: + ValueError: If `unsupported_model_behavior` is not one of the + three allowed strings. `Literal` is enforced statically only; + this guards against typos at runtime. + """ + allowed = ("ignore", "warn", "raise") + if unsupported_model_behavior not in allowed: + msg = ( + f"unsupported_model_behavior must be one of {allowed}; " + f"got {unsupported_model_behavior!r}" + ) + raise ValueError(msg) + self.unsupported_model_behavior = unsupported_model_behavior + + def _should_run(self, request: ModelRequest) -> bool: + """Return `True` if the bound model is `ChatAnthropic`. + + Args: + request: The model request to inspect. + + Returns: + `True` when sanitization should run; `False` to bypass. + + Raises: + ValueError: When the bound model is not `ChatAnthropic` and + `unsupported_model_behavior='raise'`. + """ + if isinstance(request.model, ChatAnthropic): + return True + msg = ( + "AnthropicToolIdSanitizationMiddleware only supports Anthropic " + f"models, not instances of {type(request.model)}" + ) + if self.unsupported_model_behavior == "raise": + raise ValueError(msg) + if self.unsupported_model_behavior == "warn": + warn(msg, stacklevel=3) + else: + # `ignore` mode — leave a breadcrumb so users debugging missing + # sanitization can confirm the middleware was bypassed and why. + logger.debug( + "AnthropicToolIdSanitizationMiddleware skipped: bound model " + "is %s, not ChatAnthropic.", + type(request.model).__name__, + ) + return False + + def _build_id_map(self, messages: list[AnyMessage]) -> dict[str, str]: + """Build a deterministic mapping from invalid IDs to safe ones. + + Args: + messages: The outgoing message list to scan. + + Returns: + A dict mapping each invalid id to its safe replacement, or an + empty dict when every id already conforms. + """ + invalid: list[str] = [] + seen: set[str] = set() + valid_ids: set[str] = set() + + for tool_id in _iter_all_ids(messages): + if _is_valid_id(tool_id): + valid_ids.add(tool_id) + elif tool_id and tool_id not in seen: + invalid.append(tool_id) + seen.add(tool_id) + + if not invalid: + return {} + + used = set(valid_ids) + mapping: dict[str, str] = {} + for original in invalid: + new = _make_safe_id(original, used) + mapping[original] = new + used.add(new) + return mapping + + def _sanitize(self, request: ModelRequest) -> ModelRequest: + """Return a copy of `request` with sanitized tool-call IDs. + + Args: + request: The model request to sanitize. + + Returns: + A new request when any id was rewritten, otherwise the original. + """ + mapping = self._build_id_map(request.messages) + rewritten, changed = _rewrite_messages(request.messages, mapping) + if not changed: + return request + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "Sanitized %d tool-call id(s) for Anthropic compatibility: %s", + len(mapping), + {k: mapping[k] for k in sorted(mapping)}, + ) + return request.override(messages=rewritten) + + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: + """Sanitize tool-call IDs before delegating to the handler. + + Args: + request: The model request to potentially modify. + handler: The handler to execute the (possibly rewritten) request. + + Returns: + The model response produced by `handler`. + """ + if not self._should_run(request): + return handler(request) + return handler(self._sanitize(request)) + + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelCallResult: + """Sanitize tool-call IDs before delegating to the async handler. + + Args: + request: The model request to potentially modify. + handler: The async handler to execute the (possibly rewritten) + request. + + Returns: + The model response produced by `handler`. + """ + if not self._should_run(request): + return await handler(request) + return await handler(self._sanitize(request)) + + +def _iter_all_ids(messages: list[AnyMessage]) -> list[str]: + """Collect every tool-call ID present in the message list.""" + ids: list[str] = [] + for msg in messages: + if isinstance(msg, AIMessage): + for tc in msg.tool_calls or []: + tid = tc.get("id") + if tid: + ids.append(tid) + if isinstance(msg.content, list): + for block in msg.content: + if isinstance(block, dict) and _is_tool_use_type(block.get("type")): + tid = block.get("id") + if isinstance(tid, str) and tid: + ids.append(tid) + elif isinstance(msg, ToolMessage): + if msg.tool_call_id: + ids.append(msg.tool_call_id) + if isinstance(msg.content, list): + for block in msg.content: + if isinstance(block, dict) and _is_tool_result_type( + block.get("type") + ): + tid = block.get("tool_use_id") + if isinstance(tid, str) and tid: + ids.append(tid) + return ids + + +def _rewrite_messages( + messages: list[AnyMessage], mapping: dict[str, str] +) -> tuple[list[AnyMessage], bool]: + """Return messages rewritten with sanitized IDs and local drift aliases.""" + rewritten: list[AnyMessage] = [] + changed = False + active_result_aliases: dict[str, str] = {} + + for msg in messages: + new_msg: AnyMessage + if isinstance(msg, AIMessage): + new_msg, active_result_aliases = _rewrite_ai_message(msg, mapping) + elif isinstance(msg, ToolMessage): + result_mapping = {**mapping, **active_result_aliases} + new_msg = _rewrite_tool_message(msg, result_mapping) + else: + active_result_aliases = {} + new_msg = msg + + if new_msg is not msg: + changed = True + rewritten.append(new_msg) + + return rewritten, changed + + +def _rewrite_ai_message( + msg: AIMessage, mapping: dict[str, str] +) -> tuple[AIMessage, dict[str, str]]: + """Return a copy of `msg` with `tool_calls` and tool-use blocks aligned. + + Applies `mapping` to each id, then enforces position-paired alignment so + `tool_calls[i].id` and the i-th client `tool_use` content block share the + same final id even if their inputs drifted. + """ + result_aliases: dict[str, str] = {} + new_tool_calls: list[Any] | None = None + if msg.tool_calls and any(tc.get("id") in mapping for tc in msg.tool_calls): + new_tool_calls = [] + for tc in msg.tool_calls: + tid = tc.get("id") + if tid is not None and tid in mapping: + new_tool_calls.append({**tc, "id": mapping[tid]}) + else: + new_tool_calls.append(tc) + + new_content: list[Any] | None = None + if isinstance(msg.content, list) and any( + isinstance(b, dict) + and _is_tool_use_type(b.get("type")) + and b.get("id") in mapping + for b in msg.content + ): + new_content = [ + {**b, "id": mapping[b["id"]]} + if ( + isinstance(b, dict) + and _is_tool_use_type(b.get("type")) + and b.get("id") in mapping + ) + else b + for b in msg.content + ] + + # Drift correction: even if `tool_calls[i].id` and the i-th client + # `tool_use` block disagreed pre-mapping (or post-mapping due to distinct + # invalid inputs collapsing through different sanitization branches), force + # the content block to adopt the canonical `tool_calls[i].id`. `tool_calls` + # is the langchain-canonical view. + effective_tool_calls = ( + new_tool_calls if new_tool_calls is not None else msg.tool_calls + ) + effective_content = new_content if new_content is not None else msg.content + if effective_tool_calls and isinstance(effective_content, list): + aligned, result_aliases, drift_seen = _align_client_tool_use_blocks( + effective_tool_calls, + effective_content, + msg.content if isinstance(msg.content, list) else effective_content, + ) + if drift_seen: + new_content = aligned + + updates: dict[str, Any] = {} + if new_tool_calls is not None: + updates["tool_calls"] = new_tool_calls + if new_content is not None: + updates["content"] = new_content + return (msg.model_copy(update=updates) if updates else msg), result_aliases + + +def _align_client_tool_use_blocks( + tool_calls: list[Any], content: list[Any], original_content: list[Any] +) -> tuple[list[Any], dict[str, str], bool]: + """Force tool-use content blocks to share IDs with corresponding tool_calls. + + Returns the (possibly rewritten) content list plus aliases from drifted + content-block IDs to the canonical tool-call ID. These aliases are applied + only to the following `ToolMessage` results for this assistant turn. + """ + result_aliases: dict[str, str] = {} + drift_seen = False + aligned: list[Any] = list(content) + tu_indices = [ + idx + for idx, b in enumerate(content) + if isinstance(b, dict) and _is_client_tool_use_type(b.get("type")) + ] + for pair_idx, content_idx in enumerate(tu_indices): + if pair_idx >= len(tool_calls): + break + canonical_id = tool_calls[pair_idx].get("id") + block = aligned[content_idx] + if not isinstance(block, dict): + continue + if canonical_id and block.get("id") != canonical_id: + aligned[content_idx] = {**block, "id": canonical_id} + drift_seen = True + original_block = original_content[content_idx] + if isinstance(original_block, dict): + original_id = original_block.get("id") + if isinstance(original_id, str) and original_id != canonical_id: + result_aliases[original_id] = canonical_id + current_id = block.get("id") + if isinstance(current_id, str) and current_id != canonical_id: + result_aliases[current_id] = canonical_id + return aligned, result_aliases, drift_seen + + +def _rewrite_tool_message(msg: ToolMessage, mapping: dict[str, str]) -> ToolMessage: + """Return a copy of `msg` with `tool_call_id` and tool-result blocks rewritten.""" + updates: dict[str, Any] = {} + + if msg.tool_call_id in mapping: + updates["tool_call_id"] = mapping[msg.tool_call_id] + + if isinstance(msg.content, list) and any( + isinstance(b, dict) + and _is_tool_result_type(b.get("type")) + and b.get("tool_use_id") in mapping + for b in msg.content + ): + updates["content"] = [ + {**b, "tool_use_id": mapping[b["tool_use_id"]]} + if ( + isinstance(b, dict) + and _is_tool_result_type(b.get("type")) + and b.get("tool_use_id") in mapping + ) + else b + for b in msg.content + ] + + return msg.model_copy(update=updates) if updates else msg diff --git a/libs/partners/anthropic/tests/unit_tests/middleware/test_tool_id_sanitization.py b/libs/partners/anthropic/tests/unit_tests/middleware/test_tool_id_sanitization.py new file mode 100644 index 00000000000..606dae2b82e --- /dev/null +++ b/libs/partners/anthropic/tests/unit_tests/middleware/test_tool_id_sanitization.py @@ -0,0 +1,888 @@ +"""Tests for Anthropic tool-call ID sanitization middleware.""" + +import logging +import re +import warnings +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from langchain.agents.middleware.types import AgentState, ModelRequest, ModelResponse +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import ( + AIMessage, + AnyMessage, + BaseMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) +from langchain_core.outputs import ChatGeneration, ChatResult +from langgraph.runtime import Runtime + +from langchain_anthropic.chat_models import ChatAnthropic +from langchain_anthropic.middleware import AnthropicToolIdSanitizationMiddleware + +_ANTHROPIC_ID_RE = re.compile(r"^[a-zA-Z0-9_-]+$") + + +class _FakeModel(BaseChatModel): + """Stand-in non-Anthropic model for unsupported-model tests.""" + + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + return ChatResult( + generations=[ChatGeneration(message=AIMessage(content="ok", id="0"))] + ) + + async def _agenerate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: AsyncCallbackManagerForLLMRun | None = None, + **kwargs: Any, + ) -> ChatResult: + return ChatResult( + generations=[ChatGeneration(message=AIMessage(content="ok", id="0"))] + ) + + @property + def _llm_type(self) -> str: + return "fake" + + +def _make_request( + messages: list[BaseMessage], *, anthropic: bool = True +) -> ModelRequest: + """Build a `ModelRequest` for testing.""" + if anthropic: + mock_model = MagicMock(spec=ChatAnthropic) + mock_model._llm_type = "anthropic-chat" + model: BaseChatModel = cast(BaseChatModel, mock_model) + else: + model = _FakeModel() + return ModelRequest( + model=model, + messages=cast(list[AnyMessage], messages), + system_prompt=None, + tool_choice=None, + tools=[], + response_format=None, + state=cast(AgentState[Any], {"messages": messages}), + runtime=cast(Runtime, object()), + model_settings={}, + ) + + +def _capture( + messages: list[BaseMessage], +) -> tuple[ + AnthropicToolIdSanitizationMiddleware, + ModelRequest, + list[ModelRequest], +]: + """Wire a middleware + request + capturing handler list for a test.""" + middleware = AnthropicToolIdSanitizationMiddleware() + request = _make_request(messages) + captured: list[ModelRequest] = [] + + return middleware, request, captured + + +def _handler_factory(captured: list[ModelRequest]) -> Any: + def _handler(req: ModelRequest) -> ModelResponse: + captured.append(req) + return ModelResponse(result=[AIMessage(content="ok")]) + + return _handler + + +def _valid_tool_call_id(value: str | None) -> str: + """Assert that an optional tool-call ID is present and Anthropic-safe.""" + assert value is not None + assert _ANTHROPIC_ID_RE.match(value) + return value + + +def test_passthrough_when_all_ids_valid() -> None: + """Valid IDs short-circuit: every message reaches the handler unchanged.""" + messages: list[BaseMessage] = [ + HumanMessage("hi"), + AIMessage( + content="", + tool_calls=[ + {"name": "grep", "args": {}, "id": "toolu_01abc", "type": "tool_call"} + ], + ), + ToolMessage(content="done", tool_call_id="toolu_01abc"), + ] + middleware, request, captured = _capture(messages) + + middleware.wrap_model_call(request, _handler_factory(captured)) + + assert len(captured) == 1 + # Per-message identity: each individual message is the original instance. + for original, sent in zip(request.messages, captured[0].messages, strict=True): + assert sent is original + + +def test_rewrites_kimi_k2_style_ids_in_pairs() -> None: + """Bad IDs on AIMessage tool_calls and matching ToolMessage are rewritten.""" + messages: list[BaseMessage] = [ + HumanMessage("hi"), + AIMessage( + content="", + tool_calls=[ + { + "name": "grep", + "args": {"q": "x"}, + "id": "functions.grep:0", + "type": "tool_call", + }, + { + "name": "grep", + "args": {"q": "y"}, + "id": "functions.grep:1", + "type": "tool_call", + }, + ], + ), + ToolMessage(content="r0", tool_call_id="functions.grep:0"), + ToolMessage(content="r1", tool_call_id="functions.grep:1"), + ] + middleware, request, captured = _capture(messages) + + middleware.wrap_model_call(request, _handler_factory(captured)) + + sent = captured[0].messages + ai = sent[1] + assert isinstance(ai, AIMessage) + new_ids = [_valid_tool_call_id(tc["id"]) for tc in ai.tool_calls] + tool_msgs = [m for m in sent if isinstance(m, ToolMessage)] + assert [tm.tool_call_id for tm in tool_msgs] == new_ids + assert new_ids[0] != new_ids[1] + + +def test_rewrite_is_deterministic_across_invocations() -> None: + """The same input produces the same sanitized IDs across separate calls.""" + payload: list[BaseMessage] = [ + AIMessage( + content="", + tool_calls=[ + { + "name": "grep", + "args": {}, + "id": "functions.grep:0", + "type": "tool_call", + }, + { + "name": "grep", + "args": {}, + "id": "functions.grep:1", + "type": "tool_call", + }, + ], + ), + ToolMessage(content="r0", tool_call_id="functions.grep:0"), + ToolMessage(content="r1", tool_call_id="functions.grep:1"), + ] + + def _run() -> list[str]: + # Build a fresh copy each invocation so the two runs are independent. + msgs = [m.model_copy(deep=True) for m in payload] + middleware = AnthropicToolIdSanitizationMiddleware() + request = _make_request(msgs) + captured: list[ModelRequest] = [] + middleware.wrap_model_call(request, _handler_factory(captured)) + ai = cast(AIMessage, captured[0].messages[0]) + return [_valid_tool_call_id(tc["id"]) for tc in ai.tool_calls] + + assert _run() == _run() + + +def test_rewrites_anthropic_content_blocks() -> None: + """`tool_use` and `tool_result` content blocks have their IDs rewritten too.""" + messages: list[BaseMessage] = [ + AIMessage( + content=[ + {"type": "text", "text": "calling"}, + { + "type": "tool_use", + "id": "functions.read_file:0", + "name": "read_file", + "input": {}, + }, + ], + tool_calls=[ + { + "name": "read_file", + "args": {}, + "id": "functions.read_file:0", + "type": "tool_call", + }, + ], + ), + ToolMessage( + content=[ + { + "type": "tool_result", + "tool_use_id": "functions.read_file:0", + "content": "hi", + }, + ], + tool_call_id="functions.read_file:0", + ), + ] + middleware, request, captured = _capture(messages) + + middleware.wrap_model_call(request, _handler_factory(captured)) + + sent = captured[0].messages + ai = sent[0] + tm = sent[1] + assert isinstance(ai, AIMessage) + assert isinstance(tm, ToolMessage) + + block_id = next( + b["id"] + for b in ai.content + if isinstance(b, dict) and b.get("type") == "tool_use" + ) + tool_call_id = _valid_tool_call_id(ai.tool_calls[0]["id"]) + result_id = next( + b["tool_use_id"] + for b in tm.content + if isinstance(b, dict) and b.get("type") == "tool_result" + ) + + assert block_id == tool_call_id == tm.tool_call_id == result_id + for value in (block_id, tool_call_id, tm.tool_call_id, result_id): + assert _ANTHROPIC_ID_RE.match(value) + + +def test_drift_between_tool_calls_and_tool_use_blocks_is_corrected() -> None: + """When `tool_calls[i].id` and the i-th tool_use block disagree, alignment wins. + + Without correction, two distinct invalid IDs would map to two different + safe IDs (`a_b` vs `a_b_`), and Anthropic would 400 on the + mismatched pair. The middleware forces both to share `tool_calls[i].id`. + """ + messages: list[BaseMessage] = [ + AIMessage( + content=[ + { + "type": "tool_use", + "id": "a:b", # drift — different illegal id + "name": "tool", + "input": {}, + }, + ], + tool_calls=[ + { + "name": "tool", + "args": {}, + "id": "a.b", # canonical + "type": "tool_call", + } + ], + ), + ToolMessage( + content=[ + { + "type": "tool_result", + "tool_use_id": "a:b", + "content": "r", + }, + ], + tool_call_id="a:b", + ), + ] + middleware, request, captured = _capture(messages) + + middleware.wrap_model_call(request, _handler_factory(captured)) + + sent = captured[0].messages + ai = cast(AIMessage, sent[0]) + tm = cast(ToolMessage, sent[1]) + + final_tc_id = _valid_tool_call_id(ai.tool_calls[0]["id"]) + final_block_id = next( + b["id"] + for b in ai.content + if isinstance(b, dict) and b.get("type") == "tool_use" + ) + final_result_id = next( + b["tool_use_id"] + for b in tm.content + if isinstance(b, dict) and b.get("type") == "tool_result" + ) + + assert final_tc_id == final_block_id == tm.tool_call_id == final_result_id + + +def test_valid_drift_between_tool_calls_and_tool_use_blocks_is_corrected() -> None: + """Valid but mismatched tool-call and `tool_use` IDs are aligned.""" + messages: list[BaseMessage] = [ + AIMessage( + content=[ + { + "type": "tool_use", + "id": "toolu_content", + "name": "tool", + "input": {}, + }, + ], + tool_calls=[ + { + "name": "tool", + "args": {}, + "id": "toolu_call", + "type": "tool_call", + } + ], + ), + ToolMessage( + content=[ + { + "type": "tool_result", + "tool_use_id": "toolu_content", + "content": "r", + }, + ], + tool_call_id="toolu_content", + ), + ] + middleware, request, captured = _capture(messages) + + middleware.wrap_model_call(request, _handler_factory(captured)) + + sent = captured[0].messages + ai = cast(AIMessage, sent[0]) + tm = cast(ToolMessage, sent[1]) + block_id = next( + b["id"] + for b in ai.content + if isinstance(b, dict) and b.get("type") == "tool_use" + ) + result_id = next( + b["tool_use_id"] + for b in tm.content + if isinstance(b, dict) and b.get("type") == "tool_result" + ) + + assert ai.tool_calls[0]["id"] == "toolu_call" + assert block_id == "toolu_call" + assert tm.tool_call_id == "toolu_call" + assert result_id == "toolu_call" + + +def test_two_invalid_ids_with_same_base_get_distinct_safe_ids() -> None: + """`a.b` and `a:b` both sanitize to base `a_b` — second falls back to suffix. + + Exercises the sha256-suffix branch of `_make_safe_id`. Both AIMessages + use the same callable name so name-based heuristics can't disambiguate; + only the `_make_safe_id` collision logic produces distinct safe IDs. + """ + messages: list[BaseMessage] = [ + AIMessage( + content="", + tool_calls=[{"name": "tool", "args": {}, "id": "a.b", "type": "tool_call"}], + ), + ToolMessage(content="r1", tool_call_id="a.b"), + AIMessage( + content="", + tool_calls=[{"name": "tool", "args": {}, "id": "a:b", "type": "tool_call"}], + ), + ToolMessage(content="r2", tool_call_id="a:b"), + ] + middleware, request, captured = _capture(messages) + + middleware.wrap_model_call(request, _handler_factory(captured)) + + sent = captured[0].messages + first_id = _valid_tool_call_id(cast(AIMessage, sent[0]).tool_calls[0]["id"]) + second_id = _valid_tool_call_id(cast(AIMessage, sent[2]).tool_calls[0]["id"]) + + # Both sanitize cleanly, both are valid, both are distinct. + assert first_id != second_id + # The first invalid wins the base; the second gets the suffix. + assert first_id == "a_b" + assert second_id.startswith("a_b_") + # Pairs survive: ToolMessage ids match their AIMessage counterparts. + assert cast(ToolMessage, sent[1]).tool_call_id == first_id + assert cast(ToolMessage, sent[3]).tool_call_id == second_id + + +def test_state_is_not_mutated() -> None: + """Original messages on the request and in graph state are not mutated.""" + original_id = "functions.grep:0" + messages: list[BaseMessage] = [ + AIMessage( + content="", + tool_calls=[ + {"name": "grep", "args": {}, "id": original_id, "type": "tool_call"} + ], + ), + ToolMessage(content="r", tool_call_id=original_id), + ] + middleware, request, captured = _capture(messages) + + middleware.wrap_model_call(request, _handler_factory(captured)) + + assert request.messages[0].tool_calls[0]["id"] == original_id # type: ignore[union-attr] + assert cast(ToolMessage, request.messages[1]).tool_call_id == original_id + assert captured[0].messages is not request.messages + + +def test_collision_avoidance_with_existing_valid_id() -> None: + """A bad ID that sanitizes to an already-used valid ID gets a hash suffix.""" + messages: list[BaseMessage] = [ + AIMessage( + content="", + tool_calls=[ + { + "name": "grep", + "args": {}, + "id": "functions_grep_0", + "type": "tool_call", + }, + ], + ), + ToolMessage(content="a", tool_call_id="functions_grep_0"), + AIMessage( + content="", + tool_calls=[ + { + "name": "grep", + "args": {}, + "id": "functions.grep:0", + "type": "tool_call", + }, + ], + ), + ToolMessage(content="b", tool_call_id="functions.grep:0"), + ] + middleware, request, captured = _capture(messages) + + middleware.wrap_model_call(request, _handler_factory(captured)) + + sent = captured[0].messages + first_id = _valid_tool_call_id(cast(AIMessage, sent[0]).tool_calls[0]["id"]) + second_id = _valid_tool_call_id(cast(AIMessage, sent[2]).tool_calls[0]["id"]) + + assert first_id == "functions_grep_0" + assert second_id != first_id + + +def test_mixed_valid_and_invalid_ids_in_same_message() -> None: + """Valid IDs are preserved while siblings with invalid IDs are rewritten.""" + messages: list[BaseMessage] = [ + AIMessage( + content="", + tool_calls=[ + { + "name": "good", + "args": {}, + "id": "toolu_clean", + "type": "tool_call", + }, + { + "name": "bad", + "args": {}, + "id": "functions.bad:0", + "type": "tool_call", + }, + ], + ), + ToolMessage(content="g", tool_call_id="toolu_clean"), + ToolMessage(content="b", tool_call_id="functions.bad:0"), + ] + middleware, request, captured = _capture(messages) + + middleware.wrap_model_call(request, _handler_factory(captured)) + + sent = captured[0].messages + ai = cast(AIMessage, sent[0]) + valid_id = _valid_tool_call_id(ai.tool_calls[0]["id"]) + rewritten_id = _valid_tool_call_id(ai.tool_calls[1]["id"]) + + assert valid_id == "toolu_clean" # untouched + assert rewritten_id != "functions.bad:0" + assert cast(ToolMessage, sent[1]).tool_call_id == valid_id + assert cast(ToolMessage, sent[2]).tool_call_id == rewritten_id + + +def test_rewrites_server_and_mcp_block_types() -> None: + """`server_tool_use` / `mcp_tool_use` and their result variants are also handled.""" + messages: list[BaseMessage] = [ + AIMessage( + content=[ + { + "type": "server_tool_use", + "id": "srv.1", + "name": "web_search", + "input": {}, + }, + { + "type": "mcp_tool_use", + "id": "mcp.2", + "name": "fetch", + "input": {}, + }, + ], + tool_calls=[], + ), + ToolMessage( + content=[ + { + "type": "web_search_tool_result", + "tool_use_id": "srv.1", + "content": "hits", + }, + ], + tool_call_id="srv.1", + ), + ToolMessage( + content=[ + { + "type": "mcp_tool_result", + "tool_use_id": "mcp.2", + "content": "data", + }, + ], + tool_call_id="mcp.2", + ), + ] + middleware, request, captured = _capture(messages) + + middleware.wrap_model_call(request, _handler_factory(captured)) + + sent = captured[0].messages + ai = cast(AIMessage, sent[0]) + server_block_id = next( + b["id"] + for b in ai.content + if isinstance(b, dict) and b.get("type") == "server_tool_use" + ) + mcp_block_id = next( + b["id"] + for b in ai.content + if isinstance(b, dict) and b.get("type") == "mcp_tool_use" + ) + + for value in ( + server_block_id, + mcp_block_id, + cast(ToolMessage, sent[1]).tool_call_id, + cast(ToolMessage, sent[2]).tool_call_id, + ): + assert _ANTHROPIC_ID_RE.match(value) + # Pair survival: + web_result_id = next( + b["tool_use_id"] + for b in cast(ToolMessage, sent[1]).content + if isinstance(b, dict) and b.get("type") == "web_search_tool_result" + ) + mcp_result_id = next( + b["tool_use_id"] + for b in cast(ToolMessage, sent[2]).content + if isinstance(b, dict) and b.get("type") == "mcp_tool_result" + ) + assert web_result_id == server_block_id + assert mcp_result_id == mcp_block_id + + +def test_client_alignment_skips_server_and_mcp_tool_use_blocks() -> None: + """Server and MCP tool-use blocks do not position-pair with `tool_calls`.""" + messages: list[BaseMessage] = [ + AIMessage( + content=[ + { + "type": "server_tool_use", + "id": "srv.1", + "name": "web_search", + "input": {}, + }, + { + "type": "mcp_tool_use", + "id": "mcp.2", + "name": "fetch", + "input": {}, + }, + { + "type": "tool_use", + "id": "client.3", + "name": "client_tool", + "input": {}, + }, + ], + tool_calls=[ + { + "name": "client_tool", + "args": {}, + "id": "client.3", + "type": "tool_call", + }, + ], + ), + ToolMessage( + content=[ + { + "type": "web_search_tool_result", + "tool_use_id": "srv.1", + "content": "hits", + }, + ], + tool_call_id="srv.1", + ), + ToolMessage( + content=[ + { + "type": "mcp_tool_result", + "tool_use_id": "mcp.2", + "content": "data", + }, + ], + tool_call_id="mcp.2", + ), + ToolMessage( + content=[ + { + "type": "tool_result", + "tool_use_id": "client.3", + "content": "done", + }, + ], + tool_call_id="client.3", + ), + ] + middleware, request, captured = _capture(messages) + + middleware.wrap_model_call(request, _handler_factory(captured)) + + sent = captured[0].messages + ai = cast(AIMessage, sent[0]) + server_block_id = next( + b["id"] + for b in ai.content + if isinstance(b, dict) and b.get("type") == "server_tool_use" + ) + mcp_block_id = next( + b["id"] + for b in ai.content + if isinstance(b, dict) and b.get("type") == "mcp_tool_use" + ) + client_block_id = next( + b["id"] + for b in ai.content + if isinstance(b, dict) and b.get("type") == "tool_use" + ) + client_call_id = _valid_tool_call_id(ai.tool_calls[0]["id"]) + web_result_id = next( + b["tool_use_id"] + for b in cast(ToolMessage, sent[1]).content + if isinstance(b, dict) and b.get("type") == "web_search_tool_result" + ) + mcp_result_id = next( + b["tool_use_id"] + for b in cast(ToolMessage, sent[2]).content + if isinstance(b, dict) and b.get("type") == "mcp_tool_result" + ) + client_result_id = next( + b["tool_use_id"] + for b in cast(ToolMessage, sent[3]).content + if isinstance(b, dict) and b.get("type") == "tool_result" + ) + + assert server_block_id == web_result_id == cast(ToolMessage, sent[1]).tool_call_id + assert mcp_block_id == mcp_result_id == cast(ToolMessage, sent[2]).tool_call_id + assert client_block_id == client_call_id == client_result_id + assert cast(ToolMessage, sent[3]).tool_call_id == client_call_id + assert server_block_id != client_call_id + assert mcp_block_id != client_call_id + + +def test_partial_update_only_rewrites_what_changed() -> None: + """An AIMessage with a clean tool_calls but dirty content block triggers + rewrite of the content only, not the tool_calls list (partial-update path). + """ + messages: list[BaseMessage] = [ + AIMessage( + content=[ + {"type": "text", "text": "thinking"}, + { + "type": "tool_use", + "id": "functions.x:0", # invalid + "name": "x", + "input": {}, + }, + ], + tool_calls=[ + # Valid id; should NOT be rewritten by mapping. Drift correction + # will re-align the content block to use this id. + {"name": "x", "args": {}, "id": "toolu_canon", "type": "tool_call"}, + ], + ), + ToolMessage(content="r", tool_call_id="toolu_canon"), + ] + middleware, request, captured = _capture(messages) + + middleware.wrap_model_call(request, _handler_factory(captured)) + + sent = captured[0].messages + ai = cast(AIMessage, sent[0]) + block_id = next( + b["id"] + for b in ai.content + if isinstance(b, dict) and b.get("type") == "tool_use" + ) + # Drift correction collapses the divergent block id onto the canonical one. + assert ai.tool_calls[0]["id"] == "toolu_canon" + assert block_id == "toolu_canon" + assert cast(ToolMessage, sent[1]).tool_call_id == "toolu_canon" + + +def test_unrelated_message_subclasses_pass_through() -> None: + """SystemMessage and HumanMessage are returned untouched even mid-rewrite.""" + sys_msg = SystemMessage(content="you are an agent") + human = HumanMessage(content="hi") + messages: list[BaseMessage] = [ + sys_msg, + human, + AIMessage( + content="", + tool_calls=[ + {"name": "x", "args": {}, "id": "functions.x:0", "type": "tool_call"} + ], + ), + ToolMessage(content="r", tool_call_id="functions.x:0"), + ] + middleware, request, captured = _capture(messages) + + middleware.wrap_model_call(request, _handler_factory(captured)) + + sent = captured[0].messages + # Non-AI/non-Tool messages keep their identity. + assert sent[0] is sys_msg + assert sent[1] is human + + +async def test_async_path_rewrites_ids() -> None: + """`awrap_model_call` rewrites IDs the same way as the sync path.""" + messages: list[BaseMessage] = [ + AIMessage( + content="", + tool_calls=[ + { + "name": "grep", + "args": {}, + "id": "functions.grep:0", + "type": "tool_call", + } + ], + ), + ToolMessage(content="r", tool_call_id="functions.grep:0"), + ] + middleware, request, captured = _capture(messages) + + async def handler(req: ModelRequest) -> ModelResponse: + captured.append(req) + return ModelResponse(result=[AIMessage(content="ok")]) + + await middleware.awrap_model_call(request, handler) + + sent = captured[0].messages + rewritten_id = _valid_tool_call_id(cast(AIMessage, sent[0]).tool_calls[0]["id"]) + assert cast(ToolMessage, sent[1]).tool_call_id == rewritten_id + + +def test_unsupported_model_ignore_default_skips_silently() -> None: + """Default `unsupported_model_behavior='ignore'` does not warn or rewrite.""" + messages: list[BaseMessage] = [ + AIMessage( + content="", + tool_calls=[ + { + "name": "grep", + "args": {}, + "id": "functions.grep:0", + "type": "tool_call", + } + ], + ), + ToolMessage(content="r", tool_call_id="functions.grep:0"), + ] + middleware = AnthropicToolIdSanitizationMiddleware() + request = _make_request(messages, anthropic=False) + captured: list[ModelRequest] = [] + + with warnings.catch_warnings(): + warnings.simplefilter("error") + middleware.wrap_model_call(request, _handler_factory(captured)) + + assert captured[0].messages is request.messages + + +def test_unsupported_model_warn_emits_warning() -> None: + """`unsupported_model_behavior='warn'` warns and skips rewrite.""" + middleware = AnthropicToolIdSanitizationMiddleware( + unsupported_model_behavior="warn" + ) + request = _make_request([HumanMessage("hi")], anthropic=False) + captured: list[ModelRequest] = [] + + with pytest.warns(UserWarning, match="only supports Anthropic"): + middleware.wrap_model_call(request, _handler_factory(captured)) + + +def test_unsupported_model_raise_errors() -> None: + """`unsupported_model_behavior='raise'` raises `ValueError`.""" + middleware = AnthropicToolIdSanitizationMiddleware( + unsupported_model_behavior="raise" + ) + request = _make_request([HumanMessage("hi")], anthropic=False) + captured: list[ModelRequest] = [] + + with pytest.raises(ValueError, match="only supports Anthropic"): + middleware.wrap_model_call(request, _handler_factory(captured)) + + +def test_invalid_unsupported_model_behavior_rejected() -> None: + """A typo in `unsupported_model_behavior` raises `ValueError` at construction. + + `Literal` is enforced statically only; without runtime validation a typo + silently falls through to ignore semantics — a footgun precisely for users + who opted into `'raise'` to surface bugs. + """ + with pytest.raises(ValueError, match="unsupported_model_behavior"): + AnthropicToolIdSanitizationMiddleware( + unsupported_model_behavior="raies", # type: ignore[arg-type] + ) + + +def test_unsupported_model_ignore_logs_debug_breadcrumb( + caplog: pytest.LogCaptureFixture, +) -> None: + """Default `'ignore'` mode emits a debug log so the bypass is observable.""" + middleware = AnthropicToolIdSanitizationMiddleware() + request = _make_request([HumanMessage("hi")], anthropic=False) + captured: list[ModelRequest] = [] + + with caplog.at_level( + logging.DEBUG, logger="langchain_anthropic.middleware.tool_id_sanitization" + ): + middleware.wrap_model_call(request, _handler_factory(captured)) + + assert any( + "skipped" in record.message and "_FakeModel" in record.message + for record in caplog.records + ) diff --git a/libs/partners/anthropic/uv.lock b/libs/partners/anthropic/uv.lock index 5d3a07bc6dd..a62670bd1db 100644 --- a/libs/partners/anthropic/uv.lock +++ b/libs/partners/anthropic/uv.lock @@ -510,7 +510,7 @@ wheels = [ [[package]] name = "langchain" -version = "1.2.15" +version = "1.2.16" source = { editable = "../../langchain_v1" } dependencies = [ { name = "langchain-core" }, @@ -734,7 +734,7 @@ wheels = [ [[package]] name = "langchain-tests" -version = "1.1.6" +version = "1.1.7" source = { editable = "../../standard-tests" } dependencies = [ { name = "httpx" },