From 043ef0721ad9ea58a3e8d9f65df2e3e6c880434f Mon Sep 17 00:00:00 2001 From: ccurme Date: Tue, 17 Mar 2026 10:58:56 -0400 Subject: [PATCH] feat(anthropic): AnthropicPromptCachingMiddleware: apply explicit caching to system message and tool definitions (#35969) --- .../middleware/prompt_caching.py | 134 +++++++++- .../middleware/test_prompt_caching.py | 240 +++++++++++++++++- 2 files changed, 360 insertions(+), 14 deletions(-) diff --git a/libs/partners/anthropic/langchain_anthropic/middleware/prompt_caching.py b/libs/partners/anthropic/langchain_anthropic/middleware/prompt_caching.py index f27285c7f01..eb35b6b974f 100644 --- a/libs/partners/anthropic/langchain_anthropic/middleware/prompt_caching.py +++ b/libs/partners/anthropic/langchain_anthropic/middleware/prompt_caching.py @@ -5,10 +5,15 @@ Requires: - `langchain-anthropic`: For `ChatAnthropic` model (already a dependency) """ +from __future__ import annotations + from collections.abc import Awaitable, Callable -from typing import Literal +from typing import Any, Literal from warnings import warn +from langchain_core.messages import SystemMessage +from langchain_core.tools import BaseTool + from langchain_anthropic.chat_models import ChatAnthropic try: @@ -34,6 +39,15 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware): Requires both `langchain` and `langchain-anthropic` packages to be installed. + Applies cache control breakpoints to: + + - **System message**: Tags the last content block of the system message + with `cache_control` so static system prompt content is cached. + - **Tools**: Tags all tool definitions with `cache_control` so tool + schemas are cached across turns. + - **Last cacheable block**: Tags last cacheable block of message sequence using + Anthropic's automatic caching feature. + Learn more about Anthropic prompt caching [here](https://platform.claude.com/docs/en/build-with-claude/prompt-caching). """ @@ -68,6 +82,10 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware): self.min_messages_to_cache = min_messages_to_cache self.unsupported_model_behavior = unsupported_model_behavior + @property + def _cache_control(self) -> dict[str, str]: + return {"type": self.type, "ttl": self.ttl} + def _should_apply_caching(self, request: ModelRequest) -> bool: """Check if caching should be applied to the request. @@ -98,6 +116,33 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware): ) return messages_count >= self.min_messages_to_cache + def _apply_caching(self, request: ModelRequest) -> ModelRequest: + """Apply cache control to system message, tools, and model settings. + + Args: + request: The model request to modify. + + Returns: + New request with cache control applied. + """ + overrides: dict[str, Any] = {} + cache_control = self._cache_control + + overrides["model_settings"] = { + **request.model_settings, + "cache_control": cache_control, + } + + system_message = _tag_system_message(request.system_message, cache_control) + if system_message is not request.system_message: + overrides["system_message"] = system_message + + tools = _tag_tools(request.tools, cache_control) + if tools is not request.tools: + overrides["tools"] = tools + + return request.override(**overrides) + def wrap_model_call( self, request: ModelRequest, @@ -115,12 +160,7 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware): if not self._should_apply_caching(request): return handler(request) - model_settings = request.model_settings - new_model_settings = { - **model_settings, - "cache_control": {"type": self.type, "ttl": self.ttl}, - } - return handler(request.override(model_settings=new_model_settings)) + return handler(self._apply_caching(request)) async def awrap_model_call( self, @@ -139,9 +179,77 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware): if not self._should_apply_caching(request): return await handler(request) - model_settings = request.model_settings - new_model_settings = { - **model_settings, - "cache_control": {"type": self.type, "ttl": self.ttl}, - } - return await handler(request.override(model_settings=new_model_settings)) + return await handler(self._apply_caching(request)) + + +def _tag_system_message( + system_message: Any, + cache_control: dict[str, str], +) -> Any: + """Tag the last content block of a system message with cache_control. + + Returns the original system_message unchanged if there are no blocks + to tag. + + Args: + system_message: The system message to tag. + cache_control: The cache control dict to apply. + + Returns: + A new SystemMessage with cache_control on the last block, or the + original if no modification was needed. + """ + if system_message is None: + return system_message + + content = system_message.content + if isinstance(content, str): + if not content: + return system_message + new_content: list[str | dict[str, Any]] = [ + {"type": "text", "text": content, "cache_control": cache_control} + ] + elif isinstance(content, list): + if not content: + return system_message + new_content = list(content) + last = new_content[-1] + base = last if isinstance(last, dict) else {} + new_content[-1] = {**base, "cache_control": cache_control} + else: + return system_message + + return SystemMessage(content=new_content) + + +def _tag_tools( + tools: list[Any] | None, + cache_control: dict[str, str], +) -> list[Any] | None: + """Tag the last tool with cache_control via its extras dict. + + Only the last tool is tagged to minimize the number of explicit cache + breakpoints (Anthropic limits these to 4 per request). Since tool + definitions are sent as a contiguous block, a single breakpoint on the + last tool caches the entire set. + + Creates a copy of the last tool with cache_control added to extras, + without mutating the original. + + Args: + tools: The list of tools to tag. + cache_control: The cache control dict to apply. + + Returns: + A new list with cache_control on the last tool's extras, or the + original if no tools are present. + """ + if not tools: + return tools + + last = tools[-1] + if not isinstance(last, BaseTool): + return tools + + new_extras = {**(last.extras or {}), "cache_control": cache_control} + return [*tools[:-1], last.model_copy(update={"extras": new_extras})] diff --git a/libs/partners/anthropic/tests/unit_tests/middleware/test_prompt_caching.py b/libs/partners/anthropic/tests/unit_tests/middleware/test_prompt_caching.py index 634d8ff91a8..d8221cfca17 100644 --- a/libs/partners/anthropic/tests/unit_tests/middleware/test_prompt_caching.py +++ b/libs/partners/anthropic/tests/unit_tests/middleware/test_prompt_caching.py @@ -11,8 +11,9 @@ from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) from langchain_core.language_models import BaseChatModel -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.tools import BaseTool, tool from langgraph.runtime import Runtime from langchain_anthropic.chat_models import ChatAnthropic @@ -334,3 +335,240 @@ async def test_anthropic_prompt_caching_middleware_async_default_values() -> Non assert modified_request.model_settings == { "cache_control": {"type": "ephemeral", "ttl": "5m"} } + + +class TestSystemMessageCaching: + """Tests for system message cache_control tagging.""" + + def _make_request( + self, + system_message: SystemMessage | None = None, + **kwargs: Any, + ) -> ModelRequest: + mock_model = MagicMock(spec=ChatAnthropic) + defaults: dict[str, Any] = { + "model": mock_model, + "messages": [HumanMessage("Hello")], + "system_message": system_message, + "tool_choice": None, + "tools": [], + "response_format": None, + "state": {"messages": [HumanMessage("Hello")]}, + "runtime": cast(Runtime, object()), + "model_settings": {}, + } + defaults.update(kwargs) + return ModelRequest(**defaults) + + def _run(self, request: ModelRequest) -> ModelRequest: + middleware = AnthropicPromptCachingMiddleware() + captured: ModelRequest | None = None + + def handler(req: ModelRequest) -> ModelResponse: + nonlocal captured + captured = req + return ModelResponse(result=[AIMessage(content="ok")]) + + middleware.wrap_model_call(request, handler) + assert captured is not None + return captured + + def _get_content_blocks(self, result: ModelRequest) -> list[dict[str, Any]]: + assert result.system_message is not None + content = result.system_message.content + assert isinstance(content, list) + return cast("list[dict[str, Any]]", content) + + def test_tags_last_block_of_string_system_message(self) -> None: + result = self._run(self._make_request(SystemMessage("Base prompt"))) + blocks = self._get_content_blocks(result) + assert len(blocks) == 1 + assert blocks[0]["text"] == "Base prompt" + assert blocks[0]["cache_control"] == {"type": "ephemeral", "ttl": "5m"} + + def test_tags_only_last_block_of_multi_block_system_message(self) -> None: + msg = SystemMessage( + content=[ + {"type": "text", "text": "Block 1"}, + {"type": "text", "text": "Block 2"}, + {"type": "text", "text": "Block 3"}, + ] + ) + blocks = self._get_content_blocks(self._run(self._make_request(msg))) + assert len(blocks) == 3 + assert "cache_control" not in blocks[0] + assert "cache_control" not in blocks[1] + assert blocks[2]["cache_control"] == {"type": "ephemeral", "ttl": "5m"} + + def test_does_not_mutate_original_system_message(self) -> None: + original_content: list[str | dict[str, str]] = [ + {"type": "text", "text": "Block 1"}, + {"type": "text", "text": "Block 2"}, + ] + msg = SystemMessage(content=original_content) + self._run(self._make_request(msg)) + assert "cache_control" not in original_content[1] + + def test_passes_through_when_no_system_message(self) -> None: + result = self._run(self._make_request(system_message=None)) + assert result.system_message is None + + def test_passes_through_when_system_message_has_empty_string(self) -> None: + msg = SystemMessage(content="") + result = self._run(self._make_request(msg)) + assert result.system_message is not None + assert result.system_message.content == "" + + def test_passes_through_when_system_message_has_empty_list(self) -> None: + msg = SystemMessage(content=[]) + result = self._run(self._make_request(msg)) + assert result.system_message is not None + assert result.system_message.content == [] + + def test_preserves_non_text_block_types(self) -> None: + msg = SystemMessage( + content=[ + {"type": "text", "text": "Prompt"}, + {"type": "custom_type", "data": "value"}, + ] + ) + blocks = self._get_content_blocks(self._run(self._make_request(msg))) + assert blocks[0] == {"type": "text", "text": "Prompt"} + assert blocks[1]["type"] == "custom_type" + assert blocks[1]["data"] == "value" + assert blocks[1]["cache_control"] == {"type": "ephemeral", "ttl": "5m"} + + def test_respects_custom_ttl(self) -> None: + middleware = AnthropicPromptCachingMiddleware(ttl="1h") + request = self._make_request(SystemMessage("Prompt")) + captured: ModelRequest | None = None + + def handler(req: ModelRequest) -> ModelResponse: + nonlocal captured + captured = req + return ModelResponse(result=[AIMessage(content="ok")]) + + middleware.wrap_model_call(request, handler) + assert captured is not None + blocks = self._get_content_blocks(captured) + assert blocks[0]["cache_control"] == {"type": "ephemeral", "ttl": "1h"} + + +class TestToolCaching: + """Tests for tool definition cache_control tagging.""" + + def _make_request( + self, + tools: list[Any] | None = None, + **kwargs: Any, + ) -> ModelRequest: + mock_model = MagicMock(spec=ChatAnthropic) + defaults: dict[str, Any] = { + "model": mock_model, + "messages": [HumanMessage("Hello")], + "system_message": None, + "tool_choice": None, + "tools": tools or [], + "response_format": None, + "state": {"messages": [HumanMessage("Hello")]}, + "runtime": cast(Runtime, object()), + "model_settings": {}, + } + defaults.update(kwargs) + return ModelRequest(**defaults) + + def _run(self, request: ModelRequest) -> ModelRequest: + middleware = AnthropicPromptCachingMiddleware() + captured: ModelRequest | None = None + + def handler(req: ModelRequest) -> ModelResponse: + nonlocal captured + captured = req + return ModelResponse(result=[AIMessage(content="ok")]) + + middleware.wrap_model_call(request, handler) + assert captured is not None + return captured + + def test_tags_only_last_tool_with_cache_control(self) -> None: + @tool + def get_weather(location: str) -> str: + """Get weather for a location.""" + return "sunny" + + @tool + def get_time(timezone: str) -> str: + """Get time in a timezone.""" + return "12:00" + + result = self._run(self._make_request(tools=[get_weather, get_time])) + assert result.tools is not None + assert len(result.tools) == 2 + first = result.tools[0] + assert isinstance(first, BaseTool) + assert first.extras is None or "cache_control" not in first.extras + last = result.tools[1] + assert isinstance(last, BaseTool) + assert last.extras is not None + assert last.extras["cache_control"] == {"type": "ephemeral", "ttl": "5m"} + + def test_does_not_mutate_original_tools(self) -> None: + @tool + def my_tool(x: str) -> str: + """A tool.""" + return x + + original_extras = my_tool.extras + self._run(self._make_request(tools=[my_tool])) + assert my_tool.extras is original_extras + + def test_preserves_existing_extras(self) -> None: + @tool(extras={"defer_loading": True}) + def my_tool(x: str) -> str: + """A tool.""" + return x + + result = self._run(self._make_request(tools=[my_tool])) + assert result.tools is not None + t = result.tools[0] + assert isinstance(t, BaseTool) + assert t.extras is not None + assert t.extras["defer_loading"] is True + assert t.extras["cache_control"] == { + "type": "ephemeral", + "ttl": "5m", + } + + def test_passes_through_empty_tools(self) -> None: + result = self._run(self._make_request(tools=[])) + assert result.tools == [] + + def test_passes_through_none_tools(self) -> None: + result = self._run(self._make_request(tools=None)) + assert result.tools == [] + + def test_respects_custom_ttl(self) -> None: + @tool + def my_tool(x: str) -> str: + """A tool.""" + return x + + middleware = AnthropicPromptCachingMiddleware(ttl="1h") + request = self._make_request(tools=[my_tool]) + captured: ModelRequest | None = None + + def handler(req: ModelRequest) -> ModelResponse: + nonlocal captured + captured = req + return ModelResponse(result=[AIMessage(content="ok")]) + + middleware.wrap_model_call(request, handler) + assert captured is not None + assert captured.tools is not None + t = captured.tools[0] + assert isinstance(t, BaseTool) + assert t.extras is not None + assert t.extras["cache_control"] == { + "type": "ephemeral", + "ttl": "1h", + }