From 578cef96223fd3cf53d722ef97698975b6675887 Mon Sep 17 00:00:00 2001 From: lwtaiyty <146391307+lwtaiyty@users.noreply.github.com> Date: Tue, 6 Jan 2026 04:40:59 +0800 Subject: [PATCH] fix(anthropic): skip cache_control for code_execution blocks (#34579) --- .../langchain_anthropic/chat_models.py | 95 +++++++- .../middleware/test_prompt_caching.py | 207 +++++++++++++++++- 2 files changed, 291 insertions(+), 11 deletions(-) diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index c62b050fd6e..d014fd82d92 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -651,6 +651,65 @@ def _format_messages( return system, formatted_messages +def _collect_code_execution_tool_ids(formatted_messages: list[dict]) -> set[str]: + """Collect tool_use IDs that were called by code_execution. + + These blocks cannot have cache_control applied per Anthropic API requirements. + """ + code_execution_tool_ids: set[str] = set() + + for message in formatted_messages: + if message.get("role") != "assistant": + continue + content = message.get("content", []) + if not isinstance(content, list): + continue + for block in content: + if not isinstance(block, dict): + continue + if block.get("type") != "tool_use": + continue + caller = block.get("caller") + if isinstance(caller, dict): + caller_type = caller.get("type", "") + if caller_type.startswith("code_execution"): + tool_id = block.get("id") + if tool_id: + code_execution_tool_ids.add(tool_id) + + return code_execution_tool_ids + + +def _is_code_execution_related_block( + block: dict, + code_execution_tool_ids: set[str], +) -> bool: + """Check if a content block is related to code_execution. + + Returns True for blocks that should NOT have cache_control applied. + """ + if not isinstance(block, dict): + return False + + block_type = block.get("type") + + # tool_use blocks called by code_execution + if block_type == "tool_use": + caller = block.get("caller") + if isinstance(caller, dict): + caller_type = caller.get("type", "") + if caller_type.startswith("code_execution"): + return True + + # tool_result blocks for code_execution called tools + if block_type == "tool_result": + tool_use_id = block.get("tool_use_id") + if tool_use_id and tool_use_id in code_execution_tool_ids: + return True + + return False + + def _handle_anthropic_bad_request(e: anthropic.BadRequestError) -> None: """Handle Anthropic BadRequestError.""" if ("messages: at least one message is required") in e.message: @@ -1008,17 +1067,33 @@ class ChatAnthropic(BaseChatModel): system, formatted_messages = _format_messages(messages) - # If cache_control is provided in kwargs, add it to the last message with - # content (Anthropic requires cache_control to be nested within a message - # block). + # If cache_control is provided in kwargs, add it to the last eligible message + # block (Anthropic requires cache_control to be nested within a message block). + # Skip blocks related to code_execution as they cannot have cache_control. cache_control = kwargs.pop("cache_control", None) if cache_control and formatted_messages: + # Collect tool IDs called by code_execution + code_execution_tool_ids = _collect_code_execution_tool_ids( + formatted_messages + ) + + cache_applied = False for formatted_message in reversed(formatted_messages): + if cache_applied: + break content = formatted_message.get("content") if isinstance(content, list) and content: - content[-1]["cache_control"] = cache_control - break - if isinstance(content, str): + # Find last eligible block (not code_execution related) + for block in reversed(content): + if isinstance(block, dict): + if _is_code_execution_related_block( + block, code_execution_tool_ids + ): + continue + block["cache_control"] = cache_control + cache_applied = True + break + elif isinstance(content, str): formatted_message["content"] = [ { "type": "text", @@ -1026,10 +1101,10 @@ class ChatAnthropic(BaseChatModel): "cache_control": cache_control, } ] - break - # If we didn't find a message with content we silently drop the control. - # Anthropic would reject a payload with empty content blocks. - + cache_applied = True + # If we didn't find an eligible block we silently drop the control. + # Anthropic would reject a payload with cache_control on + # code_execution blocks. payload = { "model": self.model, "max_tokens": self.max_tokens, diff --git a/libs/partners/anthropic/tests/unit_tests/middleware/test_prompt_caching.py b/libs/partners/anthropic/tests/unit_tests/middleware/test_prompt_caching.py index 634d8ff91a8..50594a29092 100644 --- a/libs/partners/anthropic/tests/unit_tests/middleware/test_prompt_caching.py +++ b/libs/partners/anthropic/tests/unit_tests/middleware/test_prompt_caching.py @@ -15,7 +15,11 @@ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.outputs import ChatGeneration, ChatResult from langgraph.runtime import Runtime -from langchain_anthropic.chat_models import ChatAnthropic +from langchain_anthropic.chat_models import ( + ChatAnthropic, + _collect_code_execution_tool_ids, + _is_code_execution_related_block, +) from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware @@ -334,3 +338,204 @@ async def test_anthropic_prompt_caching_middleware_async_default_values() -> Non assert modified_request.model_settings == { "cache_control": {"type": "ephemeral", "ttl": "5m"} } + + +class TestCollectCodeExecutionToolIds: + """Tests for _collect_code_execution_tool_ids function.""" + + def test_empty_messages(self) -> None: + """Test with empty messages list.""" + result = _collect_code_execution_tool_ids([]) + assert result == set() + + def test_no_code_execution_calls(self) -> None: + """Test messages without any code_execution calls.""" + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": "Hello"}], + }, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_regular", + "name": "get_weather", + "input": {"location": "NYC"}, + } + ], + }, + ] + result = _collect_code_execution_tool_ids(messages) + assert result == set() + + def test_single_code_execution_call(self) -> None: + """Test with a single code_execution tool call.""" + messages = [ + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_code_exec_1", + "name": "get_weather", + "input": {"location": "NYC"}, + "caller": { + "type": "code_execution_20250825", + "tool_id": "srvtoolu_abc123", + }, + } + ], + }, + ] + result = _collect_code_execution_tool_ids(messages) + assert result == {"toolu_code_exec_1"} + + def test_multiple_code_execution_calls(self) -> None: + """Test with multiple code_execution tool calls.""" + messages = [ + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_regular", + "name": "search", + "input": {"query": "test"}, + }, + { + "type": "tool_use", + "id": "toolu_code_exec_1", + "name": "get_weather", + "input": {"location": "NYC"}, + "caller": { + "type": "code_execution_20250825", + "tool_id": "srvtoolu_abc", + }, + }, + { + "type": "tool_use", + "id": "toolu_code_exec_2", + "name": "get_weather", + "input": {"location": "SF"}, + "caller": { + "type": "code_execution_20250825", + "tool_id": "srvtoolu_def", + }, + }, + ], + }, + ] + result = _collect_code_execution_tool_ids(messages) + assert result == {"toolu_code_exec_1", "toolu_code_exec_2"} + assert "toolu_regular" not in result + + def test_future_code_execution_version(self) -> None: + """Test with a hypothetical future code_execution version.""" + messages = [ + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_future", + "name": "get_weather", + "input": {}, + "caller": { + "type": "code_execution_20260101", + "tool_id": "srvtoolu_future", + }, + } + ], + }, + ] + result = _collect_code_execution_tool_ids(messages) + assert result == {"toolu_future"} + + def test_ignores_user_messages(self) -> None: + """Test that user messages are ignored.""" + messages = [ + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_123", + "content": "result", + } + ], + }, + ] + result = _collect_code_execution_tool_ids(messages) + assert result == set() + + def test_handles_string_content(self) -> None: + """Test that string content is handled gracefully.""" + messages = [ + { + "role": "assistant", + "content": "Just a text response", + }, + ] + result = _collect_code_execution_tool_ids(messages) + assert result == set() + + +class TestIsCodeExecutionRelatedBlock: + """Tests for _is_code_execution_related_block function.""" + + def test_regular_tool_use_block(self) -> None: + """Test regular tool_use block without caller.""" + block = { + "type": "tool_use", + "id": "toolu_regular", + "name": "get_weather", + "input": {"location": "NYC"}, + } + assert not _is_code_execution_related_block(block, set()) + + def test_code_execution_tool_use_block(self) -> None: + """Test tool_use block called by code_execution.""" + block = { + "type": "tool_use", + "id": "toolu_code_exec", + "name": "get_weather", + "input": {"location": "NYC"}, + "caller": { + "type": "code_execution_20250825", + "tool_id": "srvtoolu_abc", + }, + } + assert _is_code_execution_related_block(block, set()) + + def test_regular_tool_result_block(self) -> None: + """Test tool_result block for regular tool.""" + block = { + "type": "tool_result", + "tool_use_id": "toolu_regular", + "content": "Sunny, 72°F", + } + code_exec_ids = {"toolu_code_exec"} + assert not _is_code_execution_related_block(block, code_exec_ids) + + def test_code_execution_tool_result_block(self) -> None: + """Test tool_result block for code_execution called tool.""" + block = { + "type": "tool_result", + "tool_use_id": "toolu_code_exec", + "content": "Sunny, 72°F", + } + code_exec_ids = {"toolu_code_exec"} + assert _is_code_execution_related_block(block, code_exec_ids) + + def test_text_block(self) -> None: + """Test that text blocks are not flagged.""" + block = {"type": "text", "text": "Hello world"} + assert not _is_code_execution_related_block(block, set()) + + def test_non_dict_block(self) -> None: + """Test that non-dict values return False.""" + assert not _is_code_execution_related_block("string", set()) # type: ignore[arg-type] + assert not _is_code_execution_related_block(None, set()) # type: ignore[arg-type] + assert not _is_code_execution_related_block(123, set()) # type: ignore[arg-type]