feat(anthropic): AnthropicPromptCachingMiddleware: apply explicit caching to system message and tool definitions (#35969)

This commit is contained in:
ccurme
2026-03-17 10:58:56 -04:00
committed by GitHub
parent 55711b010b
commit 043ef0721a
2 changed files with 360 additions and 14 deletions

View File

@@ -5,10 +5,15 @@ Requires:
- `langchain-anthropic`: For `ChatAnthropic` model (already a dependency)
"""
from __future__ import annotations
from collections.abc import Awaitable, Callable
from typing import Literal
from typing import Any, Literal
from warnings import warn
from langchain_core.messages import SystemMessage
from langchain_core.tools import BaseTool
from langchain_anthropic.chat_models import ChatAnthropic
try:
@@ -34,6 +39,15 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
Requires both `langchain` and `langchain-anthropic` packages to be installed.
Applies cache control breakpoints to:
- **System message**: Tags the last content block of the system message
with `cache_control` so static system prompt content is cached.
- **Tools**: Tags all tool definitions with `cache_control` so tool
schemas are cached across turns.
- **Last cacheable block**: Tags last cacheable block of message sequence using
Anthropic's automatic caching feature.
Learn more about Anthropic prompt caching
[here](https://platform.claude.com/docs/en/build-with-claude/prompt-caching).
"""
@@ -68,6 +82,10 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
self.min_messages_to_cache = min_messages_to_cache
self.unsupported_model_behavior = unsupported_model_behavior
@property
def _cache_control(self) -> dict[str, str]:
return {"type": self.type, "ttl": self.ttl}
def _should_apply_caching(self, request: ModelRequest) -> bool:
"""Check if caching should be applied to the request.
@@ -98,6 +116,33 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
)
return messages_count >= self.min_messages_to_cache
def _apply_caching(self, request: ModelRequest) -> ModelRequest:
"""Apply cache control to system message, tools, and model settings.
Args:
request: The model request to modify.
Returns:
New request with cache control applied.
"""
overrides: dict[str, Any] = {}
cache_control = self._cache_control
overrides["model_settings"] = {
**request.model_settings,
"cache_control": cache_control,
}
system_message = _tag_system_message(request.system_message, cache_control)
if system_message is not request.system_message:
overrides["system_message"] = system_message
tools = _tag_tools(request.tools, cache_control)
if tools is not request.tools:
overrides["tools"] = tools
return request.override(**overrides)
def wrap_model_call(
self,
request: ModelRequest,
@@ -115,12 +160,7 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
if not self._should_apply_caching(request):
return handler(request)
model_settings = request.model_settings
new_model_settings = {
**model_settings,
"cache_control": {"type": self.type, "ttl": self.ttl},
}
return handler(request.override(model_settings=new_model_settings))
return handler(self._apply_caching(request))
async def awrap_model_call(
self,
@@ -139,9 +179,77 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
if not self._should_apply_caching(request):
return await handler(request)
model_settings = request.model_settings
new_model_settings = {
**model_settings,
"cache_control": {"type": self.type, "ttl": self.ttl},
}
return await handler(request.override(model_settings=new_model_settings))
return await handler(self._apply_caching(request))
def _tag_system_message(
system_message: Any,
cache_control: dict[str, str],
) -> Any:
"""Tag the last content block of a system message with cache_control.
Returns the original system_message unchanged if there are no blocks
to tag.
Args:
system_message: The system message to tag.
cache_control: The cache control dict to apply.
Returns:
A new SystemMessage with cache_control on the last block, or the
original if no modification was needed.
"""
if system_message is None:
return system_message
content = system_message.content
if isinstance(content, str):
if not content:
return system_message
new_content: list[str | dict[str, Any]] = [
{"type": "text", "text": content, "cache_control": cache_control}
]
elif isinstance(content, list):
if not content:
return system_message
new_content = list(content)
last = new_content[-1]
base = last if isinstance(last, dict) else {}
new_content[-1] = {**base, "cache_control": cache_control}
else:
return system_message
return SystemMessage(content=new_content)
def _tag_tools(
tools: list[Any] | None,
cache_control: dict[str, str],
) -> list[Any] | None:
"""Tag the last tool with cache_control via its extras dict.
Only the last tool is tagged to minimize the number of explicit cache
breakpoints (Anthropic limits these to 4 per request). Since tool
definitions are sent as a contiguous block, a single breakpoint on the
last tool caches the entire set.
Creates a copy of the last tool with cache_control added to extras,
without mutating the original.
Args:
tools: The list of tools to tag.
cache_control: The cache control dict to apply.
Returns:
A new list with cache_control on the last tool's extras, or the
original if no tools are present.
"""
if not tools:
return tools
last = tools[-1]
if not isinstance(last, BaseTool):
return tools
new_extras = {**(last.extras or {}), "cache_control": cache_control}
return [*tools[:-1], last.model_copy(update={"extras": new_extras})]