fix(anthropic): tag multiple system message blocks for better cache hits

When the system message has multiple content blocks (e.g. a static base
prompt followed by dynamic memory content from middleware), tag both the
second-to-last and last blocks with cache_control instead of only the
last block. This creates two breakpoints so that when the last block
changes (e.g. memory differs across conversations), earlier blocks still
get cache hits.

Single-block system messages are unchanged (tags that one block).

Also fixes class and function docstrings to accurately describe behavior
(tags last tool, not all tools).
This commit is contained in:
Kevin Frank
2026-04-10 16:14:22 -05:00
parent 9f232caa7a
commit 65309fc57f

View File

@@ -41,10 +41,11 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
Applies cache control breakpoints to:
- **System message**: Tags the last content block of the system message
with `cache_control` so static system prompt content is cached.
- **Tools**: Tags all tool definitions with `cache_control` so tool
schemas are cached across turns.
- **System message**: Tags system message content blocks with
`cache_control`. When multiple blocks exist, tags both the
second-to-last and last blocks for finer-grained cache hits.
- **Tools**: Tags the last tool definition with `cache_control` so
tool schemas are cached across turns.
- **Last cacheable block**: Tags last cacheable block of message sequence using
Anthropic's automatic caching feature.
@@ -186,7 +187,14 @@ def _tag_system_message(
system_message: Any,
cache_control: dict[str, str],
) -> Any:
"""Tag the last content block of a system message with cache_control.
"""Tag system message content blocks with cache_control.
When the system message has multiple content blocks (e.g. a static
base prompt and dynamic memory), tags both the second-to-last and
last blocks. This creates two cache breakpoints so that when the
last block changes, the earlier blocks still get cache hits.
When there is only one block, tags that block.
Returns the original system_message unchanged if there are no blocks
to tag.
@@ -196,8 +204,8 @@ def _tag_system_message(
cache_control: The cache control dict to apply.
Returns:
A new SystemMessage with cache_control on the last block, or the
original if no modification was needed.
A new SystemMessage with cache_control applied, or the original
if no modification was needed.
"""
if system_message is None:
return system_message
@@ -213,6 +221,10 @@ def _tag_system_message(
if not content:
return system_message
new_content = list(content)
if len(new_content) >= 2:
second_last = new_content[-2]
base2 = second_last if isinstance(second_last, dict) else {}
new_content[-2] = {**base2, "cache_control": cache_control}
last = new_content[-1]
base = last if isinstance(last, dict) else {}
new_content[-1] = {**base, "cache_control": cache_control}