mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
feat(anthropic): AnthropicPromptCachingMiddleware: apply explicit caching to system message and tool definitions (#35969)
This commit is contained in:
@@ -5,10 +5,15 @@ Requires:
|
||||
- `langchain-anthropic`: For `ChatAnthropic` model (already a dependency)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
from warnings import warn
|
||||
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain_anthropic.chat_models import ChatAnthropic
|
||||
|
||||
try:
|
||||
@@ -34,6 +39,15 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
|
||||
Requires both `langchain` and `langchain-anthropic` packages to be installed.
|
||||
|
||||
Applies cache control breakpoints to:
|
||||
|
||||
- **System message**: Tags the last content block of the system message
|
||||
with `cache_control` so static system prompt content is cached.
|
||||
- **Tools**: Tags all tool definitions with `cache_control` so tool
|
||||
schemas are cached across turns.
|
||||
- **Last cacheable block**: Tags last cacheable block of message sequence using
|
||||
Anthropic's automatic caching feature.
|
||||
|
||||
Learn more about Anthropic prompt caching
|
||||
[here](https://platform.claude.com/docs/en/build-with-claude/prompt-caching).
|
||||
"""
|
||||
@@ -68,6 +82,10 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
self.min_messages_to_cache = min_messages_to_cache
|
||||
self.unsupported_model_behavior = unsupported_model_behavior
|
||||
|
||||
@property
|
||||
def _cache_control(self) -> dict[str, str]:
|
||||
return {"type": self.type, "ttl": self.ttl}
|
||||
|
||||
def _should_apply_caching(self, request: ModelRequest) -> bool:
|
||||
"""Check if caching should be applied to the request.
|
||||
|
||||
@@ -98,6 +116,33 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
)
|
||||
return messages_count >= self.min_messages_to_cache
|
||||
|
||||
def _apply_caching(self, request: ModelRequest) -> ModelRequest:
|
||||
"""Apply cache control to system message, tools, and model settings.
|
||||
|
||||
Args:
|
||||
request: The model request to modify.
|
||||
|
||||
Returns:
|
||||
New request with cache control applied.
|
||||
"""
|
||||
overrides: dict[str, Any] = {}
|
||||
cache_control = self._cache_control
|
||||
|
||||
overrides["model_settings"] = {
|
||||
**request.model_settings,
|
||||
"cache_control": cache_control,
|
||||
}
|
||||
|
||||
system_message = _tag_system_message(request.system_message, cache_control)
|
||||
if system_message is not request.system_message:
|
||||
overrides["system_message"] = system_message
|
||||
|
||||
tools = _tag_tools(request.tools, cache_control)
|
||||
if tools is not request.tools:
|
||||
overrides["tools"] = tools
|
||||
|
||||
return request.override(**overrides)
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
@@ -115,12 +160,7 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
if not self._should_apply_caching(request):
|
||||
return handler(request)
|
||||
|
||||
model_settings = request.model_settings
|
||||
new_model_settings = {
|
||||
**model_settings,
|
||||
"cache_control": {"type": self.type, "ttl": self.ttl},
|
||||
}
|
||||
return handler(request.override(model_settings=new_model_settings))
|
||||
return handler(self._apply_caching(request))
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
@@ -139,9 +179,77 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
if not self._should_apply_caching(request):
|
||||
return await handler(request)
|
||||
|
||||
model_settings = request.model_settings
|
||||
new_model_settings = {
|
||||
**model_settings,
|
||||
"cache_control": {"type": self.type, "ttl": self.ttl},
|
||||
}
|
||||
return await handler(request.override(model_settings=new_model_settings))
|
||||
return await handler(self._apply_caching(request))
|
||||
|
||||
|
||||
def _tag_system_message(
|
||||
system_message: Any,
|
||||
cache_control: dict[str, str],
|
||||
) -> Any:
|
||||
"""Tag the last content block of a system message with cache_control.
|
||||
|
||||
Returns the original system_message unchanged if there are no blocks
|
||||
to tag.
|
||||
|
||||
Args:
|
||||
system_message: The system message to tag.
|
||||
cache_control: The cache control dict to apply.
|
||||
|
||||
Returns:
|
||||
A new SystemMessage with cache_control on the last block, or the
|
||||
original if no modification was needed.
|
||||
"""
|
||||
if system_message is None:
|
||||
return system_message
|
||||
|
||||
content = system_message.content
|
||||
if isinstance(content, str):
|
||||
if not content:
|
||||
return system_message
|
||||
new_content: list[str | dict[str, Any]] = [
|
||||
{"type": "text", "text": content, "cache_control": cache_control}
|
||||
]
|
||||
elif isinstance(content, list):
|
||||
if not content:
|
||||
return system_message
|
||||
new_content = list(content)
|
||||
last = new_content[-1]
|
||||
base = last if isinstance(last, dict) else {}
|
||||
new_content[-1] = {**base, "cache_control": cache_control}
|
||||
else:
|
||||
return system_message
|
||||
|
||||
return SystemMessage(content=new_content)
|
||||
|
||||
|
||||
def _tag_tools(
|
||||
tools: list[Any] | None,
|
||||
cache_control: dict[str, str],
|
||||
) -> list[Any] | None:
|
||||
"""Tag the last tool with cache_control via its extras dict.
|
||||
|
||||
Only the last tool is tagged to minimize the number of explicit cache
|
||||
breakpoints (Anthropic limits these to 4 per request). Since tool
|
||||
definitions are sent as a contiguous block, a single breakpoint on the
|
||||
last tool caches the entire set.
|
||||
|
||||
Creates a copy of the last tool with cache_control added to extras,
|
||||
without mutating the original.
|
||||
|
||||
Args:
|
||||
tools: The list of tools to tag.
|
||||
cache_control: The cache control dict to apply.
|
||||
|
||||
Returns:
|
||||
A new list with cache_control on the last tool's extras, or the
|
||||
original if no tools are present.
|
||||
"""
|
||||
if not tools:
|
||||
return tools
|
||||
|
||||
last = tools[-1]
|
||||
if not isinstance(last, BaseTool):
|
||||
return tools
|
||||
|
||||
new_extras = {**(last.extras or {}), "cache_control": cache_control}
|
||||
return [*tools[:-1], last.model_copy(update={"extras": new_extras})]
|
||||
|
||||
Reference in New Issue
Block a user