mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 02:53:16 +00:00
feat(anthropic): AnthropicPromptCachingMiddleware: apply explicit caching to system message and tool definitions (#35969)
This commit is contained in:
@@ -5,10 +5,15 @@ Requires:
|
||||
- `langchain-anthropic`: For `ChatAnthropic` model (already a dependency)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
from warnings import warn
|
||||
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain_anthropic.chat_models import ChatAnthropic
|
||||
|
||||
try:
|
||||
@@ -34,6 +39,15 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
|
||||
Requires both `langchain` and `langchain-anthropic` packages to be installed.
|
||||
|
||||
Applies cache control breakpoints to:
|
||||
|
||||
- **System message**: Tags the last content block of the system message
|
||||
with `cache_control` so static system prompt content is cached.
|
||||
- **Tools**: Tags all tool definitions with `cache_control` so tool
|
||||
schemas are cached across turns.
|
||||
- **Last cacheable block**: Tags last cacheable block of message sequence using
|
||||
Anthropic's automatic caching feature.
|
||||
|
||||
Learn more about Anthropic prompt caching
|
||||
[here](https://platform.claude.com/docs/en/build-with-claude/prompt-caching).
|
||||
"""
|
||||
@@ -68,6 +82,10 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
self.min_messages_to_cache = min_messages_to_cache
|
||||
self.unsupported_model_behavior = unsupported_model_behavior
|
||||
|
||||
@property
|
||||
def _cache_control(self) -> dict[str, str]:
|
||||
return {"type": self.type, "ttl": self.ttl}
|
||||
|
||||
def _should_apply_caching(self, request: ModelRequest) -> bool:
|
||||
"""Check if caching should be applied to the request.
|
||||
|
||||
@@ -98,6 +116,33 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
)
|
||||
return messages_count >= self.min_messages_to_cache
|
||||
|
||||
def _apply_caching(self, request: ModelRequest) -> ModelRequest:
|
||||
"""Apply cache control to system message, tools, and model settings.
|
||||
|
||||
Args:
|
||||
request: The model request to modify.
|
||||
|
||||
Returns:
|
||||
New request with cache control applied.
|
||||
"""
|
||||
overrides: dict[str, Any] = {}
|
||||
cache_control = self._cache_control
|
||||
|
||||
overrides["model_settings"] = {
|
||||
**request.model_settings,
|
||||
"cache_control": cache_control,
|
||||
}
|
||||
|
||||
system_message = _tag_system_message(request.system_message, cache_control)
|
||||
if system_message is not request.system_message:
|
||||
overrides["system_message"] = system_message
|
||||
|
||||
tools = _tag_tools(request.tools, cache_control)
|
||||
if tools is not request.tools:
|
||||
overrides["tools"] = tools
|
||||
|
||||
return request.override(**overrides)
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
@@ -115,12 +160,7 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
if not self._should_apply_caching(request):
|
||||
return handler(request)
|
||||
|
||||
model_settings = request.model_settings
|
||||
new_model_settings = {
|
||||
**model_settings,
|
||||
"cache_control": {"type": self.type, "ttl": self.ttl},
|
||||
}
|
||||
return handler(request.override(model_settings=new_model_settings))
|
||||
return handler(self._apply_caching(request))
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
@@ -139,9 +179,77 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
||||
if not self._should_apply_caching(request):
|
||||
return await handler(request)
|
||||
|
||||
model_settings = request.model_settings
|
||||
new_model_settings = {
|
||||
**model_settings,
|
||||
"cache_control": {"type": self.type, "ttl": self.ttl},
|
||||
}
|
||||
return await handler(request.override(model_settings=new_model_settings))
|
||||
return await handler(self._apply_caching(request))
|
||||
|
||||
|
||||
def _tag_system_message(
|
||||
system_message: Any,
|
||||
cache_control: dict[str, str],
|
||||
) -> Any:
|
||||
"""Tag the last content block of a system message with cache_control.
|
||||
|
||||
Returns the original system_message unchanged if there are no blocks
|
||||
to tag.
|
||||
|
||||
Args:
|
||||
system_message: The system message to tag.
|
||||
cache_control: The cache control dict to apply.
|
||||
|
||||
Returns:
|
||||
A new SystemMessage with cache_control on the last block, or the
|
||||
original if no modification was needed.
|
||||
"""
|
||||
if system_message is None:
|
||||
return system_message
|
||||
|
||||
content = system_message.content
|
||||
if isinstance(content, str):
|
||||
if not content:
|
||||
return system_message
|
||||
new_content: list[str | dict[str, Any]] = [
|
||||
{"type": "text", "text": content, "cache_control": cache_control}
|
||||
]
|
||||
elif isinstance(content, list):
|
||||
if not content:
|
||||
return system_message
|
||||
new_content = list(content)
|
||||
last = new_content[-1]
|
||||
base = last if isinstance(last, dict) else {}
|
||||
new_content[-1] = {**base, "cache_control": cache_control}
|
||||
else:
|
||||
return system_message
|
||||
|
||||
return SystemMessage(content=new_content)
|
||||
|
||||
|
||||
def _tag_tools(
|
||||
tools: list[Any] | None,
|
||||
cache_control: dict[str, str],
|
||||
) -> list[Any] | None:
|
||||
"""Tag the last tool with cache_control via its extras dict.
|
||||
|
||||
Only the last tool is tagged to minimize the number of explicit cache
|
||||
breakpoints (Anthropic limits these to 4 per request). Since tool
|
||||
definitions are sent as a contiguous block, a single breakpoint on the
|
||||
last tool caches the entire set.
|
||||
|
||||
Creates a copy of the last tool with cache_control added to extras,
|
||||
without mutating the original.
|
||||
|
||||
Args:
|
||||
tools: The list of tools to tag.
|
||||
cache_control: The cache control dict to apply.
|
||||
|
||||
Returns:
|
||||
A new list with cache_control on the last tool's extras, or the
|
||||
original if no tools are present.
|
||||
"""
|
||||
if not tools:
|
||||
return tools
|
||||
|
||||
last = tools[-1]
|
||||
if not isinstance(last, BaseTool):
|
||||
return tools
|
||||
|
||||
new_extras = {**(last.extras or {}), "cache_control": cache_control}
|
||||
return [*tools[:-1], last.model_copy(update={"extras": new_extras})]
|
||||
|
||||
@@ -11,8 +11,9 @@ from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.tools import BaseTool, tool
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from langchain_anthropic.chat_models import ChatAnthropic
|
||||
@@ -334,3 +335,240 @@ async def test_anthropic_prompt_caching_middleware_async_default_values() -> Non
|
||||
assert modified_request.model_settings == {
|
||||
"cache_control": {"type": "ephemeral", "ttl": "5m"}
|
||||
}
|
||||
|
||||
|
||||
class TestSystemMessageCaching:
|
||||
"""Tests for system message cache_control tagging."""
|
||||
|
||||
def _make_request(
|
||||
self,
|
||||
system_message: SystemMessage | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ModelRequest:
|
||||
mock_model = MagicMock(spec=ChatAnthropic)
|
||||
defaults: dict[str, Any] = {
|
||||
"model": mock_model,
|
||||
"messages": [HumanMessage("Hello")],
|
||||
"system_message": system_message,
|
||||
"tool_choice": None,
|
||||
"tools": [],
|
||||
"response_format": None,
|
||||
"state": {"messages": [HumanMessage("Hello")]},
|
||||
"runtime": cast(Runtime, object()),
|
||||
"model_settings": {},
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return ModelRequest(**defaults)
|
||||
|
||||
def _run(self, request: ModelRequest) -> ModelRequest:
|
||||
middleware = AnthropicPromptCachingMiddleware()
|
||||
captured: ModelRequest | None = None
|
||||
|
||||
def handler(req: ModelRequest) -> ModelResponse:
|
||||
nonlocal captured
|
||||
captured = req
|
||||
return ModelResponse(result=[AIMessage(content="ok")])
|
||||
|
||||
middleware.wrap_model_call(request, handler)
|
||||
assert captured is not None
|
||||
return captured
|
||||
|
||||
def _get_content_blocks(self, result: ModelRequest) -> list[dict[str, Any]]:
|
||||
assert result.system_message is not None
|
||||
content = result.system_message.content
|
||||
assert isinstance(content, list)
|
||||
return cast("list[dict[str, Any]]", content)
|
||||
|
||||
def test_tags_last_block_of_string_system_message(self) -> None:
|
||||
result = self._run(self._make_request(SystemMessage("Base prompt")))
|
||||
blocks = self._get_content_blocks(result)
|
||||
assert len(blocks) == 1
|
||||
assert blocks[0]["text"] == "Base prompt"
|
||||
assert blocks[0]["cache_control"] == {"type": "ephemeral", "ttl": "5m"}
|
||||
|
||||
def test_tags_only_last_block_of_multi_block_system_message(self) -> None:
|
||||
msg = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "Block 1"},
|
||||
{"type": "text", "text": "Block 2"},
|
||||
{"type": "text", "text": "Block 3"},
|
||||
]
|
||||
)
|
||||
blocks = self._get_content_blocks(self._run(self._make_request(msg)))
|
||||
assert len(blocks) == 3
|
||||
assert "cache_control" not in blocks[0]
|
||||
assert "cache_control" not in blocks[1]
|
||||
assert blocks[2]["cache_control"] == {"type": "ephemeral", "ttl": "5m"}
|
||||
|
||||
def test_does_not_mutate_original_system_message(self) -> None:
|
||||
original_content: list[str | dict[str, str]] = [
|
||||
{"type": "text", "text": "Block 1"},
|
||||
{"type": "text", "text": "Block 2"},
|
||||
]
|
||||
msg = SystemMessage(content=original_content)
|
||||
self._run(self._make_request(msg))
|
||||
assert "cache_control" not in original_content[1]
|
||||
|
||||
def test_passes_through_when_no_system_message(self) -> None:
|
||||
result = self._run(self._make_request(system_message=None))
|
||||
assert result.system_message is None
|
||||
|
||||
def test_passes_through_when_system_message_has_empty_string(self) -> None:
|
||||
msg = SystemMessage(content="")
|
||||
result = self._run(self._make_request(msg))
|
||||
assert result.system_message is not None
|
||||
assert result.system_message.content == ""
|
||||
|
||||
def test_passes_through_when_system_message_has_empty_list(self) -> None:
|
||||
msg = SystemMessage(content=[])
|
||||
result = self._run(self._make_request(msg))
|
||||
assert result.system_message is not None
|
||||
assert result.system_message.content == []
|
||||
|
||||
def test_preserves_non_text_block_types(self) -> None:
|
||||
msg = SystemMessage(
|
||||
content=[
|
||||
{"type": "text", "text": "Prompt"},
|
||||
{"type": "custom_type", "data": "value"},
|
||||
]
|
||||
)
|
||||
blocks = self._get_content_blocks(self._run(self._make_request(msg)))
|
||||
assert blocks[0] == {"type": "text", "text": "Prompt"}
|
||||
assert blocks[1]["type"] == "custom_type"
|
||||
assert blocks[1]["data"] == "value"
|
||||
assert blocks[1]["cache_control"] == {"type": "ephemeral", "ttl": "5m"}
|
||||
|
||||
def test_respects_custom_ttl(self) -> None:
|
||||
middleware = AnthropicPromptCachingMiddleware(ttl="1h")
|
||||
request = self._make_request(SystemMessage("Prompt"))
|
||||
captured: ModelRequest | None = None
|
||||
|
||||
def handler(req: ModelRequest) -> ModelResponse:
|
||||
nonlocal captured
|
||||
captured = req
|
||||
return ModelResponse(result=[AIMessage(content="ok")])
|
||||
|
||||
middleware.wrap_model_call(request, handler)
|
||||
assert captured is not None
|
||||
blocks = self._get_content_blocks(captured)
|
||||
assert blocks[0]["cache_control"] == {"type": "ephemeral", "ttl": "1h"}
|
||||
|
||||
|
||||
class TestToolCaching:
|
||||
"""Tests for tool definition cache_control tagging."""
|
||||
|
||||
def _make_request(
|
||||
self,
|
||||
tools: list[Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ModelRequest:
|
||||
mock_model = MagicMock(spec=ChatAnthropic)
|
||||
defaults: dict[str, Any] = {
|
||||
"model": mock_model,
|
||||
"messages": [HumanMessage("Hello")],
|
||||
"system_message": None,
|
||||
"tool_choice": None,
|
||||
"tools": tools or [],
|
||||
"response_format": None,
|
||||
"state": {"messages": [HumanMessage("Hello")]},
|
||||
"runtime": cast(Runtime, object()),
|
||||
"model_settings": {},
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return ModelRequest(**defaults)
|
||||
|
||||
def _run(self, request: ModelRequest) -> ModelRequest:
|
||||
middleware = AnthropicPromptCachingMiddleware()
|
||||
captured: ModelRequest | None = None
|
||||
|
||||
def handler(req: ModelRequest) -> ModelResponse:
|
||||
nonlocal captured
|
||||
captured = req
|
||||
return ModelResponse(result=[AIMessage(content="ok")])
|
||||
|
||||
middleware.wrap_model_call(request, handler)
|
||||
assert captured is not None
|
||||
return captured
|
||||
|
||||
def test_tags_only_last_tool_with_cache_control(self) -> None:
|
||||
@tool
|
||||
def get_weather(location: str) -> str:
|
||||
"""Get weather for a location."""
|
||||
return "sunny"
|
||||
|
||||
@tool
|
||||
def get_time(timezone: str) -> str:
|
||||
"""Get time in a timezone."""
|
||||
return "12:00"
|
||||
|
||||
result = self._run(self._make_request(tools=[get_weather, get_time]))
|
||||
assert result.tools is not None
|
||||
assert len(result.tools) == 2
|
||||
first = result.tools[0]
|
||||
assert isinstance(first, BaseTool)
|
||||
assert first.extras is None or "cache_control" not in first.extras
|
||||
last = result.tools[1]
|
||||
assert isinstance(last, BaseTool)
|
||||
assert last.extras is not None
|
||||
assert last.extras["cache_control"] == {"type": "ephemeral", "ttl": "5m"}
|
||||
|
||||
def test_does_not_mutate_original_tools(self) -> None:
|
||||
@tool
|
||||
def my_tool(x: str) -> str:
|
||||
"""A tool."""
|
||||
return x
|
||||
|
||||
original_extras = my_tool.extras
|
||||
self._run(self._make_request(tools=[my_tool]))
|
||||
assert my_tool.extras is original_extras
|
||||
|
||||
def test_preserves_existing_extras(self) -> None:
|
||||
@tool(extras={"defer_loading": True})
|
||||
def my_tool(x: str) -> str:
|
||||
"""A tool."""
|
||||
return x
|
||||
|
||||
result = self._run(self._make_request(tools=[my_tool]))
|
||||
assert result.tools is not None
|
||||
t = result.tools[0]
|
||||
assert isinstance(t, BaseTool)
|
||||
assert t.extras is not None
|
||||
assert t.extras["defer_loading"] is True
|
||||
assert t.extras["cache_control"] == {
|
||||
"type": "ephemeral",
|
||||
"ttl": "5m",
|
||||
}
|
||||
|
||||
def test_passes_through_empty_tools(self) -> None:
|
||||
result = self._run(self._make_request(tools=[]))
|
||||
assert result.tools == []
|
||||
|
||||
def test_passes_through_none_tools(self) -> None:
|
||||
result = self._run(self._make_request(tools=None))
|
||||
assert result.tools == []
|
||||
|
||||
def test_respects_custom_ttl(self) -> None:
|
||||
@tool
|
||||
def my_tool(x: str) -> str:
|
||||
"""A tool."""
|
||||
return x
|
||||
|
||||
middleware = AnthropicPromptCachingMiddleware(ttl="1h")
|
||||
request = self._make_request(tools=[my_tool])
|
||||
captured: ModelRequest | None = None
|
||||
|
||||
def handler(req: ModelRequest) -> ModelResponse:
|
||||
nonlocal captured
|
||||
captured = req
|
||||
return ModelResponse(result=[AIMessage(content="ok")])
|
||||
|
||||
middleware.wrap_model_call(request, handler)
|
||||
assert captured is not None
|
||||
assert captured.tools is not None
|
||||
t = captured.tools[0]
|
||||
assert isinstance(t, BaseTool)
|
||||
assert t.extras is not None
|
||||
assert t.extras["cache_control"] == {
|
||||
"type": "ephemeral",
|
||||
"ttl": "1h",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user