feat(anthropic): AnthropicPromptCachingMiddleware: apply explicit caching to system message and tool definitions (#35969)

This commit is contained in:
ccurme
2026-03-17 10:58:56 -04:00
committed by GitHub
parent 55711b010b
commit 043ef0721a
2 changed files with 360 additions and 14 deletions

View File

@@ -5,10 +5,15 @@ Requires:
- `langchain-anthropic`: For `ChatAnthropic` model (already a dependency)
"""
from __future__ import annotations
from collections.abc import Awaitable, Callable
from typing import Literal
from typing import Any, Literal
from warnings import warn
from langchain_core.messages import SystemMessage
from langchain_core.tools import BaseTool
from langchain_anthropic.chat_models import ChatAnthropic
try:
@@ -34,6 +39,15 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
Requires both `langchain` and `langchain-anthropic` packages to be installed.
Applies cache control breakpoints to:
- **System message**: Tags the last content block of the system message
with `cache_control` so static system prompt content is cached.
- **Tools**: Tags all tool definitions with `cache_control` so tool
schemas are cached across turns.
- **Last cacheable block**: Tags last cacheable block of message sequence using
Anthropic's automatic caching feature.
Learn more about Anthropic prompt caching
[here](https://platform.claude.com/docs/en/build-with-claude/prompt-caching).
"""
@@ -68,6 +82,10 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
self.min_messages_to_cache = min_messages_to_cache
self.unsupported_model_behavior = unsupported_model_behavior
@property
def _cache_control(self) -> dict[str, str]:
return {"type": self.type, "ttl": self.ttl}
def _should_apply_caching(self, request: ModelRequest) -> bool:
"""Check if caching should be applied to the request.
@@ -98,6 +116,33 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
)
return messages_count >= self.min_messages_to_cache
def _apply_caching(self, request: ModelRequest) -> ModelRequest:
"""Apply cache control to system message, tools, and model settings.
Args:
request: The model request to modify.
Returns:
New request with cache control applied.
"""
overrides: dict[str, Any] = {}
cache_control = self._cache_control
overrides["model_settings"] = {
**request.model_settings,
"cache_control": cache_control,
}
system_message = _tag_system_message(request.system_message, cache_control)
if system_message is not request.system_message:
overrides["system_message"] = system_message
tools = _tag_tools(request.tools, cache_control)
if tools is not request.tools:
overrides["tools"] = tools
return request.override(**overrides)
def wrap_model_call(
self,
request: ModelRequest,
@@ -115,12 +160,7 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
if not self._should_apply_caching(request):
return handler(request)
model_settings = request.model_settings
new_model_settings = {
**model_settings,
"cache_control": {"type": self.type, "ttl": self.ttl},
}
return handler(request.override(model_settings=new_model_settings))
return handler(self._apply_caching(request))
async def awrap_model_call(
self,
@@ -139,9 +179,77 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
if not self._should_apply_caching(request):
return await handler(request)
model_settings = request.model_settings
new_model_settings = {
**model_settings,
"cache_control": {"type": self.type, "ttl": self.ttl},
}
return await handler(request.override(model_settings=new_model_settings))
return await handler(self._apply_caching(request))
def _tag_system_message(
system_message: Any,
cache_control: dict[str, str],
) -> Any:
"""Tag the last content block of a system message with cache_control.
Returns the original system_message unchanged if there are no blocks
to tag.
Args:
system_message: The system message to tag.
cache_control: The cache control dict to apply.
Returns:
A new SystemMessage with cache_control on the last block, or the
original if no modification was needed.
"""
if system_message is None:
return system_message
content = system_message.content
if isinstance(content, str):
if not content:
return system_message
new_content: list[str | dict[str, Any]] = [
{"type": "text", "text": content, "cache_control": cache_control}
]
elif isinstance(content, list):
if not content:
return system_message
new_content = list(content)
last = new_content[-1]
base = last if isinstance(last, dict) else {}
new_content[-1] = {**base, "cache_control": cache_control}
else:
return system_message
return SystemMessage(content=new_content)
def _tag_tools(
tools: list[Any] | None,
cache_control: dict[str, str],
) -> list[Any] | None:
"""Tag the last tool with cache_control via its extras dict.
Only the last tool is tagged to minimize the number of explicit cache
breakpoints (Anthropic limits these to 4 per request). Since tool
definitions are sent as a contiguous block, a single breakpoint on the
last tool caches the entire set.
Creates a copy of the last tool with cache_control added to extras,
without mutating the original.
Args:
tools: The list of tools to tag.
cache_control: The cache control dict to apply.
Returns:
A new list with cache_control on the last tool's extras, or the
original if no tools are present.
"""
if not tools:
return tools
last = tools[-1]
if not isinstance(last, BaseTool):
return tools
new_extras = {**(last.extras or {}), "cache_control": cache_control}
return [*tools[:-1], last.model_copy(update={"extras": new_extras})]

View File

@@ -11,8 +11,9 @@ from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.tools import BaseTool, tool
from langgraph.runtime import Runtime
from langchain_anthropic.chat_models import ChatAnthropic
@@ -334,3 +335,240 @@ async def test_anthropic_prompt_caching_middleware_async_default_values() -> Non
assert modified_request.model_settings == {
"cache_control": {"type": "ephemeral", "ttl": "5m"}
}
class TestSystemMessageCaching:
"""Tests for system message cache_control tagging."""
def _make_request(
self,
system_message: SystemMessage | None = None,
**kwargs: Any,
) -> ModelRequest:
mock_model = MagicMock(spec=ChatAnthropic)
defaults: dict[str, Any] = {
"model": mock_model,
"messages": [HumanMessage("Hello")],
"system_message": system_message,
"tool_choice": None,
"tools": [],
"response_format": None,
"state": {"messages": [HumanMessage("Hello")]},
"runtime": cast(Runtime, object()),
"model_settings": {},
}
defaults.update(kwargs)
return ModelRequest(**defaults)
def _run(self, request: ModelRequest) -> ModelRequest:
middleware = AnthropicPromptCachingMiddleware()
captured: ModelRequest | None = None
def handler(req: ModelRequest) -> ModelResponse:
nonlocal captured
captured = req
return ModelResponse(result=[AIMessage(content="ok")])
middleware.wrap_model_call(request, handler)
assert captured is not None
return captured
def _get_content_blocks(self, result: ModelRequest) -> list[dict[str, Any]]:
assert result.system_message is not None
content = result.system_message.content
assert isinstance(content, list)
return cast("list[dict[str, Any]]", content)
def test_tags_last_block_of_string_system_message(self) -> None:
result = self._run(self._make_request(SystemMessage("Base prompt")))
blocks = self._get_content_blocks(result)
assert len(blocks) == 1
assert blocks[0]["text"] == "Base prompt"
assert blocks[0]["cache_control"] == {"type": "ephemeral", "ttl": "5m"}
def test_tags_only_last_block_of_multi_block_system_message(self) -> None:
msg = SystemMessage(
content=[
{"type": "text", "text": "Block 1"},
{"type": "text", "text": "Block 2"},
{"type": "text", "text": "Block 3"},
]
)
blocks = self._get_content_blocks(self._run(self._make_request(msg)))
assert len(blocks) == 3
assert "cache_control" not in blocks[0]
assert "cache_control" not in blocks[1]
assert blocks[2]["cache_control"] == {"type": "ephemeral", "ttl": "5m"}
def test_does_not_mutate_original_system_message(self) -> None:
original_content: list[str | dict[str, str]] = [
{"type": "text", "text": "Block 1"},
{"type": "text", "text": "Block 2"},
]
msg = SystemMessage(content=original_content)
self._run(self._make_request(msg))
assert "cache_control" not in original_content[1]
def test_passes_through_when_no_system_message(self) -> None:
result = self._run(self._make_request(system_message=None))
assert result.system_message is None
def test_passes_through_when_system_message_has_empty_string(self) -> None:
msg = SystemMessage(content="")
result = self._run(self._make_request(msg))
assert result.system_message is not None
assert result.system_message.content == ""
def test_passes_through_when_system_message_has_empty_list(self) -> None:
msg = SystemMessage(content=[])
result = self._run(self._make_request(msg))
assert result.system_message is not None
assert result.system_message.content == []
def test_preserves_non_text_block_types(self) -> None:
msg = SystemMessage(
content=[
{"type": "text", "text": "Prompt"},
{"type": "custom_type", "data": "value"},
]
)
blocks = self._get_content_blocks(self._run(self._make_request(msg)))
assert blocks[0] == {"type": "text", "text": "Prompt"}
assert blocks[1]["type"] == "custom_type"
assert blocks[1]["data"] == "value"
assert blocks[1]["cache_control"] == {"type": "ephemeral", "ttl": "5m"}
def test_respects_custom_ttl(self) -> None:
middleware = AnthropicPromptCachingMiddleware(ttl="1h")
request = self._make_request(SystemMessage("Prompt"))
captured: ModelRequest | None = None
def handler(req: ModelRequest) -> ModelResponse:
nonlocal captured
captured = req
return ModelResponse(result=[AIMessage(content="ok")])
middleware.wrap_model_call(request, handler)
assert captured is not None
blocks = self._get_content_blocks(captured)
assert blocks[0]["cache_control"] == {"type": "ephemeral", "ttl": "1h"}
class TestToolCaching:
"""Tests for tool definition cache_control tagging."""
def _make_request(
self,
tools: list[Any] | None = None,
**kwargs: Any,
) -> ModelRequest:
mock_model = MagicMock(spec=ChatAnthropic)
defaults: dict[str, Any] = {
"model": mock_model,
"messages": [HumanMessage("Hello")],
"system_message": None,
"tool_choice": None,
"tools": tools or [],
"response_format": None,
"state": {"messages": [HumanMessage("Hello")]},
"runtime": cast(Runtime, object()),
"model_settings": {},
}
defaults.update(kwargs)
return ModelRequest(**defaults)
def _run(self, request: ModelRequest) -> ModelRequest:
middleware = AnthropicPromptCachingMiddleware()
captured: ModelRequest | None = None
def handler(req: ModelRequest) -> ModelResponse:
nonlocal captured
captured = req
return ModelResponse(result=[AIMessage(content="ok")])
middleware.wrap_model_call(request, handler)
assert captured is not None
return captured
def test_tags_only_last_tool_with_cache_control(self) -> None:
@tool
def get_weather(location: str) -> str:
"""Get weather for a location."""
return "sunny"
@tool
def get_time(timezone: str) -> str:
"""Get time in a timezone."""
return "12:00"
result = self._run(self._make_request(tools=[get_weather, get_time]))
assert result.tools is not None
assert len(result.tools) == 2
first = result.tools[0]
assert isinstance(first, BaseTool)
assert first.extras is None or "cache_control" not in first.extras
last = result.tools[1]
assert isinstance(last, BaseTool)
assert last.extras is not None
assert last.extras["cache_control"] == {"type": "ephemeral", "ttl": "5m"}
def test_does_not_mutate_original_tools(self) -> None:
@tool
def my_tool(x: str) -> str:
"""A tool."""
return x
original_extras = my_tool.extras
self._run(self._make_request(tools=[my_tool]))
assert my_tool.extras is original_extras
def test_preserves_existing_extras(self) -> None:
@tool(extras={"defer_loading": True})
def my_tool(x: str) -> str:
"""A tool."""
return x
result = self._run(self._make_request(tools=[my_tool]))
assert result.tools is not None
t = result.tools[0]
assert isinstance(t, BaseTool)
assert t.extras is not None
assert t.extras["defer_loading"] is True
assert t.extras["cache_control"] == {
"type": "ephemeral",
"ttl": "5m",
}
def test_passes_through_empty_tools(self) -> None:
result = self._run(self._make_request(tools=[]))
assert result.tools == []
def test_passes_through_none_tools(self) -> None:
result = self._run(self._make_request(tools=None))
assert result.tools == []
def test_respects_custom_ttl(self) -> None:
@tool
def my_tool(x: str) -> str:
"""A tool."""
return x
middleware = AnthropicPromptCachingMiddleware(ttl="1h")
request = self._make_request(tools=[my_tool])
captured: ModelRequest | None = None
def handler(req: ModelRequest) -> ModelResponse:
nonlocal captured
captured = req
return ModelResponse(result=[AIMessage(content="ok")])
middleware.wrap_model_call(request, handler)
assert captured is not None
assert captured.tools is not None
t = captured.tools[0]
assert isinstance(t, BaseTool)
assert t.extras is not None
assert t.extras["cache_control"] == {
"type": "ephemeral",
"ttl": "1h",
}