fix(anthropic): skip cache_control for code_execution blocks (#34579)

This commit is contained in:
lwtaiyty
2026-01-06 04:40:59 +08:00
committed by GitHub
parent 7979fd3d9f
commit 578cef9622
2 changed files with 291 additions and 11 deletions

View File

@@ -651,6 +651,65 @@ def _format_messages(
return system, formatted_messages
def _collect_code_execution_tool_ids(formatted_messages: list[dict]) -> set[str]:
"""Collect tool_use IDs that were called by code_execution.
These blocks cannot have cache_control applied per Anthropic API requirements.
"""
code_execution_tool_ids: set[str] = set()
for message in formatted_messages:
if message.get("role") != "assistant":
continue
content = message.get("content", [])
if not isinstance(content, list):
continue
for block in content:
if not isinstance(block, dict):
continue
if block.get("type") != "tool_use":
continue
caller = block.get("caller")
if isinstance(caller, dict):
caller_type = caller.get("type", "")
if caller_type.startswith("code_execution"):
tool_id = block.get("id")
if tool_id:
code_execution_tool_ids.add(tool_id)
return code_execution_tool_ids
def _is_code_execution_related_block(
block: dict,
code_execution_tool_ids: set[str],
) -> bool:
"""Check if a content block is related to code_execution.
Returns True for blocks that should NOT have cache_control applied.
"""
if not isinstance(block, dict):
return False
block_type = block.get("type")
# tool_use blocks called by code_execution
if block_type == "tool_use":
caller = block.get("caller")
if isinstance(caller, dict):
caller_type = caller.get("type", "")
if caller_type.startswith("code_execution"):
return True
# tool_result blocks for code_execution called tools
if block_type == "tool_result":
tool_use_id = block.get("tool_use_id")
if tool_use_id and tool_use_id in code_execution_tool_ids:
return True
return False
def _handle_anthropic_bad_request(e: anthropic.BadRequestError) -> None:
"""Handle Anthropic BadRequestError."""
if ("messages: at least one message is required") in e.message:
@@ -1008,17 +1067,33 @@ class ChatAnthropic(BaseChatModel):
system, formatted_messages = _format_messages(messages)
# If cache_control is provided in kwargs, add it to the last message with
# content (Anthropic requires cache_control to be nested within a message
# block).
# If cache_control is provided in kwargs, add it to the last eligible message
# block (Anthropic requires cache_control to be nested within a message block).
# Skip blocks related to code_execution as they cannot have cache_control.
cache_control = kwargs.pop("cache_control", None)
if cache_control and formatted_messages:
# Collect tool IDs called by code_execution
code_execution_tool_ids = _collect_code_execution_tool_ids(
formatted_messages
)
cache_applied = False
for formatted_message in reversed(formatted_messages):
if cache_applied:
break
content = formatted_message.get("content")
if isinstance(content, list) and content:
content[-1]["cache_control"] = cache_control
# Find last eligible block (not code_execution related)
for block in reversed(content):
if isinstance(block, dict):
if _is_code_execution_related_block(
block, code_execution_tool_ids
):
continue
block["cache_control"] = cache_control
cache_applied = True
break
if isinstance(content, str):
elif isinstance(content, str):
formatted_message["content"] = [
{
"type": "text",
@@ -1026,10 +1101,10 @@ class ChatAnthropic(BaseChatModel):
"cache_control": cache_control,
}
]
break
# If we didn't find a message with content we silently drop the control.
# Anthropic would reject a payload with empty content blocks.
cache_applied = True
# If we didn't find an eligible block we silently drop the control.
# Anthropic would reject a payload with cache_control on
# code_execution blocks.
payload = {
"model": self.model,
"max_tokens": self.max_tokens,

View File

@@ -15,7 +15,11 @@ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langgraph.runtime import Runtime
from langchain_anthropic.chat_models import ChatAnthropic
from langchain_anthropic.chat_models import (
ChatAnthropic,
_collect_code_execution_tool_ids,
_is_code_execution_related_block,
)
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
@@ -334,3 +338,204 @@ async def test_anthropic_prompt_caching_middleware_async_default_values() -> Non
assert modified_request.model_settings == {
"cache_control": {"type": "ephemeral", "ttl": "5m"}
}
class TestCollectCodeExecutionToolIds:
"""Tests for _collect_code_execution_tool_ids function."""
def test_empty_messages(self) -> None:
"""Test with empty messages list."""
result = _collect_code_execution_tool_ids([])
assert result == set()
def test_no_code_execution_calls(self) -> None:
"""Test messages without any code_execution calls."""
messages = [
{
"role": "user",
"content": [{"type": "text", "text": "Hello"}],
},
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "toolu_regular",
"name": "get_weather",
"input": {"location": "NYC"},
}
],
},
]
result = _collect_code_execution_tool_ids(messages)
assert result == set()
def test_single_code_execution_call(self) -> None:
"""Test with a single code_execution tool call."""
messages = [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "toolu_code_exec_1",
"name": "get_weather",
"input": {"location": "NYC"},
"caller": {
"type": "code_execution_20250825",
"tool_id": "srvtoolu_abc123",
},
}
],
},
]
result = _collect_code_execution_tool_ids(messages)
assert result == {"toolu_code_exec_1"}
def test_multiple_code_execution_calls(self) -> None:
"""Test with multiple code_execution tool calls."""
messages = [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "toolu_regular",
"name": "search",
"input": {"query": "test"},
},
{
"type": "tool_use",
"id": "toolu_code_exec_1",
"name": "get_weather",
"input": {"location": "NYC"},
"caller": {
"type": "code_execution_20250825",
"tool_id": "srvtoolu_abc",
},
},
{
"type": "tool_use",
"id": "toolu_code_exec_2",
"name": "get_weather",
"input": {"location": "SF"},
"caller": {
"type": "code_execution_20250825",
"tool_id": "srvtoolu_def",
},
},
],
},
]
result = _collect_code_execution_tool_ids(messages)
assert result == {"toolu_code_exec_1", "toolu_code_exec_2"}
assert "toolu_regular" not in result
def test_future_code_execution_version(self) -> None:
"""Test with a hypothetical future code_execution version."""
messages = [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "toolu_future",
"name": "get_weather",
"input": {},
"caller": {
"type": "code_execution_20260101",
"tool_id": "srvtoolu_future",
},
}
],
},
]
result = _collect_code_execution_tool_ids(messages)
assert result == {"toolu_future"}
def test_ignores_user_messages(self) -> None:
"""Test that user messages are ignored."""
messages = [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_123",
"content": "result",
}
],
},
]
result = _collect_code_execution_tool_ids(messages)
assert result == set()
def test_handles_string_content(self) -> None:
"""Test that string content is handled gracefully."""
messages = [
{
"role": "assistant",
"content": "Just a text response",
},
]
result = _collect_code_execution_tool_ids(messages)
assert result == set()
class TestIsCodeExecutionRelatedBlock:
"""Tests for _is_code_execution_related_block function."""
def test_regular_tool_use_block(self) -> None:
"""Test regular tool_use block without caller."""
block = {
"type": "tool_use",
"id": "toolu_regular",
"name": "get_weather",
"input": {"location": "NYC"},
}
assert not _is_code_execution_related_block(block, set())
def test_code_execution_tool_use_block(self) -> None:
"""Test tool_use block called by code_execution."""
block = {
"type": "tool_use",
"id": "toolu_code_exec",
"name": "get_weather",
"input": {"location": "NYC"},
"caller": {
"type": "code_execution_20250825",
"tool_id": "srvtoolu_abc",
},
}
assert _is_code_execution_related_block(block, set())
def test_regular_tool_result_block(self) -> None:
"""Test tool_result block for regular tool."""
block = {
"type": "tool_result",
"tool_use_id": "toolu_regular",
"content": "Sunny, 72°F",
}
code_exec_ids = {"toolu_code_exec"}
assert not _is_code_execution_related_block(block, code_exec_ids)
def test_code_execution_tool_result_block(self) -> None:
"""Test tool_result block for code_execution called tool."""
block = {
"type": "tool_result",
"tool_use_id": "toolu_code_exec",
"content": "Sunny, 72°F",
}
code_exec_ids = {"toolu_code_exec"}
assert _is_code_execution_related_block(block, code_exec_ids)
def test_text_block(self) -> None:
"""Test that text blocks are not flagged."""
block = {"type": "text", "text": "Hello world"}
assert not _is_code_execution_related_block(block, set())
def test_non_dict_block(self) -> None:
"""Test that non-dict values return False."""
assert not _is_code_execution_related_block("string", set()) # type: ignore[arg-type]
assert not _is_code_execution_related_block(None, set()) # type: ignore[arg-type]
assert not _is_code_execution_related_block(123, set()) # type: ignore[arg-type]